Skip to content

Commit

Permalink
Better sampling and relax dependency (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zenglinxiao authored Jul 28, 2021
1 parent 54c777a commit 6cb0320
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --upgrade setuptools==50.3.0
pip install --upgrade setuptools
pip install -e .
pip install -r requirements.opt.txt
pip install flake8
Expand Down
5 changes: 2 additions & 3 deletions onmt/translate/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp):
logits = sample_topp(logits, keep_topp)
if keep_topk > 0:
logits = sample_topk(logits, keep_topk)
dist = torch.distributions.Multinomial(
logits=logits, total_count=1)
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
dist = torch.distributions.Categorical(logits=logits)
topk_ids = dist.sample().view(-1, 1)
topk_scores = logits.gather(dim=1, index=topk_ids)
return topk_ids, topk_scores

Expand Down
10 changes: 3 additions & 7 deletions requirements.opt.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
cffi==1.14.3
joblib==0.17.0
numba==0.43.0
llvmlite==0.32.1
pyrouge==0.1.3
pyrouge
git+git://github.com/NVIDIA/apex.git@700d6825e205732c1d6be511306ca4e595297070
sentencepiece==0.1.94
subword-nmt==0.3.7
sentencepiece>=0.1.94
subword-nmt>=0.3.7
13 changes: 6 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
},
python_requires=">=3.5",
install_requires=[
"tqdm>=4.51,<5",
"torch==1.6.0",
"torch>=1.6.0",
"torchtext==0.5.0",
"configargparse>=1.2.3,<2",
"tensorboard>=2.3,<3",
"flask==1.1.2",
"waitress==1.4.4",
"configargparse",
"tensorboard>=2.3",
"flask",
"waitress",
"pyonmttok>=1.23,<2;platform_system=='Linux' or platform_system=='Darwin'",
"pyyaml==5.4",
"pyyaml",
],
entry_points={
"console_scripts": [
Expand Down

0 comments on commit 6cb0320

Please sign in to comment.