"""Configuration management for bayesian_feature_selection."""
from dataclasses import dataclass, field, asdict
from typing import Literal, Optional, Dict, Any, List
from pathlib import Path
import yaml
import warnings
[docs]
@dataclass
class DataConfig:
"""Data configuration."""
data_path: Optional[str] = None
target_col: Optional[str] = None
feature_cols: Optional[List[str]] = None # If None, use all except target
test_size: float = 0.0 # Fraction for test split (0 = no split)
standardize: bool = False # Standardize features
random_seed: int = 42
[docs]
@dataclass
class InferenceConfig:
"""Configuration for inference."""
method: Literal["mcmc", "svi"] = "mcmc"
num_warmup: int = 1000
num_samples: int = 2000
num_chains: int = 4
# SVI specific
num_steps: int = 10000
learning_rate: float = 0.001
# Performance
use_gpu: bool = True
progress_bar: bool = True
def __post_init__(self):
"""Validate configuration parameters after initialization."""
# Validate MCMC parameters
if self.num_warmup <= 0:
raise ValueError(
f"num_warmup must be positive, got {self.num_warmup}"
)
if self.num_samples <= 0:
raise ValueError(
f"num_samples must be positive, got {self.num_samples}"
)
if self.num_chains <= 0:
raise ValueError(
f"num_chains must be positive, got {self.num_chains}"
)
# Validate SVI parameters
if self.num_steps <= 0:
raise ValueError(
f"num_steps must be positive, got {self.num_steps}"
)
if not 0 < self.learning_rate < 1:
raise ValueError(
f"learning_rate must be in (0, 1), got {self.learning_rate}"
)
# Cross-parameter validation
if self.method == "mcmc":
if self.num_samples < self.num_warmup:
raise ValueError(
f"num_samples ({self.num_samples}) should be >= "
f"num_warmup ({self.num_warmup}) for MCMC"
)
# Warn about potentially slow configurations
if self.num_chains > 10:
warnings.warn(
f"num_chains={self.num_chains} is high and may be slow. "
"Consider using fewer chains for faster inference.",
UserWarning
)
if self.num_samples > 10000:
warnings.warn(
f"num_samples={self.num_samples} is very high. "
"This may take a long time to run.",
UserWarning
)
elif self.method == "svi":
# Warn about potentially suboptimal SVI configurations
if self.num_steps < 1000:
warnings.warn(
f"num_steps={self.num_steps} may be too low for SVI convergence. "
"Consider using at least 1000 steps.",
UserWarning
)
if self.learning_rate > 0.1:
warnings.warn(
f"learning_rate={self.learning_rate} is quite high for SVI. "
"Consider using a smaller value (e.g., 0.001-0.01).",
UserWarning
)
[docs]
@dataclass
class ModelConfig:
"""Model configuration."""
family: Literal["gaussian", "binomial", "poisson"] = "gaussian"
scale_global: float = 1.0
[docs]
@dataclass
class SelectionConfig:
"""Feature selection configuration."""
method: Literal["beta", "lambda", "both"] = "beta"
threshold: float = 0.5
[docs]
@dataclass
class OutputConfig:
"""Output configuration."""
save_plots: bool = True
save_diagnostics: bool = True
save_samples: bool = False
[docs]
@dataclass
class ExperimentConfig:
"""Complete experiment configuration."""
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
inference: InferenceConfig = field(default_factory=InferenceConfig)
selection: SelectionConfig = field(default_factory=SelectionConfig)
output: OutputConfig = field(default_factory=OutputConfig)
[docs]
@classmethod
def from_yaml(cls, yaml_path: Path) -> "ExperimentConfig":
"""Load configuration from YAML file."""
with open(yaml_path, 'r') as f:
config_dict = yaml.safe_load(f)
data_cfg = DataConfig(**config_dict.get("data", {}))
model_cfg = ModelConfig(**config_dict.get("model", {}))
inference_cfg = InferenceConfig(**config_dict.get("inference", {}))
selection_cfg = SelectionConfig(**config_dict.get("selection", {}))
output_cfg = OutputConfig(**config_dict.get("output", {}))
return cls(
data=data_cfg,
model=model_cfg,
inference=inference_cfg,
selection=selection_cfg,
output=output_cfg
)
[docs]
def to_yaml(self, yaml_path: Path) -> None:
"""Save configuration to YAML file."""
config_dict = {
"data": asdict(self.data),
"model": asdict(self.model),
"inference": asdict(self.inference),
"selection": asdict(self.selection),
"output": asdict(self.output)
}
with open(yaml_path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
[docs]
def update_from_dict(self, updates: Dict[str, Any]) -> "ExperimentConfig":
"""Update configuration from dictionary (e.g., CLI overrides)."""
for section, params in updates.items():
if hasattr(self, section):
section_obj = getattr(self, section)
for key, value in params.items():
if hasattr(section_obj, key):
setattr(section_obj, key, value)
return self