Simple model
SimpleModel
¶
Bases: LightningModule
A class for a simple linear model that takes in binding effects for each transcription factor and predicts gene expression values This class contains all of the logic for setup, training, validation, and testing of the model, as well as defining how data is passed through the model It is a subclass of pytorch_lightning.LightningModule, which is similar to a regular PyTorch nn.module but with added functionality for training and validation.
Source code in yeastdnnexplorer/ml_models/simple_model.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|
__init__(input_dim, output_dim, lr=0.001)
¶
Constructor of SimpleModel.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int
|
The number of input features to our model, these are the binding effects for each transcription factor for a specific gene |
required |
output_dim |
int
|
The number of output features of our model, this is the predicted gene expression value for each TF |
required |
lr |
float
|
The learning rate for the optimizer |
0.001
|
Raises:
Type | Description |
---|---|
TypeError
|
If input_dim is not an integer |
TypeError
|
If output_dim is not an integer |
TypeError
|
If lr is not a positive float |
ValueError
|
If input_dim or output_dim are not positive |
Source code in yeastdnnexplorer/ml_models/simple_model.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|
configure_optimizers()
¶
Configure the optimizer for the model.
Returns:
Type | Description |
---|---|
Optimizer
|
The optimizer for the model |
Source code in yeastdnnexplorer/ml_models/simple_model.py
139 140 141 142 143 144 145 146 147 |
|
forward(x)
¶
Forward pass of the model (i.e. how predictions are made for a given input)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor
|
The input data to the model (minus the target y values) |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The predicted y values for the input x, this is a tensor of shape (batch_size, output_dim) |
Source code in yeastdnnexplorer/ml_models/simple_model.py
59 60 61 62 63 64 65 66 67 68 69 70 |
|
test_step(batch, batch_idx)
¶
Test step for the model, this is called for each batch of data during testing Testing is only performed after training and validation when we have chosen a final model We want to test our final model on unseen data (which is why we use validation sets to “test” during training)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
Any
|
The batch of data to test on (this will have size (batch_size, input_dim) |
required |
batch_idx |
int
|
The index of the batch |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The loss for the test batch |
Source code in yeastdnnexplorer/ml_models/simple_model.py
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
|
training_step(batch, batch_idx)
¶
Training step for the model, this is called for each batch of data during training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
Any
|
The batch of data to train on |
required |
batch_idx |
int
|
The index of the batch |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The loss for the training batch |
Source code in yeastdnnexplorer/ml_models/simple_model.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
|
validation_step(batch, batch_idx)
¶
Validation step for the model, this is called for each batch of data during validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
Any
|
The batch of data to validate on |
required |
batch_idx |
int
|
The index of the batch |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The loss for the validation batch |
Source code in yeastdnnexplorer/ml_models/simple_model.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|