A Unified MRC Framework for Named Entity Recognition项目代码
- 简述
- 项目结构
- models
-
- model_config.py
- classifier.py
- bert_tagger.py
- bert_query_ner.py
- train
-
- mrc_ner_trainer.py
- ner2mrc
-
- msra2mrc.py
- datasets
-
- mrc_ner_dataset.py
- evaluate
-
- mrc_ner_evaluate.py
- inference
-
- mrc_ner_inference.py
- 总结
- 后记
项目链接:https://github.com/ShannonAI/mrc-for-flat-nested-ner 论文链接:https://arxiv.org/abs/1910.11476
简述
论文将命名实体识别任务转换为机器阅读理解任务MRC,也就是说,文本序列中对应的实体通过问一个问题来提取;一般来说,具体类别的问题是提取的org实体类别,query文本序列中的组织是什么?”。其使用BERT作为backbone,将文本和问句作为序列BERT,使用两个二分类器对BERT对最终数据进行分类,一个分类器判断每个数据token另一个分类器判断每个实体开始索引的可能性token实体结束索引的可能性。 
项目结构
- datasets --构建数据集文件
- evaluate --用于评估的文件
- inference --用于前向推理的文件
- metrics --实现metric的文件
- models --构建模型的文件
- ner2mrc --将数据转换为mrc所需格式的文件
- scripts --在每个数据集上训练的启动文件
- tests --测试所需的文件
- train --模型训练文件
- utils --其他的辅助文件
- README.md --项目详情
- requirements.txt --项目所需的python包
models
先针对models文件分析在路径下,了解模型构建的整个过程
model_config.py
使用文件MRC实体提取任务和框架Tag实体抽取任务的方式分别定义了所需的config类
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # file: model_config.py from transformers import BertConfig class BertQueryNerConfig(BertConfig): # 使用MRC框架的实体提取所需config def __init__(self, **kwargs): super(BertQueryNerConfig, self).__init__(**kwargs) self.mrc_dropout = kwargs.get("mrc_dropout", 0.1) self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024) self.classifier_act_func = kwargs.get("classifier_act_func", "gelu") class BertTaggerConfig(BertConfig
)
:
# 使用tag方式的实体抽取所需的config
def
__init__
(self
,
**kwargs
)
:
super
(BertTaggerConfig
, self
)
.__init__
(
**kwargs
) self
.num_labels
= kwargs
.get
(
"num_labels"
,
6
) self
.classifier_dropout
= kwargs
.get
(
"classifier_dropout"
,
0.1
) self
.classifier_sign
= kwargs
.get
(
"classifier_sign"
,
"multi_nonlinear"
) self
.classifier_act_func
= kwargs
.get
(
"classifier_act_func"
,
"gelu"
) self
.classifier_intermediate_hidden_size
= kwargs
.get
(
"classifier_intermediate_hidden_size"
,
1024
)
classifier.py
为两种实体抽取方式分别定义分类头
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# file: classifier.py
import torch.nn as nn
from torch.nn import functional as F
class SingleLinearClassifier(nn.Module): # 通过一个全连接层直接进行类别预测
def __init__(self, hidden_size, num_label):
super(SingleLinearClassifier, self).__init__()
self.num_label = num_label
self.classifier = nn.Linear(hidden_size, num_label)
def forward(self, input_features):
features_output = self.classifier(input_features)
return features_output
class MultiNonLinearClassifier(nn.Module): # MRC框架使用的分类头,使用了两个全连接层对bert输出的特征进行处理
def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
super(MultiNonLinearClassifier, self).__init__()
self.num_label = num_label
self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
self.dropout = nn.Dropout(dropout_rate)
self.act_func = act_func
def forward(self, input_features):
features_output1 = self.classifier1(input_features)
if self.act_func == "gelu":
features_output1 = F.gelu(features_output1)
elif self.act_func == "relu":
features_output1 = F.relu(features_output1)
elif self.act_func == "tanh":
features_output1 = F.tanh(features_output1)
else:
raise ValueError
features_output1 = self.dropout(features_output1)
features_output2 = self.classifier2(features_output1)
return features_output2
class BERTTaggerClassifier(nn.Module): # tag方式的分类头
def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
super(BERTTaggerClassifier, self).__init__()
self.num_label = num_label
self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
self.dropout = nn.Dropout(dropout_rate)
self.act_func = act_func
def forward(self, input_features):
features_output1 = self.classifier1(input_features)
if self.act_func == "gelu":
features_output1 = F.gelu(features_output1)
elif self.act_func == "relu":
features_output1 = F.relu(features_output1)
elif self.act_func == "tanh":
features_output1 = F.tanh(features_output1)
else:
raise ValueError
features_output1 = self.dropout(features_output1)
features_output2 = self.classifier2(features_output1)
return features_output2
bert_tagger.py
使用bert进行tag方式的实体抽取
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# file: bert_tagger.py
#
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from models.classifier import BERTTaggerClassifier
# 直接把bert最后一层输出的隐含状态直接连接一个分类器进行状态分类
class BertTagger(BertPreTrainedModel):
def __init__(self, config):
super(BertTagger, self).__init__(config)
self.bert = BertModel(config) # 基于config初始化bert模型
self.num_labels = config.num_labels
self.hidden_size = config.hidden_size
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if config.classifier_sign == "multi_nonlinear": # 调用tag方式的分类头
self.classifier = BERTTaggerClassifier(self.hidden_size, self.num_labels,
config.classifier_dropout,
act_func=config.classifier_act_func,
intermediate_hidden_size=config.classifier_intermediate_hidden_size)
else:
self.classifier = nn.Linear(self.hidden_size, self.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None,):
last_bert_layer, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
last_bert_layer = last_bert_layer.view(-1, self.hidden_size)
last_bert_layer = self.dropout(last_bert_layer)
logits = self.classifier(last_bert_layer)
return logits
bert_query_ner.py
如论文中一样,分别对文本序列中每个token进行实体开始索引预测和结束索引预测,再计算开始索引与结束索引的匹配预测结果进行返回
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# file: bert_query_ner.py
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from models.classifier import MultiNonLinearClassifier
class BertQueryNER(BertPreTrainedModel):
def __init__(self, config):
super(BertQueryNER, self).__init__(config)
self.bert = BertModel(config) # 初始化bert
self.start_outputs = nn.Linear(config.hidden_size, 1) # 用于计算实体开始索引
self.end_outputs = nn.Linear(config.hidden_size, 1) # 用于计算实体结束索引
# 判断i、j是否为一个匹配
self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout,
intermediate_hidden_size=config.classifier_intermediate_hidden_size)
self.hidden_size = config.hidden_size
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
""" Args: input_ids: bert input tokens, tensor of shape [seq_len] token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len],query在前,文本在后 attention_mask: attention mask, tensor of shape [seq_len] Returns: start_logits: start/non-start probs of shape [seq_len] end_logits: end/non-end probs of shape [seq_len] match_logits: start-end-match probs of shape [seq_len, 1],此处的1表示匹配行的得分 """
bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden]
batch_size, seq_len, hid_size = sequence_heatmap.size()
start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1],开始索引预测
end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1],结束索引预测
# for every position $i$ in sequence, should concate $j$ to
# predict if $i$ and $j$ are start_pos and end_pos for an entity.
# [batch, seq_len, hidden]->[batch, seq_len, 1, hidden]->[batch, seq_len, seq_len, hidden]
start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
# [batch, seq_len, hidden]->[batch, 1, seq_len, hidden]->[batch, seq_len, seq_len, hidden]
end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
# [batch, seq_len, seq_len, hidden]+[batch, seq_len, seq_len, hidden]->[batch, seq_len, seq_len, hidden*2]
span_matrix = torch.cat([start_extend, end_extend], 3)
# [batch, seq_len, seq_len, hidden*2]->[batch, seq_len, seq_len, 1]->[batch, seq_len, seq_len]
span_logits = self.span_embedding(span_matrix).squeeze(-1) # 开始索引和结束索引匹配情况的预测
return start_logits, end_logits, span_logits
train
mrc_ner_trainer.py
该项目主要使用Pytoch Lightning框架实现训练、测试等过程,使用Pytorch Lightning只用定义主要的训练/training_step()、验证/validation_step()和测试/test_step()函数,而不用写复杂的for循环,框架会自动进行训练。代码如下,该代码是在官方代码的基础上进行调整,增加了使用WandbLogger进行数据记录,可配合笔记进行阅读
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # file: mrc_ner_trainer.py import os import re import argparse import logging from collections import namedtuple from typing import Dict import torch import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from tokenizers import BertWordPieceTokenizer from torch import Tensor from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss from torch.utils.data import DataLoader from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup from torch.optim import SGD from pytorch_lightning.loggers import WandbLogger from datasets.mrc_ner_dataset import MRCNERDataset from datasets.truncate_dataset import TruncateDataset from datasets.collate_functions import collate_to_max_length from metrics.query_span_f1 import QuerySpanF1 from models.bert_query_ner import BertQueryNER from models.model_config import BertQueryNerConfig from utils.get_parser import get_parser from utils.random_seed import set_random_seed set_random_seed(0) # 设置随机数种子,固定随机数,保证模型的复现性 os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 使用GPU_0 class BertLabeling(pl.LightningModule): # 使用pytorch lightning定义网络结构时,要继承其的LightningModule类,类似于nn.Module类 def __init__(self, args: argparse.Namespace): """Initialize a model, tokenizer and config.""" super().__init__() format = '%(asctime)s - %(name)s - %(message)s' if isinstance(args, argparse.Namespace): # argparse.Namespace是parse_args( # )默认使用的简单类,用于创建一个包含属性的对象,并返回这个对象;此处表示训练模式 self.save_hyperparameters(args) # 将args中的参数保存在检查点,可通过self.hparams进行访问 self.args = args logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_result_log.txt"), level=logging.INFO) # 创建日志记录文件,设置日志等级和格式 else: # eval mode TmpArgs = namedtuple("tmp_args", field_names=list(args.keys())) self.args = args = TmpArgs(**args) logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_test.txt"), level=logging.INFO) self.bert_dir = args.bert_config_dir # bert预训练模型路径 self.data_dir = self.args.data_dir # 数据所在路径 # 构建bert的config bert_config = BertQueryNerConfig.from_pretrained(args.bert_config_dir, hidden_dropout_prob=args.bert_dropout, attention_probs_dropout_prob=args.bert_dropout, mrc_dropout=args.mrc_dropout, classifier_act_func=args.classifier_act_func, classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size) # 初始化BertQueryNER模型 self.model = BertQueryNER.from_pretrained(args.bert_config_dir, config=bert_config) logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args)) # 日志中记录args/训练参数 self.result_logger = logging.getLogger(__name__) self.result_logger.setLevel(logging.INFO) self.result_logger.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args)) self.bce_loss = BCEWithLogitsLoss(reduction="none") # 损失函数 # 设置三个损失的权重 weight_sum = args.weight_start + args.weight_end + args.weight_span self.weight_start = args.weight_start / weight_sum self.weight_end = args.weight_end / weight_sum self.weight_span = args.weight_span / weight_sum self.flat_ner = args.flat # /数据集是否包含嵌套实体 self.span_f1 = QuerySpanF1(flat=self.flat_ner) #自定义计算SpanF1的nn.module self.chinese = args.chinese # 是否为中文 self.optimizer = args.optimizer # 优化器 self.span_loss_candidates = args.span_loss_candidates @staticmethod def add_model_specific_args(parent_parser): # 补充模型训练所需的其他的超参数 parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument("--mrc_dropout", type=float, default=0.3, help="mrc dropout rate") parser.add_argument("--bert_dropout", type=float, default=0.1, help="bert dropout rate") parser.add_argument("--classifier_act_func", type=str, default="gelu") # 分类头的激活函数 parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024) # 分类头中的中间隐变量大小 parser.add_argument("--weight_start", type=float, default=1.0) # 开始索引损失的权重 parser.add_argument("--weight_end", type=float, default=1.0) # 结束索引损失的权重 parser.add_argument("--weight_span", type=float, default=0.1) # 开始索引和结束索引匹配损失的权重 parser.add_argument("--flat", action="store_true", help="is flat ner") # 数据集是否是flat parser.add_argument("--span_loss_candidates", choices=["all", "pred_and_gold", "pred_gold_random", "gold"], default="pred_and_gold", help="Candidates used to compute span loss") # span_loss parser.add_argument("--chinese", action="store_true", help="is chinese dataset") # 数据集是否是中文 parser.add_argument("--optimizer", choices=["adamw", "sgd", "torch.adam"], default="adamw", help="loss type") # 可选优化器 parser.add_argument("--final_div_factor", type=float, default=20, help="final div factor of linear decay scheduler") # 线性衰减策略的最终div因子 parser.add_argument("--lr_scheduler", type=str, default="onecycle"