How-to Guides
In this section, you can quickly look up how to implement a specific base class in Lightning-Boost and how to use its features.
Directory Structure
The recommended directory structure of Lightning-Boost is as follows:
.
├── data/
│ ├── datamodules/
│ └── datasets/
├── models/
├── modules/
│ ├── loss/
│ ├── metrics/
│ ├── preprocessing/
│ └── trainable/
└── systems/
The purpose of most directories should be self-explanatory from their names.
In particular, we differentiate between loss functions
and metrics
, depending on their inclusion in the computation of gradients, as well as models
and systems
, depending on their usage as functions with well-defined inputs and outputs, or instances that manage the entire training process for one or more such models (see Explanation/Models vs. Systems).
Components in the preprocessing
directory should also not contribute to the computation of gradients, but transform the input and target data before entering the model(s).
Modules in the trainable
directory, by contrast, are intended to be repetitive lower-level building blocks of models.
Dataset
Base class: lightning_boost.data.datasets.BaseDataset
Mandatory methods
__init__(self, root: str, download: bool = False, transform: Optional[BaseTransform] = None, **kwargs) -> None
:- Call super-class method first.
- Load dataset from disk, stored at
self.path
.
get_item(self, index: int) -> Tuple[Dict[str, Any], Dict[str, Any]]
:- Return input and target data (as dictionaries) at index.
__len__(self) -> int
:- Return dataset size.
Optional methods
download(self) -> None
:- Download dataset from the internet.
- Store on disk at
self.path
.
Transform
Base class: lightning_boost.modules.preprocessing.BaseTransform
Mandatory methods
def __call__(self, inputs: Dict[str, Any], targets: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]
:- Perform transforms on elements in inputs and targets dictionaries.
- Return two dictionaries again, keys can vary from received dictionaries.
Implementation tipps
- Define a subclass of
BaseTransform
for each subclass ofBaseDataset
. - To implement a composition of multiple transforms, use the class
lightning_boost.modules.preprocessing.CompositeTransform
. - Embed it as attribute in another
BaseTransform
subclass, then invoke its__call__()
method in the__call__()
method of the latter.
Collator
Base class: lightning_boost.modules.preprocessing.BaseCollator
Mandatory methods
get_collate_fn(self) -> Dict[str, Callable[[List[Tensor]], Tensor]]
:- Return a dictionary of collate functions, one per key/data type in the transform's output.
- Use pre-defined collate functions
pad_collate_nd()
(stacks n-dimensional, potentially differently shaped tensors along new dimension, using padding) andflatten_collate()
(concatenates zero-dimensional tensors/scalars to a one-dimensional tensor), if possible. - For
pad_collate_nd()
, you can specify custom padding values as well as padding shapes through the initialization parameterspad_val
,pad_shape
andpad_dim
ofBaseCollator
.
Implementation tipps
- Define a subclass of
BaseCollator
for each subclass ofBaseDataset
.
Datamodule
Base class: lightning_boost.data.datamodules.BaseDatamodule
Mandatory methods
get_collator(self, **kwargs) -> BaseCollator
:- Returns collator instance for the used dataset.
get_dataset_type(self, **kwargs) -> Type[BaseDataset]
:- Returns type (not instance!) of the used dataset.
get_transform(self, **kwargs) -> BaseTransform
:- Returns transform instance for the used dataset.
get_train_test_split(self) -> Tuple[BaseDataset, BaseDataset]
:- Returns training-test split for the used dataset.
- Use the attribute
test_ratio
to build the split based on pre-defined ratios. - Use the method
instantiate_dataset()
to get an instance of the used dataset without explicitly passing the parametersroot
,download
andtransform
ofBaseDataset
.
Optional methods
get_train_val_split(self) -> Tuple[BaseDataset, BaseDataset]
:- Returns training-validation split for the used dataset.
- By default, uses the attribute
test_ratio
to build the split based on pre-defined ratios.
get_cv_train_val_split(self) -> Tuple[BaseDataset, BaseDataset]
:- Returns training-validation split for the used dataset, when performing cross-validation.
- By default, splits dataset into k parts and selects (k-1)/k and 1/k as training and validation split, respectively.
determine_cv_indices(self) -> None
:- Determines permutation of data indices for cross-validation.
- By default, generates random permutation using the RNG seed
fold_seed
.
determine_fold_len(self) -> None
:- Determines fold sizes for cross-validation.
- By default, uses equisized folds.
Model
Base class: lightning_boost.models.BaseModel
Mandatory methods
__init__(self, name: str | None = None) -> None
:- Call super-class method first.
- Define model's submodules.
- Specify parameter
name
explicitly in a unique manner, if multiple instances of the same model class are to be used in the system.
forward(self, *args: Tensor) -> Tensor | Sequence[Tensor]
:- Perform forward pass by feeding input tensors into the model's submodules, return output tensor(s).
System
Base class: lightning_boost.systems.BaseSystem
Mandatory methods
step(self, inputs: Dict[str, Tensor], targets: Dict[str, Tensor]) -> Dict[str, Tensor]
:- Extract input tensors from inputs dictionary.
- Feed input tensors into model(s), receive prediction tensor(s).
- Return dictionary of prediction tensor(s), where keys correspond to tasks (they must match the keys of the target data dictionary!).
Loss functions and metrics
Loss functions and metrics do not need to be implemented using a base class.
Instead, simply import them in __init__.py
files in the corresponding directories ./modules/loss
and ./modules/metrics
.
Custom loss functions and metrics can be implemented by subclassing `torch.nn.Module
and torchmetrics.Metric
, respectively.
Main script
First, import all components (usually, models, systems, datamodules, datasets, loss functions, and metrics) you want to make accessible via the command line interface.
This can be simplified by creating __init__.py
files in the corresponding directories.
Also import lightning_boost.cli.LightningBoostCLI
.
Then, instantiate the LightningBoostCLI
in the main function.
Execution
Command line interface
Call the main script with one of the subcommands fit
, validate
or test
and mandatory arguments --data
and --system
, where you pass the class names of the datamodule and the system to be used, respectively.
Their non-default parameters need to be set subsequently.
While BaseDatamodule
has no non-default parameters, BaseSystem
needs the following to be specified:
models
: One or more subclasses ofBaseModel
.loss
: One or more subclasses oflightning_boost.modules.loss.TaskLoss
.optimizer
: Subclass oftorch.optim.Optimizer
.
Further optional parameters are:
lr_scheduler
: Subclass oftorch.optim.LRScheduler
.lr_scheduling_policy
: Uselightning_boost.modules.optim.LRSchedulingPolicy
(default), adapt parameters if needed.train_metrics
: One or more suclasses oflightning_boost.modules.metrics.TaskMetric
.val_metrics
: One or more suclasses oflightning_boost.modules.metrics.TaskMetric
.test_metrics
: One or more suclasses oflightning_boost.modules.metrics.TaskMetric
.
In the CLI,
- arguments are passed using the operator
=
, e.g,--arg_x=...
. - initialization parameters are set using dot notation, e.g.,
--arg_x=class_a --arg_x.param_1=...
. - list arguments can be passed using the operator
+=
and specifying the initialization parameters directly after,--arg_x+=class_a --arg_x.param_1=... --arg_x+=class_b --arg_x.param_1=...
.
YAML configuration
Create a YAML configuration file by adding the flag --print_config
and the pipe > config.yaml
to the command.
Default parameters can be removed from the configuration file to increase readability.
Run with a configuration file using the argument --config=config.yaml
.