From 6cb03200d0ec86941d913d1f78ee23a9f1e6dfa3 Mon Sep 17 00:00:00 2001 From: Linxiao ZENG Date: Wed, 28 Jul 2021 15:45:32 +0200 Subject: [PATCH] Better sampling and relax dependency (#2082) --- .github/workflows/push.yml | 2 +- onmt/translate/greedy_search.py | 5 ++--- requirements.opt.txt | 10 +++------- setup.py | 13 ++++++------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index caa95631f7..9780df63fc 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -19,7 +19,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install --upgrade setuptools==50.3.0 + pip install --upgrade setuptools pip install -e . pip install -r requirements.opt.txt pip install flake8 diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 60d4c711d9..1c4a43a62d 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -83,9 +83,8 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): logits = sample_topp(logits, keep_topp) if keep_topk > 0: logits = sample_topk(logits, keep_topk) - dist = torch.distributions.Multinomial( - logits=logits, total_count=1) - topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True) + dist = torch.distributions.Categorical(logits=logits) + topk_ids = dist.sample().view(-1, 1) topk_scores = logits.gather(dim=1, index=topk_ids) return topk_ids, topk_scores diff --git a/requirements.opt.txt b/requirements.opt.txt index 576c58f7ed..aaa18d7abc 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -1,8 +1,4 @@ -cffi==1.14.3 -joblib==0.17.0 -numba==0.43.0 -llvmlite==0.32.1 -pyrouge==0.1.3 +pyrouge git+git://github.com/NVIDIA/apex.git@700d6825e205732c1d6be511306ca4e595297070 -sentencepiece==0.1.94 -subword-nmt==0.3.7 +sentencepiece>=0.1.94 +subword-nmt>=0.3.7 diff --git a/setup.py b/setup.py index d3a0f03206..ec3244c69e 100644 --- a/setup.py +++ b/setup.py @@ -21,15 +21,14 @@ }, python_requires=">=3.5", install_requires=[ - "tqdm>=4.51,<5", - "torch==1.6.0", + "torch>=1.6.0", "torchtext==0.5.0", - "configargparse>=1.2.3,<2", - "tensorboard>=2.3,<3", - "flask==1.1.2", - "waitress==1.4.4", + "configargparse", + "tensorboard>=2.3", + "flask", + "waitress", "pyonmttok>=1.23,<2;platform_system=='Linux' or platform_system=='Darwin'", - "pyyaml==5.4", + "pyyaml", ], entry_points={ "console_scripts": [