Lightning Crash Course
Lightning Crash Course¶
This project uses the PyTorch Lightning Library to define and train the machine learning models. PyTorch Lightning is built on top of pytorch, and it abstracts away some of the setup and biolerplate for models (such as writing out training loops). In this notebook, we provide a brief introduction to how to use the models and dataModules we’ve defined to train models.
# imports
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from yeastdnnexplorer.data_loaders.synthetic_data_loader import SyntheticDataLoader
from yeastdnnexplorer.data_loaders.real_data_loader import RealDataLoader
from yeastdnnexplorer.ml_models.simple_model import SimpleModel
from yeastdnnexplorer.ml_models.customizable_model import CustomizableModel
In Pytorch Lightning, the data is kept completely separate from the models. This allows for you to easy train a model using different datasets or train different models on the same dataset. DataModules
encapsulate all the logic of loading in a specific dataset and splitting into training, testing, and validation sets. In this project, we have two data loaders defined: SyntheticDataLoader
for the in silico data (which takes in many parameters that allow you to specify how the data is generated) and RealDataLoader
which contains all of the logic for loading in the real experiment data and putting it into a form that the models expect.
Once you decide what model you want to train and what dataModule you want to use, you can bundle these with a Trainer
object to train the model on the dataset.
If you’d like to learn more about the models and dataModules we’ve defined, there is extensive documentation in each of the files that explains each method’s purpose.
# define an instance of our simple linear baseline model
model = SimpleModel(
input_dim=10,
output_dim=10,
lr=1e-2,
)
# define an instance of the synthetic data loader
# see the constructor for the full list of params and their explanations
data_module = SyntheticDataLoader(
batch_size=32,
num_genes=3000,
bound=[0.5] * 5,
n_sample=[1, 1, 2, 2, 4],
val_size=0.1,
test_size=0.1,
bound_mean=3.0,
)
# define a trainer instance
trainer = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu", # change to "gpu" if you have access to one
)
# train the model
trainer.fit(model, data_module)
# test the model (recall that data_module specifies the train / test split, we don't need to do it explicitly here)
test_results = trainer.test(model, data_module)
print(test_results)
It’s very easy to train the same model on a different dataset, for example if we want to use real world data we can just swap to the data module that we’ve defined for the real world data.
# we need to redefine a new instance with the same params unless we want it to pick up where it left off
new_model = SimpleModel(
input_dim=30, # note that the input and output dims are equal to the num TFs in the dataset
output_dim=30,
lr=1e-2,
)
real_data_module = RealDataLoader(
batch_size=32,
val_size=0.1,
test_size=0.1,
data_dir_path="../../data/init_analysis_data_20240409/", # note that this is relative to where real_data_loader.py is
perturbation_dataset_title="hu_reimann_tfko",
)
# we also have to define a new trainer instance, not really sure why but it seems to be necessary
trainer = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu", # change to "gpu" if you have access to one
)
trainer.fit(new_model, real_data_module)
test_results = trainer.test(new_model, real_data_module)
print(test_results)
If we wanted to do the same thing with our more complex and customizable CustomizableModel
(which allows you to pass in many params like the number of hidden layers, dropout rate, choice of optimizer, etc) the code would look identical to above except that we would be initializing a CustomizableModel
instead of a SimpleModel
. See the documentation in customizable_model.py
for more
Checkpointing & Logging¶
PyTorch lightning gives us the power to define checkpoints and loggers that will be used during training. Checkpoints will save checkpoints of your model during training. In the following code, we define a checkpoint that saves the model’s state when it produced the lowest validation mean squared error on the validation set during training. We also define another checkpoint to periodically save a checkpoint of the model after every 2 training epochs. These checkpoints are powerful because they can be reloaded later. You can continue training a model after loading its checkpoint or you can test the model checkpoint on new data.
Loggers are responsible for saving metrics about the model as it is training for us to look at later. We define several loggers to track this data. See the comments above the Tensorboard logger to see how to use Tensorboard to visualize the metrics as the model trains
To use checkpoints and loggers, we have to pass them into the Trainer object that we use to train the model with a dataModule.
There are many more types of checkpoints and loggers you can create and use, PyTorch Lightning’s documentation is very helpful here
# this will be used to save the model checkpoint that performs the best on the validation set
best_model_checkpoint = ModelCheckpoint(
monitor="val_mse", # we can depend on any metric we want
mode="min",
filename="best-model-{epoch:02d}-{val_loss:.2f}",
save_top_k=1, # we can save more than just the top model if we want
)
# Callback to save checkpoints every 2 epochs, regardless of model performance
periodic_checkpoint = ModelCheckpoint(
filename="periodic-{epoch:02d}",
every_n_epochs=2,
save_top_k=-1, # Setting -1 saves all checkpoints
)
# csv logger is a very basic logger that will create a csv file with our metrics as we train
csv_logger = CSVLogger("logs/csv_logs") # we define the directory we want the logs to be saved in
# tensorboard logger is a more advanced logger that will create a directory with a bunch of files that can be visualized with tensorboard
# tensorboard is a library that can be ran via the command line, and will create a local server that can be accessed via a web browser
# that displays the training metrics in a more interactive way (on a dashboard)
# you can run tensorboard by running the command `tensorboard --logdir=path/to/log/dir` in the terminal
tb_logger = TensorBoardLogger("logs/tensorboard_logs", name="test-run-2")
# If we wanted to use these checkpoints and loggers, we would pass them to the trainer like so:
trainer_with_checkpoints_and_loggers = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu",
callbacks=[best_model_checkpoint, periodic_checkpoint],
logger=[csv_logger, tb_logger],
)
Loading in and using a Checkpoint¶
# Load a model from a checkpoint
# We can load a model from a checkpoint like so:
path_to_checkpoint = "example/path/not/real.ckpt"
# note that we need to use the same model class that was used to save the checkpoint
model = SimpleModel.load_from_checkpoint(path_to_checkpoint)
# we can load the model and continue training from where it left off
trainer.fit(model, data_module)
# we could also load the model and test it
test_results = trainer.test(model, data_module)
# we could also load the model and make predictions
predictions = model(data_module.test_dataloader())