Skip to content

base_model

BaseModel

Bases: LightningModule

Base class for DL model.

Source code in lightning_boost/models/base_model.py
class BaseModel(LightningModule):
    """Base class for DL model."""

    def __init__(self, name: str | None = None) -> None:
        super().__init__()
        self.name = name

    def forward(self, *args: Tensor) -> Tensor | Sequence[Tensor]:
        """
        Performs forward pass.

        Args:
            Tensor: Inputs.

        Raises:
            NotImplementedError: Needs to be implemented for a concrete DL model.

        Returns:
            Tensor | Sequence[Tensor]: Predictions.
        """

        raise NotImplementedError

forward(*args)

Performs forward pass.

Parameters:

Name Type Description Default
Tensor

Inputs.

required

Raises:

Type Description
NotImplementedError

Needs to be implemented for a concrete DL model.

Returns:

Type Description
Tensor | Sequence[Tensor]

Tensor | Sequence[Tensor]: Predictions.

Source code in lightning_boost/models/base_model.py
def forward(self, *args: Tensor) -> Tensor | Sequence[Tensor]:
    """
    Performs forward pass.

    Args:
        Tensor: Inputs.

    Raises:
        NotImplementedError: Needs to be implemented for a concrete DL model.

    Returns:
        Tensor | Sequence[Tensor]: Predictions.
    """

    raise NotImplementedError