clean up training script
This commit is contained in:
parent
a3c5c76000
commit
c5163bce4d
@ -81,76 +81,6 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# def get_dataset(config, project):
|
||||
# artifact_name = f"{project}/{config['data']['dataset']}"
|
||||
# artifact = wandb.Api().artifact(artifact_name)
|
||||
# base_path = artifact.download()
|
||||
# logging.info(f"Artifacts downloaded to {base_path}")
|
||||
|
||||
# name = artifact.metadata['name']
|
||||
# part_name = f"in-sample-{config['data']['sliding_window']}"
|
||||
# data = pd.read_csv(os.path.join(
|
||||
# base_path, name + '-' + part_name + '.csv'))
|
||||
# logging.info(f"Using part: {part_name}")
|
||||
|
||||
# # TODO: Fix in dataset
|
||||
# data['weekday'] = data['weekday'].astype('str')
|
||||
# data['hour'] = data['hour'].astype('str')
|
||||
|
||||
# validation_part = config['data']['validation']
|
||||
# logging.info(f"Using {validation_part} of in sample part for validation.")
|
||||
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
||||
# val_data = data.iloc[len(train_data) - config['past_window']:]
|
||||
# logging.info(f"Trainin part size: {len(train_data)}")
|
||||
# logging.info(
|
||||
# f"Validation part size: {len(val_data)} "
|
||||
# + f"({len(data) - len(train_data)} + {config['past_window']})")
|
||||
|
||||
# logging.info("Building time series dataset for training.")
|
||||
# train = TimeSeriesDataSet(
|
||||
# train_data,
|
||||
# time_idx=config['data']['fields']['time_index'],
|
||||
# target=config['data']['fields']['target'],
|
||||
# group_ids=config['data']['fields']['group_ids'],
|
||||
# min_encoder_length=config['past_window'],
|
||||
# max_encoder_length=config['past_window'],
|
||||
# min_prediction_length=config['future_window'],
|
||||
# max_prediction_length=config['future_window'],
|
||||
# static_reals=config['data']['fields']['static_real'],
|
||||
# static_categoricals=config['data']['fields']['static_cat'],
|
||||
# time_varying_known_reals=config['data']['fields'][
|
||||
# 'dynamic_known_real'],
|
||||
# time_varying_known_categoricals=config['data']['fields'][
|
||||
# 'dynamic_known_cat'],
|
||||
# time_varying_unknown_reals=config['data']['fields'][
|
||||
# 'dynamic_unknown_real'],
|
||||
# time_varying_unknown_categoricals=config['data']['fields'][
|
||||
# 'dynamic_unknown_cat'],
|
||||
# randomize_length=False
|
||||
# )
|
||||
|
||||
# logging.info("Building time series dataset for validation.")
|
||||
# val = TimeSeriesDataSet.from_dataset(
|
||||
# train, val_data, stop_randomization=True)
|
||||
|
||||
# return train, val
|
||||
|
||||
|
||||
# def get_loss(config):
|
||||
# loss_name = config['loss']['name']
|
||||
|
||||
# if loss_name == 'Quantile':
|
||||
# return QuantileLoss(config['loss']['quantiles'])
|
||||
|
||||
# if loss_name == 'GMADL':
|
||||
# return GMADL(
|
||||
# a=config['loss']['a'],
|
||||
# b=config['loss']['b']
|
||||
# )
|
||||
|
||||
# raise ValueError("Unknown loss")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=args.log_level)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user