diff --git a/configs/experiments/temporal-fusion-btcusdt-quantile.yaml b/configs/experiments/temporal-fusion-btcusdt-quantile.yaml new file mode 100644 index 0000000..900b570 --- /dev/null +++ b/configs/experiments/temporal-fusion-btcusdt-quantile.yaml @@ -0,0 +1,66 @@ +future_window: + value: 1 +past_window: + value: 24 +batch_size: + value: 64 +max_epochs: + value: 30 +data: + value: + dataset: "btc-usdt-5m:latest" + sliding_window: 4 + validation: 0.2 + fields: + time_index: "time_index" + target: "close_price" + group_ids: ["group_id"] + dynamic_unknown_real: + - "high_price" + - "low_price" + - "open_price" + - "close_price" + - "volume" + - "open_to_close_price" + - "high_to_close_price" + - "low_to_close_price" + - "high_to_low_price" + - "returns" + - "log_returns" + - "vol_1h" + - "macd" + - "macd_signal" + - "rsi" + - "low_bband_to_close_price" + - "up_bband_to_close_price" + - "mid_bband_to_close_price" + - "sma_1h_to_close_price" + - "sma_1d_to_close_price" + - "sma_7d_to_close_price" + - "ema_1h_to_close_price" + - "ema_1d_to_close_price" + dynamic_unknown_cat: [] + dynamic_known_real: [] + dynamic_known_cat: + - "hour" + static_real: + - "effective_rates" + - "vix_close_price" + - "fear_greed_index" + - "vol_1d" + - "vol_7d" + static_cat: + - "weekday" +loss: + value: + name: "Quantile" + quantiles: [0.02, 0.1, 0.5, 0.9, 0.98] +model: + value: + name: "TemporalFusionTransformer" + hidden_size: 64 + dropout: 0.1 + attention_head_size: 2 + hidden_continuous_size: 8 + learning_rate: 0.001 + optimizer: "Adam" diff --git a/configs/sweeps/temporal-fusion-btcusdt-quantile.yaml b/configs/sweeps/temporal-fusion-btcusdt-quantile.yaml new file mode 100644 index 0000000..f3c0c9e --- /dev/null +++ b/configs/sweeps/temporal-fusion-btcusdt-quantile.yaml @@ -0,0 +1,36 @@ +program: ./scripts/train.py +project: wne-masters-thesis-testing +command: + - ${env} + - ${interpreter} + - ${program} + - "./configs/experiments/temporal-fusion-btcusdt-quantile.yaml" +method: random +metric: + goal: maximize + name: val_loss +parameters: + past_window: + distribution: int_uniform + min: 5 + max: 100 + batch_size: + values: [64, 128, 256] + model: + parameters: + name: + value: "TemporalFusionTransoformer" + share_single_variable_networks: + value: false + hidden_size: + values: [128, 256, 512, 1024] + dropout: + values: [0.0, 0.1, 0.2, 0.3, 0.4] + attention_head_size: + values: [1, 2, 4, 6] + hidden_continuous_size: + values: [4, 8, 16, 32] + learning_rate: + values: [0.01, 0.001, 0.0005, 0.0001] + optimizer: + values: ["Adam", "RMSProp", "Adagrad"] \ No newline at end of file diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 435c620..0903b4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,9 @@ requires-python = ">=3.9" dependencies = [ "pytorch-forecasting==1.0.0", "plotly==5.22.0", - "wandb==0.16.6" + "wandb==0.17.7", + "TA-lib==0.4.32", + "numpy==1.26.4" ] [tool.pytest.ini_options] diff --git a/scripts/train.py b/scripts/train.py index e69de29..e1a2462 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -0,0 +1,228 @@ +import argparse +import logging +import wandb +import pprint +import os +import pandas as pd +import lightning.pytorch as pl + +from lightning.pytorch.utilities.model_summary import ModelSummary +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.callbacks import ModelCheckpoint +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, RMSE +from pytorch_forecasting import QuantileLoss +from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "config", help="Experiment configuration file in yaml format.") + + parser.add_argument( + "-p", + "--project", + default="wne-masters-thesis-testing", + help="W&B project name.") + + parser.add_argument( + "-l", + "--log-level", + default=logging.INFO, + type=int, + help="Sets the log level.") + + parser.add_argument( + "-s", + "--seed", + default=42, + type=int, + help="Random seed for the training.") + + parser.add_argument( + "-n", + "--log-interval", + default=100, + type=int, + help="Log every n steps." + ) + + parser.add_argument( + '-v', + '--val-check-interval', + default=300, + type=int, + help="Run validation every n batches." + ) + + parser.add_argument('--no-wandb', action='store_true', + help='Disables wandb, for testing.') + + return parser.parse_args() + + +def get_dataset(config, project): + artifact_name = f"{project}/{config['data']['dataset']}" + artifact = wandb.Api().artifact(artifact_name) + base_path = artifact.download() + logging.info(f"Artifacts downloaded to {base_path}") + + name = artifact.metadata['name'] + part_name = f"in-sample-{config['data']['sliding_window']}" + data = pd.read_csv(os.path.join( + base_path, name + '-' + part_name + '.csv')) + logging.info(f"Using part: {part_name}") + + # TODO: Fix in dataset + data['weekday'] = data['weekday'].astype('str') + data['hour'] = data['hour'].astype('str') + + validation_part = config['data']['validation'] + logging.info(f"Using {validation_part} of in sample part for validation.") + train_data = data.iloc[:int(len(data) * (1 - validation_part))] + val_data = data.iloc[len(train_data) - config['past_window']:] + logging.info(f"Trainin part size: {len(train_data)}") + logging.info( + f"Validation part size: {len(val_data)} " + + f"({len(data) - len(train_data)} + {config['past_window']})") + + logging.info("Building time series dataset for training.") + train = TimeSeriesDataSet( + train_data, + time_idx=config['data']['fields']['time_index'], + target=config['data']['fields']['target'], + group_ids=config['data']['fields']['group_ids'], + min_encoder_length=config['past_window'], + max_encoder_length=config['past_window'], + min_prediction_length=config['future_window'], + max_prediction_length=config['future_window'], + static_reals=config['data']['fields']['static_real'], + static_categoricals=config['data']['fields']['static_cat'], + time_varying_known_reals=config['data']['fields']['dynamic_known_real'], + time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'], + time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'], + time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'], + randomize_length=False + ) + + logging.info("Building time series dataset for validation.") + val = TimeSeriesDataSet.from_dataset( + train, val_data, stop_randomization=True) + + return train, val + + +def get_loss(config): + loss_name = config['loss']['name'] + + if loss_name == 'Quantile': + return QuantileLoss(config['loss']['quantiles']) + + raise ValueError("Unknown loss") + + +def get_model(config, dataset, loss): + model_name = config['model']['name'] + + if model_name == 'TemporalFusionTransformer': + return TemporalFusionTransformer.from_dataset( + dataset, + hidden_size=config['model']['hidden_size'], + dropout=config['model']['dropout'], + attention_head_size=config['model']['attention_head_size'], + hidden_continuous_size=config['model']['hidden_continuous_size'], + learning_rate=config['model']['learning_rate'], + share_single_variable_networks=False, + loss=loss, + logging_metrics=[MAE(), RMSE()] + ) + + raise ValueError("Unknown model") + + +def main(): + args = get_args() + logging.basicConfig(level=args.log_level) + pl.seed_everything(args.seed) + + run = wandb.init( + project=args.project, + config=args.config, + job_type="train", + mode="disabled" if args.no_wandb else "online" + ) + config = run.config + logging.info("Using experiment config:\n%s", pprint.pformat(config)) + + # Get time series dataset + train, valid = get_dataset(config, args.project) + logging.info("Train dataset parameters:\n" + + f"{pprint.pformat(train.get_parameters())}") + + # Get loss + loss = get_loss(config) + logging.info(f"Using loss {loss}") + + # Get model + model = get_model(config, train, loss) + logging.info(f"Using model {config['model']['name']}") + logging.info(f"{ModelSummary(model)}") + logging.info( + "Model hyperparameters:\n" + + f"{pprint.pformat(model.hparams)}") + + # Checkpoint for saving the model + checkpoint_callback = ModelCheckpoint( + monitor='val_loss', + save_top_k=3, + mode='min', + ) + + # Logger for W&B + wandb_logger = WandbLogger( + project=args.project, + experiment=run, + log_model="all") if not args.no_wandb else None + + early_stopping = EarlyStopping( + monitor="val_loss", + mode="min", + patience=5) + + batch_size = config['batch_size'] + logging.info(f"Training batch size {batch_size}.") + + epochs = config['max_epochs'] + logging.info(f"Training for {epochs} epochs.") + + trainer = pl.Trainer( + accelerator="auto", + max_epochs=epochs, + logger=wandb_logger, + callbacks=[ + checkpoint_callback, + early_stopping + ], + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval + ) + + if epochs > 0: + logging.info("Starting training:") + trainer.fit( + model, + train_dataloaders=train.to_dataloader( + batch_size=batch_size, + num_workers=3, + ), + val_dataloaders=valid.to_dataloader( + batch_size=batch_size, train=False, num_workers=3 + )) + + +if __name__ == '__main__': + main()