From c5163bce4dbc3d5af5368f19d425dc9a5057b787 Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Sat, 14 Sep 2024 14:11:13 +0200 Subject: [PATCH] clean up training script --- scripts/train.py | 70 ------------------------------------------------ 1 file changed, 70 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 89b116f..d09f065 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -81,76 +81,6 @@ def get_args(): return parser.parse_args() -# def get_dataset(config, project): -# artifact_name = f"{project}/{config['data']['dataset']}" -# artifact = wandb.Api().artifact(artifact_name) -# base_path = artifact.download() -# logging.info(f"Artifacts downloaded to {base_path}") - -# name = artifact.metadata['name'] -# part_name = f"in-sample-{config['data']['sliding_window']}" -# data = pd.read_csv(os.path.join( -# base_path, name + '-' + part_name + '.csv')) -# logging.info(f"Using part: {part_name}") - -# # TODO: Fix in dataset -# data['weekday'] = data['weekday'].astype('str') -# data['hour'] = data['hour'].astype('str') - -# validation_part = config['data']['validation'] -# logging.info(f"Using {validation_part} of in sample part for validation.") -# train_data = data.iloc[:int(len(data) * (1 - validation_part))] -# val_data = data.iloc[len(train_data) - config['past_window']:] -# logging.info(f"Trainin part size: {len(train_data)}") -# logging.info( -# f"Validation part size: {len(val_data)} " -# + f"({len(data) - len(train_data)} + {config['past_window']})") - -# logging.info("Building time series dataset for training.") -# train = TimeSeriesDataSet( -# train_data, -# time_idx=config['data']['fields']['time_index'], -# target=config['data']['fields']['target'], -# group_ids=config['data']['fields']['group_ids'], -# min_encoder_length=config['past_window'], -# max_encoder_length=config['past_window'], -# min_prediction_length=config['future_window'], -# max_prediction_length=config['future_window'], -# static_reals=config['data']['fields']['static_real'], -# static_categoricals=config['data']['fields']['static_cat'], -# time_varying_known_reals=config['data']['fields'][ -# 'dynamic_known_real'], -# time_varying_known_categoricals=config['data']['fields'][ -# 'dynamic_known_cat'], -# time_varying_unknown_reals=config['data']['fields'][ -# 'dynamic_unknown_real'], -# time_varying_unknown_categoricals=config['data']['fields'][ -# 'dynamic_unknown_cat'], -# randomize_length=False -# ) - -# logging.info("Building time series dataset for validation.") -# val = TimeSeriesDataSet.from_dataset( -# train, val_data, stop_randomization=True) - -# return train, val - - -# def get_loss(config): -# loss_name = config['loss']['name'] - -# if loss_name == 'Quantile': -# return QuantileLoss(config['loss']['quantiles']) - -# if loss_name == 'GMADL': -# return GMADL( -# a=config['loss']['a'], -# b=config['loss']['b'] -# ) - -# raise ValueError("Unknown loss") - - def main(): args = get_args() logging.basicConfig(level=args.log_level)