type
status
date
slug
summary
tags
category
icon
password
Data Collator是HuggingFace开源的transformers模块进行数据处理的重要部分。它的输入是由数据集元素组成的列表,将其组装成批次,其中数据集元素为相同数据类型的
train_dataset或者eval_dataset。为了组装成数据批次,Datacollators 会应用某些处理(比如padding),有些(比如
DataCollatorForLanguageModeling)还会在数据批次上应用随机数据增强(比如随机masking)。Data collators是为了特定任务而设计的,如下:
- Causal language modeling (CLM)
- Masking language modeling (MLM)
- Sequence classification
- Seq2Seq
- Token classification
以sequence classification任务为例,data collator只需将所有序列填充到一个小批次中,以确保它们具有相同的长度。
当我们用transformers的Trainer模块进行训练时,常常会根据任务的不同,搭配一个
DataCollactor。例如很简单的微调Sequence Classification的场景中:当然,我们有时候在用
trl 库的 SFTTrainer 来做的时候,也可以不用指定:在SFTTrainer的内部实现上看,其实对这个类有一个初始化:
这个
DataCollatorForLanguageModeling 在transformers中的描述为:其实感觉实际上和
DataCollatorWithPadding 的效果是类似的,唯一不同的是 DataCollatorForLanguageModeling 会多一个随机mask的过程,而 DataCollatorWithPadding 没有具体例子解读
在
transformers模块源码中,data collators源码位于transformers/data/data_collator.py脚本中,所有data collator的父类为DataCollatorMixin,代码如下:这个类根据输入
return_tensors决定处理哪种矩阵torch、tensorflow、numpy,一般不直接使用。以下是其派生的子类,我们将一一进行了解。
DefaultDataCollator
DefaultDataCollator类为Trainer中默认DataCollator类,它不进行任何padding或者truncation,并且假设所有的输入样本拥有相同的长度。如果输入样本长度不一样,则会报错。该类一般不直接使用。
DataCollatorWithPadding
DataCollatorWithPadding类将输入的input_ids、attention_mask等向量做padding处理,截断或补充padding,使得它们的长度保持统一。
输出结果如下
这里使用的
transformers模块版本为4.46.1, tokenizers模块版本为0.20.3。我们为
Qwen1.5-7B模型的tokenier为例。从上面的输出中可以看到,原始的数据经tokenize后,每一行的input_ids和attention_mask长度并不一样,使用DataCollatorWithPadding后它们的长度一样,使用151643的token_id(对应token为<|endoftext|>)进行填充。DataCollatorWithPadding类还有其他参数,比如padding, max_length, pad_to_multiple_of, return_tensors等,我们可以指定序列最大长度(自动truncate截断),比如
DataCollatorWithPadding(tokenizer, max_length=10, padding='max_length')输出
此时如果你想
padding到max_length,还需要额外参数 padding="max_length"此时输出就自动填充到了设定的最大长度:
这里还有一个细节,有的模型是在left padding(比如T5, chronos等),有的模型是在右边right padding(Qwen等),这里就不是在DataCollator中设置了,而是在
tokenizer中设置padding_side参数。例如输出:
此时填充的值就跑到左边去了
DataCollatorForTokenClassification
DataCollatorForTokenClassification类适用于序列标注任务(token classification),比如命名实体识别(NER),每个token都会对应一个预测label,当一个序列的token比label多时,对额外的label位置添加-100处理(意味着不计算loss),使得在计算交叉熵损失时,-100位置的label损失为0不用考虑。输出结果为:
可以看到,此时仍对序列进行了填充,使得所有序列保持同一长度,同时对token多于label的场景使用-100进行填充。
DataCollatorForSeq2Seq
DataCollatorForSeq2Seq类适用于Seq2Seq任务,包括机器翻译,文本生成任务等。该类根据前面的序列来预测后面序列,后面序列为labels,一个batch内多条后面序列长度不同时,用-100来填充。同样地,在计算损失时,-100位置的损失不予考虑。输出结果为:
可以看到,后面序列labels长度不一致,使用-100填充至batch内最大长度。
DataCollatorForLanguageModeling
DataCollatorForLanguageModeling类适用于语言模型,比如BERT系列等。该类有个mlm参数,其默认值为True, 取值如下:
- 如果设置为False,则labels与inputs一致,使用-100进行填充
- 如果设置为True,则labels对于non-masked tokens设置为-100,对于masked token,其处理方式同BERT模型:80% MASK(MASK值), 10% random(随机替换成其它token值), 10% original(原始值)。
mlm_probability参数默认值为0.15,即对input_ids中的token会有15%的概率进行mask。我们来对比一下看看是啥效果
DataCollatorWithPadding
DataCollatorForLanguageModeling
可以看到主要区别就在右边labels不为-100的区域,也就是这些非-100的区域进行了替换,我们可以对照左边结果查一下:
- 第一行没变
- 第二行
1298→103,
- 第三行
2356→103
- 第四行
776→103,3441→2438
我们查一下lookup,看看 这几个id分别是什么
可以看到,其中3个字符(南、京、市)被替换为了[MASK],而其中 1个字符(桥)被随机替换为其他字符(廁)
DataCollatorForWholeWordMask
DataCollatorForWholeWordMask类继承了DataCollatorForLanguageModeling类,区别在于其全词掩盖功能(WholeWordMask),该遮词方法需要在tokenizer分词的时候就对期望连续遮住的词汇的非第一个字前加上##标记。
DataCollatorForSOP
DataCollatorForSOP适用于句子顺序预测任务,将会在后续版本中进行移除,这里不再介绍。
DataCollatorForPermutationLanguageModeling
DataCollatorForPermutationLanguageModeling适用于置换语言模型(permutation language modeling),比如XLNet模型等,可以参考XLNet模型的相关解读,这里不予介绍。
NER实战(DataCollatorForTokenClassification)
我们使用DataCollatorForTokenClassification类来完成命名实体识别任务,数据集用
peoples_daily_ner,模型采用bert-base-chinese。数据集描述:‣
训练脚本如下:
这里我们的数据预处理方式比较粗糙,样本的token与label并没有严格对应,不过对于
bert-base-chinese模型,大多数样本已经满足要求了。我们对测试集进行评估(用到了
sklearn.metrics的 classification_report功能)评估指标如下:
此外,有人反映 huggingface中出的evaluate模块和sklearn的classification_report的结果不一致:https://discuss.huggingface.co/t/help-model-evaluation-for-ner-yields-different-results-sklearn-vs-metric-compute/10528
我们这里测试一下
ㅤ | LOC | ORG | PER | overall_precision | overall_recall | overall_f1 | overall_accuracy |
precision | 0.901765 | 0.855913 | 0.945902 | 0.899059 | 0.880628 | 0.889748 | 0.987769 |
recall | 0.865774 | 0.864531 | 0.928648 | 0.899059 | 0.880628 | 0.889748 | 0.987769 |
f1 | 0.883403 | 0.860200 | 0.937195 | 0.899059 | 0.880628 | 0.889748 | 0.987769 |
number | 3658.000000 | 2185.000000 | 1864.000000 | 0.899059 | 0.880628 | 0.889748 | 0.987769 |
确实不一样,但是统计范围似乎不同,这里还是建议使用sklearn的方法去评估。
致谢
本文主要内容来基于对个人博客的补充和修缮:https://percent4.github.io/NLP(九十四)transformers模块中的DataCollator/
- Author:Yixin Huang
- URL:https://yixinhuang.cn/article/transformers-datacollactor
- Copyright:All articles in this blog, except for special statements, adopt BY-NC-SA agreement. Please indicate the source!




