informer_model/src/ml/model.py
2025-04-30 05:07:31 +00:00

253 lines
9.9 KiB
Python

from typing import Dict, List, Union, Tuple
import torch
import torch.nn as nn
import numpy as np
import wandb
import functools
import operator
from copy import copy
from pytorch_forecasting.models.nn import MultiEmbedding
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
from pytorch_forecasting.metrics import MAE, RMSE, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer import (
TemporalFusionTransformer)
from informer.attention import (
FullAttention, ProbSparseAttention, AttentionLayer)
from informer.embedding import TokenEmbedding, PositionalEmbedding
from informer.encoder import (
Encoder,
EncoderLayer,
SelfAttentionDistil,
)
from informer.decoder import Decoder, DecoderLayer
# TODO: Maybe save all models on cpu
def get_model(config, dataset, loss):
model_name = config['model']['name']
if model_name == 'TemporalFusionTransformer':
return TemporalFusionTransformer.from_dataset(
dataset,
hidden_size=config['model']['hidden_size'],
dropout=config['model']['dropout'],
attention_head_size=config['model']['attention_head_size'],
hidden_continuous_size=config['model']['hidden_continuous_size'],
learning_rate=config['model']['learning_rate'],
share_single_variable_networks=False,
loss=loss,
logging_metrics=[MAE(), RMSE()]
)
if model_name == 'Informer':
return Informer.from_dataset(
dataset,
d_model=config['model']['d_model'],
d_fully_connected=config['model']['d_fully_connected'],
n_attention_heads=config['model']['n_attention_heads'],
n_encoder_layers=config['model']['n_encoder_layers'],
n_decoder_layers=config['model']['n_decoder_layers'],
dropout=config['model']['dropout'],
learning_rate=config['model']['learning_rate'],
loss=loss,
embedding_sizes={
name: (len(encoder.classes_), config['model']['d_model'])
for name, encoder in dataset.categorical_encoders.items()
if name in dataset.categoricals
},
logging_metrics=[MAE(), RMSE()]
)
raise ValueError("Unknown model")
def load_model_from_wandb(run):
model_name = run.config['model']['name']
model_path = f"{run.project}/model-{run.id}:best"
model_artifact = wandb.Api().artifact(model_path)
if model_name == 'TemporalFusionTransformer':
return TemporalFusionTransformer.load_from_checkpoint(
model_artifact.file())
if model_name == 'Informer':
return Informer.load_from_checkpoint(model_artifact.file())
raise ValueError("Invalid model name")
class Informer(BaseModelWithCovariates):
def __init__(
self,
d_model=256,
d_fully_connected=512,
n_attention_heads=2,
n_encoder_layers=2,
n_decoder_layers=1,
dropout=0.1,
attention_type="prob",
activation="gelu",
factor=5,
mix_attention=False,
output_attention=False,
distil=True,
x_reals: List[str] = [],
x_categoricals: List[str] = [],
static_categoricals: List[str] = [],
static_reals: List[str] = [],
time_varying_reals_encoder: List[str] = [],
time_varying_reals_decoder: List[str] = [],
time_varying_categoricals_encoder: List[str] = [],
time_varying_categoricals_decoder: List[str] = [],
embedding_sizes: Dict[str, Tuple[int, int]] = {},
embedding_paddings: List[str] = [],
embedding_labels: Dict[str, np.ndarray] = {},
categorical_groups: Dict[str, List[str]] = {},
output_size: Union[int, List[int]] = 1,
loss=None,
logging_metrics: nn.ModuleList = None,
actual_n_encoder_reals: int = -1,
**kwargs):
# --- Call super().__init__ first ---
super().__init__(
loss=loss,
logging_metrics=logging_metrics,
**kwargs)
# ---
# Save hparams after super().__init__ so dataset parameters are available
self.save_hyperparameters(ignore=['loss'])
self.attention_type = attention_type
# --- Calculate n_encoder_reals using self.hparams (populated by save_hyperparameters) ---
n_encoder_reals = len(self.hparams.x_reals)
print(f"Initializing enc_real_embeddings with {n_encoder_reals} channels (derived from len(hparams.x_reals)).")
# ---
# assertions (can remain commented)
# assert isinstance(loss, PyTorchMetric), "Loss has to be PyTorch Metric"
# assert not static_reals # Ensure this line remains commented out
# --- Use self.hparams for MultiEmbedding as well ---
self.cat_embeddings = MultiEmbedding(
embedding_sizes=self.hparams.embedding_sizes,
embedding_paddings=self.hparams.embedding_paddings,
categorical_groups=self.hparams.categorical_groups,
x_categoricals=self.hparams.x_categoricals,
)
# Initialize with the derived total number of continuous encoder variables
self.enc_real_embeddings = TokenEmbedding(n_encoder_reals, self.hparams.d_model)
self.enc_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
# Decoder embedding initialization using hparams
decoder_reals_list = self.hparams.time_varying_reals_decoder
print(f"Initializing dec_real_embeddings with {len(decoder_reals_list)} channels.")
self.dec_real_embeddings = TokenEmbedding(
len(decoder_reals_list), self.hparams.d_model)
self.dec_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
Attention = ProbSparseAttention \
if self.hparams.attention_type == "prob" else FullAttention
# --- Initialize Encoder/Decoder using self.hparams ---
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
Attention(False, self.hparams.factor, attention_dropout=self.hparams.dropout,
output_attention=self.hparams.output_attention),
self.hparams.d_model,
self.hparams.n_attention_heads,
mix=False,
),
self.hparams.d_model,
self.hparams.d_fully_connected,
dropout=self.hparams.dropout,
activation=self.hparams.activation,
)
for _ in range(self.hparams.n_encoder_layers)
],
[SelfAttentionDistil(self.hparams.d_model) for _ in range(
self.hparams.n_encoder_layers - 1)] if self.hparams.distil else None,
nn.LayerNorm(self.hparams.d_model),
)
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
Attention(True, self.hparams.factor, attention_dropout=self.hparams.dropout,
output_attention=False),
self.hparams.d_model,
self.hparams.n_attention_heads,
mix=self.hparams.mix_attention,
),
AttentionLayer(
FullAttention(
False,
self.hparams.factor,
attention_dropout=self.hparams.dropout,
output_attention=False),
self.hparams.d_model,
self.hparams.n_attention_heads,
mix=False,
),
self.hparams.d_model,
self.hparams.d_fully_connected,
dropout=self.hparams.dropout,
activation=self.hparams.activation,
)
for _ in range(self.hparams.n_decoder_layers)
],
nn.LayerNorm(self.hparams.d_model),
)
self.projection = nn.Linear(self.hparams.d_model, self.hparams.output_size)
def forward(
self,
x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
decoder_length = x['decoder_lengths'].max()
enc_out =\
self.enc_real_embeddings(x['encoder_cont']) +\
self.enc_positional_embeddings(x['encoder_cont']) +\
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
x['encoder_cat']).values()])
enc_out, attentions = self.encoder(enc_out)
# Hacky solution to get only known reals,
# they are always stacked first.
# TODO: Make sure no unknown reals are passed to decoder.
dec_out =\
self.dec_real_embeddings(x['decoder_cont'][..., :len(
self.hparams.time_varying_reals_decoder)]) +\
self.dec_positional_embeddings(x['decoder_cont']) +\
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
x['decoder_cat']).values()])
dec_out = self.decoder(dec_out, enc_out)
output = self.projection(dec_out)
output = output[:, -decoder_length:, :]
output = self.transform_output(
output, target_scale=x['target_scale'])
return self.to_network_output(prediction=output)
@classmethod
def from_dataset(
cls,
dataset,
**kwargs
):
new_kwargs = copy(kwargs)
new_kwargs.update(cls.deduce_default_output_parameters(
dataset, kwargs, QuantileLoss())) # Using QuantileLoss for defaults might be okay
# Let super().from_dataset handle populating dataset_parameters correctly
return super().from_dataset(dataset, **new_kwargs)