Source code for binlearn.methods._chi2_binning

"""
Clean Chi-square binning implementation.

This module provides Chi2Binning that inherits from SupervisedBinningBase.
Uses chi-square statistic to find optimal bin boundaries for classification tasks.
"""

from __future__ import annotations

from typing import Any

import numpy as np
from scipy.stats import chi2_contingency

from ..base import SupervisedBinningBase
from ..config import apply_config_defaults
from ..utils import (
    BinEdgesDict,
    ConfigurationError,
    FittingError,
    create_equal_width_bins,
    create_param_dict_for_config,
    validate_positive_integer,
)


# pylint: disable=too-many-ancestors
[docs] class Chi2Binning(SupervisedBinningBase): """Chi-square binning implementation for supervised discretization. This class implements chi-square binning (χ² binning), a supervised discretization method that uses the chi-square statistic to find optimal bin boundaries for classification tasks. The method creates bins that maximize the association between numeric features and categorical target variables, making it particularly effective for improving classification performance. Chi-square binning is particularly effective for: - Binary and multi-class classification preprocessing - Creating bins that preserve class-discriminative information - Reducing feature dimensionality while maintaining predictive power - Handling continuous features with complex relationships to target classes Key Features: - Uses chi-square test of independence to guide bin boundary selection - Iterative merging process starting from fine initial discretization - Configurable stopping criteria (significance level, bin count limits) - Handles both binary and multi-class classification targets - Automatic handling of insufficient data and edge cases Algorithm: 1. Create initial fine-grained discretization (equal frequency or equal width) 2. For each pair of adjacent bins, calculate chi-square statistic 3. Merge the pair with the smallest (least significant) chi-square value 4. Repeat merging until stopping criterion is met: - Minimum number of bins reached, OR - All remaining chi-square values exceed significance threshold (alpha) 5. Create final bin boundaries and representatives Parameters: max_bins: Maximum number of bins to create. The algorithm will not exceed this limit regardless of statistical significance. Useful for controlling model complexity and computational costs. min_bins: Minimum number of bins to maintain. The algorithm will not merge below this threshold even if chi-square values are not significant. Ensures some level of discretization is preserved. alpha: Significance level for the chi-square test. Adjacent bins are merged if their chi-square p-value exceeds this threshold (indicating lack of significant association). Lower values result in more bins. initial_bins: Number of bins to create in the initial discretization step before beginning the merging process. Higher values provide finer granularity for the merging algorithm to work with. Attributes: bin_edges_: Dictionary mapping column identifiers to lists of optimized bin edges after fitting. These edges maximize class separation. bin_representatives_: Dictionary mapping column identifiers to lists of bin representatives (typically bin centers). Example: >>> import numpy as np >>> from binlearn.methods import Chi2Binning >>> >>> # Binary classification example >>> X = np.random.normal(0, 1, (1000, 2)) >>> # Create target correlated with first feature >>> y = (X[:, 0] > 0).astype(int) >>> >>> binner = Chi2Binning(max_bins=5, alpha=0.05) >>> binner.fit(X, guidance_data=y.reshape(-1, 1)) >>> X_binned = binner.transform(X) >>> >>> # Multi-class example with custom parameters >>> y_multi = np.random.choice([0, 1, 2], size=1000) >>> binner_multi = Chi2Binning( ... max_bins=10, ... min_bins=3, ... alpha=0.01, ... initial_bins=20 ... ) >>> binner_multi.fit(X, guidance_data=y_multi.reshape(-1, 1)) Note: - Requires target data (guidance_data) during fitting for supervised learning - Works only with numeric input features and categorical targets - Performance depends on the relationship between features and target - May create fewer bins than max_bins if early stopping criteria are met - Inherits clipping behavior and format preservation from SupervisedBinningBase """ # pylint: disable=too-many-arguments,too-many-positional-arguments
[docs] def __init__( self, max_bins: int | None = None, min_bins: int | None = None, alpha: float | None = None, initial_bins: int | None = None, clip: bool | None = None, preserve_dataframe: bool | None = None, guidance_columns: Any = None, *, bin_edges: BinEdgesDict | None = None, bin_representatives: BinEdgesDict | None = None, class_: ( # pylint: disable=unused-argument str | None ) = None, # For reconstruction compatibility module_: ( # pylint: disable=unused-argument str | None ) = None, # For reconstruction compatibility ): """Initialize Chi-square binning.""" # Use standardized initialization pattern user_params = create_param_dict_for_config( max_bins=max_bins, min_bins=min_bins, alpha=alpha, initial_bins=initial_bins, clip=clip, preserve_dataframe=preserve_dataframe, ) # Apply configuration defaults for chi2 method params = apply_config_defaults("chi2", user_params) # Store chi-square specific parameters with config defaults self.max_bins = params.get("max_bins", max_bins if max_bins is not None else 10) self.min_bins = params.get("min_bins", min_bins if min_bins is not None else 2) self.alpha = params.get("alpha", alpha if alpha is not None else 0.05) self.initial_bins = params.get( "initial_bins", initial_bins if initial_bins is not None else 20 ) # Initialize instance attributes self._filtered_contingency: np.ndarray[Any, Any] | None = None # Initialize parent with resolved config parameters (no fit_jointly for supervised) # Note: guidance_columns, bin_edges, bin_representatives are never set from config SupervisedBinningBase.__init__( self, clip=params.get("clip"), preserve_dataframe=params.get("preserve_dataframe"), guidance_columns=guidance_columns, bin_edges=bin_edges, bin_representatives=bin_representatives, )
def _validate_params(self) -> None: """Validate Chi-square binning specific parameters.""" # Call parent validation SupervisedBinningBase._validate_params(self) # Use standardized validation utilities validate_positive_integer(self.max_bins, "max_bins") validate_positive_integer(self.min_bins, "min_bins") validate_positive_integer(self.initial_bins, "initial_bins") # Validate alpha (must be between 0 and 1) if not isinstance(self.alpha, int | float) or not 0.0 < self.alpha < 1.0: raise ConfigurationError( f"alpha must be a number between 0 and 1 (exclusive), got {self.alpha}", suggestions=["Example: alpha=0.05"], ) # Validate bin constraints if self.min_bins > self.max_bins: raise ConfigurationError( f"min_bins ({self.min_bins}) must be <= max_bins ({self.max_bins})", suggestions=["Reduce min_bins or increase max_bins"], ) # Validate initial_bins constraint if self.initial_bins < self.max_bins: raise ConfigurationError( f"initial_bins ({self.initial_bins}) must be >= max_bins ({self.max_bins})", suggestions=["Increase initial_bins or reduce max_bins"], ) def _calculate_bins( self, x_col: np.ndarray[Any, Any], col_id: Any, guidance_data: np.ndarray[Any, Any] | None = None, ) -> tuple[list[float], list[float]]: """Calculate bin edges and representatives using chi-square optimization. Chi2 binning is a supervised method and requires guidance data. Args: x_col: Clean feature data (no missing values) col_id: Column identifier guidance_data: Target data with shape (n_samples, 1). Required. Returns: Tuple of (bin_edges, bin_representatives) Raises: ValueError: If guidance_data is None (supervised method requires targets) """ if guidance_data is None: raise ValueError( "Chi2 binning is a supervised method and requires guidance data (targets)" ) # Extract the single target column (guaranteed to have shape (n_samples, 1) # by SupervisedBinningBase) y_col = guidance_data[:, 0] return self._calculate_chi2_bins(x_col, y_col, col_id) def _validate_data_requirements( self, x_col: np.ndarray[Any, Any], y_col: np.ndarray[Any, Any], col_id: Any ) -> np.ndarray[Any, Any]: """Validate data requirements and return unique classes. Separated for easier testing of validation logic. """ if len(x_col) < 2: raise FittingError( f"Column {col_id} has too few data points ({len(x_col)}). " "Chi2 binning requires at least 2 data points." ) # Get unique target classes unique_classes = np.unique(y_col) if len(unique_classes) < 2: raise FittingError( f"Column {col_id} target has insufficient class diversity " f"({len(unique_classes)} classes). " "Chi2 binning requires at least 2 target classes." ) return np.asarray(unique_classes) def _create_initial_bins( self, x_col: np.ndarray[Any, Any] ) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Create initial equal-width binning. Separated for easier testing of binning logic. """ # Use standardized equal-width binning utility initial_edges = create_equal_width_bins(x_col, self.initial_bins) # Create bin assignments bin_indices = np.digitize(x_col, initial_edges) - 1 bin_indices = np.clip(bin_indices, 0, len(initial_edges) - 2) return initial_edges, bin_indices def _validate_intervals_created(self, intervals: list[dict[str, Any]], col_id: Any) -> None: """Validate that intervals were successfully created. Separated to make the empty intervals check (line 189) easily testable. """ if not intervals: raise FittingError( f"Failed to create initial intervals for column {col_id}. " "Data distribution may be unsuitable for chi2 binning." ) def _calculate_chi2_bins( self, x_col: np.ndarray[Any, Any], y_col: np.ndarray[Any, Any], col_id: Any, ) -> tuple[list[float], list[float]]: """Calculate chi-square optimized bin edges and representatives. Args: x_col: Preprocessed feature data (already handled by base class) y_col: Target data - 1D array (may have been filtered by SupervisedBinningBase) col_id: Column identifier Returns: Tuple of (bin_edges, bin_representatives) Raises: FittingError: If data is insufficient for chi-square binning """ # Step 1: Validate data requirements unique_classes = self._validate_data_requirements(x_col, y_col, col_id) # Step 2: Create initial equal-width binning initial_edges, bin_indices = self._create_initial_bins(x_col) # Step 3: Build initial contingency table intervals = self._build_intervals(bin_indices, y_col, initial_edges, unique_classes) # Step 4: Validate intervals were created (testable line 189) self._validate_intervals_created(intervals, col_id) # Step 5: Iteratively merge intervals with smallest chi-square final_intervals = self._merge_intervals(intervals, unique_classes) # Step 6: Extract edges and representatives return self._extract_final_results(final_intervals) def _extract_final_results( self, final_intervals: list[dict[str, Any]] ) -> tuple[list[float], list[float]]: """Extract final bin edges and representatives from intervals. Separated for easier testing of result extraction. """ edges = [final_intervals[0]["min"]] representatives = [] for interval in final_intervals: edges.append(interval["max"]) # Representative is the midpoint of the interval representatives.append((interval["min"] + interval["max"]) / 2) return edges, representatives def _build_intervals( self, bin_indices: np.ndarray[Any, Any], y_col: np.ndarray[Any, Any], initial_edges: np.ndarray[Any, Any], unique_classes: np.ndarray[Any, Any], ) -> list[dict[str, Any]]: """Build initial intervals with contingency information.""" intervals = [] for i in range(len(initial_edges) - 1): interval = self._create_interval_from_bin( i, bin_indices, y_col, initial_edges, unique_classes ) if interval is not None: intervals.append(interval) return intervals def _is_valid_interval(self, interval: dict[str, Any]) -> bool: """Check if interval is valid (non-empty). Separated to make line 254 branch easily testable. """ total_count = interval["total_count"] return isinstance(total_count, int | float) and total_count > 0 def _create_interval_from_bin( self, bin_idx: int, bin_indices: np.ndarray[Any, Any], y_col: np.ndarray[Any, Any], initial_edges: np.ndarray[Any, Any], unique_classes: np.ndarray[Any, Any], ) -> dict[str, Any] | None: """Create an interval from a bin index, returns None for empty bins.""" mask = bin_indices == bin_idx if not np.any(mask): return None # Skip empty intervals # Count occurrences of each class in this interval y_interval = y_col[mask] class_counts = self._calculate_class_counts(y_interval, unique_classes) interval = { "min": float(initial_edges[bin_idx]), "max": float(initial_edges[bin_idx + 1]), "class_counts": class_counts, "total_count": int(np.sum(mask)), } # Only return non-empty intervals - simplified logic if not self._is_valid_interval(interval): return None # Line 307 - now easily testable return interval def _calculate_class_counts( self, y_interval: np.ndarray[Any, Any], unique_classes: np.ndarray[Any, Any], ) -> dict[Any, int]: """Calculate class counts for a given interval.""" class_counts = {} for cls in unique_classes: class_counts[cls] = int(np.sum(y_interval == cls)) return class_counts def _has_enough_bins_to_merge(self, current_intervals: list[dict[str, Any]]) -> bool: """Check if we have more bins than max_bins (continue merging condition). Separated to make while loop condition testable. """ return bool(len(current_intervals) > self.max_bins) def _should_perform_merge(self, merge_idx: int) -> bool: """Check if merge should be performed (merge_idx validation). Separated to make line 282->275 branch testable. """ return merge_idx >= 0 def _should_continue_merging( self, current_intervals: list[dict[str, Any]], min_chi2: float, unique_classes: np.ndarray[Any, Any], ) -> bool: """Check if merging should continue (opposite of _should_stop_merging). Separated to make the break condition (line 349->342) testable. """ return not self._should_stop_merging(current_intervals, min_chi2, unique_classes) def _merge_intervals( self, intervals: list[dict[str, Any]], unique_classes: np.ndarray[Any, Any], ) -> list[dict[str, Any]]: """Iteratively merge intervals to optimize chi-square statistic.""" current_intervals = intervals.copy() while self._has_enough_bins_to_merge(current_intervals): merge_idx, min_chi2 = self._find_best_merge_candidate(current_intervals, unique_classes) # Reorganized to make break condition testable if not self._should_continue_merging(current_intervals, min_chi2, unique_classes): break # Line 349->342 - now easily testable # Reorganized merge logic to make the skip-to-while flow testable current_intervals = self._attempt_merge_or_continue(current_intervals, merge_idx) return current_intervals def _attempt_merge_or_continue( self, current_intervals: list[dict[str, Any]], merge_idx: int, ) -> list[dict[str, Any]]: """Attempt merge if valid, otherwise return unchanged intervals. Separated to make the skip-back-to-while-loop branch testable. """ if self._should_perform_merge(merge_idx): return self._perform_merge(current_intervals, merge_idx) # If merge_idx is invalid, return unchanged intervals # This makes the flow back to while loop easily testable return current_intervals def _find_best_merge_candidate( self, current_intervals: list[dict[str, Any]], unique_classes: np.ndarray[Any, Any], ) -> tuple[int, float]: """Find the pair of adjacent intervals with smallest chi-square statistic.""" min_chi2 = float("inf") merge_idx = -1 for i in range(len(current_intervals) - 1): chi2_stat = self._calculate_chi2_for_merge( current_intervals[i], current_intervals[i + 1], unique_classes ) if chi2_stat < min_chi2: min_chi2 = chi2_stat merge_idx = i return merge_idx, min_chi2 def _at_minimum_bins(self, current_intervals: list[dict[str, Any]]) -> bool: """Check if we're at minimum number of bins. Separated to make stopping condition testable. """ return bool(len(current_intervals) <= self.min_bins) def _chi2_is_significant(self, min_chi2: float, unique_classes: np.ndarray[Any, Any]) -> bool: """Check if chi-square value is significant. Separated to make significance testing branch testable. """ return min_chi2 > self._get_chi2_critical_value(len(unique_classes) - 1) def _above_minimum_bins(self, current_intervals: list[dict[str, Any]]) -> bool: """Check if we have at least minimum number of bins. Separated to make line 319->322 branch testable. """ return bool(len(current_intervals) >= self.min_bins) def _should_stop_merging( self, current_intervals: list[dict[str, Any]], min_chi2: float, unique_classes: np.ndarray[Any, Any], ) -> bool: """Determine if merging should stop based on significance and bin constraints.""" # Always stop if at minimum bins if self._at_minimum_bins(current_intervals): return True # Check significance - reorganized for better testability if not self._chi2_is_significant(min_chi2, unique_classes): return False # Line 407->410 - not significant, continue merging # Chi2 is significant, check if we can stop return self._above_minimum_bins(current_intervals) def _perform_merge( self, current_intervals: list[dict[str, Any]], merge_idx: int, ) -> list[dict[str, Any]]: """Perform the actual merge of two intervals.""" merged_interval = self._merge_two_intervals( current_intervals[merge_idx], current_intervals[merge_idx + 1] ) return ( current_intervals[:merge_idx] + [merged_interval] + current_intervals[merge_idx + 2 :] ) def _build_contingency_table( self, interval1: dict[str, Any], interval2: dict[str, Any], unique_classes: np.ndarray[Any, Any], ) -> np.ndarray[Any, Any]: """Build contingency table for two intervals. Separated to make contingency table building testable. """ contingency_rows = [] for cls in unique_classes: row = [interval1["class_counts"].get(cls, 0), interval2["class_counts"].get(cls, 0)] contingency_rows.append(row) return np.array(contingency_rows) def _validate_contingency_table(self, contingency_table: np.ndarray[Any, Any]) -> bool: """Validate contingency table has valid rows and columns. Separated to make validation logic testable. """ # Remove empty rows/columns row_sums = contingency_table.sum(axis=1) col_sums = contingency_table.sum(axis=0) valid_rows = row_sums > 0 valid_cols = col_sums > 0 if not np.any(valid_rows) or not np.any(valid_cols): return False # Filter to valid rows/cols filtered_table = contingency_table[valid_rows][:, valid_cols] if filtered_table.size == 0 or filtered_table.shape[0] < 2 or filtered_table.shape[1] < 2: return False # Store filtered table for calculation self._filtered_contingency = filtered_table return True def _convert_chi2_result(self, chi2_stat: Any) -> float: """Convert chi2 statistic to float, handling edge cases. Separated to make the type checking (lines 478-479) testable. """ if isinstance(chi2_stat, int | float | np.number): return float(chi2_stat) return 0.0 # Lines 478-479 - now testable with non-numeric input def _compute_chi2_statistic(self, contingency_table: np.ndarray[Any, Any]) -> float: """Compute chi-square statistic from contingency table. Separated to make chi2 calculation and exception handling testable. """ try: chi2_stat, _, _, _ = chi2_contingency(contingency_table) return self._convert_chi2_result(chi2_stat) except (ValueError, RuntimeWarning): return 0.0 # Exception handling def _handle_chi2_calculation_errors(self, error: Exception) -> float: """Handle specific calculation errors. Separated to make lines 499-500 testable. """ if isinstance(error, ValueError | ZeroDivisionError): return 0.0 # Lines 499-500 - now testable # Re-raise other exceptions raise error def _calculate_chi2_for_merge( self, interval1: dict[str, Any], interval2: dict[str, Any], unique_classes: np.ndarray[Any, Any], ) -> float: """Calculate chi-square statistic for merging two intervals.""" return self._safe_chi2_calculation(interval1, interval2, unique_classes) def _safe_chi2_calculation( self, interval1: dict[str, Any], interval2: dict[str, Any], unique_classes: np.ndarray[Any, Any], ) -> float: """Safely calculate chi2 with exception handling separated for testability.""" try: return self._perform_chi2_calculation(interval1, interval2, unique_classes) except (ValueError, RuntimeWarning, KeyError) as e: return self._handle_chi2_calculation_errors(e) def _perform_chi2_calculation( self, interval1: dict[str, Any], interval2: dict[str, Any], unique_classes: np.ndarray[Any, Any], ) -> float: """Perform the actual chi2 calculation (separated for easier exception testing).""" # Build contingency table for the two intervals contingency_table = self._build_contingency_table(interval1, interval2, unique_classes) # Validate and filter contingency table if not self._validate_contingency_table(contingency_table): return 0.0 # Calculate chi-square statistic with exception handling if self._filtered_contingency is None: return 0.0 return self._compute_chi2_statistic(self._filtered_contingency) def _merge_two_intervals( self, interval1: dict[str, Any], interval2: dict[str, Any], ) -> dict[str, Any]: """Merge two adjacent intervals.""" merged_class_counts = {} # Combine class counts all_classes = set(interval1["class_counts"].keys()) | set(interval2["class_counts"].keys()) for cls in all_classes: merged_class_counts[cls] = interval1["class_counts"].get(cls, 0) + interval2[ "class_counts" ].get(cls, 0) return { "min": interval1["min"], "max": interval2["max"], "class_counts": merged_class_counts, "total_count": interval1["total_count"] + interval2["total_count"], } def _get_chi2_critical_value(self, dof: int) -> float: """Get critical chi-square value for given degrees of freedom and alpha.""" # Approximation for common alpha values # This could be made more precise with scipy.stats.chi2.ppf if self.alpha >= 0.1: return 2.706 # Very lenient if self.alpha >= 0.05: return 3.841 if dof == 1 else 5.991 # Standard return 6.635 if dof == 1 else 9.210 # Strict