Add informer model

This commit is contained in:
Filip Stefaniuk 2024-09-10 20:32:05 +02:00
parent 03bd6ad9d8
commit 73c513c217
12 changed files with 1035 additions and 25 deletions

View File

@ -0,0 +1,68 @@
future_window:
value: 5
past_window:
value: 48
batch_size:
value: 64
max_epochs:
value: 1
data:
value:
dataset: "btc-usdt-5m:latest"
sliding_window: 4
validation: 0.2
fields:
time_index: "time_index"
target: "close_price"
group_ids: ["group_id"]
dynamic_unknown_real:
- "high_price"
- "low_price"
- "open_price"
- "close_price"
- "volume"
- "open_to_close_price"
- "high_to_close_price"
- "low_to_close_price"
- "high_to_low_price"
- "returns"
- "log_returns"
- "vol_1h"
- "macd"
- "macd_signal"
- "rsi"
- "low_bband_to_close_price"
- "up_bband_to_close_price"
- "mid_bband_to_close_price"
- "sma_1h_to_close_price"
- "sma_1d_to_close_price"
- "sma_7d_to_close_price"
- "ema_1h_to_close_price"
- "ema_1d_to_close_price"
dynamic_unknown_cat: []
dynamic_known_real:
- "effective_rates"
- "vix_close_price"
- "fear_greed_index"
- "vol_1d"
- "vol_7d"
dynamic_known_cat:
- "hour"
- "weekday"
static_real: []
static_cat: []
loss:
value:
name: "Quantile"
quantiles: [0.02, 0.1, 0.5, 0.9, 0.98]
model:
value:
name: "Informer"
d_model: 256
d_fully_connected: 512
n_attention_heads: 2
dropout: 0.1
n_encoder_layers: 2
n_decoder_layers: 1
learning_rate: 0.001
optimizer: "Adam"

View File

@ -4,6 +4,7 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = [ packages = [
"src/informer",
"src/ml", "src/ml",
"src/strategy" "src/strategy"
] ]

View File

@ -11,12 +11,10 @@ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import MAE, RMSE
from pytorch_forecasting import QuantileLoss from pytorch_forecasting import QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer import (
TemporalFusionTransformer)
from ml.loss import GMADL from ml.loss import GMADL
from ml.model import get_model
def get_args(): def get_args():
@ -146,25 +144,6 @@ def get_loss(config):
raise ValueError("Unknown loss") raise ValueError("Unknown loss")
def get_model(config, dataset, loss):
model_name = config['model']['name']
if model_name == 'TemporalFusionTransformer':
return TemporalFusionTransformer.from_dataset(
dataset,
hidden_size=config['model']['hidden_size'],
dropout=config['model']['dropout'],
attention_head_size=config['model']['attention_head_size'],
hidden_continuous_size=config['model']['hidden_continuous_size'],
learning_rate=config['model']['learning_rate'],
share_single_variable_networks=False,
loss=loss,
logging_metrics=[MAE(), RMSE()]
)
raise ValueError("Unknown model")
def main(): def main():
args = get_args() args = get_args()
logging.basicConfig(level=args.log_level) logging.basicConfig(level=args.log_level)

1
src/informer/README.md Normal file
View File

@ -0,0 +1 @@
Copied from https://github.com/martinwhl/Informer-PyTorch-Lightning

0
src/informer/__init__.py Normal file
View File

190
src/informer/attention.py Normal file
View File

