diff --git a/lib/pt_trading/trading_pair.py b/lib/pt_trading/trading_pair.py index 5b6da8c..8ba825d 100644 --- a/lib/pt_trading/trading_pair.py +++ b/lib/pt_trading/trading_pair.py @@ -376,16 +376,16 @@ class TradingPair: open_trades = self.user_data_["open_trades"] if len(open_trades) == 0: return 0.0 - def _stock_return(stock: str) -> float: - stock_open_trades = open_trades[open_trades["symbol"] == stock] - stock_sign = -1 if stock_open_trades["action"].iloc[0] == "SELL" else 1 - stock_price = predicted_row[f"{self.price_column_}_{stock}"] - stock_return = stock_sign * (stock_price - stock_open_trades["price"].iloc[0]) / stock_open_trades["price"].iloc[0] - return float(stock_return) + def _single_instrument_return(symbol: str) -> float: + instrument_open_trades = open_trades[open_trades["symbol"] == symbol] + instrument_sign = -1 if instrument_open_trades["action"].iloc[0] == "SELL" else 1 + instrument_price = predicted_row[f"{self.price_column_}_{symbol}"] + instrument_return = instrument_sign * (instrument_price - instrument_open_trades["price"].iloc[0]) / instrument_open_trades["price"].iloc[0] + return float(instrument_return) - stock_a_return = _stock_return(self.symbol_a_) - stock_b_return = _stock_return(self.symbol_b_) - return (stock_a_return + stock_b_return) * 100.0 + instrument_a_return = _single_instrument_return(self.symbol_a_) + instrument_b_return = _single_instrument_return(self.symbol_b_) + return (instrument_a_return + instrument_b_return) * 100.0 return 0.0 def __repr__(self) -> str: