From 93fa4009dc042b3c8cb4b9906d8f52852c2dee36 Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Thu, 5 Sep 2024 15:51:25 +0200 Subject: [PATCH] Add gmadl loss --- src/ml/loss.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 src/ml/loss.py diff --git a/src/ml/loss.py b/src/ml/loss.py new file mode 100644 index 0000000..2174c88 --- /dev/null +++ b/src/ml/loss.py @@ -0,0 +1,15 @@ +import torch + +from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric + + +class GMADL(MultiHorizonMetric): + """GMADL loss function.""" + + def __init__(self, a=1000, b=2, **kwargs): + super().__init__(**kwargs) + self.a = a + self.b = b + + def loss(self, y_pred, target): + return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b)