diff --git a/src/ml/model.py b/src/ml/model.py index fee7a54..91bebbf 100644 --- a/src/ml/model.py +++ b/src/ml/model.py @@ -72,6 +72,8 @@ def load_model_from_wandb(run): if model_name == 'TemporalFusionTransformer': return TemporalFusionTransformer.load_from_checkpoint( model_artifact.file()) + if model_name == 'Informer': + return Informer.load_from_checkpoint(model_artifact.file()) raise ValueError("Invalid model name")