progress and result.py fixes

This commit is contained in:
Oleg Sheynin 2025-07-24 06:51:46 +00:00
parent 577fb5c109
commit e30b0df4db
10 changed files with 587 additions and 855 deletions

View File

@ -1,21 +1,13 @@
{
"instrument_type_specifics": {
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"datafiles": [
"20250602.mktdata.ohlcv.db"
],
"db_table_name": "md_1min_bars",
"exchange_id": "BNBSPOT",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"datafiles": [
"20250602.mktdata.ohlcv.db"
],
"db_table_name": "md_1min_bars",
"exchange_id": "BNBSPOT",
"instrument_id_pfx": "STOCK-",
}
},

View File

@ -1,212 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
from pt_trading.fit_method import PairsTradingFitMethod
NanoPerMin = 1e9
class StaticFit(PairsTradingFitMethod):
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]: # abstractmethod
config = pair.config_
pair.get_datasets(training_minutes=config["training_minutes"])
try:
pair.predict()
except Exception as e:
print(f"{pair}: Prediction failed: {str(e)}")
return None
pair_trades = self.create_trading_signals(
pair=pair, config=config, result=bt_result
)
return pair_trades
def create_trading_signals(
self, pair: TradingPair, config: Dict, result: BacktestResult
) -> pd.DataFrame:
beta = pair.vecm_fit_.beta # type: ignore
colname_a, colname_b = pair.colnames()
predicted_df = pair.predicted_df_
if predicted_df is None:
# Return empty DataFrame with correct columns and dtypes
return pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
# Iterate through the testing dataset to find the first trading opportunity
open_row_index = None
for row_idx in range(len(predicted_df)):
curr_disequilibrium = predicted_df["scaled_disequilibrium"][row_idx]
# Check if current row has sufficient disequilibrium (not near-zero)
if curr_disequilibrium >= open_threshold:
open_row_index = row_idx
break
# If no row with sufficient disequilibrium found, skip this pair
if open_row_index is None:
print(f"{pair}: Insufficient disequilibrium in testing dataset. Skipping.")
return pd.DataFrame()
# Look for close signal starting from the open position
trading_signals_df = (
predicted_df["scaled_disequilibrium"][open_row_index:] < close_threshold
)
# Adjust indices to account for the offset from open_row_index
close_row_index = None
for idx, value in trading_signals_df.items():
if value:
close_row_index = idx
break
open_row = predicted_df.loc[open_row_index]
open_px_a = predicted_df.at[open_row_index, f"{colname_a}"]
open_px_b = predicted_df.at[open_row_index, f"{colname_b}"]
open_tstamp = predicted_df.at[open_row_index, "tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
abs_beta = abs(beta[1])
pred_px_b = predicted_df.loc[open_row_index][f"{colname_b}_pred"]
pred_px_a = predicted_df.loc[open_row_index][f"{colname_a}_pred"]
if pred_px_b * abs_beta - pred_px_a > 0:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
else:
open_side_b = "BUY"
open_side_a = "SELL"
close_side_b = "SELL"
close_side_a = "BUY"
# If no close signal found, print position and unrealized PnL
if close_row_index is None:
last_row_index = len(predicted_df) - 1
# Use the new method from BacktestResult to handle outstanding positions
result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=last_row_index,
open_side_a=open_side_a,
open_side_b=open_side_b,
open_px_a=float(open_px_a),
open_px_b=float(open_px_b),
open_tstamp=pd.Timestamp(open_tstamp),
)
# Return only open trades (no close trades)
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
]
else:
# Close signal found - create complete trade
close_row = predicted_df.loc[close_row_index]
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
print(f"{pair}: Close signal found at index {close_row_index}")
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_a,
pair.symbol_a_,
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_b,
pair.symbol_b_,
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
pass

View File

