added close position and trade session

This commit is contained in:
Oleg Sheynin 2025-07-16 18:06:33 +00:00
parent 20f150a6b7
commit 9c34d935bd
9 changed files with 1679 additions and 8107 deletions

View File

@ -21,6 +21,11 @@
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"exclude_instruments": ["CAN"]
"close_outstanding_positions": true,
"trading_hours": {
"begin_session": "15:30:00",
"end_session": "20:00:00",
"timezone": "UTC"
}
}

View File

@ -21,6 +21,7 @@
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"exclude_instruments": ["CAN"]
"exclude_instruments": ["CAN"],
"close_outstanding_positions": false
}

View File

@ -21,6 +21,7 @@
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"exclude_instruments": ["CAN"]
"exclude_instruments": ["CAN"],
"close_outstanding_positions": false
}

View File

@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
import pandas as pd # type: ignore[import]
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
@ -22,7 +22,7 @@ class PairsTradingFitMethod(ABC):
@abstractmethod
def run_pair(
self, config: Dict, pair: TradingPair, bt_result: BacktestResult
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]: ...
@abstractmethod
@ -33,3 +33,4 @@ class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSED = 3
CLOSED_POSITIONS = 4

View File

@ -16,12 +16,15 @@ NanoPerMin = 1e9
class SlidingFit(PairsTradingFitMethod):
def __init__(self) -> None:
super().__init__()
self.curr_training_start_idx_ = 0
def run_pair(
self, config: Dict, pair: TradingPair, bt_result: BacktestResult
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
curr_training_start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
@ -39,22 +42,23 @@ class SlidingFit(PairsTradingFitMethod):
training_minutes = config["training_minutes"]
curr_predicted_row_idx = 0
while True:
print(self.curr_training_start_idx_, end="\r")
print(curr_training_start_idx, end="\r")
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=self.curr_training_start_idx_,
training_start_index=curr_training_start_idx,
testing_size=1,
)
if len(pair.training_df_) < training_minutes:
print(
f"{pair}: current offset={self.curr_training_start_idx_}"
f"{pair}: current offset={curr_training_start_idx}"
f" * Training data length={len(pair.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
try:
# ================================ TRAINING ================================
is_cointegrated = pair.train_pair()
except Exception as e:
raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e
@ -64,30 +68,33 @@ class SlidingFit(PairsTradingFitMethod):
if not is_cointegrated:
if pair.user_data_["state"] == PairState.OPEN:
print(
f"{pair} {self.curr_training_start_idx_} LOST COINTEGRATION. Consider closing positions..."
f"{pair} {curr_training_start_idx} LOST COINTEGRATION. Consider closing positions..."
)
else:
print(
f"{pair} {self.curr_training_start_idx_} IS NOT COINTEGRATED. Moving on"
f"{pair} {curr_training_start_idx} IS NOT COINTEGRATED. Moving on"
)
else:
print("*" * 80)
print(
f"Pair {pair} ({self.curr_training_start_idx_}) IS COINTEGRATED"
f"Pair {pair} ({curr_training_start_idx}) IS COINTEGRATED"
)
print("*" * 80)
if not is_cointegrated:
self.curr_training_start_idx_ += 1
curr_training_start_idx += 1
continue
try:
# ================================ PREDICTION ================================
pair.predict()
except Exception as e:
raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e
# break
self.curr_training_start_idx_ += 1
curr_training_start_idx += 1
if curr_training_start_idx > end_index:
break
curr_predicted_row_idx += 1
self._create_trading_signals(pair, config, bt_result)
@ -105,7 +112,7 @@ class SlidingFit(PairsTradingFitMethod):
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(pair.predicted_df_)):
pred_row = pair.predicted_df_.iloc[curr_predicted_row_idx]
if pair.user_data_["state"] in [PairState.INITIAL, PairState.CLOSED]:
if pair.user_data_["state"] in [PairState.INITIAL, PairState.CLOSED, PairState.CLOSED_POSITIONS]:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
@ -130,18 +137,29 @@ class SlidingFit(PairsTradingFitMethod):
f"{pair}: *** Position is NOT CLOSED. ***"
)
# outstanding positions
# last_row_index = self.curr_training_start_idx_ + training_minutes
if pair.predicted_df_ is not None:
bt_result.handle_outstanding_position(
if config["close_outstanding_positions"]:
close_position_trades = self._get_close_position_trades(
pair=pair,
pair_result_df=pair.predicted_df_,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
row=pred_row,
close_threshold=close_threshold,
)
if close_position_trades is not None:
close_position_trades["status"] = "CLOSE_POSITION"
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSED_POSITIONS
else:
if pair.predicted_df_ is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=pair.predicted_df_,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
@ -284,5 +302,61 @@ class SlidingFit(PairsTradingFitMethod):
"pair": "object"
})
def _get_close_position_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
if len(pair.predicted_df_) == 0:
return None
close_position_row = row
close_position_tstamp = close_position_row["tstamp"]
close_position_disequilibrium = close_position_row["disequilibrium"]
close_position_scaled_disequilibrium = close_position_row["scaled_disequilibrium"]
close_position_px_a = close_position_row[f"{colname_a}"]
close_position_px_b = close_position_row[f"{colname_b}"]
close_position_side_a = pair.user_data_["close_side_a"]
close_position_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_position_tstamp,
close_position_side_a,
pair.symbol_a_,
close_position_px_a,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
(
close_position_tstamp,
close_position_side_b,
pair.symbol_b_,
close_position_px_b,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
self.curr_training_start_idx_ = 0
curr_training_start_idx = 0

View File

@ -14,8 +14,9 @@ NanoPerMin = 1e9
class StaticFit(PairsTradingFitMethod):
def run_pair(
self, config: Dict, pair: TradingPair, bt_result: BacktestResult
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]: # abstractmethod
config = pair.config_
pair.get_datasets(training_minutes=config["training_minutes"])
try:
is_cointegrated = pair.train_pair()

View File

@ -23,11 +23,17 @@ class TradingPair:
predicted_df_: Optional[pd.DataFrame]
def __init__(
self, market_data: pd.DataFrame, symbol_a: str, symbol_b: str, price_column: str
self, config: Dict[str, Any], market_data: pd.DataFrame, symbol_a: str, symbol_b: str, price_column: str
):
self.symbol_a_ = symbol_a
self.symbol_b_ = symbol_b
self.price_column_ = price_column
self.set_market_data(market_data)
self.user_data_ = {}
self.predicted_df_ = None
self.config_ = config
def set_market_data(self, market_data: pd.DataFrame) -> None:
self.market_data_ = pd.DataFrame(
self._transform_dataframe(market_data)[["tstamp"] + self.colnames()]
)
@ -36,9 +42,19 @@ class TradingPair:
self.market_data_['tstamp'] = pd.to_datetime(self.market_data_['tstamp'])
self.market_data_ = self.market_data_.sort_values('tstamp')
self.user_data_ = {}
self.predicted_df_ = None
def get_begin_index(self) -> int:
if "trading_hours" not in self.config_:
return 0
start_time = pd.to_datetime(self.config_["trading_hours"]["begin_session"]).time()
mask = self.market_data_['tstamp'].dt.time >= start_time
return int(self.market_data_.index[mask].min())
def get_end_index(self) -> int:
if "trading_hours" not in self.config_:
return 0
end_time = pd.to_datetime(self.config_["trading_hours"]["end_session"]).time()
mask = self.market_data_['tstamp'].dt.time <= end_time
return int(self.market_data_.index[mask].max())
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
# Select only the columns we need

File diff suppressed because one or more lines are too long

View File

@ -84,6 +84,7 @@ def run_backtest(
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
config=config_copy,
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
@ -95,7 +96,7 @@ def run_backtest(
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
pair=pair, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)