pairs_trading/lib/tools/data_loader.py
Oleg Sheynin 71822c64b0 progress
2025-07-25 20:39:59 +00:00

152 lines
4.8 KiB
Python

from __future__ import annotations
import sqlite3
from typing import Dict, List, cast
import pandas as pd
def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame:
df: pd.DataFrame = pd.DataFrame()
import os
if not os.path.exists(db_path):
print(f"WARNING: database file {db_path} does not exist")
return df
try:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query(query, conn)
return df
except sqlite3.Error as excpt:
print(f"SQLite error: {excpt}")
raise
except Exception as excpt:
print(f"Error: {excpt}")
raise Exception() from excpt
finally:
if "conn" in locals():
conn.close()
def convert_time_to_UTC(value: str, timezone: str, extra_minutes: int = 0) -> str:
from zoneinfo import ZoneInfo
from datetime import datetime, timedelta
# Parse it to naive datetime object
local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
local_dt = local_dt + timedelta(minutes=extra_minutes)
zinfo = ZoneInfo(timezone)
result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC"))
return result.strftime("%Y-%m-%d %H:%M:%S")
def load_market_data(
datafile: str,
instruments: List[Dict[str, str]],
db_table_name: str,
trading_hours: Dict = {},
extra_minutes: int = 0,
) -> pd.DataFrame:
insts = [
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
for instrument in instruments
]
instrument_ids = list(set(insts))
exchange_ids = list(
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
)
query = "select"
query += " tstamp"
query += ", tstamp_ns as time_ns"
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
query += ", open"
query += ", high"
query += ", low"
query += ", close"
query += ", volume"
query += ", num_trades"
query += ", vwap"
query += f" from {db_table_name}"
query += f" where exchange_id in ({','.join(exchange_ids)})"
query += f" and instrument_id in ({','.join(instrument_ids)})"
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
# Trading Hours
if len(df) > 0 and len(trading_hours) > 0:
date_str = df["tstamp"][0][0:10]
start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
)
end_time = convert_time_to_UTC(
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"], extra_minutes=extra_minutes # to get execution price
)
# Perform boolean selection
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"])
return cast(pd.DataFrame, df)
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
# """
# Auto-detect available instruments from the database by querying distinct instrument_id values.
# Returns instruments without the configured prefix.
# """
# try:
# conn = sqlite3.connect(datafile)
# # Build exclusion list with full instrument_ids
# exclude_instruments = config.get("exclude_instruments", [])
# prefix = config.get("instrument_id_pfx", "")
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# # Query to get distinct instrument_ids
# query = f"""
# SELECT DISTINCT instrument_id
# FROM {config['db_table_name']}
# WHERE exchange_id = ?
# """
# # Add exclusion clause if there are instruments to exclude
# if exclude_instrument_ids:
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
# query += f" AND instrument_id NOT IN ({placeholders})"
# cursor = conn.execute(
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
# )
# else:
# cursor = conn.execute(query, (config["exchange_id"],))
# instrument_ids = [row[0] for row in cursor.fetchall()]
# conn.close()
# # Remove the configured prefix to get instrument symbols
# instruments = []
# for instrument_id in instrument_ids:
# if instrument_id.startswith(prefix):
# symbol = instrument_id[len(prefix) :]
# instruments.append(symbol)
# else:
# instruments.append(instrument_id)
# return sorted(instruments)
# except Exception as e:
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
# return []
# if __name__ == "__main__":
# df1 = load_sqlite_to_dataframe(sys.argv[1], table_name="md_1min_bars")
# print(df1)