type
status
date
slug
summary
tags
category
icon
password
Data Collator是HuggingFace开源的transformers模块进行数据处理的重要部分。它的输入是由数据集元素组成的列表,将其组装成批次,其中数据集元素为相同数据类型的train_dataset或者eval_dataset
为了组装成数据批次,Datacollators 会应用某些处理(比如padding),有些(比如DataCollatorForLanguageModeling)还会在数据批次上应用随机数据增强(比如随机masking)。
 
Data collators是为了特定任务而设计的,如下:
  • Causal language modeling (CLM)
  • Masking language modeling (MLM)
  • Sequence classification
  • Seq2Seq
  • Token classification
 
以sequence classification任务为例,data collator只需将所有序列填充到一个小批次中,以确保它们具有相同的长度。
 
当我们用transformers的Trainer模块进行训练时,常常会根据任务的不同,搭配一个DataCollactor。例如很简单的微调Sequence Classification的场景中:
 
当然,我们有时候在用 trl 库的 SFTTrainer 来做的时候,也可以不用指定:
在SFTTrainer的内部实现上看,其实对这个类有一个初始化:
这个 DataCollatorForLanguageModeling 在transformers中的描述为:
 
 
其实感觉实际上和 DataCollatorWithPadding 的效果是类似的,唯一不同的是 DataCollatorForLanguageModeling 会多一个随机mask的过程,而 DataCollatorWithPadding 没有
 

具体例子解读

transformers模块源码中,data collators源码位于transformers/data/data_collator.py脚本中,所有data collator的父类为DataCollatorMixin,代码如下:
这个类根据输入return_tensors决定处理哪种矩阵torch、tensorflow、numpy,一般不直接使用。
以下是其派生的子类,我们将一一进行了解。
 

DefaultDataCollator

 
DefaultDataCollator类为Trainer中默认DataCollator类,它不进行任何padding或者truncation,并且假设所有的输入样本拥有相同的长度。如果输入样本长度不一样,则会报错。
该类一般不直接使用。
 

DataCollatorWithPadding

DataCollatorWithPadding类将输入的input_ids、attention_mask等向量做padding处理,截断或补充padding,使得它们的长度保持统一。
输出结果如下
这里使用的transformers模块版本为4.46.1tokenizers模块版本为0.20.3
我们为Qwen1.5-7B模型的tokenier为例。从上面的输出中可以看到,原始的数据经tokenize后,每一行的input_idsattention_mask长度并不一样,使用DataCollatorWithPadding后它们的长度一样,使用151643token_id(对应token为<|endoftext|>)进行填充。
DataCollatorWithPadding类还有其他参数,比如padding, max_length, pad_to_multiple_of, return_tensors等,我们可以指定序列最大长度(自动truncate截断),比如
DataCollatorWithPadding(tokenizer, max_length=10, padding='max_length')
输出
此时如果你想paddingmax_length,还需要额外参数 padding="max_length"
 
此时输出就自动填充到了设定的最大长度:
 
 
这里还有一个细节,有的模型是在left padding(比如T5, chronos等),有的模型是在右边right padding(Qwen等),这里就不是在DataCollator中设置了,而是在tokenizer中设置padding_side参数。例如
输出:
此时填充的值就跑到左边去了
 

DataCollatorForTokenClassification

DataCollatorForTokenClassification类适用于序列标注任务(token classification),比如命名实体识别(NER),每个token都会对应一个预测label,当一个序列的token比label多时,对额外的label位置添加-100处理(意味着不计算loss),使得在计算交叉熵损失时,-100位置的label损失为0不用考虑。
 
输出结果为:
 
可以看到,此时仍对序列进行了填充,使得所有序列保持同一长度,同时对token多于label的场景使用-100进行填充。

DataCollatorForSeq2Seq

DataCollatorForSeq2Seq类适用于Seq2Seq任务,包括机器翻译,文本生成任务等。该类根据前面的序列来预测后面序列,后面序列为labels,一个batch内多条后面序列长度不同时,用-100来填充。同样地,在计算损失时,-100位置的损失不予考虑。
输出结果为:
可以看到,后面序列labels长度不一致,使用-100填充至batch内最大长度。
 

DataCollatorForLanguageModeling

