Skip to content

Lightning Crash Course

Lightning Crash Course

This project uses the PyTorch Lightning Library to define and train the machine learning models. PyTorch Lightning is built on top of pytorch, and it abstracts away some of the setup and biolerplate for models (such as writing out training loops). In this notebook, we provide a brief introduction to how to use the models and dataModules we’ve defined to train models.

# imports
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from yeastdnnexplorer.data_loaders.synthetic_data_loader import SyntheticDataLoader
from yeastdnnexplorer.data_loaders.real_data_loader import RealDataLoader
from yeastdnnexplorer.ml_models.simple_model import SimpleModel
from yeastdnnexplorer.ml_models.customizable_model import CustomizableModel

In Pytorch Lightning, the data is kept completely separate from the models. This allows for you to easy train a model using different datasets or train different models on the same dataset. DataModules encapsulate all the logic of loading in a specific dataset and splitting into training, testing, and validation sets. In this project, we have two data loaders defined: SyntheticDataLoader for the in silico data (which takes in many parameters that allow you to specify how the data is generated) and RealDataLoader which contains all of the logic for loading in the real experiment data and putting it into a form that the models expect.

Once you decide what model you want to train and what dataModule you want to use, you can bundle these with a Trainer object to train the model on the dataset.

If you’d like to learn more about the models and dataModules we’ve defined, there is extensive documentation in each of the files that explains each method’s purpose.

# define an instance of our simple linear baseline model
model = SimpleModel(
    input_dim=10,
    output_dim=10,
    lr=1e-2,
)

# define an instance of the synthetic data loader
# see the constructor for the full list of params and their explanations
data_module = SyntheticDataLoader(
    batch_size=32,
    num_genes=3000,
    bound=[0.5] * 5,
    n_sample=[1, 1, 2, 2, 4],
    val_size=0.1,
    test_size=0.1,
    bound_mean=3.0,
)

# define a trainer instance
trainer = Trainer(
    max_epochs=10,
    deterministic=True,
    accelerator="cpu", # change to "gpu" if you have access to one
)

# train the model
trainer.fit(model, data_module)

