253 lines
9.9 KiB
Python
253 lines
9.9 KiB
Python
from typing import Dict, List, Union, Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import wandb
|
|
import functools
|
|
import operator
|
|
from copy import copy
|
|
|
|
from pytorch_forecasting.models.nn import MultiEmbedding
|
|
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
|
|
from pytorch_forecasting.metrics import MAE, RMSE, QuantileLoss
|
|
from pytorch_forecasting.models.temporal_fusion_transformer import (
|
|
TemporalFusionTransformer)
|
|
|
|
from informer.attention import (
|
|
FullAttention, ProbSparseAttention, AttentionLayer)
|
|
from informer.embedding import TokenEmbedding, PositionalEmbedding
|
|
from informer.encoder import (
|
|
Encoder,
|
|
EncoderLayer,
|
|
SelfAttentionDistil,
|
|
)
|
|
from informer.decoder import Decoder, DecoderLayer
|
|
|
|
# TODO: Maybe save all models on cpu
|
|
|
|
|
|
def get_model(config, dataset, loss):
|
|
model_name = config['model']['name']
|
|
|
|
if model_name == 'TemporalFusionTransformer':
|
|
return TemporalFusionTransformer.from_dataset(
|
|
dataset,
|
|
hidden_size=config['model']['hidden_size'],
|
|
dropout=config['model']['dropout'],
|
|
attention_head_size=config['model']['attention_head_size'],
|
|
hidden_continuous_size=config['model']['hidden_continuous_size'],
|
|
learning_rate=config['model']['learning_rate'],
|
|
share_single_variable_networks=False,
|
|
loss=loss,
|
|
logging_metrics=[MAE(), RMSE()]
|
|
)
|
|
|
|
if model_name == 'Informer':
|
|
return Informer.from_dataset(
|
|
dataset,
|
|
d_model=config['model']['d_model'],
|
|
d_fully_connected=config['model']['d_fully_connected'],
|
|
n_attention_heads=config['model']['n_attention_heads'],
|
|
n_encoder_layers=config['model']['n_encoder_layers'],
|
|
n_decoder_layers=config['model']['n_decoder_layers'],
|
|
dropout=config['model']['dropout'],
|
|
learning_rate=config['model']['learning_rate'],
|
|
loss=loss,
|
|
embedding_sizes={
|
|
name: (len(encoder.classes_), config['model']['d_model'])
|
|
for name, encoder in dataset.categorical_encoders.items()
|
|
if name in dataset.categoricals
|
|
},
|
|
logging_metrics=[MAE(), RMSE()]
|
|
)
|
|
|
|
raise ValueError("Unknown model")
|
|
|
|
|
|
def load_model_from_wandb(run):
|
|
model_name = run.config['model']['name']
|
|
model_path = f"{run.project}/model-{run.id}:best"
|
|
model_artifact = wandb.Api().artifact(model_path)
|
|
|
|
if model_name == 'TemporalFusionTransformer':
|
|
return TemporalFusionTransformer.load_from_checkpoint(
|
|
model_artifact.file())
|
|
if model_name == 'Informer':
|
|
return Informer.load_from_checkpoint(model_artifact.file())
|
|
|
|
raise ValueError("Invalid model name")
|
|
|
|
|
|
class Informer(BaseModelWithCovariates):
|
|
|
|
def __init__(
|
|
self,
|
|
d_model=256,
|
|
d_fully_connected=512,
|
|
n_attention_heads=2,
|
|
n_encoder_layers=2,
|
|
n_decoder_layers=1,
|
|
dropout=0.1,
|
|
attention_type="prob",
|
|
activation="gelu",
|
|
factor=5,
|
|
mix_attention=False,
|
|
output_attention=False,
|
|
distil=True,
|
|
x_reals: List[str] = [],
|
|
x_categoricals: List[str] = [],
|
|
static_categoricals: List[str] = [],
|
|
static_reals: List[str] = [],
|
|
time_varying_reals_encoder: List[str] = [],
|
|
time_varying_reals_decoder: List[str] = [],
|
|
time_varying_categoricals_encoder: List[str] = [],
|
|
time_varying_categoricals_decoder: List[str] = [],
|
|
embedding_sizes: Dict[str, Tuple[int, int]] = {},
|
|
embedding_paddings: List[str] = [],
|
|
embedding_labels: Dict[str, np.ndarray] = {},
|
|
categorical_groups: Dict[str, List[str]] = {},
|
|
output_size: Union[int, List[int]] = 1,
|
|
loss=None,
|
|
logging_metrics: nn.ModuleList = None,
|
|
actual_n_encoder_reals: int = -1,
|
|
**kwargs):
|
|
# --- Call super().__init__ first ---
|
|
super().__init__(
|
|
loss=loss,
|
|
logging_metrics=logging_metrics,
|
|
**kwargs)
|
|
# ---
|
|
# Save hparams after super().__init__ so dataset parameters are available
|
|
self.save_hyperparameters(ignore=['loss'])
|
|
self.attention_type = attention_type
|
|
|
|
# --- Calculate n_encoder_reals using self.hparams (populated by save_hyperparameters) ---
|
|
n_encoder_reals = len(self.hparams.x_reals)
|
|
print(f"Initializing enc_real_embeddings with {n_encoder_reals} channels (derived from len(hparams.x_reals)).")
|
|
# ---
|
|
|
|
# assertions (can remain commented)
|
|
# assert isinstance(loss, PyTorchMetric), "Loss has to be PyTorch Metric"
|
|
# assert not static_reals # Ensure this line remains commented out
|
|
|
|
# --- Use self.hparams for MultiEmbedding as well ---
|
|
self.cat_embeddings = MultiEmbedding(
|
|
embedding_sizes=self.hparams.embedding_sizes,
|
|
embedding_paddings=self.hparams.embedding_paddings,
|
|
categorical_groups=self.hparams.categorical_groups,
|
|
x_categoricals=self.hparams.x_categoricals,
|
|
)
|
|
|
|
# Initialize with the derived total number of continuous encoder variables
|
|
self.enc_real_embeddings = TokenEmbedding(n_encoder_reals, self.hparams.d_model)
|
|
|
|
self.enc_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
|
|
|
|
# Decoder embedding initialization using hparams
|
|
decoder_reals_list = self.hparams.time_varying_reals_decoder
|
|
print(f"Initializing dec_real_embeddings with {len(decoder_reals_list)} channels.")
|
|
self.dec_real_embeddings = TokenEmbedding(
|
|
len(decoder_reals_list), self.hparams.d_model)
|
|
self.dec_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
|
|
|
|
Attention = ProbSparseAttention \
|
|
if self.hparams.attention_type == "prob" else FullAttention
|
|
|
|
# --- Initialize Encoder/Decoder using self.hparams ---
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AttentionLayer(
|
|
Attention(False, self.hparams.factor, attention_dropout=self.hparams.dropout,
|
|
output_attention=self.hparams.output_attention),
|
|
self.hparams.d_model,
|
|
self.hparams.n_attention_heads,
|
|
mix=False,
|
|
),
|
|
self.hparams.d_model,
|
|
self.hparams.d_fully_connected,
|
|
dropout=self.hparams.dropout,
|
|
activation=self.hparams.activation,
|
|
)
|
|
for _ in range(self.hparams.n_encoder_layers)
|
|
],
|
|
[SelfAttentionDistil(self.hparams.d_model) for _ in range(
|
|
self.hparams.n_encoder_layers - 1)] if self.hparams.distil else None,
|
|
nn.LayerNorm(self.hparams.d_model),
|
|
)
|
|
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AttentionLayer(
|
|
Attention(True, self.hparams.factor, attention_dropout=self.hparams.dropout,
|
|
output_attention=False),
|
|
self.hparams.d_model,
|
|
self.hparams.n_attention_heads,
|
|
mix=self.hparams.mix_attention,
|
|
),
|
|
AttentionLayer(
|
|
FullAttention(
|
|
False,
|
|
self.hparams.factor,
|
|
attention_dropout=self.hparams.dropout,
|
|
output_attention=False),
|
|
self.hparams.d_model,
|
|
self.hparams.n_attention_heads,
|
|
mix=False,
|
|
),
|
|
self.hparams.d_model,
|
|
self.hparams.d_fully_connected,
|
|
dropout=self.hparams.dropout,
|
|
activation=self.hparams.activation,
|
|
)
|
|
for _ in range(self.hparams.n_decoder_layers)
|
|
],
|
|
nn.LayerNorm(self.hparams.d_model),
|
|
)
|
|
|
|
self.projection = nn.Linear(self.hparams.d_model, self.hparams.output_size)
|
|
|
|
def forward(
|
|
self,
|
|
x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
|
|
decoder_length = x['decoder_lengths'].max()
|
|
|
|
enc_out =\
|
|
self.enc_real_embeddings(x['encoder_cont']) +\
|
|
self.enc_positional_embeddings(x['encoder_cont']) +\
|
|
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
|
|
x['encoder_cat']).values()])
|
|
enc_out, attentions = self.encoder(enc_out)
|
|
|
|
# Hacky solution to get only known reals,
|
|
# they are always stacked first.
|
|
# TODO: Make sure no unknown reals are passed to decoder.
|
|
dec_out =\
|
|
self.dec_real_embeddings(x['decoder_cont'][..., :len(
|
|
self.hparams.time_varying_reals_decoder)]) +\
|
|
self.dec_positional_embeddings(x['decoder_cont']) +\
|
|
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
|
|
x['decoder_cat']).values()])
|
|
dec_out = self.decoder(dec_out, enc_out)
|
|
|
|
output = self.projection(dec_out)
|
|
output = output[:, -decoder_length:, :]
|
|
output = self.transform_output(
|
|
output, target_scale=x['target_scale'])
|
|
return self.to_network_output(prediction=output)
|
|
|
|
@classmethod
|
|
def from_dataset(
|
|
cls,
|
|
dataset,
|
|
**kwargs
|
|
):
|
|
new_kwargs = copy(kwargs)
|
|
new_kwargs.update(cls.deduce_default_output_parameters(
|
|
dataset, kwargs, QuantileLoss())) # Using QuantileLoss for defaults might be okay
|
|
|
|
# Let super().from_dataset handle populating dataset_parameters correctly
|
|
return super().from_dataset(dataset, **new_kwargs)
|