-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mthreads] deepspeed llama2 #354
Conversation
…gOpen#346) * [kunlunxin] fix tacotron2 running error and add 1x1 & 2x8 config * [kunlunxin] modify tacotron2 test_config * [kunlunxin] update tacotron2 readme * [kunlunxin] modify tacotron2 torch.load()
* update iluvatar/swin_transformer-pytorch * update * update * update * fix batch size mistake in readme * correct val_loss to final acc1 * add finnal_acc1 and mem in readme * correct readme mem --------- Co-authored-by: 魏杰 <[email protected]> Co-authored-by: 杨智超 <[email protected]> Co-authored-by: clveryang <[email protected]>
Co-authored-by: zhouyu <[email protected]>
4bade8b
to
9c7f506
Compare
Co-authored-by: sen.li <[email protected]>
* Update README.md * Update README.md
* iluvatar bertlarge MLM inference case * update ixrt readme --------- Co-authored-by: 杨智超 <[email protected]>
b9abb42
to
9a02a7d
Compare
9a02a7d
to
ad1500b
Compare
* support bert_hf fp32/amp/bf16 training for mthreads * update readme * prevent overrun * 1x1/2x8 not support
* support resnet50 training on mthreads * fix typo * support rn50 amp training on mthreads * add test config (should revert this commit) * update config & readme * add get_system_info fn * update * 1x1/2x8 not support --------- Co-authored-by: Zhou Yu <[email protected]>
@@ -54,18 +54,19 @@ def get_argument_parser(): | |||
|
|||
def train(model_engine, dataloader): | |||
model_engine.train() | |||
device = torch.device('musa:'+str(args.local_rank)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么改的话,别家厂商是不是跑不通了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改。
@@ -6,5 +6,6 @@ def get_llama_model(model_config_dir, flashattn): | |||
config = LlamaConfig.from_pretrained(model_config_dir) | |||
config._flash_attn_2_enabled = flashattn | |||
model = LlamaForCausalLM(config) | |||
model.gradient_checkpointing_enable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方需要修改成以配置文件控制开关,NV默认关,厂商自行开关。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加。
* fixllama * add t/tflops
ad1500b
to
5fae218
Compare
|
||
- ##### 优化策略 | ||
|
||
- 无 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里标注一下使用的优化策略,例如
-flash attention(1/2/sdp-attn)
-checkpointing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,测试之后完善。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充
datafilename = "openwebtext_llama2_100M.npy" | ||
epochs = 1 | ||
theoryflops = 98000000000000.0 | ||
flashattn = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashattn=True # using sdp attention
a1c98e3
to
8b18240
Compare
No description provided.