Skip to content

Commit

Permalink
re-change imputer path
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Jun 19, 2024
1 parent 4293055 commit d6417e3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
13 changes: 5 additions & 8 deletions icedqcd/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,13 @@ def process_data(args):
if param['predict'] in ['xgb', 'xgb_logistic']:

print(f'Evaluating MVA-model "{ID}" \n')

## 1. Impute data
if args['imputation_param']['active']:

fmodel = f'{args["datadir"]}/imputer_{args["__hash_genesis__"]}.pkl'
cprint(f'Loading imputer from: {fmodel}', 'green')

fmodel = f'{args["modeldir"]}/imputer.pkl'
imputer = pickle.load(open(fmodel, 'rb'))

data['data'], _ = process.impute_datasets(data=data['data'], features=None, args=args['imputation_param'], imputer=imputer)
data['data'], _ = process.impute_datasets(data=data['data'], features=None, args=args['imputation_param'], imputer=imputer)

## 2. Apply the input variable set reductor
X,ids = aux.red(X=data['data'].x, ids=data['data'].ids, param=param)
Expand Down Expand Up @@ -349,7 +347,7 @@ def get_predictor(args, param, feature_names=None):

elif param['predict'] == 'xgb_logistic':
func_predict, model = predict.pred_xgb_logistic(args=args, param=param, feature_names=feature_names, return_model=True)

#elif param['predict'] == 'torch_vector':
# func_predict = predict.pred_torch_generic(args=args, param=param)

Expand All @@ -360,4 +358,3 @@ def get_predictor(args, param, feature_names=None):
raise Exception(__name__ + f'.get_predictor: Unknown param["predict"] = {param["predict"]}')

return func_predict, model

4 changes: 2 additions & 2 deletions icenet/tools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def process_data(args, predata, func_factor, mvavars, runmode):
output['trn']['data'], imputer = impute_datasets(data=output['trn']['data'], features=impute_vars, args=args['imputation_param'], imputer=None)
output['val']['data'], imputer = impute_datasets(data=output['val']['data'], features=impute_vars, args=args['imputation_param'], imputer=imputer)

fmodel = f'{args["datadir"]}/imputer_{args["__hash_genesis__"]}.pkl'
fmodel = f'{args["modeldir"]}/imputer.pkl'

cprint(__name__ + f'.process_data: Saving imputer to: {fmodel}', 'green')
pickle.dump(imputer, open(fmodel, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
Expand Down Expand Up @@ -686,7 +686,7 @@ def process_data(args, predata, func_factor, mvavars, runmode):
## Imputate
if args['imputation_param']['active']:

fmodel = f'{args["datadir"]}/imputer_{args["__hash_genesis__"]}.pkl'
fmodel = f'{args["modeldir"]}/imputer.pkl'

cprint(__name__ + f'.process_data: Loading imputer from: {fmodel}', 'green')
imputer = pickle.load(open(fmodel, 'rb'))
Expand Down

0 comments on commit d6417e3

Please sign in to comment.