refactored code. before cleaning
This commit is contained in:
parent
1d73ce8070
commit
1b6b5e5735
@ -27,10 +27,9 @@ class ZScoreOLSModel(PairsTradingModel):
|
||||
|
||||
assert zscore_df is not None
|
||||
return Prediction(
|
||||
tstamp_=pair.market_data_.index[-1],
|
||||
disequilibrium_=self.training_df_["dis-equilibrium"].iloc[-1],
|
||||
scaled_disequilibrium_=self.training_df_["scaled_dis-equilibrium"].iloc[-1],
|
||||
pair_=pair,
|
||||
tstamp=pair.market_data_.iloc[-1]["tstamp"],
|
||||
disequilibrium=self.training_df_["dis-equilibrium"].iloc[-1],
|
||||
scaled_disequilibrium=self.training_df_["scaled_dis-equilibrium"].iloc[-1],
|
||||
)
|
||||
|
||||
def _fit_zscore(self, pair: TradingPair) -> pd.DataFrame:
|
||||
|
||||
@ -9,12 +9,15 @@ import pandas as pd
|
||||
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
@dataclass
|
||||
class Prediction:
|
||||
tstamp_: pd.Timestamp
|
||||
disequilibrium_: float
|
||||
scaled_disequilibrium_: float
|
||||
pair_: TradingPair
|
||||
|
||||
def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float):
|
||||
self.tstamp_ = tstamp
|
||||
self.disequilibrium_ = disequilibrium
|
||||
self.scaled_disequilibrium_ = scaled_disequilibrium
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
@ -22,10 +25,10 @@ class Prediction:
|
||||
"disequilibrium": self.disequilibrium_,
|
||||
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
|
||||
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
|
||||
"pair": self.pair_,
|
||||
# "pair": self.pair_,
|
||||
}
|
||||
def to_pd_series(self) -> pd.Series:
|
||||
return pd.DataFrame([self.to_dict()]).iloc[0]
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
return pd.DataFrame([self.to_dict()])
|
||||
|
||||
class PairsTradingModel(ABC):
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ class PtResearchStrategy:
|
||||
pt_mkt_data_: PtMarketData
|
||||
|
||||
trades_: List[pd.DataFrame]
|
||||
predictions_: pd.DataFrame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -41,6 +42,7 @@ class PtResearchStrategy:
|
||||
self.trades_ = []
|
||||
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
|
||||
self.model_data_policy_ = ModelDataPolicy.create(config)
|
||||
self.predictions_ = pd.DataFrame()
|
||||
|
||||
import copy
|
||||
|
||||
@ -83,6 +85,7 @@ class PtResearchStrategy:
|
||||
prediction = self.trading_pair_.run(
|
||||
market_data_df, self.model_data_policy_.advance()
|
||||
)
|
||||
self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True)
|
||||
assert prediction is not None
|
||||
|
||||
trades = self._create_trades(
|
||||
|
||||
79
lib/tools/viz/viz_prices.py
Normal file
79
lib/tools/viz/viz_prices.py
Normal file
@ -0,0 +1,79 @@
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
|
||||
|
||||
def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None:
|
||||
# Plot raw price data
|
||||
import matplotlib.pyplot as plt
|
||||
# Set plotting style
|
||||
import seaborn as sns
|
||||
|
||||
pair = strategy.trading_pair_
|
||||
SYMBOL_A = pair.symbol_a_
|
||||
SYMBOL_B = pair.symbol_b_
|
||||
TRD_DATE = f"{trading_date[0:4]}-{trading_date[4:6]}-{trading_date[6:8]}"
|
||||
|
||||
plt.style.use('seaborn-v0_8')
|
||||
sns.set_palette("husl")
|
||||
plt.rcParams['figure.figsize'] = (15, 10)
|
||||
|
||||
# Get column names for the trading pair
|
||||
colname_a, colname_b = pair.colnames()
|
||||
price_data = strategy.pt_mkt_data_.market_data_df_.copy()
|
||||
|
||||
# Create separate subplots for better visibility
|
||||
fig_price, price_axes = plt.subplots(2, 1, figsize=(18, 10))
|
||||
|
||||
# Plot SYMBOL_A
|
||||
price_axes[0].plot(price_data['tstamp'], price_data[colname_a], alpha=0.7,
|
||||
label=f'{SYMBOL_A}', linewidth=1, color='blue')
|
||||
price_axes[0].set_title(f'{SYMBOL_A} Price Data ({TRD_DATE})')
|
||||
price_axes[0].set_ylabel(f'{SYMBOL_A} Price')
|
||||
price_axes[0].legend()
|
||||
price_axes[0].grid(True)
|
||||
|
||||
# Plot SYMBOL_B
|
||||
price_axes[1].plot(price_data['tstamp'], price_data[colname_b], alpha=0.7,
|
||||
label=f'{SYMBOL_B}', linewidth=1, color='red')
|
||||
price_axes[1].set_title(f'{SYMBOL_B} Price Data ({TRD_DATE})')
|
||||
price_axes[1].set_ylabel(f'{SYMBOL_B} Price')
|
||||
price_axes[1].set_xlabel('Time')
|
||||
price_axes[1].legend()
|
||||
price_axes[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
# Plot individual prices
|
||||
fig, axes = plt.subplots(2, 1, figsize=(18, 12))
|
||||
|
||||
# Normalized prices for comparison
|
||||
norm_a = price_data[colname_a] / price_data[colname_a].iloc[0]
|
||||
norm_b = price_data[colname_b] / price_data[colname_b].iloc[0]
|
||||
|
||||
axes[0].plot(price_data['tstamp'], norm_a, label=f'{SYMBOL_A} (normalized)', alpha=0.8, linewidth=1)
|
||||
axes[0].plot(price_data['tstamp'], norm_b, label=f'{SYMBOL_B} (normalized)', alpha=0.8, linewidth=1)
|
||||
axes[0].set_title(f'Normalized Price Comparison (Base = 1.0) ({TRD_DATE})')
|
||||
axes[0].set_ylabel('Normalized Price')
|
||||
axes[0].legend()
|
||||
axes[0].grid(True)
|
||||
|
||||
# Price ratio
|
||||
price_ratio = price_data[colname_a] / price_data[colname_b]
|
||||
axes[1].plot(price_data['tstamp'], price_ratio, label=f'{SYMBOL_A}/{SYMBOL_B} Ratio', color='green', alpha=0.8, linewidth=1)
|
||||
axes[1].set_title(f'Price Ratio Px({SYMBOL_A})/Px({SYMBOL_B}) ({TRD_DATE})')
|
||||
axes[1].set_ylabel('Ratio')
|
||||
axes[1].set_xlabel('Time')
|
||||
axes[1].legend()
|
||||
axes[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Print basic statistics
|
||||
print(f"\nPrice Statistics:")
|
||||
print(f" {SYMBOL_A}: Mean=${price_data[colname_a].mean():.2f}, Std=${price_data[colname_a].std():.2f}")
|
||||
print(f" {SYMBOL_B}: Mean=${price_data[colname_b].mean():.2f}, Std=${price_data[colname_b].std():.2f}")
|
||||
print(f" Price Ratio: Mean={price_ratio.mean():.2f}, Std={price_ratio.std():.2f}")
|
||||
print(f" Correlation: {price_data[colname_a].corr(price_data[colname_b]):.4f}")
|
||||
|
||||
507
lib/tools/viz/viz_trades.py
Normal file
507
lib/tools/viz/viz_trades.py
Normal file
@ -0,0 +1,507 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
|
||||
|
||||
def visualize_trades(strategy: PtResearchStrategy, results: PairResearchResult, trading_date: str) -> None:
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import plotly.offline as pyo
|
||||
from IPython.display import HTML
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
pair = strategy.trading_pair_
|
||||
trades = results.trades_[trading_date].copy()
|
||||
origin_mkt_data_df = strategy.pt_mkt_data_.origin_mkt_data_df_
|
||||
mkt_data_df = strategy.pt_mkt_data_.market_data_df_
|
||||
TRD_DATE = f"{trading_date[0:4]}-{trading_date[4:6]}-{trading_date[6:8]}"
|
||||
SYMBOL_A = pair.symbol_a_
|
||||
SYMBOL_B = pair.symbol_b_
|
||||
|
||||
|
||||
print(f"\nCreated trading pair: {pair}")
|
||||
print(f"Market data shape: {pair.market_data_.shape}")
|
||||
print(f"Column names: {pair.colnames()}")
|
||||
|
||||
# Configure plotly for offline mode
|
||||
pyo.init_notebook_mode(connected=True)
|
||||
|
||||
# Strategy-specific interactive visualization
|
||||
assert strategy.config_ is not None
|
||||
|
||||
print("=== SLIDING FIT INTERACTIVE VISUALIZATION ===")
|
||||
print("Note: Rolling Fit strategy visualization with interactive plotly charts")
|
||||
|
||||
|
||||
# Create consistent timeline - superset of timestamps from both dataframes
|
||||
all_timestamps = sorted(set(mkt_data_df['tstamp']))
|
||||
|
||||
|
||||
# Create a unified timeline dataframe for consistent plotting
|
||||
timeline_df = pd.DataFrame({'tstamp': all_timestamps})
|
||||
|
||||
# Merge with predicted data to get dis-equilibrium values
|
||||
timeline_df = timeline_df.merge(strategy.predictions_[['tstamp', 'disequilibrium', 'scaled_disequilibrium', 'signed_scaled_disequilibrium']],
|
||||
on='tstamp', how='left')
|
||||
|
||||
# Get Symbol_A and Symbol_B market data
|
||||
colname_a, colname_b = pair.colnames()
|
||||
symbol_a_data = mkt_data_df[['tstamp', colname_a]].copy()
|
||||
symbol_b_data = mkt_data_df[['tstamp', colname_b]].copy()
|
||||
|
||||
norm_a = symbol_a_data[colname_a] / symbol_a_data[colname_a].iloc[0]
|
||||
norm_b = symbol_b_data[colname_b] / symbol_b_data[colname_b].iloc[0]
|
||||
|
||||
print(f"Using consistent timeline with {len(timeline_df)} timestamps")
|
||||
print(f"Timeline range: {timeline_df['tstamp'].min()} to {timeline_df['tstamp'].max()}")
|
||||
|
||||
# Create subplots with price charts at bottom
|
||||
fig = make_subplots(
|
||||
rows=4, cols=1,
|
||||
row_heights=[0.3, 0.4, 0.15, 0.15],
|
||||
subplot_titles=[
|
||||
f'Dis-equilibrium with Trading Thresholds ({TRD_DATE})',
|
||||
f'Normalized Price Comparison with BUY/SELL Signals - {SYMBOL_A}&{SYMBOL_B} ({TRD_DATE})',
|
||||
f'{SYMBOL_A} Market Data with Trading Signals ({TRD_DATE})',
|
||||
f'{SYMBOL_B} Market Data with Trading Signals ({TRD_DATE})',
|
||||
],
|
||||
vertical_spacing=0.06,
|
||||
specs=[[{"secondary_y": False}],
|
||||
[{"secondary_y": False}],
|
||||
[{"secondary_y": False}],
|
||||
[{"secondary_y": False}]]
|
||||
)
|
||||
|
||||
# 1. Scaled dis-equilibrium with thresholds - using consistent timeline
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=timeline_df['tstamp'],
|
||||
y=timeline_df['scaled_disequilibrium'],
|
||||
name='Absolute Scaled Dis-equilibrium',
|
||||
line=dict(color='green', width=2),
|
||||
opacity=0.8
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=timeline_df['tstamp'],
|
||||
y=timeline_df['signed_scaled_disequilibrium'],
|
||||
name='Scaled Dis-equilibrium',
|
||||
line=dict(color='darkmagenta', width=2),
|
||||
opacity=0.8
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add threshold lines to first subplot
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=timeline_df['tstamp'].min(),
|
||||
x1=timeline_df['tstamp'].max(),
|
||||
y0=strategy.config_['dis-equilibrium_open_trshld'],
|
||||
y1=strategy.config_['dis-equilibrium_open_trshld'],
|
||||
line=dict(color="purple", width=2, dash="dot"),
|
||||
opacity=0.7,
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=timeline_df['tstamp'].min(),
|
||||
x1=timeline_df['tstamp'].max(),
|
||||
y0=-strategy.config_['dis-equilibrium_open_trshld'],
|
||||
y1=-strategy.config_['dis-equilibrium_open_trshld'],
|
||||
line=dict(color="purple", width=2, dash="dot"),
|
||||
opacity=0.7,
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=timeline_df['tstamp'].min(),
|
||||
x1=timeline_df['tstamp'].max(),
|
||||
y0=strategy.config_['dis-equilibrium_close_trshld'],
|
||||
y1=strategy.config_['dis-equilibrium_close_trshld'],
|
||||
line=dict(color="brown", width=2, dash="dot"),
|
||||
opacity=0.7,
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=timeline_df['tstamp'].min(),
|
||||
x1=timeline_df['tstamp'].max(),
|
||||
y0=-strategy.config_['dis-equilibrium_close_trshld'],
|
||||
y1=-strategy.config_['dis-equilibrium_close_trshld'],
|
||||
line=dict(color="brown", width=2, dash="dot"),
|
||||
opacity=0.7,
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=timeline_df['tstamp'].min(),
|
||||
x1=timeline_df['tstamp'].max(),
|
||||
y0=0,
|
||||
y1=0,
|
||||
line=dict(color="black", width=1, dash="solid"),
|
||||
opacity=0.5,
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add normalized price lines
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=mkt_data_df['tstamp'],
|
||||
y=norm_a,
|
||||
name=f'{SYMBOL_A} (Normalized)',
|
||||
line=dict(color='blue', width=2),
|
||||
opacity=0.8
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=mkt_data_df['tstamp'],
|
||||
y=norm_b,
|
||||
name=f'{SYMBOL_B} (Normalized)',
|
||||
line=dict(color='orange', width=2),
|
||||
opacity=0.8,
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Add BUY and SELL signals if available
|
||||
if trades is not None and len(trades) > 0:
|
||||
# Define signal groups to avoid legend repetition
|
||||
signal_groups = {}
|
||||
|
||||
# Process all trades and group by signal type (ignore OPEN/CLOSE status)
|
||||
for _, trade in trades.iterrows():
|
||||
symbol = trade['symbol']
|
||||
side = trade['side']
|
||||
# status = trade['status']
|
||||
action = trade['action']
|
||||
|
||||
# Create signal group key (without status to combine OPEN/CLOSE)
|
||||
signal_key = f"{symbol} {side} {action}"
|
||||
|
||||
# Find normalized price for this trade
|
||||
trade_time = trade['time']
|
||||
if symbol == SYMBOL_A:
|
||||
closest_idx = mkt_data_df['tstamp'].searchsorted(trade_time)
|
||||
if closest_idx < len(norm_a):
|
||||
norm_price = norm_a.iloc[closest_idx]
|
||||
else:
|
||||
norm_price = norm_a.iloc[-1]
|
||||
else: # SYMBOL_B
|
||||
closest_idx = mkt_data_df['tstamp'].searchsorted(trade_time)
|
||||
if closest_idx < len(norm_b):
|
||||
norm_price = norm_b.iloc[closest_idx]
|
||||
else:
|
||||
norm_price = norm_b.iloc[-1]
|
||||
|
||||
# Initialize group if not exists
|
||||
if signal_key not in signal_groups:
|
||||
signal_groups[signal_key] = {
|
||||
'times': [],
|
||||
'prices': [],
|
||||
'actual_prices': [],
|
||||
'symbol': symbol,
|
||||
'side': side,
|
||||
# 'status': status,
|
||||
'action': trade['action']
|
||||
}
|
||||
|
||||
# Add to group
|
||||
signal_groups[signal_key]['times'].append(trade_time)
|
||||
signal_groups[signal_key]['prices'].append(norm_price)
|
||||
signal_groups[signal_key]['actual_prices'].append(trade['price'])
|
||||
|
||||
# Add each signal group as a single trace
|
||||
for signal_key, group_data in signal_groups.items():
|
||||
symbol = group_data['symbol']
|
||||
side = group_data['side']
|
||||
# status = group_data['status']
|
||||
|
||||
# Determine marker properties (same for all OPEN/CLOSE of same side)
|
||||
is_close: bool = (group_data['action'] == "CLOSE")
|
||||
|
||||
if 'BUY' in side:
|
||||
marker_color = 'green'
|
||||
marker_symbol = 'triangle-up'
|
||||
marker_size = 14
|
||||
else: # SELL
|
||||
marker_color = 'red'
|
||||
marker_symbol = 'triangle-down'
|
||||
marker_size = 14
|
||||
|
||||
# Create hover text for each point in the group
|
||||
hover_texts = []
|
||||
for i, (time, norm_price, actual_price) in enumerate(zip(group_data['times'],
|
||||
group_data['prices'],
|
||||
group_data['actual_prices'])):
|
||||
# Find the corresponding trade to get the status for hover text
|
||||
trade_info = trades[(trades['time'] == time) &
|
||||
(trades['symbol'] == symbol) &
|
||||
(trades['side'] == side)]
|
||||
if len(trade_info) > 0:
|
||||
action = trade_info.iloc[0]['action']
|
||||
hover_texts.append(f'<b>{signal_key} {action}</b><br>' +
|
||||
f'Time: {time}<br>' +
|
||||
f'Normalized Price: {norm_price:.4f}<br>' +
|
||||
f'Actual Price: ${actual_price:.2f}')
|
||||
else:
|
||||
hover_texts.append(f'<b>{signal_key}</b><br>' +
|
||||
f'Time: {time}<br>' +
|
||||
f'Normalized Price: {norm_price:.4f}<br>' +
|
||||
f'Actual Price: ${actual_price:.2f}')
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=group_data['times'],
|
||||
y=group_data['prices'],
|
||||
mode='markers',
|
||||
name=signal_key,
|
||||
marker=dict(
|
||||
color=marker_color,
|
||||
size=marker_size,
|
||||
symbol=marker_symbol,
|
||||
line=dict(width=2, color='black') if is_close else None
|
||||
),
|
||||
showlegend=True,
|
||||
hovertemplate='%{text}<extra></extra>',
|
||||
text=hover_texts
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=symbol_a_data['tstamp'],
|
||||
y=symbol_a_data[colname_a],
|
||||
name=f'{SYMBOL_A} Price',
|
||||
line=dict(color='blue', width=2),
|
||||
opacity=0.8
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# Filter trades for Symbol_A
|
||||
symbol_a_trades = trades[trades['symbol'] == SYMBOL_A]
|
||||
print(f"\nSymbol_A trades:\n{symbol_a_trades}")
|
||||
|
||||
if len(symbol_a_trades) > 0:
|
||||
# Separate trades by action and status for different colors
|
||||
buy_open_trades = symbol_a_trades[(symbol_a_trades['side'].str.contains('BUY', na=False)) &
|
||||
(symbol_a_trades['action'].str.contains('OPEN', na=False))]
|
||||
buy_close_trades = symbol_a_trades[(symbol_a_trades['side'].str.contains('BUY', na=False)) &
|
||||
(symbol_a_trades['action'].str.contains('CLOSE', na=False))]
|
||||
|
||||
sell_open_trades = symbol_a_trades[(symbol_a_trades['side'].str.contains('SELL', na=False)) &
|
||||
(symbol_a_trades['action'].str.contains('OPEN', na=False))]
|
||||
sell_close_trades = symbol_a_trades[(symbol_a_trades['side'].str.contains('SELL', na=False)) &
|
||||
(symbol_a_trades['action'].str.contains('CLOSE', na=False))]
|
||||
|
||||
# Add BUY OPEN signals
|
||||
if len(buy_open_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_open_trades['time'],
|
||||
y=buy_open_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_A} BUY OPEN',
|
||||
marker=dict(color='green', size=12, symbol='triangle-up'),
|
||||
showlegend=True
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# Add BUY CLOSE signals
|
||||
if len(buy_close_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_close_trades['time'],
|
||||
y=buy_close_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_A} BUY CLOSE',
|
||||
marker=dict(color='green', size=12, symbol='triangle-up'),
|
||||
line=dict(width=2, color='black'),
|
||||
showlegend=True
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# Add SELL OPEN signals
|
||||
if len(sell_open_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_open_trades['time'],
|
||||
y=sell_open_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_A} SELL OPEN',
|
||||
marker=dict(color='red', size=12, symbol='triangle-down'),
|
||||
showlegend=True
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# Add SELL CLOSE signals
|
||||
if len(sell_close_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_close_trades['time'],
|
||||
y=sell_close_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_A} SELL CLOSE',
|
||||
marker=dict(color='red', size=12, symbol='triangle-down'),
|
||||
line=dict(width=2, color='black'),
|
||||
showlegend=True
|
||||
),
|
||||
row=3, col=1
|
||||
)
|
||||
|
||||
# 4. Symbol_B Market Data with Trading Signals
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=symbol_b_data['tstamp'],
|
||||
y=symbol_b_data[colname_b],
|
||||
name=f'{SYMBOL_B} Price',
|
||||
line=dict(color='orange', width=2),
|
||||
opacity=0.8
|
||||
),
|
||||
row=4, col=1
|
||||
)
|
||||
|
||||
# Add trading signals for Symbol_B if available
|
||||
symbol_b_trades = trades[trades['symbol'] == SYMBOL_B]
|
||||
print(f"\nSymbol_B trades:\n{symbol_b_trades}")
|
||||
|
||||
if len(symbol_b_trades) > 0:
|
||||
# Separate trades by action and status for different colors
|
||||
buy_open_trades = symbol_b_trades[(symbol_b_trades['side'].str.contains('BUY', na=False)) &
|
||||
(symbol_b_trades['action'].str.startswith('OPEN', na=False))]
|
||||
buy_close_trades = symbol_b_trades[(symbol_b_trades['side'].str.contains('BUY', na=False)) &
|
||||
(symbol_b_trades['action'].str.startswith('CLOSE', na=False))]
|
||||
|
||||
sell_open_trades = symbol_b_trades[(symbol_b_trades['side'].str.contains('SELL', na=False)) &
|
||||
(symbol_b_trades['action'].str.contains('OPEN', na=False))]
|
||||
sell_close_trades = symbol_b_trades[(symbol_b_trades['side'].str.contains('SELL', na=False)) &
|
||||
(symbol_b_trades['action'].str.contains('CLOSE', na=False))]
|
||||
|
||||
# Add BUY OPEN signals
|
||||
if len(buy_open_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_open_trades['time'],
|
||||
y=buy_open_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_B} BUY OPEN',
|
||||
marker=dict(color='darkgreen', size=12, symbol='triangle-up'),
|
||||
showlegend=True
|
||||
),
|
||||
row=4, col=1
|
||||
)
|
||||
|
||||
# Add BUY CLOSE signals
|
||||
if len(buy_close_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=buy_close_trades['time'],
|
||||
y=buy_close_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_B} BUY CLOSE',
|
||||
marker=dict(color='green', size=12, symbol='triangle-up'),
|
||||
line=dict(width=2, color='black'),
|
||||
showlegend=True
|
||||
),
|
||||
row=4, col=1
|
||||
)
|
||||
|
||||
# Add SELL OPEN signals
|
||||
if len(sell_open_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_open_trades['time'],
|
||||
y=sell_open_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_B} SELL OPEN',
|
||||
marker=dict(color='red', size=12, symbol='triangle-down'),
|
||||
showlegend=True
|
||||
),
|
||||
row=4, col=1
|
||||
)
|
||||
|
||||
# Add SELL CLOSE signals
|
||||
if len(sell_close_trades) > 0:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=sell_close_trades['time'],
|
||||
y=sell_close_trades['price'],
|
||||
mode='markers',
|
||||
name=f'{SYMBOL_B} SELL CLOSE',
|
||||
marker=dict(color='red', size=12, symbol='triangle-down'),
|
||||
line=dict(width=2, color='black'),
|
||||
showlegend=True
|
||||
),
|
||||
row=4, col=1
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
height=1600,
|
||||
title_text=f"Strategy Analysis - {SYMBOL_A} & {SYMBOL_B} ({TRD_DATE})",
|
||||
showlegend=True,
|
||||
template="plotly_white",
|
||||
plot_bgcolor='lightgray',
|
||||
)
|
||||
|
||||
# Update y-axis labels
|
||||
fig.update_yaxes(title_text="Scaled Dis-equilibrium", row=1, col=1)
|
||||
fig.update_yaxes(title_text=f"{SYMBOL_A} Price ($)", row=2, col=1)
|
||||
fig.update_yaxes(title_text=f"{SYMBOL_B} Price ($)", row=3, col=1)
|
||||
fig.update_yaxes(title_text="Normalized Price (Base = 1.0)", row=4, col=1)
|
||||
|
||||
# Update x-axis labels and ensure consistent time range
|
||||
time_range = [timeline_df['tstamp'].min(), timeline_df['tstamp'].max()]
|
||||
fig.update_xaxes(range=time_range, row=1, col=1)
|
||||
fig.update_xaxes(range=time_range, row=2, col=1)
|
||||
fig.update_xaxes(range=time_range, row=3, col=1)
|
||||
fig.update_xaxes(title_text="Time", range=time_range, row=4, col=1)
|
||||
|
||||
# Display using plotly offline mode
|
||||
# pyo.iplot(fig)
|
||||
fig.show()
|
||||
|
||||
else:
|
||||
print("No interactive visualization data available - strategy may not have run successfully")
|
||||
|
||||
print(f"\nChart shows:")
|
||||
print(f"- {SYMBOL_A} and {SYMBOL_B} prices normalized to start at 1.0")
|
||||
print(f"- BUY signals shown as green triangles pointing up")
|
||||
print(f"- SELL signals shown as orange triangles pointing down")
|
||||
print(f"- All BUY signals per symbol grouped together, all SELL signals per symbol grouped together")
|
||||
print(f"- Hover over markers to see individual trade details (OPEN/CLOSE status)")
|
||||
|
||||
if trades is not None and len(trades) > 0:
|
||||
print(f"- Total signals displayed: {len(trades)}")
|
||||
print(f"- {SYMBOL_A} signals: {len(trades[trades['symbol'] == SYMBOL_A])}")
|
||||
print(f"- {SYMBOL_B} signals: {len(trades[trades['symbol'] == SYMBOL_B])}")
|
||||
else:
|
||||
print("- No trading signals to display")
|
||||
|
||||
File diff suppressed because one or more lines are too long
111
research/viz_test.py
Normal file
111
research/viz_test.py
Normal file
@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
from tools.viz.viz_trades import visualize_trades
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
from tools.config import expand_filename, load_config
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Path to the configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--date_pattern",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Date YYYYMMDD, allows * and ? wildcards",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruments",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_db",
|
||||
type=str,
|
||||
required=False,
|
||||
default="NONE",
|
||||
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config: Dict = load_config(args.config)
|
||||
|
||||
# Resolve data files (CLI takes priority over config)
|
||||
instruments = get_instruments(args, config)
|
||||
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
|
||||
|
||||
days = list(set([day for day, _ in datafiles]))
|
||||
print(f"Found {len(datafiles)} data files to process:")
|
||||
for df in datafiles:
|
||||
print(f" - {df}")
|
||||
|
||||
# Create result database if needed
|
||||
if args.result_db.upper() != "NONE":
|
||||
args.result_db = expand_filename(args.result_db)
|
||||
create_result_database(args.result_db)
|
||||
|
||||
# Initialize a dictionary to store all trade results
|
||||
all_results: Dict[str, Dict[str, Any]] = {}
|
||||
is_config_stored = False
|
||||
# Process each data file
|
||||
|
||||
results = PairResearchResult(config=config)
|
||||
for day in sorted(days):
|
||||
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
|
||||
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
|
||||
print(f"WARNING: insufficient data files: {md_datafiles}")
|
||||
continue
|
||||
print(f"\n====== Processing {day} ======")
|
||||
|
||||
if not is_config_stored:
|
||||
store_config_in_database(
|
||||
db_path=args.result_db,
|
||||
config_file_path=args.config,
|
||||
config=config,
|
||||
datafiles=datafiles,
|
||||
instruments=instruments,
|
||||
)
|
||||
is_config_stored = True
|
||||
|
||||
pt_strategy = PtResearchStrategy(
|
||||
config=config, datafiles=md_datafiles, instruments=instruments
|
||||
)
|
||||
pt_strategy.run()
|
||||
results.add_day_results(
|
||||
day=day,
|
||||
trades=pt_strategy.day_trades(),
|
||||
outstanding_positions=pt_strategy.outstanding_positions(),
|
||||
)
|
||||
|
||||
|
||||
results.analyze_pair_performance()
|
||||
|
||||
|
||||
visualize_trades(pt_strategy, results, day)
|
||||
|
||||
|
||||
if args.result_db.upper() != "NONE":
|
||||
print(f"\nResults stored in database: {args.result_db}")
|
||||
else:
|
||||
print("No results to display.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user