added changes

This commit is contained in:
yasha 2025-04-30 05:07:31 +00:00
parent 0ded87a0fb
commit 5865814464
16 changed files with 101282 additions and 518 deletions

4
.gitignore vendored
View File

@ -200,4 +200,6 @@ lightning_logs/
wne-masters-thesis-testing/ wne-masters-thesis-testing/
notebooks/cache/ notebooks/cache/
notebooks/images/ notebooks/images/
.DS_Store .DS_Store
data/
venv-py310/

View File

@ -17,9 +17,10 @@ parameters:
data: data:
parameters: parameters:
dataset: dataset:
value: "btc-usdt-5m:latest" value: "btc-5m-features-full:latest"
validation: validation:
value: 0.2 value: 0.2
sliding_window: sliding_window:
min: 0 # min: 0 # Use values for grid search
max: 5 # max: 5
values: [0, 1, 2, 3, 4, 5] # Explicitly list values for grid search

View File

@ -8,7 +8,7 @@ max_epochs:
value: 40 value: 40
data: data:
value: value:
dataset: "btc-usdt-5m:latest" dataset: "btc-5m-features-full:latest"
sliding_window: 0 sliding_window: 0
validation: 0.2 validation: 0.2
fields: fields:

View File

@ -8,9 +8,11 @@ max_epochs:
value: 30 value: 30
data: data:
value: value:
dataset: "btc-usdt-5m:latest" # in_sample_artifact_name: "btc-5m-features-in_sample:latest" # Reverted
# out_of_sample_artifact_name: "btc-5m-features-out_of_sample:latest" # Reverted
dataset: "btc-5m-features-full:latest" # Use a single artifact name
sliding_window: 0 sliding_window: 0
validation: 0.2 validation: 0.2 # This likely controls the in-sample vs out-of-sample split in train.py
fields: fields:
value: value:
time_index: "time_index" time_index: "time_index"

View File

@ -8,7 +8,7 @@ command:
- "./configs/experiments/informer-btcusdt-5m-gmadl.yaml" - "./configs/experiments/informer-btcusdt-5m-gmadl.yaml"
- "--patience" - "--patience"
- "15" - "15"
method: random method: bayes
metric: metric:
goal: minimize goal: minimize
name: val_loss name: val_loss

View File

@ -13,6 +13,8 @@ metric:
goal: minimize goal: minimize
name: val_loss name: val_loss
parameters: parameters:
val_check_interval:
value: 1.0 # Validate once per epoch
past_window: past_window:
distribution: int_uniform distribution: int_uniform
min: 20 min: 20

1
data/.gitignore vendored
View File

@ -1 +0,0 @@
*

File diff suppressed because one or more lines are too long

173
prompts/design.txt Normal file
View File