@ -0,0 +1,190 @@
import torch
import torch.nn as nn
import numpy as np
import math
from informer.masking import triangular_causal_mask, prob_mask
class FullAttention(nn.Module):
def __init__(
self,
mask_flag=True,
scale=None,
attention_dropout=0.1,
output_attention=False, **kwargs):
super(FullAttention, self).__init__()
self.mask_flag = mask_flag
self.scale = scale
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attention_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / math.sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attention_mask is None:
attention_mask = triangular_causal_mask(
B, L, device=queries.device)
scores.masked_fill_(attention_mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
return V.contiguous(), None
class ProbSparseAttention(nn.Module):
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(ProbSparseAttention, self).__init__()
self.mask_flag = mask_flag
self.factor = factor
self.scale = scale
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attention_mask):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = torch.transpose(queries, 2, 1)
keys = torch.transpose(keys, 2, 1)
values = torch.transpose(values, 2, 1)
U_part = int(self.factor * math.ceil(math.log(L_K))) # c * ln(L_K)
u = int(self.factor * math.ceil(math.log(L_Q))) # c * ln(L_Q)
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(
queries, keys, sample_k=U_part, n_top=u)
scale = self.scale or 1.0 / math.sqrt(D)
if scale is not None:
scores_top = scores_top * scale
context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries
context, attention = self._update_context(
context, values, scores_top, index, L_Q, attention_mask)
return context.transpose(2, 1).contiguous(), attention
def _prob_QK(self, queries, keys, sample_k, n_top):
B, H, L_K, E = keys.shape
_, _, L_Q, _ = queries.shape
# calculate the sampled Q_K
K_expand = keys.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor * ln(L_K)) * L_Q
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(
L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = (queries.unsqueeze(-2) @
K_sample.transpose(-2, -1)).squeeze()
# find the top_k query with sparsity measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = queries[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], M_top, :] # factor * ln(L_Q)
Q_K = Q_reduce @ keys.transpose(-2, -1) # factor * ln(L_Q) * L_K
return Q_K, M_top
def _get_initial_context(self, values, L_Q):
B, H, L_V, D = values.shape
if not self.mask_flag:
V_mean = values.mean(dim=-2)
context = \
V_mean.unsqueeze(-2).expand(B,
H, L_Q, V_mean.size(-1)).clone()
else:
# requires that L_Q == L_V, i.e. for self-attention only
assert L_Q == L_V
context = values.cumsum(dim=-2)
return context
def _update_context(
self,
context,
values,
scores,
index,
L_Q,
attention_mask):
B, H, L_V, D = values.shape
if self.mask_flag:
attention_mask = prob_mask(
B, H, L_Q, index, scores, device=values.device)
scores.masked_fill_(attention_mask, -np.inf)
attention = torch.softmax(scores, dim=-1)
context[
torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = (
attention @ values
).type_as(context)
if self.output_attention:
attentions = (torch.ones(B, H, L_V, L_V) / L_V).type_as(attention)
attentions[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], index, :] = attention
return context, attentions
return context, None
class AttentionLayer(nn.Module):
def __init__(
self,
attention,
d_model,
n_heads,
d_keys=None,
d_values=None,
mix=False):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_attention = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
self.mix = mix
def forward(self, queries, keys, values, attention_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_attention(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attention = self.inner_attention(
queries, keys, values, attention_mask)
if self.mix:
out = out.transpose(2, 1).contiguous()
out = out.view(B, L, -1)
return self.out_projection(out), attention

53
src/informer/decoder.py Normal file
View File

@ -0,0 +1,53 @@
import torch.nn as nn
import torch.nn.functional as F
class DecoderLayer(nn.Module):
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation="relu",
):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(x, x, x,
attention_mask=x_mask)[0])
x = self.norm1(x)
x = x + self.dropout(self.cross_attention(x, cross,
cross, attention_mask=cross_mask)[0])
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
def forward(self, x, cross, x_mask=None, cross_mask=None):
for layer in self.layers:
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
if self.norm is not None:
x = self.norm(x)
return x

132
src/informer/embedding.py Normal file
View File

@ -0,0 +1,132 @@
import torch
import torch.nn as nn
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_length=5000):
super(PositionalEmbedding, self).__init__()
embedding = torch.zeros(max_length, d_model)
position = torch.arange(0, max_length).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
embedding[:, 0::2] = torch.sin(position * div_term)
embedding[:, 1::2] = torch.cos(position * div_term)
embedding = embedding.unsqueeze(0)
self.register_buffer("embedding", embedding)
def forward(self, x):
return self.embedding[:, : x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.token_conv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
)
nn.init.kaiming_normal_(self.token_conv.weight,
mode="fan_in", nonlinearity="leaky_relu")
def forward(self, x):
return self.token_conv(x.permute(0, 2, 1)).transpose(1, 2)
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
weight = torch.zeros(c_in, d_model)
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
weight[:, 0::2] = torch.sin(position * div_term)
weight[:, 1::2] = torch.cos(position * div_term)
self.embedding = nn.Embedding(c_in, d_model)
self.embedding.weight = nn.Parameter(weight, requires_grad=False)
def forward(self, x):
return self.embedding(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embedding_type="fixed", frequency="h"):
super(TemporalEmbedding, self).__init__()
MINUTE_SIZE = 4
HOUR_SIZE = 24
WEEKDAY_SIZE = 7
DAY_SIZE = 32
MONTH_SIZE = 13
Embedding = FixedEmbedding \
if embedding_type == "fixed" else nn.Embedding
if frequency == "t":
self.minute_embedding = Embedding(MINUTE_SIZE, d_model)
self.hour_embedding = Embedding(HOUR_SIZE, d_model)
self.weekday_embedding = Embedding(WEEKDAY_SIZE, d_model)
self.day_embedding = Embedding(DAY_SIZE, d_model)
self.month_embedding = Embedding(MONTH_SIZE, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embedding(x[:, :, 4]) if hasattr(
self, "minute_embedding") else 0.0
hour_x = self.hour_embedding(x[:, :, 3])
weekday_x = self.weekday_embedding(x[:, :, 2])
day_x = self.day_embedding(x[:, :, 1])
month_x = self.month_embedding(x[:, :, 0])
return minute_x + hour_x + weekday_x + day_x + month_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, frequency="h"):
super(TimeFeatureEmbedding, self).__init__()
FREQUENCY_MAP = {"h": 4, "t": 5, "s": 6,
"m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
d_input = FREQUENCY_MAP[frequency]
self.embedding = nn.Linear(d_input, d_model)
def forward(self, x):
return self.embedding(x)
class DataEmbedding(nn.Module):
def __init__(
self,
c_in,
d_model,
embedding_type="fixed",
frequency="h",
dropout=0.1
):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in, d_model)
self.position_embedding = PositionalEmbedding(d_model)
if embedding_type != "timefeature":
self.temporal_embedding = TemporalEmbedding(
d_model, embedding_type=embedding_type, frequency=frequency)
else:
self.temporal_embedding = TimeFeatureEmbedding(
d_model, frequency=frequency)
self.dropout = nn.Dropout(dropout)
def forward(self, x, x_mark):
x = self.value_embedding(
x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)

