Add gmadl loss

This commit is contained in:
Filip Stefaniuk 2024-09-05 15:51:25 +02:00
parent 878679e526
commit 93fa4009dc

15
src/ml/loss.py Normal file
View File

@ -0,0 +1,15 @@
import torch
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
class GMADL(MultiHorizonMetric):
"""GMADL loss function."""
def __init__(self, a=1000, b=2, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def loss(self, y_pred, target):
return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b)