Skip to content

task_loss

TaskLoss

Bases: Module

Wrapper class for task-specific loss functions.

Source code in lightning_boost/modules/loss/task_loss.py
class TaskLoss(Module):
    """Wrapper class for task-specific loss functions."""

    def __init__(self, instance: Module, task: str = 'base-task', weight: float = 1.0) -> None:
        """
        Initializes task-specific loss function.

        Args:
            instance (Module): Loss function.
            task (str, optional): Task. Defaults to 'base-task'.
            weight (float, optional): Weight in sum of all loss functions. Defaults to 1..
        """

        super().__init__()
        self.instance = instance
        self.task = task
        self.weight = weight

    def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Evaluates loss function.

        Args:
            y_hat (torch.Tensor): Prediction.
            y (torch.Tensor): Target.

        Returns:
            torch.Tensor: Loss.
        """

        return self.instance(y_hat, y)

__init__(instance, task='base-task', weight=1.0)

Initializes task-specific loss function.

Parameters:

Name Type Description Default
instance Module

Loss function.

required
task str

Task. Defaults to 'base-task'.

'base-task'
weight float

Weight in sum of all loss functions. Defaults to 1..

1.0
Source code in lightning_boost/modules/loss/task_loss.py
def __init__(self, instance: Module, task: str = 'base-task', weight: float = 1.0) -> None:
    """
    Initializes task-specific loss function.

    Args:
        instance (Module): Loss function.
        task (str, optional): Task. Defaults to 'base-task'.
        weight (float, optional): Weight in sum of all loss functions. Defaults to 1..
    """

    super().__init__()
    self.instance = instance
    self.task = task
    self.weight = weight

forward(y_hat, y)

Evaluates loss function.

Parameters:

Name Type Description Default
y_hat Tensor

Prediction.

required
y Tensor

Target.

required

Returns:

Type Description
Tensor

torch.Tensor: Loss.

Source code in lightning_boost/modules/loss/task_loss.py
def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Evaluates loss function.

    Args:
        y_hat (torch.Tensor): Prediction.
        y (torch.Tensor): Target.

    Returns:
        torch.Tensor: Loss.
    """

    return self.instance(y_hat, y)