Bayesian Feature Selection¶
Scalable Bayesian feature selection using horseshoe priors with NumPyro/JAX. Select relevant features from high-dimensional data using a Bayesian GLM with the regularized horseshoe prior, which provides strong shrinkage for irrelevant features while preserving truly relevant signals.
Free software: MIT license
Documentation: https://macromagic.github.io/bayesian_feature_selection/
Features¶
Horseshoe Prior GLM — Regularized horseshoe prior for automatic feature selection
Multiple GLM Families — Gaussian (regression), Binomial (classification), and Poisson (count data)
MCMC & SVI Inference — Full posterior via NUTS or fast approximation via Stochastic Variational Inference
Flexible Feature Selection — Beta-based, lambda-based, or combined inclusion probabilities
Data Loading & Preprocessing — Built-in CSV loading, train/test splitting, and standardization
CLI Interface — Run experiments from the command line with YAML configuration files
Visualization — Feature importance plots and MCMC diagnostic plots via ArviZ
Quick Start¶
import numpy as np
from bayesian_feature_selection import HorseshoeGLM, InferenceConfig
# Create synthetic data
rng = np.random.RandomState(42)
n, p = 100, 10
X = rng.randn(n, p)
true_beta = np.array([3.0, -2.0, 0, 0, 1.5, 0, 0, 0, 0, 0])
y = X @ true_beta + rng.randn(n) * 0.5
# Fit model
model = HorseshoeGLM(family="gaussian")
config = InferenceConfig(
method="mcmc", num_warmup=500, num_samples=1000,
num_chains=2, use_gpu=False, progress_bar=False
)
model.fit(X, y, config=config)
# Get selected features
importance = model.get_feature_importance(threshold=0.5)
selected = importance[importance["selected"]]
print(selected[["feature_idx", "beta_mean", "beta_inclusion_prob"]])
# Make predictions
predictions = model.predict(X)
CLI Usage¶
$ bayesian-fs -c configs/default.yaml
Installation¶
$ pip install bayesian_feature_selection
For development:
$ pip install -e ".[dev]"
Environment Setup (Python 3.12 + CUDA 12)¶
Python 3.12 (CPU only)
$ pip install bayesian_feature_selection
This installs JAX ≥ 0.7.0 and NumPyro ≥ 0.15.0 automatically.
Python 3.12 + CUDA 12 (GPU)
Install the package, then upgrade JAX with the CUDA 12 backend:
$ pip install bayesian_feature_selection
$ pip install "jax[cuda12]"
Verify the setup:
import jax
print(jax.devices()) # should show CudaDevice(id=0) when GPU is available
Important version notes
NumPyro ≥ 0.15.0 requires JAX ≥ 0.7.0.
JAX ≥ 0.10.0 removed internal symbols that NumPyro ≤ 0.20.1 depends on; use JAX 0.7.x – 0.9.x until NumPyro releases a compatible update.
The
gpuextra (pip install bayesian_feature_selection[gpu]) installsjax[cuda12]and is the recommended way to enable GPU support.
Credits¶
This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.