fix formatting
This commit is contained in:
parent
93fa4009dc
commit
f36b7a8e30
File diff suppressed because it is too large
Load Diff
@ -13,7 +13,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||
from pytorch_forecasting.metrics import MAE, RMSE
|
||||
from pytorch_forecasting import QuantileLoss
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import (
|
||||
TemporalFusionTransformer)
|
||||
|
||||
from ml.loss import GMADL
|
||||
|
||||
@ -112,10 +113,14 @@ def get_dataset(config, project):
|
||||
max_prediction_length=config['future_window'],
|
||||
static_reals=config['data']['fields']['static_real'],
|
||||
static_categoricals=config['data']['fields']['static_cat'],
|
||||
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
||||
time_varying_known_reals=config['data']['fields'][
|
||||
'dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['data']['fields'][
|
||||
'dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['data']['fields'][
|
||||
'dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['data']['fields'][
|
||||
'dynamic_unknown_cat'],
|
||||
randomize_length=False
|
||||
)
|
||||
|
||||
@ -246,6 +251,8 @@ def main():
|
||||
batch_size=batch_size, train=False, num_workers=3),
|
||||
ckpt_path=ckpt_path)
|
||||
|
||||
# TODO: Clean up non-best models to save space on wandb
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@ -12,4 +12,6 @@ class GMADL(MultiHorizonMetric):
|
||||
self.b = b
|
||||
|
||||
def loss(self, y_pred, target):
|
||||
return -1 * (1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)) - 0.5) * torch.pow(torch.abs(target), self.b)
|
||||
return -1 * \
|
||||
(1 / (1 + torch.exp(-self.a * self.to_prediction(y_pred) * target)
|
||||
) - 0.5) * torch.pow(torch.abs(target), self.b)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user