DataCollatorForLanguageModeling类适用于语言模型,比如BERT系列等。
该类有个mlm参数,其默认值为True, 取值如下:
  • 如果设置为False,则labels与inputs一致,使用-100进行填充
  • 如果设置为True,则labels对于non-masked tokens设置为-100,对于masked token,其处理方式同BERT模型:80% MASK(MASK值), 10% random(随机替换成其它token值), 10% original(原始值)。
mlm_probability参数默认值为0.15,即对input_ids中的token会有15%的概率进行mask。
我们来对比一下看看是啥效果
DataCollatorWithPadding
 
DataCollatorForLanguageModeling
 
 
 
可以看到主要区别就在右边labels不为-100的区域,也就是这些非-100的区域进行了替换,我们可以对照左边结果查一下:
  • 第一行没变
  • 第二行 1298103,
  • 第三行 2356103
  • 第四行 77610334412438
我们查一下lookup,看看 这几个id分别是什么
可以看到,其中3个字符(南、京、市)被替换为了[MASK],而其中 1个字符(桥)被随机替换为其他字符(廁)
 
 
 

DataCollatorForWholeWordMask

DataCollatorForWholeWordMask类继承了DataCollatorForLanguageModeling类,区别在于其全词掩盖功能(WholeWordMask),该遮词方法需要在tokenizer分词的时候就对期望连续遮住的词汇的非第一个字前加上##标记。
 
 

DataCollatorForSOP

DataCollatorForSOP适用于句子顺序预测任务,将会在后续版本中进行移除,这里不再介绍。
 
 
 

DataCollatorForPermutationLanguageModeling

DataCollatorForPermutationLanguageModeling适用于置换语言模型(permutation language modeling),比如XLNet模型等,可以参考XLNet模型的相关解读,这里不予介绍。
 
 
 

NER实战(DataCollatorForTokenClassification)

我们使用DataCollatorForTokenClassification类来完成命名实体识别任务,数据集用peoples_daily_ner,模型采用bert-base-chinese
数据集描述:
 
 
 
训练脚本如下:
这里我们的数据预处理方式比较粗糙,样本的token与label并没有严格对应,不过对于bert-base-chinese模型,大多数样本已经满足要求了。
我们对测试集进行评估(用到了sklearn.metricsclassification_report功能)
评估指标如下:
 
此外,有人反映 huggingface中出的evaluate模块和sklearn的classification_report的结果不一致:https://discuss.huggingface.co/t/help-model-evaluation-for-ner-yields-different-results-sklearn-vs-metric-compute/10528
我们这里测试一下
LOC
ORG
PER
overall_precision
overall_recall
overall_f1
overall_accuracy
precision
0.901765
0.855913
0.945902
0.899059
0.880628
0.889748
0.987769
recall
0.865774
0.864531
0.928648
0.899059
0.880628
0.889748
0.987769
f1
0.883403
0.860200
0.937195
0.899059
0.880628
0.889748
0.987769
number
3658.000000
2185.000000
1864.000000
0.899059
0.880628
0.889748
0.987769
确实不一样,但是统计范围似乎不同,这里还是建议使用sklearn的方法去评估。
 
 

致谢

本文主要内容来基于对个人博客的补充和修缮:https://percent4.github.io/NLP(九十四)transformers模块中的DataCollator/
 
VLM系列论文阅读-Mixed Preference Optimization (MPO)MobileAgent系列学习 — Mobile Agent v2
Loading...
Yixin Huang
Yixin Huang
一个热爱生活的算法工程师
Latest posts
时间序列论文阅读 — TimeCMA(AAAI 2025)
2025-4-23
时间序列论文阅读-ChatTime: A Unified Multimodal Time Series Foundation Model Bridging
2025-4-23
VLM系列论文阅读-Mixed Preference Optimization (MPO)
2025-2-6
VLM系列论文阅读 — Flamingo
2025-2-6
认识你自己,才是这件事的最终乐趣 — 抄录
2025-2-5
用GPT4学量化投资 — Junior Level - Unit 1: Introduction to Stock Markets and Data Handling
2025-1-23
Announcement
🎉NotionNext 4.5已经上线🎉
-- 感谢您的支持 ---
👏欢迎更新体验👏