fix formatting
This commit is contained in:
parent
93fa4009dc
commit
f36b7a8e30
File diff suppressed because it is too large
Load Diff
@ -13,7 +13,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
|||||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||||
from pytorch_forecasting.metrics import MAE, RMSE
|
from pytorch_forecasting.metrics import MAE, RMSE
|
||||||
from pytorch_forecasting import QuantileLoss
|
from pytorch_forecasting import QuantileLoss
|
||||||
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
from pytorch_forecasting.models.temporal_fusion_transformer import (
|
||||||
|
TemporalFusionTransformer)
|
||||||
|
|
||||||
from ml.loss import GMADL
|
from ml.loss import GMADL
|
||||||
|
|
||||||
@ -112,10 +113,14 @@ def get_dataset(config, project):
|
|||||||
max_prediction_length=config['future_window'],
|
max_prediction_length=config['future_window'],
|
||||||
static_reals=config['data']['fields']['static_real'],
|
static_reals=config['data']['fields']['static_real'],
|
||||||
static_categoricals=config['data']['fields']['static_cat'],
|
static_categoricals=config['data']['fields']['static_cat'],
|
||||||
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
time_varying_known_reals=config['data']['fields'][
|
||||||
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
'dynamic_known_real'],
|
||||||
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
time_varying_known_categoricals=config['data']['fields'][
|
||||||
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
'dynamic_known_cat'],
|
||||||
|
time_varying_unknown_reals=config['data']['fields'][
|
||||||
|
'dynamic_unknown_real'],
|
||||||
|
time_varying_unknown_categoricals=config['data']['fields'][
|
||||||
|
'dynamic_unknown_cat'],
|
||||||
randomize_length=False
|
randomize_length=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -246,6 +251,8 @@ def main():
|
|||||||
batch_size=batch_size, train=False, num_workers=3),
|
batch_size=batch_size, train=False, num_workers=3),
|
||||||
ckpt_path=ckpt_path)
|
ckpt_path=ckpt_path)
|
||||||
|
|
||||||
|
# TODO: Clean up non-best models to save space on wandb
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -12,4 +12,6 @@ class GMADL(MultiHorizonMetric):
|
|||||||
self.b = b
|
self.b = b
|
||||||
|
|
||||||
def loss(self, y_pred, target):
|
def loss(self, y_pred, target):
|
||||||
return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b)
|
return -1 * \
|
||||||
|
(1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)
|
||||||
|
) - 0.5) * torch.pow(torch.abs(target), self.b)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user