Skip to content

Commit

Permalink
Add support for local dataset file path to be specified
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Schork committed Apr 26, 2023
1 parent c0364e8 commit 3f8f74f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ test: test_3_9 test_3_10 test_3_11 ## Test all container versions

.PHONY: test_3_9
test_3_9: build_3_9 ## Test Python 3.9 pickle
docker run -i --rm --volume /tmp:/tmp numerai-predict_py3.9:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.9.pkl
docker run -i --rm --volume /tmp:/tmp ${NAME}_py_3_9:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.9.pkl

.PHONY: test_colab
test_colab: build_3_9 ## Test Python 3.9 pickle colab export
docker run -i --rm --volume /tmp:/tmp numerai-predict_py3.9:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/colab_3.9.16.pkl
docker run -i --rm --volume /tmp:/tmp ${NAME}_py_3_9:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/colab_3.9.16.pkl

.PHONY: test_3_10
test_3_10: build_3_10 ## Test Python 3.10 pickle
docker run -i --rm --volume /tmp:/tmp numerai-predict_py3.10:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.10.pkl
docker run -i --rm --volume /tmp:/tmp ${NAME}_py_3_10:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.10.pkl

.PHONY: test_3_11
test_3_11: build_3_11 ## Test Python 3.11 pickle
docker run -i --rm --volume /tmp:/tmp numerai-predict_py3.11:latest --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.11.pkl
docker run -i --rm --volume /tmp:/tmp ${NAME}_py_3_11:latest --dataset /tmp/v4.1/live.parquet --model https://huggingface.co/pschork/hello-numerai-models/resolve/main/model_3.11.pkl

.PHONY: release
release: release_3_9 release_3_10 release_3_11 ## Push all container tagged releases
Expand Down
23 changes: 14 additions & 9 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset", default="v4.1/live.parquet", help="Numerapi dataset path"
"--dataset", default="v4.1/live.parquet", help="Numerapi dataset path or local file"
)
parser.add_argument("--model", required=True, help="Pickled model file or URL")
parser.add_argument("--output_dir", default="/tmp", help="File output dir")
Expand Down Expand Up @@ -70,12 +70,17 @@ def predict(args):
model = pd.read_pickle(model_pkl)
logging.debug(model)

napi = NumerAPI()
current_round = napi.get_current_round()

dataset_path = os.path.join(args.output_dir, args.dataset)
logging.info(f"Downloading {args.dataset} for round {current_round}")
napi.download_dataset(args.dataset, dataset_path)
if os.path.exists(args.dataset):
dataset_path = args.dataset
logging.info(f"Using local {dataset_path} for live data")
elif args.dataset.startswith("/"):
logging.error(f"Local dataset not found - {args.dataset} does not exist!")
sys.exit(1)
else:
dataset_path = os.path.join(args.output_dir, args.dataset)
logging.info(f"Using NumerAPI to download {args.dataset} for live data")
napi = NumerAPI()
napi.download_dataset(args.dataset, dataset_path)

logging.info(f"Loading live features {dataset_path}")
live_features = pd.read_parquet(dataset_path)
Expand All @@ -87,14 +92,14 @@ def predict(args):
logging.debug(predictions)

predictions_csv = os.path.join(
args.output_dir, f"live_predictions_{current_round}-{secrets.token_hex(6)}.csv"
args.output_dir, f"live_predictions-{secrets.token_hex(6)}.csv"
)
logging.info(f"Saving predictions to {predictions_csv}")
with open(predictions_csv, "w") as f:
predictions.to_csv(f)

if args.post_url:
logging.info(f"Uploading predictions to {args.post}")
logging.info(f"Uploading predictions to {args.post_url}")
files = {"file": open(predictions_csv, "rb")}
r = requests.post(args.post_url, data=args.post_data, files=files)
if r.status_code != 200:
Expand Down

0 comments on commit 3f8f74f

Please sign in to comment.