Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add num_processes to reader.train() to configure multiprocessing #271

Merged
merged 1 commit into from
Jul 29, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions haystack/reader/farm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from pathlib import Path
from typing import List, Optional, Union
import multiprocessing

import numpy as np
from farm.data_handler.data_silo import DataSilo
Expand Down Expand Up @@ -117,6 +118,7 @@ def train(
dev_split: Optional[float] = 0.1,
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = 0
):
"""
Fine-tune a model on a QA dataset. Options:
Expand All @@ -139,13 +141,17 @@ def train(
Options for different schedules are available in FARM.
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
:param save_dir: Path to store the final model
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 0 to disable multiprocessing. Set to None to use all CPU cores minus one.
:return: None
"""


if dev_filename:
dev_split = None

if num_processes is None:
num_processes = multiprocessing.cpu_count() - 1

set_all_seeds(seed=42)

# For these variables, by default, we use the value set when initializing the FARMReader.
Expand Down Expand Up @@ -177,7 +183,7 @@ def train(

# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
# and calculates a few descriptive statistics of our datasets
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes)

# Quick-fix until this is fixed upstream in FARM:
# We must avoid applying DataParallel twice (once when loading the inferencer,
Expand Down