Add gmadl loss
This commit is contained in:
parent
878679e526
commit
93fa4009dc
15
src/ml/loss.py
Normal file
15
src/ml/loss.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
|
||||
|
||||
|
||||
class GMADL(MultiHorizonMetric):
|
||||
"""GMADL loss function."""
|
||||
|
||||
def __init__(self, a=1000, b=2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def loss(self, y_pred, target):
|
||||
return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b)
|
||||
Loading…
x
Reference in New Issue
Block a user