# test the model (recall that data_module specifies the train / test split, we don't need to do it explicitly here)
test_results = trainer.test(model, data_module)
print(test_results)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/ericjia/yeastdnnexplorer/docs/tutorials/lightning_logs
/Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:260: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X_train, Y_train = torch.tensor(X_train, dtype=torch.float32), torch.tensor(
/Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:263: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X_val, Y_val = torch.tensor(X_val, dtype=torch.float32), torch.tensor(
/Users/ericjia/yeastdnnexplorer/yeastdnnexplorer/data_loaders/synthetic_data_loader.py:266: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  X_test, Y_test = torch.tensor(X_test, dtype=torch.float32), torch.tensor(

  | Name    | Type              | Params
----------------------------------------------
0 | mae     | MeanAbsoluteError | 0     
1 | SMSE    | SMSE              | 0     
2 | linear1 | Linear            | 110   
----------------------------------------------
110       Trainable params
0         Non-trainable params
110       Total params
0.000     Total estimated model params size (MB)

Sanity Checking: |                                                                                            …
/Users/ericjia/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 15 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(

Training: |                                                                                                   …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
Validation: |                                                                                                 …
`Trainer.fit` stopped: `max_epochs=10` reached.

Testing: |                                                                                                    …
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_mae            1.1637259721755981
        test_mse            1.8661913871765137
        test_smse           10.101052284240723
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_mse': 1.8661913871765137, 'test_mae': 1.1637259721755981, 'test_smse': 10.101052284240723}]

It’s very easy to train the same model on a different dataset, for example if we want to use real world data we can just swap to the data module that we’ve defined for the real world data.

# we need to redefine a new instance with the same params unless we want it to pick up where it left off
new_model = SimpleModel(
    input_dim=30,  # note that the input and output dims are equal to the num TFs in the dataset
    output_dim=30,
    lr=1e-2,
)

real_data_module = RealDataLoader(
    batch_size=32,
    val_size=0.1,
    test_size=0.1,
    data_dir_path="../../data/init_analysis_data_20240409/", # note that this is relative to where real_data_loader.py is
    perturbation_dataset_title="hu_reimann_tfko",
)

# we also have to define a new trainer instance, not really sure why but it seems to be necessary
trainer = Trainer(
    max_epochs=10,
    deterministic=True,
    accelerator="cpu", # change to "gpu" if you have access to one
)

trainer.fit(new_model, real_data_module)
test_results = trainer.test(new_model, real_data_module)
print(test_results)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[3], line 23
     16 # we also have to define a new trainer instance, not really sure why but it seems to be necessary
     17 trainer = Trainer(
     18     max_epochs=10,
     19     deterministic=True,
     20     accelerator="cpu", # change to "gpu" if you have access to one
     21 )
---> 23 trainer.fit(new_model, real_data_module)
     24 test_results = trainer.test(new_model, real_data_module)
     25 print(test_results)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:947, in Trainer._run(self, model, ckpt_path)
    944 self.__setup_profiler()
    946 log.debug(f"{self.__class__.__name__}: preparing data")
--> 947 self._data_connector.prepare_data()
    949 call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
    950 log.debug(f"{self.__class__.__name__}: configuring model")

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:94, in _DataConnector.prepare_data(self)
     92     dm_prepare_data_per_node = datamodule.prepare_data_per_node
     93     if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
---> 94         call._call_lightning_datamodule_hook(trainer, "prepare_data")
     95 # handle lightning module prepare data:
     96 # check for prepare_data_per_node before calling lightning_module.prepare_data
     97 if lightning_module is not None:

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:179, in _call_lightning_datamodule_hook(trainer, hook_name, *args, **kwargs)
    177 if callable(fn):
    178     with trainer.profiler.profile(f"[LightningDataModule]{trainer.datamodule.__class__.__name__}.{hook_name}"):
--> 179         return fn(*args, **kwargs)
    180 return None

File ~/yeastdnnexplorer/yeastdnnexplorer/data_loaders/real_data_loader.py:118, in RealDataLoader.prepare_data(self)
    106 """
    107 This function reads in the binding data and perturbation data from the CSV files
    108 that we have for these datasets.
   (...)
    113 
    114 """
    116 brent_cc_path = os.path.join(self.data_dir_path, "binding/brent_nf_cc")
    117 brent_nf_csv_files = [
--> 118     f for f in os.listdir(brent_cc_path) if f.endswith(".csv")
    119 ]
    120 perturb_dataset_path = os.path.join(
    121     self.data_dir_path, f"perturbation/{self.perturbation_dataset_title}"
    122 )
    123 perturb_dataset_csv_files = [
    124     f for f in os.listdir(perturb_dataset_path) if f.endswith(".csv")
    125 ]

FileNotFoundError: [Errno 2] No such file or directory: '../../data/init_analysis_data_20240409/binding/brent_nf_cc'

If we wanted to do the same thing with our more complex and customizable CustomizableModel (which allows you to pass in many params like the number of hidden layers, dropout rate, choice of optimizer, etc) the code would look identical to above except that we would be initializing a CustomizableModel instead of a SimpleModel. See the documentation in customizable_model.py for more

Checkpointing & Logging

PyTorch lightning gives us the power to define checkpoints and loggers that will be used during training. Checkpoints will save checkpoints of your model during training. In the following code, we define a checkpoint that saves the model’s state when it produced the lowest validation mean squared error on the validation set during training. We also define another checkpoint to periodically save a checkpoint of the model after every 2 training epochs. These checkpoints are powerful because they can be reloaded later. You can continue training a model after loading its checkpoint or you can test the model checkpoint on new data.

Loggers are responsible for saving metrics about the model as it is training for us to look at later. We define several loggers to track this data. See the comments above the Tensorboard logger to see how to use Tensorboard to visualize the metrics as the model trains

To use checkpoints and loggers, we have to pass them into the Trainer object that we use to train the model with a dataModule.

There are many more types of checkpoints and loggers you can create and use, PyTorch Lightning’s documentation is very helpful here

# this will be used to save the model checkpoint that performs the best on the validation set
best_model_checkpoint = ModelCheckpoint(
    monitor="val_mse", # we can depend on any metric we want
    mode="min",
    filename="best-model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1, # we can save more than just the top model if we want
)

# Callback to save checkpoints every 2 epochs, regardless of model performance
periodic_checkpoint = ModelCheckpoint(
    filename="periodic-{epoch:02d}",
    every_n_epochs=2,
    save_top_k=-1,  # Setting -1 saves all checkpoints  
)

# csv logger is a very basic logger that will create a csv file with our metrics as we train
csv_logger = CSVLogger("logs/csv_logs")  # we define the directory we want the logs to be saved in

# tensorboard logger is a more advanced logger that will create a directory with a bunch of files that can be visualized with tensorboard
# tensorboard is a library that can be ran via the command line, and will create a local server that can be accessed via a web browser
# that displays the training metrics in a more interactive way (on a dashboard)
# you can run tensorboard by running the command `tensorboard --logdir=path/to/log/dir` in the terminal
tb_logger = TensorBoardLogger("logs/tensorboard_logs", name="test-run-2")

# If we wanted to use these checkpoints and loggers, we would pass them to the trainer like so:
trainer_with_checkpoints_and_loggers = Trainer(
    max_epochs=10,
    deterministic=True,
    accelerator="cpu",
    callbacks=[best_model_checkpoint, periodic_checkpoint],
    logger=[csv_logger, tb_logger],
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Loading in and using a Checkpoint

# Load a model from a checkpoint
# We can load a model from a checkpoint like so:
path_to_checkpoint = "example/path/not/real.ckpt"

# note that we need to use the same model class that was used to save the checkpoint
model = SimpleModel.load_from_checkpoint(path_to_checkpoint)

# we can load the model and continue training from where it left off
trainer.fit(model, data_module)

# we could also load the model and test it
test_results = trainer.test(model, data_module)

# we could also load the model and make predictions
predictions = model(data_module.test_dataloader())
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[5], line 6
      3 path_to_checkpoint = "example/path/not/real.ckpt"
      5 # note that we need to use the same model class that was used to save the checkpoint
----> 6 model = SimpleModel.load_from_checkpoint(path_to_checkpoint)
      8 # we can load the model and continue training from where it left off
      9 trainer.fit(model, data_module)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/utilities/model_helpers.py:125, in _restricted_classmethod_impl.__get__.<locals>.wrapper(*args, **kwargs)
    120 if instance is not None and not is_scripting:
    121     raise TypeError(
    122         f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
    123         " Please call it on the class type and make sure the return value is used."
    124     )
--> 125 return self.method(cls, *args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/core/module.py:1581, in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
   1492 @_restricted_classmethod
   1493 def load_from_checkpoint(
   1494     cls,
   (...)
   1499     **kwargs: Any,
   1500 ) -> Self:
   1501     r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
   1502     passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
   1503 
   (...)
   1579 
   1580     """
-> 1581     loaded = _load_from_checkpoint(
   1582         cls,  # type: ignore[arg-type]
   1583         checkpoint_path,
   1584         map_location,
   1585         hparams_file,
   1586         strict,
   1587         **kwargs,
   1588     )
   1589     return cast(Self, loaded)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:63, in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
     61 map_location = map_location or _default_map_location
     62 with pl_legacy_patch():
---> 63     checkpoint = pl_load(checkpoint_path, map_location=map_location)
     65 # convert legacy checkpoints to the new format
     66 checkpoint = _pl_migrate_checkpoint(
     67     checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
     68 )

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/lightning_fabric/utilities/cloud_io.py:56, in _load(path_or_url, map_location)
     51     return torch.hub.load_state_dict_from_url(
     52         str(path_or_url),
     53         map_location=map_location,  # type: ignore[arg-type]
     54     )
     55 fs = get_filesystem(path_or_url)
---> 56 with fs.open(path_or_url, "rb") as f:
     57     return torch.load(f, map_location=map_location)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/spec.py:1298, in AbstractFileSystem.open(self, path, mode, block_size, cache_options, compression, **kwargs)
   1296 else:
   1297     ac = kwargs.pop("autocommit", not self._intrans)
-> 1298     f = self._open(
   1299         path,
   1300         mode=mode,
   1301         block_size=block_size,
   1302         autocommit=ac,
   1303         cache_options=cache_options,
   1304         **kwargs,
   1305     )
   1306     if compression is not None:
   1307         from fsspec.compression import compr

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:191, in LocalFileSystem._open(self, path, mode, block_size, **kwargs)
    189 if self.auto_mkdir and "w" in mode:
    190     self.makedirs(self._parent(path), exist_ok=True)
--> 191 return LocalFileOpener(path, mode, fs=self, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:355, in LocalFileOpener.__init__(self, path, mode, autocommit, fs, compression, **kwargs)
    353 self.compression = get_compression(path, compression)
    354 self.blocksize = io.DEFAULT_BUFFER_SIZE
--> 355 self._open()

File ~/Library/Caches/pypoetry/virtualenvs/yeastdnnexplorer-iu4_cpc2-py3.11/lib/python3.11/site-packages/fsspec/implementations/local.py:360, in LocalFileOpener._open(self)
    358 if self.f is None or self.f.closed:
    359     if self.autocommit or "w" not in self.mode:
--> 360         self.f = open(self.path, mode=self.mode)
    361         if self.compression:
    362             compress = compr[self.compression]

FileNotFoundError: [Errno 2] No such file or directory: '/Users/ericjia/yeastdnnexplorer/docs/tutorials/example/path/not/real.ckpt'