Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding dialogue state tracking TRADE #322

Merged
merged 160 commits into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
62f4bc0
process woz data
chiphuyen Oct 19, 2019
8b039cb
merge with glue
chiphuyen Oct 19, 2019
ebfff6e
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 11, 2019
0a39962
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 12, 2019
d110bd9
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 14, 2019
4970362
wip
chiphuyen Nov 15, 2019
a87dbb0
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 15, 2019
dc4ac48
wip
chiphuyen Nov 16, 2019
6f3e3af
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 16, 2019
2bdc211
wip
chiphuyen Nov 16, 2019
74883fc
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 18, 2019
5bbded4
wip
chiphuyen Nov 19, 2019
ed53a1e
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 19, 2019
a06e864
adding rnnencoder
chiphuyen Nov 19, 2019
d748c46
Merge branch 'master' of github.com:NVIDIA/NeMo into rnnencoder
chiphuyen Nov 19, 2019
a9ea100
rnn encoder
chiphuyen Nov 19, 2019
ab9e32d
Merge branch 'rnnencoder' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 19, 2019
1069bc5
wip
chiphuyen Nov 21, 2019
b74f58e
wip
chiphuyen Nov 21, 2019
9f1d37f
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 21, 2019
8757cd7
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Nov 21, 2019
5ef5264
wip
chiphuyen Dec 2, 2019
c2b33e8
Merge branch 'master' of github.com:NVIDIA/NeMo into woz
chiphuyen Dec 2, 2019
d59bece
merge master
chiphuyen Dec 12, 2019
f9563af
still bugs in dst dataset
chiphuyen Dec 13, 2019
6149977
Updated field names.
VahidooX Dec 17, 2019
23a781f
Fixed data layer bugs and indentations.
VahidooX Dec 18, 2019
f415c1a
Evaluation added.
VahidooX Dec 24, 2019
871ff50
Some bugs in the generator fixed.
VahidooX Dec 27, 2019
0ced200
Evaluation fixed. Mem usage improved.
VahidooX Dec 30, 2019
3628cb6
Added shuffling. Removed mem lang.
VahidooX Dec 31, 2019
e3de958
Adding progress bar.
VahidooX Dec 31, 2019
6e68b65
Added progress bar.
VahidooX Jan 2, 2020
7b6bf61
Added None option to grad_norm_clip.
VahidooX Jan 3, 2020
745f87a
Added None option to grad_norm_clip.
VahidooX Jan 3, 2020
c2644da
Fixed steps_per_epoch for multi gpu training.
VahidooX Jan 3, 2020
f8a08a6
Fixed steps_per_epoch for multi gpu training.
VahidooX Jan 3, 2020
1748754
Made progressbar optional.
VahidooX Jan 3, 2020
dd9b436
Cleaned.
VahidooX Jan 3, 2020
f43d1dd
Deterministic debugging!
VahidooX Jan 8, 2020
04fa8fa
Debugged!
VahidooX Jan 8, 2020
00237db
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
70850b5
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
18abbb8
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
08fa9cd
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
7e197ac
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
e7b0b55
Added input dropout parameter to EncoderRNN.
VahidooX Jan 8, 2020
3226eb5
Fixed the NAN bug caused by using torch.log().
VahidooX Jan 9, 2020
3c26010
Enabled weight decay.
VahidooX Jan 9, 2020
8b8e743
Fixed Nemo bug on list of params to optimize. Added input_dropout param.
VahidooX Jan 9, 2020
38da6a8
Fixed Nemo bug on list of params to optimize. Added input_dropout param.
VahidooX Jan 10, 2020
e2a41ab
Fixed evaluation bug.
VahidooX Jan 10, 2020
ad90953
Update max_res_len to max of eval batch.
VahidooX Jan 10, 2020
18b404d
Disabled masking in eval.
VahidooX Jan 10, 2020
e6533e5
Fixed input masking bug.
VahidooX Jan 13, 2020
9e34cad
Added is_training to data layer.
VahidooX Jan 13, 2020
3a46673
Fixed 'none' handling in dataset loading.
VahidooX Jan 14, 2020
0897b17
Fixed evaluation on 'none' values.
VahidooX Jan 15, 2020
1676168
Fixed evaluation on 'none' values.
VahidooX Jan 15, 2020
a7d0cd2
Matched creating dialogue history.
VahidooX Jan 15, 2020
c6ff227
Disabled seeds for random gens.
VahidooX Jan 15, 2020
04502ed
Fixed shuffle setting in data layer.
VahidooX Jan 15, 2020
ebdeb3f
First cleanup round.
VahidooX Jan 15, 2020
a881535
Made it work with no target during evaluation.
VahidooX Jan 15, 2020
d5da3e4
Reversed no target support.
VahidooX Jan 15, 2020
58748d1
Removed f1 score.
VahidooX Jan 16, 2020
bbf5b69
Removed progress bar. Code cleaned.
VahidooX Jan 16, 2020
fe1537f
Removed unused parameters.
VahidooX Jan 16, 2020
e91efd5
Added lr_policy support.
VahidooX Jan 16, 2020
90b8b39
Set default warmup to zero.
VahidooX Jan 16, 2020
224b724
Set default to MultiWOZ 2.1
VahidooX Jan 17, 2020
fd38b2b
Added min_lr parameter.
VahidooX Jan 23, 2020
488111e
Fixed min_lr parameter.
VahidooX Jan 23, 2020
2075413
Added datadesc and pipeline creaters.
VahidooX Jan 24, 2020
1b42293
Fixed pipeline creaters.
VahidooX Jan 24, 2020
4edd0a7
Cleaned process_multiwoz
VahidooX Jan 24, 2020
c0ca99c
Fixed the masking bottleneck which caused low gpu utility.
VahidooX Jan 28, 2020
fa56525
Fixed the DTH memory bottleneck which caused low gpu utility.
VahidooX Jan 28, 2020
da70cf8
Changed data_layers to data_desc in the callbacks.
VahidooX Jan 28, 2020
800a8d3
Updated data layer name for MultiWOZ
VahidooX Jan 28, 2020
68d67a1
Updated data layer name for MultiWOZ
VahidooX Jan 29, 2020
dfdacaa
init commit of nlp refactoring
yzhang123 Jan 25, 2020
889a2eb
fixed import errors
yzhang123 Jan 27, 2020
a14da52
make absolute imports
yzhang123 Jan 27, 2020
ffcd7fb
fix import error
yzhang123 Jan 27, 2020
52d8a70
fix imports
yzhang123 Jan 28, 2020
ff7774b
rebase master
yzhang123 Jan 28, 2020
7be1709
Improved masking in loss function.
VahidooX Jan 29, 2020
9fc7ed5
Enabled gpu computation in callbacks.
VahidooX Jan 29, 2020
9cc23d9
Enabled gpu computation in callbacks.
VahidooX Jan 29, 2020
5683269
Enabled gpu computation in callbacks.
VahidooX Jan 30, 2020
dd9c4e5
Enabled gpu computation in callbacks.
VahidooX Jan 30, 2020
fdc421b
add all changed nlp files
VahidooX Jan 31, 2020
4f81260
Updated thw whole test folder.
VahidooX Jan 31, 2020
0864e65
Changed nemo.logging to logging
VahidooX Jan 31, 2020
6df0778
Added transformer to the init
VahidooX Jan 31, 2020
44bd381
Fixed lgtm warnings.
VahidooX Jan 31, 2020
b3c1513
Fixed transformer package.
VahidooX Jan 31, 2020
8c58796
Fixed unused local variables.
VahidooX Jan 31, 2020
c5127a1
Fixed lgtm.
VahidooX Jan 31, 2020
ae83cff
Fixed lgtm.
VahidooX Jan 31, 2020
b851473
Fixed logging in examples.
VahidooX Jan 31, 2020
2137704
Merge remote-tracking branch 'remote/master' into nlp_refactoring_tmp
VahidooX Jan 31, 2020
b8f57bf
Moved __all__ after imports. Added more __all__:)
VahidooX Jan 31, 2020
58715fe
Added license to all the files except examples.
VahidooX Feb 1, 2020
a92ea9c
Added license to all examples.
VahidooX Feb 1, 2020
22de233
Fixed style.
VahidooX Feb 1, 2020
8590949
Fixed style.
VahidooX Feb 1, 2020
3bb2035
Updated examples names.
VahidooX Feb 1, 2020
30b7d44
Added licenses to init files.
VahidooX Feb 1, 2020
7a66bec
tested examples
yzhang123 Feb 3, 2020
ba24d5a
fix black style
yzhang123 Feb 3, 2020
8857fc5
updating jenkins after script renaming
yzhang123 Feb 3, 2020
c5a441f
updating changelog
yzhang123 Feb 3, 2020
b2e2475
merged dev-config-nm with nlp_refactor_tmp, all unit tests passed
tkornuta-nvidia Feb 3, 2020
03e02cd
import fixed
ekmb Feb 3, 2020
0d204d6
Moved scripts.
VahidooX Feb 3, 2020
b4349e5
Merge remote-tracking branch 'remote/nlp_refactoring_merged_config' i…
VahidooX Feb 3, 2020
9356e59
Fixed import.
VahidooX Feb 3, 2020
2cf27b8
tested examples scripts
yzhang123 Feb 3, 2020
1afe049
update jenkins
yzhang123 Feb 3, 2020
0363044
Merge branch 'nlp_refactoring_tmp' of https://github.com/NVIDIA/NeMo …
VahidooX Feb 3, 2020
4446f80
merge conflict on CHANGELOG resolved
tkornuta-nvidia Feb 4, 2020
efdb87b
Merge branch 'master' of github.com:NVIDIA/NeMo into merge-nlp-refact…
tkornuta-nvidia Feb 4, 2020
4e81846
LGTM fixes
tkornuta-nvidia Feb 4, 2020
cb11882
removed invalid argument in ipynb
tkornuta-nvidia Feb 4, 2020
4a9343d
removed unused import
tkornuta-nvidia Feb 4, 2020
4a159b4
removed empty nontrainables, as agreed during the meeting, if there a…
tkornuta-nvidia Feb 4, 2020
ecfd2cd
removed nontrainables reference
tkornuta-nvidia Feb 4, 2020
face974
nemo_core - lgtm import fix
tkornuta-nvidia Feb 4, 2020
e59fcdb
removed references to local_params - n-th time:]
tkornuta-nvidia Feb 4, 2020
720f5c5
Cleanups related to removing factory from BERT init calls
tkornuta-nvidia Feb 4, 2020
7f15ab9
Merged with nlp_refactoring.
VahidooX Feb 4, 2020
fa98235
Initial merge.
VahidooX Feb 4, 2020
0e3ac4b
Fixed bugs.
VahidooX Feb 4, 2020
1f0a434
Added licenses.
VahidooX Feb 4, 2020
01fc8a7
Added licenses.
VahidooX Feb 4, 2020
6ca8506
Unfixed the bug in the param list of the optimizer.
VahidooX Feb 4, 2020
19a6f14
Updated changes log.
VahidooX Feb 4, 2020
d6ea41e
Updated licenses.
VahidooX Feb 5, 2020
5cef512
Updated licenses.
VahidooX Feb 5, 2020
1d7af98
Merge remote-tracking branch 'remote/master' into dev_st_trade_new
VahidooX Feb 5, 2020
857864b
removed local_rank passed to pipeline creator.
VahidooX Feb 5, 2020
a2b17a7
Cleaned up the code.
VahidooX Feb 5, 2020
1259d48
Updated changelog.
VahidooX Feb 5, 2020
9a1dc3f
Fixed bugs.
VahidooX Feb 5, 2020
2422c8b
Updated logging style.
VahidooX Feb 5, 2020
9d11dc2
Fixed logging.
VahidooX Feb 5, 2020
bc012ed
Added comments for the input/output ports.
VahidooX Feb 5, 2020
2612258
Replaced prints with logging.
VahidooX Feb 5, 2020
95d7704
Fixed styles.
VahidooX Feb 5, 2020
ef3e150
Fixed styles.
VahidooX Feb 5, 2020
d1469bf
Added refs in process_multiwoz.py
VahidooX Feb 5, 2020
5d387ef
Fixed import bug.
VahidooX Feb 5, 2020
bdb46d2
Added multiwoz license.
VahidooX Feb 5, 2020
8631b50
Added #!/usr/bin/python to the process_multiwoz.py.
VahidooX Feb 5, 2020
5d37a1b
Fixed copyright header.
VahidooX Feb 5, 2020
5632f74
Fixed copyright header.
VahidooX Feb 5, 2020
a26e11a
Fixed logging warnings.
VahidooX Feb 5, 2020
14286aa
Fixed style.
VahidooX Feb 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ To release a new version, please update the changelog as followed:
- Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information.
([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc

- Added TRADE (dialogue state tracking model) on MultiWOZ dataset
([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX

### Dependencies Update
- Added dependency on `wrapt` (the new version of the `deprecated` warning) - @tkornuta-nvidia, @DEKHTIARJonathan

Expand Down
226 changes: 226 additions & 0 deletions examples/nlp/dialogue_state_tracking_trade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# =============================================================================
# Copyright 2019 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

""" An implementation of the paper "Transferable Multi-Domain State Generator
for Task-Oriented Dialogue Systems" (Wu et al., 2019 - ACL 2019)
Adopted from: https://github.com/jasonwu0731/trade-dst
"""

import argparse
import math
import os

import numpy as np

import nemo.collections.nlp as nemo_nlp
import nemo.core as nemo_core
from nemo import logging
from nemo.backends.pytorch.common import EncoderRNN
from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import MultiWOZDataDesc
from nemo.utils.lr_policies import get_lr_policy

parser = argparse.ArgumentParser(description='Dialog state tracking with TRADE model on MultiWOZ dataset')
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--eval_batch_size", default=16, type=int)
parser.add_argument("--num_gpus", default=1, type=int)
parser.add_argument("--num_epochs", default=10, type=int)
parser.add_argument("--lr_warmup_proportion", default=0.0, type=float)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--lr_policy", default=None, type=str)
parser.add_argument("--min_lr", default=1e-4, type=float)
parser.add_argument("--weight_decay", default=0.0, type=float)
parser.add_argument("--emb_dim", default=400, type=int)
parser.add_argument("--hid_dim", default=400, type=int)
parser.add_argument("--n_layers", default=1, type=int)
parser.add_argument("--dropout", default=0.2, type=float)
parser.add_argument("--input_dropout", default=0.2, type=float)
parser.add_argument("--data_dir", default='data/statetracking/multiwoz2.1', type=str)
parser.add_argument("--train_file_prefix", default='train', type=str)
parser.add_argument("--eval_file_prefix", default='test', type=str)
parser.add_argument("--work_dir", default='outputs', type=str)
parser.add_argument("--save_epoch_freq", default=-1, type=int)
parser.add_argument("--save_step_freq", default=-1, type=int)
parser.add_argument("--optimizer_kind", default="adam", type=str)
parser.add_argument("--amp_opt_level", default="O0", type=str, choices=["O0", "O1", "O2"])
parser.add_argument("--shuffle_data", action='store_true')
parser.add_argument("--num_train_samples", default=-1, type=int)
parser.add_argument("--num_eval_samples", default=-1, type=int)
parser.add_argument("--grad_norm_clip", type=float, default=10, help="gradient clipping")
parser.add_argument("--teacher_forcing", default=0.5, type=float)
args = parser.parse_args()

# List of the domains to be considered
domains = {"attraction": 0, "restaurant": 1, "taxi": 2, "train": 3, "hotel": 4}

if not os.path.exists(args.data_dir):
raise ValueError(f'Data not found at {args.data_dir}')

work_dir = f'{args.work_dir}/DST_TRADE'

data_desc = MultiWOZDataDesc(args.data_dir, domains)

nf = nemo_core.NeuralModuleFactory(
backend=nemo_core.Backend.PyTorch,
local_rank=args.local_rank,
optimization_level=args.amp_opt_level,
log_dir=work_dir,
create_tb_writer=True,
files_to_copy=[__file__],
add_time_to_log_dir=True,
)

vocab_size = len(data_desc.vocab)
encoder = EncoderRNN(vocab_size, args.emb_dim, args.hid_dim, args.dropout, args.n_layers)

decoder = nemo_nlp.nm.trainables.TRADEGenerator(
data_desc.vocab,
encoder.embedding,
args.hid_dim,
args.dropout,
data_desc.slots,
len(data_desc.gating_dict),
teacher_forcing=args.teacher_forcing,
)

gate_loss_fn = nemo_nlp.nm.losses.CrossEntropyLoss3D(num_classes=len(data_desc.gating_dict))
ptr_loss_fn = nemo_nlp.nm.losses.TRADEMaskedCrossEntropy()
total_loss_fn = nemo_nlp.nm.losses.LossAggregatorNM(num_inputs=2)


def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefix, is_training):
logging.info(f"Loading {data_prefix} data...")
shuffle = args.shuffle_data if is_training else False

data_layer = nemo_nlp.nm.data_layers.MultiWOZDataLayer(
args.data_dir,
data_desc.domains,
all_domains=data_desc.all_domains,
vocab=data_desc.vocab,
slots=data_desc.slots,
gating_dict=data_desc.gating_dict,
num_samples=num_samples,
shuffle=shuffle,
num_workers=0,
batch_size=batch_size,
mode=data_prefix,
is_training=is_training,
input_dropout=input_dropout,
)

src_ids, src_lens, tgt_ids, tgt_lens, gate_labels, turn_domain = data_layer()

data_size = len(data_layer)
logging.info(f'The length of data layer is {data_size}')

if data_size < batch_size:
logging.warning("Batch_size is larger than the dataset size")
logging.warning("Reducing batch_size to dataset size")
batch_size = data_size

steps_per_epoch = math.ceil(data_size / (batch_size * num_gpus))
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
logging.info(f"Steps_per_epoch = {steps_per_epoch}")

outputs, hidden = encoder(inputs=src_ids, input_lens=src_lens)

point_outputs, gate_outputs = decoder(
encoder_hidden=hidden, encoder_outputs=outputs, input_lens=src_lens, src_ids=src_ids, targets=tgt_ids
)

gate_loss = gate_loss_fn(logits=gate_outputs, labels=gate_labels)
ptr_loss = ptr_loss_fn(logits=point_outputs, targets=tgt_ids, loss_mask=tgt_lens)
total_loss = total_loss_fn(loss_1=gate_loss, loss_2=ptr_loss)

if is_training:
tensors_to_evaluate = [total_loss, gate_loss, ptr_loss]
else:
tensors_to_evaluate = [total_loss, point_outputs, gate_outputs, gate_labels, turn_domain, tgt_ids, tgt_lens]

return tensors_to_evaluate, total_loss, ptr_loss, gate_loss, steps_per_epoch, data_layer


(
tensors_train,
total_loss_train,
ptr_loss_train,
gate_loss_train,
steps_per_epoch_train,
data_layer_train,
) = create_pipeline(
args.num_train_samples,
batch_size=args.batch_size,
num_gpus=args.num_gpus,
input_dropout=args.input_dropout,
data_prefix=args.train_file_prefix,
is_training=True,
)

tensors_eval, total_loss_eval, ptr_loss_eval, gate_loss_eval, steps_per_epoch_eval, data_layer_eval = create_pipeline(
args.num_eval_samples,
batch_size=args.eval_batch_size,
num_gpus=args.num_gpus,
input_dropout=0.0,
data_prefix=args.eval_file_prefix,
is_training=False,
)

# Create callbacks for train and eval modes
train_callback = nemo_core.SimpleLossLoggerCallback(
tensors=[total_loss_train, gate_loss_train, ptr_loss_train],
print_func=lambda x: logging.info(
f'Loss:{str(np.round(x[0].item(), 3))}, '
f'Gate Loss:{str(np.round(x[1].item(), 3))}, '
f'Pointer Loss:{str(np.round(x[2].item(), 3))}'
),
tb_writer=nf.tb_writer,
get_tb_values=lambda x: [["loss", x[0]], ["gate_loss", x[1]], ["pointer_loss", x[2]]],
step_freq=steps_per_epoch_train,
)

eval_callback = nemo_core.EvaluatorCallback(
eval_tensors=tensors_eval,
user_iter_callback=lambda x, y: eval_iter_callback(x, y, data_desc),
user_epochs_done_callback=lambda x: eval_epochs_done_callback(x, data_desc),
tb_writer=nf.tb_writer,
eval_step=steps_per_epoch_train,
)

ckpt_callback = nemo_core.CheckpointCallback(
folder=nf.checkpoint_dir, epoch_freq=args.save_epoch_freq, step_freq=args.save_step_freq
)

if args.lr_policy is not None:
total_steps = args.num_epochs * steps_per_epoch_train
lr_policy_fn = get_lr_policy(
args.lr_policy, total_steps=total_steps, warmup_ratio=args.lr_warmup_proportion, min_lr=args.min_lr
)
else:
lr_policy_fn = None

grad_norm_clip = args.grad_norm_clip if args.grad_norm_clip > 0 else None
nf.train(
tensors_to_optimize=[total_loss_train],
callbacks=[eval_callback, train_callback, ckpt_callback],
lr_policy=lr_policy_fn,
optimizer=args.optimizer_kind,
optimization_params={
"num_epochs": args.num_epochs,
"lr": args.lr,
"grad_norm_clip": grad_norm_clip,
"weight_decay": args.weight_decay,
},
)
Loading