Skip to content

taishan1994/pytorch_bert_chinese_text_classification

Repository files navigation

基于pytorch+bert的中文文本分类

本项目是基于pytorch+bert的中文文本分类。

使用依赖

torch==1.6.0
transformers==4.5.1

相关说明

--logs:存放日志
--checkpoints:存放保存的模型
--data:存放数据
--utils:存放辅助函数
--bert_config.py:相关配置
--data_loader.py:制作数据为torch所需的格式
--models.py:存放模型代码
--main.py:主运行程序,包含训练、验证、测试、预测以及相关评价指标的计算

在hugging face上预先下载好预训练的bert模型,放在和该项目同级下的model_hub文件夹下。

裁判文书分类

数据集地址:裁判文书网NLP文本分类数据集 - Heywhale.com

一般步骤

  • 1、在data下新建一个存放该数据集的文件夹,这里是cpws,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。
  • 2、在data_loader.py里面新建一个类,类可参考CPWSDataset,主要是返回一个列表:[(文本,标签)]。
  • 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。
  • 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。

运行指令

python main.py \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cpws/" \
--data_name="cpws" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=5 \
--seed=123 \
--gpu_ids="0" \
--max_seq_len=256 \
--lr=3e-5 \
--train_batch_size=16 \
--train_epochs=5 \
--eval_batch_size=16 \
--do_train \
--do_test \
--do_predict

结果

这里运行了300步之后手动停止了。

{'盗窃罪': 0, '交通肇事罪': 1, '诈骗罪': 2, '故意伤害罪': 3, '危险驾驶罪': 4}
========进行测试========testloss22.040314 accuracy0.9823 micro_f10.9823 macro_f10.9822
              precision    recall  f1-score   support

         盗窃罪       0.97      0.99      0.98      1998
       交通肇事罪       0.97      0.99      0.98      1996
         诈骗罪       0.99      0.98      0.99      1998
       故意伤害罪       0.99      0.99      0.99      1999
       危险驾驶罪       0.99      0.95      0.97      2000

    accuracy                           0.98      9991
   macro avg       0.98      0.98      0.98      9991
weighted avg       0.98      0.98      0.98      9991
公诉机关指控12015年3月18日18时许被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内趁工作人员不注意盗走工地内的脚手架扣件70个价值人民币252元)。22015年3月19日13时和17时被告人余某分两次窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层的中间配电室内利用随身携带的铁钳盗走该配电室内的电缆线共计574米价值人民币4707元)。32015年3月21日7时30分许被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层靠东边的配电室内利用随身携带的铁钳要将该配电室内的电缆线共156米价值人民币1279元盗走时被工地负责人洪某某发现后被工地保安吴某某抓获并扭送公安机关公诉机关认为被告人余某的行为已构成××,本案第三起盗窃系犯罪未遂建议对被告人余某在××至一年六个月的幅度内处以刑罚并处罚金预测标签盗窃罪
真实标签盗窃罪
==========================

新闻文本分类

使用的数据集是THUCNews,数据地址:THUCNews

一般步骤

  • 1、在data下新建一个存放该数据集的文件夹,这里是cnews,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。
  • 2、在data_loader.py里面新建一个类,类可参考CNEWSDataset,主要是返回一个列表:[(文本,标签)]。
  • 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。
  • 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。

运行

python main.py \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--gpu_ids="0" \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=16 \
--train_epochs=5 \
--eval_batch_size=16 \
--do_predict

结果

这里运行了800步手动停止了。

{'房产': 0, '娱乐': 1, '教育': 2, '体育': 3, '家居': 4, '时政': 5, '财经': 6, '时尚': 7, '游戏': 8, '科技': 9}
========进行测试========testloss76.024950 accuracy0.9697 micro_f10.9697 macro_f10.9696
              precision    recall  f1-score   support

          房产       0.91      0.92      0.92      1000
          娱乐       0.99      0.99      0.99      1000
          教育       0.97      0.96      0.97      1000
          体育       1.00      1.00      1.00      1000
          家居       0.98      0.91      0.94      1000
          时政       0.98      0.94      0.96      1000
          财经       0.96      0.99      0.97      1000
          时尚       0.94      1.00      0.97      1000
          游戏       0.99      0.99      0.99      1000
          科技       0.98      0.99      0.99      1000

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