@ -68,7 +68,8 @@ def create_result_database(db_path: str) -> None:
close_quantity INTEGER,
close_disequilibrium REAL,
symbol_return REAL,
pair_return REAL
pair_return REAL,
close_condition TEXT
)
"""
)
@ -121,7 +122,7 @@ def store_config_in_database(
config: Dict,
fit_method_class: str,
datafiles: List[str],
instruments: List[str],
instruments: List[Dict[str, str]],
) -> None:
"""
Store configuration information in the database for reference.
@ -140,7 +141,12 @@ def store_config_in_database(
# Convert lists to comma-separated strings for storage
datafiles_str = ", ".join(datafiles)
instruments_str = ", ".join(instruments)
instruments_str = ", ".join(
[
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
for inst in instruments
]
)
# Insert configuration record
cursor.execute(
@ -170,6 +176,7 @@ def store_config_in_database(
traceback.print_exc()
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
@ -188,244 +195,6 @@ def convert_timestamp(timestamp: Any) -> Optional[datetime]:
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
def store_results_in_database(
db_path: str, datafile: str, bt_result: "BacktestResult"
) -> None:
"""
Store backtest results in the SQLite database.
"""
if db_path.upper() == "NONE":
return
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile)
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = bt_result.get_trades()
for pair_name, symbols in trades.items():
# Calculate pair return for this pair
pair_return = 0.0
pair_trades = []
# First pass: collect all trades and calculate returns
for symbol, symbol_trades in symbols.items():
if len(symbol_trades) == 0: # No trades for this symbol
print(
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
)
continue
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
# Handle both old and new tuple formats
if len(symbol_trades[0]) == 2: # Old format: (action, price)
entry_action, entry_price = symbol_trades[0]
exit_action, exit_price = symbol_trades[1]
open_disequilibrium = 0.0 # Fallback for old format
open_scaled_disequilibrium = 0.0
close_disequilibrium = 0.0
close_scaled_disequilibrium = 0.0
open_time = datetime.now()
close_time = datetime.now()
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
(
entry_action,
entry_price,
open_disequilibrium,
open_scaled_disequilibrium,
open_time,
) = symbol_trades[0]
(
exit_action,
exit_price,
close_disequilibrium,
close_scaled_disequilibrium,
close_time,
) = symbol_trades[1]
# Handle None values
open_disequilibrium = (
open_disequilibrium
if open_disequilibrium is not None
else 0.0
)
open_scaled_disequilibrium = (
open_scaled_disequilibrium
if open_scaled_disequilibrium is not None
else 0.0
)
close_disequilibrium = (
close_disequilibrium
if close_disequilibrium is not None
else 0.0
)
close_scaled_disequilibrium = (
close_scaled_disequilibrium
if close_scaled_disequilibrium is not None
else 0.0
)
# Convert pandas Timestamps to Python datetime objects
open_time = convert_timestamp(open_time) or datetime.now()
close_time = convert_timestamp(close_time) or datetime.now()
# Calculate actual share quantities based on funding per pair
# Split funding equally between the two positions
funding_per_position = bt_result.config["funding_per_pair"] / 2
shares = funding_per_position / entry_price
# Calculate symbol return
symbol_return = 0.0
if entry_action == "BUY" and exit_action == "SELL":
symbol_return = (exit_price - entry_price) / entry_price * 100
elif entry_action == "SELL" and exit_action == "BUY":
symbol_return = (entry_price - exit_price) / entry_price * 100
pair_return += symbol_return
pair_trades.append(
{
"symbol": symbol,
"entry_action": entry_action,
"entry_price": entry_price,
"exit_action": exit_action,
"exit_price": exit_price,
"symbol_return": symbol_return,
"open_disequilibrium": open_disequilibrium,
"open_scaled_disequilibrium": open_scaled_disequilibrium,
"close_disequilibrium": close_disequilibrium,
"close_scaled_disequilibrium": close_scaled_disequilibrium,
"open_time": open_time,
"close_time": close_time,
"shares": shares,
"is_completed": True,
}
)
# Skip one-sided trades - they will be handled by outstanding_positions table
elif len(symbol_trades) == 1:
print(
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
)
continue
else:
# This should not happen, but handle unexpected cases
print(
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
)
continue
# Second pass: insert completed trade records into database
for trade in pair_trades:
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade["symbol"],
trade["open_time"],
trade["entry_action"],
trade["entry_price"],
trade["shares"],
trade["open_scaled_disequilibrium"],
trade["close_time"],
trade["exit_action"],
trade["exit_price"],
trade["shares"],
trade["close_scaled_disequilibrium"],
trade["symbol_return"],
pair_return,
),
)
# Store outstanding positions in separate table
outstanding_positions = bt_result.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()
class BacktestResult:
"""
@ -437,6 +206,7 @@ class BacktestResult:
self.trades: Dict[str, Dict[str, Any]] = {}
self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = []
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {}
def add_trade(
self,
@ -458,14 +228,15 @@ class BacktestResult:
if symbol not in self.trades[pair_nm]:
self.trades[pair_nm][symbol] = []
self.trades[pair_nm][symbol].append(
{"symbol":symbol,
"side":side,
"action":action,
"price":price,
"disequilibrium":disequilibrium,
"scaled_disequilibrium":scaled_disequilibrium,
"timestamp":timestamp,
"status":status
{
"symbol": symbol,
"side": side,
"action": action,
"price": price,
"disequilibrium": disequilibrium,
"scaled_disequilibrium": scaled_disequilibrium,
"timestamp": timestamp,
"status": status,
}
)
@ -549,98 +320,127 @@ class BacktestResult:
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
"""Calculate and print returns by day and pair."""
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
if trade1_side == "BUY" and trade2_side == "SELL":
return (trade2_px - trade1_px) / trade1_px * 100
elif trade1_side == "SELL" and trade2_side == "BUY":
return (trade1_px - trade2_px) / trade1_px * 100
else:
return 0
print("\n====== Returns By Day and Pair ======")
trades = []
for filename, data in all_results.items():
day_return = 0
pairs = list(data["trades"].keys())
for pair in pairs:
self.pairs_trades_[pair] = []
trades_dict = data["trades"][pair]
for symbol in trades_dict.keys():
trades.extend(trades_dict[symbol])
trades = sorted(trades, key=lambda x: (x["timestamp"], x["symbol"]))
print(f"\n--- {filename} ---")
self.outstanding_positions = data["outstanding_positions"]
# Process each pair
for pair, symbols in data["trades"].items():
pair_return = 0
pair_trades = []
day_return = 0.0
for idx in range(0, len(trades), 4):
symbol_a = trades[idx]["symbol"]
trade_a_1 = trades[idx]
trade_a_2 = trades[idx + 2]
symbol_b = trades[idx + 1]["symbol"]
trade_b_1 = trades[idx + 1]
trade_b_2 = trades[idx + 3]
# Calculate individual symbol returns in the pair
for symbol, trades in symbols.items():
if len(trades) == 0:
continue
symbol_return = 0
symbol_trades = [trade for trade in trades if trade["symbol"] == symbol]
# Calculate returns for all trade combinations
for idx in range(0, len(symbol_trades), 2):
trade1 = trades[idx]
trade2 = trades[idx + 1]
assert trade1["timestamp"] < trade2["timestamp"], f"Trade 1: {trade1['timestamp']} is not less than Trade 2: {trade2['timestamp']}"
assert trade1["action"] == "OPEN" and trade2["action"] == "CLOSE", f"Trade 1: {trade1['action']} and Trade 2: {trade2['action']} are the same"
assert (
trade_a_1["timestamp"] < trade_a_2["timestamp"]
), f"Trade 1: {trade_a_1['timestamp']} is not less than Trade 2: {trade_a_2['timestamp']}"
assert (
trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"
), f"Trade 1: {trade_a_1['action']} and Trade 2: {trade_a_2['action']} are the same"
# Calculate return based on action combination
trade_return = 0
if trade1["side"] == "BUY" and trade2["side"] == "SELL":
# Long position
trade_return = (trade2["price"] - trade1["price"]) / trade1["price"] * 100
elif trade1["side"] == "SELL" and trade2["side"] == "BUY":
# Short position
trade_return = (trade1["price"] - trade2["price"]) / trade1["price"] * 100
symbol_a_return = _symbol_return(trade_a_1["side"], trade_a_1["price"], trade_a_2["side"], trade_a_2["price"])
symbol_b_return = _symbol_return(trade_b_1["side"], trade_b_1["price"], trade_b_2["side"], trade_b_2["price"])
symbol_return += trade_return
pair_return = symbol_a_return + symbol_b_return
# Store trade details for reporting
pair_trades.append(
(
symbol,
trade1["timestamp"],
trade2["timestamp"],
trade1["side"],
trade1["price"],
trade2["side"],
trade2["price"],
trade_return,
trade1["scaled_disequilibrium"],
trade2["scaled_disequilibrium"],
f"{idx + 1}", # Trade sequence number
self.pairs_trades_[pair].append(
{
"symbol": symbol_a,
"open_side": trade_a_1["side"],
"open_action": trade_a_1["action"],
"open_price": trade_a_1["price"],
"close_side": trade_a_2["side"],
"close_action": trade_a_2["action"],
"close_price": trade_a_2["price"],
"symbol_return": symbol_a_return,
"open_disequilibrium": trade_a_1["disequilibrium"],
"open_scaled_disequilibrium": trade_a_1["scaled_disequilibrium"],
"close_disequilibrium": trade_a_2["disequilibrium"],
"close_scaled_disequilibrium": trade_a_2["scaled_disequilibrium"],
"open_time": trade_a_1["timestamp"],
"close_time": trade_a_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_a_1["price"],
"is_completed": True,
"close_condition": trade_a_2["status"],
"pair_return": pair_return
}
)
self.pairs_trades_[pair].append(
{
"symbol": symbol_b,
"open_side": trade_b_1["side"],
"open_action": trade_b_1["action"],
"open_price": trade_b_1["price"],
"close_side": trade_b_2["side"],
"close_action": trade_b_2["action"],
"close_price": trade_b_2["price"],
"symbol_return": symbol_b_return,
"open_disequilibrium": trade_b_1["disequilibrium"],
"open_scaled_disequilibrium": trade_b_1["scaled_disequilibrium"],
"close_disequilibrium": trade_b_2["disequilibrium"],
"close_scaled_disequilibrium": trade_b_2["scaled_disequilibrium"],
"open_time": trade_b_1["timestamp"],
"close_time": trade_b_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_b_1["price"],
"is_completed": True,
"close_condition": trade_b_2["status"],
"pair_return": pair_return
}
)
pair_return += symbol_return
# Print pair returns with disequilibrium information
if pair_trades:
print(f" {pair}:")
for (
symbol,
trade1["timestamp"],
trade2["timestamp"],
trade1["side"],
trade1["price"],
trade2["side"],
trade2["price"],
trade_return,
trade1["scaled_disequilibrium"],
trade2["scaled_disequilibrium"],
trade_num,
) in pair_trades:
day_return = 0.0
if self.pairs_trades_[pair]:
print(f"{pair}:")
pair_return = 0.0
for trd in self.pairs_trades_[pair]:
disequil_info = ""
if (
trade1["scaled_disequilibrium"] is not None
and trade2["scaled_disequilibrium"] is not None
trd["open_scaled_disequilibrium"] is not None
and trd["open_scaled_disequilibrium"] is not None
):
disequil_info = f" | Open Dis-eq: {trade1["scaled_disequilibrium"]:.2f},"
f" Close Dis-eq: {trade2["scaled_disequilibrium"]:.2f}"
disequil_info = f" | Open Dis-eq: {trd['open_scaled_disequilibrium']:.2f},"
f" Close Dis-eq: {trd['open_scaled_disequilibrium']:.2f}"
print(
f" {trade2['timestamp'].time()} {symbol} (Trade #{trade_num}):"
f" {trade1["side"]} @ ${trade1["price"]:.2f},"
f" {trade2["side"]} @ ${trade2["price"]:.2f},"
f" Return: {trade_return:.2f}%{disequil_info}"
f" {trd['open_time'].time()} {trd['symbol']}: "
f" {trd['open_side']} @ ${trd['open_price']:.2f},"
f" {trd["close_side"]} @ ${trd["close_price"]:.2f},"
f" Return: {trd['symbol_return']:.2f}%{disequil_info}"
)
pair_return += trd["symbol_return"]
print(f" Pair Total Return: {pair_return:.2f}%")
day_return += pair_return
# Print day total return and add to global realized PnL
if day_return != 0:
print(f" Day Total Return: {day_return:.2f}%")
@ -811,3 +611,132 @@ class BacktestResult:
)
return current_value_a, current_value_b, total_current_value
def store_results_in_database(
self, db_path: str, datafile: str
) -> None:
"""
Store backtest results in the SQLite database.
"""
if db_path.upper() == "NONE":
return
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile)
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = self.get_trades()
for pair_name, _ in trades.items():
# Second pass: insert completed trade records into database
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]):
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return, close_condition
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade_pair["symbol"],
trade_pair["open_time"],
trade_pair["open_side"],
trade_pair["open_price"],
trade_pair["shares"],
trade_pair["open_scaled_disequilibrium"],
trade_pair["close_time"],
trade_pair["close_side"],
trade_pair["close_price"],
trade_pair["shares"],
trade_pair["close_scaled_disequilibrium"],
trade_pair["symbol_return"],
trade_pair["pair_return"],
trade_pair["close_condition"]
),
)
# Store outstanding positions in separate table
outstanding_positions = self.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import sqlite3
from typing import Dict, List, cast
import pandas as pd
def load_sqlite_to_dataframe(db_path, query):
try:
conn = sqlite3.connect(db_path)
@ -35,20 +36,27 @@ def convert_time_to_UTC(value: str, timezone: str) -> str:
return result.strftime("%Y-%m-%d %H:%M:%S")
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
from tools.data_loader import load_sqlite_to_dataframe
def load_market_data(
datafile: str,
instruments: List[Dict[str, str]],
db_table_name: str,
trading_hours: Dict = {},
) -> pd.DataFrame:
instrument_ids = [
'"' + config["instrument_id_pfx"] + instrument + '"'
for instrument in config["instruments"]
insts = [
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
for instrument in instruments
]
exchange_id = config["exchange_id"]
instrument_ids = list(set(insts))
exchange_ids = list(
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
)
query = "select"
query += " tstamp"
query += ", tstamp_ns as time_ns"
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol"
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
query += ", open"
query += ", high"
query += ", low"
@ -57,15 +65,15 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
query += ", num_trades"
query += ", vwap"
query += f" from {config['db_table_name']}"
query += f" where exchange_id ='{exchange_id}'"
query += f" from {db_table_name}"
query += f" where exchange_id in ({','.join(exchange_ids)})"
query += f" and instrument_id in ({','.join(instrument_ids)})"
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
# Trading Hours
if len(df) > 0 and len(trading_hours) > 0:
date_str = df["tstamp"][0][0:10]
trading_hours = config["trading_hours"]
start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
@ -81,50 +89,52 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
return cast(pd.DataFrame, df)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
# """
# Auto-detect available instruments from the database by querying distinct instrument_id values.
# Returns instruments without the configured prefix.
# """
# try:
# conn = sqlite3.connect(datafile)
# Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# # Build exclusion list with full instrument_ids
# exclude_instruments = config.get("exclude_instruments", [])
# prefix = config.get("instrument_id_pfx", "")
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# # Query to get distinct instrument_ids
# query = f"""
# SELECT DISTINCT instrument_id
# FROM {config['db_table_name']}
# WHERE exchange_id = ?
# """
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# # Add exclusion clause if there are instruments to exclude
# if exclude_instrument_ids:
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
# query += f" AND instrument_id NOT IN ({placeholders})"
# cursor = conn.execute(
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
# )
# else:
# cursor = conn.execute(query, (config["exchange_id"],))
# instrument_ids = [row[0] for row in cursor.fetchall()]
# conn.close()
# Remove the configured prefix to get instrument symbols
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
# # Remove the configured prefix to get instrument symbols
# instruments = []
# for instrument_id in instrument_ids:
# if instrument_id.startswith(prefix):
# symbol = instrument_id[len(prefix) :]
# instruments.append(symbol)
# else:
# instruments.append(instrument_id)
return sorted(instruments)
# return sorted(instruments)
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# except Exception as e:
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
# return []
# if __name__ == "__main__":

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from tools.data_loader import get_available_instruments_from_db
from pt_trading.results import (
BacktestResult,
create_result_database,

File diff suppressed because one or more lines are too long

View File

@ -9,62 +9,62 @@ import pandas as pd
from research.research_tools import create_pairs
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
CLI datafiles take priority over config datafiles.
Supports wildcards in config but not in CLI.
"""
if cli_datafiles:
# CLI override - comma-separated list, no wildcards
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
def resolve_datafiles(
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
) -> List[str]:
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
for inst in instruments:
pattern = date_pattern
inst_type = inst["instrument_type"]
data_dir = config["market_data_loading"][inst_type]["data_directory"]
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
matched_files = glob.glob(pattern)
resolved_files.extend(matched_files)
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
instruments = [
{
"symbol": inst.split(":")[0],
"instrument_type": inst.split(":")[1],
"exchange_id": inst.split(":")[2],
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
"instrument_id_pfx"
],
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
"db_table_name"
],
}
for inst in args.instruments.split(",")
]
return instruments
def run_backtest(
config: Dict,
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
instruments: List[Dict[str, str]],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
@ -72,13 +72,14 @@ def run_backtest(
bt_result: BacktestResult = BacktestResult(config=config)
pairs_trades = []
for pair in create_pairs(
pairs = create_pairs(
datafile=datafile,
fit_method=fit_method,
price_column=price_column,
config=config,
instruments=instruments,
):
)
for pair in pairs:
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
@ -98,16 +99,16 @@ def main() -> None:
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafiles",
"--date_pattern",
type=str,
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
required=True,
help="Date YYYYMMDD, allows * and ? wildcards",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
required=True,
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
)
parser.add_argument(
"--result_db",
@ -128,7 +129,8 @@ def main() -> None:
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
if not datafiles:
print("No data files found to process.")
@ -149,18 +151,8 @@ def main() -> None:
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
@ -168,7 +160,7 @@ def main() -> None:
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
instruments=instruments,
)
# Process each data file
@ -177,20 +169,6 @@ def main() -> None:
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
# Process data for this file
try:
fit_method.reset()
@ -212,7 +190,15 @@ def main() -> None:
# Store results in database
if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results)
bt_results.calculate_returns(
{
filename: {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
}
}
)
bt_results.store_results_in_database(args.result_db, datafile)
print(f"Successfully processed {filename}")

View File

@ -5,7 +5,6 @@ from typing import Dict, List, Optional
from pt_trading.fit_method import PairsTradingFitMethod
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
@ -44,9 +43,17 @@ def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def create_pairs(datafile: str, fit_method: PairsTradingFitMethod, price_column: str, config: Dict, instruments: List[str]) -> List:
def create_pairs(
datafile: str,
fit_method: PairsTradingFitMethod,
price_column: str,
config: Dict,
instruments: List[Dict[str, str]],
) -> List:
from tools.data_loader import load_market_data
from pt_trading.trading_pair import TradingPair
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
@ -55,17 +62,22 @@ def create_pairs(datafile: str, fit_method: PairsTradingFitMethod, price_column:
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
market_data_df = load_market_data(
datafile=datafile,
instruments=instruments,
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
for a_index, b_index in unique_index_pairs:
from research.pt_backtest import TradingPair
pair = fit_method.create_trading_pair(
config=config_copy,
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
symbol_a=instruments[a_index]["symbol"],
symbol_b=instruments[b_index]["symbol"],
price_column=price_column,
)
pairs.append(pair)
return pairs

View File

@ -20,8 +20,6 @@ from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
@ -44,7 +42,14 @@ def run_strategy(
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
market_data_df = load_market_data(
datafile=datafile,
exchange_id=config_copy["exchange_id"],
instruments=config_copy["instruments"],
instrument_id_pfx=config_copy["instrument_id_pfx"],
db_table_name=config_copy["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
for a_index, b_index in unique_index_pairs:
pair = fit_method.create_trading_pair(