diff --git a/configuration/crypto.cfg b/configuration/crypto.cfg index 5adcead..eb2da76 100644 --- a/configuration/crypto.cfg +++ b/configuration/crypto.cfg @@ -29,5 +29,5 @@ "dis-equilibrium_close_trshld": 0.5, "training_minutes": 120, "funding_per_pair": 2000.0, - "strategy_class": "trading.strategies.StaticFitStrategy" + "fit_method_class": "pt_trading.fit_methods.StaticFit" } \ No newline at end of file diff --git a/configuration/equity.cfg b/configuration/equity.cfg index dffd077..ae64dea 100644 --- a/configuration/equity.cfg +++ b/configuration/equity.cfg @@ -19,8 +19,7 @@ "dis-equilibrium_close_trshld": 1.0, "training_minutes": 120, "funding_per_pair": 2000.0, - # "strategy_class": "strategies.StaticFitStrategy" - "strategy_class": "trading.strategies.SlidingFitStrategy" + "fit_method_class": "pt_trading.fit_methods.SlidingFit", "exclude_instruments": ["CAN"] } \ No newline at end of file diff --git a/src/cvtt/mkt_data.py b/lib/cvtt/mkt_data.py similarity index 100% rename from src/cvtt/mkt_data.py rename to lib/cvtt/mkt_data.py diff --git a/src/trading/strategies.py b/lib/pt_trading/fit_methods.py similarity index 97% rename from src/trading/strategies.py rename to lib/pt_trading/fit_methods.py index a327fd5..0789bd0 100644 --- a/src/trading/strategies.py +++ b/lib/pt_trading/fit_methods.py @@ -1,16 +1,15 @@ from abc import ABC, abstractmethod from enum import Enum - from typing import Dict, Optional, cast import pandas as pd # type: ignore[import] -from trading.trading_pair import TradingPair -from trading.results import BacktestResult +from pt_trading.results import BacktestResult +from pt_trading.trading_pair import TradingPair NanoPerMin = 1e9 -class PairsTradingStrategy(ABC): +class PairsTradingFitMethod(ABC): TRADES_COLUMNS = [ "time", "action", @@ -28,7 +27,7 @@ class PairsTradingStrategy(ABC): def reset(self): ... -class StaticFitStrategy(PairsTradingStrategy): +class StaticFit(PairsTradingFitMethod): def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod pair.get_datasets(training_minutes=config["training_minutes"]) @@ -203,7 +202,7 @@ class StaticFitStrategy(PairsTradingStrategy): columns=self.TRADES_COLUMNS, # type: ignore ) - def reset(self): + def reset(self) -> None: pass class PairState(Enum): @@ -211,8 +210,8 @@ class PairState(Enum): OPEN = 2 CLOSED = 3 -class SlidingFitStrategy(PairsTradingStrategy): - def __init__(self): +class SlidingFit(PairsTradingFitMethod): + def __init__(self) -> None: super().__init__() self.curr_training_start_idx_ = 0 @@ -235,7 +234,7 @@ class SlidingFitStrategy(PairsTradingStrategy): testing_size=1 ) - if len(pair.training_df_) < training_minutes: # type: ignore + if len(pair.training_df_) < training_minutes: print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.") if pair.user_data_["state"] == PairState.OPEN: print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.") diff --git a/src/trading/results.py b/lib/pt_trading/results.py similarity index 69% rename from src/trading/results.py rename to lib/pt_trading/results.py index 56e34d9..db3def4 100644 --- a/src/trading/results.py +++ b/lib/pt_trading/results.py @@ -11,18 +11,22 @@ def adapt_date_iso(val): """Adapt datetime.date to ISO 8601 date.""" return val.isoformat() + def adapt_datetime_iso(val): """Adapt datetime.datetime to timezone-naive ISO 8601 date.""" return val.isoformat() + def convert_date(val): """Convert ISO 8601 date to datetime.date object.""" return datetime.fromisoformat(val.decode()).date() + def convert_datetime(val): """Convert ISO 8601 datetime to datetime.datetime object.""" return datetime.fromisoformat(val.decode()) + # Register the adapters and converters sqlite3.register_adapter(date, adapt_date_iso) sqlite3.register_adapter(datetime, adapt_datetime_iso) @@ -37,9 +41,10 @@ def create_result_database(db_path: str) -> None: try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # Create the pt_bt_results table for completed trades - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS pt_bt_results ( date DATE, pair TEXT, @@ -57,11 +62,13 @@ def create_result_database(db_path: str) -> None: symbol_return REAL, pair_return REAL ) - ''') + """ + ) cursor.execute("DELETE FROM pt_bt_results;") # Create the outstanding_positions table for open positions - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS outstanding_positions ( date DATE, pair TEXT, @@ -72,120 +79,138 @@ def create_result_database(db_path: str) -> None: open_price REAL, open_side TEXT ) - ''') + """ + ) cursor.execute("DELETE FROM outstanding_positions;") - + # Create the config table for storing configuration JSON for reference - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS config ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_timestamp DATETIME, config_file_path TEXT, config_json TEXT, - strategy_class TEXT, + fit_method_class TEXT, datafiles TEXT, instruments TEXT ) - ''') + """ + ) cursor.execute("DELETE FROM config;") - + conn.commit() conn.close() - + except Exception as e: print(f"Error creating result database: {str(e)}") raise -def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None: +def store_config_in_database( + db_path: str, + config_file_path: str, + config: Dict, + fit_method_class: str, + datafiles: List[str], + instruments: List[str], +) -> None: """ Store configuration information in the database for reference. """ import json - + if db_path.upper() == "NONE": return - + try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # Convert config to JSON string config_json = json.dumps(config, indent=2, default=str) - + # Convert lists to comma-separated strings for storage - datafiles_str = ', '.join(datafiles) - instruments_str = ', '.join(instruments) - + datafiles_str = ", ".join(datafiles) + instruments_str = ", ".join(instruments) + # Insert configuration record - cursor.execute(''' + cursor.execute( + """ INSERT INTO config ( - run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments + run_timestamp, config_file_path, config_json, fit_method_class, datafiles, instruments ) VALUES (?, ?, ?, ?, ?, ?) - ''', ( - datetime.now(), - config_file_path, - config_json, - strategy_class, - datafiles_str, - instruments_str - )) - + """, + ( + datetime.now(), + config_file_path, + config_json, + fit_method_class, + datafiles_str, + instruments_str, + ), + ) + conn.commit() conn.close() - + print(f"Configuration stored in database") - + except Exception as e: print(f"Error storing configuration in database: {str(e)}") import traceback + traceback.print_exc() -def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None: +def store_results_in_database( + db_path: str, datafile: str, bt_result: "BacktestResult" +) -> None: """ Store backtest results in the SQLite database. """ if db_path.upper() == "NONE": return - + def convert_timestamp(timestamp): """Convert pandas Timestamp to Python datetime object for SQLite compatibility.""" if timestamp is None: return None - if hasattr(timestamp, 'to_pydatetime'): + if hasattr(timestamp, "to_pydatetime"): return timestamp.to_pydatetime() return timestamp - + try: # Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db) filename = os.path.basename(datafile) - date_str = filename.split('.')[0] # Extract date part - + date_str = filename.split(".")[0] # Extract date part + # Convert to proper date format try: - date_obj = datetime.strptime(date_str, '%Y%m%d').date() + date_obj = datetime.strptime(date_str, "%Y%m%d").date() except ValueError: # If date parsing fails, use current date date_obj = datetime.now().date() - + conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # Process each trade from bt_result trades = bt_result.get_trades() - + for pair_name, symbols in trades.items(): # Calculate pair return for this pair pair_return = 0.0 pair_trades = [] - + # First pass: collect all trades and calculate returns for symbol, symbol_trades in symbols.items(): if len(symbol_trades) == 0: # No trades for this symbol - print(f"Warning: No trades found for symbol {symbol} in pair {pair_name}") + print( + f"Warning: No trades found for symbol {symbol} in pair {pair_name}" + ) continue - + elif len(symbol_trades) >= 2: # Completed trades (entry + exit) # Handle both old and new tuple formats if len(symbol_trades[0]) == 2: # Old format: (action, price) @@ -198,138 +223,190 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR open_time = datetime.now() close_time = datetime.now() else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp) - entry_action, entry_price, open_disequilibrium, open_scaled_disequilibrium, open_time = symbol_trades[0] - exit_action, exit_price, close_disequilibrium, close_scaled_disequilibrium, close_time = symbol_trades[1] - + ( + entry_action, + entry_price, + open_disequilibrium, + open_scaled_disequilibrium, + open_time, + ) = symbol_trades[0] + ( + exit_action, + exit_price, + close_disequilibrium, + close_scaled_disequilibrium, + close_time, + ) = symbol_trades[1] + # Handle None values - open_disequilibrium = open_disequilibrium if open_disequilibrium is not None else 0.0 - open_scaled_disequilibrium = open_scaled_disequilibrium if open_scaled_disequilibrium is not None else 0.0 - close_disequilibrium = close_disequilibrium if close_disequilibrium is not None else 0.0 - close_scaled_disequilibrium = close_scaled_disequilibrium if close_scaled_disequilibrium is not None else 0.0 - + open_disequilibrium = ( + open_disequilibrium + if open_disequilibrium is not None + else 0.0 + ) + open_scaled_disequilibrium = ( + open_scaled_disequilibrium + if open_scaled_disequilibrium is not None + else 0.0 + ) + close_disequilibrium = ( + close_disequilibrium + if close_disequilibrium is not None + else 0.0 + ) + close_scaled_disequilibrium = ( + close_scaled_disequilibrium + if close_scaled_disequilibrium is not None + else 0.0 + ) + # Convert pandas Timestamps to Python datetime objects open_time = convert_timestamp(open_time) or datetime.now() close_time = convert_timestamp(close_time) or datetime.now() - + # Calculate actual share quantities based on funding per pair # Split funding equally between the two positions funding_per_position = bt_result.config["funding_per_pair"] / 2 shares = funding_per_position / entry_price - + # Calculate symbol return symbol_return = 0.0 if entry_action == "BUY" and exit_action == "SELL": symbol_return = (exit_price - entry_price) / entry_price * 100 elif entry_action == "SELL" and exit_action == "BUY": symbol_return = (entry_price - exit_price) / entry_price * 100 - + pair_return += symbol_return - - pair_trades.append({ - 'symbol': symbol, - 'entry_action': entry_action, - 'entry_price': entry_price, - 'exit_action': exit_action, - 'exit_price': exit_price, - 'symbol_return': symbol_return, - 'open_disequilibrium': open_disequilibrium, - 'open_scaled_disequilibrium': open_scaled_disequilibrium, - 'close_disequilibrium': close_disequilibrium, - 'close_scaled_disequilibrium': close_scaled_disequilibrium, - 'open_time': open_time, - 'close_time': close_time, - 'shares': shares, - 'is_completed': True - }) - + + pair_trades.append( + { + "symbol": symbol, + "entry_action": entry_action, + "entry_price": entry_price, + "exit_action": exit_action, + "exit_price": exit_price, + "symbol_return": symbol_return, + "open_disequilibrium": open_disequilibrium, + "open_scaled_disequilibrium": open_scaled_disequilibrium, + "close_disequilibrium": close_disequilibrium, + "close_scaled_disequilibrium": close_scaled_disequilibrium, + "open_time": open_time, + "close_time": close_time, + "shares": shares, + "is_completed": True, + } + ) + # Skip one-sided trades - they will be handled by outstanding_positions table elif len(symbol_trades) == 1: - print(f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table") + print( + f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table" + ) continue - + else: # This should not happen, but handle unexpected cases - print(f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}") + print( + f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}" + ) continue - + # Second pass: insert completed trade records into database for trade in pair_trades: # Only store completed trades in pt_bt_results table - cursor.execute(''' + cursor.execute( + """ INSERT INTO pt_bt_results ( date, pair, symbol, open_time, open_side, open_price, open_quantity, open_disequilibrium, close_time, close_side, close_price, close_quantity, close_disequilibrium, symbol_return, pair_return ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - date_obj, - pair_name, - trade['symbol'], - trade['open_time'], - trade['entry_action'], - trade['entry_price'], - trade['shares'], - trade['open_scaled_disequilibrium'], - trade['close_time'], - trade['exit_action'], - trade['exit_price'], - trade['shares'], - trade['close_scaled_disequilibrium'], - trade['symbol_return'], - pair_return - )) - + """, + ( + date_obj, + pair_name, + trade["symbol"], + trade["open_time"], + trade["entry_action"], + trade["entry_price"], + trade["shares"], + trade["open_scaled_disequilibrium"], + trade["close_time"], + trade["exit_action"], + trade["exit_price"], + trade["shares"], + trade["close_scaled_disequilibrium"], + trade["symbol_return"], + pair_return, + ), + ) + # Store outstanding positions in separate table outstanding_positions = bt_result.get_outstanding_positions() for pos in outstanding_positions: # Calculate position quantity (negative for SELL positions) - position_qty_a = pos['shares_a'] if pos['side_a'] == 'BUY' else -pos['shares_a'] - position_qty_b = pos['shares_b'] if pos['side_b'] == 'BUY' else -pos['shares_b'] - + position_qty_a = ( + pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"] + ) + position_qty_b = ( + pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"] + ) + # Calculate unrealized returns # For symbol A: (current_price - open_price) / open_price * 100 * position_direction - unrealized_return_a = ((pos['current_px_a'] - pos['open_px_a']) / pos['open_px_a'] * 100) * (1 if pos['side_a'] == 'BUY' else -1) - unrealized_return_b = ((pos['current_px_b'] - pos['open_px_b']) / pos['open_px_b'] * 100) * (1 if pos['side_b'] == 'BUY' else -1) - + unrealized_return_a = ( + (pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100 + ) * (1 if pos["side_a"] == "BUY" else -1) + unrealized_return_b = ( + (pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100 + ) * (1 if pos["side_b"] == "BUY" else -1) + # Store outstanding position for symbol A - cursor.execute(''' + cursor.execute( + """ INSERT INTO outstanding_positions ( date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - date_obj, - pos['pair'], - pos['symbol_a'], - position_qty_a, - pos['current_px_a'], - unrealized_return_a, - pos['open_px_a'], - pos['side_a'] - )) - + """, + ( + date_obj, + pos["pair"], + pos["symbol_a"], + position_qty_a, + pos["current_px_a"], + unrealized_return_a, + pos["open_px_a"], + pos["side_a"], + ), + ) + # Store outstanding position for symbol B - cursor.execute(''' + cursor.execute( + """ INSERT INTO outstanding_positions ( date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - date_obj, - pos['pair'], - pos['symbol_b'], - position_qty_b, - pos['current_px_b'], - unrealized_return_b, - pos['open_px_b'], - pos['side_b'] - )) - + """, + ( + date_obj, + pos["pair"], + pos["symbol_b"], + position_qty_b, + pos["current_px_b"], + unrealized_return_b, + pos["open_px_b"], + pos["side_b"], + ), + ) + conn.commit() conn.close() - + except Exception as e: print(f"Error storing results in database: {str(e)}") import traceback + traceback.print_exc() @@ -344,7 +421,16 @@ class BacktestResult: self.total_realized_pnl = 0.0 self.outstanding_positions: List[Dict[str, Any]] = [] - def add_trade(self, pair_nm, symbol, action, price, disequilibrium=None, scaled_disequilibrium=None, timestamp=None): + def add_trade( + self, + pair_nm, + symbol, + action, + price, + disequilibrium=None, + scaled_disequilibrium=None, + timestamp=None, + ): """Add a trade to the results tracking.""" pair_nm = str(pair_nm) @@ -352,7 +438,9 @@ class BacktestResult: self.trades[pair_nm] = {symbol: []} if symbol not in self.trades[pair_nm]: self.trades[pair_nm][symbol] = [] - self.trades[pair_nm][symbol].append((action, price, disequilibrium, scaled_disequilibrium, timestamp)) + self.trades[pair_nm][symbol].append( + (action, price, disequilibrium, scaled_disequilibrium, timestamp) + ) def add_outstanding_position(self, position: Dict[str, Any]): """Add an outstanding position to tracking.""" @@ -390,13 +478,17 @@ class BacktestResult: action = row.action symbol = row.symbol price = row.price - disequilibrium = getattr(row, 'disequilibrium', None) - scaled_disequilibrium = getattr(row, 'scaled_disequilibrium', None) - timestamp = getattr(row, 'time', None) + disequilibrium = getattr(row, "disequilibrium", None) + scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None) + timestamp = getattr(row, "time", None) self.add_trade( - pair_nm=row.pair, action=action, symbol=symbol, price=price, - disequilibrium=disequilibrium, scaled_disequilibrium=scaled_disequilibrium, - timestamp=timestamp + pair_nm=row.pair, + action=action, + symbol=symbol, + price=price, + disequilibrium=disequilibrium, + scaled_disequilibrium=scaled_disequilibrium, + timestamp=timestamp, ) def print_single_day_results(self): @@ -447,19 +539,31 @@ class BacktestResult: else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp) entry_action, entry_price = trades[0][:2] exit_action, exit_price = trades[1][:2] - open_disequilibrium = trades[0][2] if len(trades[0]) > 2 else None - open_scaled_disequilibrium = trades[0][3] if len(trades[0]) > 3 else None - close_disequilibrium = trades[1][2] if len(trades[1]) > 2 else None - close_scaled_disequilibrium = trades[1][3] if len(trades[1]) > 3 else None + open_disequilibrium = ( + trades[0][2] if len(trades[0]) > 2 else None + ) + open_scaled_disequilibrium = ( + trades[0][3] if len(trades[0]) > 3 else None + ) + close_disequilibrium = ( + trades[1][2] if len(trades[1]) > 2 else None + ) + close_scaled_disequilibrium = ( + trades[1][3] if len(trades[1]) > 3 else None + ) # Calculate return based on action symbol_return = 0 if entry_action == "BUY" and exit_action == "SELL": # Long position - symbol_return = (exit_price - entry_price) / entry_price * 100 + symbol_return = ( + (exit_price - entry_price) / entry_price * 100 + ) elif entry_action == "SELL" and exit_action == "BUY": # Short position - symbol_return = (entry_price - exit_price) / entry_price * 100 + symbol_return = ( + (entry_price - exit_price) / entry_price * 100 + ) pair_trades.append( ( @@ -489,9 +593,12 @@ class BacktestResult: close_scaled_disequilibrium, ) in pair_trades: disequil_info = "" - if open_scaled_disequilibrium is not None and close_scaled_disequilibrium is not None: + if ( + open_scaled_disequilibrium is not None + and close_scaled_disequilibrium is not None + ): disequil_info = f" | Open Dis-eq: {open_scaled_disequilibrium:.2f}, Close Dis-eq: {close_scaled_disequilibrium:.2f}" - + print( f" {symbol}: {entry_action} @ ${entry_price:.2f}, {exit_action} @ ${exit_price:.2f}, Return: {symbol_return:.2f}%{disequil_info}" ) @@ -582,9 +689,17 @@ class BacktestResult: print(f"\n====== GRAND TOTALS ACROSS ALL PAIRS ======") print(f"Total Realized PnL: {self.get_total_realized_pnl():.2f}%") - def handle_outstanding_position(self, pair, pair_result_df, last_row_index, - open_side_a, open_side_b, open_px_a, open_px_b, - open_tstamp): + def handle_outstanding_position( + self, + pair, + pair_result_df, + last_row_index, + open_side_a, + open_side_b, + open_px_a, + open_px_b, + open_tstamp, + ): """ Handle calculation and tracking of outstanding positions when no close signal is found. @@ -648,9 +763,15 @@ class BacktestResult: # Print position details print(f"{pair}: NO CLOSE SIGNAL FOUND - Position held until end of session") print(f" Open: {open_tstamp} | Last: {last_tstamp}") - print(f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}") - print(f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}") + print( + f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}" + ) + print( + f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}" + ) print(f" Total Value: ${total_current_value:.2f}") - print(f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}") + print( + f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}" + ) - return current_value_a, current_value_b, total_current_value \ No newline at end of file + return current_value_a, current_value_b, total_current_value diff --git a/src/trading/trading_pair.py b/lib/pt_trading/trading_pair.py similarity index 100% rename from src/trading/trading_pair.py rename to lib/pt_trading/trading_pair.py diff --git a/lib/tools/config.py b/lib/tools/config.py new file mode 100644 index 0000000..a88aa74 --- /dev/null +++ b/lib/tools/config.py @@ -0,0 +1,17 @@ +import hjson +from typing import Dict +from datetime import datetime + + +def load_config(config_path: str) -> Dict: + with open(config_path, "r") as f: + config = hjson.load(f) + return dict(config) + + +def expand_filename(filename: str) -> str: + # expand %T + res = filename.replace("%T", datetime.now().strftime("%Y%m%d_%H%M%S")) + # expand %D + return res.replace("%D", datetime.now().strftime("%Y%m%d")) + diff --git a/src/tools/data_loader.py b/lib/tools/data_loader.py similarity index 100% rename from src/tools/data_loader.py rename to lib/tools/data_loader.py diff --git a/src/utils/db_inspector.py b/lib/utils/db_inspector.py similarity index 91% rename from src/utils/db_inspector.py rename to lib/utils/db_inspector.py index 3267920..99da030 100644 --- a/src/utils/db_inspector.py +++ b/lib/utils/db_inspector.py @@ -25,7 +25,7 @@ def list_tables(db_path: str) -> List[str]: conn.close() return tables -def view_table_schema(db_path: str, table_name: str): +def view_table_schema(db_path: str, table_name: str) -> None: """View the schema of a specific table.""" conn = sqlite3.connect(db_path) cursor = conn.cursor() @@ -44,13 +44,13 @@ def view_table_schema(db_path: str, table_name: str): conn.close() -def view_config_table(db_path: str, limit: int = 10): +def view_config_table(db_path: str, limit: int = 10) -> None: """View entries from the config table.""" conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(f""" - SELECT id, run_timestamp, config_file_path, strategy_class, + SELECT id, run_timestamp, config_file_path, fit_method_class, datafiles, instruments, config_json FROM config ORDER BY run_timestamp DESC @@ -67,17 +67,17 @@ def view_config_table(db_path: str, limit: int = 10): print("=" * 80) for row in rows: - id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row + id, run_timestamp, config_file_path, fit_method_class, datafiles, instruments, config_json = row print(f"ID: {id} | {run_timestamp}") - print(f"Config: {config_file_path} | Strategy: {strategy_class}") + print(f"Config: {config_file_path} | Strategy: {fit_method_class}") print(f"Files: {datafiles}") print(f"Instruments: {instruments}") print("-" * 80) conn.close() -def view_results_summary(db_path: str): +def view_results_summary(db_path: str) -> None: """View summary of trading results.""" conn = sqlite3.connect(db_path) cursor = conn.cursor() @@ -119,7 +119,7 @@ def view_results_summary(db_path: str): conn.close() -def main(): +def main() -> None: if len(sys.argv) < 2: print("Usage: python db_inspector.py [command]") print("Commands:") diff --git a/pyrightconfig.json b/pyrightconfig.json index d22e8a7..9bf3138 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,6 +1,6 @@ { "include": [ - "src" + "lib" ], "exclude": [ "**/node_modules", @@ -16,7 +16,7 @@ "autoImportCompletions": true, "autoSearchPaths": true, "extraPaths": [ - "src" + "lib" ], "stubPath": "./typings", "venvPath": ".", diff --git a/src/notebooks/pt_pair_backtest.ipynb b/research/notebooks/pt_pair_backtest.ipynb similarity index 99% rename from src/notebooks/pt_pair_backtest.ipynb rename to research/notebooks/pt_pair_backtest.ipynb index f732329..cd4b2ef 100644 --- a/src/notebooks/pt_pair_backtest.ipynb +++ b/research/notebooks/pt_pair_backtest.ipynb @@ -62,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -87,10 +87,10 @@ "from IPython.display import clear_output\n", "\n", "# Import our modules\n", - "from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n", + "from pt_trading.fit_methods import StaticFit, SlidingFit, PairState\n", "from tools.data_loader import load_market_data\n", - "from trading.trading_pair import TradingPair\n", - "from trading.results import BacktestResult\n", + "from pt_trading.trading_pair import TradingPair\n", + "from pt_trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -149,34 +149,34 @@ " print(f\"Unexpected error loading config from {config_file}: {e}\")\n", " return None\n", "\n", - "def instantiate_strategy_from_config(config: Dict):\n", + "def instantiate_fit_method_from_config(config: Dict):\n", " \"\"\"Dynamically instantiate strategy from config\"\"\"\n", - " strategy_class_name = config.get(\"strategy_class\", \"strategies.StaticFitStrategy\")\n", - " \n", + " fit_method_class_name = config.get(\"fit_method_class\", None)\n", + " assert fit_method_class_name is not None\n", " try:\n", " # Split module and class name\n", - " if '.' in strategy_class_name:\n", - " module_name, class_name = strategy_class_name.rsplit('.', 1)\n", + " if '.' in fit_method_class_name:\n", + " module_name, class_name = fit_method_class_name.rsplit('.', 1)\n", " else:\n", - " module_name = \"strategies\"\n", - " class_name = strategy_class_name\n", + " module_name = \"fit_methods\"\n", + " class_name = fit_method_class_name\n", " \n", " # Import module and get class\n", " module = importlib.import_module(module_name)\n", - " strategy_class = getattr(module, class_name)\n", + " fit_method_class = getattr(module, class_name)\n", " \n", " # Instantiate strategy\n", - " return strategy_class()\n", + " return fit_method_class()\n", " \n", " except Exception as e:\n", - " print(f\"Error instantiating strategy {strategy_class_name}: {e}\")\n", + " print(f\"Error instantiating strategy {fit_method_class_name}: {e}\")\n", " print(\"Falling back to StaticFitStrategy\")\n", " return StaticFitStrategy()\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -230,7 +230,7 @@ " print(f\" Close threshold: {pt_bt_config['dis-equilibrium_close_trshld']}\")\n", " \n", " # Instantiate strategy from config\n", - " STRATEGY = instantiate_strategy_from_config(pt_bt_config)\n", + " STRATEGY = instantiate_fit_method_from_config(pt_bt_config)\n", " print(f\" Strategy: {type(STRATEGY).__name__}\")\n", " \n", " # Automatically construct data file name based on date and config type\n", @@ -576,7 +576,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -831,12 +831,12 @@ " max_demo_iterations = min(200, max_iterations)\n", " print(f\"Processing first {max_demo_iterations} iterations for demonstration...\")\n", " \n", - " # Initialize pair state for sliding strategy\n", + " # Initialize pair state for sliding fit method\n", " pair.user_data_['state'] = PairState.INITIAL\n", " pair.user_data_[\"trades\"] = pd.DataFrame(columns=pd.Index(STRATEGY.TRADES_COLUMNS, dtype=str))\n", " pair.user_data_[\"is_cointegrated\"] = False\n", " \n", - " # Run the sliding strategy\n", + " # Run the sliding fit method\n", " pair_trades = STRATEGY.run_pair(config=pt_bt_config, pair=pair, bt_result=bt_result)\n", " \n", " if pair_trades is not None and len(pair_trades) > 0:\n", diff --git a/src/notebooks/pt_sliding.ipynb b/research/notebooks/pt_sliding.ipynb similarity index 99% rename from src/notebooks/pt_sliding.ipynb rename to research/notebooks/pt_sliding.ipynb index 4e539a1..6867a28 100644 --- a/src/notebooks/pt_sliding.ipynb +++ b/research/notebooks/pt_sliding.ipynb @@ -111,10 +111,10 @@ "from IPython.display import clear_output\n", "\n", "# Import our modules\n", - "from strategies import SlidingFitStrategy, PairState\n", + "from pt_trading.fit_methods import SlidingFit, PairState\n", "from tools.data_loader import load_market_data\n", - "from trading.trading_pair import TradingPair\n", - "from trading.results import BacktestResult\n", + "from pt_trading.trading_pair import TradingPair\n", + "from pt_trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", diff --git a/src/notebooks/pt_static.ipynb b/research/notebooks/pt_static.ipynb similarity index 98% rename from src/notebooks/pt_static.ipynb rename to research/notebooks/pt_static.ipynb index 201152d..4c202b4 100644 --- a/src/notebooks/pt_static.ipynb +++ b/research/notebooks/pt_static.ipynb @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -96,10 +96,10 @@ "from typing import Dict, List, Optional\n", "\n", "# Import our modules\n", - "from strategies import StaticFitStrategy, SlidingFitStrategy\n", + "from pt_trading.fit_methods import StaticFit, SlidingFit\n", "from tools.data_loader import load_market_data\n", - "from trading.trading_pair import TradingPair\n", - "from trading.results import BacktestResult\n", + "from pt_trading.trading_pair import TradingPair\n", + "from pt_trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", @@ -226,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -246,11 +246,11 @@ "DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n", "\n", "# Choose strategy\n", - "STRATEGY = StaticFitStrategy()\n", + "FIT_METHOD = StaticFit()\n", "\n", "print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n", "print(f\"Data file: {DATA_FILE}\")\n", - "print(f\"Strategy: {type(STRATEGY).__name__}\")" + "print(f\"Strategy: {type(FIT_METHOD).__name__}\")" ] }, { @@ -548,7 +548,7 @@ "\n", " # Run strategy\n", " bt_result = BacktestResult(config=CONFIG)\n", - " pair_trades = STRATEGY.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n", + " pair_trades = FIT_METHOD.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n", "\n", " if pair_trades is not None and len(pair_trades) > 0:\n", " print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n", @@ -674,7 +674,7 @@ "print(\"=\" * 60)\n", "\n", "print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n", - "print(f\"Strategy: {type(STRATEGY).__name__}\")\n", + "print(f\"Strategy: {type(FIT_METHOD).__name__}\")\n", "print(f\"Data file: {DATA_FILE}\")\n", "print(f\"Training period: {training_minutes} minutes\")\n", "\n", diff --git a/src/pt_backtest.py b/research/pt_backtest.py similarity index 91% rename from src/pt_backtest.py rename to research/pt_backtest.py index 9e734c6..ea33f15 100644 --- a/src/pt_backtest.py +++ b/research/pt_backtest.py @@ -1,29 +1,22 @@ import argparse -import hjson -import importlib import glob +import importlib import os -from datetime import datetime, date - +from datetime import date, datetime from typing import Any, Dict, List, Optional import pandas as pd +from tools.config import expand_filename, load_config from tools.data_loader import get_available_instruments_from_db, load_market_data -from trading.strategies import PairsTradingStrategy -from trading.trading_pair import TradingPair -from trading.results import ( +from pt_trading.results import ( BacktestResult, create_result_database, - store_results_in_database, store_config_in_database, + store_results_in_database, ) - - -def load_config(config_path: str) -> Dict: - with open(config_path, "r") as f: - config = hjson.load(f) - return dict(config) +from pt_trading.fit_methods import PairsTradingFitMethod +from pt_trading.trading_pair import TradingPair def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: @@ -69,7 +62,7 @@ def run_backtest( config: Dict, datafile: str, price_column: str, - strategy: PairsTradingStrategy, + fit_method: PairsTradingFitMethod, instruments: List[str], ) -> BacktestResult: """ @@ -101,7 +94,7 @@ def run_backtest( pairs_trades = [] for pair in _create_pairs(config, instruments): - single_pair_trades = strategy.run_pair( + single_pair_trades = fit_method.run_pair( pair=pair, config=config, bt_result=bt_result ) if single_pair_trades is not None and len(single_pair_trades) > 0: @@ -148,11 +141,12 @@ def main() -> None: config: Dict = load_config(args.config) - # Dynamically instantiate strategy class - strategy_class_name = config.get("strategy_class", "strategies.StaticFitStrategy") - module_name, class_name = strategy_class_name.rsplit(".", 1) + # Dynamically instantiate fit method class + fit_method_class_name = config.get("fit_method_class", None) + assert fit_method_class_name is not None + module_name, class_name = fit_method_class_name.rsplit(".", 1) module = importlib.import_module(module_name) - strategy = getattr(module, class_name)() + fit_method = getattr(module, class_name)() # Resolve data files (CLI takes priority over config) datafiles = resolve_datafiles(config, args.datafiles) @@ -167,6 +161,7 @@ def main() -> None: # Create result database if needed if args.result_db.upper() != "NONE": + args.result_db = expand_filename(args.result_db) create_result_database(args.result_db) # Initialize a dictionary to store all trade results @@ -192,7 +187,7 @@ def main() -> None: db_path=args.result_db, config_file_path=args.config, config=config, - strategy_class=strategy_class_name, + fit_method_class=fit_method_class_name, datafiles=datafiles, instruments=unique_instruments, ) @@ -219,13 +214,13 @@ def main() -> None: # Process data for this file try: - strategy.reset() + fit_method.reset() bt_results = run_backtest( config=config, datafile=datafile, price_column=price_column, - strategy=strategy, + fit_method=fit_method, instruments=instruments, ) diff --git a/strategy/pair_strategy.py b/strategy/pair_strategy.py new file mode 100644 index 0000000..7407115 --- /dev/null +++ b/strategy/pair_strategy.py @@ -0,0 +1,220 @@ +import argparse +import asyncio +import glob +import importlib +import os +from datetime import date, datetime +from typing import Any, Dict, List, Optional + +import hjson +import pandas as pd + +from tools.data_loader import get_available_instruments_from_db, load_market_data +from pt_trading.results import ( + BacktestResult, + create_result_database, + store_config_in_database, + store_results_in_database, +) +from pt_trading.fit_methods import PairsTradingFitMethod +from pt_trading.trading_pair import TradingPair + + + + +def run_strategy( + config: Dict, + datafile: str, + price_column: str, + fit_method: PairsTradingFitMethod, + instruments: List[str], +) -> BacktestResult: + """ + Run backtest for all pairs using the specified instruments. + """ + bt_result: BacktestResult = BacktestResult(config=config) + + def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: + nonlocal datafile + all_indexes = range(len(instruments)) + unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] + pairs = [] + + # Update config to use the specified instruments + config_copy = config.copy() + config_copy["instruments"] = instruments + + market_data_df = load_market_data(datafile, config=config_copy) + + for a_index, b_index in unique_index_pairs: + pair = TradingPair( + market_data=market_data_df, + symbol_a=instruments[a_index], + symbol_b=instruments[b_index], + price_column=price_column, + ) + pairs.append(pair) + return pairs + + pairs_trades = [] + for pair in _create_pairs(config, instruments): + single_pair_trades = fit_method.run_pair( + pair=pair, config=config, bt_result=bt_result + ) + if single_pair_trades is not None and len(single_pair_trades) > 0: + pairs_trades.append(single_pair_trades) + + # Check if result_list has any data before concatenating + if len(pairs_trades) == 0: + print("No trading signals found for any pairs") + return bt_result + + result = pd.concat(pairs_trades, ignore_index=True) + result["time"] = pd.to_datetime(result["time"]) + result = result.set_index("time").sort_index() + + bt_result.collect_single_day_results(result) + return bt_result + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run pairs trading backtest.") + parser.add_argument( + "--config", type=str, required=True, help="Path to the configuration file." + ) + parser.add_argument( + "--datafiles", + type=str, + required=False, + help="Comma-separated list of data files (overrides config). No wildcards supported.", + ) + parser.add_argument( + "--instruments", + type=str, + required=False, + help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.", + ) + parser.add_argument( + "--result_db", + type=str, + required=True, + help="Path to SQLite database for storing results. Use 'NONE' to disable database output.", + ) + + args = parser.parse_args() + + config: Dict = load_config(args.config) + + # Dynamically instantiate fit method class + fit_method_class_name = config.get("fit_method_class", None) + assert fit_method_class_name is not None + module_name, class_name = fit_method_class_name.rsplit(".", 1) + module = importlib.import_module(module_name) + fit_method = getattr(module, class_name)() + + # Resolve data files (CLI takes priority over config) + datafiles = resolve_datafiles(config, args.datafiles) + + if not datafiles: + print("No data files found to process.") + return + + print(f"Found {len(datafiles)} data files to process:") + for df in datafiles: + print(f" - {df}") + + # Create result database if needed + if args.result_db.upper() != "NONE": + create_result_database(args.result_db) + + # Initialize a dictionary to store all trade results + all_results: Dict[str, Dict[str, Any]] = {} + + # Store configuration in database for reference + if args.result_db.upper() != "NONE": + # Get list of all instruments for storage + all_instruments = [] + for datafile in datafiles: + if args.instruments: + file_instruments = [ + inst.strip() for inst in args.instruments.split(",") + ] + else: + file_instruments = get_available_instruments_from_db(datafile, config) + all_instruments.extend(file_instruments) + + # Remove duplicates while preserving order + unique_instruments = list(dict.fromkeys(all_instruments)) + + store_config_in_database( + db_path=args.result_db, + config_file_path=args.config, + config=config, + fit_method_class=fit_method_class_name, + datafiles=datafiles, + instruments=unique_instruments, + ) + + # Process each data file + price_column = config["price_column"] + + for datafile in datafiles: + print(f"\n====== Processing {os.path.basename(datafile)} ======") + + # Determine instruments to use + if args.instruments: + # Use CLI-specified instruments + instruments = [inst.strip() for inst in args.instruments.split(",")] + print(f"Using CLI-specified instruments: {instruments}") + else: + # Auto-detect instruments from database + instruments = get_available_instruments_from_db(datafile, config) + print(f"Auto-detected instruments: {instruments}") + + if not instruments: + print(f"No instruments found for {datafile}, skipping...") + continue + + # Process data for this file + try: + fit_method.reset() + + bt_results = run_strategy( + config=config, + datafile=datafile, + price_column=price_column, + fit_method=fit_method, + instruments=instruments, + ) + + # Store results with file name as key + filename = os.path.basename(datafile) + all_results[filename] = {"trades": bt_results.trades.copy()} + + # Store results in database + if args.result_db.upper() != "NONE": + store_results_in_database(args.result_db, datafile, bt_results) + + print(f"Successfully processed {filename}") + + except Exception as err: + print(f"Error processing {datafile}: {str(err)}") + import traceback + + traceback.print_exc() + + # Calculate and print results using a new BacktestResult instance for aggregation + if all_results: + aggregate_bt_results = BacktestResult(config=config) + aggregate_bt_results.calculate_returns(all_results) + aggregate_bt_results.print_grand_totals() + aggregate_bt_results.print_outstanding_positions() + + if args.result_db.upper() != "NONE": + print(f"\nResults stored in database: {args.result_db}") + else: + print("No results to display.") + + +if __name__ == "__main__": + asyncio.run(main())