@ -0,0 +1,173 @@
Implementation Guide: GMADL Informer Strategy for High-Frequency Bitcoin Trading (Based on arXiv:2503.18096v1)1. IntroductionThe research detailed in arXiv:2503.18096v1 investigates the application of the Informer deep learning architecture to develop automated trading strategies for high-frequency Bitcoin (BTC) data.1 The study specifically compares the efficacy of the Informer model when trained using three distinct loss functions: Root Mean Squared Error (RMSE), Quantile loss, and a novel loss function termed Generalized Mean Absolute Directional Loss (GMADL).1 The core objective was to assess whether Informer-based strategies could outperform standard benchmarks like Buy & Hold (B&H) and strategies based on classical technical indicators such as Moving Average Convergence Divergence (MACD) and Relative Strength Index (RSI).2This report provides a comprehensive guide to implementing the final, optimal configuration of the Informer model trained with the GMADL loss function, as described in the paper. This specific configuration demonstrated superior trading performance, particularly when applied to higher frequency data intervals.1 It is crucial to clarify that "GMADL" in the context of this study refers specifically to the loss function employed during the model training phase, not a distinct architectural layer or modification to the Informer structure itself.1 The primary contribution highlighted by the authors involves the systematic application and evaluation of these different loss functions within the Informer framework to forecast future returns and subsequently generate trading signals.1The study evaluated model performance across Bitcoin data sampled at 5-minute, 15-minute, and 30-minute intervals, finding that the GMADL variant exhibited particular strength at the 5-minute frequency.1 The performance was rigorously benchmarked against B&H, MACD, RSI strategies, and the Informer models trained using RMSE and Quantile loss.2 For practitioners seeking to replicate or build upon these findings, the authors note the availability of an open-source implementation framework on GitLab, designed to facilitate the comparison of trading strategies and aid in result reproduction.42. Dataset Specification and Feature EngineeringThe foundation of the GMADL Informer strategy implementation lies in the specific dataset utilized and the extensive feature engineering process applied.2.1. Primary Dataset
Asset: The study exclusively used data for the Bitcoin/Tether (BTC/USDT) cryptocurrency pair.3 This choice was motivated by Bitcoin's characteristic high volatility, the availability of a long historical data record, its continuous 24/7 trading nature (which circumvents issues like overnight gaps or holiday closures common in traditional markets), and the relative ease of accessing high-quality, fine-grained historical data from cryptocurrency exchanges.2
Data Source: While not explicitly named, the data was sourced from cryptocurrency exchanges providing historical data access, potentially via API or downloadable files.4
Time Period: The historical data spanned from August 21, 2019, to July 24, 2024.2
Frequencies: The analysis was conducted using data aggregated into 5-minute, 15-minute, and 30-minute intervals.2
Raw Data Fields: For each interval, the dataset included the open time, close time, open price, high price, low price, close price, and trading volume.3
2.2. Input Features (Model Input)The input to the Informer model was not limited to raw price and volume but comprised a rich set of engineered features, categorized as real and categorical variables.3
Real Variables:
Basic Market Data: Open price, high price, low price, close price, volume.
Returns: Calculated returns (the specific calculation method, e.g., log or simple returns, is not detailed in the provided summary but is listed as 'returns').3
Price Ratios: Ratios derived from OHLC prices: open-to-close, high-to-close, low-to-close, high-to-low.3 These ratios can capture intraday price dynamics and momentum.
Volatility Measures: Historical volatility calculated over 1-hour, 1-day, and 7-day lookback windows.3 Including multiple timeframes allows the model to consider volatility dynamics at different scales.
Moving Average Ratios: Ratios of Simple Moving Averages (SMA) over 1-hour, 1-day, and 7-day periods to the close price; Ratios of Exponential Moving Averages (EMA) over 1-hour and 1-day periods to the close price.3 These features provide information about price trends relative to historical averages over multiple horizons.
Technical Indicators: Standard indicators including MACD, MACD signal line, RSI, and Bollinger Bands ratios (lower band/close, upper band/close, middle band/close).3 These directly encode common technical analysis signals related to momentum, trend, and volatility bands.
External Data: The Cboe Volatility Index (VIX), Federal Funds effective rates, and the Crypto Fear & Greed Index were incorporated.3 The inclusion of these typically lower-frequency indicators (daily for VIX/Fear&Greed, less frequent for Fed rates) suggests an effort to capture the influence of broader market sentiment and macroeconomic conditions on even high-frequency Bitcoin price movements. This implies an underlying assumption that short-term crypto dynamics are not entirely isolated from traditional market factors. Implementation requires careful alignment of these lower-frequency data points with the high-frequency intervals (e.g., forward-filling daily values).
Categorical Variables:
Hour of the day (0-23) derived from the interval's close time.3
Day of the week (0-6) derived from the interval's close time.3 These allow the model to potentially learn time-based patterns like intraday or weekly seasonality.
The sheer breadth and complexity of this feature set indicate a significant effort in feature engineering. The model is explicitly provided with information related not only to price action but also to volatility across multiple timescales, momentum, trend strength, mean-reversion tendencies, market sentiment (both crypto-specific and traditional), and macroeconomic context. This suggests a belief that a comprehensive view incorporating these diverse factors is necessary for effective high-frequency prediction. The performance attributed to the GMADL Informer likely relies substantially on this rich, multi-faceted input representation, extending beyond the contributions of the architecture or loss function alone. Replicating the study's results necessitates the accurate calculation and integration of this entire feature suite.3. Data Preprocessing and StructuringAppropriate data preprocessing and a robust data splitting strategy are essential for training the model and evaluating its performance reliably.
Normalization: All real-valued input features were normalized before being passed into the Informer model.3 Normalization is a standard procedure in deep learning, ensuring that features with different scales do not disproportionately influence model training and helps maintain numerical stability. The specific normalization technique (e.g., Z-score standardization, Min-Max scaling) employed is not specified in the available summary 3, which represents a potential ambiguity for exact replication. Standardization is a common choice for financial time series.
Categorical Feature Handling: The categorical features representing the hour (0-23) and weekday (0-6) were transformed into real-valued embeddings.3 This is a standard technique allowing neural networks to learn meaningful representations for discrete inputs and capture potential cyclical patterns associated with time. The dimensionality of these embeddings contributes to the overall model dimension (dd).
Data Splitting Strategy (Rolling Windows): A key aspect of the methodology was the use of a rolling window approach for training and evaluation.3 The entire dataset (21.08.2019 - 24.07.2024) was divided into six consecutive windows. Each window comprised:
An in-sample period of 24 months (2 years) used for model training and validation.
An out-of-sample period of 6 months immediately following the in-sample period, used exclusively for testing the trained model's performance.3
The in-sample data within each window was further subdivided into a training set (80%) and a validation set (20%).3 The validation set is used for tasks like hyperparameter tuning and triggering early stopping.
This rolling window methodology is critical for evaluating time series models in finance. Financial markets exhibit non-stationarity, meaning their statistical properties (like mean and variance) change over time. A simple train-test split on such data can lead to overly optimistic or misleading results that fail to generalize. By repeatedly retraining the model on updated 2-year windows and testing on the subsequent 6 months, the evaluation simulates a more realistic scenario where the trading strategy must adapt to evolving market dynamics. The use of six distinct windows provides a more robust assessment of the strategy's performance across different market regimes present during the nearly 5-year study period.3 Any attempt at replication should adhere strictly to this rolling window protocol to ensure comparable results.
Missing Values: The provided summary 3 does not explicitly mention the strategy for handling potential missing values in the high-frequency data. While the 24/7 nature of crypto markets minimizes systematic gaps 4, sporadic missing data points can still occur. Standard techniques like forward-filling or interpolation might have been used, but confirmation would require consulting the full paper or source code.
4. GMADL Informer Model ArchitectureThe core of the strategy is the Informer model, configured with specific parameters that were found to be optimal when using the GMADL loss function.
Core Architecture: The model is based on the Informer architecture proposed by Zhou et al..2 Its key distinguishing features include the multi-head probe-sparse self-attention mechanism, designed to handle long input sequences efficiently, along with an encoder-decoder structure typical of Transformer-based models.3
Input and Output Sequences:
Input Sequence Length (past window): This parameter determines how many past time steps of features are fed into the model. Crucially, it was treated as a hyperparameter tuned separately for each data frequency.3 The optimal values were:
5-minute data: 28 steps
15-minute data: 30 steps
30-minute data: 26 steps
Output Sequence Length: While not explicitly stated in the summary 3, the objective is to forecast future returns to generate trading signals.1 This strongly implies a prediction horizon of the next single time step, making the output sequence length likely equal to 1.
Dimensionality:
Model/Embedding Dimension (dd): This defines the dimensionality of the input embeddings (for both real and categorical features) and the internal hidden states throughout the model. The optimal value was found to be 256 across all data frequencies (5-min, 15-min, 30-min).3
Feed-Forward Network Dimension (ff): This refers to the inner dimension of the position-wise feed-forward layers within the Informer's encoder and decoder blocks. It was also set to 256 for all frequencies.3
Attention Dimension: The dimension of the attention mechanism is intrinsically linked to the model dimension (dd) and the number of attention heads (hh). Specifically, the dimension per head is calculated as dd / hh.
Informer Configuration: The number of layers and attention heads were also tuned and varied depending on the data frequency.3
Encoder Layers: A single encoder layer was used consistently for all frequencies (5-min, 15-min, 30-min).3
Decoder Layers: The depth of the decoder varied:
5-minute data: 3 layers
15-minute data: 1 layer
30-minute data: 3 layers
Attention Heads (hh): The number of parallel attention heads also varied:
5-minute data: 2 heads
15-minute data: 2 heads
30-minute data: 4 heads
Attention Mechanism: The model employs the multi-head probe-sparse self-attention mechanism.3 This mechanism enhances computational efficiency compared to standard self-attention by allowing each query to attend only to a selected subset of keys, identified based on a sparsity measurement, rather than all keys. This makes the Informer architecture particularly suitable for the potentially long input sequences used in time series forecasting. The summary 3 does not specify the exact value of the factor controlling sparsity (e.g., the number of selected keys, u).
GMADL Integration: It is essential to reiterate that GMADL is the loss function used to train the model, not an architectural component.1 The Informer architecture itself follows the standard design; the GMADL function simply defines the objective minimized during the training process.
The observation that optimal architectural parameters (decoder layers, attention heads, input sequence length) differ across data frequencies is significant. It suggests that the nature and complexity of predictive temporal patterns vary with the sampling interval. For instance, the 5-minute and 30-minute data benefited from deeper decoders (3 layers) compared to the 15-minute data (1 layer). This might imply that capturing the dynamics at these frequencies requires more hierarchical processing or modeling more complex relationships. Similarly, the use of more attention heads (4) for 30-minute data could allow the model to attend to a wider variety of patterns over the longer intervals represented by each step. This frequency-specificity underscores the need for careful hyperparameter tuning tailored to the granularity of the input data; a single architecture is unlikely to be optimal across all frequencies.Furthermore, the consistent use of a shallow encoder (1 layer) across all frequencies, while varying the decoder depth, presents an interesting design choice. It might suggest that a single encoding layer is sufficient to extract the necessary information from the rich input feature set across the lookback window. The decoding process, which involves generating the prediction, appears more sensitive to the data frequency, requiring different levels of complexity (1 vs. 3 layers) to effectively model the forecasting task for each interval.The following table summarizes the key architectural parameters specific to the GMADL Informer for each data frequency:Table 1: GMADL Informer Architecture Parameters per Data FrequencyParameter5-min Data15-min Data30-min DataInput Sequence Length283026Output Sequence Length1 (Inferred)1 (Inferred)1 (Inferred)Model Dimension (dd)256256256Feed-Forward Dim (ff)256256256Encoder Layers111Decoder Layers313Attention Heads (hh)224Source: 35. Training Hyperparameters and ProcedureThe training process involved specific hyperparameter settings and procedures, again optimized for the GMADL loss function and varying by data frequency.
Optimizer: The specific optimizer used for training (e.g., Adam, AdamW, SGD) is not explicitly mentioned in the provided summary.3 Adam or its variants are common choices for training Transformer-based models due to their effectiveness and efficiency.
Learning Rate (LR): The initial learning rate was tuned and set differently for each frequency 3:
5-minute data: 0.0001
15-minute data: 0.001
30-minute data: 0.0001
Whether a learning rate schedule (e.g., warmup followed by decay) was employed is not specified; the values likely represent the initial or constant learning rate used.3
Batch Size: The number of samples processed in each training iteration also varied 3:
5-minute data: 256
15-minute data: 128
30-minute data: 128
Training Epochs: Models were trained for a maximum duration of 40 epochs.3
Early Stopping: To prevent overfitting and reduce unnecessary computation, an early stopping mechanism was used. Training was halted if the loss calculated on the validation set did not show improvement for 15 consecutive validation checks. These validation checks were performed periodically, specifically after every 300 training batches.3
Regularization: Dropout was applied as a regularization technique to mitigate overfitting.3
Dropout Rate: A relatively low dropout rate of 0.01 was used consistently across all three data frequencies (5-min, 15-min, 30-min).3
The variations in learning rate and batch size across frequencies likely reflect attempts to balance training stability and convergence speed. For the 5-minute data, which has the highest number of samples within each 2-year training window, a larger batch size (256) could accelerate training per epoch. However, this was paired with a smaller learning rate (0.0001), potentially to ensure more stable convergence given the larger batch size and the potentially higher noise level in 5-minute data. Conversely, the 15-minute data used a smaller batch size (128) but a significantly larger learning rate (0.001), perhaps aiming for faster convergence with fewer data points per window. The 30-minute data used the smaller learning rate (0.0001) and the smaller batch size (128), possibly indicating a need for more cautious and stable updates for the dynamics at this lower frequency. These choices highlight that optimal training hyperparameters are data-dependent and involve trade-offs. The low dropout rate (0.01) suggests that overfitting might not have been the primary challenge, or that the early stopping mechanism was effective in controlling it.The following table summarizes the training hyperparameters for the GMADL Informer model:Table 2: GMADL Informer Training Hyperparameters per Data FrequencyParameter5-min Data15-min Data30-min DataOptimizerNot SpecifiedNot SpecifiedNot SpecifiedLearning Rate0.00010.0010.0001Batch Size256128128Max Epochs404040Early Stopping Patience (Checks)151515Early Stopping Freq. (Batches)300300300Dropout Rate0.010.010.01Source: 36. Generalized Mean Absolute Directional Loss (GMADL)The defining characteristic of the best-performing strategy identified in the paper is its use of the GMADL loss function during training.1
Definition: The GMADL function is mathematically defined as follows 3:
GMADL=N1i=1∑N[(1)⋅(1+ea⋅yi⋅y^i121)⋅(yi)b]
Where:
yi is the actual observed return for observation i.
y^i is the return predicted by the Informer model for observation i.
N is the total number of observations (typically the batch size during training).
a and b are parameters controlling the behavior of the loss function.
Parameters: In this study, the parameters were fixed at a=100 and b=2.3
Explanation of Components:
Directional Term: The core component (1+ea⋅y⋅y^121) focuses heavily on directional accuracy. The term 1+ex1 is the sigmoid function, let's denote it σ(x). So this part is σ(a⋅y⋅y^)0.5.
If the predicted direction matches the actual direction (y and y^ have the same sign), their product y⋅y^ is positive. Given the large value of a=100, the argument a⋅y⋅y^ becomes a large positive number. The sigmoid function σ(large positive)≈1. Thus, the term becomes approximately 10.5=0.5.
If the predicted direction is opposite to the actual direction (y and y^ have opposite signs), their product y⋅y^ is negative. The argument a⋅y⋅y^ becomes a large negative number. The sigmoid function σ(large negative)≈0. Thus, the term becomes approximately 00.5=0.5.
The large value of a=100 ensures a very steep transition around y⋅y^=0. This means the loss function strongly penalizes predictions that get the direction wrong, even if the magnitude of the error is small but crosses the zero threshold.
The leading factor of (1) in the overall formula flips these contributions. A correct direction results in a loss component close to 0.5×(y)b, while an incorrect direction results in a loss component close to 0.5×(y)b. Since the goal is minimization, the model is heavily incentivized to predict the correct direction.
Magnitude Weighting: The term (y)b weights the directional error component by the magnitude of the actual return, raised to the power b=2. This means that directional errors made when the actual price movement (y) was large are penalized much more severely (quadratically, in this case, due to b=2) than directional errors made when the actual price movement was small.
The design of the GMADL function reveals its specific objective: to train a model that excels at predicting the direction of future returns, particularly prioritizing accuracy during larger market movements. Unlike standard regression losses like RMSE, which penalize prediction errors based purely on magnitude regardless of direction, GMADL explicitly incorporates directional correctness via the steep sigmoid term (controlled by a) and emphasizes the importance of getting large moves right via the magnitude weighting (controlled by b). This objective aligns closely with the practical goals of many trading strategies, where capturing the direction of significant price swings is often more critical to profitability than minimizing small prediction errors around zero. The reported success of GMADL, especially in the noisy high-frequency domain 1, suggests that optimizing directly for this trade-oriented objective can yield superior results compared to optimizing for pure predictive accuracy (like RMSE).7. Evaluation MetricsThe effectiveness of the strategies derived from the Informer model's predictions was not measured using typical machine learning regression metrics (like MAE or MSE alone). Instead, the evaluation focused on a suite of standard financial performance metrics that assess the quality of the resulting trading strategy.3
Metric Suite: The performance evaluation employed the following metrics:
Annualized Return Compounded (ARC): The geometric average annual rate of return.
Annualized Standard Deviation (ASD): The annualized volatility of the strategy's returns, a measure of risk.
Information Ratio (IR*): A measure of risk-adjusted return, often calculated as the excess return over a benchmark (like B&H) divided by the standard deviation of that excess return.
Maximum Drawdown (MD): The largest percentage decline from a peak to a subsequent trough in the strategy's equity curve, indicating downside risk.
Modified Information Ratio (IR**): An adjusted version of the Information Ratio, potentially accounting for non-normality (skewness, kurtosis) in return distributions (e.g., Sharpe Ratio, Sortino Ratio). The summary highlights this as a key metric where the GMADL strategy excelled.3
Number of trades (N): Total number of trades executed by the strategy, relevant for assessing trading costs and activity.
Percentage of Time in Long Position (LONG): The proportion of the evaluation period the strategy held a long position.
Percentage of Time in Short Position (SHORT): The proportion of the evaluation period the strategy held a short position.
Trading Signal Generation: While the exact mechanism for converting the model's return forecast (y^) into a trading signal (Buy/Long, Sell/Short, Hold/Flat) is not detailed in the summary 3, a common approach involves using thresholds. For example, if y^ exceeds a positive threshold, a long position is initiated; if y^ falls below a negative threshold, a short position is initiated; otherwise, the strategy remains flat. These thresholds could be fixed or potentially optimized.
The choice of evaluation metrics underscores the study's practical focus. The ultimate goal was not merely accurate price forecasting but the development of profitable and risk-managed automated trading strategies.1 Financial metrics like ARC, IR**, and MD directly measure the simulated profitability and risk profile of acting upon the model's predictions. This is crucial because a model achieving low prediction error (e.g., low MSE) might not necessarily translate into a successful trading strategy if its errors, even if small on average, occur at critical moments for trading decisions. Therefore, evaluating success based on these portfolio-level metrics provides a more relevant assessment of the model's utility in a real-world trading context. This aligns well with the use of the GMADL loss function, which itself prioritizes aspects (direction, large moves) directly relevant to trading profitability.8. Reported Performance and Key FindingsThe experimental results presented in the paper strongly favor the Informer model trained with the GMADL loss function, particularly for high-frequency data.
GMADL Superiority: The central finding is that the GMADL Informer strategy significantly outperformed the benchmark strategies (Buy & Hold, MACD, RSI) and the other Informer variants trained with RMSE and Quantile loss functions.1 The Quantile loss-based Informer strategy reportedly did not manage to outperform the benchmarks.1
Frequency Dependence: A critical observation was the differing impact of data frequency on the performance of GMADL versus RMSE models:
The performance of the GMADL Informer strategy generally improved as the data frequency increased, achieving its best results on the 5-minute interval data.1
Conversely, the performance of the Informer strategy trained with RMSE worsened when applied to higher frequency data.1
Key 5-Minute GMADL Results: The paper highlights specific performance figures for the top-performing GMADL Informer strategy using 5-minute data 3:
Annualized Return Compounded (ARC): 115.88%
Modified Information Ratio (IR**): 7.552
These figures represent substantial outperformance compared to typical benchmark returns over the period.
Statistical Significance: The superiority of the GMADL Informer strategy was statistically validated. A t-test confirmed that its Information Ratio (IR*) was significantly greater than that of the passive Buy & Hold benchmark strategy.3
The contrasting performance trends between GMADL and RMSE with increasing data frequency offer valuable insights. High-frequency financial data is notoriously noisy. RMSE, by penalizing squared errors, might be overly sensitive to this noise, potentially leading the model to overfit or struggle to predict magnitudes accurately amidst fluctuations. GMADL, however, focuses primarily on directional correctness, particularly for larger moves (due to the b=2 weighting). This focus might allow it to more effectively filter the noise and capture the underlying directional signal, which could be more discernible or exploitable at higher frequencies where numerous small directional shifts occur. The success of GMADL at the 5-minute interval suggests that this loss function, combined with the Informer architecture and the comprehensive feature set, can effectively model very short-term predictive patterns crucial for high-frequency trading. This implies that GMADL could be a particularly well-suited loss function for developing HFT strategies where capturing directionality is paramount, potentially more so than achieving pinpoint accuracy in magnitude prediction.The following table provides a concise comparison of key performance metrics for the best GMADL strategy against a benchmark, illustrating its reported advantage:Table 3: Key Performance Metrics Summary (Average over 6 Test Periods)StrategyData FrequencyAnnualized Return (ARC)Modified Information Ratio (IR**)GMADL Informer5-min115.88%7.552(RMSE Informer 5-min)(5-min)(Value Not Provided)(Value Not Provided)(Buy & Hold)(N/A)(Value Not Provided)(Value Not Provided)Source:.3 Note: Specific comparable values for RMSE Informer (5-min) and Buy & Hold ARC/IR* were not available in the provided summary 3, but the text confirms GMADL's significant outperformance.*9. Implementation ResourcesFor those seeking to replicate the study or utilize the methodology, the authors have indicated the provision of supporting resources.
Open Source Framework: The paper states that an open-sourced implementation framework, designed for the efficient comparison of trading strategies (including the GMADL Informer strategy discussed), is made available on GitLab.4 While the specific GitLab repository URL is not provided in the summaries, the existence of this code is a significant asset for reproducibility. Accessing this framework would likely provide clarity on implementation details not fully specified in the paper's text, such as the exact normalization method, the optimizer choice, the precise logic for trading signal generation from forecasts, and the implementation of the feature calculations.
10. ConclusionThis report has detailed the implementation specifics of the GMADL Informer trading strategy for high-frequency Bitcoin data, based on the findings presented in arXiv:2503.18096v1. Successful replication requires careful attention to several key components:
Data: Utilizing the BTC/USDT dataset for the specified period (2019-2024) at 5-minute, 15-minute, or 30-minute frequencies.
Features: Implementing the extensive feature engineering pipeline, including price/volume derivatives, multi-scale volatility and moving average ratios, technical indicators, and external market data (VIX, Fed rates, Fear/Greed index).
Preprocessing: Normalizing real-valued features, embedding categorical time features, and crucially, adhering to the 6-fold rolling window structure (24 months in-sample train/validation, 6 months out-of-sample test).
Architecture: Configuring the Informer model with frequency-specific parameters for input sequence length, decoder layers, and attention heads, using a model dimension (dd) and feed-forward dimension (ff) of 256, and employing probe-sparse attention.
Training: Employing frequency-specific learning rates and batch sizes, training for up to 40 epochs with early stopping based on validation loss (patience 15 checks, frequency 300 batches), and using a dropout rate of 0.01.
Loss Function: Implementing the Generalized Mean Absolute Directional Loss (GMADL) with parameters a=100 and b=2.
Evaluation: Measuring performance using financial trading metrics such as ARC, IR**, and MD, rather than solely relying on prediction accuracy metrics.
The core takeaway from the research is that the combination of the Informer architecture, a rich feature set, and particularly the GMADL loss function yields a potent automated trading strategy for high-frequency Bitcoin data.1 The GMADL variant demonstrated statistically significant outperformance against benchmarks and other Informer configurations, with its effectiveness notably increasing at higher frequencies (peaking at 5-minute intervals) where traditional RMSE-based models struggled.1 This suggests that explicitly optimizing for directional accuracy weighted by move magnitude, as GMADL does, is a highly effective approach in the context of noisy, high-frequency financial markets.Researchers and practitioners aiming to implement this strategy should meticulously follow the frequency-specific architectural and training parameters outlined. Given the complexity, consulting the authors' publicly available GitLab repository, if accessible, is highly recommended for resolving ambiguities and ensuring faithful replication.

