From 73c513c217dfcca04387e8ab78bd9faeb0eab14a Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Tue, 10 Sep 2024 20:32:05 +0200 Subject: [PATCH] Add informer model --- .../informer-btcusdt-quantile.yaml | 68 +++++ pyproject.toml | 1 + scripts/train.py | 23 +- src/informer/README.md | 1 + src/informer/__init__.py | 0 src/informer/attention.py | 190 ++++++++++++++ src/informer/decoder.py | 53 ++++ src/informer/embedding.py | 132 ++++++++++ src/informer/encoder.py | 101 +++++++ src/informer/masking.py | 19 ++ src/informer/model.py | 246 ++++++++++++++++++ src/ml/model.py | 226 +++++++++++++++- 12 files changed, 1035 insertions(+), 25 deletions(-) create mode 100644 configs/experiments/informer-btcusdt-quantile.yaml create mode 100644 src/informer/README.md create mode 100644 src/informer/__init__.py create mode 100644 src/informer/attention.py create mode 100644 src/informer/decoder.py create mode 100644 src/informer/embedding.py create mode 100644 src/informer/encoder.py create mode 100644 src/informer/masking.py create mode 100644 src/informer/model.py diff --git a/configs/experiments/informer-btcusdt-quantile.yaml b/configs/experiments/informer-btcusdt-quantile.yaml new file mode 100644 index 0000000..278d2b7 --- /dev/null +++ b/configs/experiments/informer-btcusdt-quantile.yaml @@ -0,0 +1,68 @@ +future_window: + value: 5 +past_window: + value: 48 +batch_size: + value: 64 +max_epochs: + value: 1 +data: + value: + dataset: "btc-usdt-5m:latest" + sliding_window: 4 + validation: 0.2 + fields: + time_index: "time_index" + target: "close_price" + group_ids: ["group_id"] + dynamic_unknown_real: + - "high_price" + - "low_price" + - "open_price" + - "close_price" + - "volume" + - "open_to_close_price" + - "high_to_close_price" + - "low_to_close_price" + - "high_to_low_price" + - "returns" + - "log_returns" + - "vol_1h" + - "macd" + - "macd_signal" + - "rsi" + - "low_bband_to_close_price" + - "up_bband_to_close_price" + - "mid_bband_to_close_price" + - "sma_1h_to_close_price" + - "sma_1d_to_close_price" + - "sma_7d_to_close_price" + - "ema_1h_to_close_price" + - "ema_1d_to_close_price" + dynamic_unknown_cat: [] + dynamic_known_real: + - "effective_rates" + - "vix_close_price" + - "fear_greed_index" + - "vol_1d" + - "vol_7d" + dynamic_known_cat: + - "hour" + - "weekday" + static_real: [] + static_cat: [] +loss: + value: + name: "Quantile" + quantiles: [0.02, 0.1, 0.5, 0.9, 0.98] +model: + value: + name: "Informer" + d_model: 256 + d_fully_connected: 512 + n_attention_heads: 2 + dropout: 0.1 + n_encoder_layers: 2 + n_decoder_layers: 1 + learning_rate: 0.001 + optimizer: "Adam" diff --git a/pyproject.toml b/pyproject.toml index 2382a55..4ef78a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = [ + "src/informer", "src/ml", "src/strategy" ] diff --git a/scripts/train.py b/scripts/train.py index 8ee088d..ba3a811 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,12 +11,10 @@ from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint from pytorch_forecasting.data.timeseries import TimeSeriesDataSet -from pytorch_forecasting.metrics import MAE, RMSE from pytorch_forecasting import QuantileLoss -from pytorch_forecasting.models.temporal_fusion_transformer import ( - TemporalFusionTransformer) from ml.loss import GMADL +from ml.model import get_model def get_args(): @@ -146,25 +144,6 @@ def get_loss(config): raise ValueError("Unknown loss") -def get_model(config, dataset, loss): - model_name = config['model']['name'] - - if model_name == 'TemporalFusionTransformer': - return TemporalFusionTransformer.from_dataset( - dataset, - hidden_size=config['model']['hidden_size'], - dropout=config['model']['dropout'], - attention_head_size=config['model']['attention_head_size'], - hidden_continuous_size=config['model']['hidden_continuous_size'], - learning_rate=config['model']['learning_rate'], - share_single_variable_networks=False, - loss=loss, - logging_metrics=[MAE(), RMSE()] - ) - - raise ValueError("Unknown model") - - def main(): args = get_args() logging.basicConfig(level=args.log_level) diff --git a/src/informer/README.md b/src/informer/README.md new file mode 100644 index 0000000..acf9b34 --- /dev/null +++ b/src/informer/README.md @@ -0,0 +1 @@ +Copied from https://github.com/martinwhl/Informer-PyTorch-Lightning diff --git a/src/informer/__init__.py b/src/informer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/informer/attention.py b/src/informer/attention.py new file mode 100644 index 0000000..f70fdf1 --- /dev/null +++ b/src/informer/attention.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn +import numpy as np +import math +from informer.masking import triangular_causal_mask, prob_mask + + +class FullAttention(nn.Module): + def __init__( + self, + mask_flag=True, + scale=None, + attention_dropout=0.1, + output_attention=False, **kwargs): + super(FullAttention, self).__init__() + self.mask_flag = mask_flag + self.scale = scale + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attention_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / math.sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + if self.mask_flag: + if attention_mask is None: + attention_mask = triangular_causal_mask( + B, L, device=queries.device) + scores.masked_fill_(attention_mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + return V.contiguous(), None + + +class ProbSparseAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super(ProbSparseAttention, self).__init__() + self.mask_flag = mask_flag + self.factor = factor + self.scale = scale + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attention_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = torch.transpose(queries, 2, 1) + keys = torch.transpose(keys, 2, 1) + values = torch.transpose(values, 2, 1) + + U_part = int(self.factor * math.ceil(math.log(L_K))) # c * ln(L_K) + u = int(self.factor * math.ceil(math.log(L_Q))) # c * ln(L_Q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK( + queries, keys, sample_k=U_part, n_top=u) + + scale = self.scale or 1.0 / math.sqrt(D) + if scale is not None: + scores_top = scores_top * scale + + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attention = self._update_context( + context, values, scores_top, index, L_Q, attention_mask) + + return context.transpose(2, 1).contiguous(), attention + + def _prob_QK(self, queries, keys, sample_k, n_top): + B, H, L_K, E = keys.shape + _, _, L_Q, _ = queries.shape + + # calculate the sampled Q_K + K_expand = keys.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + # real U = U_part(factor * ln(L_K)) * L_Q + index_sample = torch.randint(L_K, (L_Q, sample_k)) + K_sample = K_expand[:, :, torch.arange( + L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = (queries.unsqueeze(-2) @ + K_sample.transpose(-2, -1)).squeeze() + + # find the top_k query with sparsity measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = queries[torch.arange(B)[:, None, None], torch.arange(H)[ + None, :, None], M_top, :] # factor * ln(L_Q) + Q_K = Q_reduce @ keys.transpose(-2, -1) # factor * ln(L_Q) * L_K + + return Q_K, M_top + + def _get_initial_context(self, values, L_Q): + B, H, L_V, D = values.shape + if not self.mask_flag: + V_mean = values.mean(dim=-2) + context = \ + V_mean.unsqueeze(-2).expand(B, + H, L_Q, V_mean.size(-1)).clone() + else: + # requires that L_Q == L_V, i.e. for self-attention only + assert L_Q == L_V + context = values.cumsum(dim=-2) + return context + + def _update_context( + self, + context, + values, + scores, + index, + L_Q, + attention_mask): + B, H, L_V, D = values.shape + + if self.mask_flag: + attention_mask = prob_mask( + B, H, L_Q, index, scores, device=values.device) + scores.masked_fill_(attention_mask, -np.inf) + + attention = torch.softmax(scores, dim=-1) + + context[ + torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = ( + attention @ values + ).type_as(context) + if self.output_attention: + attentions = (torch.ones(B, H, L_V, L_V) / L_V).type_as(attention) + attentions[torch.arange(B)[:, None, None], torch.arange(H)[ + None, :, None], index, :] = attention + return context, attentions + return context, None + + +class AttentionLayer(nn.Module): + def __init__( + self, + attention, + d_model, + n_heads, + d_keys=None, + d_values=None, + mix=False): + super(AttentionLayer, self).__init__() + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_attention = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + + self.n_heads = n_heads + self.mix = mix + + def forward(self, queries, keys, values, attention_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_attention(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attention = self.inner_attention( + queries, keys, values, attention_mask) + if self.mix: + out = out.transpose(2, 1).contiguous() + out = out.view(B, L, -1) + + return self.out_projection(out), attention diff --git a/src/informer/decoder.py b/src/informer/decoder.py new file mode 100644 index 0000000..96a4fd1 --- /dev/null +++ b/src/informer/decoder.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class DecoderLayer(nn.Module): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention(x, x, x, + attention_mask=x_mask)[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention(x, cross, + cross, attention_mask=cross_mask)[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + + def forward(self, x, cross, x_mask=None, cross_mask=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + if self.norm is not None: + x = self.norm(x) + return x diff --git a/src/informer/embedding.py b/src/informer/embedding.py new file mode 100644 index 0000000..95d1079 --- /dev/null +++ b/src/informer/embedding.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_length=5000): + super(PositionalEmbedding, self).__init__() + embedding = torch.zeros(max_length, d_model) + + position = torch.arange(0, max_length).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + embedding[:, 0::2] = torch.sin(position * div_term) + embedding[:, 1::2] = torch.cos(position * div_term) + + embedding = embedding.unsqueeze(0) + self.register_buffer("embedding", embedding) + + def forward(self, x): + return self.embedding[:, : x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.token_conv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + ) + nn.init.kaiming_normal_(self.token_conv.weight, + mode="fan_in", nonlinearity="leaky_relu") + + def forward(self, x): + return self.token_conv(x.permute(0, 2, 1)).transpose(1, 2) + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + weight = torch.zeros(c_in, d_model) + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + weight[:, 0::2] = torch.sin(position * div_term) + weight[:, 1::2] = torch.cos(position * div_term) + + self.embedding = nn.Embedding(c_in, d_model) + self.embedding.weight = nn.Parameter(weight, requires_grad=False) + + def forward(self, x): + return self.embedding(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embedding_type="fixed", frequency="h"): + super(TemporalEmbedding, self).__init__() + + MINUTE_SIZE = 4 + HOUR_SIZE = 24 + WEEKDAY_SIZE = 7 + DAY_SIZE = 32 + MONTH_SIZE = 13 + + Embedding = FixedEmbedding \ + if embedding_type == "fixed" else nn.Embedding + if frequency == "t": + self.minute_embedding = Embedding(MINUTE_SIZE, d_model) + self.hour_embedding = Embedding(HOUR_SIZE, d_model) + self.weekday_embedding = Embedding(WEEKDAY_SIZE, d_model) + self.day_embedding = Embedding(DAY_SIZE, d_model) + self.month_embedding = Embedding(MONTH_SIZE, d_model) + + def forward(self, x): + x = x.long() + + minute_x = self.minute_embedding(x[:, :, 4]) if hasattr( + self, "minute_embedding") else 0.0 + hour_x = self.hour_embedding(x[:, :, 3]) + weekday_x = self.weekday_embedding(x[:, :, 2]) + day_x = self.day_embedding(x[:, :, 1]) + month_x = self.month_embedding(x[:, :, 0]) + + return minute_x + hour_x + weekday_x + day_x + month_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, frequency="h"): + super(TimeFeatureEmbedding, self).__init__() + + FREQUENCY_MAP = {"h": 4, "t": 5, "s": 6, + "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + d_input = FREQUENCY_MAP[frequency] + self.embedding = nn.Linear(d_input, d_model) + + def forward(self, x): + return self.embedding(x) + + +class DataEmbedding(nn.Module): + def __init__( + self, + c_in, + d_model, + embedding_type="fixed", + frequency="h", + dropout=0.1 + ): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in, d_model) + self.position_embedding = PositionalEmbedding(d_model) + if embedding_type != "timefeature": + self.temporal_embedding = TemporalEmbedding( + d_model, embedding_type=embedding_type, frequency=frequency) + else: + self.temporal_embedding = TimeFeatureEmbedding( + d_model, frequency=frequency) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x, x_mark): + x = self.value_embedding( + x) + self.position_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) diff --git a/src/informer/encoder.py b/src/informer/encoder.py new file mode 100644 index 0000000..8079753 --- /dev/null +++ b/src/informer/encoder.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SelfAttentionDistil(nn.Module): + def __init__(self, c_in): + super(SelfAttentionDistil, self).__init__() + self.conv = nn.Conv1d(c_in, c_in, kernel_size=3, + padding=2, padding_mode="circular") + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.conv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.max_pool(x) + x = torch.transpose(x, 1, 2) + return x + + +class EncoderLayer(nn.Module): + def __init__( + self, + attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attention_mask=None): + new_x, attention = self.attention( + x, x, x, attention_mask=attention_mask) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attention + + +class Encoder(nn.Module): + def __init__(self, attention_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attention_layers = nn.ModuleList(attention_layers) + self.conv_layers = nn.ModuleList( + conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attention_mask=None): + attentions = [] + if self.conv_layers is not None: + for attention_layer, conv_layer in zip( + self.attention_layers, self.conv_layers): + x, attention = attention_layer( + x, attention_mask=attention_mask) + x = conv_layer(x) + attentions.append(attention) + x, attention = self.attention_layers[-1](x) + attentions.append(attention) + else: + for attention_layer in self.attention_layers: + x, attention = attention_layer( + x, attention_mask=attention_mask) + attentions.append(attention) + if self.norm is not None: + x = self.norm(x) + return x, attentions + + +class EncoderStack(nn.Module): + def __init__(self, encoders): + super(EncoderStack).__init__() + self.encoders = nn.ModuleList(encoders) + + def forward(self, x, attention_mask=None): + inp_len = x.size(1) + x_stack = [] + attentions = [] + for encoder in self.encoders: + if encoder is None: + inp_len //= 2 + continue + x, attention = encoder(x[:, -inp_len:, :]) + x_stack.append(x) + attentions.append(attention) + inp_len //= 2 + x_stack = torch.cat(x_stack, -2) + return x_stack, attentions diff --git a/src/informer/masking.py b/src/informer/masking.py new file mode 100644 index 0000000..e439104 --- /dev/null +++ b/src/informer/masking.py @@ -0,0 +1,19 @@ +import torch + + +def triangular_causal_mask(B, L, device=torch.device("cpu")): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + mask = torch.triu(torch.ones( + mask_shape, dtype=torch.bool, device=device), diagonal=1) + return mask + + +def prob_mask(B, H, L, index, scores, device=torch.device("cpu")): + mask = torch.ones(L, scores.shape[-1], + dtype=torch.bool, device=device).triu(1) + mask_ex = mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[ + None, :, None], index, :] + mask = indicator.view(scores.shape) + return mask diff --git a/src/informer/model.py b/src/informer/model.py new file mode 100644 index 0000000..0b686c6 --- /dev/null +++ b/src/informer/model.py @@ -0,0 +1,246 @@ +import torch.nn as nn +from informer.attention import ( + FullAttention, ProbSparseAttention, AttentionLayer) +from informer.embedding import DataEmbedding +from informer.encoder import ( + Encoder, + EncoderLayer, + EncoderStack, + SelfAttentionDistil, +) +from informer.decoder import Decoder, DecoderLayer + + +class BaseInformer(nn.Module): + def __init__( + self, + enc_in=7, + dec_in=7, + c_out=7, + out_len=24, + factor=5, + d_model=512, + n_heads=8, + num_encoder_layers=2, + num_decoder_layers=1, + d_ff=2048, + dropout=0.05, + attention_type="prob", + embedding_type="fixed", + frequency="h", + activation="gelu", + output_attention=False, + distil=True, + mix_attention=False, + **kwargs + ): + super(BaseInformer, self).__init__() + self.pred_len = out_len + self.attention_type = attention_type + self.output_attention = output_attention + + self.enc_embedding = DataEmbedding( + enc_in, d_model, embedding_type, frequency, dropout) + self.dec_embedding = DataEmbedding( + dec_in, d_model, embedding_type, frequency, dropout) + + Attention = ProbSparseAttention \ + if attention_type == "prob" else FullAttention + + self.encoder = None + + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + Attention(True, factor, attention_dropout=dropout, + output_attention=False), + d_model, + n_heads, + mix=mix_attention, + ), + AttentionLayer( + FullAttention( + False, + factor, + attention_dropout=dropout, + output_attention=False), + d_model, + n_heads, + mix=False, + ), + d_model, + d_ff, + dropout=dropout, + activation=activation, + ) + for _ in range(num_decoder_layers) + ], + nn.LayerNorm(d_model), + ) + + self.projection = nn.Linear(d_model, c_out) + + def forward( + self, + x_enc, + x_enc_mark, + x_dec, + x_dec_mark, + enc_self_mask=None, + dec_self_mask=None, + dec_enc_mask=None, + ): + enc_out = self.enc_embedding(x_enc, x_enc_mark) + enc_out, attentions = self.encoder( + enc_out, attention_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_dec_mark) + dec_out = self.decoder( + dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + dec_out = self.projection(dec_out) + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attentions + return dec_out[:, -self.pred_len:, :] + + +class Informer(BaseInformer): + def __init__( + self, + enc_in=7, + dec_in=7, + c_out=7, + out_len=24, + factor=5, + d_model=512, + n_heads=8, + num_encoder_layers=2, + num_decoder_layers=1, + d_ff=2048, + dropout=0.05, + attention_type="prob", + embedding_type="fixed", + frequency="h", + activation="gelu", + output_attention=False, + distil=True, + mix_attention=False, + **kwargs + ): + super(Informer, self).__init__( + enc_in, + dec_in, + c_out, + out_len, + factor=factor, + d_model=d_model, + n_heads=n_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + d_ff=d_ff, + dropout=dropout, + attention_type=attention_type, + embedding_type=embedding_type, + frequency=frequency, + activation=activation, + output_attention=output_attention, + distil=distil, + mix_attention=mix_attention, + ) + Attention = ProbSparseAttention \ + if attention_type == "prob" else FullAttention + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + Attention(False, factor, attention_dropout=dropout, + output_attention=output_attention), + d_model, + n_heads, + mix=False, + ), + d_model, + d_ff, + dropout=dropout, + activation=activation, + ) + for _ in range(num_encoder_layers) + ], + [SelfAttentionDistil(d_model) for _ in range( + num_encoder_layers - 1)] if distil else None, + nn.LayerNorm(d_model), + ) + + +class InformerStack(BaseInformer): + def __init__( + self, + enc_in=7, + dec_in=7, + c_out=7, + out_len=24, + factor=5, + d_model=512, + n_heads=8, + num_encoder_layers=2, + num_decoder_layers=1, + d_ff=2048, + dropout=0.05, + attention_type="prob", + embedding_type="fixed", + frequency="h", + activation="gelu", + output_attention=False, + distil=True, + mix_attention=False, + **kwargs + ): + super(InformerStack, self).__init__( + enc_in, + dec_in, + c_out, + out_len, + factor=factor, + d_model=d_model, + n_heads=n_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + d_ff=d_ff, + dropout=dropout, + attention_type=attention_type, + embedding_type=embedding_type, + frequency=frequency, + activation=activation, + output_attention=output_attention, + distil=distil, + mix_attention=mix_attention, + ) + Attention = ProbSparseAttention \ + if attention_type == "prob" else FullAttention + stacks = list(range(num_encoder_layers, 2, -1)) # customize here + encoders = [ + Encoder( + [ + EncoderLayer( + AttentionLayer( + Attention(False, factor, attention_dropout=dropout, + output_attention=output_attention), + d_model, + n_heads, + mix=False, + ), + d_model, + d_ff, + dropout=dropout, + activation=activation, + ) + for _ in range(el) + ], + [SelfAttentionDistil(d_model) + for _ in range(el - 1)] if distil else None, + nn.LayerNorm(d_model), + ) + for el in stacks + ] + self.encoder = EncoderStack(encoders) diff --git a/src/ml/model.py b/src/ml/model.py index b4a0caa..fee7a54 100644 --- a/src/ml/model.py +++ b/src/ml/model.py @@ -1,13 +1,233 @@ +from typing import Dict, List, Union, Tuple +import torch +import torch.nn as nn +import numpy as np import wandb -from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer +import functools +import operator +from copy import copy + +from pytorch_forecasting.models.nn import MultiEmbedding +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.metrics import MAE, RMSE, QuantileLoss +from pytorch_forecasting.models.temporal_fusion_transformer import ( + TemporalFusionTransformer) + +from informer.attention import ( + FullAttention, ProbSparseAttention, AttentionLayer) +from informer.embedding import TokenEmbedding, PositionalEmbedding +from informer.encoder import ( + Encoder, + EncoderLayer, + SelfAttentionDistil, +) +from informer.decoder import Decoder, DecoderLayer # TODO: Maybe save all models on cpu + + +def get_model(config, dataset, loss): + model_name = config['model']['name'] + + if model_name == 'TemporalFusionTransformer': + return TemporalFusionTransformer.from_dataset( + dataset, + hidden_size=config['model']['hidden_size'], + dropout=config['model']['dropout'], + attention_head_size=config['model']['attention_head_size'], + hidden_continuous_size=config['model']['hidden_continuous_size'], + learning_rate=config['model']['learning_rate'], + share_single_variable_networks=False, + loss=loss, + logging_metrics=[MAE(), RMSE()] + ) + + if model_name == 'Informer': + return Informer.from_dataset( + dataset, + d_model=config['model']['d_model'], + d_fully_connected=config['model']['d_fully_connected'], + n_attention_heads=config['model']['n_attention_heads'], + n_encoder_layers=config['model']['n_encoder_layers'], + n_decoder_layers=config['model']['n_decoder_layers'], + dropout=config['model']['dropout'], + learning_rate=config['model']['learning_rate'], + loss=loss, + embedding_sizes={ + name: (len(encoder.classes_), config['model']['d_model']) + for name, encoder in dataset.categorical_encoders.items() + if name in dataset.categoricals + }, + logging_metrics=[MAE(), RMSE()] + ) + + raise ValueError("Unknown model") + + def load_model_from_wandb(run): model_name = run.config['model']['name'] model_path = f"{run.project}/model-{run.id}:best" model_artifact = wandb.Api().artifact(model_path) if model_name == 'TemporalFusionTransformer': - return TemporalFusionTransformer.load_from_checkpoint(model_artifact.file()) + return TemporalFusionTransformer.load_from_checkpoint( + model_artifact.file()) - raise ValueError("Invalid model name") \ No newline at end of file + raise ValueError("Invalid model name") + + +class Informer(BaseModelWithCovariates): + + def __init__( + self, + d_model=256, + d_fully_connected=512, + n_attention_heads=2, + n_encoder_layers=2, + n_decoder_layers=1, + dropout=0.1, + attention_type="prob", + activation="gelu", + factor=5, + mix_attention=False, + output_attention=False, + distil=True, + x_reals: List[str] = [], + x_categoricals: List[str] = [], + static_categoricals: List[str] = [], + static_reals: List[str] = [], + time_varying_reals_encoder: List[str] = [], + time_varying_reals_decoder: List[str] = [], + time_varying_categoricals_encoder: List[str] = [], + time_varying_categoricals_decoder: List[str] = [], + embedding_sizes: Dict[str, Tuple[int, int]] = {}, + embedding_paddings: List[str] = [], + embedding_labels: Dict[str, np.ndarray] = {}, + categorical_groups: Dict[str, List[str]] = {}, + output_size: Union[int, List[int]] = 1, + loss=None, + logging_metrics: nn.ModuleList = None, + **kwargs): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + **kwargs) + self.save_hyperparameters(ignore=['loss']) + self.attention_type = attention_type + + assert not static_reals + assert not static_categoricals + + self.cat_embeddings = MultiEmbedding( + embedding_sizes=embedding_sizes, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + + self.enc_real_embeddings = TokenEmbedding( + len(time_varying_reals_encoder), d_model) + self.enc_positional_embeddings = PositionalEmbedding(d_model) + + self.dec_real_embeddings = TokenEmbedding( + len(time_varying_reals_decoder), d_model) + self.dec_positional_embeddings = PositionalEmbedding(d_model) + + Attention = ProbSparseAttention \ + if attention_type == "prob" else FullAttention + + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + Attention(False, factor, attention_dropout=dropout, + output_attention=output_attention), + d_model, + n_attention_heads, + mix=False, + ), + d_model, + d_fully_connected, + dropout=dropout, + activation=activation, + ) + for _ in range(n_encoder_layers) + ], + [SelfAttentionDistil(d_model) for _ in range( + n_encoder_layers - 1)] if distil else None, + nn.LayerNorm(d_model), + ) + + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + Attention(True, factor, attention_dropout=dropout, + output_attention=False), + d_model, + n_attention_heads, + mix=mix_attention, + ), + AttentionLayer( + FullAttention( + False, + factor, + attention_dropout=dropout, + output_attention=False), + d_model, + n_attention_heads, + mix=False, + ), + d_model, + d_fully_connected, + dropout=dropout, + activation=activation, + ) + for _ in range(n_decoder_layers) + ], + nn.LayerNorm(d_model), + ) + + self.projection = nn.Linear(d_model, output_size) + + def forward( + self, + x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + decoder_length = x['decoder_lengths'].max() + + enc_out =\ + self.enc_real_embeddings(x['encoder_cont']) +\ + self.enc_positional_embeddings(x['encoder_cont']) +\ + functools.reduce(operator.add, [emb for emb in self.cat_embeddings( + x['encoder_cat']).values()]) + enc_out, attentions = self.encoder(enc_out) + + # Hacky solution to get only known reals, + # they are always stacked first. + # TODO: Make sure no unknown reals are passed to decoder. + dec_out =\ + self.dec_real_embeddings(x['decoder_cont'][..., :len( + self.hparams.time_varying_reals_decoder)]) +\ + self.dec_positional_embeddings(x['decoder_cont']) +\ + functools.reduce(operator.add, [emb for emb in self.cat_embeddings( + x['decoder_cat']).values()]) + dec_out = self.decoder(dec_out, enc_out) + + output = self.projection(dec_out) + output = output[:, -decoder_length:, :] + output = self.transform_output( + output, target_scale=x['target_scale']) + return self.to_network_output(prediction=output) + + @classmethod + def from_dataset( + cls, + dataset, + **kwargs + ): + new_kwargs = copy(kwargs) + new_kwargs.update(cls.deduce_default_output_parameters( + dataset, kwargs, QuantileLoss())) + + return super().from_dataset(dataset, **new_kwargs)