鲍勃库西奖归谁属NCAA最强控卫是坎巴还是弗神新浪体育讯如今本赛季的NCAA进入到了末段各项奖项的评选结果也即将出炉其中评选最佳控卫的鲍勃-库西奖就将在下周最终四强战时公布鲍勃-库西奖是由奈史密斯篮球名人堂提供旨在奖励年度最佳大学控卫最终获奖的球员也即将在以下几名热门人选中产生。〈〈〈 NCAA疯狂三月专题主页上线点击链接查看精彩内容吉梅尔-弗雷戴特杨百翰大学弗神吉梅尔-弗雷戴特一直都备受关注他不仅仅是一名射手他会用终结对手脚踝一样的变向过掉面前的防守者并且他可以用任意一支手完成得分如果他被犯规了可以提前把这两份划入他的帐下了因为他是一名命中率高达90%的罚球手弗雷戴特具有所有伟大控卫都具备的一点特质他是一位赢家也是一位领导者。“他整个赛季至始至终的稳定领导着球队前进这是无可比拟的。”杨百翰大学主教练戴夫-罗斯称赞道,“他的得分能力毋庸置疑但是我认为他带领球队获胜的能力才是他最重要的控卫职责我们在主场之外的比赛(客场或中立场)共取胜19场他都表现的很棒。”弗雷戴特能否在NBA取得成功当然但是有很多专业人士比我们更有资格去做出这样的判断。“我喜爱他。”凯尔特人主教练多克-里弗斯说道,“他很棒我看过ESPN的片段剪辑从剪辑来看他是个超级巨星我认为他很成为一名优秀的NBA球员。”诺兰-史密斯杜克大学当赛季初球队宣布大一天才控卫凯瑞-厄尔文因脚趾的伤病缺席赛季大部分比赛后诺兰-史密斯便开始接管球权他在进攻端上足发条在ACC联盟(杜克大学所在分区)的得分榜上名列前茅但同时他在分区助攻榜上也占据头名这在众强林立的ACC联盟前无古人。“我不认为全美有其他的球员能在凯瑞-厄尔文受伤后如此好的接管球队并且之前毫无准备。”杜克主教练迈克-沙舍夫斯基赞扬道,“他会将比赛带入自己的节奏得分组织领导球队无所不能而且他现在是攻防俱佳对持球人的防守很有提高总之他拥有了辉煌的赛季。”坎巴-沃克康涅狄格大学坎巴-沃克带领康涅狄格在赛季初的毛伊岛邀请赛一路力克密歇根州大和肯塔基等队夺冠他场均30分4助攻得到最佳球员在大东赛区锦标赛和全国锦标赛中他场均27.16.1个篮板5.1次助攻依旧如此给力他以疯狂的表现开始这个赛季也将以疯狂的表现结束这个赛季。“我们在全国锦标赛中前进着并且之前曾经5天连赢5场赢得了大东赛区锦标赛的冠军这些都归功于坎巴-沃克。”康涅狄格大学主教练吉姆-卡洪称赞道,“他是一名纯正的控卫而且能为我们得分他有过单场42分有过单场17助攻也有过单场15篮板这些都是一名6英尺175镑的球员所完成的啊我们有很多好球员但他才是最好的领导者为球队所做的贡献也是最大。”乔丹-泰勒威斯康辛大学全美没有一个持球者能像乔丹-泰勒一样很少失误他4.26的助攻失误在全美遥遥领先在大十赛区的比赛中他平均35.8分钟才会有一次失误他还是名很出色的得分手全场砍下39分击败印第安纳大学的比赛就是最好的证明其中下半场他曾经连拿18分。“那个夜晚他证明自己值得首轮顺位。”当时的见证者印第安纳大学主教练汤姆-克雷恩说道。“对一名控卫的所有要求不过是领导球队使球队变的更好带领球队成功乔丹-泰勒全做到了。”威斯康辛教练博-莱恩说道诺里斯-科尔克利夫兰州大诺里斯-科尔的草根传奇正在上演默默无闻的他被克利夫兰州大招募后便开始刻苦地训练去年夏天他曾加练上千次跳投来提高这个可能的弱点他在本赛季与杨斯顿州大的比赛中得到40分20篮板和9次助攻在他之前过去15年只有一位球员曾经在NCAA一级联盟做到过40+20他的名字是布雷克-格里芬。“他可以很轻松地防下对方王牌。”克利夫兰州大主教练加里-沃特斯如此称赞自己的弟子,“同时他还能得分并为球队助攻他几乎能做到一个成功的团队所有需要的事。”这其中四名球员都带领自己的球队进入到了甜蜜16强虽然有3个球员和他们各自的球队被挡在8强的大门之外但是他们已经表现的足够出色不远的将来他们很可能出现在一所你熟悉的NBA球馆里。(clay)
预测标签体育
真实标签体育
==========================

Dataparallel分布式训练

python main_dataparallel.py \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--gpu_ids="0,1,3" \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

Distributed单机多卡分布式训练(windows下)

linux下没有测试过。运行需要在powershell里面运行,右键点击开始菜单,选择powershell。nvidia.bat用于监控运行之后GPU的使用情况。需要pytorch版本至少大于1.7,这里使用的是pytorch==1.12。

使用torch.distributed.launch启动

python -m torch.distributed.launch --nnode=1 --node_rank=0 --nproc_per_node=4 main_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test

说明:文件里面通过os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'来选择使用的GPU。nproc_per_node为使用的GPU的数目,local_world_size为使用的GPU的数目。

使用torch.multiprocessing启动

python main_mp_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test

说明:文件里面通过os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'来选择使用的GPU。local_world_size为使用的GPU的数目。

补充

Q:怎么训练自己的数据集?

A:按照样例的一般步骤里面进行即可。

更新日志

  • 2022-08-08:重构了代码,使得总体结构更加简单,更易于用于不同的数据集上。

  • 2022-08-09:新增是否加载模型继续训练,运行参数加上--retrain。

  • 2023-03-30:新增基于dataparallel的分布式训练。

  • 2023-04-02:新增基于distributed的分布式训练。

About

基于pytorch+bert的中文文本分类

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published