informer_model/scripts/prepare_btc_data.py
2025-04-30 05:07:31 +00:00

514 lines
24 KiB
Python

import argparse
import glob
import logging
import os
import sqlite3
import pandas as pd
import pandas_ta as ta
import numpy as np
import wandb
import tempfile
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_data_from_db(db_path, table_name="klines"):
"""Loads data from a specific table in an SQLite database."""
logging.info(f"Reading data from {db_path}, table '{table_name}'...")
try:
conn = sqlite3.connect(db_path)
# Query to check if table exists
cursor = conn.cursor()
cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';")
if cursor.fetchone() is None:
logging.warning(f"Table '{table_name}' not found in {db_path}. Skipping.")
return None
# Adjust column names if necessary based on your actual schema
query = f"SELECT timestamp, open, high, low, close, volume FROM {table_name} WHERE instrument_id LIKE 'PAIR-BTC-%'"
# --- Add logging for the query --- New
logging.info(f"Executing query: {query}")
df = pd.read_sql_query(query, conn)
# --- Add logging for rows read --- New
logging.info(f"Read {len(df)} rows matching the criteria from {db_path}")
# --- Log raw timestamp range --- New
if not df.empty:
logging.info(f"Raw timestamp range: {df['timestamp'].min()} to {df['timestamp'].max()}")
# --- End logging ---
except sqlite3.Error as e:
logging.error(f"Error reading database {db_path}: {e}")
return None
finally:
if conn:
conn.close()
return df
def calculate_features(df):
"""Calculates technical indicators and other features."""
logging.info("Calculating base features...")
# df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') # Removed: Already indexed in main
# df = df.set_index('datetime').sort_index() # Removed: Already indexed in main
# Adjust column names if your input df has different names
# open_col, high_col, low_col, close_col, vol_col = 'open', 'high', 'low', 'close', 'volume' # Old names
# --- Use the renamed column names --- New
open_col, high_col, low_col, close_col, vol_col = 'open_price', 'high_price', 'low_price', 'close_price', 'volume'
# Drop rows with missing essential data before calculations
df = df.dropna(subset=[open_col, high_col, low_col, close_col, vol_col])
if df.empty:
logging.warning("DataFrame is empty after dropping NaNs in essential columns.")
return df
# --- Basic Price Features ---
df['open_to_close_price'] = df[close_col] / df[open_col] - 1
df['high_to_close_price'] = df[high_col] / df[close_col] - 1
df['low_to_close_price'] = df[low_col] / df[close_col] - 1
df['high_to_low_price'] = df[high_col] / df[low_col] - 1
# --- Returns ---
# Shift(1) calculates return based on previous close: (close_t / close_{t-1}) - 1
df['returns'] = df[close_col].pct_change()
df['log_returns'] = np.log(df[close_col] / df[close_col].shift(1))
# --- Time Features ---
df['hour'] = df.index.hour.astype(str).astype("category") # Use string/category as required by config
df['weekday'] = df.index.weekday.astype(str).astype("category") # Use string/category
# --- Technical Indicators using pandas_ta ---
logging.info("Calculating technical indicators (this may take a while)...")
custom_strategy = ta.Strategy(
name="informer_features",
description="Calculate features for Informer model based on config",
ta=[
# Volatility (adjust lengths as needed, config doesn't specify)
{"kind": "atr", "length": 14, "col_names": "atr"}, # Example ATR
# MACD
{"kind": "macd", "fast": 12, "slow": 26, "signal": 9, "col_names": ("macd", "macd_hist", "macd_signal")},
# RSI
{"kind": "rsi", "length": 14, "col_names": "rsi"},
# Bollinger Bands
{"kind": "bbands", "length": 20, "std": 2, "col_names": ("low_bband", "mid_bband", "up_bband", "bandwidth", "percent")},
# SMA (1h=12*5m, 1d=288*5m, 7d=2016*5m)
{"kind": "sma", "length": 12, "col_names": "sma_1h"},
{"kind": "sma", "length": 288, "col_names": "sma_1d"},
{"kind": "sma", "length": 2016, "col_names": "sma_7d"},
# EMA (1h=12*5m, 1d=288*5m) - Note: Config only lists ema_1h, ema_1d relative to close
{"kind": "ema", "length": 12, "col_names": "ema_1h"},
{"kind": "ema", "length": 288, "col_names": "ema_1d"},
]
)
df.ta.strategy(custom_strategy)
# --- Volatility (Calculated on Price/Returns - Choose appropriate source) ---
# Using log returns is common for volatility calculation
df['vol_1h'] = df['log_returns'].rolling(window=12).std() * np.sqrt(12) # Scaled 1h vol
df['vol_1d'] = df['log_returns'].rolling(window=288).std() * np.sqrt(288) # Scaled daily vol
df['vol_7d'] = df['log_returns'].rolling(window=2016).std() * np.sqrt(2016) # Scaled weekly vol
# --- Relative Indicators (indicator / close_price) ---
logging.info("Calculating relative indicators...")
for indicator in ['low_bband', 'mid_bband', 'up_bband', 'sma_1h', 'sma_1d', 'sma_7d', 'ema_1h', 'ema_1d']:
if indicator in df.columns:
df[f'{indicator}_to_close_price'] = df[indicator] / df[close_col] -1
else:
logging.warning(f"Base indicator '{indicator}' not found for relative calculation.")
# --- Clean up intermediate columns if needed ---
# df = df.drop(columns=['atr', 'macd_hist', 'low_bband', 'mid_bband', 'up_bband', 'bandwidth', 'percent', 'sma_1h', 'sma_1d', 'sma_7d', 'ema_1h', 'ema_1d'])
# --- Handle initial NaNs introduced by rolling windows/shifts ---
# returns and log_returns will have NaN for the first row.
# Indicators will have NaNs for their window length.
# We will forward-fill later after merging external data.
return df
def load_external_data(file_path, date_col, value_col, rename_to=None):
"""Loads external daily data like VIX or Fear/Greed Index."""
logging.info(f"Loading external data from {file_path}...")
try:
df = pd.read_csv(file_path)
df[date_col] = pd.to_datetime(df[date_col])
# Keep only date and value, rename value column
df = df[[date_col, value_col]].rename(columns={value_col: rename_to or value_col})
# --- Normalize the date index --- New
df = df.set_index(date_col).sort_index()
df.index = df.index.normalize() # Ensure time is midnight
logging.info(f"Loaded {len(df)} records from {file_path}. Index normalized.")
return df
except FileNotFoundError:
logging.error(f"External data file not found: {file_path}")
return None
except Exception as e:
logging.error(f"Error loading external data from {file_path}: {e}")
return None
def main(db_pattern, db_table, vix_file, fear_greed_file, eff_rate_file, args):
"""Main function to load, process, and save data."""
db_files = glob.glob(os.path.expanduser(db_pattern), recursive=True)
if not db_files:
logging.error(f"No database files found matching pattern: {db_pattern}")
return
logging.info(f"Found {len(db_files)} database files.")
all_data = []
for db_file in db_files:
df = load_data_from_db(db_file, table_name=db_table)
if df is not None:
all_data.append(df)
if not all_data:
logging.error("No data loaded from any database file.")
return
logging.info("Concatenating data from all databases...")
btc_df = pd.concat(all_data, ignore_index=True)
# --- Log raw timestamp info --- New
if not btc_df.empty:
logging.info(f"Raw timestamp column info - dtype: {btc_df['timestamp'].dtype}, head:\n{btc_df['timestamp'].head()}")
else:
logging.warning("BTC DataFrame empty after concat, cannot check raw timestamp.")
# --- End logging ---
# --- Initial Processing ---
# Convert timestamp to datetime and sort
btc_df['datetime'] = pd.to_datetime(btc_df['timestamp'], unit='s')
# --- Add logging to check converted dates --- New
if not btc_df.empty:
logging.info(f"Converted datetime range: {btc_df['datetime'].min()} to {btc_df['datetime'].max()}")
else:
logging.warning("BTC DataFrame is empty after concatenation, cannot check datetime range.")
# --- End logging ---
# Deduplicate based on timestamp, keep first entry
btc_df = btc_df.sort_values('datetime').drop_duplicates(subset=['timestamp'], keep='first')
# --- Rename price columns to match config --- Moved Earlier
rename_map = {'open': 'open_price', 'high': 'high_price', 'low': 'low_price', 'close': 'close_price'}
btc_df = btc_df.rename(columns=rename_map)
logging.info(f"Renamed columns: {rename_map}")
# --- Set index and log info ---
btc_df = btc_df.set_index('datetime').sort_index()
if not btc_df.empty:
# --- Log index details --- Modified
logging.info(f"DataFrame index info - dtype: {btc_df.index.dtype}, timezone: {btc_df.index.tz}, range: {btc_df.index.min()} to {btc_df.index.max()}") # Added timezone check
logging.info(f"DataFrame head(1):\n{btc_df.head(1)}") # Added head check
else:
logging.warning("BTC DataFrame empty after setting index.")
# --- End logging ---
logging.info(f"Total unique records after concatenation: {len(btc_df)}")
# --- Resample to 5-minute Intervals --- New
logging.info("Resampling 1-minute data to 5-minute intervals...")
resampling_rules = {
'open_price': 'first',
'high_price': 'max',
'low_price': 'min',
'close_price': 'last',
'volume': 'sum'
}
# Ensure columns exist before resampling
missing_cols = [col for col in resampling_rules if col not in btc_df.columns]
if missing_cols:
logging.error(f"Cannot resample, required columns missing: {missing_cols}")
return
btc_df = btc_df[list(resampling_rules.keys())].resample('5T').agg(resampling_rules)
# Drop rows where resampling might have produced all NaNs (e.g., gaps in original data)
btc_df.dropna(subset=['open_price', 'high_price', 'low_price', 'close_price'], inplace=True)
logging.info(f"Resampled data shape: {btc_df.shape}")
if not btc_df.empty:
logging.info(f"Resampled index range: {btc_df.index.min()} to {btc_df.index.max()}")
logging.info(f"Resampled head(1):\n{btc_df.head(1)}")
else:
logging.warning("DataFrame empty after resampling.")
return # Stop if empty after resampling
# --- End Resampling ---
# --- Feature Calculation ---
# Now operates on the 5-minute resampled data
btc_df = calculate_features(btc_df)
if btc_df.empty:
logging.error("DataFrame became empty during feature calculation.")
return
# --- Load and Merge External Data ---
# VIX Data - Assuming daily data
# vix_df = load_external_data(vix_file, date_col='Date', value_col='VIX Close', rename_to='vix_close_price') # Old
vix_df = load_external_data(vix_file, date_col='date', value_col='close', rename_to='vix_close_price') # Corrected
# Fear & Greed Data - Assuming daily data
# fg_df = load_external_data(fear_greed_file, date_col='timestamp', value_col='value', rename_to='fear_greed_index') # Old
fg_df = load_external_data(fear_greed_file, date_col='date', value_col='fng_value', rename_to='fear_greed_index') # Corrected
# --- Load Effective Rates Data ---
eff_rates_df = load_external_data(eff_rate_file, date_col='observation_date', value_col='DFF', rename_to='effective_rates')
# --- Log External Data Index Info & Timezones --- Modified
if vix_df is not None: logging.info(f"VIX index info - dtype: {vix_df.index.dtype}, timezone: {vix_df.index.tz}, range: {vix_df.index.min()} to {vix_df.index.max()}")
if fg_df is not None: logging.info(f"F&G index info - dtype: {fg_df.index.dtype}, timezone: {fg_df.index.tz}, range: {fg_df.index.min()} to {fg_df.index.max()}")
if eff_rates_df is not None: logging.info(f"EffRates index info - dtype: {eff_rates_df.index.dtype}, timezone: {eff_rates_df.index.tz}, range: {eff_rates_df.index.min()} to {eff_rates_df.index.max()}")
# --- Log external data near BTC start --- New
if not btc_df.empty:
first_btc_time = btc_df.index.min()
logging.info(f"First BTC timestamp: {first_btc_time}")
if vix_df is not None:
logging.info(f"VIX data at/before start:\n{vix_df[vix_df.index <= first_btc_time].tail()}")
if fg_df is not None:
logging.info(f"F&G data at/before start:\n{fg_df[fg_df.index <= first_btc_time].tail()}")
if eff_rates_df is not None:
logging.info(f"EffRates data at/before start:\n{eff_rates_df[eff_rates_df.index <= first_btc_time].tail()}")
# --- End logging ---
# --- Perform merge_asof ---
logging.info("Performing merge_asof based on DatetimeIndex...")
# Ensure DataFrames are sorted by index (should be, but explicit is safer)
btc_df = btc_df.sort_index()
if vix_df is not None: vix_df = vix_df.sort_index()
if fg_df is not None: fg_df = fg_df.sort_index()
if eff_rates_df is not None: eff_rates_df = eff_rates_df.sort_index()
if vix_df is not None:
btc_df = pd.merge_asof(btc_df, vix_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after VIX merge_asof: {btc_df.shape}, VIX NaNs: {btc_df['vix_close_price'].isna().sum()}")
if fg_df is not None:
btc_df = pd.merge_asof(btc_df, fg_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after F&G merge_asof: {btc_df.shape}, F&G NaNs: {btc_df['fear_greed_index'].isna().sum()}")
if eff_rates_df is not None:
btc_df = pd.merge_asof(btc_df, eff_rates_df, left_index=True, right_index=True, direction='backward')
logging.info(f"Shape after EffRates merge_asof: {btc_df.shape}, EffRates NaNs: {btc_df['effective_rates'].isna().sum()}")
logging.info("Finished merge_asof operations.")
# --- End Merge Block ---
# --- Add logging after merge ---
logging.info(f"BTC data after merge - Shape: {btc_df.shape}, Null counts:\n{btc_df.isna().sum().sort_values(ascending=False).head()}")
# --- Final Preparations ---
logging.info("Performing final data preparation steps...")
# Add required columns not generated yet
btc_df['group_id'] = "BTC-USDT" # Static group ID for this dataset
btc_df['group_id'] = btc_df['group_id'].astype("category")
# Create the sequential time index required by pytorch-forecasting
btc_df = btc_df.sort_index() # Ensure sorted before creating index
# --- Add close_time column --- New
# The index represents the start of the 5min interval
# Close time is 5 minutes after the start
btc_df['close_time'] = btc_df.index + pd.Timedelta(minutes=5)
logging.info(f"Added 'close_time' column. Head:\n{btc_df['close_time'].head()}")
# --- End Add close_time ---
btc_df = btc_df.reset_index() # Bring datetime back as a column temporarily
btc_df['time_index'] = btc_df.index # Create sequential integer index
# Define final columns based on the YAML config (ensure all generated features are included)
# Make sure these names match exactly what was generated
final_columns = [
"time_index", "group_id", "returns", # Core fields
"close_time", # Added missing column
# dynamic_unknown_real
"high_price", "low_price", "open_price", "close_price", "volume",
"open_to_close_price", "high_to_close_price", "low_to_close_price", "high_to_low_price",
"log_returns", "vol_1h", "macd", "macd_signal", "rsi",
"low_bband_to_close_price", "up_bband_to_close_price", "mid_bband_to_close_price",
"sma_1h_to_close_price", "sma_1d_to_close_price", "sma_7d_to_close_price",
"ema_1h_to_close_price", "ema_1d_to_close_price",
# dynamic_known_real (Check if these exist after merge)
"vix_close_price", "fear_greed_index", "vol_1d", "vol_7d", "effective_rates",
# dynamic_known_cat
"hour", "weekday"
]
# TODO: Add 'effective_rates' back if you load and merge it
# Select and reorder columns, handling potential missing external cols
cols_to_select = []
for col in final_columns:
if col in btc_df.columns:
cols_to_select.append(col)
else:
logging.warning(f"Required column '{col}' not found in DataFrame. It will be excluded.")
final_df = btc_df[cols_to_select]
# --- Handle Missing Values ---
# Forward fill is common for time series, especially after merges and indicator calculations
# Note: Ffill might not be suitable for returns, but initial NaNs in returns are expected.
# Consider specific handling if needed.
logging.info(f"Forward filling NaNs. Initial NaN count:\n{final_df.isna().sum().sort_values(ascending=False).head()}")
final_df = final_df.ffill()
# Drop any rows that *still* have NaNs (e.g., at the very beginning before first external data point or calc window)
initial_rows = len(final_df)
# --- Modify dropna to be less aggressive --- New # Comment Needs Update
# Define critical columns that *must* be present
# critical_cols = ['open_price', 'high_price', 'low_price', 'close_price', 'volume', 'returns']
# Check if critical columns exist before using them in subset
# subset_cols = [col for col in critical_cols if col in final_df.columns]
# if not subset_cols:
# logging.warning("No critical columns found for dropna subset. Skipping dropna.")
# else:
# logging.info(f"Dropping rows where any of {subset_cols} are NaN.")
# final_df = final_df.dropna(subset=subset_cols)
# --- Drop rows with ANY NaN value --- Modified
logging.info(f"Dropping rows with any NaN values.")
final_df = final_df.dropna()
# --- End modification ---
rows_dropped = initial_rows - len(final_df)
if rows_dropped > 0:
logging.warning(f"Dropped {rows_dropped} rows containing NaNs after forward filling.")
# Final check
if final_df.isna().any().any():
logging.warning(f"NaN values still present after processing:\n{final_df.isna().sum()[final_df.isna().sum() > 0]}")
else:
logging.info("No remaining NaN values detected.")
if final_df.empty:
logging.error("Final DataFrame is empty after processing and NaN handling.")
return
# --- Removed Data Splitting Logic ---
# split_ratio = 0.8 # Use 80% for in-sample
# split_index = int(len(final_df) * split_ratio)
#
# in_sample_df = final_df.iloc[:split_index]
# out_of_sample_df = final_df.iloc[split_index:]
#
# logging.info(f"Split data: {len(in_sample_df)} in-sample rows, {len(out_of_sample_df)} out-of-sample rows.")
# logging.info(f"In-sample time range: {in_sample_df['time_index'].min()} to {in_sample_df['time_index'].max()}")
# logging.info(f"Out-of-sample time range: {out_of_sample_df['time_index'].min()} to {out_of_sample_df['time_index'].max()}")
# --- End Split Removal ---
# --- Log Single Artifact to W&B --- Modified
logging.info(f"Logging full dataset artifact to W&B project '{wandb.run.project}', run '{wandb.run.name}'...")
try:
with tempfile.TemporaryDirectory() as tempdir:
# Save the entire final_df
full_data_path = os.path.join(tempdir, 'full_data.parquet')
final_df.to_parquet(full_data_path, index=False)
logging.info(f"Temporary file saved to {tempdir}")
# Create and log the single artifact
full_artifact = wandb.Artifact(
name=args.full_dataset_artifact_name, # Use new arg
type='dataset',
description=f'Full BTC 5min features data ({len(final_df)} rows). Prepared by run {wandb.run.id}.',
metadata={'rows': len(final_df)}
)
full_artifact.add_file(full_data_path)
wandb.log_artifact(full_artifact)
logging.info(f"Logged full dataset artifact: {args.full_dataset_artifact_name}")
# --- Removed logging for separate artifacts ---
# # Create and log the IN-SAMPLE artifact
# in_sample_artifact = wandb.Artifact(
# name=args.in_sample_artifact_name, # Use arg
# type='dataset',
# description=f'In-sample BTC 5min data ({len(in_sample_df)} rows). Prepared by run {wandb.run.id}.',
# metadata={'rows': len(in_sample_df), 'split': 'in_sample'}
# )
# in_sample_artifact.add_file(in_sample_path)
# wandb.log_artifact(in_sample_artifact)
# logging.info(f"Logged in-sample artifact: {args.in_sample_artifact_name}")
#
# # Create and log the OUT-OF-SAMPLE artifact
# out_of_sample_artifact = wandb.Artifact(
# name=args.out_of_sample_artifact_name, # Use arg
# type='dataset',
# description=f'Out-of-sample BTC 5min data ({len(out_of_sample_df)} rows). Prepared by run {wandb.run.id}.',
# metadata={'rows': len(out_of_sample_df), 'split': 'out_of_sample'}
# )
# out_of_sample_artifact.add_file(out_of_sample_path)
# wandb.log_artifact(out_of_sample_artifact)
# logging.info(f"Logged out-of-sample artifact: {args.out_of_sample_artifact_name}")
logging.info("Artifact logged successfully.")
except Exception as e:
logging.error(f"Error logging artifacts to W&B: {e}")
wandb.run.finish(exit_code=1) # Finish run with error
return
# --- End W&B Logging ---
wandb.run.finish() # Finish run successfully
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare BTC-USDT 5-minute data and log to W&B.")
parser.add_argument(
"--db-pattern",
default="/home/yasha/develop/data/combined.coinbase_1min_hist.db",
help="Pattern or exact path to find input SQLite database file(s)."
)
parser.add_argument(
"--db-table",
default="combined_hist_1min",
help="Name of the table containing kline data within the SQLite files."
)
parser.add_argument(
"--vix-file",
default="data/vix_daily.csv",
help="Path to the VIX index CSV file."
)
parser.add_argument(
"--fear-greed-file",
default="data/fear_greed_index.csv",
help="Path to the Crypto Fear & Greed Index CSV file."
)
parser.add_argument(
"--eff-rate-file",
default="data/DFF.csv",
help="Path to the Effective Rates CSV file."
)
parser.add_argument(
"--wandb-project",
default="wne-masters-thesis-testing",
help="W&B project name."
)
parser.add_argument(
"--wandb-run-name",
default="prepare-btc-data",
help="W&B run name for this preparation job."
)
parser.add_argument(
"--wandb-notes",
default=None,
help="Optional notes for the W&B run."
)
parser.add_argument(
"--full-dataset-artifact-name",
default="btc-5m-features-full", # Match YAML default
help="Name for the single W&B artifact containing the full dataset."
)
args = parser.parse_args()
# --- Initialize W&B Run --- New
run = wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
notes=args.wandb_notes,
job_type="data-preparation",
config=vars(args) # Log command line args
)
# --- End W&B Init ---
# --- Pass args to main --- Modified
main(
db_pattern=args.db_pattern,
db_table=args.db_table,
vix_file=args.vix_file,
fear_greed_file=args.fear_greed_file,
eff_rate_file=args.eff_rate_file,
args=args # Pass all args for artifact names etc.
)