Add gmadl loss function
This commit is contained in:
parent
6923b6cb05
commit
dbd99bba3b
67
configs/experiments/temporal-fusion-btcusdt-gmadl.yaml
Normal file
67
configs/experiments/temporal-fusion-btcusdt-gmadl.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
future_window:
|
||||
value: 1
|
||||
past_window:
|
||||
value: 24
|
||||
batch_size:
|
||||
value: 64
|
||||
max_epochs:
|
||||
value: 30
|
||||
data:
|
||||
value:
|
||||
dataset: "btc-usdt-5m:latest"
|
||||
sliding_window: 4
|
||||
validation: 0.2
|
||||
fields:
|
||||
time_index: "time_index"
|
||||
target: "returns"
|
||||
group_ids: ["group_id"]
|
||||
dynamic_unknown_real:
|
||||
- "high_price"
|
||||
- "low_price"
|
||||
- "open_price"
|
||||
- "close_price"
|
||||
- "volume"
|
||||
- "open_to_close_price"
|
||||
- "high_to_close_price"
|
||||
- "low_to_close_price"
|
||||
- "high_to_low_price"
|
||||
- "returns"
|
||||
- "log_returns"
|
||||
- "vol_1h"
|
||||
- "macd"
|
||||
- "macd_signal"
|
||||
- "rsi"
|
||||
- "low_bband_to_close_price"
|
||||
- "up_bband_to_close_price"
|
||||
- "mid_bband_to_close_price"
|
||||
- "sma_1h_to_close_price"
|
||||
- "sma_1d_to_close_price"
|
||||
- "sma_7d_to_close_price"
|
||||
- "ema_1h_to_close_price"
|
||||
- "ema_1d_to_close_price"
|
||||
dynamic_unknown_cat: []
|
||||
dynamic_known_real: []
|
||||
dynamic_known_cat:
|
||||
- "hour"
|
||||
static_real:
|
||||
- "effective_rates"
|
||||
- "vix_close_price"
|
||||
- "fear_greed_index"
|
||||
- "vol_1d"
|
||||
- "vol_7d"
|
||||
static_cat:
|
||||
- "weekday"
|
||||
loss:
|
||||
value:
|
||||
name: "GMADL"
|
||||
a: 1000
|
||||
b: 2
|
||||
model:
|
||||
value:
|
||||
name: "TemporalFusionTransformer"
|
||||
hidden_size: 64
|
||||
dropout: 0.1
|
||||
attention_head_size: 2
|
||||
hidden_continuous_size: 8
|
||||
learning_rate: 0.001
|
||||
optimizer: "Adam"
|
||||
38
configs/sweeps/temporal-fusion-btcusdt-gmadl.yaml
Normal file
38
configs/sweeps/temporal-fusion-btcusdt-gmadl.yaml
Normal file
@ -0,0 +1,38 @@
|
||||
program: ./scripts/train.py
|
||||
project: wne-masters-thesis-testing
|
||||
command:
|
||||
- ${env}
|
||||
- ${interpreter}
|
||||
- ${program}
|
||||
- "./configs/experiments/temporal-fusion-btcusdt-gmadl.yaml"
|
||||
- "--patience"
|
||||
- "10"
|
||||
method: random
|
||||
metric:
|
||||
goal: minimize
|
||||
name: val_loss
|
||||
parameters:
|
||||
past_window:
|
||||
distribution: int_uniform
|
||||
min: 20
|
||||
max: 100
|
||||
batch_size:
|
||||
values: [64, 128, 256]
|
||||
model:
|
||||
parameters:
|
||||
name:
|
||||
value: "TemporalFusionTransformer"
|
||||
share_single_variable_networks:
|
||||
value: false
|
||||
hidden_size:
|
||||
values: [128, 256, 512, 1024]
|
||||
dropout:
|
||||
values: [0.0, 0.1, 0.2, 0.3]
|
||||
attention_head_size:
|
||||
values: [1, 2, 4, 6]
|
||||
hidden_continuous_size:
|
||||
values: [4, 8, 16, 32]
|
||||
learning_rate:
|
||||
value: 0.001
|
||||
optimizer:
|
||||
value: "Adam"
|
||||
@ -15,6 +15,8 @@ from pytorch_forecasting.metrics import MAE, RMSE
|
||||
from pytorch_forecasting import QuantileLoss
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
||||
|
||||
from ml.loss import GMADL
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -130,6 +132,12 @@ def get_loss(config):
|
||||
if loss_name == 'Quantile':
|
||||
return QuantileLoss(config['loss']['quantiles'])
|
||||
|
||||
if loss_name == 'GMADL':
|
||||
return GMADL(
|
||||
a=config['loss']['a'],
|
||||
b=config['loss']['b']
|
||||
)
|
||||
|
||||
raise ValueError("Unknown loss")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user