101
src/informer/encoder.py Normal file
View File

@ -0,0 +1,101 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttentionDistil(nn.Module):
def __init__(self, c_in):
super(SelfAttentionDistil, self).__init__()
self.conv = nn.Conv1d(c_in, c_in, kernel_size=3,
padding=2, padding_mode="circular")
self.norm = nn.BatchNorm1d(c_in)
self.activation = nn.ELU()
self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv(x.permute(0, 2, 1))
x = self.norm(x)
x = self.activation(x)
x = self.max_pool(x)
x = torch.transpose(x, 1, 2)
return x
class EncoderLayer(nn.Module):
def __init__(
self,
attention,
d_model,
d_ff=None,
dropout=0.1,
activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attention_mask=None):
new_x, attention = self.attention(
x, x, x, attention_mask=attention_mask)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attention
class Encoder(nn.Module):
def __init__(self, attention_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attention_layers = nn.ModuleList(attention_layers)
self.conv_layers = nn.ModuleList(
conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attention_mask=None):
attentions = []
if self.conv_layers is not None:
for attention_layer, conv_layer in zip(
self.attention_layers, self.conv_layers):
x, attention = attention_layer(
x, attention_mask=attention_mask)
x = conv_layer(x)
attentions.append(attention)
x, attention = self.attention_layers[-1](x)
attentions.append(attention)
else:
for attention_layer in self.attention_layers:
x, attention = attention_layer(
x, attention_mask=attention_mask)
attentions.append(attention)
if self.norm is not None:
x = self.norm(x)
return x, attentions
class EncoderStack(nn.Module):
def __init__(self, encoders):
super(EncoderStack).__init__()
self.encoders = nn.ModuleList(encoders)
def forward(self, x, attention_mask=None):
inp_len = x.size(1)
x_stack = []
attentions = []
for encoder in self.encoders:
if encoder is None:
inp_len //= 2
continue
x, attention = encoder(x[:, -inp_len:, :])
x_stack.append(x)
attentions.append(attention)
inp_len //= 2
x_stack = torch.cat(x_stack, -2)
return x_stack, attentions

19
src/informer/masking.py Normal file
View File

@ -0,0 +1,19 @@
import torch
def triangular_causal_mask(B, L, device=torch.device("cpu")):
mask_shape = [B, 1, L, L]
with torch.no_grad():
mask = torch.triu(torch.ones(
mask_shape, dtype=torch.bool, device=device), diagonal=1)
return mask
def prob_mask(B, H, L, index, scores, device=torch.device("cpu")):
mask = torch.ones(L, scores.shape[-1],
dtype=torch.bool, device=device).triu(1)
mask_ex = mask[None, None, :].expand(B, H, L, scores.shape[-1])
indicator = mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], index, :]
mask = indicator.view(scores.shape)
return mask

