pairs_trading/lib/pt_strategy/model_data_policy.py
2025-07-30 04:08:02 +00:00

61 lines
1.8 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, cast, Generator, List
@dataclass
class DataParams:
training_size: int
training_start_index: int
class ModelDataPolicy(ABC):
config_: Dict[str, Any]
current_data_params_: DataParams
def __init__(self, config: Dict[str, Any]):
self.config_ = config
self.current_data_params_ = DataParams(
training_size=config.get("training_size", 120),
training_start_index=0,
)
@abstractmethod
def advance(self) -> DataParams:
...
@staticmethod
def create(config: Dict[str, Any]) -> ModelDataPolicy:
import importlib
model_data_policy_class_name = config.get("model_data_policy_class", None)
assert model_data_policy_class_name is not None
module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
model_training_data_policy_object = getattr(module, class_name)(config=config)
return cast(ModelDataPolicy, model_training_data_policy_object)
class RollingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.count_ = 1
def advance(self) -> DataParams:
self.current_data_params_.training_start_index += 1
print(self.count_, end='\r')
self.count_ += 1
return self.current_data_params_
class ExpandingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
def advance(self) -> DataParams:
self.current_data_params_.training_size += 1
return self.current_data_params_