progress and result.py fixes
This commit is contained in:
parent
577fb5c109
commit
e30b0df4db
@ -1,21 +1,13 @@
|
|||||||
{
|
{
|
||||||
"instrument_type_specifics": {
|
"market_data_loading": {
|
||||||
"CRYPTO": {
|
"CRYPTO": {
|
||||||
"data_directory": "./data/crypto",
|
"data_directory": "./data/crypto",
|
||||||
"datafiles": [
|
|
||||||
"20250602.mktdata.ohlcv.db"
|
|
||||||
],
|
|
||||||
"db_table_name": "md_1min_bars",
|
"db_table_name": "md_1min_bars",
|
||||||
"exchange_id": "BNBSPOT",
|
|
||||||
"instrument_id_pfx": "PAIR-",
|
"instrument_id_pfx": "PAIR-",
|
||||||
},
|
},
|
||||||
"EQUITY": {
|
"EQUITY": {
|
||||||
"data_directory": "./data/equity",
|
"data_directory": "./data/equity",
|
||||||
"datafiles": [
|
|
||||||
"20250602.mktdata.ohlcv.db"
|
|
||||||
],
|
|
||||||
"db_table_name": "md_1min_bars",
|
"db_table_name": "md_1min_bars",
|
||||||
"exchange_id": "BNBSPOT",
|
|
||||||
"instrument_id_pfx": "STOCK-",
|
"instrument_id_pfx": "STOCK-",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,212 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Dict, Optional, cast
|
|
||||||
|
|
||||||
import pandas as pd # type: ignore[import]
|
|
||||||
from pt_trading.results import BacktestResult
|
|
||||||
from pt_trading.trading_pair import TradingPair
|
|
||||||
from pt_trading.fit_method import PairsTradingFitMethod
|
|
||||||
|
|
||||||
NanoPerMin = 1e9
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StaticFit(PairsTradingFitMethod):
|
|
||||||
|
|
||||||
def run_pair(
|
|
||||||
self, pair: TradingPair, bt_result: BacktestResult
|
|
||||||
) -> Optional[pd.DataFrame]: # abstractmethod
|
|
||||||
config = pair.config_
|
|
||||||
pair.get_datasets(training_minutes=config["training_minutes"])
|
|
||||||
|
|
||||||
try:
|
|
||||||
pair.predict()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"{pair}: Prediction failed: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
pair_trades = self.create_trading_signals(
|
|
||||||
pair=pair, config=config, result=bt_result
|
|
||||||
)
|
|
||||||
|
|
||||||
return pair_trades
|
|
||||||
|
|
||||||
def create_trading_signals(
|
|
||||||
self, pair: TradingPair, config: Dict, result: BacktestResult
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
beta = pair.vecm_fit_.beta # type: ignore
|
|
||||||
colname_a, colname_b = pair.colnames()
|
|
||||||
|
|
||||||
predicted_df = pair.predicted_df_
|
|
||||||
if predicted_df is None:
|
|
||||||
# Return empty DataFrame with correct columns and dtypes
|
|
||||||
return pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
|
|
||||||
"time": "datetime64[ns]",
|
|
||||||
"action": "string",
|
|
||||||
"symbol": "string",
|
|
||||||
"price": "float64",
|
|
||||||
"disequilibrium": "float64",
|
|
||||||
"scaled_disequilibrium": "float64",
|
|
||||||
"pair": "object"
|
|
||||||
})
|
|
||||||
|
|
||||||
open_threshold = config["dis-equilibrium_open_trshld"]
|
|
||||||
close_threshold = config["dis-equilibrium_close_trshld"]
|
|
||||||
|
|
||||||
# Iterate through the testing dataset to find the first trading opportunity
|
|
||||||
open_row_index = None
|
|
||||||
for row_idx in range(len(predicted_df)):
|
|
||||||
curr_disequilibrium = predicted_df["scaled_disequilibrium"][row_idx]
|
|
||||||
|
|
||||||
# Check if current row has sufficient disequilibrium (not near-zero)
|
|
||||||
if curr_disequilibrium >= open_threshold:
|
|
||||||
open_row_index = row_idx
|
|
||||||
break
|
|
||||||
|
|
||||||
# If no row with sufficient disequilibrium found, skip this pair
|
|
||||||
if open_row_index is None:
|
|
||||||
print(f"{pair}: Insufficient disequilibrium in testing dataset. Skipping.")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# Look for close signal starting from the open position
|
|
||||||
trading_signals_df = (
|
|
||||||
predicted_df["scaled_disequilibrium"][open_row_index:] < close_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust indices to account for the offset from open_row_index
|
|
||||||
close_row_index = None
|
|
||||||
for idx, value in trading_signals_df.items():
|
|
||||||
if value:
|
|
||||||
close_row_index = idx
|
|
||||||
break
|
|
||||||
|
|
||||||
open_row = predicted_df.loc[open_row_index]
|
|
||||||
open_px_a = predicted_df.at[open_row_index, f"{colname_a}"]
|
|
||||||
open_px_b = predicted_df.at[open_row_index, f"{colname_b}"]
|
|
||||||
open_tstamp = predicted_df.at[open_row_index, "tstamp"]
|
|
||||||
open_disequilibrium = open_row["disequilibrium"]
|
|
||||||
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
|
|
||||||
|
|
||||||
abs_beta = abs(beta[1])
|
|
||||||
pred_px_b = predicted_df.loc[open_row_index][f"{colname_b}_pred"]
|
|
||||||
pred_px_a = predicted_df.loc[open_row_index][f"{colname_a}_pred"]
|
|
||||||
|
|
||||||
if pred_px_b * abs_beta - pred_px_a > 0:
|
|
||||||
open_side_a = "BUY"
|
|
||||||
open_side_b = "SELL"
|
|
||||||
close_side_a = "SELL"
|
|
||||||
close_side_b = "BUY"
|
|
||||||
else:
|
|
||||||
open_side_b = "BUY"
|
|
||||||
open_side_a = "SELL"
|
|
||||||
close_side_b = "SELL"
|
|
||||||
close_side_a = "BUY"
|
|
||||||
|
|
||||||
# If no close signal found, print position and unrealized PnL
|
|
||||||
if close_row_index is None:
|
|
||||||
|
|
||||||
last_row_index = len(predicted_df) - 1
|
|
||||||
|
|
||||||
# Use the new method from BacktestResult to handle outstanding positions
|
|
||||||
result.handle_outstanding_position(
|
|
||||||
pair=pair,
|
|
||||||
pair_result_df=predicted_df,
|
|
||||||
last_row_index=last_row_index,
|
|
||||||
open_side_a=open_side_a,
|
|
||||||
open_side_b=open_side_b,
|
|
||||||
open_px_a=float(open_px_a),
|
|
||||||
open_px_b=float(open_px_b),
|
|
||||||
open_tstamp=pd.Timestamp(open_tstamp),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return only open trades (no close trades)
|
|
||||||
trd_signal_tuples = [
|
|
||||||
(
|
|
||||||
open_tstamp,
|
|
||||||
open_side_a,
|
|
||||||
pair.symbol_a_,
|
|
||||||
open_px_a,
|
|
||||||
open_disequilibrium,
|
|
||||||
open_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
open_tstamp,
|
|
||||||
open_side_b,
|
|
||||||
pair.symbol_b_,
|
|
||||||
open_px_b,
|
|
||||||
open_disequilibrium,
|
|
||||||
open_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Close signal found - create complete trade
|
|
||||||
close_row = predicted_df.loc[close_row_index]
|
|
||||||
close_tstamp = close_row["tstamp"]
|
|
||||||
close_disequilibrium = close_row["disequilibrium"]
|
|
||||||
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
|
|
||||||
close_px_a = close_row[f"{colname_a}"]
|
|
||||||
close_px_b = close_row[f"{colname_b}"]
|
|
||||||
|
|
||||||
print(f"{pair}: Close signal found at index {close_row_index}")
|
|
||||||
|
|
||||||
trd_signal_tuples = [
|
|
||||||
(
|
|
||||||
open_tstamp,
|
|
||||||
open_side_a,
|
|
||||||
pair.symbol_a_,
|
|
||||||
open_px_a,
|
|
||||||
open_disequilibrium,
|
|
||||||
open_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
open_tstamp,
|
|
||||||
open_side_b,
|
|
||||||
pair.symbol_b_,
|
|
||||||
open_px_b,
|
|
||||||
open_disequilibrium,
|
|
||||||
open_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
close_tstamp,
|
|
||||||
close_side_a,
|
|
||||||
pair.symbol_a_,
|
|
||||||
close_px_a,
|
|
||||||
close_disequilibrium,
|
|
||||||
close_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
close_tstamp,
|
|
||||||
close_side_b,
|
|
||||||
pair.symbol_b_,
|
|
||||||
close_px_b,
|
|
||||||
close_disequilibrium,
|
|
||||||
close_scaled_disequilibrium,
|
|
||||||
pair,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
|
|
||||||
df = pd.DataFrame(
|
|
||||||
trd_signal_tuples,
|
|
||||||
columns=self.TRADES_COLUMNS,
|
|
||||||
)
|
|
||||||
# Ensure consistent dtypes
|
|
||||||
return df.astype({
|
|
||||||
"time": "datetime64[ns]",
|
|
||||||
"action": "string",
|
|
||||||
"symbol": "string",
|
|
||||||
"price": "float64",
|
|
||||||
"disequilibrium": "float64",
|
|
||||||
"scaled_disequilibrium": "float64",
|
|
||||||
"pair": "object"
|
|
||||||
})
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ def create_result_database(db_path: str) -> None:
|
|||||||
if db_dir and not os.path.exists(db_dir):
|
if db_dir and not os.path.exists(db_dir):
|
||||||
os.makedirs(db_dir, exist_ok=True)
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
print(f"Created directory: {db_dir}")
|
print(f"Created directory: {db_dir}")
|
||||||
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@ -68,7 +68,8 @@ def create_result_database(db_path: str) -> None:
|
|||||||
close_quantity INTEGER,
|
close_quantity INTEGER,
|
||||||
close_disequilibrium REAL,
|
close_disequilibrium REAL,
|
||||||
symbol_return REAL,
|
symbol_return REAL,
|
||||||
pair_return REAL
|
pair_return REAL,
|
||||||
|
close_condition TEXT
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -121,7 +122,7 @@ def store_config_in_database(
|
|||||||
config: Dict,
|
config: Dict,
|
||||||
fit_method_class: str,
|
fit_method_class: str,
|
||||||
datafiles: List[str],
|
datafiles: List[str],
|
||||||
instruments: List[str],
|
instruments: List[Dict[str, str]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Store configuration information in the database for reference.
|
Store configuration information in the database for reference.
|
||||||
@ -140,7 +141,12 @@ def store_config_in_database(
|
|||||||
|
|
||||||
# Convert lists to comma-separated strings for storage
|
# Convert lists to comma-separated strings for storage
|
||||||
datafiles_str = ", ".join(datafiles)
|
datafiles_str = ", ".join(datafiles)
|
||||||
instruments_str = ", ".join(instruments)
|
instruments_str = ", ".join(
|
||||||
|
[
|
||||||
|
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
|
||||||
|
for inst in instruments
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Insert configuration record
|
# Insert configuration record
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@ -170,6 +176,7 @@ def store_config_in_database(
|
|||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
|
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
|
||||||
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
|
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
|
||||||
if timestamp is None:
|
if timestamp is None:
|
||||||
@ -187,244 +194,6 @@ def convert_timestamp(timestamp: Any) -> Optional[datetime]:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
|
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
|
||||||
|
|
||||||
|
|
||||||
def store_results_in_database(
|
|
||||||
db_path: str, datafile: str, bt_result: "BacktestResult"
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Store backtest results in the SQLite database.
|
|
||||||
"""
|
|
||||||
if db_path.upper() == "NONE":
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
|
|
||||||
filename = os.path.basename(datafile)
|
|
||||||
date_str = filename.split(".")[0] # Extract date part
|
|
||||||
|
|
||||||
# Convert to proper date format
|
|
||||||
try:
|
|
||||||
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
|
|
||||||
except ValueError:
|
|
||||||
# If date parsing fails, use current date
|
|
||||||
date_obj = datetime.now().date()
|
|
||||||
|
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Process each trade from bt_result
|
|
||||||
trades = bt_result.get_trades()
|
|
||||||
|
|
||||||
for pair_name, symbols in trades.items():
|
|
||||||
# Calculate pair return for this pair
|
|
||||||
pair_return = 0.0
|
|
||||||
pair_trades = []
|
|
||||||
|
|
||||||
# First pass: collect all trades and calculate returns
|
|
||||||
for symbol, symbol_trades in symbols.items():
|
|
||||||
if len(symbol_trades) == 0: # No trades for this symbol
|
|
||||||
print(
|
|
||||||
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
|
|
||||||
# Handle both old and new tuple formats
|
|
||||||
if len(symbol_trades[0]) == 2: # Old format: (action, price)
|
|
||||||
entry_action, entry_price = symbol_trades[0]
|
|
||||||
exit_action, exit_price = symbol_trades[1]
|
|
||||||
open_disequilibrium = 0.0 # Fallback for old format
|
|
||||||
open_scaled_disequilibrium = 0.0
|
|
||||||
close_disequilibrium = 0.0
|
|
||||||
close_scaled_disequilibrium = 0.0
|
|
||||||
open_time = datetime.now()
|
|
||||||
close_time = datetime.now()
|
|
||||||
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
|
|
||||||
(
|
|
||||||
entry_action,
|
|
||||||
entry_price,
|
|
||||||
open_disequilibrium,
|
|
||||||
open_scaled_disequilibrium,
|
|
||||||
open_time,
|
|
||||||
) = symbol_trades[0]
|
|
||||||
(
|
|
||||||
exit_action,
|
|
||||||
exit_price,
|
|
||||||
close_disequilibrium,
|
|
||||||
close_scaled_disequilibrium,
|
|
||||||
close_time,
|
|
||||||
) = symbol_trades[1]
|
|
||||||
|
|
||||||
# Handle None values
|
|
||||||
open_disequilibrium = (
|
|
||||||
open_disequilibrium
|
|
||||||
if open_disequilibrium is not None
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
open_scaled_disequilibrium = (
|
|
||||||
open_scaled_disequilibrium
|
|
||||||
if open_scaled_disequilibrium is not None
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
close_disequilibrium = (
|
|
||||||
close_disequilibrium
|
|
||||||
if close_disequilibrium is not None
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
close_scaled_disequilibrium = (
|
|
||||||
close_scaled_disequilibrium
|
|
||||||
if close_scaled_disequilibrium is not None
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert pandas Timestamps to Python datetime objects
|
|
||||||
open_time = convert_timestamp(open_time) or datetime.now()
|
|
||||||
close_time = convert_timestamp(close_time) or datetime.now()
|
|
||||||
|
|
||||||
# Calculate actual share quantities based on funding per pair
|
|
||||||
# Split funding equally between the two positions
|
|
||||||
funding_per_position = bt_result.config["funding_per_pair"] / 2
|
|
||||||
shares = funding_per_position / entry_price
|
|
||||||
|
|
||||||
# Calculate symbol return
|
|
||||||
symbol_return = 0.0
|
|
||||||
if entry_action == "BUY" and exit_action == "SELL":
|
|
||||||
symbol_return = (exit_price - entry_price) / entry_price * 100
|
|
||||||
elif entry_action == "SELL" and exit_action == "BUY":
|
|
||||||
symbol_return = (entry_price - exit_price) / entry_price * 100
|
|
||||||
|
|
||||||
pair_return += symbol_return
|
|
||||||
|
|
||||||
pair_trades.append(
|
|
||||||
{
|
|
||||||
"symbol": symbol,
|
|
||||||
"entry_action": entry_action,
|
|
||||||
"entry_price": entry_price,
|
|
||||||
"exit_action": exit_action,
|
|
||||||
"exit_price": exit_price,
|
|
||||||
"symbol_return": symbol_return,
|
|
||||||
"open_disequilibrium": open_disequilibrium,
|
|
||||||
"open_scaled_disequilibrium": open_scaled_disequilibrium,
|
|
||||||
"close_disequilibrium": close_disequilibrium,
|
|
||||||
"close_scaled_disequilibrium": close_scaled_disequilibrium,
|
|
||||||
"open_time": open_time,
|
|
||||||
"close_time": close_time,
|
|
||||||
"shares": shares,
|
|
||||||
"is_completed": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip one-sided trades - they will be handled by outstanding_positions table
|
|
||||||
elif len(symbol_trades) == 1:
|
|
||||||
print(
|
|
||||||
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
else:
|
|
||||||
# This should not happen, but handle unexpected cases
|
|
||||||
print(
|
|
||||||
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Second pass: insert completed trade records into database
|
|
||||||
for trade in pair_trades:
|
|
||||||
# Only store completed trades in pt_bt_results table
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO pt_bt_results (
|
|
||||||
date, pair, symbol, open_time, open_side, open_price,
|
|
||||||
open_quantity, open_disequilibrium, close_time, close_side,
|
|
||||||
close_price, close_quantity, close_disequilibrium,
|
|
||||||
symbol_return, pair_return
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
date_obj,
|
|
||||||
pair_name,
|
|
||||||
trade["symbol"],
|
|
||||||
trade["open_time"],
|
|
||||||
trade["entry_action"],
|
|
||||||
trade["entry_price"],
|
|
||||||
trade["shares"],
|
|
||||||
trade["open_scaled_disequilibrium"],
|
|
||||||
trade["close_time"],
|
|
||||||
trade["exit_action"],
|
|
||||||
trade["exit_price"],
|
|
||||||
trade["shares"],
|
|
||||||
trade["close_scaled_disequilibrium"],
|
|
||||||
trade["symbol_return"],
|
|
||||||
pair_return,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store outstanding positions in separate table
|
|
||||||
outstanding_positions = bt_result.get_outstanding_positions()
|
|
||||||
for pos in outstanding_positions:
|
|
||||||
# Calculate position quantity (negative for SELL positions)
|
|
||||||
position_qty_a = (
|
|
||||||
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
|
|
||||||
)
|
|
||||||
position_qty_b = (
|
|
||||||
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate unrealized returns
|
|
||||||
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
|
|
||||||
unrealized_return_a = (
|
|
||||||
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
|
|
||||||
) * (1 if pos["side_a"] == "BUY" else -1)
|
|
||||||
unrealized_return_b = (
|
|
||||||
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
|
|
||||||
) * (1 if pos["side_b"] == "BUY" else -1)
|
|
||||||
|
|
||||||
# Store outstanding position for symbol A
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO outstanding_positions (
|
|
||||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
date_obj,
|
|
||||||
pos["pair"],
|
|
||||||
pos["symbol_a"],
|
|
||||||
position_qty_a,
|
|
||||||
pos["current_px_a"],
|
|
||||||
unrealized_return_a,
|
|
||||||
pos["open_px_a"],
|
|
||||||
pos["side_a"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store outstanding position for symbol B
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO outstanding_positions (
|
|
||||||
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
date_obj,
|
|
||||||
pos["pair"],
|
|
||||||
pos["symbol_b"],
|
|
||||||
position_qty_b,
|
|
||||||
pos["current_px_b"],
|
|
||||||
unrealized_return_b,
|
|
||||||
pos["open_px_b"],
|
|
||||||
pos["side_b"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error storing results in database: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
class BacktestResult:
|
class BacktestResult:
|
||||||
@ -437,7 +206,8 @@ class BacktestResult:
|
|||||||
self.trades: Dict[str, Dict[str, Any]] = {}
|
self.trades: Dict[str, Dict[str, Any]] = {}
|
||||||
self.total_realized_pnl = 0.0
|
self.total_realized_pnl = 0.0
|
||||||
self.outstanding_positions: List[Dict[str, Any]] = []
|
self.outstanding_positions: List[Dict[str, Any]] = []
|
||||||
|
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
def add_trade(
|
def add_trade(
|
||||||
self,
|
self,
|
||||||
pair_nm: str,
|
pair_nm: str,
|
||||||
@ -458,15 +228,16 @@ class BacktestResult:
|
|||||||
if symbol not in self.trades[pair_nm]:
|
if symbol not in self.trades[pair_nm]:
|
||||||
self.trades[pair_nm][symbol] = []
|
self.trades[pair_nm][symbol] = []
|
||||||
self.trades[pair_nm][symbol].append(
|
self.trades[pair_nm][symbol].append(
|
||||||
{"symbol":symbol,
|
{
|
||||||
"side":side,
|
"symbol": symbol,
|
||||||
"action":action,
|
"side": side,
|
||||||
"price":price,
|
"action": action,
|
||||||
"disequilibrium":disequilibrium,
|
"price": price,
|
||||||
"scaled_disequilibrium":scaled_disequilibrium,
|
"disequilibrium": disequilibrium,
|
||||||
"timestamp":timestamp,
|
"scaled_disequilibrium": scaled_disequilibrium,
|
||||||
"status":status
|
"timestamp": timestamp,
|
||||||
}
|
"status": status,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_outstanding_position(self, position: Dict[str, Any]) -> None:
|
def add_outstanding_position(self, position: Dict[str, Any]) -> None:
|
||||||
@ -549,97 +320,126 @@ class BacktestResult:
|
|||||||
|
|
||||||
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
|
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
|
||||||
"""Calculate and print returns by day and pair."""
|
"""Calculate and print returns by day and pair."""
|
||||||
|
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
|
||||||
|
if trade1_side == "BUY" and trade2_side == "SELL":
|
||||||
|
return (trade2_px - trade1_px) / trade1_px * 100
|
||||||
|
elif trade1_side == "SELL" and trade2_side == "BUY":
|
||||||
|
return (trade1_px - trade2_px) / trade1_px * 100
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
print("\n====== Returns By Day and Pair ======")
|
print("\n====== Returns By Day and Pair ======")
|
||||||
|
|
||||||
|
trades = []
|
||||||
for filename, data in all_results.items():
|
for filename, data in all_results.items():
|
||||||
day_return = 0
|
pairs = list(data["trades"].keys())
|
||||||
|
for pair in pairs:
|
||||||
|
self.pairs_trades_[pair] = []
|
||||||
|
trades_dict = data["trades"][pair]
|
||||||
|
for symbol in trades_dict.keys():
|
||||||
|
trades.extend(trades_dict[symbol])
|
||||||
|
trades = sorted(trades, key=lambda x: (x["timestamp"], x["symbol"]))
|
||||||
|
|
||||||
print(f"\n--- {filename} ---")
|
print(f"\n--- {filename} ---")
|
||||||
|
|
||||||
self.outstanding_positions = data["outstanding_positions"]
|
self.outstanding_positions = data["outstanding_positions"]
|
||||||
|
|
||||||
|
day_return = 0.0
|
||||||
|
for idx in range(0, len(trades), 4):
|
||||||
|
symbol_a = trades[idx]["symbol"]
|
||||||
|
trade_a_1 = trades[idx]
|
||||||
|
trade_a_2 = trades[idx + 2]
|
||||||
|
|
||||||
# Process each pair
|
symbol_b = trades[idx + 1]["symbol"]
|
||||||
for pair, symbols in data["trades"].items():
|
trade_b_1 = trades[idx + 1]
|
||||||
pair_return = 0
|
trade_b_2 = trades[idx + 3]
|
||||||
pair_trades = []
|
|
||||||
|
|
||||||
# Calculate individual symbol returns in the pair
|
symbol_return = 0
|
||||||
for symbol, trades in symbols.items():
|
assert (
|
||||||
if len(trades) == 0:
|
trade_a_1["timestamp"] < trade_a_2["timestamp"]
|
||||||
continue
|
), f"Trade 1: {trade_a_1['timestamp']} is not less than Trade 2: {trade_a_2['timestamp']}"
|
||||||
symbol_return = 0
|
assert (
|
||||||
symbol_trades = [trade for trade in trades if trade["symbol"] == symbol]
|
trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"
|
||||||
|
), f"Trade 1: {trade_a_1['action']} and Trade 2: {trade_a_2['action']} are the same"
|
||||||
|
|
||||||
# Calculate returns for all trade combinations
|
# Calculate return based on action combination
|
||||||
for idx in range(0, len(symbol_trades), 2):
|
trade_return = 0
|
||||||
trade1 = trades[idx]
|
symbol_a_return = _symbol_return(trade_a_1["side"], trade_a_1["price"], trade_a_2["side"], trade_a_2["price"])
|
||||||
trade2 = trades[idx + 1]
|
symbol_b_return = _symbol_return(trade_b_1["side"], trade_b_1["price"], trade_b_2["side"], trade_b_2["price"])
|
||||||
|
|
||||||
assert trade1["timestamp"] < trade2["timestamp"], f"Trade 1: {trade1['timestamp']} is not less than Trade 2: {trade2['timestamp']}"
|
|
||||||
assert trade1["action"] == "OPEN" and trade2["action"] == "CLOSE", f"Trade 1: {trade1['action']} and Trade 2: {trade2['action']} are the same"
|
|
||||||
|
|
||||||
# Calculate return based on action combination
|
|
||||||
trade_return = 0
|
|
||||||
if trade1["side"] == "BUY" and trade2["side"] == "SELL":
|
|
||||||
# Long position
|
|
||||||
trade_return = (trade2["price"] - trade1["price"]) / trade1["price"] * 100
|
|
||||||
elif trade1["side"] == "SELL" and trade2["side"] == "BUY":
|
|
||||||
# Short position
|
|
||||||
trade_return = (trade1["price"] - trade2["price"]) / trade1["price"] * 100
|
|
||||||
|
|
||||||
symbol_return += trade_return
|
|
||||||
|
|
||||||
# Store trade details for reporting
|
|
||||||
pair_trades.append(
|
|
||||||
(
|
|
||||||
symbol,
|
|
||||||
trade1["timestamp"],
|
|
||||||
trade2["timestamp"],
|
|
||||||
trade1["side"],
|
|
||||||
trade1["price"],
|
|
||||||
trade2["side"],
|
|
||||||
trade2["price"],
|
|
||||||
trade_return,
|
|
||||||
trade1["scaled_disequilibrium"],
|
|
||||||
trade2["scaled_disequilibrium"],
|
|
||||||
f"{idx + 1}", # Trade sequence number
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pair_return += symbol_return
|
|
||||||
|
|
||||||
# Print pair returns with disequilibrium information
|
pair_return = symbol_a_return + symbol_b_return
|
||||||
if pair_trades:
|
|
||||||
print(f" {pair}:")
|
|
||||||
for (
|
|
||||||
symbol,
|
|
||||||
trade1["timestamp"],
|
|
||||||
trade2["timestamp"],
|
|
||||||
trade1["side"],
|
|
||||||
trade1["price"],
|
|
||||||
trade2["side"],
|
|
||||||
trade2["price"],
|
|
||||||
trade_return,
|
|
||||||
trade1["scaled_disequilibrium"],
|
|
||||||
trade2["scaled_disequilibrium"],
|
|
||||||
trade_num,
|
|
||||||
) in pair_trades:
|
|
||||||
disequil_info = ""
|
|
||||||
if (
|
|
||||||
trade1["scaled_disequilibrium"] is not None
|
|
||||||
and trade2["scaled_disequilibrium"] is not None
|
|
||||||
):
|
|
||||||
disequil_info = f" | Open Dis-eq: {trade1["scaled_disequilibrium"]:.2f},"
|
|
||||||
f" Close Dis-eq: {trade2["scaled_disequilibrium"]:.2f}"
|
|
||||||
|
|
||||||
print(
|
self.pairs_trades_[pair].append(
|
||||||
f" {trade2['timestamp'].time()} {symbol} (Trade #{trade_num}):"
|
{
|
||||||
f" {trade1["side"]} @ ${trade1["price"]:.2f},"
|
"symbol": symbol_a,
|
||||||
f" {trade2["side"]} @ ${trade2["price"]:.2f},"
|
"open_side": trade_a_1["side"],
|
||||||
f" Return: {trade_return:.2f}%{disequil_info}"
|
"open_action": trade_a_1["action"],
|
||||||
)
|
"open_price": trade_a_1["price"],
|
||||||
print(f" Pair Total Return: {pair_return:.2f}%")
|
"close_side": trade_a_2["side"],
|
||||||
day_return += pair_return
|
"close_action": trade_a_2["action"],
|
||||||
|
"close_price": trade_a_2["price"],
|
||||||
|
"symbol_return": symbol_a_return,
|
||||||
|
"open_disequilibrium": trade_a_1["disequilibrium"],
|
||||||
|
"open_scaled_disequilibrium": trade_a_1["scaled_disequilibrium"],
|
||||||
|
"close_disequilibrium": trade_a_2["disequilibrium"],
|
||||||
|
"close_scaled_disequilibrium": trade_a_2["scaled_disequilibrium"],
|
||||||
|
"open_time": trade_a_1["timestamp"],
|
||||||
|
"close_time": trade_a_2["timestamp"],
|
||||||
|
"shares": self.config["funding_per_pair"] / 2 / trade_a_1["price"],
|
||||||
|
"is_completed": True,
|
||||||
|
"close_condition": trade_a_2["status"],
|
||||||
|
"pair_return": pair_return
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.pairs_trades_[pair].append(
|
||||||
|
{
|
||||||
|
"symbol": symbol_b,
|
||||||
|
"open_side": trade_b_1["side"],
|
||||||
|
"open_action": trade_b_1["action"],
|
||||||
|
"open_price": trade_b_1["price"],
|
||||||
|
"close_side": trade_b_2["side"],
|
||||||
|
"close_action": trade_b_2["action"],
|
||||||
|
"close_price": trade_b_2["price"],
|
||||||
|
"symbol_return": symbol_b_return,
|
||||||
|
"open_disequilibrium": trade_b_1["disequilibrium"],
|
||||||
|
"open_scaled_disequilibrium": trade_b_1["scaled_disequilibrium"],
|
||||||
|
"close_disequilibrium": trade_b_2["disequilibrium"],
|
||||||
|
"close_scaled_disequilibrium": trade_b_2["scaled_disequilibrium"],
|
||||||
|
"open_time": trade_b_1["timestamp"],
|
||||||
|
"close_time": trade_b_2["timestamp"],
|
||||||
|
"shares": self.config["funding_per_pair"] / 2 / trade_b_1["price"],
|
||||||
|
"is_completed": True,
|
||||||
|
"close_condition": trade_b_2["status"],
|
||||||
|
"pair_return": pair_return
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Print pair returns with disequilibrium information
|
||||||
|
day_return = 0.0
|
||||||
|
if self.pairs_trades_[pair]:
|
||||||
|
|
||||||
|
print(f"{pair}:")
|
||||||
|
pair_return = 0.0
|
||||||
|
for trd in self.pairs_trades_[pair]:
|
||||||
|
disequil_info = ""
|
||||||
|
if (
|
||||||
|
trd["open_scaled_disequilibrium"] is not None
|
||||||
|
and trd["open_scaled_disequilibrium"] is not None
|
||||||
|
):
|
||||||
|
disequil_info = f" | Open Dis-eq: {trd['open_scaled_disequilibrium']:.2f},"
|
||||||
|
f" Close Dis-eq: {trd['open_scaled_disequilibrium']:.2f}"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" {trd['open_time'].time()} {trd['symbol']}: "
|
||||||
|
f" {trd['open_side']} @ ${trd['open_price']:.2f},"
|
||||||
|
f" {trd["close_side"]} @ ${trd["close_price"]:.2f},"
|
||||||
|
f" Return: {trd['symbol_return']:.2f}%{disequil_info}"
|
||||||
|
)
|
||||||
|
pair_return += trd["symbol_return"]
|
||||||
|
|
||||||
|
print(f" Pair Total Return: {pair_return:.2f}%")
|
||||||
|
day_return += pair_return
|
||||||
|
|
||||||
# Print day total return and add to global realized PnL
|
# Print day total return and add to global realized PnL
|
||||||
if day_return != 0:
|
if day_return != 0:
|
||||||
@ -716,7 +516,7 @@ class BacktestResult:
|
|||||||
|
|
||||||
print("-" * 100)
|
print("-" * 100)
|
||||||
|
|
||||||
total_value += pos["total_current_value"]
|
total_value += pos["total_current_value"]
|
||||||
|
|
||||||
print(f"{'TOTAL OUTSTANDING VALUE':<80} ${total_value:<12.2f}")
|
print(f"{'TOTAL OUTSTANDING VALUE':<80} ${total_value:<12.2f}")
|
||||||
|
|
||||||
@ -811,3 +611,132 @@ class BacktestResult:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return current_value_a, current_value_b, total_current_value
|
return current_value_a, current_value_b, total_current_value
|
||||||
|
|
||||||
|
def store_results_in_database(
|
||||||
|
self, db_path: str, datafile: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store backtest results in the SQLite database.
|
||||||
|
"""
|
||||||
|
if db_path.upper() == "NONE":
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
|
||||||
|
filename = os.path.basename(datafile)
|
||||||
|
date_str = filename.split(".")[0] # Extract date part
|
||||||
|
|
||||||
|
# Convert to proper date format
|
||||||
|
try:
|
||||||
|
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
|
||||||
|
except ValueError:
|
||||||
|
# If date parsing fails, use current date
|
||||||
|
date_obj = datetime.now().date()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Process each trade from bt_result
|
||||||
|
trades = self.get_trades()
|
||||||
|
|
||||||
|
for pair_name, _ in trades.items():
|
||||||
|
|
||||||
|
# Second pass: insert completed trade records into database
|
||||||
|
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]):
|
||||||
|
# Only store completed trades in pt_bt_results table
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO pt_bt_results (
|
||||||
|
date, pair, symbol, open_time, open_side, open_price,
|
||||||
|
open_quantity, open_disequilibrium, close_time, close_side,
|
||||||
|
close_price, close_quantity, close_disequilibrium,
|
||||||
|
symbol_return, pair_return, close_condition
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
date_obj,
|
||||||
|
pair_name,
|
||||||
|
trade_pair["symbol"],
|
||||||
|
trade_pair["open_time"],
|
||||||
|
trade_pair["open_side"],
|
||||||
|
trade_pair["open_price"],
|
||||||
|
trade_pair["shares"],
|
||||||
|
trade_pair["open_scaled_disequilibrium"],
|
||||||
|
trade_pair["close_time"],
|
||||||
|
trade_pair["close_side"],
|
||||||
|
trade_pair["close_price"],
|
||||||
|
trade_pair["shares"],
|
||||||
|
trade_pair["close_scaled_disequilibrium"],
|
||||||
|
trade_pair["symbol_return"],
|
||||||
|
trade_pair["pair_return"],
|
||||||
|
trade_pair["close_condition"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store outstanding positions in separate table
|
||||||
|
outstanding_positions = self.get_outstanding_positions()
|
||||||
|
for pos in outstanding_positions:
|
||||||
|
# Calculate position quantity (negative for SELL positions)
|
||||||
|
position_qty_a = (
|
||||||
|
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
|
||||||
|
)
|
||||||
|
position_qty_b = (
|
||||||
|
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate unrealized returns
|
||||||
|
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
|
||||||
|
unrealized_return_a = (
|
||||||
|
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
|
||||||
|
) * (1 if pos["side_a"] == "BUY" else -1)
|
||||||
|
unrealized_return_b = (
|
||||||
|
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
|
||||||
|
) * (1 if pos["side_b"] == "BUY" else -1)
|
||||||
|
|
||||||
|
# Store outstanding position for symbol A
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO outstanding_positions (
|
||||||
|
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
date_obj,
|
||||||
|
pos["pair"],
|
||||||
|
pos["symbol_a"],
|
||||||
|
position_qty_a,
|
||||||
|
pos["current_px_a"],
|
||||||
|
unrealized_return_a,
|
||||||
|
pos["open_px_a"],
|
||||||
|
pos["side_a"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store outstanding position for symbol B
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO outstanding_positions (
|
||||||
|
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
date_obj,
|
||||||
|
pos["pair"],
|
||||||
|
pos["symbol_b"],
|
||||||
|
position_qty_b,
|
||||||
|
pos["current_px_b"],
|
||||||
|
unrealized_return_b,
|
||||||
|
pos["open_px_b"],
|
||||||
|
pos["side_b"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error storing results in database: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Dict, List, cast
|
from typing import Dict, List, cast
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_sqlite_to_dataframe(db_path, query):
|
def load_sqlite_to_dataframe(db_path, query):
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
@ -35,20 +36,27 @@ def convert_time_to_UTC(value: str, timezone: str) -> str:
|
|||||||
return result.strftime("%Y-%m-%d %H:%M:%S")
|
return result.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
|
def load_market_data(
|
||||||
from tools.data_loader import load_sqlite_to_dataframe
|
datafile: str,
|
||||||
|
instruments: List[Dict[str, str]],
|
||||||
|
db_table_name: str,
|
||||||
|
trading_hours: Dict = {},
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
|
||||||
instrument_ids = [
|
insts = [
|
||||||
'"' + config["instrument_id_pfx"] + instrument + '"'
|
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
|
||||||
for instrument in config["instruments"]
|
for instrument in instruments
|
||||||
]
|
]
|
||||||
exchange_id = config["exchange_id"]
|
instrument_ids = list(set(insts))
|
||||||
|
exchange_ids = list(
|
||||||
|
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
|
||||||
|
)
|
||||||
|
|
||||||
query = "select"
|
query = "select"
|
||||||
query += " tstamp"
|
query += " tstamp"
|
||||||
query += ", tstamp_ns as time_ns"
|
query += ", tstamp_ns as time_ns"
|
||||||
|
|
||||||
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol"
|
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
|
||||||
query += ", open"
|
query += ", open"
|
||||||
query += ", high"
|
query += ", high"
|
||||||
query += ", low"
|
query += ", low"
|
||||||
@ -57,74 +65,76 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
|
|||||||
query += ", num_trades"
|
query += ", num_trades"
|
||||||
query += ", vwap"
|
query += ", vwap"
|
||||||
|
|
||||||
query += f" from {config['db_table_name']}"
|
query += f" from {db_table_name}"
|
||||||
query += f" where exchange_id ='{exchange_id}'"
|
query += f" where exchange_id in ({','.join(exchange_ids)})"
|
||||||
query += f" and instrument_id in ({','.join(instrument_ids)})"
|
query += f" and instrument_id in ({','.join(instrument_ids)})"
|
||||||
|
|
||||||
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
|
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
|
||||||
|
|
||||||
# Trading Hours
|
# Trading Hours
|
||||||
date_str = df["tstamp"][0][0:10]
|
if len(df) > 0 and len(trading_hours) > 0:
|
||||||
trading_hours = config["trading_hours"]
|
date_str = df["tstamp"][0][0:10]
|
||||||
|
|
||||||
start_time = convert_time_to_UTC(
|
start_time = convert_time_to_UTC(
|
||||||
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
|
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
|
||||||
)
|
)
|
||||||
end_time = convert_time_to_UTC(
|
end_time = convert_time_to_UTC(
|
||||||
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
|
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform boolean selection
|
# Perform boolean selection
|
||||||
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
|
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
|
||||||
df["tstamp"] = pd.to_datetime(df["tstamp"])
|
df["tstamp"] = pd.to_datetime(df["tstamp"])
|
||||||
|
|
||||||
return cast(pd.DataFrame, df)
|
return cast(pd.DataFrame, df)
|
||||||
|
|
||||||
|
|
||||||
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
|
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
|
||||||
"""
|
# """
|
||||||
Auto-detect available instruments from the database by querying distinct instrument_id values.
|
# Auto-detect available instruments from the database by querying distinct instrument_id values.
|
||||||
Returns instruments without the configured prefix.
|
# Returns instruments without the configured prefix.
|
||||||
"""
|
# """
|
||||||
try:
|
# try:
|
||||||
conn = sqlite3.connect(datafile)
|
# conn = sqlite3.connect(datafile)
|
||||||
|
|
||||||
# Build exclusion list with full instrument_ids
|
# # Build exclusion list with full instrument_ids
|
||||||
exclude_instruments = config.get("exclude_instruments", [])
|
# exclude_instruments = config.get("exclude_instruments", [])
|
||||||
prefix = config.get("instrument_id_pfx", "")
|
# prefix = config.get("instrument_id_pfx", "")
|
||||||
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
|
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
|
||||||
|
|
||||||
# Query to get distinct instrument_ids
|
|
||||||
query = f"""
|
|
||||||
SELECT DISTINCT instrument_id
|
|
||||||
FROM {config['db_table_name']}
|
|
||||||
WHERE exchange_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Add exclusion clause if there are instruments to exclude
|
|
||||||
if exclude_instrument_ids:
|
|
||||||
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
|
|
||||||
query += f" AND instrument_id NOT IN ({placeholders})"
|
|
||||||
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
|
|
||||||
else:
|
|
||||||
cursor = conn.execute(query, (config["exchange_id"],))
|
|
||||||
instrument_ids = [row[0] for row in cursor.fetchall()]
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# Remove the configured prefix to get instrument symbols
|
# # Query to get distinct instrument_ids
|
||||||
instruments = []
|
# query = f"""
|
||||||
for instrument_id in instrument_ids:
|
# SELECT DISTINCT instrument_id
|
||||||
if instrument_id.startswith(prefix):
|
# FROM {config['db_table_name']}
|
||||||
symbol = instrument_id[len(prefix) :]
|
# WHERE exchange_id = ?
|
||||||
instruments.append(symbol)
|
# """
|
||||||
else:
|
|
||||||
instruments.append(instrument_id)
|
|
||||||
|
|
||||||
return sorted(instruments)
|
# # Add exclusion clause if there are instruments to exclude
|
||||||
|
# if exclude_instrument_ids:
|
||||||
|
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
|
||||||
|
# query += f" AND instrument_id NOT IN ({placeholders})"
|
||||||
|
# cursor = conn.execute(
|
||||||
|
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# cursor = conn.execute(query, (config["exchange_id"],))
|
||||||
|
# instrument_ids = [row[0] for row in cursor.fetchall()]
|
||||||
|
# conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
# # Remove the configured prefix to get instrument symbols
|
||||||
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
|
# instruments = []
|
||||||
return []
|
# for instrument_id in instrument_ids:
|
||||||
|
# if instrument_id.startswith(prefix):
|
||||||
|
# symbol = instrument_id[len(prefix) :]
|
||||||
|
# instruments.append(symbol)
|
||||||
|
# else:
|
||||||
|
# instruments.append(instrument_id)
|
||||||
|
|
||||||
|
# return sorted(instruments)
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
|
||||||
|
# return []
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from tools.config import expand_filename, load_config
|
from tools.config import expand_filename, load_config
|
||||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
from tools.data_loader import get_available_instruments_from_db
|
||||||
from pt_trading.results import (
|
from pt_trading.results import (
|
||||||
BacktestResult,
|
BacktestResult,
|
||||||
create_result_database,
|
create_result_database,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -9,62 +9,62 @@ import pandas as pd
|
|||||||
|
|
||||||
from research.research_tools import create_pairs
|
from research.research_tools import create_pairs
|
||||||
from tools.config import expand_filename, load_config
|
from tools.config import expand_filename, load_config
|
||||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
|
||||||
from pt_trading.results import (
|
from pt_trading.results import (
|
||||||
BacktestResult,
|
BacktestResult,
|
||||||
create_result_database,
|
create_result_database,
|
||||||
store_config_in_database,
|
store_config_in_database,
|
||||||
store_results_in_database,
|
|
||||||
)
|
)
|
||||||
from pt_trading.fit_method import PairsTradingFitMethod
|
from pt_trading.fit_method import PairsTradingFitMethod
|
||||||
from pt_trading.trading_pair import TradingPair
|
from pt_trading.trading_pair import TradingPair
|
||||||
|
|
||||||
|
|
||||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
def resolve_datafiles(
|
||||||
"""
|
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
|
||||||
Resolve the list of data files to process.
|
) -> List[str]:
|
||||||
CLI datafiles take priority over config datafiles.
|
|
||||||
Supports wildcards in config but not in CLI.
|
|
||||||
"""
|
|
||||||
if cli_datafiles:
|
|
||||||
# CLI override - comma-separated list, no wildcards
|
|
||||||
datafiles = [f.strip() for f in cli_datafiles.split(",")]
|
|
||||||
# Make paths absolute relative to data directory
|
|
||||||
data_dir = config.get("data_directory", "./data")
|
|
||||||
resolved_files = []
|
|
||||||
for df in datafiles:
|
|
||||||
if not os.path.isabs(df):
|
|
||||||
df = os.path.join(data_dir, df)
|
|
||||||
resolved_files.append(df)
|
|
||||||
return resolved_files
|
|
||||||
|
|
||||||
# Use config datafiles with wildcard support
|
|
||||||
config_datafiles = config.get("datafiles", [])
|
|
||||||
data_dir = config.get("data_directory", "./data")
|
|
||||||
resolved_files = []
|
resolved_files = []
|
||||||
|
for inst in instruments:
|
||||||
for pattern in config_datafiles:
|
pattern = date_pattern
|
||||||
|
inst_type = inst["instrument_type"]
|
||||||
|
data_dir = config["market_data_loading"][inst_type]["data_directory"]
|
||||||
if "*" in pattern or "?" in pattern:
|
if "*" in pattern or "?" in pattern:
|
||||||
# Handle wildcards
|
# Handle wildcards
|
||||||
if not os.path.isabs(pattern):
|
if not os.path.isabs(pattern):
|
||||||
pattern = os.path.join(data_dir, pattern)
|
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
||||||
matched_files = glob.glob(pattern)
|
matched_files = glob.glob(pattern)
|
||||||
resolved_files.extend(matched_files)
|
resolved_files.extend(matched_files)
|
||||||
else:
|
else:
|
||||||
# Handle explicit file path
|
# Handle explicit file path
|
||||||
if not os.path.isabs(pattern):
|
if not os.path.isabs(pattern):
|
||||||
pattern = os.path.join(data_dir, pattern)
|
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
||||||
resolved_files.append(pattern)
|
resolved_files.append(pattern)
|
||||||
|
|
||||||
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
||||||
|
|
||||||
|
|
||||||
|
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
|
||||||
|
|
||||||
|
instruments = [
|
||||||
|
{
|
||||||
|
"symbol": inst.split(":")[0],
|
||||||
|
"instrument_type": inst.split(":")[1],
|
||||||
|
"exchange_id": inst.split(":")[2],
|
||||||
|
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
|
||||||
|
"instrument_id_pfx"
|
||||||
|
],
|
||||||
|
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
|
||||||
|
"db_table_name"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for inst in args.instruments.split(",")
|
||||||
|
]
|
||||||
|
return instruments
|
||||||
|
|
||||||
|
|
||||||
def run_backtest(
|
def run_backtest(
|
||||||
config: Dict,
|
config: Dict,
|
||||||
datafile: str,
|
datafile: str,
|
||||||
price_column: str,
|
price_column: str,
|
||||||
fit_method: PairsTradingFitMethod,
|
fit_method: PairsTradingFitMethod,
|
||||||
instruments: List[str],
|
instruments: List[Dict[str, str]],
|
||||||
) -> BacktestResult:
|
) -> BacktestResult:
|
||||||
"""
|
"""
|
||||||
Run backtest for all pairs using the specified instruments.
|
Run backtest for all pairs using the specified instruments.
|
||||||
@ -72,13 +72,14 @@ def run_backtest(
|
|||||||
bt_result: BacktestResult = BacktestResult(config=config)
|
bt_result: BacktestResult = BacktestResult(config=config)
|
||||||
|
|
||||||
pairs_trades = []
|
pairs_trades = []
|
||||||
for pair in create_pairs(
|
pairs = create_pairs(
|
||||||
datafile=datafile,
|
datafile=datafile,
|
||||||
fit_method=fit_method,
|
fit_method=fit_method,
|
||||||
price_column=price_column,
|
price_column=price_column,
|
||||||
config=config,
|
config=config,
|
||||||
instruments=instruments,
|
instruments=instruments,
|
||||||
):
|
)
|
||||||
|
for pair in pairs:
|
||||||
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
|
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
|
||||||
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
||||||
pairs_trades.append(single_pair_trades)
|
pairs_trades.append(single_pair_trades)
|
||||||
@ -98,16 +99,16 @@ def main() -> None:
|
|||||||
"--config", type=str, required=True, help="Path to the configuration file."
|
"--config", type=str, required=True, help="Path to the configuration file."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--datafiles",
|
"--date_pattern",
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=True,
|
||||||
help="Comma-separated list of data files (overrides config). No wildcards supported.",
|
help="Date YYYYMMDD, allows * and ? wildcards",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--instruments",
|
"--instruments",
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=True,
|
||||||
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
|
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result_db",
|
"--result_db",
|
||||||
@ -128,7 +129,8 @@ def main() -> None:
|
|||||||
fit_method = getattr(module, class_name)()
|
fit_method = getattr(module, class_name)()
|
||||||
|
|
||||||
# Resolve data files (CLI takes priority over config)
|
# Resolve data files (CLI takes priority over config)
|
||||||
datafiles = resolve_datafiles(config, args.datafiles)
|
instruments = get_instruments(args, config)
|
||||||
|
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
|
||||||
|
|
||||||
if not datafiles:
|
if not datafiles:
|
||||||
print("No data files found to process.")
|
print("No data files found to process.")
|
||||||
@ -149,18 +151,8 @@ def main() -> None:
|
|||||||
# Store configuration in database for reference
|
# Store configuration in database for reference
|
||||||
if args.result_db.upper() != "NONE":
|
if args.result_db.upper() != "NONE":
|
||||||
# Get list of all instruments for storage
|
# Get list of all instruments for storage
|
||||||
all_instruments = []
|
|
||||||
for datafile in datafiles:
|
|
||||||
if args.instruments:
|
|
||||||
file_instruments = [
|
|
||||||
inst.strip() for inst in args.instruments.split(",")
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
file_instruments = get_available_instruments_from_db(datafile, config)
|
|
||||||
all_instruments.extend(file_instruments)
|
|
||||||
|
|
||||||
# Remove duplicates while preserving order
|
# Remove duplicates while preserving order
|
||||||
unique_instruments = list(dict.fromkeys(all_instruments))
|
|
||||||
|
|
||||||
store_config_in_database(
|
store_config_in_database(
|
||||||
db_path=args.result_db,
|
db_path=args.result_db,
|
||||||
@ -168,7 +160,7 @@ def main() -> None:
|
|||||||
config=config,
|
config=config,
|
||||||
fit_method_class=fit_method_class_name,
|
fit_method_class=fit_method_class_name,
|
||||||
datafiles=datafiles,
|
datafiles=datafiles,
|
||||||
instruments=unique_instruments,
|
instruments=instruments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process each data file
|
# Process each data file
|
||||||
@ -177,20 +169,6 @@ def main() -> None:
|
|||||||
for datafile in datafiles:
|
for datafile in datafiles:
|
||||||
print(f"\n====== Processing {os.path.basename(datafile)} ======")
|
print(f"\n====== Processing {os.path.basename(datafile)} ======")
|
||||||
|
|
||||||
# Determine instruments to use
|
|
||||||
if args.instruments:
|
|
||||||
# Use CLI-specified instruments
|
|
||||||
instruments = [inst.strip() for inst in args.instruments.split(",")]
|
|
||||||
print(f"Using CLI-specified instruments: {instruments}")
|
|
||||||
else:
|
|
||||||
# Auto-detect instruments from database
|
|
||||||
instruments = get_available_instruments_from_db(datafile, config)
|
|
||||||
print(f"Auto-detected instruments: {instruments}")
|
|
||||||
|
|
||||||
if not instruments:
|
|
||||||
print(f"No instruments found for {datafile}, skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process data for this file
|
# Process data for this file
|
||||||
try:
|
try:
|
||||||
fit_method.reset()
|
fit_method.reset()
|
||||||
@ -212,7 +190,15 @@ def main() -> None:
|
|||||||
|
|
||||||
# Store results in database
|
# Store results in database
|
||||||
if args.result_db.upper() != "NONE":
|
if args.result_db.upper() != "NONE":
|
||||||
store_results_in_database(args.result_db, datafile, bt_results)
|
bt_results.calculate_returns(
|
||||||
|
{
|
||||||
|
filename: {
|
||||||
|
"trades": bt_results.trades.copy(),
|
||||||
|
"outstanding_positions": bt_results.outstanding_positions.copy(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
bt_results.store_results_in_database(args.result_db, datafile)
|
||||||
|
|
||||||
print(f"Successfully processed {filename}")
|
print(f"Successfully processed {filename}")
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
|||||||
from pt_trading.fit_method import PairsTradingFitMethod
|
from pt_trading.fit_method import PairsTradingFitMethod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Resolve the list of data files to process.
|
Resolve the list of data files to process.
|
||||||
@ -44,28 +43,41 @@ def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List
|
|||||||
|
|
||||||
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
||||||
|
|
||||||
def create_pairs(datafile: str, fit_method: PairsTradingFitMethod, price_column: str, config: Dict, instruments: List[str]) -> List:
|
|
||||||
from tools.data_loader import load_market_data
|
|
||||||
from pt_trading.trading_pair import TradingPair
|
|
||||||
all_indexes = range(len(instruments))
|
|
||||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
|
||||||
pairs = []
|
|
||||||
|
|
||||||
# Update config to use the specified instruments
|
def create_pairs(
|
||||||
config_copy = config.copy()
|
datafile: str,
|
||||||
config_copy["instruments"] = instruments
|
fit_method: PairsTradingFitMethod,
|
||||||
|
price_column: str,
|
||||||
|
config: Dict,
|
||||||
|
instruments: List[Dict[str, str]],
|
||||||
|
) -> List:
|
||||||
|
from tools.data_loader import load_market_data
|
||||||
|
from pt_trading.trading_pair import TradingPair
|
||||||
|
|
||||||
market_data_df = load_market_data(datafile, config=config_copy)
|
all_indexes = range(len(instruments))
|
||||||
|
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||||
|
pairs = []
|
||||||
|
|
||||||
for a_index, b_index in unique_index_pairs:
|
# Update config to use the specified instruments
|
||||||
from research.pt_backtest import TradingPair
|
config_copy = config.copy()
|
||||||
pair = fit_method.create_trading_pair(
|
config_copy["instruments"] = instruments
|
||||||
config=config_copy,
|
|
||||||
market_data=market_data_df,
|
market_data_df = load_market_data(
|
||||||
symbol_a=instruments[a_index],
|
datafile=datafile,
|
||||||
symbol_b=instruments[b_index],
|
instruments=instruments,
|
||||||
price_column=price_column,
|
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
|
||||||
)
|
trading_hours=config_copy["trading_hours"],
|
||||||
pairs.append(pair)
|
)
|
||||||
return pairs
|
|
||||||
|
|
||||||
|
for a_index, b_index in unique_index_pairs:
|
||||||
|
from research.pt_backtest import TradingPair
|
||||||
|
|
||||||
|
pair = fit_method.create_trading_pair(
|
||||||
|
config=config_copy,
|
||||||
|
market_data=market_data_df,
|
||||||
|
symbol_a=instruments[a_index]["symbol"],
|
||||||
|
symbol_b=instruments[b_index]["symbol"],
|
||||||
|
price_column=price_column,
|
||||||
|
)
|
||||||
|
pairs.append(pair)
|
||||||
|
return pairs
|
||||||
|
|||||||
Binary file not shown.
@ -20,8 +20,6 @@ from pt_trading.fit_methods import PairsTradingFitMethod
|
|||||||
from pt_trading.trading_pair import TradingPair
|
from pt_trading.trading_pair import TradingPair
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_strategy(
|
def run_strategy(
|
||||||
config: Dict,
|
config: Dict,
|
||||||
datafile: str,
|
datafile: str,
|
||||||
@ -44,7 +42,14 @@ def run_strategy(
|
|||||||
config_copy = config.copy()
|
config_copy = config.copy()
|
||||||
config_copy["instruments"] = instruments
|
config_copy["instruments"] = instruments
|
||||||
|
|
||||||
market_data_df = load_market_data(datafile, config=config_copy)
|
market_data_df = load_market_data(
|
||||||
|
datafile=datafile,
|
||||||
|
exchange_id=config_copy["exchange_id"],
|
||||||
|
instruments=config_copy["instruments"],
|
||||||
|
instrument_id_pfx=config_copy["instrument_id_pfx"],
|
||||||
|
db_table_name=config_copy["db_table_name"],
|
||||||
|
trading_hours=config_copy["trading_hours"],
|
||||||
|
)
|
||||||
|
|
||||||
for a_index, b_index in unique_index_pairs:
|
for a_index, b_index in unique_index_pairs:
|
||||||
pair = fit_method.create_trading_pair(
|
pair = fit_method.create_trading_pair(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user