246
src/informer/model.py Normal file
View File

@ -0,0 +1,246 @@
import torch.nn as nn
from informer.attention import (
FullAttention, ProbSparseAttention, AttentionLayer)
from informer.embedding import DataEmbedding
from informer.encoder import (
Encoder,
EncoderLayer,
EncoderStack,
SelfAttentionDistil,
)
from informer.decoder import Decoder, DecoderLayer
class BaseInformer(nn.Module):
def __init__(
self,
enc_in=7,
dec_in=7,
c_out=7,
out_len=24,
factor=5,
d_model=512,
n_heads=8,
num_encoder_layers=2,
num_decoder_layers=1,
d_ff=2048,
dropout=0.05,
attention_type="prob",
embedding_type="fixed",
frequency="h",
activation="gelu",
output_attention=False,
distil=True,
mix_attention=False,
**kwargs
):
super(BaseInformer, self).__init__()
self.pred_len = out_len
self.attention_type = attention_type
self.output_attention = output_attention
self.enc_embedding = DataEmbedding(
enc_in, d_model, embedding_type, frequency, dropout)
self.dec_embedding = DataEmbedding(
dec_in, d_model, embedding_type, frequency, dropout)
Attention = ProbSparseAttention \
if attention_type == "prob" else FullAttention
self.encoder = None
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
Attention(True, factor, attention_dropout=dropout,
output_attention=False),
d_model,
n_heads,
mix=mix_attention,
),
AttentionLayer(
FullAttention(
False,
factor,
attention_dropout=dropout,
output_attention=False),
d_model,
n_heads,
mix=False,
),
d_model,
d_ff,
dropout=dropout,
activation=activation,
)
for _ in range(num_decoder_layers)
],
nn.LayerNorm(d_model),
)
self.projection = nn.Linear(d_model, c_out)
def forward(
self,
x_enc,
x_enc_mark,
x_dec,
x_dec_mark,
enc_self_mask=None,
dec_self_mask=None,
dec_enc_mask=None,
):
enc_out = self.enc_embedding(x_enc, x_enc_mark)
enc_out, attentions = self.encoder(
enc_out, attention_mask=enc_self_mask)
dec_out = self.dec_embedding(x_dec, x_dec_mark)
dec_out = self.decoder(
dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
dec_out = self.projection(dec_out)
if self.output_attention:
return dec_out[:, -self.pred_len:, :], attentions
return dec_out[:, -self.pred_len:, :]
class Informer(BaseInformer):
def __init__(
self,
enc_in=7,
dec_in=7,
c_out=7,
out_len=24,
factor=5,
d_model=512,
n_heads=8,
num_encoder_layers=2,
num_decoder_layers=1,
d_ff=2048,
dropout=0.05,
attention_type="prob",
embedding_type="fixed",
frequency="h",
activation="gelu",
output_attention=False,
distil=True,
mix_attention=False,
**kwargs
):
super(Informer, self).__init__(
enc_in,
dec_in,
c_out,
out_len,
factor=factor,
d_model=d_model,
n_heads=n_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ff=d_ff,
dropout=dropout,
attention_type=attention_type,
embedding_type=embedding_type,
frequency=frequency,
activation=activation,
output_attention=output_attention,
distil=distil,
mix_attention=mix_attention,
)
Attention = ProbSparseAttention \
if attention_type == "prob" else FullAttention
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
Attention(False, factor, attention_dropout=dropout,
output_attention=output_attention),
d_model,
n_heads,
mix=False,
),
d_model,
d_ff,
dropout=dropout,
activation=activation,
)
for _ in range(num_encoder_layers)
],
[SelfAttentionDistil(d_model) for _ in range(
num_encoder_layers - 1)] if distil else None,
nn.LayerNorm(d_model),
)
class InformerStack(BaseInformer):
def __init__(
self,
enc_in=7,
dec_in=7,
c_out=7,
out_len=24,
factor=5,
d_model=512,
n_heads=8,
num_encoder_layers=2,
num_decoder_layers=1,
d_ff=2048,
dropout=0.05,
attention_type="prob",
embedding_type="fixed",
frequency="h",
activation="gelu",
output_attention=False,
distil=True,
mix_attention=False,
**kwargs
):
super(InformerStack, self).__init__(
enc_in,
dec_in,
c_out,
out_len,
factor=factor,
d_model=d_model,
n_heads=n_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ff=d_ff,
dropout=dropout,
attention_type=attention_type,
embedding_type=embedding_type,
frequency=frequency,
activation=activation,
output_attention=output_attention,
distil=distil,
mix_attention=mix_attention,
)
Attention = ProbSparseAttention \
if attention_type == "prob" else FullAttention
stacks = list(range(num_encoder_layers, 2, -1)) # customize here
encoders = [
Encoder(
[
EncoderLayer(
AttentionLayer(
Attention(False, factor, attention_dropout=dropout,
output_attention=output_attention),
d_model,
n_heads,
mix=False,
),
d_model,
d_ff,
dropout=dropout,
activation=activation,
)
for _ in range(el)
],
[SelfAttentionDistil(d_model)
for _ in range(el - 1)] if distil else None,
nn.LayerNorm(d_model),
)
for el in stacks
]
self.encoder = EncoderStack(encoders)