514
scripts/prepare_btc_data.py Normal file
View File

@ -0,0 +1,514 @@
import argparse
import glob
import logging
import os
import sqlite3
import pandas as pd
import pandas_ta as ta
import numpy as np
import wandb
import tempfile
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_data_from_db(db_path, table_name="klines"):
"""Loads data from a specific table in an SQLite database."""
logging.info(f"Reading data from {db_path}, table '{table_name}'...")
try:
conn = sqlite3.connect(db_path)
# Query to check if table exists
cursor = conn.cursor()
cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';")
if cursor.fetchone() is None:
logging.warning(f"Table '{table_name}' not found in {db_path}. Skipping.")
return None
# Adjust column names if necessary based on your actual schema
query = f"SELECT timestamp, open, high, low, close, volume FROM {table_name} WHERE instrument_id LIKE 'PAIR-BTC-%'"
# --- Add logging for the query --- New
logging.info(f"Executing query: {query}")
df = pd.read_sql_query(query, conn)
# --- Add logging for rows read --- New
logging.info(f"Read {len(df)} rows matching the criteria from {db_path}")
# --- Log raw timestamp range --- New
if not df.empty:
logging.info(f"Raw timestamp range: {df['timestamp'].min()} to {df['timestamp'].max()}")
# --- End logging ---
except sqlite3.Error as e:
logging.error(f"Error reading database {db_path}: {e}")
return None
finally:
if conn:
conn.close()
return df
def calculate_features(df):
"""Calculates technical indicators and other features."""
logging.info("Calculating base features...")
# df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') # Removed: Already indexed in main
# df = df.set_index('datetime').sort_index() # Removed: Already indexed in main
# Adjust column names if your input df has different names
# open_col, high_col, low_col, close_col, vol_col = 'open', 'high', 'low', 'close', 'volume' # Old names
# --- Use the renamed column names --- New
open_col, high_col, low_col, close_col, vol_col = 'open_price', 'high_price', 'low_price', 'close_price', 'volume'
# Drop rows with missing essential data before calculations
df = df.dropna(subset=[open_col, high_col, low_col, close_col, vol_col])
if df.empty:
logging.warning("DataFrame is empty after dropping NaNs in essential columns.")
return df
# --- Basic Price Features ---
df['open_to_close_price'] = df[close_col] / df[open_col] - 1
df['high_to_close_price'] = df[high_col] / df[close_col] - 1
df['low_to_close_price'] = df[low_col] / df[close_col] - 1
df['high_to_low_price'] = df[high_col] / df[low_col] - 1
# --- Returns ---
# Shift(1) calculates return based on previous close: (close_t / close_{t-1}) - 1
df['returns'] = df[close_col].pct_change()
df['log_returns'] = np.log(df[close_col] / df[close_col].shift(1))
# --- Time Features ---
df['hour'] = df.index.hour.astype(str).astype("category") # Use string/category as required by config
df['weekday'] = df.index.weekday.astype(str).astype("category") # Use string/category
# --- Technical Indicators using pandas_ta ---
logging.info("Calculating technical indicators (this may take a while)...")
custom_strategy = ta.Strategy(
name="informer_features",
description="Calculate features for Informer model based on config",
ta=[
# Volatility (adjust lengths as needed, config doesn't specify)
{"kind": "atr", "length": 14, "col_names": "atr"}, # Example ATR
# MACD
{"kind": "macd", "fast": 12, "slow": 26, "signal": 9, "col_names": ("macd", "macd_hist", "macd_signal")},
# RSI
{"kind": "rsi", "length": 14, "col_names": "rsi"},
# Bollinger Bands
{"kind": "bbands", "length": 20, "std": 2, "col_names": ("low_bband", "mid_bband", "up_bband", "bandwidth", "percent")},
# SMA (1h=12*5m, 1d=288*5m, 7d=2016*5m)
{"kind": "sma", "length": 12, "col_names": "sma_1h"},
{"kind": "sma", "length": 288, "col_names": "sma_1d"},
{"kind": "sma", "length": 2016, "col_names": "sma_7d"},
# EMA (1h=12*5m, 1d=288*5m) - Note: Config only lists ema_1h, ema_1d relative to close
{"kind": "ema", "length": 12, "col_names": "ema_1h"},
{"kind": "ema", "length": 288, "col_names": "ema_1d"},
]
)
df.ta.strategy(custom_strategy)
# --- Volatility (Calculated on Price/Returns - Choose appropriate source) ---
# Using log returns is common for volatility calculation
df['vol_1h'] = df['log_returns'].rolling(window=12).std() * np.sqrt(12) # Scaled 1h vol
df['vol_1d'] = df['log_returns'].rolling(window=288).std() * np.sqrt(288) # Scaled daily vol
df['vol_7d'] = df['log_returns'].rolling(window=2016).std() * np.sqrt(2016) # Scaled weekly vol
# --- Relative Indicators (indicator / close_price) ---
logging.info("Calculating relative indicators...")
for indicator in ['low_bband', 'mid_bband', 'up_bband', 'sma_1h', 'sma_1d', 'sma_7d', 'ema_1h', 'ema_1d']:
if indicator in df.columns:
df[f'{indicator}_to_close_price'] = df[indicator] / df[close_col] -1
else:
logging.warning(f"Base indicator '{indicator}' not found for relative calculation.")
# --- Clean up intermediate columns if needed ---
# df = df.drop(columns=['atr', 'macd_hist', 'low_bband', 'mid_bband', 'up_bband', 'bandwidth', 'percent', 'sma_1h', 'sma_1d', 'sma_7d', 'ema_1h', 'ema_1d'])
# --- Handle initial NaNs introduced by rolling windows/shifts ---
# returns and log_returns will have NaN for the first row.
# Indicators will have NaNs for their window length.
# We will forward-fill later after merging external data.
return df
def load_external_data(file_path, date_col, value_col, rename_to=None):
"""Loads external daily data like VIX or Fear/Greed Index."""
logging.info(f"Loading external data from {file_path}...")
try:
df = pd.read_csv(file_path)
df[date_col] = pd.to_datetime(df[date_col])
# Keep only date and value, rename value column
df = df[[date_col, value_col]].rename(columns={value_col: rename_to or value_col})
# --- Normalize the date index --- New
df = df.set_index(date_col).sort_index()
df.index = df.index.normalize() # Ensure time is midnight
logging.info(f"Loaded {len(df)} records from {file_path}. Index normalized.")
return df
except FileNotFoundError:
logging.error(f"External data file not found: {file_path}")
return None
except Exception as e:
logging.error(f"Error loading external data from {file_path}: {e}")
return None
def main(db_pattern, db_table, vix_file, fear_greed_file, eff_rate_file, args):
"""Main function to load, process, and save data."""
db_files = glob.glob(os.path.expanduser(db_pattern), recursive=True)
if not db_files:
logging.error(f"No database files found matching pattern: {db_pattern}")
return
logging.info(f"Found {len(db_files)} database files.")
all_data = []
for db_file in db_files:
df = load_data_from_db(db_file, table_name=db_table)
if df is not None:
all_data.append(df)
if not all_data:
logging.error("No data loaded from any database file.")
return
logging.info("Concatenating data from all databases...")
btc_df = pd.concat(all_data, ignore_index=True)
# --- Log raw timestamp info --- New
if not btc_df.empty:
logging.info(f"Raw timestamp column info - dtype: {btc_df['timestamp'].dtype}, head:\n{btc_df['timestamp'].head()}")
else:
logging.warning("BTC DataFrame empty after concat, cannot check raw timestamp.")
# --- End logging ---
# --- Initial Processing ---
# Convert timestamp to datetime and sort
btc_df['datetime'] = pd.to_datetime(btc_df['timestamp'], unit='s')
# --- Add logging to check converted dates --- New
if not btc_df.empty:
logging.info(f"Converted datetime range: {btc_df['datetime'].min()} to {btc_df['datetime'].max()}")
else:
logging.warning("BTC DataFrame is empty after concatenation, cannot check datetime range.")
# --- End logging ---
# Deduplicate based on timestamp, keep first entry
btc_df = btc_df.sort_values('datetime').drop_duplicates(subset=['timestamp'], keep='first')
# --- Rename price columns to match config --- Moved Earlier
rename_map = {'open': 'open_price', 'high': 'high_price', 'low': 'low_price', 'close': 'close_price'}
btc_df = btc_df.rename(columns=rename_map)
logging.info(f"Renamed columns: {rename_map}")
# --- Set index and log info ---
btc_df = btc_df.set_index('datetime').sort_index()
if not btc_df.empty:
# --- Log index details --- Modified
logging.info(f"DataFrame index info - dtype: {btc_df.index.dtype}, timezone: {btc_df.index.tz}, range: {btc_df.index.min()} to {btc_df.index.max()}") # Added timezone check
logging.info(f"DataFrame head(1):\n{btc_df.head(1)}") # Added head check
else:
logging.warning("BTC DataFrame empty after setting index.")
# --- End logging ---
logging.info(f"Total unique records after concatenation: {len(btc_df)}")
# --- Resample to 5-minute Intervals --- New
logging.info("Resampling 1-minute data to 5-minute intervals...")
resampling_rules = {
'open_price': 'first',
'high_price': 'max',
'low_price': 'min',
'close_price': 'last',
'volume': 'sum'
}
# Ensure columns exist before resampling
missing_cols = [col for col in resampling_rules if col not in btc_df.columns]
if missing_cols:
logging.error(f"Cannot resample, required columns missing: {missing_cols}")
return
btc_df = btc_df[list(resampling_rules.keys())].resample('5T').agg(resampling_rules)
# Drop rows where resampling might have produced all NaNs (e.g., gaps in original data)
btc_df.dropna(subset=['open_price', 'high_price', 'low_price', 'close_price'], inplace=True)
logging.info(f"Resampled data shape: {btc_df.shape}")
if not btc_df.empty:
logging.info(f"Resampled index range: {btc_df.index.min()} to {btc_df.index.max()}")
logging.info(f"Resampled head(1):\n{btc_df.head(1)}")
else:
logging.warning("DataFrame empty after resampling.")
return # Stop if empty after resampling
# --- End Resampling ---
# --- Feature Calculation ---
# Now operates on the 5-minute resampled data
btc_df = calculate_features(btc_df)
if btc_df.empty:
logging.error("DataFrame became empty during feature calculation.")
return
# --- Load and Merge External Data ---
# VIX Data - Assuming daily data
# vix_df = load_external_data(vix_file, date_col='Date', value_col='VIX Close', rename_to='vix_close_price') # Old
vix_df = load_external_data(vix_file, date_col='date', value_col='close', rename_to='vix_close_price') # Corrected
# Fear & Greed Data - Assuming daily data
# fg_df = load_external_data(fear_greed_file, date_col='timestamp', value_col='value', rename_to='fear_greed_index') # Old
fg_df = load_external_data(fear_greed_file, date_col='date', value_col='fng_value', rename_to='fear_greed_index') # Corrected
# --- Load Effective Rates Data ---
eff_rates_df = load_external_data(eff_rate_file, date_col='observation_date', value_col='DFF', rename_to='effective_rates')
# --- Log External Data Index Info & Timezones --- Modified
if vix_df is not None: logging.info(f"VIX index info - dtype: {vix_df.index.dtype}, timezone: {vix_df.index.tz}, range: {vix_df.index.min()} to {vix_df.index.max()}")
if fg_df is not None: logging.info(f"F&G index info - dtype: {fg_df.index.dtype}, timezone: {fg_df.index.tz}, range: {fg_df.index.min()} to {fg_df.index.max()}")
if eff_rates_df is not None: logging.info(f"EffRates index info - dtype: {eff_rates_df.index.dtype}, timezone: {eff_rates_df.index.tz}, range: {eff_rates_df.index.min()} to {eff_rates_df.index.max()}")
# --- Log external data near BTC start --- New
if not btc_df.empty:
first_btc_time = btc_df.index.min()
logging.info(f"First BTC timestamp: {first_btc_time}")
if vix_df is not None:
logging.info(f"VIX data at/before start:\n{vix_df[vix_df.index <= first_btc_time].tail()}")
if fg_df is not None:
logging.info(f"F&G data at/before start:\n{fg_df[fg_df.index <= first_btc_time].tail()}")
if eff_rates_df is not None:
logging.info(f"EffRates data at/before start:\n{eff_rates_df[eff_rates_df.index <= first_btc_time].tail()}")
# --- End logging ---
# --- Perform merge_asof ---
logging.info("Performing merge_asof based on DatetimeIndex...")
# Ensure DataFrames are sorted by index (should be, but explicit is safer)
btc_df = btc_df.sort_index()
if vix_df is not None: vix_df = vix_df.sort_index()
if fg_df is not None: fg_df = fg_df.sort_index()
if eff_rates_df is not None: eff_rates_df = eff_rates_df.sort_index()
if vix_df is not None:
btc_df = pd.merge_asof(btc_df, vix_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after VIX merge_asof: {btc_df.shape}, VIX NaNs: {btc_df['vix_close_price'].isna().sum()}")
if fg_df is not None:
btc_df = pd.merge_asof(btc_df, fg_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after F&G merge_asof: {btc_df.shape}, F&G NaNs: {btc_df['fear_greed_index'].isna().sum()}")
if eff_rates_df is not None:
btc_df = pd.merge_asof(btc_df, eff_rates_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after EffRates merge_asof: {btc_df.shape}, EffRates NaNs: {btc_df['effective_rates'].isna().sum()}")
logging.info("Finished merge_asof operations.")
# --- End Merge Block ---
# --- Add logging after merge ---
logging.info(f"BTC data after merge - Shape: {btc_df.shape}, Null counts:\n{btc_df.isna().sum().sort_values(ascending=False).head()}")
# --- Final Preparations ---
logging.info("Performing final data preparation steps...")
# Add required columns not generated yet
btc_df['group_id'] = "BTC-USDT" # Static group ID for this dataset
btc_df['group_id'] = btc_df['group_id'].astype("category")
# Create the sequential time index required by pytorch-forecasting
btc_df = btc_df.sort_index() # Ensure sorted before creating index
# --- Add close_time column --- New
# The index represents the start of the 5min interval
# Close time is 5 minutes after the start
btc_df['close_time'] = btc_df.index + pd.Timedelta(minutes=5)
logging.info(f"Added 'close_time' column. Head:\n{btc_df['close_time'].head()}")
# --- End Add close_time ---
btc_df = btc_df.reset_index() # Bring datetime back as a column temporarily
btc_df['time_index'] = btc_df.index # Create sequential integer index
# Define final columns based on the YAML config (ensure all generated features are included)
# Make sure these names match exactly what was generated
final_columns = [
"time_index", "group_id", "returns", # Core fields
"close_time", # Added missing column
# dynamic_unknown_real
"high_price", "low_price", "open_price", "close_price", "volume",
"open_to_close_price", "high_to_close_price", "low_to_close_price", "high_to_low_price",
"log_returns", "vol_1h", "macd", "macd_signal", "rsi",
"low_bband_to_close_price", "up_bband_to_close_price", "mid_bband_to_close_price",
"sma_1h_to_close_price", "sma_1d_to_close_price", "sma_7d_to_close_price",
"ema_1h_to_close_price", "ema_1d_to_close_price",
# dynamic_known_real (Check if these exist after merge)
"vix_close_price", "fear_greed_index", "vol_1d", "vol_7d", "effective_rates",
# dynamic_known_cat
"hour", "weekday"
]
# TODO: Add 'effective_rates' back if you load and merge it
# Select and reorder columns, handling potential missing external cols
cols_to_select = []
for col in final_columns:
if col in btc_df.columns:
cols_to_select.append(col)
else:
logging.warning(f"Required column '{col}' not found in DataFrame. It will be excluded.")
final_df = btc_df[cols_to_select]
# --- Handle Missing Values ---
# Forward fill is common for time series, especially after merges and indicator calculations
# Note: Ffill might not be suitable for returns, but initial NaNs in returns are expected.
# Consider specific handling if needed.
logging.info(f"Forward filling NaNs. Initial NaN count:\n{final_df.isna().sum().sort_values(ascending=False).head()}")
final_df = final_df.ffill()
# Drop any rows that *still* have NaNs (e.g., at the very beginning before first external data point or calc window)
initial_rows = len(final_df)
# --- Modify dropna to be less aggressive --- New # Comment Needs Update
# Define critical columns that *must* be present
# critical_cols = ['open_price', 'high_price', 'low_price', 'close_price', 'volume', 'returns']
# Check if critical columns exist before using them in subset
# subset_cols = [col for col in critical_cols if col in final_df.columns]
# if not subset_cols:
# logging.warning("No critical columns found for dropna subset. Skipping dropna.")
# else:
# logging.info(f"Dropping rows where any of {subset_cols} are NaN.")
# final_df = final_df.dropna(subset=subset_cols)
# --- Drop rows with ANY NaN value --- Modified
logging.info(f"Dropping rows with any NaN values.")
final_df = final_df.dropna()
# --- End modification ---
rows_dropped = initial_rows - len(final_df)
if rows_dropped > 0:
logging.warning(f"Dropped {rows_dropped} rows containing NaNs after forward filling.")
# Final check
if final_df.isna().any().any():
logging.warning(f"NaN values still present after processing:\n{final_df.isna().sum()[final_df.isna().sum() > 0]}")
else:
logging.info("No remaining NaN values detected.")
if final_df.empty:
logging.error("Final DataFrame is empty after processing and NaN handling.")
return
# --- Removed Data Splitting Logic ---
# split_ratio = 0.8 # Use 80% for in-sample
# split_index = int(len(final_df) * split_ratio)
#
# in_sample_df = final_df.iloc[:split_index]
# out_of_sample_df = final_df.iloc[split_index:]
#
# logging.info(f"Split data: {len(in_sample_df)} in-sample rows, {len(out_of_sample_df)} out-of-sample rows.")
# logging.info(f"In-sample time range: {in_sample_df['time_index'].min()} to {in_sample_df['time_index'].max()}")
# logging.info(f"Out-of-sample time range: {out_of_sample_df['time_index'].min()} to {out_of_sample_df['time_index'].max()}")
# --- End Split Removal ---
# --- Log Single Artifact to W&B --- Modified
logging.info(f"Logging full dataset artifact to W&B project '{wandb.run.project}', run '{wandb.run.name}'...")
try:
with tempfile.TemporaryDirectory() as tempdir:
# Save the entire final_df
full_data_path = os.path.join(tempdir, 'full_data.parquet')
final_df.to_parquet(full_data_path, index=False)
logging.info(f"Temporary file saved to {tempdir}")
# Create and log the single artifact
full_artifact = wandb.Artifact(
name=args.full_dataset_artifact_name, # Use new arg
type='dataset',
description=f'Full BTC 5min features data ({len(final_df)} rows). Prepared by run {wandb.run.id}.',
metadata={'rows': len(final_df)}
)
full_artifact.add_file(full_data_path)
wandb.log_artifact(full_artifact)
logging.info(f"Logged full dataset artifact: {args.full_dataset_artifact_name}")
# --- Removed logging for separate artifacts ---
# # Create and log the IN-SAMPLE artifact
# in_sample_artifact = wandb.Artifact(
# name=args.in_sample_artifact_name, # Use arg
# type='dataset',
# description=f'In-sample BTC 5min data ({len(in_sample_df)} rows). Prepared by run {wandb.run.id}.',
# metadata={'rows': len(in_sample_df), 'split': 'in_sample'}
# )
# in_sample_artifact.add_file(in_sample_path)
# wandb.log_artifact(in_sample_artifact)
# logging.info(f"Logged in-sample artifact: {args.in_sample_artifact_name}")
#
# # Create and log the OUT-OF-SAMPLE artifact
# out_of_sample_artifact = wandb.Artifact(
# name=args.out_of_sample_artifact_name, # Use arg
# type='dataset',
# description=f'Out-of-sample BTC 5min data ({len(out_of_sample_df)} rows). Prepared by run {wandb.run.id}.',
# metadata={'rows': len(out_of_sample_df), 'split': 'out_of_sample'}
# )
# out_of_sample_artifact.add_file(out_of_sample_path)
# wandb.log_artifact(out_of_sample_artifact)
# logging.info(f"Logged out-of-sample artifact: {args.out_of_sample_artifact_name}")
logging.info("Artifact logged successfully.")
except Exception as e:
logging.error(f"Error logging artifacts to W&B: {e}")
wandb.run.finish(exit_code=1) # Finish run with error
return
# --- End W&B Logging ---
wandb.run.finish() # Finish run successfully
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare BTC-USDT 5-minute data and log to W&B.")
parser.add_argument(
"--db-pattern",
default="/home/yasha/develop/data/combined.coinbase_1min_hist.db",
help="Pattern or exact path to find input SQLite database file(s)."
)
parser.add_argument(
"--db-table",
default="combined_hist_1min",
help="Name of the table containing kline data within the SQLite files."
)
parser.add_argument(
"--vix-file",
default="data/vix_daily.csv",
help="Path to the VIX index CSV file."
)
parser.add_argument(
"--fear-greed-file",
default="data/fear_greed_index.csv",
help="Path to the Crypto Fear & Greed Index CSV file."
)
parser.add_argument(
"--eff-rate-file",
default="data/DFF.csv",
help="Path to the Effective Rates CSV file."
)
parser.add_argument(
"--wandb-project",
default="wne-masters-thesis-testing",
help="W&B project name."
)
parser.add_argument(
"--wandb-run-name",
default="prepare-btc-data",
help="W&B run name for this preparation job."
)
parser.add_argument(
"--wandb-notes",
default=None,
help="Optional notes for the W&B run."
)
parser.add_argument(
"--full-dataset-artifact-name",
default="btc-5m-features-full", # Match YAML default
help="Name for the single W&B artifact containing the full dataset."
)
args = parser.parse_args()
# --- Initialize W&B Run --- New
run = wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
notes=args.wandb_notes,
job_type="data-preparation",
config=vars(args) # Log command line args
)
# --- End W&B Init ---
# --- Pass args to main --- Modified
main(
db_pattern=args.db_pattern,
db_table=args.db_table,
vix_file=args.vix_file,
fear_greed_file=args.fear_greed_file,
eff_rate_file=args.eff_rate_file,
args=args # Pass all args for artifact names etc.
)

View File

@ -7,6 +7,7 @@ import tempfile
import torch import torch
import lightning.pytorch as pl import lightning.pytorch as pl
import pandas as pd import pandas as pd
import warnings
from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.early_stopping import EarlyStopping
@ -21,6 +22,13 @@ from ml.data import (
build_time_series_dataset build_time_series_dataset
) )
# --- Suppress specific sklearn UserWarning ---
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.utils.validation")
# ---
# --- Set Matmul Precision for Tensor Cores ---
torch.set_float32_matmul_precision('medium')
# ---
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -60,7 +68,7 @@ def get_args():
parser.add_argument( parser.add_argument(
'-v', '-v',
'--val-check-interval', '--val-check-interval',
default=300, default=100,
type=int, type=int,
help="Run validation every n batches." help="Run validation every n batches."
) )

View File

@ -1,60 +1,155 @@
import os import os
import pandas as pd
import wandb import wandb
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet import pandas as pd
import logging
from pytorch_forecasting.data import TimeSeriesDataSet
def get_dataset_from_wandb(run, window=None): def get_dataset_from_wandb(run):
artifact_name = f"{run.project}/{run.config['data']['dataset']}" """Downloads the specified dataset artifact and splits it based on sliding window and validation split."""
artifact = wandb.Api().artifact(artifact_name)
base_path = artifact.download() # Construct artifact name from run config
# Example: "btc-5m-features-full:latest"
dataset_artifact_name = run.config.get('data', {}).get('dataset', None)
if not dataset_artifact_name:
raise ValueError("Dataset artifact name not found in run configuration (run.config.data.dataset)")
full_artifact_name = f"{run.project}/{dataset_artifact_name}"
logging.info(f"Attempting to download artifact: {full_artifact_name}")
try:
artifact = wandb.Api().artifact(full_artifact_name)
base_path = artifact.download()
logging.info(f"Artifact downloaded to: {base_path}")
except Exception as e:
logging.error(f"Failed to download artifact {full_artifact_name}: {e}")
raise # Re-raise the exception
name = artifact.metadata['name'] # Load the single parquet file
in_sample_name =\ full_data_file = os.path.join(base_path, 'full_data.parquet')
f"in-sample-{window or run.config['data']['sliding_window']}" logging.info(f"Loading full dataset from: {full_data_file}")
in_sample_data = pd.read_csv(os.path.join( if not os.path.exists(full_data_file):
base_path, name + '-' + in_sample_name + '.csv')) raise FileNotFoundError(f"Expected parquet file 'full_data.parquet' not found in artifact directory: {base_path}")
out_of_sample_name =\
f"out-of-sample-{window or run.config['data']['sliding_window']}" full_df = pd.read_parquet(full_data_file)
out_of_sample_data = pd.read_csv(os.path.join( logging.info(f"Loaded full dataset with shape: {full_df.shape}")
base_path, name + '-' + out_of_sample_name + '.csv'))
return in_sample_data, out_of_sample_data # --- Get Parameters for Splitting --- Modified
# Validation split now defines the size of the FINAL TEST SET
test_set_fraction = run.config.get('data', {}).get('validation', 0.2)
if not (0 < test_set_fraction < 1):
raise ValueError(f"Invalid final test set fraction (config.data.validation): {test_set_fraction}.")
# Sliding window index determines the end of the current in-sample data
sliding_window_idx = run.config.get('data', {}).get('sliding_window', 0)
# *** Assumption: Total number of windows (e.g., 0 to 5 means 6 windows) ***
# This should ideally match the sweep range if sweeping over sliding_window
total_num_windows = 6 # Hardcoded assumption - Adjust if needed or make configurable
if not (0 <= sliding_window_idx < total_num_windows):
raise ValueError(f"Invalid sliding_window index: {sliding_window_idx}. Must be between 0 and {total_num_windows-1}.")
# ---
# --- Calculate Splits --- Modified
N = len(full_df)
# End index of the pool used for all training/validation windows (excludes final test set)
train_val_pool_end_idx = int(N * (1 - test_set_fraction))
# Size of each window's data block within the training/validation pool
window_block_size = train_val_pool_end_idx // total_num_windows
if window_block_size == 0:
raise ValueError("Dataset too small for the number of windows and test set fraction.")
# End index for the current window's in-sample data (expanding window)
current_in_sample_end_idx = window_block_size * (sliding_window_idx + 1)
# Ensure the last window uses all data up to the test set
if sliding_window_idx == total_num_windows - 1:
current_in_sample_end_idx = train_val_pool_end_idx
in_sample_df = full_df.iloc[0 : current_in_sample_end_idx].copy()
# Out-of-sample is now the fixed final test set
out_of_sample_df = full_df.iloc[train_val_pool_end_idx :].copy()
logging.info(f"Sliding Window: {sliding_window_idx}, Test Fraction: {test_set_fraction:.1%}")
logging.info(f"In-sample indices: 0:{current_in_sample_end_idx}, Out-of-sample (Test) indices: {train_val_pool_end_idx}:{N}")
logging.info(f"Returned data shapes: In-sample={in_sample_df.shape}, Out-of-sample={out_of_sample_df.shape}")
# ---
return in_sample_df, out_of_sample_df
def get_train_validation_split(config, in_sample_data): def get_train_validation_split(config, data):
validation_part = config['data']['validation'] """Splits the provided (in-sample) data into training and validation sets."""
train_data = in_sample_data.iloc[:int( # data here is the in_sample_df for the current sliding window
len(in_sample_data) * (1 - validation_part))]
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:] # Use the 'validation' fraction again, but now it defines the val set size *within* the in-sample data
# This takes the LATEST part of the current window for validation.
return train_data, val_data validation_fraction_within_window = config.get('data', {}).get('validation', 0.2)
if not (0 < validation_fraction_within_window < 1):
raise ValueError(f"Invalid validation fraction for train/val split (config.data.validation): {validation_fraction_within_window}.")
N_in_sample = len(data)
# Calculate the start index of the validation set within the current in-sample data
validation_start_idx = int(N_in_sample * (1 - validation_fraction_within_window))
train_data = data.iloc[:validation_start_idx]
valid_data = data.iloc[validation_start_idx:]
logging.info(f"Split in-sample data ({1-validation_fraction_within_window:.1%} / {validation_fraction_within_window:.1%}): {len(train_data)} train, {len(valid_data)} validation rows.")
return train_data, valid_data
def build_time_series_dataset(config, data): def build_time_series_dataset(config, data):
data = data.copy() """Builds TimeSeriesDataSet from configuration and data."""
# TODO: Fix in dataset
data['weekday'] = data['weekday'].astype('str') fields = config.get('fields', {})
data['hour'] = data['hour'].astype('str') time_idx = fields.get('time_index', 'time_idx') # Default if not specified
target = fields.get('target', 'target')
group_ids = fields.get('group_ids', [])
# Extract features based on types defined in config
time_varying_known_reals = fields.get('dynamic_known_real', [])
time_varying_known_categoricals = fields.get('dynamic_known_cat', [])
time_varying_unknown_reals = fields.get('dynamic_unknown_real', [])
time_varying_unknown_categoricals = fields.get('dynamic_unknown_cat', [])
static_reals = fields.get('static_real', [])
static_categoricals = fields.get('static_cat', [])
# Max lengths from config
max_encoder_length = config.get('past_window', 24)
max_prediction_length = config.get('future_window', 6)
time_series_dataset = TimeSeriesDataSet( # Ensure all specified columns exist in the dataframe
data, required_cols = (
time_idx=config['fields']['time_index'], [time_idx, target] + group_ids +
target=config['fields']['target'], time_varying_known_reals + time_varying_known_categoricals +
group_ids=config['fields']['group_ids'], time_varying_unknown_reals + time_varying_unknown_categoricals +
min_encoder_length=config['past_window'], static_reals + static_categoricals
max_encoder_length=config['past_window'],
min_prediction_length=config['future_window'],
max_prediction_length=config['future_window'],
static_reals=config['fields']['static_real'],
static_categoricals=config['fields']['static_cat'],
time_varying_known_reals=config['fields']['dynamic_known_real'],
time_varying_known_categoricals=config['fields']['dynamic_known_cat'],
time_varying_unknown_reals=config['fields']['dynamic_unknown_real'],
time_varying_unknown_categoricals=config['fields'][
'dynamic_unknown_cat'],
randomize_length=False,
) )
missing_cols = [col for col in required_cols if col not in data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in DataFrame: {missing_cols}")
return time_series_dataset logging.info("Building TimeSeriesDataSet...")
# Ensure target is float for regression/quantile tasks
# data[target] = data[target].astype(float)
dataset = TimeSeriesDataSet(
data,
time_idx=time_idx,
target=target,
group_ids=group_ids,
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
static_categoricals=static_categoricals,
static_reals=static_reals,
time_varying_known_categoricals=time_varying_known_categoricals,
time_varying_known_reals=time_varying_known_reals,
time_varying_unknown_categoricals=time_varying_unknown_categoricals,
time_varying_unknown_reals=time_varying_unknown_reals,
add_relative_time_idx=True, # Often useful
add_target_scales=True, # Often useful
add_encoder_length=True, # Often useful
allow_missing_timesteps=True # Set based on your data characteristics
)
logging.info("TimeSeriesDataSet built successfully.")
return dataset

View File

@ -109,88 +109,104 @@ class Informer(BaseModelWithCovariates):
output_size: Union[int, List[int]] = 1, output_size: Union[int, List[int]] = 1,
loss=None, loss=None,
logging_metrics: nn.ModuleList = None, logging_metrics: nn.ModuleList = None,
actual_n_encoder_reals: int = -1,
**kwargs): **kwargs):
# --- Call super().__init__ first ---
super().__init__( super().__init__(
loss=loss, loss=loss,
logging_metrics=logging_metrics, logging_metrics=logging_metrics,
**kwargs) **kwargs)
# ---
# Save hparams after super().__init__ so dataset parameters are available
self.save_hyperparameters(ignore=['loss']) self.save_hyperparameters(ignore=['loss'])
self.attention_type = attention_type self.attention_type = attention_type
assert not static_reals # --- Calculate n_encoder_reals using self.hparams (populated by save_hyperparameters) ---
assert not static_categoricals n_encoder_reals = len(self.hparams.x_reals)
print(f"Initializing enc_real_embeddings with {n_encoder_reals} channels (derived from len(hparams.x_reals)).")
# ---
# assertions (can remain commented)
# assert isinstance(loss, PyTorchMetric), "Loss has to be PyTorch Metric"
# assert not static_reals # Ensure this line remains commented out
# --- Use self.hparams for MultiEmbedding as well ---
self.cat_embeddings = MultiEmbedding( self.cat_embeddings = MultiEmbedding(
embedding_sizes=embedding_sizes, embedding_sizes=self.hparams.embedding_sizes,
embedding_paddings=embedding_paddings, embedding_paddings=self.hparams.embedding_paddings,
categorical_groups=categorical_groups, categorical_groups=self.hparams.categorical_groups,
x_categoricals=x_categoricals, x_categoricals=self.hparams.x_categoricals,
) )
self.enc_real_embeddings = TokenEmbedding( # Initialize with the derived total number of continuous encoder variables
len(time_varying_reals_encoder), d_model) self.enc_real_embeddings = TokenEmbedding(n_encoder_reals, self.hparams.d_model)
self.enc_positional_embeddings = PositionalEmbedding(d_model)
self.enc_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
# Decoder embedding initialization using hparams
decoder_reals_list = self.hparams.time_varying_reals_decoder
print(f"Initializing dec_real_embeddings with {len(decoder_reals_list)} channels.")
self.dec_real_embeddings = TokenEmbedding( self.dec_real_embeddings = TokenEmbedding(
len(time_varying_reals_decoder), d_model) len(decoder_reals_list), self.hparams.d_model)
self.dec_positional_embeddings = PositionalEmbedding(d_model) self.dec_positional_embeddings = PositionalEmbedding(self.hparams.d_model)
Attention = ProbSparseAttention \ Attention = ProbSparseAttention \
if attention_type == "prob" else FullAttention if self.hparams.attention_type == "prob" else FullAttention
# --- Initialize Encoder/Decoder using self.hparams ---
self.encoder = Encoder( self.encoder = Encoder(
[ [
EncoderLayer( EncoderLayer(
AttentionLayer( AttentionLayer(
Attention(False, factor, attention_dropout=dropout, Attention(False, self.hparams.factor, attention_dropout=self.hparams.dropout,
output_attention=output_attention), output_attention=self.hparams.output_attention),
d_model, self.hparams.d_model,
n_attention_heads, self.hparams.n_attention_heads,
mix=False, mix=False,
), ),
d_model, self.hparams.d_model,
d_fully_connected, self.hparams.d_fully_connected,
dropout=dropout, dropout=self.hparams.dropout,
activation=activation, activation=self.hparams.activation,
) )
for _ in range(n_encoder_layers) for _ in range(self.hparams.n_encoder_layers)
], ],
[SelfAttentionDistil(d_model) for _ in range( [SelfAttentionDistil(self.hparams.d_model) for _ in range(
n_encoder_layers - 1)] if distil else None, self.hparams.n_encoder_layers - 1)] if self.hparams.distil else None,
nn.LayerNorm(d_model), nn.LayerNorm(self.hparams.d_model),
) )
self.decoder = Decoder( self.decoder = Decoder(
[ [
DecoderLayer( DecoderLayer(
AttentionLayer( AttentionLayer(
Attention(True, factor, attention_dropout=dropout, Attention(True, self.hparams.factor, attention_dropout=self.hparams.dropout,
output_attention=False), output_attention=False),
d_model, self.hparams.d_model,
n_attention_heads, self.hparams.n_attention_heads,
mix=mix_attention, mix=self.hparams.mix_attention,
), ),
AttentionLayer( AttentionLayer(
FullAttention( FullAttention(
False, False,
factor, self.hparams.factor,
attention_dropout=dropout, attention_dropout=self.hparams.dropout,
output_attention=False), output_attention=False),
d_model, self.hparams.d_model,
n_attention_heads, self.hparams.n_attention_heads,
mix=False, mix=False,
), ),
d_model, self.hparams.d_model,
d_fully_connected, self.hparams.d_fully_connected,
dropout=dropout, dropout=self.hparams.dropout,
activation=activation, activation=self.hparams.activation,
) )
for _ in range(n_decoder_layers) for _ in range(self.hparams.n_decoder_layers)
], ],
nn.LayerNorm(d_model), nn.LayerNorm(self.hparams.d_model),
) )
self.projection = nn.Linear(d_model, output_size) self.projection = nn.Linear(self.hparams.d_model, self.hparams.output_size)
def forward( def forward(
self, self,
@ -230,6 +246,7 @@ class Informer(BaseModelWithCovariates):
): ):
new_kwargs = copy(kwargs) new_kwargs = copy(kwargs)
new_kwargs.update(cls.deduce_default_output_parameters( new_kwargs.update(cls.deduce_default_output_parameters(
dataset, kwargs, QuantileLoss())) dataset, kwargs, QuantileLoss())) # Using QuantileLoss for defaults might be okay
# Let super().from_dataset handle populating dataset_parameters correctly
return super().from_dataset(dataset, **new_kwargs) return super().from_dataset(dataset, **new_kwargs)

