Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data processing #3

Open
Lj4040 opened this issue Nov 3, 2023 · 3 comments
Open

Data processing #3

Lj4040 opened this issue Nov 3, 2023 · 3 comments

Comments

@Lj4040
Copy link

Lj4040 commented Nov 3, 2023

您好!读了您的论文觉得非常有趣!于是想尝试着学习一下,想着重头实验下您的工作,于是我下载了原始的数据集,但是代码好像都是论文图2的右半部分,对于检测模板建造的代码我好像没有找到,我看论文中说是通过GECtor获得预测的标签这个得到也是一个句子把?,然后通过errant获得真实的label ,这部分您是如何处理的?我在您给的脚本中看见我标记的地方,这是提前处理好的带有模板的数据吗?这个如何才能获得这样的数据?
image

@li-aolong
Copy link
Owner

这部分的处理脚本确实也是没整理好,因为比较杂乱,但是思路都是一致的。

要构建template,首先要获得对句子每个单词的一个检测结果,比如二分类。检测结果有两种或的方式,一就是自己训练一个分类器,可以直接预测检测结果。二是先对源句进行纠正,然后根据纠正后的句子和正确的target计算errant文件或者.m2文件,再根据这个标记文件去抽取错误位置。

由于GECToR本身不是一个检测模型,所以只能用上边第二种方法,先纠正一遍,得到一个m2文件,再去抽取GECToR的检测结果。这是检测信息的获取方式,而真实的错误信息只能通过方法二获取。

比如,对于bea-dev数据集来说,它的m2编辑文件是ABCN.dev.gold.bea19.m2,下边是一个根据编辑文件抽取2分类或4分类检测标签的一个脚本

import os

skip = {"noop", "UNK", "Um"}
class_4 = ['R', 'M', 'U']
class_type = '2-class'

dataset = 'bea'
split = 'dev'
print(dataset, split)

m2_file = f'ABCN.dev.gold.bea19.m2'
m2 = open(m2_file).read().strip('\n').split("\n\n")

# C-correct R-replacement M-missing U-unnecessary
labels = []
for i, sent in enumerate(m2):
    # if i != 3:
    #     continue
    sent_split = sent.split("\n")
    src = sent_split[0].split(' ')[1:]
    edits = sent_split[1:]

    label = ['C'] * len(src)
    last_end = 0
    for edit in edits:
        edit = edit.split("|||")
        err_type = edit[1]
        if err_type in skip:
            continue
        coder = int(edit[-1])
        if coder != 0:
            continue

        err_type_4_class = edit[1][0]
        assert err_type_4_class in class_4, i

        span = edit[0].split(' ')[1:] # Ignore "A "
        start = int(span[0])
        end = int(span[1])
        edit_src = ' '.join(src[start : end])
        edit_tgt = edit[2]

        if start != end:
            if class_type == '2-class':
                label[start : end] = ['I'] * (end - start)
            if class_type == '4-class':
                if edit_tgt == '':
                    assert err_type.startswith('U:'), i
                    label[start : end] = ['U'] * (end - start)
                else:
                    label[start : end] = ['R'] * (end - start)
        else:
            # 如果错误位置在句尾,句子最后一个token变为I
            if class_type == '2-class':
                if start == len(label):
                    label[-1] = 'I'
                else:
                    label[start] = 'I'
            if class_type == '4-class':
                if start == len(label):
                    label[-1] = 'M'
                else:
                    label[start] = 'M'
    
    assert len(label) == len(src), i
    labels.append(' '.join(label) + '\n')

os.makedirs(f'./{dataset}/{split}', exist_ok=True)
open(f'./{dataset}/{split}/{split}.{class_type}.label', 'w').writelines(labels)

之后就得到了每个句子的检测标签,比如:

C C C I I C C C C C C C C C C C C C C C C C C C C C C C C C C C C
C C C C I C C C C C C C C C C C C C C C C C I C C C C C C
C C C C C C I C C C C I C C C I C C C C
C C C C C C C C C C C C C C C C C C C C C C C C C C C C I C
C C C C C C C C C C

然后再根据检测标签和源句去制作template,脚本如下:

import json

n_class = '2-class'

dataset = 'bea'
split = 'valid'
noise = '' # .noise
print(dataset, split)

label_file = f'./{n_class}/{dataset}/{split}/{split}.{n_class}{noise}.label'

labels = open(label_file).read().strip('\n').split('\n')
sources = [json.loads(line)['source'] for line in open(f'./{n_class}/{dataset}/{split}/{split}.{n_class}.json').readlines()]
assert len(labels) == len(sources)

t2s_simple_srcs = []
for i, (label, src_list) in enumerate(zip(labels, sources)):
    if i % 100000 == 0:
        print(i)

    label_list = label.split(' ')
    # src_list = src.split(' ')
        
    try:
        assert len(label_list) == len(src_list), i
    except:
        print(i)
        exit()
        continue

    const = []
    template = []
    tmp = []
    count = 0
    last_label = 'C'
    for j, (label, token) in enumerate(zip(label_list, src_list)):
        if count == 20:
            break
        if label == 'C' and tmp == []:
            template.append(token)

            last_label = label
            continue

        if label != last_label and tmp != []:
            const.append(f'<extra_id_{count}>')
            const.extend(tmp)
            template.append(f'<extra_id_{count}>')

            count += 1

            tmp = []
            if label != 'C':
                tmp.append(token)
            else:
                template.append(f'{token}')

            last_label = label
            continue

        if label != 'C':
            if tmp == []:
                assert last_label == 'C', i
            else:
                assert last_label == label, i
            tmp.append(token)

            last_label = label
            continue

    if tmp != []:
        const.append(f'<extra_id_{count}>')
        const.extend(tmp)
        template.append(f'<extra_id_{count}>')


    # print(const)
    # print(template)
    if const == []:
        t2s_simple_srcs.append('</s> ' + ' '.join(template) + '\n')
    else:
        t2s_simple_srcs.append(' '.join(const) + ' </s> ' + ' '.join(template) + '\n')

output_file = label_file.replace('.label', '.t5.src')
open(output_file, 'w').writelines(t2s_simple_srcs)

这里的gold是为了表示pred和gold对,代码里会根据gold检测是否有这一对的文件,gold就是用真实target计算出来的检测信息。后边tok和bpe是为了使用fairseq进行的一些处理脚本,具体可以参考:https://zhuanlan.zhihu.com/p/401844695

@Lj4040
Copy link
Author

Lj4040 commented Nov 6, 2023

好的,非常感谢您的回复!

@Lj4040
Copy link
Author

Lj4040 commented Dec 5, 2023

您好,我在复现T5的过程中 ,按照您给定的scripts/model_t5/1_train_t5_template-only.sh去设置,然后运行
image
报错一些错误参数,如loss_name,我看这些正是您论文中提到的kl散度,是我运行错误了吗?我也尝试尝试运行了scripts/model_t5/1_train_t5_template-consistency.sh 这个,但是我debug走了一遍,并没有找到论文中kl散度,以及最终的损失相加,似乎只有一般的序列损失
image

我看论文中送入模型的是俩个大致相同的数据,一个是原始Seq2dits预测的处理的数据,一个是根据真实标签修改的数据一起送入模型然后分别做损失和KL散度,但是我查看了下您给定的数据好像里面都是只有一个,是我理解错误了吗?
image
image
应该传入输入数据是这俩个文件把
image
如果是这俩个了我应该如何修改这个参数?对于Train_file和vaild_file这俩个参数,如果您方便的话,希望您能指导我下,谢谢您!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants