import pandas as pd import numpy as np import json from pathlib import Path from typing import Dict, Any, Union, Optional class DataLoader: """ Utility class for loading converted data files (CSV and JSON). """ def __init__(self, data_dir: str = "data"): """ Initialize the data loader. Parameters: ----------- data_dir : str Directory containing the converted data files """ self.data_dir = Path(data_dir) self.mapping_file = self.data_dir / "conversion_mapping.json" self.mapping = self._load_mapping() def _load_mapping(self) -> Dict[str, str]: """Load the conversion mapping file.""" if self.mapping_file.exists(): with open(self.mapping_file, 'r') as f: return json.load(f) return {} def load_csv(self, filename: str) -> pd.DataFrame: """ Load a CSV file. Parameters: ----------- filename : str Name of the CSV file to load Returns: -------- pd.DataFrame Loaded data """ filepath = self.data_dir / filename if not filepath.exists(): raise FileNotFoundError(f"File not found: {filepath}") return pd.read_csv(filepath) def load_json(self, filename: str) -> Dict[str, Any]: """ Load a JSON file. Parameters: ----------- filename : str Name of the JSON file to load Returns: -------- dict Loaded data """ filepath = self.data_dir / filename if not filepath.exists(): raise FileNotFoundError(f"File not found: {filepath}") with open(filepath, 'r') as f: return json.load(f) def load_data(self, original_mat_filename: str) -> Union[pd.DataFrame, Dict[str, Any]]: """ Load data using the original .mat filename. Parameters: ----------- original_mat_filename : str Original .mat filename (e.g., 'earnannFile.mat') Returns: -------- Union[pd.DataFrame, dict] Loaded data (DataFrame for CSV, dict for JSON) """ if original_mat_filename not in self.mapping: raise ValueError(f"No conversion found for {original_mat_filename}") converted_filename = self.mapping[original_mat_filename] if converted_filename.endswith('.csv'): return self.load_csv(converted_filename) elif converted_filename.endswith('.json'): return self.load_json(converted_filename) else: raise ValueError(f"Unsupported file format: {converted_filename}") def get_timeseries_data(self, filename: str, time_col: str = 'tday', price_cols: Optional[list] = None) -> Dict[str, np.ndarray]: """ Load time series data and extract common fields. Parameters: ----------- filename : str Original .mat filename or converted filename time_col : str Name of the time column price_cols : list, optional List of price column prefixes to extract (e.g., ['cl', 'op', 'hi', 'lo']) Returns: -------- dict Dictionary with extracted time series data """ # Load the data if filename.endswith('.mat'): data = self.load_data(filename) else: if filename.endswith('.csv'): data = self.load_csv(filename) else: data = self.load_json(filename) # Convert JSON to DataFrame if it contains arrays if isinstance(data, dict) and all(isinstance(v, list) for v in data.values()): data = pd.DataFrame(data) if not isinstance(data, pd.DataFrame): raise ValueError("Data is not in tabular format") result = {} # Extract time data if time_col in data.columns: result['tday'] = data[time_col].values # Extract price data if price_cols is None: price_cols = ['cl', 'op', 'hi', 'lo', 'vol'] for price_type in price_cols: # Find columns that start with this price type matching_cols = [col for col in data.columns if col.startswith(f'{price_type}_')] if matching_cols: # Sort columns by index matching_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0) price_data = data[matching_cols].values result[price_type] = price_data elif price_type in data.columns: # Single column case result[price_type] = data[price_type].values # Extract contract information if available contract_cols = [col for col in data.columns if col.startswith('contracts_')] if contract_cols: contract_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0) result['contracts'] = data[contract_cols].values # Extract symbol information if available symbol_cols = [col for col in data.columns if col.startswith('syms_')] if symbol_cols: symbol_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0) result['syms'] = data[symbol_cols].values # Extract stock information if available stock_cols = [col for col in data.columns if col.startswith('stocks_')] if stock_cols: stock_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0) result['stocks'] = data[stock_cols].values return result def get_earnings_data(self) -> Dict[str, np.ndarray]: """ Load earnings announcement data. Returns: -------- dict Dictionary with earnings data """ return self.get_timeseries_data('earnannFile.mat') def get_interest_rate_data(self, currency: str = 'AUD') -> Dict[str, Any]: """ Load interest rate data. Parameters: ----------- currency : str Currency code ('AUD' or 'CAD') Returns: -------- dict Interest rate data """ filename = f"{currency}_interestRate.json" return self.load_json(filename) def get_futures_data(self, symbol: str, date: str = None) -> Dict[str, np.ndarray]: """ Load futures data for a specific symbol. Parameters: ----------- symbol : str Futures symbol (e.g., 'TU', 'CL', 'VX', 'HO2', 'C2', 'HG', 'BR') date : str, optional Date string (e.g., '20120813') Returns: -------- dict Futures data """ if date: filename = f"inputDataDaily_{symbol}_{date}.csv" else: # Try to find any file for this symbol possible_files = list(self.data_dir.glob(f"inputDataDaily_{symbol}_*.csv")) if not possible_files: raise FileNotFoundError(f"No data files found for symbol {symbol}") filename = possible_files[0].name return self.get_timeseries_data(filename) def get_etf_data(self) -> Dict[str, np.ndarray]: """ Load ETF data. Returns: -------- dict ETF data """ return self.get_timeseries_data('inputData_ETF.csv') def get_stock_data(self, date: str = None) -> Dict[str, np.ndarray]: """ Load stock OHLC data. Parameters: ----------- date : str, optional Date string (e.g., '20120424') Returns: -------- dict Stock data """ if date: filename = f"inputDataOHLCDaily_stocks_{date}.csv" else: filename = "inputDataOHLCDaily_stocks_20120424.csv" return self.get_timeseries_data(filename) def list_available_files(self) -> Dict[str, str]: """ List all available converted files. Returns: -------- dict Mapping of original .mat files to converted files """ return self.mapping.copy() # Create a default instance for easy importing default_loader = DataLoader() # Convenience functions def load_earnings_data(): """Load earnings announcement data.""" return default_loader.get_earnings_data() def load_futures_data(symbol: str, date: str = None): """Load futures data for a specific symbol.""" return default_loader.get_futures_data(symbol, date) def load_etf_data(): """Load ETF data.""" return default_loader.get_etf_data() def load_stock_data(date: str = None): """Load stock OHLC data.""" return default_loader.get_stock_data(date) def load_interest_rates(currency: str = 'AUD'): """Load interest rate data.""" return default_loader.get_interest_rate_data(currency)