progress
This commit is contained in:
parent
9240d20e16
commit
95b25eddd7
@ -15,10 +15,6 @@
|
||||
"db_table_name": "md_1min_bars",
|
||||
"exchange_id": "ALPACA",
|
||||
"instrument_id_pfx": "STOCK-",
|
||||
# "instruments": [
|
||||
# "COIN",
|
||||
# "GBTC"
|
||||
# ],
|
||||
"trading_hours": {
|
||||
"begin_session": "9:30:00",
|
||||
"end_session": "16:00:00",
|
||||
@ -31,5 +27,6 @@
|
||||
"dis-equilibrium_close_trshld": 1.0,
|
||||
"training_minutes": 120,
|
||||
"funding_per_pair": 2000.0,
|
||||
"strategy_class": "strategies.StaticFitStrategy"
|
||||
# "strategy_class": "strategies.StaticFitStrategy"
|
||||
"strategy_class": "strategies.SlidingFitStrategy"
|
||||
}
|
||||
@ -12,7 +12,7 @@ import pandas as pd
|
||||
|
||||
from tools.data_loader import load_market_data
|
||||
from tools.trading_pair import TradingPair
|
||||
from results import BacktestResult, create_result_database, store_results_in_database
|
||||
from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict:
|
||||
@ -203,6 +203,29 @@ def main() -> None:
|
||||
all_results: Dict[str, Dict[str, Any]] = {}
|
||||
bt_results = BacktestResult(config=config)
|
||||
|
||||
# Store configuration in database for reference
|
||||
if args.result_db.upper() != "NONE":
|
||||
# Get list of all instruments for storage
|
||||
all_instruments = []
|
||||
for datafile in datafiles:
|
||||
if args.instruments:
|
||||
file_instruments = [inst.strip() for inst in args.instruments.split(",")]
|
||||
else:
|
||||
file_instruments = get_available_instruments_from_db(datafile, config)
|
||||
all_instruments.extend(file_instruments)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_instruments = list(dict.fromkeys(all_instruments))
|
||||
|
||||
store_config_in_database(
|
||||
db_path=args.result_db,
|
||||
config_file_path=args.config,
|
||||
config=config,
|
||||
strategy_class=strategy_class_name,
|
||||
datafiles=datafiles,
|
||||
instruments=unique_instruments
|
||||
)
|
||||
|
||||
# Process each data file
|
||||
price_column = config["price_column"]
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ sqlite3.register_converter("datetime", convert_datetime)
|
||||
|
||||
def create_result_database(db_path: str) -> None:
|
||||
"""
|
||||
Create the SQLite database and pt_bt_results table if they don't exist.
|
||||
Create the SQLite database and required tables if they don't exist.
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
@ -58,6 +58,7 @@ def create_result_database(db_path: str) -> None:
|
||||
pair_return REAL
|
||||
)
|
||||
''')
|
||||
cursor.execute("DELETE FROM pt_bt_results;")
|
||||
|
||||
# Create the outstanding_positions table for open positions
|
||||
cursor.execute('''
|
||||
@ -72,6 +73,21 @@ def create_result_database(db_path: str) -> None:
|
||||
open_side TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute("DELETE FROM outstanding_positions;")
|
||||
|
||||
# Create the config table for storing configuration JSON for reference
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_timestamp DATETIME,
|
||||
config_file_path TEXT,
|
||||
config_json TEXT,
|
||||
strategy_class TEXT,
|
||||
datafiles TEXT,
|
||||
instruments TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute("DELETE FROM config;")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -81,6 +97,51 @@ def create_result_database(db_path: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None:
|
||||
"""
|
||||
Store configuration information in the database for reference.
|
||||
"""
|
||||
import json
|
||||
|
||||
if db_path.upper() == "NONE":
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Convert config to JSON string
|
||||
config_json = json.dumps(config, indent=2, default=str)
|
||||
|
||||
# Convert lists to comma-separated strings for storage
|
||||
datafiles_str = ', '.join(datafiles)
|
||||
instruments_str = ', '.join(instruments)
|
||||
|
||||
# Insert configuration record
|
||||
cursor.execute('''
|
||||
INSERT INTO config (
|
||||
run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
datetime.now(),
|
||||
config_file_path,
|
||||
config_json,
|
||||
strategy_class,
|
||||
datafiles_str,
|
||||
instruments_str
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"Configuration stored in database")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error storing configuration in database: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None:
|
||||
"""
|
||||
Store backtest results in the SQLite database.
|
||||
|
||||
@ -4,7 +4,7 @@ import sys
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
from tools.trading_pair import TradingPair
|
||||
from results import BacktestResult
|
||||
@ -22,7 +22,7 @@ class PairsTradingStrategy(ABC):
|
||||
"pair",
|
||||
]
|
||||
@abstractmethod
|
||||
def run_pair(self, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]:
|
||||
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]:
|
||||
...
|
||||
|
||||
class StaticFitStrategy(PairsTradingStrategy):
|
||||
@ -49,7 +49,7 @@ class StaticFitStrategy(PairsTradingStrategy):
|
||||
return pair_trades
|
||||
|
||||
def create_trading_signals(self, pair: TradingPair, config: Dict, result: BacktestResult) -> pd.DataFrame:
|
||||
beta = pair.vecm_fit_.beta
|
||||
beta = pair.vecm_fit_.beta # type: ignore
|
||||
colname_a, colname_b = pair.colnames()
|
||||
|
||||
predicted_df = pair.predicted_df_
|
||||
@ -229,7 +229,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
|
||||
testing_size=1
|
||||
)
|
||||
|
||||
if len(pair.training_df_) < training_minutes:
|
||||
if len(pair.training_df_) < training_minutes: # type: ignore
|
||||
print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.")
|
||||
if pair.user_data_["state"] == PairState.OPEN:
|
||||
print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.")
|
||||
@ -251,7 +251,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
|
||||
try:
|
||||
is_cointegrated = pair.train_pair()
|
||||
except Exception as e:
|
||||
raise Exception(f"{pair}: Training failed: {str(e)}") from e
|
||||
raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e
|
||||
|
||||
if pair.user_data_["is_cointegrated"] != is_cointegrated:
|
||||
pair.user_data_["is_cointegrated"] = is_cointegrated
|
||||
@ -271,7 +271,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
|
||||
try:
|
||||
pair.predict()
|
||||
except Exception as e:
|
||||
raise Exception(f"{pair}: Prediction failed: {str(e)}") from e
|
||||
raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e
|
||||
|
||||
if pair.user_data_["state"] == PairState.INITIAL:
|
||||
|
||||
@ -296,7 +296,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
|
||||
|
||||
predicted_df = pair.predicted_df_
|
||||
|
||||
open_row = predicted_df.loc[0]
|
||||
# Check if we have any data to work with
|
||||
if len(predicted_df) == 0:
|
||||
return None
|
||||
|
||||
open_row = predicted_df.iloc[0]
|
||||
open_tstamp = open_row["tstamp"]
|
||||
open_disequilibrium = open_row["disequilibrium"]
|
||||
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
|
||||
@ -359,7 +363,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
|
||||
def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]:
|
||||
colname_a, colname_b = pair.colnames()
|
||||
|
||||
close_row = pair.predicted_df_.loc[0]
|
||||
# Check if we have any data to work with
|
||||
if len(pair.predicted_df_) == 0:
|
||||
return None
|
||||
|
||||
close_row = pair.predicted_df_.iloc[0]
|
||||
close_tstamp = close_row["tstamp"]
|
||||
close_disequilibrium = close_row["disequilibrium"]
|
||||
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
|
||||
|
||||
169
src/utils/db_inspector.py
Normal file
169
src/utils/db_inspector.py
Normal file
@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database inspector utility for pairs trading results database.
|
||||
Provides functionality to view all tables and their contents.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
|
||||
def list_tables(db_path: str) -> List[str]:
|
||||
"""List all tables in the database."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table'
|
||||
ORDER BY name
|
||||
""")
|
||||
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
return tables
|
||||
|
||||
def view_table_schema(db_path: str, table_name: str):
|
||||
"""View the schema of a specific table."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = cursor.fetchall()
|
||||
|
||||
print(f"\nTable: {table_name}")
|
||||
print("-" * 50)
|
||||
print("Column Name".ljust(20) + "Type".ljust(15) + "Not Null".ljust(10) + "Default")
|
||||
print("-" * 50)
|
||||
|
||||
for col in columns:
|
||||
cid, name, type_, not_null, default_value, pk = col
|
||||
print(f"{name}".ljust(20) + f"{type_}".ljust(15) + f"{bool(not_null)}".ljust(10) + f"{default_value or ''}")
|
||||
|
||||
conn.close()
|
||||
|
||||
def view_config_table(db_path: str, limit: int = 10):
|
||||
"""View entries from the config table."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(f"""
|
||||
SELECT id, run_timestamp, config_file_path, strategy_class,
|
||||
datafiles, instruments, config_json
|
||||
FROM config
|
||||
ORDER BY run_timestamp DESC
|
||||
LIMIT {limit}
|
||||
""")
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
print("No configuration entries found.")
|
||||
return
|
||||
|
||||
print(f"\nMost recent {len(rows)} configuration entries:")
|
||||
print("=" * 80)
|
||||
|
||||
for row in rows:
|
||||
id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row
|
||||
|
||||
print(f"ID: {id} | {run_timestamp}")
|
||||
print(f"Config: {config_file_path} | Strategy: {strategy_class}")
|
||||
print(f"Files: {datafiles}")
|
||||
print(f"Instruments: {instruments}")
|
||||
print("-" * 80)
|
||||
|
||||
conn.close()
|
||||
|
||||
def view_results_summary(db_path: str):
|
||||
"""View summary of trading results."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get results summary
|
||||
cursor.execute("""
|
||||
SELECT date, COUNT(*) as trade_count,
|
||||
ROUND(SUM(symbol_return), 2) as total_return
|
||||
FROM pt_bt_results
|
||||
GROUP BY date
|
||||
ORDER BY date DESC
|
||||
""")
|
||||
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
print("No trading results found.")
|
||||
return
|
||||
|
||||
print(f"\nTrading Results Summary:")
|
||||
print("-" * 50)
|
||||
print("Date".ljust(15) + "Trades".ljust(10) + "Total Return %")
|
||||
print("-" * 50)
|
||||
|
||||
for date, trade_count, total_return in results:
|
||||
print(f"{date}".ljust(15) + f"{trade_count}".ljust(10) + f"{total_return}")
|
||||
|
||||
# Get outstanding positions summary
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) as position_count,
|
||||
ROUND(SUM(unrealized_return), 2) as total_unrealized
|
||||
FROM outstanding_positions
|
||||
""")
|
||||
|
||||
outstanding = cursor.fetchone()
|
||||
if outstanding and outstanding[0] > 0:
|
||||
print(f"\nOutstanding Positions: {outstanding[0]} positions")
|
||||
print(f"Total Unrealized Return: {outstanding[1]}%")
|
||||
|
||||
conn.close()
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python db_inspector.py <database_path> [command]")
|
||||
print("Commands:")
|
||||
print(" tables - List all tables")
|
||||
print(" schema - Show schema for all tables")
|
||||
print(" config - View configuration entries")
|
||||
print(" results - View trading results summary")
|
||||
print(" all - Show everything (default)")
|
||||
print("\nExample: python db_inspector.py results/equity.db config")
|
||||
sys.exit(1)
|
||||
|
||||
db_path = sys.argv[1]
|
||||
command = sys.argv[2] if len(sys.argv) > 2 else "all"
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
print(f"Database file not found: {db_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if command in ["tables", "all"]:
|
||||
tables = list_tables(db_path)
|
||||
print(f"Tables in database: {', '.join(tables)}")
|
||||
|
||||
if command in ["schema", "all"]:
|
||||
tables = list_tables(db_path)
|
||||
for table in tables:
|
||||
view_table_schema(db_path, table)
|
||||
|
||||
if command in ["config", "all"]:
|
||||
if "config" in list_tables(db_path):
|
||||
view_config_table(db_path)
|
||||
else:
|
||||
print("Config table not found.")
|
||||
|
||||
if command in ["results", "all"]:
|
||||
if "pt_bt_results" in list_tables(db_path):
|
||||
view_results_summary(db_path)
|
||||
else:
|
||||
print("Results table not found.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error inspecting database: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user