Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mthreads] deepspeed llama2 #354

Merged
merged 11 commits into from
Dec 21, 2023
Merged

Conversation

shang-mt
Copy link
Contributor

@shang-mt shang-mt commented Dec 6, 2023

No description provided.

jamesruio and others added 3 commits December 6, 2023 11:26
…gOpen#346)

* [kunlunxin] fix tacotron2 running error and add 1x1 & 2x8 config

* [kunlunxin] modify tacotron2 test_config

* [kunlunxin] update tacotron2 readme

* [kunlunxin] modify tacotron2 torch.load()
* update iluvatar/swin_transformer-pytorch

* update

* update

* update

* fix batch size mistake in readme

* correct val_loss to final acc1

* add finnal_acc1 and mem in readme

* correct readme mem

---------

Co-authored-by: 魏杰 <[email protected]>
Co-authored-by: 杨智超 <[email protected]>
Co-authored-by: clveryang <[email protected]>
@shang-mt shang-mt force-pushed the llama2-base branch 3 times, most recently from 4bade8b to 9c7f506 Compare December 7, 2023 08:23
forestlee95 and others added 3 commits December 7, 2023 17:10
* Update README.md

* Update README.md
* iluvatar bertlarge MLM inference case

* update ixrt readme

---------

Co-authored-by: 杨智超 <[email protected]>
@shang-mt shang-mt force-pushed the llama2-base branch 2 times, most recently from b9abb42 to 9a02a7d Compare December 11, 2023 03:41
mingyuanw-mt and others added 2 commits December 11, 2023 17:56
* support bert_hf fp32/amp/bf16 training for mthreads

* update readme

* prevent overrun

* 1x1/2x8 not support
* support resnet50 training on mthreads

* fix typo

* support rn50 amp training on mthreads

* add test config (should revert this commit)

* update config & readme

* add get_system_info fn

* update

* 1x1/2x8 not support

---------

Co-authored-by: Zhou Yu <[email protected]>
@@ -54,18 +54,19 @@ def get_argument_parser():

def train(model_engine, dataloader):
model_engine.train()
device = torch.device('musa:'+str(args.local_rank))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么改的话,别家厂商是不是跑不通了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改。

@@ -6,5 +6,6 @@ def get_llama_model(model_config_dir, flashattn):
config = LlamaConfig.from_pretrained(model_config_dir)
config._flash_attn_2_enabled = flashattn
model = LlamaForCausalLM(config)
model.gradient_checkpointing_enable()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方需要修改成以配置文件控制开关,NV默认关,厂商自行开关。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加。


- ##### 优化策略

- 无
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这里标注一下使用的优化策略,例如
-flash attention(1/2/sdp-attn)
-checkpointing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,测试之后完善。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充

@shang-mt shang-mt changed the base branch from main to mthreads_llama2 December 14, 2023 09:49
datafilename = "openwebtext_llama2_100M.npy"
epochs = 1
theoryflops = 98000000000000.0
flashattn = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashattn=True # using sdp attention

@shh2000 shh2000 merged commit 9c64606 into FlagOpen:mthreads_llama2 Dec 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants