add training script
This commit is contained in:
parent
c3a03fe338
commit
e6c2d4b914
66
configs/experiments/temporal-fusion-btcusdt-quantile.yaml
Normal file
66
configs/experiments/temporal-fusion-btcusdt-quantile.yaml
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
future_window:
|
||||||
|
value: 1
|
||||||
|
past_window:
|
||||||
|
value: 24
|
||||||
|
batch_size:
|
||||||
|
value: 64
|
||||||
|
max_epochs:
|
||||||
|
value: 30
|
||||||
|
data:
|
||||||
|
value:
|
||||||
|
dataset: "btc-usdt-5m:latest"
|
||||||
|
sliding_window: 4
|
||||||
|
validation: 0.2
|
||||||
|
fields:
|
||||||
|
time_index: "time_index"
|
||||||
|
target: "close_price"
|
||||||
|
group_ids: ["group_id"]
|
||||||
|
dynamic_unknown_real:
|
||||||
|
- "high_price"
|
||||||
|
- "low_price"
|
||||||
|
- "open_price"
|
||||||
|
- "close_price"
|
||||||
|
- "volume"
|
||||||
|
- "open_to_close_price"
|
||||||
|
- "high_to_close_price"
|
||||||
|
- "low_to_close_price"
|
||||||
|
- "high_to_low_price"
|
||||||
|
- "returns"
|
||||||
|
- "log_returns"
|
||||||
|
- "vol_1h"
|
||||||
|
- "macd"
|
||||||
|
- "macd_signal"
|
||||||
|
- "rsi"
|
||||||
|
- "low_bband_to_close_price"
|
||||||
|
- "up_bband_to_close_price"
|
||||||
|
- "mid_bband_to_close_price"
|
||||||
|
- "sma_1h_to_close_price"
|
||||||
|
- "sma_1d_to_close_price"
|
||||||
|
- "sma_7d_to_close_price"
|
||||||
|
- "ema_1h_to_close_price"
|
||||||
|
- "ema_1d_to_close_price"
|
||||||
|
dynamic_unknown_cat: []
|
||||||
|
dynamic_known_real: []
|
||||||
|
dynamic_known_cat:
|
||||||
|
- "hour"
|
||||||
|
static_real:
|
||||||
|
- "effective_rates"
|
||||||
|
- "vix_close_price"
|
||||||
|
- "fear_greed_index"
|
||||||
|
- "vol_1d"
|
||||||
|
- "vol_7d"
|
||||||
|
static_cat:
|
||||||
|
- "weekday"
|
||||||
|
loss:
|
||||||
|
value:
|
||||||
|
name: "Quantile"
|
||||||
|
quantiles: [0.02, 0.1, 0.5, 0.9, 0.98]
|
||||||
|
model:
|
||||||
|
value:
|
||||||
|
name: "TemporalFusionTransformer"
|
||||||
|
hidden_size: 64
|
||||||
|
dropout: 0.1
|
||||||
|
attention_head_size: 2
|
||||||
|
hidden_continuous_size: 8
|
||||||
|
learning_rate: 0.001
|
||||||
|
optimizer: "Adam"
|
||||||
36
configs/sweeps/temporal-fusion-btcusdt-quantile.yaml
Normal file
36
configs/sweeps/temporal-fusion-btcusdt-quantile.yaml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
program: ./scripts/train.py
|
||||||
|
project: wne-masters-thesis-testing
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- ${interpreter}
|
||||||
|
- ${program}
|
||||||
|
- "./configs/experiments/temporal-fusion-btcusdt-quantile.yaml"
|
||||||
|
method: random
|
||||||
|
metric:
|
||||||
|
goal: maximize
|
||||||
|
name: val_loss
|
||||||
|
parameters:
|
||||||
|
past_window:
|
||||||
|
distribution: int_uniform
|
||||||
|
min: 5
|
||||||
|
max: 100
|
||||||
|
batch_size:
|
||||||
|
values: [64, 128, 256]
|
||||||
|
model:
|
||||||
|
parameters:
|
||||||
|
name:
|
||||||
|
value: "TemporalFusionTransoformer"
|
||||||
|
share_single_variable_networks:
|
||||||
|
value: false
|
||||||
|
hidden_size:
|
||||||
|
values: [128, 256, 512, 1024]
|
||||||
|
dropout:
|
||||||
|
values: [0.0, 0.1, 0.2, 0.3, 0.4]
|
||||||
|
attention_head_size:
|
||||||
|
values: [1, 2, 4, 6]
|
||||||
|
hidden_continuous_size:
|
||||||
|
values: [4, 8, 16, 32]
|
||||||
|
learning_rate:
|
||||||
|
values: [0.01, 0.001, 0.0005, 0.0001]
|
||||||
|
optimizer:
|
||||||
|
values: ["Adam", "RMSProp", "Adagrad"]
|
||||||
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
@ -15,7 +15,9 @@ requires-python = ">=3.9"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"pytorch-forecasting==1.0.0",
|
"pytorch-forecasting==1.0.0",
|
||||||
"plotly==5.22.0",
|
"plotly==5.22.0",
|
||||||
"wandb==0.16.6"
|
"wandb==0.17.7",
|
||||||
|
"TA-lib==0.4.32",
|
||||||
|
"numpy==1.26.4"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
228
scripts/train.py
228
scripts/train.py
@ -0,0 +1,228 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import wandb
|
||||||
|
import pprint
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
|
||||||
|
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||||
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||||
|
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||||
|
from pytorch_forecasting.metrics import MAE, RMSE
|
||||||
|
from pytorch_forecasting import QuantileLoss
|
||||||
|
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"config", help="Experiment configuration file in yaml format.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--project",
|
||||||
|
default="wne-masters-thesis-testing",
|
||||||
|
help="W&B project name.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--log-level",
|
||||||
|
default=logging.INFO,
|
||||||
|
type=int,
|
||||||
|
help="Sets the log level.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--seed",
|
||||||
|
default=42,
|
||||||
|
type=int,
|
||||||
|
help="Random seed for the training.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-n",
|
||||||
|
"--log-interval",
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help="Log every n steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-v',
|
||||||
|
'--val-check-interval',
|
||||||
|
default=300,
|
||||||
|
type=int,
|
||||||
|
help="Run validation every n batches."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--no-wandb', action='store_true',
|
||||||
|
help='Disables wandb, for testing.')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config, project):
|
||||||
|
artifact_name = f"{project}/{config['data']['dataset']}"
|
||||||
|
artifact = wandb.Api().artifact(artifact_name)
|
||||||
|
base_path = artifact.download()
|
||||||
|
logging.info(f"Artifacts downloaded to {base_path}")
|
||||||
|
|
||||||
|
name = artifact.metadata['name']
|
||||||
|
part_name = f"in-sample-{config['data']['sliding_window']}"
|
||||||
|
data = pd.read_csv(os.path.join(
|
||||||
|
base_path, name + '-' + part_name + '.csv'))
|
||||||
|
logging.info(f"Using part: {part_name}")
|
||||||
|
|
||||||
|
# TODO: Fix in dataset
|
||||||
|
data['weekday'] = data['weekday'].astype('str')
|
||||||
|
data['hour'] = data['hour'].astype('str')
|
||||||
|
|
||||||
|
validation_part = config['data']['validation']
|
||||||
|
logging.info(f"Using {validation_part} of in sample part for validation.")
|
||||||
|
train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
||||||
|
val_data = data.iloc[len(train_data) - config['past_window']:]
|
||||||
|
logging.info(f"Trainin part size: {len(train_data)}")
|
||||||
|
logging.info(
|
||||||
|
f"Validation part size: {len(val_data)} "
|
||||||
|
+ f"({len(data) - len(train_data)} + {config['past_window']})")
|
||||||
|
|
||||||
|
logging.info("Building time series dataset for training.")
|
||||||
|
train = TimeSeriesDataSet(
|
||||||
|
train_data,
|
||||||
|
time_idx=config['data']['fields']['time_index'],
|
||||||
|
target=config['data']['fields']['target'],
|
||||||
|
group_ids=config['data']['fields']['group_ids'],
|
||||||
|
min_encoder_length=config['past_window'],
|
||||||
|
max_encoder_length=config['past_window'],
|
||||||
|
min_prediction_length=config['future_window'],
|
||||||
|
max_prediction_length=config['future_window'],
|
||||||
|
static_reals=config['data']['fields']['static_real'],
|
||||||
|
static_categoricals=config['data']['fields']['static_cat'],
|
||||||
|
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
||||||
|
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
||||||
|
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
||||||
|
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
||||||
|
randomize_length=False
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Building time series dataset for validation.")
|
||||||
|
val = TimeSeriesDataSet.from_dataset(
|
||||||
|
train, val_data, stop_randomization=True)
|
||||||
|
|
||||||
|
return train, val
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss(config):
|
||||||
|
loss_name = config['loss']['name']
|
||||||
|
|
||||||
|
if loss_name == 'Quantile':
|
||||||
|
return QuantileLoss(config['loss']['quantiles'])
|
||||||
|
|
||||||
|
raise ValueError("Unknown loss")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(config, dataset, loss):
|
||||||
|
model_name = config['model']['name']
|
||||||
|
|
||||||
|
if model_name == 'TemporalFusionTransformer':
|
||||||
|
return TemporalFusionTransformer.from_dataset(
|
||||||
|
dataset,
|
||||||
|
hidden_size=config['model']['hidden_size'],
|
||||||
|
dropout=config['model']['dropout'],
|
||||||
|
attention_head_size=config['model']['attention_head_size'],
|
||||||
|
hidden_continuous_size=config['model']['hidden_continuous_size'],
|
||||||
|
learning_rate=config['model']['learning_rate'],
|
||||||
|
share_single_variable_networks=False,
|
||||||
|
loss=loss,
|
||||||
|
logging_metrics=[MAE(), RMSE()]
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError("Unknown model")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=args.log_level)
|
||||||
|
pl.seed_everything(args.seed)
|
||||||
|
|
||||||
|
run = wandb.init(
|
||||||
|
project=args.project,
|
||||||
|
config=args.config,
|
||||||
|
job_type="train",
|
||||||
|
mode="disabled" if args.no_wandb else "online"
|
||||||
|
)
|
||||||
|
config = run.config
|
||||||
|
logging.info("Using experiment config:\n%s", pprint.pformat(config))
|
||||||
|
|
||||||
|
# Get time series dataset
|
||||||
|
train, valid = get_dataset(config, args.project)
|
||||||
|
logging.info("Train dataset parameters:\n" +
|
||||||
|
f"{pprint.pformat(train.get_parameters())}")
|
||||||
|
|
||||||
|
# Get loss
|
||||||
|
loss = get_loss(config)
|
||||||
|
logging.info(f"Using loss {loss}")
|
||||||
|
|
||||||
|
# Get model
|
||||||
|
model = get_model(config, train, loss)
|
||||||
|
logging.info(f"Using model {config['model']['name']}")
|
||||||
|
logging.info(f"{ModelSummary(model)}")
|
||||||
|
logging.info(
|
||||||
|
"Model hyperparameters:\n" +
|
||||||
|
f"{pprint.pformat(model.hparams)}")
|
||||||
|
|
||||||
|
# Checkpoint for saving the model
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
monitor='val_loss',
|
||||||
|
save_top_k=3,
|
||||||
|
mode='min',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logger for W&B
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project=args.project,
|
||||||
|
experiment=run,
|
||||||
|
log_model="all") if not args.no_wandb else None
|
||||||
|
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode="min",
|
||||||
|
patience=5)
|
||||||
|
|
||||||
|
batch_size = config['batch_size']
|
||||||
|
logging.info(f"Training batch size {batch_size}.")
|
||||||
|
|
||||||
|
epochs = config['max_epochs']
|
||||||
|
logging.info(f"Training for {epochs} epochs.")
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator="auto",
|
||||||
|
max_epochs=epochs,
|
||||||
|
logger=wandb_logger,
|
||||||
|
callbacks=[
|
||||||
|
checkpoint_callback,
|
||||||
|
early_stopping
|
||||||
|
],
|
||||||
|
log_every_n_steps=args.log_interval,
|
||||||
|
val_check_interval=args.val_check_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
if epochs > 0:
|
||||||
|
logging.info("Starting training:")
|
||||||
|
trainer.fit(
|
||||||
|
model,
|
||||||
|
train_dataloaders=train.to_dataloader(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=3,
|
||||||
|
),
|
||||||
|
val_dataloaders=valid.to_dataloader(
|
||||||
|
batch_size=batch_size, train=False, num_workers=3
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user