Skip to content

Commit

Permalink
Fix multi-gpu training via DataParallel (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
tholor authored Jul 15, 2020
1 parent 5c1a5fe commit c9d3146
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
1 change: 1 addition & 0 deletions haystack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
logging.getLogger('farm.infer').setLevel(logging.INFO)
logging.getLogger('transformers').setLevel(logging.WARNING)
logging.getLogger('farm.eval').setLevel(logging.INFO)
logging.getLogger('farm.modeling.optimization').setLevel(logging.INFO)


14 changes: 13 additions & 1 deletion haystack/reader/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from farm.infer import QAInferencer
from farm.modeling.optimization import initialize_optimizer
from farm.modeling.predictions import QAPred, QACandidate
from farm.modeling.adaptive_model import BaseAdaptiveModel
from farm.train import Trainer
from farm.eval import Evaluator
from farm.utils import set_all_seeds, initialize_device_settings
from scipy.special import expit
import shutil

from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
Expand Down Expand Up @@ -177,9 +179,17 @@ def train(
# and calculates a few descriptive statistics of our datasets
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)

# Quick-fix until this is fixed upstream in FARM:
# We must avoid applying DataParallel twice (once when loading the inferencer,
# once when calling initalize_optimizer)
self.inferencer.model.save("tmp_model")
model = BaseAdaptiveModel.load(load_dir="tmp_model", device=device, strict=True)
shutil.rmtree('tmp_model')

# 3. Create an optimizer and pass the already initialized model
model, optimizer, lr_schedule = initialize_optimizer(
model=self.inferencer.model,
model=model,
# model=self.inferencer.model,
learning_rate=learning_rate,
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
n_batches=len(data_silo.loaders["train"]),
Expand All @@ -197,6 +207,8 @@ def train(
evaluate_every=evaluate_every,
device=device,
)


# 5. Let it grow!
self.inferencer.model = trainer.train()
self.save(Path(save_dir))
Expand Down
4 changes: 2 additions & 2 deletions tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@
}
],
"source": [
"reader = FARMReader(model_name_or_path=\"distilbert-base-uncased-distilled-squad\", use_gpu=False)\n",
"reader = FARMReader(model_name_or_path=\"distilbert-base-uncased-distilled-squad\", use_gpu=True)\n",
"train_data = \"data/squad20\"\n",
"# train_data = \"PATH/TO_YOUR/TRAIN_DATA\" \n",
"reader.train(data_dir=train_data, train_filename=\"dev-v2.0.json\", use_gpu=False, n_epochs=1, save_dir=\"my_model\")"
"reader.train(data_dir=train_data, train_filename=\"dev-v2.0.json\", use_gpu=True, n_epochs=1, save_dir=\"my_model\")"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

#**Recommendation: Run training on a GPU. To do so change the `use_gpu` arguments below to `True`

reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=True)
train_data = "data/squad20"
# train_data = "PATH/TO_YOUR/TRAIN_DATA"
reader.train(data_dir=train_data, train_filename="dev-v2.0.json", use_gpu=False, n_epochs=1, save_dir="my_model")
reader.train(data_dir=train_data, train_filename="dev-v2.0.json", use_gpu=True, n_epochs=1, save_dir="my_model")

# Saving the model happens automatically at the end of training into the `save_dir` you specified
# However, you could also save a reader manually again via:
Expand Down

0 comments on commit c9d3146

Please sign in to comment.