stratified_cv
Stratified cross-validation for tfbpmodeling.
tfbpmodeling.stratified_cv
stratified_cv_modeling
stratified_cv_modeling(
y,
X,
classes,
estimator=LassoCV(),
skf=StratifiedKFold(
n_splits=4, shuffle=True, random_state=42
),
sample_weight=None,
**kwargs
)
Fit a model using stratified cross-validation splits.
This function wraps a scikit-learn estimator with user-defined stratified folds.
While it defaults to LassoCV, any estimator with a cv attribute can be used.
| Parameters: |
-
y
(DataFrame)
–
Response variable. Must be a single-column DataFrame.
-
X
(DataFrame)
–
Predictor matrix. Must be a DataFrame with the same number of rows as y.
-
classes
(ndarray)
–
Array of class labels for stratification, typically generated by stratification_classification().
-
estimator
(BaseEstimator, default:
LassoCV()
)
–
scikit-learn estimator to use for modeling. Must support cv as an attribute.
-
skf
(StratifiedKFold, default:
StratifiedKFold(n_splits=4, shuffle=True, random_state=42)
)
–
StratifiedKFold object to control how splits are generated.
-
sample_weight
(ndarray | None, default:
None
)
–
Optional array of per-sample weights for the estimator.
-
kwargs
–
Additional arguments passed to the estimator's fit() method.
|
| Returns: |
-
BaseEstimator
–
A fitted estimator with the best parameters determined via cross-validation.
|
| Raises: |
-
ValueError
–
If inputs are misformatted or incompatible with the estimator.
|
Source code in tfbpmodeling/stratified_cv.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 | def stratified_cv_modeling(
y: pd.DataFrame,
X: pd.DataFrame,
classes: np.ndarray,
estimator: BaseEstimator = LassoCV(),
skf: StratifiedKFold = StratifiedKFold(n_splits=4, shuffle=True, random_state=42),
sample_weight: np.ndarray | None = None,
**kwargs,
) -> BaseEstimator:
"""
Fit a model using stratified cross-validation splits.
This function wraps a scikit-learn estimator with user-defined stratified folds.
While it defaults to `LassoCV`, any estimator with a `cv` attribute can be used.
:param y: Response variable. Must be a single-column DataFrame.
:param X: Predictor matrix. Must be a DataFrame with the same number of rows as `y`.
:param classes: Array of class labels for stratification, typically generated by
`stratification_classification()`.
:param estimator: scikit-learn estimator to use for modeling. Must support `cv` as
an attribute.
:param skf: StratifiedKFold object to control how splits are generated.
:param sample_weight: Optional array of per-sample weights for the estimator.
:param kwargs: Additional arguments passed to the estimator's `fit()` method.
:return: A fitted estimator with the best parameters determined via
cross-validation.
:raises ValueError: If inputs are misformatted or incompatible with the estimator.
"""
# Validate data
if not isinstance(y, pd.DataFrame):
raise ValueError("The response variable y must be a DataFrame.")
if y.shape[1] != 1:
raise ValueError("The response variable y must be a single column DataFrame.")
if not isinstance(X, pd.DataFrame):
raise ValueError("The predictors X must be a DataFrame.")
if X.shape[0] != y.shape[0]:
raise ValueError("The number of rows in X must match the number of rows in y.")
if classes.size == 0 or not isinstance(classes, np.ndarray):
raise ValueError("The classes must be a non-empty numpy array.")
# Verify estimator has a `cv` attribute
if not hasattr(estimator, "cv"):
raise ValueError("The estimator must support a `cv` parameter.")
# Initialize StratifiedKFold for stratified splits
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# default setting for shuffle is False, which means the partitioning is
# deterministic and static. Recommendation for bootstrapping is to
# set shuffle=True and use random_state = bootstrap_iteration in order to
# have random, but reproducible, partitions
folds = list(skf.split(X, classes))
for warning in w:
logger.debug(
f"Warning encountered during stratified k-fold split: {warning.message}"
)
# Clone the estimator and set the `cv` attribute with predefined folds
model = clone(estimator)
model.cv = folds
# Step 7: Fit the model using the custom cross-validation folds
model.fit(
X,
y.values.ravel(),
sample_weight=sample_weight,
)
return model
|
Overview
The stratified_cv module provides cross-validation functionality that maintains the distribution of data characteristics across folds. This is particularly important for tfbpmodeling where data may have natural groupings or strata that should be preserved during validation.
Key Features
- Stratified Sampling: Maintains data distribution across CV folds
- Bootstrap Integration: Works with bootstrap resampling
- Flexible Stratification: Multiple stratification strategies
- Robust Validation: Reduces bias in cross-validation estimates
Usage Examples
Basic Stratified CV
from tfbpmodeling.stratified_cv import StratifiedCV
# Create stratified CV object
cv = StratifiedCV(
n_splits=5,
stratification_variable='binding_strength_bins',
random_state=42
)
# Generate CV folds
for train_idx, test_idx in cv.split(X, y):
# Train and evaluate model
pass
Bootstrap Integration
from tfbpmodeling.stratified_cv import bootstrap_stratified_cv
# Perform bootstrap with stratified CV
cv_scores = bootstrap_stratified_cv(
X=predictor_data,
y=response_data,
estimator=LassoCV(),
n_bootstraps=1000,
cv_folds=5,
stratification_bins=[0, 8, 12, np.inf]
)
Stratification Methods
Binding Strength Bins
Stratifies data based on transcription factor binding strength:
# Define binding strength bins
bins = [0, 0.1, 0.5, 1.0] # Low, medium, high binding
cv = StratifiedCV(
n_splits=5,
stratification_method='binding_bins',
bins=bins
)
Expression Level Bins
Stratifies based on expression level ranges:
# Expression-based stratification
cv = StratifiedCV(
n_splits=5,
stratification_method='expression_bins',
bins=[-np.inf, -1, 0, 1, np.inf]
)