基于bert继续预训练
在目前的各项NLP任务中,如果要在特定任务或者领域应用文本分类,数据分布一定是有一些差距的。这时候可以考虑进行深度预训练。进行继续预训练,有利于提升任务的性能。
·
文章目录
前言
在目前的各项NLP任务中,如果要在特定任务或者领域应用文本分类,数据分布一定是有一些差距的。这时候可以考虑进行深度预训练。进行继续预训练,有利于提升任务的性能。
一、继续预训练是什么?
根据基于bert模型的下游NLP任务的特定语料,对模型进行领域的继续训练。主要类型如下:
Within-task pre-training:Bert在训练语料上进行预训练
In-domain pre-training:在同一领域上的语料进行预训练
Cross-domain pre-training:在不同领域上的语料进行预训练
二、代码
1.引入库
代码如下(示例):
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
from transformers import (AutoModelForMaskedLM,AutoTokenizer, LineByLineTextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments)
2.读入数据
代码如下(示例):
train_data = pd.read_csv('data/train/train.csv', sep='\t')
test_data = pd.read_csv('data/test/test.csv', sep='\t')
train_data['text'] = train_data['title'] + '.' + train_data['abstract']
test_data['text'] = test_data['title'] + '.' + test_data['abstract']
data = pd.concat([train_data, test_data])
data['text'] = data['text'].apply(lambda x: x.replace('\n', ''))
text = '\n'.join(data.text.tolist())
3.构建数据集
代码如下(示例):
model_name = 'roberta-base'
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained('./paper_roberta_base')
train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path="text.txt", # mention train text file here
block_size=256)
valid_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path="text.txt", # mention valid text file here
block_size=256)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
4.开始训练
training_args = TrainingArguments(
output_dir="./paper_roberta_base_chk", # select model path for checkpoint
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy='steps',
save_total_limit=2,
eval_steps=200,
metric_for_best_model='eval_loss',
greater_is_better=False,
load_best_model_at_end=True,
prediction_loss_only=True,
report_to="none")
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=valid_dataset)
trainer.train()
trainer.save_model(f'./paper_roberta_base')
总结
面对一些NLP比赛,可以考虑进行深度预训练。
更多推荐
已为社区贡献1条内容
所有评论(0)