Add informer model
This commit is contained in:
parent
03bd6ad9d8
commit
73c513c217
68
configs/experiments/informer-btcusdt-quantile.yaml
Normal file
68
configs/experiments/informer-btcusdt-quantile.yaml
Normal file
@ -0,0 +1,68 @@
|
||||
future_window:
|
||||
value: 5
|
||||
past_window:
|
||||
value: 48
|
||||
batch_size:
|
||||
value: 64
|
||||
max_epochs:
|
||||
value: 1
|
||||
data:
|
||||
value:
|
||||
dataset: "btc-usdt-5m:latest"
|
||||
sliding_window: 4
|
||||
validation: 0.2
|
||||
fields:
|
||||
time_index: "time_index"
|
||||
target: "close_price"
|
||||
group_ids: ["group_id"]
|
||||
dynamic_unknown_real:
|
||||
- "high_price"
|
||||
- "low_price"
|
||||
- "open_price"
|
||||
- "close_price"
|
||||
- "volume"
|
||||
- "open_to_close_price"
|
||||
- "high_to_close_price"
|
||||
- "low_to_close_price"
|
||||
- "high_to_low_price"
|
||||
- "returns"
|
||||
- "log_returns"
|
||||
- "vol_1h"
|
||||
- "macd"
|
||||
- "macd_signal"
|
||||
- "rsi"
|
||||
- "low_bband_to_close_price"
|
||||
- "up_bband_to_close_price"
|
||||
- "mid_bband_to_close_price"
|
||||
- "sma_1h_to_close_price"
|
||||
- "sma_1d_to_close_price"
|
||||
- "sma_7d_to_close_price"
|
||||
- "ema_1h_to_close_price"
|
||||
- "ema_1d_to_close_price"
|
||||
dynamic_unknown_cat: []
|
||||
dynamic_known_real:
|
||||
- "effective_rates"
|
||||
- "vix_close_price"
|
||||
- "fear_greed_index"
|
||||
- "vol_1d"
|
||||
- "vol_7d"
|
||||
dynamic_known_cat:
|
||||
- "hour"
|
||||
- "weekday"
|
||||
static_real: []
|
||||
static_cat: []
|
||||
loss:
|
||||
value:
|
||||
name: "Quantile"
|
||||
quantiles: [0.02, 0.1, 0.5, 0.9, 0.98]
|
||||
model:
|
||||
value:
|
||||
name: "Informer"
|
||||
d_model: 256
|
||||
d_fully_connected: 512
|
||||
n_attention_heads: 2
|
||||
dropout: 0.1
|
||||
n_encoder_layers: 2
|
||||
n_decoder_layers: 1
|
||||
learning_rate: 0.001
|
||||
optimizer: "Adam"
|
||||
@ -4,6 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [
|
||||
"src/informer",
|
||||
"src/ml",
|
||||
"src/strategy"
|
||||
]
|
||||
|
||||
@ -11,12 +11,10 @@ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||
from pytorch_forecasting.metrics import MAE, RMSE
|
||||
from pytorch_forecasting import QuantileLoss
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import (
|
||||
TemporalFusionTransformer)
|
||||
|
||||
from ml.loss import GMADL
|
||||
from ml.model import get_model
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -146,25 +144,6 @@ def get_loss(config):
|
||||
raise ValueError("Unknown loss")
|
||||
|
||||
|
||||
def get_model(config, dataset, loss):
|
||||
model_name = config['model']['name']
|
||||
|
||||
if model_name == 'TemporalFusionTransformer':
|
||||
return TemporalFusionTransformer.from_dataset(
|
||||
dataset,
|
||||
hidden_size=config['model']['hidden_size'],
|
||||
dropout=config['model']['dropout'],
|
||||
attention_head_size=config['model']['attention_head_size'],
|
||||
hidden_continuous_size=config['model']['hidden_continuous_size'],
|
||||
learning_rate=config['model']['learning_rate'],
|
||||
share_single_variable_networks=False,
|
||||
loss=loss,
|
||||
logging_metrics=[MAE(), RMSE()]
|
||||
)
|
||||
|
||||
raise ValueError("Unknown model")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=args.log_level)
|
||||
|
||||
1
src/informer/README.md
Normal file
1
src/informer/README.md
Normal file
@ -0,0 +1 @@
|
||||
Copied from https://github.com/martinwhl/Informer-PyTorch-Lightning
|
||||
0
src/informer/__init__.py
Normal file
0
src/informer/__init__.py
Normal file
190
src/informer/attention.py
Normal file
190
src/informer/attention.py
Normal file
@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from informer.masking import triangular_causal_mask, prob_mask
|
||||
|
||||
|
||||
class FullAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mask_flag=True,
|
||||
scale=None,
|
||||
attention_dropout=0.1,
|
||||
output_attention=False, **kwargs):
|
||||
super(FullAttention, self).__init__()
|
||||
self.mask_flag = mask_flag
|
||||
self.scale = scale
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, attention_mask):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
scale = self.scale or 1.0 / math.sqrt(E)
|
||||
|
||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
||||
if self.mask_flag:
|
||||
if attention_mask is None:
|
||||
attention_mask = triangular_causal_mask(
|
||||
B, L, device=queries.device)
|
||||
scores.masked_fill_(attention_mask, -np.inf)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||
|
||||
if self.output_attention:
|
||||
return V.contiguous(), A
|
||||
return V.contiguous(), None
|
||||
|
||||
|
||||
class ProbSparseAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mask_flag=True,
|
||||
factor=5,
|
||||
scale=None,
|
||||
attention_dropout=0.1,
|
||||
output_attention=False,
|
||||
):
|
||||
super(ProbSparseAttention, self).__init__()
|
||||
self.mask_flag = mask_flag
|
||||
self.factor = factor
|
||||
self.scale = scale
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, attention_mask):
|
||||
B, L_Q, H, D = queries.shape
|
||||
_, L_K, _, _ = keys.shape
|
||||
|
||||
queries = torch.transpose(queries, 2, 1)
|
||||
keys = torch.transpose(keys, 2, 1)
|
||||
values = torch.transpose(values, 2, 1)
|
||||
|
||||
U_part = int(self.factor * math.ceil(math.log(L_K))) # c * ln(L_K)
|
||||
u = int(self.factor * math.ceil(math.log(L_Q))) # c * ln(L_Q)
|
||||
|
||||
U_part = U_part if U_part < L_K else L_K
|
||||
u = u if u < L_Q else L_Q
|
||||
|
||||
scores_top, index = self._prob_QK(
|
||||
queries, keys, sample_k=U_part, n_top=u)
|
||||
|
||||
scale = self.scale or 1.0 / math.sqrt(D)
|
||||
if scale is not None:
|
||||
scores_top = scores_top * scale
|
||||
|
||||
context = self._get_initial_context(values, L_Q)
|
||||
# update the context with selected top_k queries
|
||||
context, attention = self._update_context(
|
||||
context, values, scores_top, index, L_Q, attention_mask)
|
||||
|
||||
return context.transpose(2, 1).contiguous(), attention
|
||||
|
||||
def _prob_QK(self, queries, keys, sample_k, n_top):
|
||||
B, H, L_K, E = keys.shape
|
||||
_, _, L_Q, _ = queries.shape
|
||||
|
||||
# calculate the sampled Q_K
|
||||
K_expand = keys.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
||||
# real U = U_part(factor * ln(L_K)) * L_Q
|
||||
index_sample = torch.randint(L_K, (L_Q, sample_k))
|
||||
K_sample = K_expand[:, :, torch.arange(
|
||||
L_Q).unsqueeze(1), index_sample, :]
|
||||
Q_K_sample = (queries.unsqueeze(-2) @
|
||||
K_sample.transpose(-2, -1)).squeeze()
|
||||
|
||||
# find the top_k query with sparsity measurement
|
||||
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
|
||||
M_top = M.topk(n_top, sorted=False)[1]
|
||||
|
||||
# use the reduced Q to calculate Q_K
|
||||
Q_reduce = queries[torch.arange(B)[:, None, None], torch.arange(H)[
|
||||
None, :, None], M_top, :] # factor * ln(L_Q)
|
||||
Q_K = Q_reduce @ keys.transpose(-2, -1) # factor * ln(L_Q) * L_K
|
||||
|
||||
return Q_K, M_top
|
||||
|
||||
def _get_initial_context(self, values, L_Q):
|
||||
B, H, L_V, D = values.shape
|
||||
if not self.mask_flag:
|
||||
V_mean = values.mean(dim=-2)
|
||||
context = \
|
||||
V_mean.unsqueeze(-2).expand(B,
|
||||
H, L_Q, V_mean.size(-1)).clone()
|
||||
else:
|
||||
# requires that L_Q == L_V, i.e. for self-attention only
|
||||
assert L_Q == L_V
|
||||
context = values.cumsum(dim=-2)
|
||||
return context
|
||||
|
||||
def _update_context(
|
||||
self,
|
||||
context,
|
||||
values,
|
||||
scores,
|
||||
index,
|
||||
L_Q,
|
||||
attention_mask):
|
||||
B, H, L_V, D = values.shape
|
||||
|
||||
if self.mask_flag:
|
||||
attention_mask = prob_mask(
|
||||
B, H, L_Q, index, scores, device=values.device)
|
||||
scores.masked_fill_(attention_mask, -np.inf)
|
||||
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
|
||||
context[
|
||||
torch.arange(B)[:, None, None],
|
||||
torch.arange(H)[None, :, None],
|
||||
index, :] = (
|
||||
attention @ values
|
||||
).type_as(context)
|
||||
if self.output_attention:
|
||||
attentions = (torch.ones(B, H, L_V, L_V) / L_V).type_as(attention)
|
||||
attentions[torch.arange(B)[:, None, None], torch.arange(H)[
|
||||
None, :, None], index, :] = attention
|
||||
return context, attentions
|
||||
return context, None
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention,
|
||||
d_model,
|
||||
n_heads,
|
||||
d_keys=None,
|
||||
d_values=None,
|
||||
mix=False):
|
||||
super(AttentionLayer, self).__init__()
|
||||
d_keys = d_keys or (d_model // n_heads)
|
||||
d_values = d_values or (d_model // n_heads)
|
||||
|
||||
self.inner_attention = attention
|
||||
self.query_attention = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, queries, keys, values, attention_mask):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
queries = self.query_attention(queries).view(B, L, H, -1)
|
||||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||
values = self.value_projection(values).view(B, S, H, -1)
|
||||
|
||||
out, attention = self.inner_attention(
|
||||
queries, keys, values, attention_mask)
|
||||
if self.mix:
|
||||
out = out.transpose(2, 1).contiguous()
|
||||
out = out.view(B, L, -1)
|
||||
|
||||
return self.out_projection(out), attention
|
||||
53
src/informer/decoder.py
Normal file
53
src/informer/decoder.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
self_attention,
|
||||
cross_attention,
|
||||
d_model,
|
||||
d_ff=None,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
):
|
||||
super(DecoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.self_attention = self_attention
|
||||
self.cross_attention = cross_attention
|
||||
self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
||||
x = x + self.dropout(self.self_attention(x, x, x,
|
||||
attention_mask=x_mask)[0])
|
||||
x = self.norm1(x)
|
||||
|
||||
x = x + self.dropout(self.cross_attention(x, cross,
|
||||
cross, attention_mask=cross_mask)[0])
|
||||
|
||||
y = x = self.norm2(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm3(x + y)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, layers, norm_layer=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
132
src/informer/embedding.py
Normal file
132
src/informer/embedding.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_length=5000):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
embedding = torch.zeros(max_length, d_model)
|
||||
|
||||
position = torch.arange(0, max_length).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
embedding[:, 0::2] = torch.sin(position * div_term)
|
||||
embedding[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
embedding = embedding.unsqueeze(0)
|
||||
self.register_buffer("embedding", embedding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding[:, : x.size(1)]
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__ >= "1.5.0" else 2
|
||||
self.token_conv = nn.Conv1d(
|
||||
in_channels=c_in,
|
||||
out_channels=d_model,
|
||||
kernel_size=3,
|
||||
padding=padding,
|
||||
padding_mode="circular",
|
||||
)
|
||||
nn.init.kaiming_normal_(self.token_conv.weight,
|
||||
mode="fan_in", nonlinearity="leaky_relu")
|
||||
|
||||
def forward(self, x):
|
||||
return self.token_conv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
weight = torch.zeros(c_in, d_model)
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float()
|
||||
* -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
weight[:, 0::2] = torch.sin(position * div_term)
|
||||
weight[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.embedding = nn.Embedding(c_in, d_model)
|
||||
self.embedding.weight = nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding(x).detach()
|
||||
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embedding_type="fixed", frequency="h"):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
MINUTE_SIZE = 4
|
||||
HOUR_SIZE = 24
|
||||
WEEKDAY_SIZE = 7
|
||||
DAY_SIZE = 32
|
||||
MONTH_SIZE = 13
|
||||
|
||||
Embedding = FixedEmbedding \
|
||||
if embedding_type == "fixed" else nn.Embedding
|
||||
if frequency == "t":
|
||||
self.minute_embedding = Embedding(MINUTE_SIZE, d_model)
|
||||
self.hour_embedding = Embedding(HOUR_SIZE, d_model)
|
||||
self.weekday_embedding = Embedding(WEEKDAY_SIZE, d_model)
|
||||
self.day_embedding = Embedding(DAY_SIZE, d_model)
|
||||
self.month_embedding = Embedding(MONTH_SIZE, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
|
||||
minute_x = self.minute_embedding(x[:, :, 4]) if hasattr(
|
||||
self, "minute_embedding") else 0.0
|
||||
hour_x = self.hour_embedding(x[:, :, 3])
|
||||
weekday_x = self.weekday_embedding(x[:, :, 2])
|
||||
day_x = self.day_embedding(x[:, :, 1])
|
||||
month_x = self.month_embedding(x[:, :, 0])
|
||||
|
||||
return minute_x + hour_x + weekday_x + day_x + month_x
|
||||
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, frequency="h"):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
FREQUENCY_MAP = {"h": 4, "t": 5, "s": 6,
|
||||
"m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
|
||||
d_input = FREQUENCY_MAP[frequency]
|
||||
self.embedding = nn.Linear(d_input, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embedding(x)
|
||||
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_in,
|
||||
d_model,
|
||||
embedding_type="fixed",
|
||||
frequency="h",
|
||||
dropout=0.1
|
||||
):
|
||||
super(DataEmbedding, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in, d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
if embedding_type != "timefeature":
|
||||
self.temporal_embedding = TemporalEmbedding(
|
||||
d_model, embedding_type=embedding_type, frequency=frequency)
|
||||
else:
|
||||
self.temporal_embedding = TimeFeatureEmbedding(
|
||||
d_model, frequency=frequency)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = self.value_embedding(
|
||||
x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
101
src/informer/encoder.py
Normal file
101
src/informer/encoder.py
Normal file
@ -0,0 +1,101 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SelfAttentionDistil(nn.Module):
|
||||
def __init__(self, c_in):
|
||||
super(SelfAttentionDistil, self).__init__()
|
||||
self.conv = nn.Conv1d(c_in, c_in, kernel_size=3,
|
||||
padding=2, padding_mode="circular")
|
||||
self.norm = nn.BatchNorm1d(c_in)
|
||||
self.activation = nn.ELU()
|
||||
self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x.permute(0, 2, 1))
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
x = self.max_pool(x)
|
||||
x = torch.transpose(x, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention,
|
||||
d_model,
|
||||
d_ff=None,
|
||||
dropout=0.1,
|
||||
activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(d_model, d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(d_ff, d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
new_x, attention = self.attention(
|
||||
x, x, x, attention_mask=attention_mask)
|
||||
x = x + self.dropout(new_x)
|
||||
|
||||
y = x = self.norm1(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm2(x + y), attention
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attention_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attention_layers = nn.ModuleList(attention_layers)
|
||||
self.conv_layers = nn.ModuleList(
|
||||
conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
attentions = []
|
||||
if self.conv_layers is not None:
|
||||
for attention_layer, conv_layer in zip(
|
||||
self.attention_layers, self.conv_layers):
|
||||
x, attention = attention_layer(
|
||||
x, attention_mask=attention_mask)
|
||||
x = conv_layer(x)
|
||||
attentions.append(attention)
|
||||
x, attention = self.attention_layers[-1](x)
|
||||
attentions.append(attention)
|
||||
else:
|
||||
for attention_layer in self.attention_layers:
|
||||
x, attention = attention_layer(
|
||||
x, attention_mask=attention_mask)
|
||||
attentions.append(attention)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, attentions
|
||||
|
||||
|
||||
class EncoderStack(nn.Module):
|
||||
def __init__(self, encoders):
|
||||
super(EncoderStack).__init__()
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
inp_len = x.size(1)
|
||||
x_stack = []
|
||||
attentions = []
|
||||
for encoder in self.encoders:
|
||||
if encoder is None:
|
||||
inp_len //= 2
|
||||
continue
|
||||
x, attention = encoder(x[:, -inp_len:, :])
|
||||
x_stack.append(x)
|
||||
attentions.append(attention)
|
||||
inp_len //= 2
|
||||
x_stack = torch.cat(x_stack, -2)
|
||||
return x_stack, attentions
|
||||
19
src/informer/masking.py
Normal file
19
src/informer/masking.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
|
||||
def triangular_causal_mask(B, L, device=torch.device("cpu")):
|
||||
mask_shape = [B, 1, L, L]
|
||||
with torch.no_grad():
|
||||
mask = torch.triu(torch.ones(
|
||||
mask_shape, dtype=torch.bool, device=device), diagonal=1)
|
||||
return mask
|
||||
|
||||
|
||||
def prob_mask(B, H, L, index, scores, device=torch.device("cpu")):
|
||||
mask = torch.ones(L, scores.shape[-1],
|
||||
dtype=torch.bool, device=device).triu(1)
|
||||
mask_ex = mask[None, None, :].expand(B, H, L, scores.shape[-1])
|
||||
indicator = mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[
|
||||
None, :, None], index, :]
|
||||
mask = indicator.view(scores.shape)
|
||||
return mask
|
||||
246
src/informer/model.py
Normal file
246
src/informer/model.py
Normal file
@ -0,0 +1,246 @@
|
||||
import torch.nn as nn
|
||||
from informer.attention import (
|
||||
FullAttention, ProbSparseAttention, AttentionLayer)
|
||||
from informer.embedding import DataEmbedding
|
||||
from informer.encoder import (
|
||||
Encoder,
|
||||
EncoderLayer,
|
||||
EncoderStack,
|
||||
SelfAttentionDistil,
|
||||
)
|
||||
from informer.decoder import Decoder, DecoderLayer
|
||||
|
||||
|
||||
class BaseInformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
enc_in=7,
|
||||
dec_in=7,
|
||||
c_out=7,
|
||||
out_len=24,
|
||||
factor=5,
|
||||
d_model=512,
|
||||
n_heads=8,
|
||||
num_encoder_layers=2,
|
||||
num_decoder_layers=1,
|
||||
d_ff=2048,
|
||||
dropout=0.05,
|
||||
attention_type="prob",
|
||||
embedding_type="fixed",
|
||||
frequency="h",
|
||||
activation="gelu",
|
||||
output_attention=False,
|
||||
distil=True,
|
||||
mix_attention=False,
|
||||
**kwargs
|
||||
):
|
||||
super(BaseInformer, self).__init__()
|
||||
self.pred_len = out_len
|
||||
self.attention_type = attention_type
|
||||
self.output_attention = output_attention
|
||||
|
||||
self.enc_embedding = DataEmbedding(
|
||||
enc_in, d_model, embedding_type, frequency, dropout)
|
||||
self.dec_embedding = DataEmbedding(
|
||||
dec_in, d_model, embedding_type, frequency, dropout)
|
||||
|
||||
Attention = ProbSparseAttention \
|
||||
if attention_type == "prob" else FullAttention
|
||||
|
||||
self.encoder = None
|
||||
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(
|
||||
Attention(True, factor, attention_dropout=dropout,
|
||||
output_attention=False),
|
||||
d_model,
|
||||
n_heads,
|
||||
mix=mix_attention,
|
||||
),
|
||||
AttentionLayer(
|
||||
FullAttention(
|
||||
False,
|
||||
factor,
|
||||
attention_dropout=dropout,
|
||||
output_attention=False),
|
||||
d_model,
|
||||
n_heads,
|
||||
mix=False,
|
||||
),
|
||||
d_model,
|
||||
d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
for _ in range(num_decoder_layers)
|
||||
],
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
self.projection = nn.Linear(d_model, c_out)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_enc,
|
||||
x_enc_mark,
|
||||
x_dec,
|
||||
x_dec_mark,
|
||||
enc_self_mask=None,
|
||||
dec_self_mask=None,
|
||||
dec_enc_mask=None,
|
||||
):
|
||||
enc_out = self.enc_embedding(x_enc, x_enc_mark)
|
||||
enc_out, attentions = self.encoder(
|
||||
enc_out, attention_mask=enc_self_mask)
|
||||
|
||||
dec_out = self.dec_embedding(x_dec, x_dec_mark)
|
||||
dec_out = self.decoder(
|
||||
dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
||||
dec_out = self.projection(dec_out)
|
||||
|
||||
if self.output_attention:
|
||||
return dec_out[:, -self.pred_len:, :], attentions
|
||||
return dec_out[:, -self.pred_len:, :]
|
||||
|
||||
|
||||
class Informer(BaseInformer):
|
||||
def __init__(
|
||||
self,
|
||||
enc_in=7,
|
||||
dec_in=7,
|
||||
c_out=7,
|
||||
out_len=24,
|
||||
factor=5,
|
||||
d_model=512,
|
||||
n_heads=8,
|
||||
num_encoder_layers=2,
|
||||
num_decoder_layers=1,
|
||||
d_ff=2048,
|
||||
dropout=0.05,
|
||||
attention_type="prob",
|
||||
embedding_type="fixed",
|
||||
frequency="h",
|
||||
activation="gelu",
|
||||
output_attention=False,
|
||||
distil=True,
|
||||
mix_attention=False,
|
||||
**kwargs
|
||||
):
|
||||
super(Informer, self).__init__(
|
||||
enc_in,
|
||||
dec_in,
|
||||
c_out,
|
||||
out_len,
|
||||
factor=factor,
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
attention_type=attention_type,
|
||||
embedding_type=embedding_type,
|
||||
frequency=frequency,
|
||||
activation=activation,
|
||||
output_attention=output_attention,
|
||||
distil=distil,
|
||||
mix_attention=mix_attention,
|
||||
)
|
||||
Attention = ProbSparseAttention \
|
||||
if attention_type == "prob" else FullAttention
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
Attention(False, factor, attention_dropout=dropout,
|
||||
output_attention=output_attention),
|
||||
d_model,
|
||||
n_heads,
|
||||
mix=False,
|
||||
),
|
||||
d_model,
|
||||
d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
for _ in range(num_encoder_layers)
|
||||
],
|
||||
[SelfAttentionDistil(d_model) for _ in range(
|
||||
num_encoder_layers - 1)] if distil else None,
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
|
||||
class InformerStack(BaseInformer):
|
||||
def __init__(
|
||||
self,
|
||||
enc_in=7,
|
||||
dec_in=7,
|
||||
c_out=7,
|
||||
out_len=24,
|
||||
factor=5,
|
||||
d_model=512,
|
||||
n_heads=8,
|
||||
num_encoder_layers=2,
|
||||
num_decoder_layers=1,
|
||||
d_ff=2048,
|
||||
dropout=0.05,
|
||||
attention_type="prob",
|
||||
embedding_type="fixed",
|
||||
frequency="h",
|
||||
activation="gelu",
|
||||
output_attention=False,
|
||||
distil=True,
|
||||
mix_attention=False,
|
||||
**kwargs
|
||||
):
|
||||
super(InformerStack, self).__init__(
|
||||
enc_in,
|
||||
dec_in,
|
||||
c_out,
|
||||
out_len,
|
||||
factor=factor,
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
d_ff=d_ff,
|
||||
dropout=dropout,
|
||||
attention_type=attention_type,
|
||||
embedding_type=embedding_type,
|
||||
frequency=frequency,
|
||||
activation=activation,
|
||||
output_attention=output_attention,
|
||||
distil=distil,
|
||||
mix_attention=mix_attention,
|
||||
)
|
||||
Attention = ProbSparseAttention \
|
||||
if attention_type == "prob" else FullAttention
|
||||
stacks = list(range(num_encoder_layers, 2, -1)) # customize here
|
||||
encoders = [
|
||||
Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
Attention(False, factor, attention_dropout=dropout,
|
||||
output_attention=output_attention),
|
||||
d_model,
|
||||
n_heads,
|
||||
mix=False,
|
||||
),
|
||||
d_model,
|
||||
d_ff,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
for _ in range(el)
|
||||
],
|
||||
[SelfAttentionDistil(d_model)
|
||||
for _ in range(el - 1)] if distil else None,
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
for el in stacks
|
||||
]
|
||||
self.encoder = EncoderStack(encoders)
|
||||
224
src/ml/model.py
224
src/ml/model.py
@ -1,13 +1,233 @@
|
||||
from typing import Dict, List, Union, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import wandb
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
||||
import functools
|
||||
import operator
|
||||
from copy import copy
|
||||
|
||||
from pytorch_forecasting.models.nn import MultiEmbedding
|
||||
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
|
||||
from pytorch_forecasting.metrics import MAE, RMSE, QuantileLoss
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import (
|
||||
TemporalFusionTransformer)
|
||||
|
||||
from informer.attention import (
|
||||
FullAttention, ProbSparseAttention, AttentionLayer)
|
||||
from informer.embedding import TokenEmbedding, PositionalEmbedding
|
||||
from informer.encoder import (
|
||||
Encoder,
|
||||
EncoderLayer,
|
||||
SelfAttentionDistil,
|
||||
)
|
||||
from informer.decoder import Decoder, DecoderLayer
|
||||
|
||||
# TODO: Maybe save all models on cpu
|
||||
|
||||
|
||||
def get_model(config, dataset, loss):
|
||||
model_name = config['model']['name']
|
||||
|
||||
if model_name == 'TemporalFusionTransformer':
|
||||
return TemporalFusionTransformer.from_dataset(
|
||||
dataset,
|
||||
hidden_size=config['model']['hidden_size'],
|
||||
dropout=config['model']['dropout'],
|
||||
attention_head_size=config['model']['attention_head_size'],
|
||||
hidden_continuous_size=config['model']['hidden_continuous_size'],
|
||||
learning_rate=config['model']['learning_rate'],
|
||||
share_single_variable_networks=False,
|
||||
loss=loss,
|
||||
logging_metrics=[MAE(), RMSE()]
|
||||
)
|
||||
|
||||
if model_name == 'Informer':
|
||||
return Informer.from_dataset(
|
||||
dataset,
|
||||
d_model=config['model']['d_model'],
|
||||
d_fully_connected=config['model']['d_fully_connected'],
|
||||
n_attention_heads=config['model']['n_attention_heads'],
|
||||
n_encoder_layers=config['model']['n_encoder_layers'],
|
||||
n_decoder_layers=config['model']['n_decoder_layers'],
|
||||
dropout=config['model']['dropout'],
|
||||
learning_rate=config['model']['learning_rate'],
|
||||
loss=loss,
|
||||
embedding_sizes={
|
||||
name: (len(encoder.classes_), config['model']['d_model'])
|
||||
for name, encoder in dataset.categorical_encoders.items()
|
||||
if name in dataset.categoricals
|
||||
},
|
||||
logging_metrics=[MAE(), RMSE()]
|
||||
)
|
||||
|
||||
raise ValueError("Unknown model")
|
||||
|
||||
|
||||
def load_model_from_wandb(run):
|
||||
model_name = run.config['model']['name']
|
||||
model_path = f"{run.project}/model-{run.id}:best"
|
||||
model_artifact = wandb.Api().artifact(model_path)
|
||||
|
||||
if model_name == 'TemporalFusionTransformer':
|
||||
return TemporalFusionTransformer.load_from_checkpoint(model_artifact.file())
|
||||
return TemporalFusionTransformer.load_from_checkpoint(
|
||||
model_artifact.file())
|
||||
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
|
||||
class Informer(BaseModelWithCovariates):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=256,
|
||||
d_fully_connected=512,
|
||||
n_attention_heads=2,
|
||||
n_encoder_layers=2,
|
||||
n_decoder_layers=1,
|
||||
dropout=0.1,
|
||||
attention_type="prob",
|
||||
activation="gelu",
|
||||
factor=5,
|
||||
mix_attention=False,
|
||||
output_attention=False,
|
||||
distil=True,
|
||||
x_reals: List[str] = [],
|
||||
x_categoricals: List[str] = [],
|
||||
static_categoricals: List[str] = [],
|
||||
static_reals: List[str] = [],
|
||||
time_varying_reals_encoder: List[str] = [],
|
||||
time_varying_reals_decoder: List[str] = [],
|
||||
time_varying_categoricals_encoder: List[str] = [],
|
||||
time_varying_categoricals_decoder: List[str] = [],
|
||||
embedding_sizes: Dict[str, Tuple[int, int]] = {},
|
||||
embedding_paddings: List[str] = [],
|
||||
embedding_labels: Dict[str, np.ndarray] = {},
|
||||
categorical_groups: Dict[str, List[str]] = {},
|
||||
output_size: Union[int, List[int]] = 1,
|
||||
loss=None,
|
||||
logging_metrics: nn.ModuleList = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
loss=loss,
|
||||
logging_metrics=logging_metrics,
|
||||
**kwargs)
|
||||
self.save_hyperparameters(ignore=['loss'])
|
||||
self.attention_type = attention_type
|
||||
|
||||
assert not static_reals
|
||||
assert not static_categoricals
|
||||
|
||||
self.cat_embeddings = MultiEmbedding(
|
||||
embedding_sizes=embedding_sizes,
|
||||
embedding_paddings=embedding_paddings,
|
||||
categorical_groups=categorical_groups,
|
||||
x_categoricals=x_categoricals,
|
||||
)
|
||||
|
||||
self.enc_real_embeddings = TokenEmbedding(
|
||||
len(time_varying_reals_encoder), d_model)
|
||||
self.enc_positional_embeddings = PositionalEmbedding(d_model)
|
||||
|
||||
self.dec_real_embeddings = TokenEmbedding(
|
||||
len(time_varying_reals_decoder), d_model)
|
||||
self.dec_positional_embeddings = PositionalEmbedding(d_model)
|
||||
|
||||
Attention = ProbSparseAttention \
|
||||
if attention_type == "prob" else FullAttention
|
||||
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
Attention(False, factor, attention_dropout=dropout,
|
||||
output_attention=output_attention),
|
||||
d_model,
|
||||
n_attention_heads,
|
||||
mix=False,
|
||||
),
|
||||
d_model,
|
||||
d_fully_connected,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
for _ in range(n_encoder_layers)
|
||||
],
|
||||
[SelfAttentionDistil(d_model) for _ in range(
|
||||
n_encoder_layers - 1)] if distil else None,
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(
|
||||
Attention(True, factor, attention_dropout=dropout,
|
||||
output_attention=False),
|
||||
d_model,
|
||||
n_attention_heads,
|
||||
mix=mix_attention,
|
||||
),
|
||||
AttentionLayer(
|
||||
FullAttention(
|
||||
False,
|
||||
factor,
|
||||
attention_dropout=dropout,
|
||||
output_attention=False),
|
||||
d_model,
|
||||
n_attention_heads,
|
||||
mix=False,
|
||||
),
|
||||
d_model,
|
||||
d_fully_connected,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
)
|
||||
for _ in range(n_decoder_layers)
|
||||
],
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
self.projection = nn.Linear(d_model, output_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
|
||||
decoder_length = x['decoder_lengths'].max()
|
||||
|
||||
enc_out =\
|
||||
self.enc_real_embeddings(x['encoder_cont']) +\
|
||||
self.enc_positional_embeddings(x['encoder_cont']) +\
|
||||
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
|
||||
x['encoder_cat']).values()])
|
||||
enc_out, attentions = self.encoder(enc_out)
|
||||
|
||||
# Hacky solution to get only known reals,
|
||||
# they are always stacked first.
|
||||
# TODO: Make sure no unknown reals are passed to decoder.
|
||||
dec_out =\
|
||||
self.dec_real_embeddings(x['decoder_cont'][..., :len(
|
||||
self.hparams.time_varying_reals_decoder)]) +\
|
||||
self.dec_positional_embeddings(x['decoder_cont']) +\
|
||||
functools.reduce(operator.add, [emb for emb in self.cat_embeddings(
|
||||
x['decoder_cat']).values()])
|
||||
dec_out = self.decoder(dec_out, enc_out)
|
||||
|
||||
output = self.projection(dec_out)
|
||||
output = output[:, -decoder_length:, :]
|
||||
output = self.transform_output(
|
||||
output, target_scale=x['target_scale'])
|
||||
return self.to_network_output(prediction=output)
|
||||
|
||||
@classmethod
|
||||
def from_dataset(
|
||||
cls,
|
||||
dataset,
|
||||
**kwargs
|
||||
):
|
||||
new_kwargs = copy(kwargs)
|
||||
new_kwargs.update(cls.deduce_default_output_parameters(
|
||||
dataset, kwargs, QuantileLoss()))
|
||||
|
||||
return super().from_dataset(dataset, **new_kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user