stratified_cv

Stratified cross-validation for tfbpmodeling.

tfbpmodeling.stratified_cv

stratified_cv_modeling

stratified_cv_modeling(
    y,
    X,
    classes,
    estimator=LassoCV(),
    skf=StratifiedKFold(
        n_splits=4, shuffle=True, random_state=42
    ),
    sample_weight=None,
    **kwargs
)

Fit a model using stratified cross-validation splits.

This function wraps a scikit-learn estimator with user-defined stratified folds. While it defaults to LassoCV, any estimator with a cv attribute can be used.

Parameters:
  • y (DataFrame) –

    Response variable. Must be a single-column DataFrame.

  • X (DataFrame) –

    Predictor matrix. Must be a DataFrame with the same number of rows as y.

  • classes (ndarray) –

    Array of class labels for stratification, typically generated by stratification_classification().

  • estimator (BaseEstimator, default: LassoCV() ) –

    scikit-learn estimator to use for modeling. Must support cv as an attribute.

  • skf (StratifiedKFold, default: StratifiedKFold(n_splits=4, shuffle=True, random_state=42) ) –

    StratifiedKFold object to control how splits are generated.

  • sample_weight (ndarray | None, default: None ) –

    Optional array of per-sample weights for the estimator.

  • kwargs

    Additional arguments passed to the estimator's fit() method.

Returns:
  • BaseEstimator

    A fitted estimator with the best parameters determined via cross-validation.

Raises:
  • ValueError

    If inputs are misformatted or incompatible with the estimator.

Source code in tfbpmodeling/stratified_cv.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def stratified_cv_modeling(
    y: pd.DataFrame,
    X: pd.DataFrame,
    classes: np.ndarray,
    estimator: BaseEstimator = LassoCV(),
    skf: StratifiedKFold = StratifiedKFold(n_splits=4, shuffle=True, random_state=42),
    sample_weight: np.ndarray | None = None,
    **kwargs,
) -> BaseEstimator:
    """
    Fit a model using stratified cross-validation splits.

    This function wraps a scikit-learn estimator with user-defined stratified folds.
    While it defaults to `LassoCV`, any estimator with a `cv` attribute can be used.

    :param y: Response variable. Must be a single-column DataFrame.
    :param X: Predictor matrix. Must be a DataFrame with the same number of rows as `y`.
    :param classes: Array of class labels for stratification, typically generated by
        `stratification_classification()`.
    :param estimator: scikit-learn estimator to use for modeling. Must support `cv` as
        an attribute.
    :param skf: StratifiedKFold object to control how splits are generated.
    :param sample_weight: Optional array of per-sample weights for the estimator.
    :param kwargs: Additional arguments passed to the estimator's `fit()` method.

    :return: A fitted estimator with the best parameters determined via
        cross-validation.

    :raises ValueError: If inputs are misformatted or incompatible with the estimator.

    """
    # Validate data
    if not isinstance(y, pd.DataFrame):
        raise ValueError("The response variable y must be a DataFrame.")
    if y.shape[1] != 1:
        raise ValueError("The response variable y must be a single column DataFrame.")
    if not isinstance(X, pd.DataFrame):
        raise ValueError("The predictors X must be a DataFrame.")
    if X.shape[0] != y.shape[0]:
        raise ValueError("The number of rows in X must match the number of rows in y.")
    if classes.size == 0 or not isinstance(classes, np.ndarray):
        raise ValueError("The classes must be a non-empty numpy array.")

    # Verify estimator has a `cv` attribute
    if not hasattr(estimator, "cv"):
        raise ValueError("The estimator must support a `cv` parameter.")

    # Initialize StratifiedKFold for stratified splits
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        # default setting for shuffle is False, which means the partitioning is
        # deterministic and static. Recommendation for bootstrapping is to
        # set shuffle=True and use random_state = bootstrap_iteration in order to
        # have random, but reproducible, partitions
        folds = list(skf.split(X, classes))
        for warning in w:
            logger.debug(
                f"Warning encountered during stratified k-fold split: {warning.message}"
            )

    # Clone the estimator and set the `cv` attribute with predefined folds
    model = clone(estimator)
    model.cv = folds

    # Step 7: Fit the model using the custom cross-validation folds
    model.fit(
        X,
        y.values.ravel(),
        sample_weight=sample_weight,
    )

    return model

Overview

The stratified_cv module provides cross-validation functionality that maintains the distribution of data characteristics across folds. This is particularly important for tfbpmodeling where data may have natural groupings or strata that should be preserved during validation.

Key Features

  • Stratified Sampling: Maintains data distribution across CV folds
  • Bootstrap Integration: Works with bootstrap resampling
  • Flexible Stratification: Multiple stratification strategies
  • Robust Validation: Reduces bias in cross-validation estimates

Usage Examples

Basic Stratified CV

from tfbpmodeling.stratified_cv import StratifiedCV

# Create stratified CV object
cv = StratifiedCV(
    n_splits=5,
    stratification_variable='binding_strength_bins',
    random_state=42
)

# Generate CV folds
for train_idx, test_idx in cv.split(X, y):
    # Train and evaluate model
    pass

Bootstrap Integration

from tfbpmodeling.stratified_cv import bootstrap_stratified_cv

# Perform bootstrap with stratified CV
cv_scores = bootstrap_stratified_cv(
    X=predictor_data,
    y=response_data,
    estimator=LassoCV(),
    n_bootstraps=1000,
    cv_folds=5,
    stratification_bins=[0, 8, 12, np.inf]
)

Stratification Methods

Binding Strength Bins

Stratifies data based on transcription factor binding strength:

# Define binding strength bins
bins = [0, 0.1, 0.5, 1.0]  # Low, medium, high binding

cv = StratifiedCV(
    n_splits=5,
    stratification_method='binding_bins',
    bins=bins
)

Expression Level Bins

Stratifies based on expression level ranges:

# Expression-based stratification
cv = StratifiedCV(
    n_splits=5,
    stratification_method='expression_bins',
    bins=[-np.inf, -1, 0, 1, np.inf]
)