Lightning Crash Course
Lightning Crash Course¶
This project uses the PyTorch Lightning Library to define and train the machine learning models. PyTorch Lightning is built on top of pytorch, and it abstracts away some of the setup and biolerplate for models (such as writing out training loops). In this notebook, we provide a brief introduction to how to use the models and dataModules we've defined to train models.
# imports
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from yeastdnnexplorer.data_loaders.synthetic_data_loader import SyntheticDataLoader
from yeastdnnexplorer.data_loaders.real_data_loader import RealDataLoader
from yeastdnnexplorer.ml_models.simple_model import SimpleModel
from yeastdnnexplorer.ml_models.customizable_model import CustomizableModel
In Pytorch Lightning, the data is kept completely separate from the models. This allows for you to easy train a model using different datasets or train different models on the same dataset. DataModules
encapsulate all the logic of loading in a specific dataset and splitting into training, testing, and validation sets. In this project, we have two data loaders defined: SyntheticDataLoader
for the in silico data (which takes in many parameters that allow you to specify how the data is generated) and RealDataLoader
which contains all of the logic for loading in the real experiment data and putting it into a form that the models expect.
Once you decide what model you want to train and what dataModule you want to use, you can bundle these with a Trainer
object to train the model on the dataset.
If you'd like to learn more about the models and dataModules we've defined, there is extensive documentation in each of the files that explains each method's purpose.
# define an instance of our simple linear baseline model
model = SimpleModel(
input_dim=10,
output_dim=10,
lr=1e-2,
)
# define an instance of the synthetic data loader
# see the constructor for the full list of params and their explanations
data_module = SyntheticDataLoader(
batch_size=32,
num_genes=3000,
bound=[0.5] * 5,
n_sample=[1, 1, 2, 2, 4],
val_size=0.1,
test_size=0.1,
bound_mean=3.0,
)
# define a trainer instance
trainer = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu", # change to "gpu" if you have access to one
)
# train the model
trainer.fit(model, data_module)
# test the model (recall that data_module specifies the train / test split, we don't need to do it explicitly here)
test_results = trainer.test(model, data_module)
print(test_results)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Missing logger folder: /Users/ericjia/yeastdnnexplorer/docs/tutorials/lightning_logs /Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:260: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X_train, Y_train = torch.tensor(X_train, dtype=torch.float32), torch.tensor( /Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:263: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X_val, Y_val = torch.tensor(X_val, dtype=torch.float32), torch.tensor( /Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:266: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). X_test, Y_test = torch.tensor(X_test, dtype=torch.float32), torch.tensor( | Name | Type | Params ---------------------------------------------- 0 | mae | MeanAbsoluteError | 0 1 | SMSE | SMSE | 0 2 | linear1 | Linear | 110 ---------------------------------------------- 110 Trainable params 0 Non-trainable params 110 Total params 0.000 Total estimated model params size (MB)
Sanity Checking: | …
/Users/ericjia/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 15 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg(
Training: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
Validation: | …
`Trainer.fit` stopped: `max_epochs=10` reached.
Testing: | …
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── test_mae 1.1637259721755981 test_mse 1.8661913871765137 test_smse 10.101052284240723 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── [{'test_mse': 1.8661913871765137, 'test_mae': 1.1637259721755981, 'test_smse': 10.101052284240723}]
It's very easy to train the same model on a different dataset, for example if we want to use real world data we can just swap to the data module that we've defined for the real world data.
# we need to redefine a new instance with the same params unless we want it to pick up where it left off
new_model = SimpleModel(
input_dim=30, # note that the input and output dims are equal to the num TFs in the dataset
output_dim=30,
lr=1e-2,
)
real_data_module = RealDataLoader(
batch_size=32,
val_size=0.1,
test_size=0.1,
data_dir_path="../../data/init_analysis_data_20240409/", # note that this is relative to where real_data_loader.py is
perturbation_dataset_title="hu_reimann_tfko",
)
# we also have to define a new trainer instance, not really sure why but it seems to be necessary
trainer = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu", # change to "gpu" if you have access to one
)
trainer.fit(new_model, real_data_module)
test_results = trainer.test(new_model, real_data_module)
print(test_results)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
--------------------------------------------------------------------------- FileNotFoundError Traceback (most recent call last) Cell In[3], line 23 16 # we also have to define a new trainer instance, not really sure why but it seems to be necessary 17 trainer = Trainer( 18 max_epochs=10, 19 deterministic=True, 20 accelerator="cpu", # change to "gpu" if you have access to one 21 ) ---> 23 trainer.fit(new_model, real_data_module) 24 test_results = trainer.test(new_model, real_data_module) 25 print(test_results) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 ) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) ---> 44 return trainer_fn(*args, **kwargs) 46 except _TunerExitException: 47 _call_teardown_hook(trainer) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 573 assert self.state.fn is not None 574 ckpt_path = self._checkpoint_connector._select_ckpt_path( 575 self.state.fn, 576 ckpt_path, 577 model_provided=True, 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 582 assert self.state.stopped 583 self.training = False File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:947, in Trainer._run(self, model, ckpt_path) 944 self.__setup_profiler() 946 log.debug(f"{self.__class__.__name__}: preparing data") --> 947 self._data_connector.prepare_data() 949 call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment 950 log.debug(f"{self.__class__.__name__}: configuring model") File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:94, in _DataConnector.prepare_data(self) 92 dm_prepare_data_per_node = datamodule.prepare_data_per_node 93 if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): ---> 94 call._call_lightning_datamodule_hook(trainer, "prepare_data") 95 # handle lightning module prepare data: 96 # check for prepare_data_per_node before calling lightning_module.prepare_data 97 if lightning_module is not None: File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:179, in _call_lightning_datamodule_hook(trainer, hook_name, *args, **kwargs) 177 if callable(fn): 178 with trainer.profiler.profile(f"[LightningDataModule]{trainer.datamodule.__class__.__name__}.{hook_name}"): --> 179 return fn(*args, **kwargs) 180 return None File ~/yeastdnnexplorer/yeastdnnexplorer/data_loaders/real_data_loader.py:118, in RealDataLoader.prepare_data(self) 106 """ 107 This function reads in the binding data and perturbation data from the CSV files 108 that we have for these datasets. (...) 113 114 """ 116 brent_cc_path = os.path.join(self.data_dir_path, "binding/brent_nf_cc") 117 brent_nf_csv_files = [ --> 118 f for f in os.listdir(brent_cc_path) if f.endswith(".csv") 119 ] 120 perturb_dataset_path = os.path.join( 121 self.data_dir_path, f"perturbation/{self.perturbation_dataset_title}" 122 ) 123 perturb_dataset_csv_files = [ 124 f for f in os.listdir(perturb_dataset_path) if f.endswith(".csv") 125 ] FileNotFoundError: [Errno 2] No such file or directory: '../../data/init_analysis_data_20240409/binding/brent_nf_cc'
If we wanted to do the same thing with our more complex and customizable CustomizableModel
(which allows you to pass in many params like the number of hidden layers, dropout rate, choice of optimizer, etc) the code would look identical to above except that we would be initializing a CustomizableModel
instead of a SimpleModel
. See the documentation in customizable_model.py
for more
Checkpointing & Logging¶
PyTorch lightning gives us the power to define checkpoints and loggers that will be used during training. Checkpoints will save checkpoints of your model during training. In the following code, we define a checkpoint that saves the model's state when it produced the lowest validation mean squared error on the validation set during training. We also define another checkpoint to periodically save a checkpoint of the model after every 2 training epochs. These checkpoints are powerful because they can be reloaded later. You can continue training a model after loading its checkpoint or you can test the model checkpoint on new data.
Loggers are responsible for saving metrics about the model as it is training for us to look at later. We define several loggers to track this data. See the comments above the Tensorboard logger to see how to use Tensorboard to visualize the metrics as the model trains
To use checkpoints and loggers, we have to pass them into the Trainer object that we use to train the model with a dataModule.
There are many more types of checkpoints and loggers you can create and use, PyTorch Lightning's documentation is very helpful here
# this will be used to save the model checkpoint that performs the best on the validation set
best_model_checkpoint = ModelCheckpoint(
monitor="val_mse", # we can depend on any metric we want
mode="min",
filename="best-model-{epoch:02d}-{val_loss:.2f}",
save_top_k=1, # we can save more than just the top model if we want
)
# Callback to save checkpoints every 2 epochs, regardless of model performance
periodic_checkpoint = ModelCheckpoint(
filename="periodic-{epoch:02d}",
every_n_epochs=2,
save_top_k=-1, # Setting -1 saves all checkpoints
)
# csv logger is a very basic logger that will create a csv file with our metrics as we train
csv_logger = CSVLogger("logs/csv_logs") # we define the directory we want the logs to be saved in
# tensorboard logger is a more advanced logger that will create a directory with a bunch of files that can be visualized with tensorboard
# tensorboard is a library that can be ran via the command line, and will create a local server that can be accessed via a web browser
# that displays the training metrics in a more interactive way (on a dashboard)
# you can run tensorboard by running the command `tensorboard --logdir=path/to/log/dir` in the terminal
tb_logger = TensorBoardLogger("logs/tensorboard_logs", name="test-run-2")
# If we wanted to use these checkpoints and loggers, we would pass them to the trainer like so:
trainer_with_checkpoints_and_loggers = Trainer(
max_epochs=10,
deterministic=True,
accelerator="cpu",
callbacks=[best_model_checkpoint, periodic_checkpoint],
logger=[csv_logger, tb_logger],
)
GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Loading in and using a Checkpoint¶
# Load a model from a checkpoint
# We can load a model from a checkpoint like so:
path_to_checkpoint = "example/path/not/real.ckpt"
# note that we need to use the same model class that was used to save the checkpoint
model = SimpleModel.load_from_checkpoint(path_to_checkpoint)
# we can load the model and continue training from where it left off
trainer.fit(model, data_module)
# we could also load the model and test it
test_results = trainer.test(model, data_module)
# we could also load the model and make predictions
predictions = model(data_module.test_dataloader())
--------------------------------------------------------------------------- FileNotFoundError Traceback (most recent call last) Cell In[5], line 6 3 path_to_checkpoint = "example/path/not/real.ckpt" 5 # note that we need to use the same model class that was used to save the checkpoint ----> 6 model = SimpleModel.load_from_checkpoint(path_to_checkpoint) 8 # we can load the model and continue training from where it left off 9 trainer.fit(model, data_module) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/utilities/model_helpers.py:125, in _restricted_classmethod_impl.__get__.<locals>.wrapper(*args, **kwargs) 120 if instance is not None and not is_scripting: 121 raise TypeError( 122 f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." 123 " Please call it on the class type and make sure the return value is used." 124 ) --> 125 return self.method(cls, *args, **kwargs) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/core/module.py:1581, in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs) 1492 @_restricted_classmethod 1493 def load_from_checkpoint( 1494 cls, (...) 1499 **kwargs: Any, 1500 ) -> Self: 1501 r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments 1502 passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. 1503 (...) 1579 1580 """ -> 1581 loaded = _load_from_checkpoint( 1582 cls, # type: ignore[arg-type] 1583 checkpoint_path, 1584 map_location, 1585 hparams_file, 1586 strict, 1587 **kwargs, 1588 ) 1589 return cast(Self, loaded) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:63, in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs) 61 map_location = map_location or _default_map_location 62 with pl_legacy_patch(): ---> 63 checkpoint = pl_load(checkpoint_path, map_location=map_location) 65 # convert legacy checkpoints to the new format 66 checkpoint = _pl_migrate_checkpoint( 67 checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None) 68 ) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/lightning_fabric/utilities/cloud_io.py:56, in _load(path_or_url, map_location) 51 return torch.hub.load_state_dict_from_url( 52 str(path_or_url), 53 map_location=map_location, # type: ignore[arg-type] 54 ) 55 fs = get_filesystem(path_or_url) ---> 56 with fs.open(path_or_url, "rb") as f: 57 return torch.load(f, map_location=map_location) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/spec.py:1298, in AbstractFileSystem.open(self, path, mode, block_size, cache_options, compression, **kwargs) 1296 else: 1297 ac = kwargs.pop("autocommit", not self._intrans) -> 1298 f = self._open( 1299 path, 1300 mode=mode, 1301 block_size=block_size, 1302 autocommit=ac, 1303 cache_options=cache_options, 1304 **kwargs, 1305 ) 1306 if compression is not None: 1307 from fsspec.compression import compr File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:191, in LocalFileSystem._open(self, path, mode, block_size, **kwargs) 189 if self.auto_mkdir and "w" in mode: 190 self.makedirs(self._parent(path), exist_ok=True) --> 191 return LocalFileOpener(path, mode, fs=self, **kwargs) File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:355, in LocalFileOpener.__init__(self, path, mode, autocommit, fs, compression, **kwargs) 353 self.compression = get_compression(path, compression) 354 self.blocksize = io.DEFAULT_BUFFER_SIZE --> 355 self._open() File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:360, in LocalFileOpener._open(self) 358 if self.f is None or self.f.closed: 359 if self.autocommit or "w" not in self.mode: --> 360 self.f = open(self.path, mode=self.mode) 361 if self.compression: 362 compress = compr[self.compression] FileNotFoundError: [Errno 2] No such file or directory: '/Users/ericjia/yeastdnnexplorer/docs/tutorials/example/path/not/real.ckpt'