progress and result.py fixes
This commit is contained in:
parent
577fb5c109
commit
e30b0df4db
@ -1,21 +1,13 @@
|
||||
{
|
||||
"instrument_type_specifics": {
|
||||
"market_data_loading": {
|
||||
"CRYPTO": {
|
||||
"data_directory": "./data/crypto",
|
||||
"datafiles": [
|
||||
"20250602.mktdata.ohlcv.db"
|
||||
],
|
||||
"db_table_name": "md_1min_bars",
|
||||
"exchange_id": "BNBSPOT",
|
||||
"instrument_id_pfx": "PAIR-",
|
||||
},
|
||||
"EQUITY": {
|
||||
"data_directory": "./data/equity",
|
||||
"datafiles": [
|
||||
"20250602.mktdata.ohlcv.db"
|
||||
],
|
||||
"db_table_name": "md_1min_bars",
|
||||
"exchange_id": "BNBSPOT",
|
||||
"instrument_id_pfx": "STOCK-",
|
||||
}
|
||||
},
|
||||
|
||||
@ -1,212 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
import pandas as pd # type: ignore[import]
|
||||
from pt_trading.results import BacktestResult
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
from pt_trading.fit_method import PairsTradingFitMethod
|
||||
|
||||
NanoPerMin = 1e9
|
||||
|
||||
|
||||
|
||||
class StaticFit(PairsTradingFitMethod):
|
||||
|
||||
def run_pair(
|
||||
self, pair: TradingPair, bt_result: BacktestResult
|
||||
) -> Optional[pd.DataFrame]: # abstractmethod
|
||||
config = pair.config_
|
||||
pair.get_datasets(training_minutes=config["training_minutes"])
|
||||
|
||||
try:
|
||||
pair.predict()
|
||||
except Exception as e:
|
||||
print(f"{pair}: Prediction failed: {str(e)}")
|
||||
return None
|
||||
|
||||
pair_trades = self.create_trading_signals(
|
||||
pair=pair, config=config, result=bt_result
|
||||
)
|
||||
|
||||
return pair_trades
|
||||
|
||||
def create_trading_signals(
|
||||
self, pair: TradingPair, config: Dict, result: BacktestResult
|
||||
) -> pd.DataFrame:
|
||||
beta = pair.vecm_fit_.beta # type: ignore
|
||||
colname_a, colname_b = pair.colnames()
|
||||
|
||||
predicted_df = pair.predicted_df_
|
||||
if predicted_df is None:
|
||||
# Return empty DataFrame with correct columns and dtypes
|
||||
return pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
|
||||
"time": "datetime64[ns]",
|
||||
"action": "string",
|
||||
"symbol": "string",
|
||||
"price": "float64",
|
||||
"disequilibrium": "float64",
|
||||
"scaled_disequilibrium": "float64",
|
||||
"pair": "object"
|
||||
})
|
||||
|
||||
open_threshold = config["dis-equilibrium_open_trshld"]
|
||||
close_threshold = config["dis-equilibrium_close_trshld"]
|
||||
|
||||
# Iterate through the testing dataset to find the first trading opportunity
|
||||
open_row_index = None
|
||||
for row_idx in range(len(predicted_df)):
|
||||
curr_disequilibrium = predicted_df["scaled_disequilibrium"][row_idx]
|
||||
|
||||
# Check if current row has sufficient disequilibrium (not near-zero)
|
||||
if curr_disequilibrium >= open_threshold:
|
||||
open_row_index = row_idx
|
||||
break
|
||||
|
||||
# If no row with sufficient disequilibrium found, skip this pair
|
||||
if open_row_index is None:
|
||||
print(f"{pair}: Insufficient disequilibrium in testing dataset. Skipping.")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Look for close signal starting from the open position
|
||||
trading_signals_df = (
|
||||
predicted_df["scaled_disequilibrium"][open_row_index:] < close_threshold
|
||||
)
|
||||
|
||||
# Adjust indices to account for the offset from open_row_index
|
||||
close_row_index = None
|
||||
for idx, value in trading_signals_df.items():
|
||||
if value:
|
||||
close_row_index = idx
|
||||
break
|
||||
|
||||
open_row = predicted_df.loc[open_row_index]
|
||||
open_px_a = predicted_df.at[open_row_index, f"{colname_a}"]
|
||||
open_px_b = predicted_df.at[open_row_index, f"{colname_b}"]
|
||||
open_tstamp = predicted_df.at[open_row_index, "tstamp"]
|
||||
open_disequilibrium = open_row["disequilibrium"]
|
||||
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
|
||||
|
||||
abs_beta = abs(beta[1])
|
||||
pred_px_b = predicted_df.loc[open_row_index][f"{colname_b}_pred"]
|
||||
pred_px_a = predicted_df.loc[open_row_index][f"{colname_a}_pred"]
|
||||
|
||||
if pred_px_b * abs_beta - pred_px_a > 0:
|
||||
open_side_a = "BUY"
|
||||
open_side_b = "SELL"
|
||||
close_side_a = "SELL"
|
||||
close_side_b = "BUY"
|
||||
else:
|
||||
open_side_b = "BUY"
|
||||
open_side_a = "SELL"
|
||||
close_side_b = "SELL"
|
||||
close_side_a = "BUY"
|
||||
|
||||
# If no close signal found, print position and unrealized PnL
|
||||
if close_row_index is None:
|
||||
|
||||
last_row_index = len(predicted_df) - 1
|
||||
|
||||
# Use the new method from BacktestResult to handle outstanding positions
|
||||
result.handle_outstanding_position(
|
||||
pair=pair,
|
||||
pair_result_df=predicted_df,
|
||||
last_row_index=last_row_index,
|
||||
open_side_a=open_side_a,
|
||||
open_side_b=open_side_b,
|
||||
open_px_a=float(open_px_a),
|
||||
open_px_b=float(open_px_b),
|
||||
open_tstamp=pd.Timestamp(open_tstamp),
|
||||
)
|
||||
|
||||
# Return only open trades (no close trades)
|
||||
trd_signal_tuples = [
|
||||
(
|
||||
open_tstamp,
|
||||
open_side_a,
|
||||
pair.symbol_a_,
|
||||
open_px_a,
|
||||
open_disequilibrium,
|
||||
open_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
(
|
||||
open_tstamp,
|
||||
open_side_b,
|
||||
pair.symbol_b_,
|
||||
open_px_b,
|
||||
open_disequilibrium,
|
||||
open_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# Close signal found - create complete trade
|
||||
close_row = predicted_df.loc[close_row_index]
|
||||
close_tstamp = close_row["tstamp"]
|
||||
close_disequilibrium = close_row["disequilibrium"]
|
||||
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
|
||||
close_px_a = close_row[f"{colname_a}"]
|
||||
close_px_b = close_row[f"{colname_b}"]
|
||||
|
||||
print(f"{pair}: Close signal found at index {close_row_index}")
|
||||
|
||||
trd_signal_tuples = [
|
||||
(
|
||||
open_tstamp,
|
||||
open_side_a,
|
||||
pair.symbol_a_,
|
||||
open_px_a,
|
||||
open_disequilibrium,
|
||||
open_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
(
|
||||
open_tstamp,
|
||||
open_side_b,
|
||||
pair.symbol_b_,
|
||||
open_px_b,
|
||||
open_disequilibrium,
|
||||
open_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
(
|
||||
close_tstamp,
|
||||
close_side_a,
|
||||
pair.symbol_a_,
|
||||
close_px_a,
|
||||
close_disequilibrium,
|
||||
close_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
(
|
||||
close_tstamp,
|
||||
close_side_b,
|
||||
pair.symbol_b_,
|
||||
close_px_b,
|
||||
close_disequilibrium,
|
||||
close_scaled_disequilibrium,
|
||||
pair,
|
||||
),
|
||||
]
|
||||
|
||||
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
|
||||
df = pd.DataFrame(
|
||||
trd_signal_tuples,
|
||||
columns=self.TRADES_COLUMNS,
|
||||
)
|
||||
# Ensure consistent dtypes
|
||||
return df.astype({
|
||||
"time": "datetime64[ns]",
|
||||
"action": "string",
|
||||
"symbol": "string",
|
||||
"price": "float64",
|
||||
"disequilibrium": "float64",
|
||||
"scaled_disequilibrium": "float64",
|
||||
"pair": "object"
|
||||
})
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -68,7 +68,8 @@ def create_result_database(db_path: str) -> None:
|
||||
close_quantity INTEGER,
|
||||
close_disequilibrium REAL,
|
||||
symbol_return REAL,
|
||||
pair_return REAL
|
||||
pair_return REAL,
|
||||
close_condition TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
@ -121,7 +122,7 @@ def store_config_in_database(
|
||||
config: Dict,
|
||||
fit_method_class: str,
|
||||
datafiles: List[str],
|
||||
instruments: List[str],
|
||||
instruments: List[Dict[str, str]],
|
||||
) -> None:
|
||||
"""
|
||||
Store configuration information in the database for reference.
|
||||
@ -140,7 +141,12 @@ def store_config_in_database(
|
||||
|
||||
# Convert lists to comma-separated strings for storage
|
||||
datafiles_str = ", ".join(datafiles)
|
||||
instruments_str = ", ".join(instruments)
|
||||
instruments_str = ", ".join(
|
||||
[
|
||||
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
|
||||
for inst in instruments
|
||||
]
|
||||
)
|
||||
|
||||
# Insert configuration record
|
||||
cursor.execute(
|
||||
@ -170,6 +176,7 @@ def store_config_in_database(
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
|
||||
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
|
||||
if timestamp is None:
|
||||
@ -188,244 +195,6 @@ def convert_timestamp(timestamp: Any) -> Optional[datetime]:
|
||||
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
|
||||
|
||||
|
||||
def store_results_in_database(
|
||||
db_path: str, datafile: str, bt_result: "BacktestResult"
|
||||
) -> None:
|
||||
"""
|
||||
Store backtest results in the SQLite database.
|
||||
"""
|
||||
if db_path.upper() == "NONE":
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
|
||||
filename = os.path.basename(datafile)
|
||||
date_str = filename.split(".")[0] # Extract date part
|
||||
|
||||
# Convert to proper date format
|
||||
try:
|
||||
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
|
||||
except ValueError:
|
||||
# If date parsing fails, use current date
|
||||
date_obj = datetime.now().date()
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Process each trade from bt_result
|
||||
trades = bt_result.get_trades()
|
||||
|
||||
for pair_name, symbols in trades.items():
|
||||
# Calculate pair return for this pair
|
||||
pair_return = 0.0
|
||||
pair_trades = []
|
||||
|
||||
# First pass: collect all trades and calculate returns
|
||||
for symbol, symbol_trades in symbols.items():
|
||||
if len(symbol_trades) == 0: # No trades for this symbol
|
||||
print(
|
||||
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
|
||||
# Handle both old and new tuple formats
|
||||
if len(symbol_trades[0]) == 2: # Old format: (action, price)
|
||||
entry_action, entry_price = symbol_trades[0]
|
||||
exit_action, exit_price = symbol_trades[1]
|
||||
open_disequilibrium = 0.0 # Fallback for old format
|
||||
open_scaled_disequilibrium = 0.0
|
||||
close_disequilibrium = 0.0
|
||||
close_scaled_disequilibrium = 0.0
|
||||
open_time = datetime.now()
|
||||
close_time = datetime.now()
|
||||
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
|
||||
(
|
||||
entry_action,
|
||||
entry_price,
|
||||
open_disequilibrium,
|
||||
open_scaled_disequilibrium,
|
||||
open_time,
|
||||
) = symbol_trades[0]
|
||||
(
|
||||
exit_action,
|
||||
exit_price,
|
||||
close_disequilibrium,
|
||||
close_scaled_disequilibrium,
|
||||
close_time,
|
||||
) = symbol_trades[1]
|
||||
|
||||
# Handle None values
|
||||
open_disequilibrium = (
|
||||
open_disequilibrium
|
||||
if open_disequilibrium is not None
|
||||
else 0.0
|
||||
)
|
||||
open_scaled_disequilibrium = (
|
||||
open_scaled_disequilibrium
|
||||
if open_scaled_disequilibrium is not None
|
||||
else 0.0
|
||||
)
|
||||
close_disequilibrium = (
|
||||
close_disequilibrium
|
||||
if close_disequilibrium is not None
|
||||
else 0.0
|
||||
)
|
||||
close_scaled_disequilibrium = (
|
||||
close_scaled_disequilibrium
|
||||
if close_scaled_disequilibrium is not None
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# Convert pandas Timestamps to Python datetime objects
|
||||
open_time = convert_timestamp(open_time) or datetime.now()
|
||||
close_time = convert_timestamp(close_time) or datetime.now()
|
||||
|
||||
# Calculate actual share quantities based on funding per pair
|
||||
# Split funding equally between the two positions
|
||||
funding_per_position = bt_result.config["funding_per_pair"] / 2
|
||||
shares = funding_per_position / entry_price
|
||||
|
||||
# Calculate symbol return
|
||||
symbol_return = 0.0
|
||||
if entry_action == "BUY" and exit_action == "SELL":
|
||||
symbol_return = (exit_price - entry_price) / entry_price * 100
|
||||
elif entry_action == "SELL" and exit_action == "BUY":
|
||||
symbol_return = (entry_price - exit_price) / entry_price * 100
|
||||
|
||||
pair_return += symbol_return
|
||||
|
||||
pair_trades.append(
|
||||
{
|
||||
"symbol": symbol,
|
||||
"entry_action": entry_action,
|
||||
"entry_price": entry_price,
|
||||
"exit_action": exit_action,
|
||||
"exit_price": exit_price,
|
||||
"symbol_return": symbol_return,
|
||||
"open_disequilibrium": open_disequilibrium,
|
||||
"open_scaled_disequilibrium": open_scaled_disequilibrium,
|
||||
"close_disequilibrium": close_disequilibrium,
|
||||
"close_scaled_disequilibrium": close_scaled_disequilibrium,
|
||||
"open_time": open_time,
|
||||
"close_time": close_time,
|
||||
"shares": shares,
|
||||
"is_completed": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Skip one-sided trades - they will be handled by outstanding_positions table
|
||||
elif len(symbol_trades) == 1:
|
||||
print(
|
||||
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
|
||||
)
|
||||
continue
|
||||
|
||||
else:
|
||||
# This should not happen, but handle unexpected cases
|
||||
print(
|
||||
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Second pass: insert completed trade records into database
|
||||
for trade in pair_trades:
|
||||
# Only store completed trades in pt_bt_results table
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO pt_bt_results (
|
||||
date, pair, symbol, open_time, open_side, open_price,
|
||||
open_quantity, open_disequilibrium, close_time, close_side,
|
||||
close_price, close_quantity, close_disequilibrium,
|
||||
symbol_return, pair_return
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pair_name,
|
||||
trade["symbol"],
|
||||
trade["open_time"],
|
||||
trade["entry_action"],
|
||||
trade["entry_price"],
|
||||
trade["shares"],
|
||||
trade["open_scaled_disequilibrium"],
|
||||
trade["close_time"],
|
||||
trade["exit_action"],
|
||||
trade["exit_price"],
|
||||
trade["shares"],
|
||||
trade["close_scaled_disequilibrium"],
|
||||
trade["symbol_return"],
|
||||
pair_return,
|
||||
),
|
||||
)
|
||||
|
||||
# Store outstanding positions in separate table
|
||||
outstanding_positions = bt_result.get_outstanding_positions()
|
||||
for pos in outstanding_positions:
|
||||
# Calculate position quantity (negative for SELL positions)
|
||||
position_qty_a = (
|
||||
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
|
||||
)
|
||||
position_qty_b = (
|
||||
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
|
||||
)
|
||||
|
||||
# Calculate unrealized returns
|
||||
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
|
||||
unrealized_return_a = (
|
||||
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
|
||||
) * (1 if pos["side_a"] == "BUY" else -1)
|
||||
unrealized_return_b = (
|
||||
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
|
||||
) * (1 if pos["side_b"] == "BUY" else -1)
|
||||
|
||||
# Store outstanding position for symbol A
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO outstanding_positions (
|
||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pos["pair"],
|
||||
pos["symbol_a"],
|
||||
position_qty_a,
|
||||
pos["current_px_a"],
|
||||
unrealized_return_a,
|
||||
pos["open_px_a"],
|
||||
pos["side_a"],
|
||||
),
|
||||
)
|
||||
|
||||
# Store outstanding position for symbol B
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO outstanding_positions (
|
||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pos["pair"],
|
||||
pos["symbol_b"],
|
||||
position_qty_b,
|
||||
pos["current_px_b"],
|
||||
unrealized_return_b,
|
||||
pos["open_px_b"],
|
||||
pos["side_b"],
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error storing results in database: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class BacktestResult:
|
||||
"""
|
||||
@ -437,6 +206,7 @@ class BacktestResult:
|
||||
self.trades: Dict[str, Dict[str, Any]] = {}
|
||||
self.total_realized_pnl = 0.0
|
||||
self.outstanding_positions: List[Dict[str, Any]] = []
|
||||
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
def add_trade(
|
||||
self,
|
||||
@ -458,15 +228,16 @@ class BacktestResult:
|
||||
if symbol not in self.trades[pair_nm]:
|
||||
self.trades[pair_nm][symbol] = []
|
||||
self.trades[pair_nm][symbol].append(
|
||||
{"symbol":symbol,
|
||||
"side":side,
|
||||
"action":action,
|
||||
"price":price,
|
||||
"disequilibrium":disequilibrium,
|
||||
"scaled_disequilibrium":scaled_disequilibrium,
|
||||
"timestamp":timestamp,
|
||||
"status":status
|
||||
}
|
||||
{
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"action": action,
|
||||
"price": price,
|
||||
"disequilibrium": disequilibrium,
|
||||
"scaled_disequilibrium": scaled_disequilibrium,
|
||||
"timestamp": timestamp,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
def add_outstanding_position(self, position: Dict[str, Any]) -> None:
|
||||
@ -549,97 +320,126 @@ class BacktestResult:
|
||||
|
||||
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""Calculate and print returns by day and pair."""
|
||||
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
|
||||
if trade1_side == "BUY" and trade2_side == "SELL":
|
||||
return (trade2_px - trade1_px) / trade1_px * 100
|
||||
elif trade1_side == "SELL" and trade2_side == "BUY":
|
||||
return (trade1_px - trade2_px) / trade1_px * 100
|
||||
else:
|
||||
return 0
|
||||
|
||||
print("\n====== Returns By Day and Pair ======")
|
||||
|
||||
trades = []
|
||||
for filename, data in all_results.items():
|
||||
day_return = 0
|
||||
pairs = list(data["trades"].keys())
|
||||
for pair in pairs:
|
||||
self.pairs_trades_[pair] = []
|
||||
trades_dict = data["trades"][pair]
|
||||
for symbol in trades_dict.keys():
|
||||
trades.extend(trades_dict[symbol])
|
||||
trades = sorted(trades, key=lambda x: (x["timestamp"], x["symbol"]))
|
||||
|
||||
print(f"\n--- {filename} ---")
|
||||
|
||||
self.outstanding_positions = data["outstanding_positions"]
|
||||
|
||||
# Process each pair
|
||||
for pair, symbols in data["trades"].items():
|
||||
pair_return = 0
|
||||
pair_trades = []
|
||||
day_return = 0.0
|
||||
for idx in range(0, len(trades), 4):
|
||||
symbol_a = trades[idx]["symbol"]
|
||||
trade_a_1 = trades[idx]
|
||||
trade_a_2 = trades[idx + 2]
|
||||
|
||||
# Calculate individual symbol returns in the pair
|
||||
for symbol, trades in symbols.items():
|
||||
if len(trades) == 0:
|
||||
continue
|
||||
symbol_return = 0
|
||||
symbol_trades = [trade for trade in trades if trade["symbol"] == symbol]
|
||||
symbol_b = trades[idx + 1]["symbol"]
|
||||
trade_b_1 = trades[idx + 1]
|
||||
trade_b_2 = trades[idx + 3]
|
||||
|
||||
# Calculate returns for all trade combinations
|
||||
for idx in range(0, len(symbol_trades), 2):
|
||||
trade1 = trades[idx]
|
||||
trade2 = trades[idx + 1]
|
||||
symbol_return = 0
|
||||
assert (
|
||||
trade_a_1["timestamp"] < trade_a_2["timestamp"]
|
||||
), f"Trade 1: {trade_a_1['timestamp']} is not less than Trade 2: {trade_a_2['timestamp']}"
|
||||
assert (
|
||||
trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"
|
||||
), f"Trade 1: {trade_a_1['action']} and Trade 2: {trade_a_2['action']} are the same"
|
||||
|
||||
assert trade1["timestamp"] < trade2["timestamp"], f"Trade 1: {trade1['timestamp']} is not less than Trade 2: {trade2['timestamp']}"
|
||||
assert trade1["action"] == "OPEN" and trade2["action"] == "CLOSE", f"Trade 1: {trade1['action']} and Trade 2: {trade2['action']} are the same"
|
||||
# Calculate return based on action combination
|
||||
trade_return = 0
|
||||
symbol_a_return = _symbol_return(trade_a_1["side"], trade_a_1["price"], trade_a_2["side"], trade_a_2["price"])
|
||||
symbol_b_return = _symbol_return(trade_b_1["side"], trade_b_1["price"], trade_b_2["side"], trade_b_2["price"])
|
||||
|
||||
# Calculate return based on action combination
|
||||
trade_return = 0
|
||||
if trade1["side"] == "BUY" and trade2["side"] == "SELL":
|
||||
# Long position
|
||||
trade_return = (trade2["price"] - trade1["price"]) / trade1["price"] * 100
|
||||
elif trade1["side"] == "SELL" and trade2["side"] == "BUY":
|
||||
# Short position
|
||||
trade_return = (trade1["price"] - trade2["price"]) / trade1["price"] * 100
|
||||
pair_return = symbol_a_return + symbol_b_return
|
||||
|
||||
symbol_return += trade_return
|
||||
self.pairs_trades_[pair].append(
|
||||
{
|
||||
"symbol": symbol_a,
|
||||
"open_side": trade_a_1["side"],
|
||||
"open_action": trade_a_1["action"],
|
||||
"open_price": trade_a_1["price"],
|
||||
"close_side": trade_a_2["side"],
|
||||
"close_action": trade_a_2["action"],
|
||||
"close_price": trade_a_2["price"],
|
||||
"symbol_return": symbol_a_return,
|
||||
"open_disequilibrium": trade_a_1["disequilibrium"],
|
||||
"open_scaled_disequilibrium": trade_a_1["scaled_disequilibrium"],
|
||||
"close_disequilibrium": trade_a_2["disequilibrium"],
|
||||
"close_scaled_disequilibrium": trade_a_2["scaled_disequilibrium"],
|
||||
"open_time": trade_a_1["timestamp"],
|
||||
"close_time": trade_a_2["timestamp"],
|
||||
"shares": self.config["funding_per_pair"] / 2 / trade_a_1["price"],
|
||||
"is_completed": True,
|
||||
"close_condition": trade_a_2["status"],
|
||||
"pair_return": pair_return
|
||||
}
|
||||
)
|
||||
self.pairs_trades_[pair].append(
|
||||
{
|
||||
"symbol": symbol_b,
|
||||
"open_side": trade_b_1["side"],
|
||||
"open_action": trade_b_1["action"],
|
||||
"open_price": trade_b_1["price"],
|
||||
"close_side": trade_b_2["side"],
|
||||
"close_action": trade_b_2["action"],
|
||||
"close_price": trade_b_2["price"],
|
||||
"symbol_return": symbol_b_return,
|
||||
"open_disequilibrium": trade_b_1["disequilibrium"],
|
||||
"open_scaled_disequilibrium": trade_b_1["scaled_disequilibrium"],
|
||||
"close_disequilibrium": trade_b_2["disequilibrium"],
|
||||
"close_scaled_disequilibrium": trade_b_2["scaled_disequilibrium"],
|
||||
"open_time": trade_b_1["timestamp"],
|
||||
"close_time": trade_b_2["timestamp"],
|
||||
"shares": self.config["funding_per_pair"] / 2 / trade_b_1["price"],
|
||||
"is_completed": True,
|
||||
"close_condition": trade_b_2["status"],
|
||||
"pair_return": pair_return
|
||||
}
|
||||
)
|
||||
|
||||
# Store trade details for reporting
|
||||
pair_trades.append(
|
||||
(
|
||||
symbol,
|
||||
trade1["timestamp"],
|
||||
trade2["timestamp"],
|
||||
trade1["side"],
|
||||
trade1["price"],
|
||||
trade2["side"],
|
||||
trade2["price"],
|
||||
trade_return,
|
||||
trade1["scaled_disequilibrium"],
|
||||
trade2["scaled_disequilibrium"],
|
||||
f"{idx + 1}", # Trade sequence number
|
||||
)
|
||||
)
|
||||
|
||||
pair_return += symbol_return
|
||||
# Print pair returns with disequilibrium information
|
||||
day_return = 0.0
|
||||
if self.pairs_trades_[pair]:
|
||||
|
||||
# Print pair returns with disequilibrium information
|
||||
if pair_trades:
|
||||
print(f" {pair}:")
|
||||
for (
|
||||
symbol,
|
||||
trade1["timestamp"],
|
||||
trade2["timestamp"],
|
||||
trade1["side"],
|
||||
trade1["price"],
|
||||
trade2["side"],
|
||||
trade2["price"],
|
||||
trade_return,
|
||||
trade1["scaled_disequilibrium"],
|
||||
trade2["scaled_disequilibrium"],
|
||||
trade_num,
|
||||
) in pair_trades:
|
||||
disequil_info = ""
|
||||
if (
|
||||
trade1["scaled_disequilibrium"] is not None
|
||||
and trade2["scaled_disequilibrium"] is not None
|
||||
):
|
||||
disequil_info = f" | Open Dis-eq: {trade1["scaled_disequilibrium"]:.2f},"
|
||||
f" Close Dis-eq: {trade2["scaled_disequilibrium"]:.2f}"
|
||||
print(f"{pair}:")
|
||||
pair_return = 0.0
|
||||
for trd in self.pairs_trades_[pair]:
|
||||
disequil_info = ""
|
||||
if (
|
||||
trd["open_scaled_disequilibrium"] is not None
|
||||
and trd["open_scaled_disequilibrium"] is not None
|
||||
):
|
||||
disequil_info = f" | Open Dis-eq: {trd['open_scaled_disequilibrium']:.2f},"
|
||||
f" Close Dis-eq: {trd['open_scaled_disequilibrium']:.2f}"
|
||||
|
||||
print(
|
||||
f" {trade2['timestamp'].time()} {symbol} (Trade #{trade_num}):"
|
||||
f" {trade1["side"]} @ ${trade1["price"]:.2f},"
|
||||
f" {trade2["side"]} @ ${trade2["price"]:.2f},"
|
||||
f" Return: {trade_return:.2f}%{disequil_info}"
|
||||
)
|
||||
print(f" Pair Total Return: {pair_return:.2f}%")
|
||||
day_return += pair_return
|
||||
print(
|
||||
f" {trd['open_time'].time()} {trd['symbol']}: "
|
||||
f" {trd['open_side']} @ ${trd['open_price']:.2f},"
|
||||
f" {trd["close_side"]} @ ${trd["close_price"]:.2f},"
|
||||
f" Return: {trd['symbol_return']:.2f}%{disequil_info}"
|
||||
)
|
||||
pair_return += trd["symbol_return"]
|
||||
|
||||
print(f" Pair Total Return: {pair_return:.2f}%")
|
||||
day_return += pair_return
|
||||
|
||||
# Print day total return and add to global realized PnL
|
||||
if day_return != 0:
|
||||
@ -811,3 +611,132 @@ class BacktestResult:
|
||||
)
|
||||
|
||||
return current_value_a, current_value_b, total_current_value
|
||||
|
||||
def store_results_in_database(
|
||||
self, db_path: str, datafile: str
|
||||
) -> None:
|
||||
"""
|
||||
Store backtest results in the SQLite database.
|
||||
"""
|
||||
if db_path.upper() == "NONE":
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
|
||||
filename = os.path.basename(datafile)
|
||||
date_str = filename.split(".")[0] # Extract date part
|
||||
|
||||
# Convert to proper date format
|
||||
try:
|
||||
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
|
||||
except ValueError:
|
||||
# If date parsing fails, use current date
|
||||
date_obj = datetime.now().date()
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Process each trade from bt_result
|
||||
trades = self.get_trades()
|
||||
|
||||
for pair_name, _ in trades.items():
|
||||
|
||||
# Second pass: insert completed trade records into database
|
||||
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]):
|
||||
# Only store completed trades in pt_bt_results table
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO pt_bt_results (
|
||||
date, pair, symbol, open_time, open_side, open_price,
|
||||
open_quantity, open_disequilibrium, close_time, close_side,
|
||||
close_price, close_quantity, close_disequilibrium,
|
||||
symbol_return, pair_return, close_condition
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pair_name,
|
||||
trade_pair["symbol"],
|
||||
trade_pair["open_time"],
|
||||
trade_pair["open_side"],
|
||||
trade_pair["open_price"],
|
||||
trade_pair["shares"],
|
||||
trade_pair["open_scaled_disequilibrium"],
|
||||
trade_pair["close_time"],
|
||||
trade_pair["close_side"],
|
||||
trade_pair["close_price"],
|
||||
trade_pair["shares"],
|
||||
trade_pair["close_scaled_disequilibrium"],
|
||||
trade_pair["symbol_return"],
|
||||
trade_pair["pair_return"],
|
||||
trade_pair["close_condition"]
|
||||
),
|
||||
)
|
||||
|
||||
# Store outstanding positions in separate table
|
||||
outstanding_positions = self.get_outstanding_positions()
|
||||
for pos in outstanding_positions:
|
||||
# Calculate position quantity (negative for SELL positions)
|
||||
position_qty_a = (
|
||||
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
|
||||
)
|
||||
position_qty_b = (
|
||||
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
|
||||
)
|
||||
|
||||
# Calculate unrealized returns
|
||||
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
|
||||
unrealized_return_a = (
|
||||
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
|
||||
) * (1 if pos["side_a"] == "BUY" else -1)
|
||||
unrealized_return_b = (
|
||||
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
|
||||
) * (1 if pos["side_b"] == "BUY" else -1)
|
||||
|
||||
# Store outstanding position for symbol A
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO outstanding_positions (
|
||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pos["pair"],
|
||||
pos["symbol_a"],
|
||||
position_qty_a,
|
||||
pos["current_px_a"],
|
||||
unrealized_return_a,
|
||||
pos["open_px_a"],
|
||||
pos["side_a"],
|
||||
),
|
||||
)
|
||||
|
||||
# Store outstanding position for symbol B
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO outstanding_positions (
|
||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
date_obj,
|
||||
pos["pair"],
|
||||
pos["symbol_b"],
|
||||
position_qty_b,
|
||||
pos["current_px_b"],
|
||||
unrealized_return_b,
|
||||
pos["open_px_b"],
|
||||
pos["side_b"],
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error storing results in database: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from typing import Dict, List, cast
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
def load_sqlite_to_dataframe(db_path, query):
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
@ -35,20 +36,27 @@ def convert_time_to_UTC(value: str, timezone: str) -> str:
|
||||
return result.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
|
||||
from tools.data_loader import load_sqlite_to_dataframe
|
||||
def load_market_data(
|
||||
datafile: str,
|
||||
instruments: List[Dict[str, str]],
|
||||
db_table_name: str,
|
||||
trading_hours: Dict = {},
|
||||
) -> pd.DataFrame:
|
||||
|
||||
instrument_ids = [
|
||||
'"' + config["instrument_id_pfx"] + instrument + '"'
|
||||
for instrument in config["instruments"]
|
||||
insts = [
|
||||
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
|
||||
for instrument in instruments
|
||||
]
|
||||
exchange_id = config["exchange_id"]
|
||||
instrument_ids = list(set(insts))
|
||||
exchange_ids = list(
|
||||
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
|
||||
)
|
||||
|
||||
query = "select"
|
||||
query += " tstamp"
|
||||
query += ", tstamp_ns as time_ns"
|
||||
|
||||
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol"
|
||||
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
|
||||
query += ", open"
|
||||
query += ", high"
|
||||
query += ", low"
|
||||
@ -57,74 +65,76 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
|
||||
query += ", num_trades"
|
||||
query += ", vwap"
|
||||
|
||||
query += f" from {config['db_table_name']}"
|
||||
query += f" where exchange_id ='{exchange_id}'"
|
||||
query += f" from {db_table_name}"
|
||||
query += f" where exchange_id in ({','.join(exchange_ids)})"
|
||||
query += f" and instrument_id in ({','.join(instrument_ids)})"
|
||||
|
||||
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
|
||||
|
||||
# Trading Hours
|
||||
date_str = df["tstamp"][0][0:10]
|
||||
trading_hours = config["trading_hours"]
|
||||
if len(df) > 0 and len(trading_hours) > 0:
|
||||
date_str = df["tstamp"][0][0:10]
|
||||
|
||||
start_time = convert_time_to_UTC(
|
||||
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
|
||||
)
|
||||
end_time = convert_time_to_UTC(
|
||||
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
|
||||
)
|
||||
start_time = convert_time_to_UTC(
|
||||
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
|
||||
)
|
||||
end_time = convert_time_to_UTC(
|
||||
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
|
||||
)
|
||||
|
||||
# Perform boolean selection
|
||||
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
|
||||
df["tstamp"] = pd.to_datetime(df["tstamp"])
|
||||
# Perform boolean selection
|
||||
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
|
||||
df["tstamp"] = pd.to_datetime(df["tstamp"])
|
||||
|
||||
return cast(pd.DataFrame, df)
|
||||
|
||||
|
||||
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
|
||||
"""
|
||||
Auto-detect available instruments from the database by querying distinct instrument_id values.
|
||||
Returns instruments without the configured prefix.
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(datafile)
|
||||
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
|
||||
# """
|
||||
# Auto-detect available instruments from the database by querying distinct instrument_id values.
|
||||
# Returns instruments without the configured prefix.
|
||||
# """
|
||||
# try:
|
||||
# conn = sqlite3.connect(datafile)
|
||||
|
||||
# Build exclusion list with full instrument_ids
|
||||
exclude_instruments = config.get("exclude_instruments", [])
|
||||
prefix = config.get("instrument_id_pfx", "")
|
||||
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
|
||||
# # Build exclusion list with full instrument_ids
|
||||
# exclude_instruments = config.get("exclude_instruments", [])
|
||||
# prefix = config.get("instrument_id_pfx", "")
|
||||
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
|
||||
|
||||
# Query to get distinct instrument_ids
|
||||
query = f"""
|
||||
SELECT DISTINCT instrument_id
|
||||
FROM {config['db_table_name']}
|
||||
WHERE exchange_id = ?
|
||||
"""
|
||||
# # Query to get distinct instrument_ids
|
||||
# query = f"""
|
||||
# SELECT DISTINCT instrument_id
|
||||
# FROM {config['db_table_name']}
|
||||
# WHERE exchange_id = ?
|
||||
# """
|
||||
|
||||
# Add exclusion clause if there are instruments to exclude
|
||||
if exclude_instrument_ids:
|
||||
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
|
||||
query += f" AND instrument_id NOT IN ({placeholders})"
|
||||
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
|
||||
else:
|
||||
cursor = conn.execute(query, (config["exchange_id"],))
|
||||
instrument_ids = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
# # Add exclusion clause if there are instruments to exclude
|
||||
# if exclude_instrument_ids:
|
||||
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
|
||||
# query += f" AND instrument_id NOT IN ({placeholders})"
|
||||
# cursor = conn.execute(
|
||||
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
|
||||
# )
|
||||
# else:
|
||||
# cursor = conn.execute(query, (config["exchange_id"],))
|
||||
# instrument_ids = [row[0] for row in cursor.fetchall()]
|
||||
# conn.close()
|
||||
|
||||
# Remove the configured prefix to get instrument symbols
|
||||
instruments = []
|
||||
for instrument_id in instrument_ids:
|
||||
if instrument_id.startswith(prefix):
|
||||
symbol = instrument_id[len(prefix) :]
|
||||
instruments.append(symbol)
|
||||
else:
|
||||
instruments.append(instrument_id)
|
||||
# # Remove the configured prefix to get instrument symbols
|
||||
# instruments = []
|
||||
# for instrument_id in instrument_ids:
|
||||
# if instrument_id.startswith(prefix):
|
||||
# symbol = instrument_id[len(prefix) :]
|
||||
# instruments.append(symbol)
|
||||
# else:
|
||||
# instruments.append(instrument_id)
|
||||
|
||||
return sorted(instruments)
|
||||
# return sorted(instruments)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
|
||||
return []
|
||||
# except Exception as e:
|
||||
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
|
||||
# return []
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
from tools.config import expand_filename, load_config
|
||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
||||
from tools.data_loader import get_available_instruments_from_db
|
||||
from pt_trading.results import (
|
||||
BacktestResult,
|
||||
create_result_database,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -9,62 +9,62 @@ import pandas as pd
|
||||
|
||||
from research.research_tools import create_pairs
|
||||
from tools.config import expand_filename, load_config
|
||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
||||
from pt_trading.results import (
|
||||
BacktestResult,
|
||||
create_result_database,
|
||||
store_config_in_database,
|
||||
store_results_in_database,
|
||||
)
|
||||
from pt_trading.fit_method import PairsTradingFitMethod
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
|
||||
|
||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Resolve the list of data files to process.
|
||||
CLI datafiles take priority over config datafiles.
|
||||
Supports wildcards in config but not in CLI.
|
||||
"""
|
||||
if cli_datafiles:
|
||||
# CLI override - comma-separated list, no wildcards
|
||||
datafiles = [f.strip() for f in cli_datafiles.split(",")]
|
||||
# Make paths absolute relative to data directory
|
||||
data_dir = config.get("data_directory", "./data")
|
||||
resolved_files = []
|
||||
for df in datafiles:
|
||||
if not os.path.isabs(df):
|
||||
df = os.path.join(data_dir, df)
|
||||
resolved_files.append(df)
|
||||
return resolved_files
|
||||
|
||||
# Use config datafiles with wildcard support
|
||||
config_datafiles = config.get("datafiles", [])
|
||||
data_dir = config.get("data_directory", "./data")
|
||||
def resolve_datafiles(
|
||||
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
|
||||
) -> List[str]:
|
||||
resolved_files = []
|
||||
|
||||
for pattern in config_datafiles:
|
||||
for inst in instruments:
|
||||
pattern = date_pattern
|
||||
inst_type = inst["instrument_type"]
|
||||
data_dir = config["market_data_loading"][inst_type]["data_directory"]
|
||||
if "*" in pattern or "?" in pattern:
|
||||
# Handle wildcards
|
||||
if not os.path.isabs(pattern):
|
||||
pattern = os.path.join(data_dir, pattern)
|
||||
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
||||
matched_files = glob.glob(pattern)
|
||||
resolved_files.extend(matched_files)
|
||||
else:
|
||||
# Handle explicit file path
|
||||
if not os.path.isabs(pattern):
|
||||
pattern = os.path.join(data_dir, pattern)
|
||||
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
||||
resolved_files.append(pattern)
|
||||
|
||||
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
||||
|
||||
|
||||
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
|
||||
|
||||
instruments = [
|
||||
{
|
||||
"symbol": inst.split(":")[0],
|
||||
"instrument_type": inst.split(":")[1],
|
||||
"exchange_id": inst.split(":")[2],
|
||||
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
|
||||
"instrument_id_pfx"
|
||||
],
|
||||
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
|
||||
"db_table_name"
|
||||
],
|
||||
}
|
||||
for inst in args.instruments.split(",")
|
||||
]
|
||||
return instruments
|
||||
|
||||
|
||||
def run_backtest(
|
||||
config: Dict,
|
||||
datafile: str,
|
||||
price_column: str,
|
||||
fit_method: PairsTradingFitMethod,
|
||||
instruments: List[str],
|
||||
instruments: List[Dict[str, str]],
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run backtest for all pairs using the specified instruments.
|
||||
@ -72,13 +72,14 @@ def run_backtest(
|
||||
bt_result: BacktestResult = BacktestResult(config=config)
|
||||
|
||||
pairs_trades = []
|
||||
for pair in create_pairs(
|
||||
pairs = create_pairs(
|
||||
datafile=datafile,
|
||||
fit_method=fit_method,
|
||||
price_column=price_column,
|
||||
config=config,
|
||||
instruments=instruments,
|
||||
):
|
||||
)
|
||||
for pair in pairs:
|
||||
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
|
||||
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
||||
pairs_trades.append(single_pair_trades)
|
||||
@ -98,16 +99,16 @@ def main() -> None:
|
||||
"--config", type=str, required=True, help="Path to the configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datafiles",
|
||||
"--date_pattern",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Comma-separated list of data files (overrides config). No wildcards supported.",
|
||||
required=True,
|
||||
help="Date YYYYMMDD, allows * and ? wildcards",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruments",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
|
||||
required=True,
|
||||
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_db",
|
||||
@ -128,7 +129,8 @@ def main() -> None:
|
||||
fit_method = getattr(module, class_name)()
|
||||
|
||||
# Resolve data files (CLI takes priority over config)
|
||||
datafiles = resolve_datafiles(config, args.datafiles)
|
||||
instruments = get_instruments(args, config)
|
||||
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
|
||||
|
||||
if not datafiles:
|
||||
print("No data files found to process.")
|
||||
@ -149,18 +151,8 @@ def main() -> None:
|
||||
# Store configuration in database for reference
|
||||
if args.result_db.upper() != "NONE":
|
||||
# Get list of all instruments for storage
|
||||
all_instruments = []
|
||||
for datafile in datafiles:
|
||||
if args.instruments:
|
||||
file_instruments = [
|
||||
inst.strip() for inst in args.instruments.split(",")
|
||||
]
|
||||
else:
|
||||
file_instruments = get_available_instruments_from_db(datafile, config)
|
||||
all_instruments.extend(file_instruments)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_instruments = list(dict.fromkeys(all_instruments))
|
||||
|
||||
store_config_in_database(
|
||||
db_path=args.result_db,
|
||||
@ -168,7 +160,7 @@ def main() -> None:
|
||||
config=config,
|
||||
fit_method_class=fit_method_class_name,
|
||||
datafiles=datafiles,
|
||||
instruments=unique_instruments,
|
||||
instruments=instruments,
|
||||
)
|
||||
|
||||
# Process each data file
|
||||
@ -177,20 +169,6 @@ def main() -> None:
|
||||
for datafile in datafiles:
|
||||
print(f"\n====== Processing {os.path.basename(datafile)} ======")
|
||||
|
||||
# Determine instruments to use
|
||||
if args.instruments:
|
||||
# Use CLI-specified instruments
|
||||
instruments = [inst.strip() for inst in args.instruments.split(",")]
|
||||
print(f"Using CLI-specified instruments: {instruments}")
|
||||
else:
|
||||
# Auto-detect instruments from database
|
||||
instruments = get_available_instruments_from_db(datafile, config)
|
||||
print(f"Auto-detected instruments: {instruments}")
|
||||
|
||||
if not instruments:
|
||||
print(f"No instruments found for {datafile}, skipping...")
|
||||
continue
|
||||
|
||||
# Process data for this file
|
||||
try:
|
||||
fit_method.reset()
|
||||
@ -212,7 +190,15 @@ def main() -> None:
|
||||
|
||||
# Store results in database
|
||||
if args.result_db.upper() != "NONE":
|
||||
store_results_in_database(args.result_db, datafile, bt_results)
|
||||
bt_results.calculate_returns(
|
||||
{
|
||||
filename: {
|
||||
"trades": bt_results.trades.copy(),
|
||||
"outstanding_positions": bt_results.outstanding_positions.copy(),
|
||||
}
|
||||
}
|
||||
)
|
||||
bt_results.store_results_in_database(args.result_db, datafile)
|
||||
|
||||
print(f"Successfully processed {filename}")
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
||||
from pt_trading.fit_method import PairsTradingFitMethod
|
||||
|
||||
|
||||
|
||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Resolve the list of data files to process.
|
||||
@ -44,28 +43,41 @@ def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List
|
||||
|
||||
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
||||
|
||||
def create_pairs(datafile: str, fit_method: PairsTradingFitMethod, price_column: str, config: Dict, instruments: List[str]) -> List:
|
||||
from tools.data_loader import load_market_data
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
all_indexes = range(len(instruments))
|
||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||
pairs = []
|
||||
|
||||
# Update config to use the specified instruments
|
||||
config_copy = config.copy()
|
||||
config_copy["instruments"] = instruments
|
||||
def create_pairs(
|
||||
datafile: str,
|
||||
fit_method: PairsTradingFitMethod,
|
||||
price_column: str,
|
||||
config: Dict,
|
||||
instruments: List[Dict[str, str]],
|
||||
) -> List:
|
||||
from tools.data_loader import load_market_data
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
|
||||
market_data_df = load_market_data(datafile, config=config_copy)
|
||||
all_indexes = range(len(instruments))
|
||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||
pairs = []
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
from research.pt_backtest import TradingPair
|
||||
pair = fit_method.create_trading_pair(
|
||||
config=config_copy,
|
||||
market_data=market_data_df,
|
||||
symbol_a=instruments[a_index],
|
||||
symbol_b=instruments[b_index],
|
||||
price_column=price_column,
|
||||
)
|
||||
pairs.append(pair)
|
||||
return pairs
|
||||
# Update config to use the specified instruments
|
||||
config_copy = config.copy()
|
||||
config_copy["instruments"] = instruments
|
||||
|
||||
market_data_df = load_market_data(
|
||||
datafile=datafile,
|
||||
instruments=instruments,
|
||||
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
|
||||
trading_hours=config_copy["trading_hours"],
|
||||
)
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
from research.pt_backtest import TradingPair
|
||||
|
||||
pair = fit_method.create_trading_pair(
|
||||
config=config_copy,
|
||||
market_data=market_data_df,
|
||||
symbol_a=instruments[a_index]["symbol"],
|
||||
symbol_b=instruments[b_index]["symbol"],
|
||||
price_column=price_column,
|
||||
)
|
||||
pairs.append(pair)
|
||||
return pairs
|
||||
|
||||
Binary file not shown.
@ -20,8 +20,6 @@ from pt_trading.fit_methods import PairsTradingFitMethod
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
|
||||
|
||||
|
||||
|
||||
def run_strategy(
|
||||
config: Dict,
|
||||
datafile: str,
|
||||
@ -44,7 +42,14 @@ def run_strategy(
|
||||
config_copy = config.copy()
|
||||
config_copy["instruments"] = instruments
|
||||
|
||||
market_data_df = load_market_data(datafile, config=config_copy)
|
||||
market_data_df = load_market_data(
|
||||
datafile=datafile,
|
||||
exchange_id=config_copy["exchange_id"],
|
||||
instruments=config_copy["instruments"],
|
||||
instrument_id_pfx=config_copy["instrument_id_pfx"],
|
||||
db_table_name=config_copy["db_table_name"],
|
||||
trading_hours=config_copy["trading_hours"],
|
||||
)
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
pair = fit_method.create_trading_pair(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user