2025-01-16 12:14:59 -05:00

37 lines
905 B
Python

import torch
from pytorch_forecasting import QuantileLoss, RMSE
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
def get_loss(config):
loss_name = config['loss']['name']
if loss_name == 'Quantile':
return QuantileLoss(config['loss']['quantiles'])
if loss_name == 'GMADL':
return GMADL(
a=config['loss']['a'],
b=config['loss']['b']
)
if loss_name == 'RMSE':
return RMSE()
raise ValueError("Unknown loss")
class GMADL(MultiHorizonMetric):
"""GMADL loss function."""
def __init__(self, a=1000, b=2, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def loss(self, y_pred, target):
return -1 * \
(1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)
) - 0.5) * torch.pow(torch.abs(target), self.b)