150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
from typing import Dict, List, cast
|
|
import pandas as pd
|
|
|
|
|
|
def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame:
|
|
df: pd.DataFrame = pd.DataFrame()
|
|
import os
|
|
if not os.path.exists(db_path):
|
|
print(f"WARNING: database file {db_path} does not exist")
|
|
return df
|
|
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
df = pd.read_sql_query(query, conn)
|
|
return df
|
|
except sqlite3.Error as excpt:
|
|
print(f"SQLite error: {excpt}")
|
|
raise
|
|
except Exception as excpt:
|
|
print(f"Error: {excpt}")
|
|
raise Exception() from excpt
|
|
finally:
|
|
if "conn" in locals():
|
|
conn.close()
|
|
|
|
|
|
def convert_time_to_UTC(value: str, timezone: str) -> str:
|
|
|
|
from zoneinfo import ZoneInfo
|
|
from datetime import datetime
|
|
|
|
# Parse it to naive datetime object
|
|
local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
|
|
|
|
zinfo = ZoneInfo(timezone)
|
|
result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC"))
|
|
|
|
return result.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
def load_market_data(
|
|
datafile: str,
|
|
instruments: List[Dict[str, str]],
|
|
db_table_name: str,
|
|
trading_hours: Dict = {},
|
|
) -> pd.DataFrame:
|
|
|
|
insts = [
|
|
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
|
|
for instrument in instruments
|
|
]
|
|
instrument_ids = list(set(insts))
|
|
exchange_ids = list(
|
|
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
|
|
)
|
|
|
|
query = "select"
|
|
query += " tstamp"
|
|
query += ", tstamp_ns as time_ns"
|
|
|
|
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
|
|
query += ", open"
|
|
query += ", high"
|
|
query += ", low"
|
|
query += ", close"
|
|
query += ", volume"
|
|
query += ", num_trades"
|
|
query += ", vwap"
|
|
|
|
query += f" from {db_table_name}"
|
|
query += f" where exchange_id in ({','.join(exchange_ids)})"
|
|
query += f" and instrument_id in ({','.join(instrument_ids)})"
|
|
|
|
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
|
|
|
|
# Trading Hours
|
|
if len(df) > 0 and len(trading_hours) > 0:
|
|
date_str = df["tstamp"][0][0:10]
|
|
|
|
start_time = convert_time_to_UTC(
|
|
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
|
|
)
|
|
end_time = convert_time_to_UTC(
|
|
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
|
|
)
|
|
|
|
# Perform boolean selection
|
|
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
|
|
df["tstamp"] = pd.to_datetime(df["tstamp"])
|
|
|
|
return cast(pd.DataFrame, df)
|
|
|
|
|
|
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
|
|
# """
|
|
# Auto-detect available instruments from the database by querying distinct instrument_id values.
|
|
# Returns instruments without the configured prefix.
|
|
# """
|
|
# try:
|
|
# conn = sqlite3.connect(datafile)
|
|
|
|
# # Build exclusion list with full instrument_ids
|
|
# exclude_instruments = config.get("exclude_instruments", [])
|
|
# prefix = config.get("instrument_id_pfx", "")
|
|
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
|
|
|
|
# # Query to get distinct instrument_ids
|
|
# query = f"""
|
|
# SELECT DISTINCT instrument_id
|
|
# FROM {config['db_table_name']}
|
|
# WHERE exchange_id = ?
|
|
# """
|
|
|
|
# # Add exclusion clause if there are instruments to exclude
|
|
# if exclude_instrument_ids:
|
|
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
|
|
# query += f" AND instrument_id NOT IN ({placeholders})"
|
|
# cursor = conn.execute(
|
|
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
|
|
# )
|
|
# else:
|
|
# cursor = conn.execute(query, (config["exchange_id"],))
|
|
# instrument_ids = [row[0] for row in cursor.fetchall()]
|
|
# conn.close()
|
|
|
|
# # Remove the configured prefix to get instrument symbols
|
|
# instruments = []
|
|
# for instrument_id in instrument_ids:
|
|
# if instrument_id.startswith(prefix):
|
|
# symbol = instrument_id[len(prefix) :]
|
|
# instruments.append(symbol)
|
|
# else:
|
|
# instruments.append(instrument_id)
|
|
|
|
# return sorted(instruments)
|
|
|
|
# except Exception as e:
|
|
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
|
|
# return []
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# df1 = load_sqlite_to_dataframe(sys.argv[1], table_name="md_1min_bars")
|
|
|
|
# print(df1)
|