clean up training script

This commit is contained in:
Filip Stefaniuk 2024-09-14 14:11:13 +02:00
parent a3c5c76000
commit c5163bce4d

View File

@ -81,76 +81,6 @@ def get_args():
return parser.parse_args()
# def get_dataset(config, project):
# artifact_name = f"{project}/{config['data']['dataset']}"
# artifact = wandb.Api().artifact(artifact_name)
# base_path = artifact.download()
# logging.info(f"Artifacts downloaded to {base_path}")
# name = artifact.metadata['name']
# part_name = f"in-sample-{config['data']['sliding_window']}"
# data = pd.read_csv(os.path.join(
# base_path, name + '-' + part_name + '.csv'))
# logging.info(f"Using part: {part_name}")
# # TODO: Fix in dataset
# data['weekday'] = data['weekday'].astype('str')
# data['hour'] = data['hour'].astype('str')
# validation_part = config['data']['validation']
# logging.info(f"Using {validation_part} of in sample part for validation.")
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
# val_data = data.iloc[len(train_data) - config['past_window']:]
# logging.info(f"Trainin part size: {len(train_data)}")
# logging.info(
# f"Validation part size: {len(val_data)} "
# + f"({len(data) - len(train_data)} + {config['past_window']})")
# logging.info("Building time series dataset for training.")
# train = TimeSeriesDataSet(
# train_data,
# time_idx=config['data']['fields']['time_index'],
# target=config['data']['fields']['target'],
# group_ids=config['data']['fields']['group_ids'],
# min_encoder_length=config['past_window'],
# max_encoder_length=config['past_window'],
# min_prediction_length=config['future_window'],
# max_prediction_length=config['future_window'],
# static_reals=config['data']['fields']['static_real'],
# static_categoricals=config['data']['fields']['static_cat'],
# time_varying_known_reals=config['data']['fields'][
# 'dynamic_known_real'],
# time_varying_known_categoricals=config['data']['fields'][
# 'dynamic_known_cat'],
# time_varying_unknown_reals=config['data']['fields'][
# 'dynamic_unknown_real'],
# time_varying_unknown_categoricals=config['data']['fields'][
# 'dynamic_unknown_cat'],
# randomize_length=False
# )
# logging.info("Building time series dataset for validation.")
# val = TimeSeriesDataSet.from_dataset(
# train, val_data, stop_randomization=True)
# return train, val
# def get_loss(config):
# loss_name = config['loss']['name']
# if loss_name == 'Quantile':
# return QuantileLoss(config['loss']['quantiles'])
# if loss_name == 'GMADL':
# return GMADL(
# a=config['loss']['a'],
# b=config['loss']['b']
# )
# raise ValueError("Unknown loss")
def main():
args = get_args()
logging.basicConfig(level=args.log_level)