View File

@ -1,13 +1,233 @@
from typing import Dict, List, Union, Tuple
import torch
import torch.nn as nn
import numpy as np
import wandb import wandb
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer import functools
import operator
from copy import copy
from pytorch_forecasting.models.nn import MultiEmbedding
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
from pytorch_forecasting.metrics import MAE, RMSE, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer import (
TemporalFusionTransformer)
from informer.attention import (
FullAttention, ProbSparseAttention, AttentionLayer)
from informer.embedding import TokenEmbedding, PositionalEmbedding
from informer.encoder import (
Encoder,
EncoderLayer,
SelfAttentionDistil,
)
from informer.decoder import Decoder, DecoderLayer
# TODO: Maybe save all models on cpu # TODO: Maybe save all models on cpu
def get_model(config, dataset, loss):
model_name = config['model']['name']
if model_name == 'TemporalFusionTransformer':
return TemporalFusionTransformer.from_dataset(
dataset,
hidden_size=config['model']['hidden_size'],
dropout=config['model']['dropout'],
attention_head_size=config['model']['attention_head_size'],
hidden_continuous_size=config['model']['hidden_continuous_size'],
learning_rate=config['model']['learning_rate'],
share_single_variable_networks=False,
loss=loss,
logging_metrics=[MAE(), RMSE()]
)
if model_name == 'Informer':
return Informer.from_dataset(
dataset,
d_model=config['model']['d_model'],
d_fully_connected=config['model']['d_fully_connected'],
n_attention_heads=config['model']['n_attention_heads'],
n_encoder_layers=config['model']['n_encoder_layers'],
n_decoder_layers=config['model']['n_decoder_layers'],
dropout=config['model']['dropout'],
learning_rate=config['model']['learning_rate'],
loss=loss,
embedding_sizes={
name: (len(encoder.classes_), config['model']['d_model'])
for name, encoder in dataset.categorical_encoders.items()
if name in dataset.categoricals
},
logging_metrics=[MAE(), RMSE()]
)
raise ValueError("Unknown model")
def load_model_from_wandb(run): def load_model_from_wandb(run):
model_name = run.config['model']['name'] model_name = run.config['model']['name']
model_path = f"{run.project}/model-{run.id}:best" model_path = f"{run.project}/model-{run.id}:best"
model_artifact = wandb.Api().artifact(model_path) model_artifact = wandb.Api().artifact(model_path)
if model_name == 'TemporalFusionTransformer': if model_name == 'TemporalFusionTransformer':
return TemporalFusionTransformer.load_from_checkpoint(model_artifact.file()) return TemporalFusionTransformer.load_from_checkpoint(
model_artifact.file())
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
class Informer(BaseModelWithCovariates):
def __init__(
self,
d_model=256,
d_fully_connected=512,
n_attention_heads=2,
n_encoder_layers=2,
n_decoder_layers=1,
dropout=0.1,
attention_type="prob",
activation="gelu",
factor=5,
mix_attention=False,
output_attention=False,
distil=True,
x_reals: List[str] = [],
x_categoricals: List[str] = [],
static_categoricals: List[str] = [],
static_reals: List[str] = [],
time_varying_reals_encoder: List[str] = [],
time_varying_reals_decoder: List[str] = [],
time_varying_categoricals_encoder: List[str] = [],
time_varying_categoricals_decoder: List[str] = [],
embedding_sizes: Dict[str, Tuple[int, int]] = {},
embedding_paddings: List[str] = [],
embedding_labels: Dict[str, np.ndarray] = {},
categorical_groups: Dict[str, List[str]] = {},
output_size: Union[int, List[int]] = 1,
loss=None,
logging_metrics: nn.ModuleList = None,
**kwargs):
super().__init__(
loss=loss,
logging_metrics=logging_metrics,
**kwargs)
self.save_hyperparameters(ignore=['loss'])
self.attention_type = attention_type
assert not static_reals
assert not static_categoricals
self.cat_embeddings = MultiEmbedding(
embedding_sizes=embedding_sizes,
embedding_paddings=embedding_paddings,
categorical_groups=categorical_groups,
x_categoricals=x_categoricals,
)
self.enc_real_embeddings = TokenEmbedding(
len(time_varying_reals_encoder), d_model)
self.enc_positional_embeddings = PositionalEmbedding(d_model)
self.dec_real_embeddings = TokenEmbedding(
len(time_varying_reals_decoder), d_model)
self.dec_positional_embeddings = PositionalEmbedding(d_model)
Attention = ProbSparseAttention \
if attention_type == "prob" else FullAttention
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
Attention(False, factor, attention_dropout=dropout,
output_attention=output_attention),
d_model,
n_attention_heads,
mix=False,
),
d_model,
d_fully_connected,
dropout=dropout,
activation=activation,
)
for _ in range(n_encoder_layers)
],
[SelfAttentionDistil(d_model) for _ in range(
n_encoder_layers - 1)] if distil else None,
nn.LayerNorm(d_model),
)
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
Attention(True, factor, attention_dropout=dropout,
output_attention=False),
d_model,
n_attention_heads,
mix=mix_attention,
),
AttentionLayer(
FullAttention(
False,
factor,
attention_dropout=dropout,
output_attention=False),
d_model,
n_attention_heads,
mix=False,
),
d_model,
d_fully_connected,
dropout=dropout,
activation=activation,
)
for _ in range(n_decoder_layers)
],
nn.LayerNorm(d_model),
)
self.projection = nn.Linear(d_model, output_size)
def forward(
self,
x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
decoder_length = x['decoder_lengths'].max()
enc_out =\
self.enc_real_embeddings(x['encoder_cont']) +\
self.enc_positional_embeddings(x['encoder_cont']) +\
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
x['encoder_cat']).values()])
enc_out, attentions = self.encoder(enc_out)
# Hacky solution to get only known reals,
# they are always stacked first.
# TODO: Make sure no unknown reals are passed to decoder.
dec_out =\
self.dec_real_embeddings(x['decoder_cont'][..., :len(
self.hparams.time_varying_reals_decoder)]) +\
self.dec_positional_embeddings(x['decoder_cont']) +\
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
x['decoder_cat']).values()])
dec_out = self.decoder(dec_out, enc_out)
output = self.projection(dec_out)
output = output[:, -decoder_length:, :]
output = self.transform_output(
output, target_scale=x['target_scale'])
return self.to_network_output(prediction=output)
@classmethod
def from_dataset(
cls,
dataset,
**kwargs
):
new_kwargs = copy(kwargs)
new_kwargs.update(cls.deduce_default_output_parameters(
dataset, kwargs, QuantileLoss()))
return super().from_dataset(dataset, **new_kwargs)