from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, cast, Generator, List @dataclass class DataParams: training_size: int training_start_index: int class ModelDataPolicy(ABC): config_: Dict[str, Any] current_data_params_: DataParams def __init__(self, config: Dict[str, Any]): self.config_ = config self.current_data_params_ = DataParams( training_size=config.get("training_size", 120), training_start_index=0, ) @abstractmethod def advance(self) -> DataParams: ... @staticmethod def create(config: Dict[str, Any]) -> ModelDataPolicy: import importlib model_data_policy_class_name = config.get("model_data_policy_class", None) assert model_data_policy_class_name is not None module_name, class_name = model_data_policy_class_name.rsplit(".", 1) module = importlib.import_module(module_name) model_training_data_policy_object = getattr(module, class_name)(config=config) return cast(ModelDataPolicy, model_training_data_policy_object) class RollingWindowDataPolicy(ModelDataPolicy): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.count_ = 1 def advance(self) -> DataParams: self.current_data_params_.training_start_index += 1 print(self.count_, end='\r') self.count_ += 1 return self.current_data_params_ class ExpandingWindowDataPolicy(ModelDataPolicy): def __init__(self, config: Dict[str, Any]): super().__init__(config) def advance(self) -> DataParams: self.current_data_params_.training_size += 1 return self.current_data_params_