View File

@ -3,8 +3,9 @@ import itertools
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import functools import functools
from tqdm import tqdm # from tqdm import tqdm # Remove standard tqdm
from multiprocessing import Pool from tqdm.notebook import tqdm # Import notebook-specific tqdm
from joblib import Parallel, delayed
from strategy import metrics from strategy import metrics
from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION
from strategy.strategy import StrategyBase from strategy.strategy import StrategyBase
@ -30,26 +31,33 @@ def parameter_sweep(
result = [] result = []
total = len(param_sets) total = len(param_sets)
# Evaluate sets of different hyperparameters in parallel # Prepare the function with fixed arguments
with Pool(num_workers) as pool, tqdm(total=total) as pbar: evaluate_func_partial = functools.partial(
for chunk in (param_sets[i:i + log_every]
for i in range(0, total, log_every)):
tmp = list(
pool.map(
functools.partial(
evaluate_strategy, evaluate_strategy,
data, data,
exchange_fee=exchange_fee, exchange_fee=exchange_fee,
interval=interval, interval=interval,
padding=padding, padding=padding,
include_arrays=False), include_arrays=False
map( )
lambda p: strategy_class(
**p), chunk))) # Prepare the list of delayed strategy instantiations and function calls
pbar.update(len(tmp)) tasks = [
result += list(zip(tmp, map( delayed(evaluate_func_partial)(strategy_class(**p))
lambda p: strategy_class( for p in param_sets
**p), chunk))) ]
print(f"Running {total} tasks in parallel with joblib (n_jobs={num_workers})...")
# Run in parallel
# Using loky backend for potentially better pickling robustness
tmp_results = Parallel(n_jobs=num_workers, backend="loky")(
tqdm(tasks)
)
# Let's re-instantiate for simplicity here, though less efficient
evaluated_strategies = [strategy_class(**p) for p in param_sets]
result = list(zip(tmp_results, evaluated_strategies))
print("Parallel processing finished.")
return sorted(result, key=lambda x: x[0][sort_by], reverse=True) return sorted(result, key=lambda x: x[0][sort_by], reverse=True)

View File

@ -241,9 +241,17 @@ class ModelPredictionsStrategyBase(StrategyBase):
data, self.predictions, on=['time_index', 'group_id'], data, self.predictions, on=['time_index', 'group_id'],
how='left') how='left')
return self.get_positions(merged_data) # merged_data['prediction'] = merged_data['prediction'].fillna(0).infer_objects(copy=False) # Old fix
# Explicitly convert to numeric, coercing errors, then fill NaNs
merged_data['prediction'] = pd.to_numeric(merged_data['prediction'], errors='coerce').fillna(0)
arr_preds = merged_data['prediction'].to_numpy()
def get_positions(self, data): # arr_preds = arr_preds[:, self.future, 0] # Incorrect indexing for 1D array
# No change needed if arr_preds is already the correct 1D array of predictions
return self._get_positions(merged_data, arr_preds)
def _get_positions(self, data, arr_preds):
raise NotImplementedError() raise NotImplementedError()
@ -277,11 +285,7 @@ class ModelGmadlPredictionsStrategy(ModelPredictionsStrategyBase):
'exit_short': self.exit_short 'exit_short': self.exit_short
} }
def get_positions(self, data): def _get_positions(self, data, arr_preds):
# bfill() is a hack to make it work with non predicted data
arr_preds = np.stack(data['prediction'].ffill().bfill().to_numpy())
arr_preds = arr_preds[:, self.future, 0]
enter_long = arr_preds > (self.enter_long or np.infty) enter_long = arr_preds > (self.enter_long or np.infty)
exit_long = arr_preds < (self.exit_long or -np.infty) exit_long = arr_preds < (self.exit_long or -np.infty)
enter_short = arr_preds < ( enter_short = arr_preds < (
@ -344,7 +348,7 @@ class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase):
'quantile_exit_short': self.quantile_exit_short 'quantile_exit_short': self.quantile_exit_short
} }
def get_positions(self, data): def _get_positions(self, data, arr_preds):
if self.new_impl: if self.new_impl:
return self.get_positions2(data) return self.get_positions2(data)
return self.get_positions1(data) return self.get_positions1(data)
@ -528,12 +532,7 @@ class ModelQuantileReturnsPredictionsStrategy(ModelPredictionsStrategyBase):
'quantile_exit_short': self.quantile_exit_short 'quantile_exit_short': self.quantile_exit_short
} }
def get_positions(self, data): def _get_positions(self, data, arr_preds):
arr_target = data[self.target].to_numpy()
arr_preds = np.stack(
# bfill() is a hack to make it work with non predicted data
data['prediction'].ffill().bfill().to_numpy())
enter_long = (((arr_preds[ enter_long = (((arr_preds[
:, self.future - 1, self.get_quantile_idx( :, self.future - 1, self.get_quantile_idx(
round(1 - self.quantile_enter_long, 2))] round(1 - self.quantile_enter_long, 2))]

View File

@ -5,6 +5,7 @@ import pandas as pd
import numpy as np import numpy as np
from numba import jit from numba import jit
from numba import int32, float64, optional from numba import int32, float64, optional
import logging
def get_sweep_data_windows(sweep_id): def get_sweep_data_windows(sweep_id):
@ -26,22 +27,66 @@ def get_sweep_data_windows(sweep_id):
def get_data_windows(project, dataset_name, min_window=0, max_window=5): def get_data_windows(project, dataset_name, min_window=0, max_window=5):
artifact_name = f"{project}/{dataset_name}" artifact_name = f"{project}/{dataset_name}"
logging.info(f"Downloading artifact: {artifact_name}")
artifact = wandb.Api().artifact(artifact_name) artifact = wandb.Api().artifact(artifact_name)
base_path = artifact.download() base_path = artifact.download()
name = artifact.metadata['name'] # name = artifact.name # We don't need the artifact name itself anymore
# --- Load the single full dataset file --- Modified
full_data_file = os.path.join(base_path, 'full_data.parquet')
logging.info(f"Loading full dataset from: {full_data_file}")
if not os.path.exists(full_data_file):
raise FileNotFoundError(f"Expected parquet file not found in artifact: {full_data_file}")
full_df = pd.read_parquet(full_data_file)
logging.info(f"Loaded full dataset with shape: {full_df.shape}")
# --- End Load ---
result = [] result = []
for i in range(min_window, max_window+1): N = len(full_df)
in_sample_name =\ # Determine how many out-of-sample windows we need based on the loop
f"in-sample-{i}" num_oos_windows = max_window - min_window + 1
in_sample_data = pd.read_csv(os.path.join( if N < num_oos_windows:
base_path, name + '-' + in_sample_name + '.csv')) raise ValueError(f"Dataset length ({N}) is too short for the requested number of windows ({num_oos_windows})")
out_of_sample_name =\
f"out-of-sample-{i}" # Simplistic split: Assume the last ~20% is for OOS, split into required windows
out_of_sample_data = pd.read_csv(os.path.join( # This might need refinement based on how splits were actually done in training sweeps.
base_path, name + '-' + out_of_sample_name + '.csv')) # A more robust approach might get split info from artifact metadata if available.
oos_total_size = int(N * 0.2) # Example: use last 20% for all out-of-sample periods
if oos_total_size < num_oos_windows:
raise ValueError(f"Calculated out-of-sample size ({oos_total_size}) is too small for {num_oos_windows} windows.")
oos_chunk_size = oos_total_size // num_oos_windows
start_of_oos_data = N - oos_total_size
logging.info(f"Total OOS size: {oos_total_size}, Num OOS windows: {num_oos_windows}, OOS chunk size: {oos_chunk_size}, OOS starts at index: {start_of_oos_data}")
for i in range(min_window, max_window + 1):
# Calculate split points for this specific window i
# In-sample data is everything *before* the current OOS chunk starts
current_oos_start_index = start_of_oos_data + (i - min_window) * oos_chunk_size
current_oos_end_index = current_oos_start_index + oos_chunk_size
# Ensure the last chunk goes to the end if division wasn't perfect
if i == max_window:
current_oos_end_index = N
in_sample_data = full_df.iloc[0:current_oos_start_index].copy()
out_of_sample_data = full_df.iloc[current_oos_start_index:current_oos_end_index].copy()
logging.info(f"Window {i}: In-sample indices 0:{current_oos_start_index}, Out-of-sample indices {current_oos_start_index}:{current_oos_end_index}")
result.append((in_sample_data, out_of_sample_data)) result.append((in_sample_data, out_of_sample_data))
# --- Old file reading logic removed ---
# in_sample_name =\
# f"in-sample-{i}"
# in_sample_data = pd.read_csv(os.path.join(
# base_path, name + '-' + in_sample_name + '.csv'))
# out_of_sample_name =\
# f"out-of-sample-{i}"
# out_of_sample_data = pd.read_csv(os.path.join(
# base_path, name + '-' + out_of_sample_name + '.csv'))
# result.append((in_sample_data, out_of_sample_data))
# --- End removal ---
return result return result
@ -61,9 +106,15 @@ def get_sweep_window_predictions(sweep_id, part):
artifact_path = window_prediction.download() artifact_path = window_prediction.download()
index = torch.load(os.path.join( index = torch.load(os.path.join(
artifact_path, 'index.pt'), map_location=torch.device('cpu')) artifact_path, 'index.pt'),
map_location=torch.device('cpu'),
weights_only=False # Allow loading non-tensor objects
)
preds = torch.load(os.path.join( preds = torch.load(os.path.join(
artifact_path, 'predictions.pt'), map_location=torch.device('cpu')) artifact_path, 'predictions.pt'),
map_location=torch.device('cpu'),
weights_only=False # Allow loading non-tensor objects (safer to add here too)
)
result.append((window_num, index, preds.numpy())) result.append((window_num, index, preds.numpy()))