2025-06-05 08:48:33 +02:00

296 lines
9.4 KiB
Python

import pandas as pd
import numpy as np
import json
from pathlib import Path
from typing import Dict, Any, Union, Optional
class DataLoader:
"""
Utility class for loading converted data files (CSV and JSON).
"""
def __init__(self, data_dir: str = "data"):
"""
Initialize the data loader.
Parameters:
-----------
data_dir : str
Directory containing the converted data files
"""
self.data_dir = Path(data_dir)
self.mapping_file = self.data_dir / "conversion_mapping.json"
self.mapping = self._load_mapping()
def _load_mapping(self) -> Dict[str, str]:
"""Load the conversion mapping file."""
if self.mapping_file.exists():
with open(self.mapping_file, 'r') as f:
return json.load(f)
return {}
def load_csv(self, filename: str) -> pd.DataFrame:
"""
Load a CSV file.
Parameters:
-----------
filename : str
Name of the CSV file to load
Returns:
--------
pd.DataFrame
Loaded data
"""
filepath = self.data_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"File not found: {filepath}")
return pd.read_csv(filepath)
def load_json(self, filename: str) -> Dict[str, Any]:
"""
Load a JSON file.
Parameters:
-----------
filename : str
Name of the JSON file to load
Returns:
--------
dict
Loaded data
"""
filepath = self.data_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"File not found: {filepath}")
with open(filepath, 'r') as f:
return json.load(f)
def load_data(self, original_mat_filename: str) -> Union[pd.DataFrame, Dict[str, Any]]:
"""
Load data using the original .mat filename.
Parameters:
-----------
original_mat_filename : str
Original .mat filename (e.g., 'earnannFile.mat')
Returns:
--------
Union[pd.DataFrame, dict]
Loaded data (DataFrame for CSV, dict for JSON)
"""
if original_mat_filename not in self.mapping:
raise ValueError(f"No conversion found for {original_mat_filename}")
converted_filename = self.mapping[original_mat_filename]
if converted_filename.endswith('.csv'):
return self.load_csv(converted_filename)
elif converted_filename.endswith('.json'):
return self.load_json(converted_filename)
else:
raise ValueError(f"Unsupported file format: {converted_filename}")
def get_timeseries_data(self, filename: str,
time_col: str = 'tday',
price_cols: Optional[list] = None) -> Dict[str, np.ndarray]:
"""
Load time series data and extract common fields.
Parameters:
-----------
filename : str
Original .mat filename or converted filename
time_col : str
Name of the time column
price_cols : list, optional
List of price column prefixes to extract (e.g., ['cl', 'op', 'hi', 'lo'])
Returns:
--------
dict
Dictionary with extracted time series data
"""
# Load the data
if filename.endswith('.mat'):
data = self.load_data(filename)
else:
if filename.endswith('.csv'):
data = self.load_csv(filename)
else:
data = self.load_json(filename)
# Convert JSON to DataFrame if it contains arrays
if isinstance(data, dict) and all(isinstance(v, list) for v in data.values()):
data = pd.DataFrame(data)
if not isinstance(data, pd.DataFrame):
raise ValueError("Data is not in tabular format")
result = {}
# Extract time data
if time_col in data.columns:
result['tday'] = data[time_col].values
# Extract price data
if price_cols is None:
price_cols = ['cl', 'op', 'hi', 'lo', 'vol']
for price_type in price_cols:
# Find columns that start with this price type
matching_cols = [col for col in data.columns if col.startswith(f'{price_type}_')]
if matching_cols:
# Sort columns by index
matching_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
price_data = data[matching_cols].values
result[price_type] = price_data
elif price_type in data.columns:
# Single column case
result[price_type] = data[price_type].values
# Extract contract information if available
contract_cols = [col for col in data.columns if col.startswith('contracts_')]
if contract_cols:
contract_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
result['contracts'] = data[contract_cols].values
# Extract symbol information if available
symbol_cols = [col for col in data.columns if col.startswith('syms_')]
if symbol_cols:
symbol_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
result['syms'] = data[symbol_cols].values
# Extract stock information if available
stock_cols = [col for col in data.columns if col.startswith('stocks_')]
if stock_cols:
stock_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
result['stocks'] = data[stock_cols].values
return result
def get_earnings_data(self) -> Dict[str, np.ndarray]:
"""
Load earnings announcement data.
Returns:
--------
dict
Dictionary with earnings data
"""
return self.get_timeseries_data('earnannFile.mat')
def get_interest_rate_data(self, currency: str = 'AUD') -> Dict[str, Any]:
"""
Load interest rate data.
Parameters:
-----------
currency : str
Currency code ('AUD' or 'CAD')
Returns:
--------
dict
Interest rate data
"""
filename = f"{currency}_interestRate.json"
return self.load_json(filename)
def get_futures_data(self, symbol: str, date: str = None) -> Dict[str, np.ndarray]:
"""
Load futures data for a specific symbol.
Parameters:
-----------
symbol : str
Futures symbol (e.g., 'TU', 'CL', 'VX', 'HO2', 'C2', 'HG', 'BR')
date : str, optional
Date string (e.g., '20120813')
Returns:
--------
dict
Futures data
"""
if date:
filename = f"inputDataDaily_{symbol}_{date}.csv"
else:
# Try to find any file for this symbol
possible_files = list(self.data_dir.glob(f"inputDataDaily_{symbol}_*.csv"))
if not possible_files:
raise FileNotFoundError(f"No data files found for symbol {symbol}")
filename = possible_files[0].name
return self.get_timeseries_data(filename)
def get_etf_data(self) -> Dict[str, np.ndarray]:
"""
Load ETF data.
Returns:
--------
dict
ETF data
"""
return self.get_timeseries_data('inputData_ETF.csv')
def get_stock_data(self, date: str = None) -> Dict[str, np.ndarray]:
"""
Load stock OHLC data.
Parameters:
-----------
date : str, optional
Date string (e.g., '20120424')
Returns:
--------
dict
Stock data
"""
if date:
filename = f"inputDataOHLCDaily_stocks_{date}.csv"
else:
filename = "inputDataOHLCDaily_stocks_20120424.csv"
return self.get_timeseries_data(filename)
def list_available_files(self) -> Dict[str, str]:
"""
List all available converted files.
Returns:
--------
dict
Mapping of original .mat files to converted files
"""
return self.mapping.copy()
# Create a default instance for easy importing
default_loader = DataLoader()
# Convenience functions
def load_earnings_data():
"""Load earnings announcement data."""
return default_loader.get_earnings_data()
def load_futures_data(symbol: str, date: str = None):
"""Load futures data for a specific symbol."""
return default_loader.get_futures_data(symbol, date)
def load_etf_data():
"""Load ETF data."""
return default_loader.get_etf_data()
def load_stock_data(date: str = None):
"""Load stock OHLC data."""
return default_loader.get_stock_data(date)
def load_interest_rates(currency: str = 'AUD'):
"""Load interest rate data."""
return default_loader.get_interest_rate_data(currency)