296 lines
9.4 KiB
Python
296 lines
9.4 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Union, Optional
|
|
|
|
class DataLoader:
|
|
"""
|
|
Utility class for loading converted data files (CSV and JSON).
|
|
"""
|
|
|
|
def __init__(self, data_dir: str = "data"):
|
|
"""
|
|
Initialize the data loader.
|
|
|
|
Parameters:
|
|
-----------
|
|
data_dir : str
|
|
Directory containing the converted data files
|
|
"""
|
|
self.data_dir = Path(data_dir)
|
|
self.mapping_file = self.data_dir / "conversion_mapping.json"
|
|
self.mapping = self._load_mapping()
|
|
|
|
def _load_mapping(self) -> Dict[str, str]:
|
|
"""Load the conversion mapping file."""
|
|
if self.mapping_file.exists():
|
|
with open(self.mapping_file, 'r') as f:
|
|
return json.load(f)
|
|
return {}
|
|
|
|
def load_csv(self, filename: str) -> pd.DataFrame:
|
|
"""
|
|
Load a CSV file.
|
|
|
|
Parameters:
|
|
-----------
|
|
filename : str
|
|
Name of the CSV file to load
|
|
|
|
Returns:
|
|
--------
|
|
pd.DataFrame
|
|
Loaded data
|
|
"""
|
|
filepath = self.data_dir / filename
|
|
if not filepath.exists():
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
return pd.read_csv(filepath)
|
|
|
|
def load_json(self, filename: str) -> Dict[str, Any]:
|
|
"""
|
|
Load a JSON file.
|
|
|
|
Parameters:
|
|
-----------
|
|
filename : str
|
|
Name of the JSON file to load
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Loaded data
|
|
"""
|
|
filepath = self.data_dir / filename
|
|
if not filepath.exists():
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
with open(filepath, 'r') as f:
|
|
return json.load(f)
|
|
|
|
def load_data(self, original_mat_filename: str) -> Union[pd.DataFrame, Dict[str, Any]]:
|
|
"""
|
|
Load data using the original .mat filename.
|
|
|
|
Parameters:
|
|
-----------
|
|
original_mat_filename : str
|
|
Original .mat filename (e.g., 'earnannFile.mat')
|
|
|
|
Returns:
|
|
--------
|
|
Union[pd.DataFrame, dict]
|
|
Loaded data (DataFrame for CSV, dict for JSON)
|
|
"""
|
|
if original_mat_filename not in self.mapping:
|
|
raise ValueError(f"No conversion found for {original_mat_filename}")
|
|
|
|
converted_filename = self.mapping[original_mat_filename]
|
|
|
|
if converted_filename.endswith('.csv'):
|
|
return self.load_csv(converted_filename)
|
|
elif converted_filename.endswith('.json'):
|
|
return self.load_json(converted_filename)
|
|
else:
|
|
raise ValueError(f"Unsupported file format: {converted_filename}")
|
|
|
|
def get_timeseries_data(self, filename: str,
|
|
time_col: str = 'tday',
|
|
price_cols: Optional[list] = None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Load time series data and extract common fields.
|
|
|
|
Parameters:
|
|
-----------
|
|
filename : str
|
|
Original .mat filename or converted filename
|
|
time_col : str
|
|
Name of the time column
|
|
price_cols : list, optional
|
|
List of price column prefixes to extract (e.g., ['cl', 'op', 'hi', 'lo'])
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Dictionary with extracted time series data
|
|
"""
|
|
# Load the data
|
|
if filename.endswith('.mat'):
|
|
data = self.load_data(filename)
|
|
else:
|
|
if filename.endswith('.csv'):
|
|
data = self.load_csv(filename)
|
|
else:
|
|
data = self.load_json(filename)
|
|
# Convert JSON to DataFrame if it contains arrays
|
|
if isinstance(data, dict) and all(isinstance(v, list) for v in data.values()):
|
|
data = pd.DataFrame(data)
|
|
|
|
if not isinstance(data, pd.DataFrame):
|
|
raise ValueError("Data is not in tabular format")
|
|
|
|
result = {}
|
|
|
|
# Extract time data
|
|
if time_col in data.columns:
|
|
result['tday'] = data[time_col].values
|
|
|
|
# Extract price data
|
|
if price_cols is None:
|
|
price_cols = ['cl', 'op', 'hi', 'lo', 'vol']
|
|
|
|
for price_type in price_cols:
|
|
# Find columns that start with this price type
|
|
matching_cols = [col for col in data.columns if col.startswith(f'{price_type}_')]
|
|
if matching_cols:
|
|
# Sort columns by index
|
|
matching_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
|
|
price_data = data[matching_cols].values
|
|
result[price_type] = price_data
|
|
elif price_type in data.columns:
|
|
# Single column case
|
|
result[price_type] = data[price_type].values
|
|
|
|
# Extract contract information if available
|
|
contract_cols = [col for col in data.columns if col.startswith('contracts_')]
|
|
if contract_cols:
|
|
contract_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
|
|
result['contracts'] = data[contract_cols].values
|
|
|
|
# Extract symbol information if available
|
|
symbol_cols = [col for col in data.columns if col.startswith('syms_')]
|
|
if symbol_cols:
|
|
symbol_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
|
|
result['syms'] = data[symbol_cols].values
|
|
|
|
# Extract stock information if available
|
|
stock_cols = [col for col in data.columns if col.startswith('stocks_')]
|
|
if stock_cols:
|
|
stock_cols.sort(key=lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0)
|
|
result['stocks'] = data[stock_cols].values
|
|
|
|
return result
|
|
|
|
def get_earnings_data(self) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Load earnings announcement data.
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Dictionary with earnings data
|
|
"""
|
|
return self.get_timeseries_data('earnannFile.mat')
|
|
|
|
def get_interest_rate_data(self, currency: str = 'AUD') -> Dict[str, Any]:
|
|
"""
|
|
Load interest rate data.
|
|
|
|
Parameters:
|
|
-----------
|
|
currency : str
|
|
Currency code ('AUD' or 'CAD')
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Interest rate data
|
|
"""
|
|
filename = f"{currency}_interestRate.json"
|
|
return self.load_json(filename)
|
|
|
|
def get_futures_data(self, symbol: str, date: str = None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Load futures data for a specific symbol.
|
|
|
|
Parameters:
|
|
-----------
|
|
symbol : str
|
|
Futures symbol (e.g., 'TU', 'CL', 'VX', 'HO2', 'C2', 'HG', 'BR')
|
|
date : str, optional
|
|
Date string (e.g., '20120813')
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Futures data
|
|
"""
|
|
if date:
|
|
filename = f"inputDataDaily_{symbol}_{date}.csv"
|
|
else:
|
|
# Try to find any file for this symbol
|
|
possible_files = list(self.data_dir.glob(f"inputDataDaily_{symbol}_*.csv"))
|
|
if not possible_files:
|
|
raise FileNotFoundError(f"No data files found for symbol {symbol}")
|
|
filename = possible_files[0].name
|
|
|
|
return self.get_timeseries_data(filename)
|
|
|
|
def get_etf_data(self) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Load ETF data.
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
ETF data
|
|
"""
|
|
return self.get_timeseries_data('inputData_ETF.csv')
|
|
|
|
def get_stock_data(self, date: str = None) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Load stock OHLC data.
|
|
|
|
Parameters:
|
|
-----------
|
|
date : str, optional
|
|
Date string (e.g., '20120424')
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Stock data
|
|
"""
|
|
if date:
|
|
filename = f"inputDataOHLCDaily_stocks_{date}.csv"
|
|
else:
|
|
filename = "inputDataOHLCDaily_stocks_20120424.csv"
|
|
|
|
return self.get_timeseries_data(filename)
|
|
|
|
def list_available_files(self) -> Dict[str, str]:
|
|
"""
|
|
List all available converted files.
|
|
|
|
Returns:
|
|
--------
|
|
dict
|
|
Mapping of original .mat files to converted files
|
|
"""
|
|
return self.mapping.copy()
|
|
|
|
# Create a default instance for easy importing
|
|
default_loader = DataLoader()
|
|
|
|
# Convenience functions
|
|
def load_earnings_data():
|
|
"""Load earnings announcement data."""
|
|
return default_loader.get_earnings_data()
|
|
|
|
def load_futures_data(symbol: str, date: str = None):
|
|
"""Load futures data for a specific symbol."""
|
|
return default_loader.get_futures_data(symbol, date)
|
|
|
|
def load_etf_data():
|
|
"""Load ETF data."""
|
|
return default_loader.get_etf_data()
|
|
|
|
def load_stock_data(date: str = None):
|
|
"""Load stock OHLC data."""
|
|
return default_loader.get_stock_data(date)
|
|
|
|
def load_interest_rates(currency: str = 'AUD'):
|
|
"""Load interest rate data."""
|
|
return default_loader.get_interest_rate_data(currency) |