pairs_trading/lib/tools/data_loader.py
Oleg Sheynin 85c9d2ab93 progress
2025-07-10 18:14:37 +00:00

139 lines
4.3 KiB
Python

import sqlite3
from typing import Dict, List, cast
import pandas as pd
def load_sqlite_to_dataframe(db_path, query):
try:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query(query, conn)
return df
except sqlite3.Error as excpt:
print(f"SQLite error: {excpt}")
raise
except Exception as excpt:
print(f"Error: {excpt}")
raise Exception() from excpt
finally:
if "conn" in locals():
conn.close()
def convert_time_to_UTC(value: str, timezone: str) -> str:
from zoneinfo import ZoneInfo
from datetime import datetime
# Parse it to naive datetime object
local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
zinfo = ZoneInfo(timezone)
result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC"))
return result.strftime("%Y-%m-%d %H:%M:%S")
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
from tools.data_loader import load_sqlite_to_dataframe
instrument_ids = [
'"' + config["instrument_id_pfx"] + instrument + '"'
for instrument in config["instruments"]
]
security_type = config["security_type"]
exchange_id = config["exchange_id"]
query = "select"
if security_type == "CRYPTO":
query += " strftime('%Y-%m-%d %H:%M:%S', tstamp_ns/1000000000, 'unixepoch') as tstamp"
query += ", tstamp as time_ns"
else:
query += " tstamp"
query += ", tstamp_ns as time_ns"
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol"
query += ", open"
query += ", high"
query += ", low"
query += ", close"
query += ", volume"
query += ", num_trades"
query += ", vwap"
query += f" from {config['db_table_name']}"
query += f" where exchange_id ='{exchange_id}'"
query += f" and instrument_id in ({','.join(instrument_ids)})"
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
# Trading Hours
date_str = df["tstamp"][0][0:10]
trading_hours = config["trading_hours"]
start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
)
end_time = convert_time_to_UTC(
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
)
# Perform boolean selection
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"])
return cast(pd.DataFrame, df)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# Remove the configured prefix to get instrument symbols
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
return sorted(instruments)
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# if __name__ == "__main__":
# df1 = load_sqlite_to_dataframe(sys.argv[1], table_name="md_1min_bars")
# print(df1)