task_loss
TaskLoss
Bases: Module
Wrapper class for task-specific loss functions.
Source code in lightning_boost/modules/loss/task_loss.py
__init__(instance, task='base-task', weight=1.0)
Initializes task-specific loss function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
instance |
Module
|
Loss function. |
required |
task |
str
|
Task. Defaults to 'base-task'. |
'base-task'
|
weight |
float
|
Weight in sum of all loss functions. Defaults to 1.. |
1.0
|
Source code in lightning_boost/modules/loss/task_loss.py
forward(y_hat, y)
Evaluates loss function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_hat |
Tensor
|
Prediction. |
required |
y |
Tensor
|
Target. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Loss. |