61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Dict, Optional, cast, Generator, List
|
|
|
|
|
|
@dataclass
|
|
class DataParams:
|
|
training_size: int
|
|
training_start_index: int
|
|
|
|
class ModelDataPolicy(ABC):
|
|
config_: Dict[str, Any]
|
|
current_data_params_: DataParams
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config_ = config
|
|
self.current_data_params_ = DataParams(
|
|
training_size=config.get("training_size", 120),
|
|
training_start_index=0,
|
|
)
|
|
|
|
@abstractmethod
|
|
def advance(self) -> DataParams:
|
|
...
|
|
|
|
@staticmethod
|
|
def create(config: Dict[str, Any]) -> ModelDataPolicy:
|
|
import importlib
|
|
|
|
model_data_policy_class_name = config.get("model_data_policy_class", None)
|
|
assert model_data_policy_class_name is not None
|
|
module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
|
|
module = importlib.import_module(module_name)
|
|
model_training_data_policy_object = getattr(module, class_name)(config=config)
|
|
return cast(ModelDataPolicy, model_training_data_policy_object)
|
|
|
|
class RollingWindowDataPolicy(ModelDataPolicy):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.count_ = 1
|
|
|
|
def advance(self) -> DataParams:
|
|
self.current_data_params_.training_start_index += 1
|
|
print(self.count_, end='\r')
|
|
self.count_ += 1
|
|
return self.current_data_params_
|
|
|
|
|
|
class ExpandingWindowDataPolicy(ModelDataPolicy):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
|
|
def advance(self) -> DataParams:
|
|
self.current_data_params_.training_size += 1
|
|
return self.current_data_params_
|
|
|
|
|