Skip to content

Synthetic Data Loader

Bases: LightningDataModule

A class for a synthetic data loader that generates synthetic bindiing & perturbation effect data for training, validation, and testing a model This class contains all of the logic for generating and parsing the synthetic data, as well as splitting it into train, validation, and test sets It is a subclass of pytorch_lightning.LightningDataModule, which is similar to a regular PyTorch DataLoader but with added functionality for data loading.

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class SyntheticDataLoader(LightningDataModule):
    """A class for a synthetic data loader that generates synthetic bindiing &
    perturbation effect data for training, validation, and testing a model This class
    contains all of the logic for generating and parsing the synthetic data, as well as
    splitting it into train, validation, and test sets It is a subclass of
    pytorch_lightning.LightningDataModule, which is similar to a regular PyTorch
    DataLoader but with added functionality for data loading."""

    def __init__(
        self,
        batch_size: int = 32,
        num_genes: int = 1000,
        bound: list[float] = [0.1, 0.2, 0.2, 0.4, 0.5],
        bound_mean: float = 3.0,
        n_sample: list[int] = [1, 2, 2, 4, 4],
        val_size: float = 0.1,
        test_size: float = 0.1,
        random_state: int = 42,
        max_mean_adjustment: float = 0.0,
        adjustment_function: Callable[
            [torch.Tensor, float, float, float], torch.Tensor
        ] = default_perturbation_effect_adjustment_function,
        tf_relationships: dict[int, list[int] | list[Relation]] = {},
    ) -> None:
        """
        Constructor of SyntheticDataLoader.

        :param batch_size: The number of samples in each mini-batch
        :type batch_size: int
        :param num_genes: The number of genes in the synthetic data (this is the number
            of datapoints in our dataset)
        :type num_genes: int
        :param bound: The proportion of genes in each sample group that are put in the
            bound grop (i.e. have a non-zero binding effect and expression response)
        :type bound: List[int]
        :param n_sample: The number of samples to draw from each bound group
        :type n_sample: List[int]
        :param val_size: The proportion of the dataset to include in the validation
            split
        :type val_size: float
        :param test_size: The proportion of the dataset to include in the test split
        :type test_size: float
        :param random_state: The random seed to use for splitting the data (keep this
            consistent to ensure reproduceability)
        :type random_state: int
        :param bound_mean: The mean of the bound distribution
        :type bound_mean: float
        :param max_mean_adjustment: The maximum mean adjustment to apply to the mean
                                    of the bound (bound) perturbation effects
        :type max_mean_adjustment: float
        :param adjustment_function: A function that adjusts the mean of the bound
                                    (bound) perturbation effects
        :type adjustment_function: Callable[[torch.Tensor, float, float,
                                   float, dict[int, list[int]]], torch.Tensor]
        :raises TypeError: If batch_size is not an positive integer
        :raises TypeError: If num_genes is not an positive integer
        :raises TypeError: If bound is not a list of integers or floats
        :raises TypeError: If n_sample is not a list of integers
        :raises TypeError: If val_size is not a float between 0 and 1 (inclusive)
        :raises TypeError: If test_size is not a float between 0 and 1 (inclusive)
        :raises TypeError: If random_state is not an integer
        :raises TypeError: If bound_mean is not a float
        :raises ValueError: If val_size + test_size is greater than 1 (i.e. the splits
            are too large)

        """
        if not isinstance(batch_size, int) or batch_size < 1:
            raise TypeError("batch_size must be a positive integer")
        if not isinstance(num_genes, int) or num_genes < 1:
            raise TypeError("num_genes must be a positive integer")
        if not isinstance(bound, list) or not all(
            isinstance(x, (int, float)) for x in bound
        ):
            raise TypeError("bound must be a list of integers or floats")
        if not isinstance(n_sample, list) or not all(
            isinstance(x, int) for x in n_sample
        ):
            raise TypeError("n_sample must be a list of integers")
        if not isinstance(val_size, (int, float)) or val_size <= 0 or val_size >= 1:
            raise TypeError("val_size must be a float between 0 and 1 (inclusive)")
        if not isinstance(test_size, (int, float)) or test_size <= 0 or test_size >= 1:
            raise TypeError("test_size must be a float between 0 and 1 (inclusive)")
        if not isinstance(random_state, int):
            raise TypeError("random_state must be an integer")
        if not isinstance(bound_mean, float):
            raise TypeError("bound_mean must be a float")
        if test_size + val_size > 1:
            raise ValueError("val_size + test_size must be less than or equal to 1")

        super().__init__()
        self.batch_size = batch_size
        self.num_genes = num_genes
        self.bound_mean = bound_mean
        self.bound = bound or [0.1, 0.15, 0.2, 0.25, 0.3]
        self.n_sample = n_sample or [1 for _ in range(len(self.bound))]
        self.num_tfs = sum(self.n_sample)  # sum of all n_sample is the number of TFs
        self.val_size = val_size
        self.test_size = test_size
        self.random_state = random_state

        self.max_mean_adjustment = max_mean_adjustment
        self.adjustment_function = adjustment_function
        self.tf_relationships = tf_relationships

        self.final_data_tensor: torch.Tensor = None
        self.binding_effect_matrix: torch.Tensor | None = None
        self.perturbation_effect_matrix: torch.Tensor | None = None
        self.val_dataset: TensorDataset | None = None
        self.test_dataset: TensorDataset | None = None

    def prepare_data(self) -> None:
        """Function to generate the raw synthetic data and save it in a tensor For
        explanations of the functions used to generate the data, see the
        generate_in_silico_data tutorial notebook in the docs No assertion checks are
        performed as that is handled in the functions in generate_data.py."""
        # this will be a list of length 10 with a GenePopulation object in each element
        gene_populations_list = []
        for bound_proportion, n_draws in zip(self.bound, self.n_sample):
            for _ in range(n_draws):
                gene_populations_list.append(
                    generate_gene_population(self.num_genes, bound_proportion)
                )

        # Generate binding data for each gene population
        binding_effect_list = [
            generate_binding_effects(gene_population)
            for gene_population in gene_populations_list
        ]

        # Calculate p-values for binding data
        binding_pvalue_list = [
            generate_pvalues(binding_data) for binding_data in binding_effect_list
        ]

        binding_data_combined = [
            torch.stack((gene_population.labels, binding_effect, binding_pval), dim=1)
            for gene_population, binding_effect, binding_pval in zip(
                gene_populations_list, binding_effect_list, binding_pvalue_list
            )
        ]

        # Stack along a new dimension (dim=1) to create a tensor of shape
        # [num_genes, num_TFs, 3]
        binding_data_tensor = torch.stack(binding_data_combined, dim=1)

        # if we are using a mean adjustment, we need to generate perturbation
        # effects in a slightly different way than if we are not using
        # a mean adjustment
        if self.max_mean_adjustment > 0:
            perturbation_effects_list = generate_perturbation_effects(
                binding_data_tensor,
                bound_mean=self.bound_mean,
                tf_index=0,  # unused
                max_mean_adjustment=self.max_mean_adjustment,
                adjustment_function=self.adjustment_function,
                tf_relationships=self.tf_relationships,
            )

            perturbation_pvalue_list = torch.zeros_like(perturbation_effects_list)
            for col_index in range(perturbation_effects_list.shape[1]):
                perturbation_pvalue_list[:, col_index] = generate_pvalues(
                    perturbation_effects_list[:, col_index]
                )

            # take absolute values
            perturbation_effects_list = torch.abs(perturbation_effects_list)

            perturbation_effects_tensor = perturbation_effects_list
            perturbation_pvalues_tensor = perturbation_pvalue_list
        else:
            perturbation_effects_list = [
                generate_perturbation_effects(
                    binding_data_tensor[:, tf_index, :].unsqueeze(1),
                    bound_mean=self.bound_mean,
                    tf_index=0,  # unused
                )
                for tf_index in range(sum(self.n_sample))
            ]
            perturbation_pvalue_list = [
                generate_pvalues(perturbation_effects)
                for perturbation_effects in perturbation_effects_list
            ]

            # take absolute values
            perturbation_effects_list = [
                torch.abs(perturbation_effects)
                for perturbation_effects in perturbation_effects_list
            ]

            # Convert lists to tensors
            perturbation_effects_tensor = torch.stack(perturbation_effects_list, dim=1)
            perturbation_pvalues_tensor = torch.stack(perturbation_pvalue_list, dim=1)

        # Ensure perturbation data is reshaped to match [n_genes, n_tfs]
        # This step might need adjustment based on the actual shapes of your tensors.
        perturbation_effects_tensor = perturbation_effects_tensor.unsqueeze(
            -1
        )  # Adds an extra dimension for concatenation
        perturbation_pvalues_tensor = perturbation_pvalues_tensor.unsqueeze(
            -1
        )  # Adds an extra dimension for concatenation

        # Concatenate along the last dimension to form a [n_genes, n_tfs, 5] tensor
        self.final_data_tensor = torch.cat(
            (
                binding_data_tensor,
                perturbation_effects_tensor,
                perturbation_pvalues_tensor,
            ),
            dim=2,
        )

    def setup(self, stage: str | None = None) -> None:
        """
        This function runs after prepare_data finishes and is used to split the data
        into train, validation, and test sets It ensures that these datasets are of the
        correct dimensionality and size to be used by the model.

        :param stage: The stage of the data setup (either 'fit' for training, 'validate'
            for validation, or 'test' for testing), unused for now as the model is not
            complicated enough to necessitate this
        :type stage: Optional[str]

        """
        self.binding_effect_matrix = self.final_data_tensor[:, :, 1]
        self.perturbation_effect_matrix = self.final_data_tensor[:, :, 3]

        # split into train, val, and test
        X_train, X_temp, Y_train, Y_temp = train_test_split(
            self.binding_effect_matrix,
            self.perturbation_effect_matrix,
            test_size=(self.val_size + self.test_size),
            random_state=self.random_state,
        )

        # normalize test_size so that it is a percentage of the remaining data
        self.test_size = self.test_size / (self.val_size + self.test_size)
        X_val, X_test, Y_val, Y_test = train_test_split(
            X_temp, Y_temp, test_size=self.test_size, random_state=self.random_state
        )

        # Convert to tensors
        X_train, Y_train = torch.tensor(X_train, dtype=torch.float32), torch.tensor(
            Y_train, dtype=torch.float32
        )
        X_val, Y_val = torch.tensor(X_val, dtype=torch.float32), torch.tensor(
            Y_val, dtype=torch.float32
        )
        X_test, Y_test = torch.tensor(X_test, dtype=torch.float32), torch.tensor(
            Y_test, dtype=torch.float32
        )

        # Set our datasets
        self.train_dataset = TensorDataset(X_train, Y_train)
        self.val_dataset = TensorDataset(X_val, Y_val)
        self.test_dataset = TensorDataset(X_test, Y_test)

    def train_dataloader(self) -> DataLoader:
        """
        Function to return the training dataloader, we shuffle to avoid learning based
        on the order of the data.

        :return: The training dataloader
        :rtype: DataLoader

        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=15,
            shuffle=True,
            persistent_workers=True,
        )

    def val_dataloader(self) -> DataLoader:
        """
        Function to return the validation dataloader.

        :return: The validation dataloader
        :rtype: DataLoader

        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=15,
            shuffle=False,
            persistent_workers=True,
        )

    def test_dataloader(self) -> DataLoader:
        """
        Function to return the testing dataloader.

        :return: The testing dataloader
        :rtype: DataLoader

        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=15,
            shuffle=False,
            persistent_workers=True,
        )

__init__(batch_size=32, num_genes=1000, bound=[0.1, 0.2, 0.2, 0.4, 0.5], bound_mean=3.0, n_sample=[1, 2, 2, 4, 4], val_size=0.1, test_size=0.1, random_state=42, max_mean_adjustment=0.0, adjustment_function=default_perturbation_effect_adjustment_function, tf_relationships={})

Constructor of SyntheticDataLoader.

Parameters:

Name Type Description Default
batch_size int

The number of samples in each mini-batch

32
num_genes int

The number of genes in the synthetic data (this is the number of datapoints in our dataset)

1000
bound list[float]

The proportion of genes in each sample group that are put in the bound grop (i.e. have a non-zero binding effect and expression response)

[0.1, 0.2, 0.2, 0.4, 0.5]
n_sample list[int]

The number of samples to draw from each bound group

[1, 2, 2, 4, 4]
val_size float

The proportion of the dataset to include in the validation split

0.1
test_size float

The proportion of the dataset to include in the test split

0.1
random_state int

The random seed to use for splitting the data (keep this consistent to ensure reproduceability)

42
bound_mean float

The mean of the bound distribution

3.0
max_mean_adjustment float

The maximum mean adjustment to apply to the mean of the bound (bound) perturbation effects

0.0
adjustment_function Callable[[Tensor, float, float, float], Tensor]

A function that adjusts the mean of the bound (bound) perturbation effects

default_perturbation_effect_adjustment_function

Raises:

Type Description
TypeError

If batch_size is not an positive integer

TypeError

If num_genes is not an positive integer

TypeError

If bound is not a list of integers or floats

TypeError

If n_sample is not a list of integers

TypeError

If val_size is not a float between 0 and 1 (inclusive)

TypeError

If test_size is not a float between 0 and 1 (inclusive)

TypeError

If random_state is not an integer

TypeError

If bound_mean is not a float

ValueError

If val_size + test_size is greater than 1 (i.e. the splits are too large)

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(
    self,
    batch_size: int = 32,
    num_genes: int = 1000,
    bound: list[float] = [0.1, 0.2, 0.2, 0.4, 0.5],
    bound_mean: float = 3.0,
    n_sample: list[int] = [1, 2, 2, 4, 4],
    val_size: float = 0.1,
    test_size: float = 0.1,
    random_state: int = 42,
    max_mean_adjustment: float = 0.0,
    adjustment_function: Callable[
        [torch.Tensor, float, float, float], torch.Tensor
    ] = default_perturbation_effect_adjustment_function,
    tf_relationships: dict[int, list[int] | list[Relation]] = {},
) -> None:
    """
    Constructor of SyntheticDataLoader.

    :param batch_size: The number of samples in each mini-batch
    :type batch_size: int
    :param num_genes: The number of genes in the synthetic data (this is the number
        of datapoints in our dataset)
    :type num_genes: int
    :param bound: The proportion of genes in each sample group that are put in the
        bound grop (i.e. have a non-zero binding effect and expression response)
    :type bound: List[int]
    :param n_sample: The number of samples to draw from each bound group
    :type n_sample: List[int]
    :param val_size: The proportion of the dataset to include in the validation
        split
    :type val_size: float
    :param test_size: The proportion of the dataset to include in the test split
    :type test_size: float
    :param random_state: The random seed to use for splitting the data (keep this
        consistent to ensure reproduceability)
    :type random_state: int
    :param bound_mean: The mean of the bound distribution
    :type bound_mean: float
    :param max_mean_adjustment: The maximum mean adjustment to apply to the mean
                                of the bound (bound) perturbation effects
    :type max_mean_adjustment: float
    :param adjustment_function: A function that adjusts the mean of the bound
                                (bound) perturbation effects
    :type adjustment_function: Callable[[torch.Tensor, float, float,
                               float, dict[int, list[int]]], torch.Tensor]
    :raises TypeError: If batch_size is not an positive integer
    :raises TypeError: If num_genes is not an positive integer
    :raises TypeError: If bound is not a list of integers or floats
    :raises TypeError: If n_sample is not a list of integers
    :raises TypeError: If val_size is not a float between 0 and 1 (inclusive)
    :raises TypeError: If test_size is not a float between 0 and 1 (inclusive)
    :raises TypeError: If random_state is not an integer
    :raises TypeError: If bound_mean is not a float
    :raises ValueError: If val_size + test_size is greater than 1 (i.e. the splits
        are too large)

    """
    if not isinstance(batch_size, int) or batch_size < 1:
        raise TypeError("batch_size must be a positive integer")
    if not isinstance(num_genes, int) or num_genes < 1:
        raise TypeError("num_genes must be a positive integer")
    if not isinstance(bound, list) or not all(
        isinstance(x, (int, float)) for x in bound
    ):
        raise TypeError("bound must be a list of integers or floats")
    if not isinstance(n_sample, list) or not all(
        isinstance(x, int) for x in n_sample
    ):
        raise TypeError("n_sample must be a list of integers")
    if not isinstance(val_size, (int, float)) or val_size <= 0 or val_size >= 1:
        raise TypeError("val_size must be a float between 0 and 1 (inclusive)")
    if not isinstance(test_size, (int, float)) or test_size <= 0 or test_size >= 1:
        raise TypeError("test_size must be a float between 0 and 1 (inclusive)")
    if not isinstance(random_state, int):
        raise TypeError("random_state must be an integer")
    if not isinstance(bound_mean, float):
        raise TypeError("bound_mean must be a float")
    if test_size + val_size > 1:
        raise ValueError("val_size + test_size must be less than or equal to 1")

    super().__init__()
    self.batch_size = batch_size
    self.num_genes = num_genes
    self.bound_mean = bound_mean
    self.bound = bound or [0.1, 0.15, 0.2, 0.25, 0.3]
    self.n_sample = n_sample or [1 for _ in range(len(self.bound))]
    self.num_tfs = sum(self.n_sample)  # sum of all n_sample is the number of TFs
    self.val_size = val_size
    self.test_size = test_size
    self.random_state = random_state

    self.max_mean_adjustment = max_mean_adjustment
    self.adjustment_function = adjustment_function
    self.tf_relationships = tf_relationships

    self.final_data_tensor: torch.Tensor = None
    self.binding_effect_matrix: torch.Tensor | None = None
    self.perturbation_effect_matrix: torch.Tensor | None = None
    self.val_dataset: TensorDataset | None = None
    self.test_dataset: TensorDataset | None = None

prepare_data()

Function to generate the raw synthetic data and save it in a tensor For explanations of the functions used to generate the data, see the generate_in_silico_data tutorial notebook in the docs No assertion checks are performed as that is handled in the functions in generate_data.py.

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def prepare_data(self) -> None:
    """Function to generate the raw synthetic data and save it in a tensor For
    explanations of the functions used to generate the data, see the
    generate_in_silico_data tutorial notebook in the docs No assertion checks are
    performed as that is handled in the functions in generate_data.py."""
    # this will be a list of length 10 with a GenePopulation object in each element
    gene_populations_list = []
    for bound_proportion, n_draws in zip(self.bound, self.n_sample):
        for _ in range(n_draws):
            gene_populations_list.append(
                generate_gene_population(self.num_genes, bound_proportion)
            )

    # Generate binding data for each gene population
    binding_effect_list = [
        generate_binding_effects(gene_population)
        for gene_population in gene_populations_list
    ]

    # Calculate p-values for binding data
    binding_pvalue_list = [
        generate_pvalues(binding_data) for binding_data in binding_effect_list
    ]

    binding_data_combined = [
        torch.stack((gene_population.labels, binding_effect, binding_pval), dim=1)
        for gene_population, binding_effect, binding_pval in zip(
            gene_populations_list, binding_effect_list, binding_pvalue_list
        )
    ]

    # Stack along a new dimension (dim=1) to create a tensor of shape
    # [num_genes, num_TFs, 3]
    binding_data_tensor = torch.stack(binding_data_combined, dim=1)

    # if we are using a mean adjustment, we need to generate perturbation
    # effects in a slightly different way than if we are not using
    # a mean adjustment
    if self.max_mean_adjustment > 0:
        perturbation_effects_list = generate_perturbation_effects(
            binding_data_tensor,
            bound_mean=self.bound_mean,
            tf_index=0,  # unused
            max_mean_adjustment=self.max_mean_adjustment,
            adjustment_function=self.adjustment_function,
            tf_relationships=self.tf_relationships,
        )

        perturbation_pvalue_list = torch.zeros_like(perturbation_effects_list)
        for col_index in range(perturbation_effects_list.shape[1]):
            perturbation_pvalue_list[:, col_index] = generate_pvalues(
                perturbation_effects_list[:, col_index]
            )

        # take absolute values
        perturbation_effects_list = torch.abs(perturbation_effects_list)

        perturbation_effects_tensor = perturbation_effects_list
        perturbation_pvalues_tensor = perturbation_pvalue_list
    else:
        perturbation_effects_list = [
            generate_perturbation_effects(
                binding_data_tensor[:, tf_index, :].unsqueeze(1),
                bound_mean=self.bound_mean,
                tf_index=0,  # unused
            )
            for tf_index in range(sum(self.n_sample))
        ]
        perturbation_pvalue_list = [
            generate_pvalues(perturbation_effects)
            for perturbation_effects in perturbation_effects_list
        ]

        # take absolute values
        perturbation_effects_list = [
            torch.abs(perturbation_effects)
            for perturbation_effects in perturbation_effects_list
        ]

        # Convert lists to tensors
        perturbation_effects_tensor = torch.stack(perturbation_effects_list, dim=1)
        perturbation_pvalues_tensor = torch.stack(perturbation_pvalue_list, dim=1)

    # Ensure perturbation data is reshaped to match [n_genes, n_tfs]
    # This step might need adjustment based on the actual shapes of your tensors.
    perturbation_effects_tensor = perturbation_effects_tensor.unsqueeze(
        -1
    )  # Adds an extra dimension for concatenation
    perturbation_pvalues_tensor = perturbation_pvalues_tensor.unsqueeze(
        -1
    )  # Adds an extra dimension for concatenation

    # Concatenate along the last dimension to form a [n_genes, n_tfs, 5] tensor
    self.final_data_tensor = torch.cat(
        (
            binding_data_tensor,
            perturbation_effects_tensor,
            perturbation_pvalues_tensor,
        ),
        dim=2,
    )

setup(stage=None)

This function runs after prepare_data finishes and is used to split the data into train, validation, and test sets It ensures that these datasets are of the correct dimensionality and size to be used by the model.

Parameters:

Name Type Description Default
stage str | None

The stage of the data setup (either ‘fit’ for training, ‘validate’ for validation, or ‘test’ for testing), unused for now as the model is not complicated enough to necessitate this

None
Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def setup(self, stage: str | None = None) -> None:
    """
    This function runs after prepare_data finishes and is used to split the data
    into train, validation, and test sets It ensures that these datasets are of the
    correct dimensionality and size to be used by the model.

    :param stage: The stage of the data setup (either 'fit' for training, 'validate'
        for validation, or 'test' for testing), unused for now as the model is not
        complicated enough to necessitate this
    :type stage: Optional[str]

    """
    self.binding_effect_matrix = self.final_data_tensor[:, :, 1]
    self.perturbation_effect_matrix = self.final_data_tensor[:, :, 3]

    # split into train, val, and test
    X_train, X_temp, Y_train, Y_temp = train_test_split(
        self.binding_effect_matrix,
        self.perturbation_effect_matrix,
        test_size=(self.val_size + self.test_size),
        random_state=self.random_state,
    )

    # normalize test_size so that it is a percentage of the remaining data
    self.test_size = self.test_size / (self.val_size + self.test_size)
    X_val, X_test, Y_val, Y_test = train_test_split(
        X_temp, Y_temp, test_size=self.test_size, random_state=self.random_state
    )

    # Convert to tensors
    X_train, Y_train = torch.tensor(X_train, dtype=torch.float32), torch.tensor(
        Y_train, dtype=torch.float32
    )
    X_val, Y_val = torch.tensor(X_val, dtype=torch.float32), torch.tensor(
        Y_val, dtype=torch.float32
    )
    X_test, Y_test = torch.tensor(X_test, dtype=torch.float32), torch.tensor(
        Y_test, dtype=torch.float32
    )

    # Set our datasets
    self.train_dataset = TensorDataset(X_train, Y_train)
    self.val_dataset = TensorDataset(X_val, Y_val)
    self.test_dataset = TensorDataset(X_test, Y_test)

test_dataloader()

Function to return the testing dataloader.

Returns:

Type Description
DataLoader

The testing dataloader

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def test_dataloader(self) -> DataLoader:
    """
    Function to return the testing dataloader.

    :return: The testing dataloader
    :rtype: DataLoader

    """
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=15,
        shuffle=False,
        persistent_workers=True,
    )

train_dataloader()

Function to return the training dataloader, we shuffle to avoid learning based on the order of the data.

Returns:

Type Description
DataLoader

The training dataloader

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def train_dataloader(self) -> DataLoader:
    """
    Function to return the training dataloader, we shuffle to avoid learning based
    on the order of the data.

    :return: The training dataloader
    :rtype: DataLoader

    """
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        num_workers=15,
        shuffle=True,
        persistent_workers=True,
    )

val_dataloader()

Function to return the validation dataloader.

Returns:

Type Description
DataLoader

The validation dataloader

Source code in yeastdnnexplorer/data_loaders/synthetic_data_loader.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def val_dataloader(self) -> DataLoader:
    """
    Function to return the validation dataloader.

    :return: The validation dataloader
    :rtype: DataLoader

    """
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        num_workers=15,
        shuffle=False,
        persistent_workers=True,
    )