fix formatting

This commit is contained in:
Filip Stefaniuk 2024-09-06 11:16:55 +02:00
parent 93fa4009dc
commit f36b7a8e30
3 changed files with 15880 additions and 15 deletions

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import MAE, RMSE from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting import QuantileLoss from pytorch_forecasting import QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer from pytorch_forecasting.models.temporal_fusion_transformer import (
TemporalFusionTransformer)
from ml.loss import GMADL from ml.loss import GMADL
@ -112,10 +113,14 @@ def get_dataset(config, project):
max_prediction_length=config['future_window'], max_prediction_length=config['future_window'],
static_reals=config['data']['fields']['static_real'], static_reals=config['data']['fields']['static_real'],
static_categoricals=config['data']['fields']['static_cat'], static_categoricals=config['data']['fields']['static_cat'],
time_varying_known_reals=config['data']['fields']['dynamic_known_real'], time_varying_known_reals=config['data']['fields'][
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'], 'dynamic_known_real'],
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'], time_varying_known_categoricals=config['data']['fields'][
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'], 'dynamic_known_cat'],
time_varying_unknown_reals=config['data']['fields'][
'dynamic_unknown_real'],
time_varying_unknown_categoricals=config['data']['fields'][
'dynamic_unknown_cat'],
randomize_length=False randomize_length=False
) )
@ -246,6 +251,8 @@ def main():
batch_size=batch_size, train=False, num_workers=3), batch_size=batch_size, train=False, num_workers=3),
ckpt_path=ckpt_path) ckpt_path=ckpt_path)
# TODO: Clean up non-best models to save space on wandb
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -12,4 +12,6 @@ class GMADL(MultiHorizonMetric):
self.b = b self.b = b
def loss(self, y_pred, target): def loss(self, y_pred, target):
return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b) return -1 * \
(1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)
) - 0.5) * torch.pow(torch.abs(target), self.b)