资讯详情

NER项目--github--A Unified MRC Framework for Named Entity Recognition

A Unified MRC Framework for Named Entity Recognition项目代码

  • 简述
  • 项目结构
  • models
    • model_config.py
    • classifier.py
    • bert_tagger.py
    • bert_query_ner.py
  • train
    • mrc_ner_trainer.py
  • ner2mrc
    • msra2mrc.py
  • datasets
    • mrc_ner_dataset.py
  • evaluate
    • mrc_ner_evaluate.py
  • inference
    • mrc_ner_inference.py
  • 总结
  • 后记

项目链接:https://github.com/ShannonAI/mrc-for-flat-nested-ner 论文链接:https://arxiv.org/abs/1910.11476

简述

论文将命名实体识别任务转换为机器阅读理解任务MRC,也就是说,文本序列中对应的实体通过问一个问题来提取;一般来说,具体类别的问题是提取的org实体类别,query文本序列中的组织是什么?”。其使用BERT作为backbone,将文本和问句作为序列BERT,使用两个二分类器对BERT对最终数据进行分类,一个分类器判断每个数据token另一个分类器判断每个实体开始索引的可能性token实体结束索引的可能性。 在这里插入图片描述

项目结构

  • datasets --构建数据集文件
  • evaluate --用于评估的文件
  • inference --用于前向推理的文件
  • metrics --实现metric的文件
  • models --构建模型的文件
  • ner2mrc --将数据转换为mrc所需格式的文件
  • scripts --在每个数据集上训练的启动文件
  • tests --测试所需的文件
  • train --模型训练文件
  • utils --其他的辅助文件
  • README.md --项目详情
  • requirements.txt --项目所需的python包

models

先针对models文件分析在路径下,了解模型构建的整个过程

model_config.py

使用文件MRC实体提取任务和框架Tag实体抽取任务的方式分别定义了所需的config类

#!/usr/bin/env python3 # -*- coding: utf-8 -*-  # file: model_config.py  from transformers import BertConfig   class BertQueryNerConfig(BertConfig):    # 使用MRC框架的实体提取所需config     def __init__(self, **kwargs):         super(BertQueryNerConfig, self).__init__(**kwargs)         self.mrc_dropout = kwargs.get("mrc_dropout", 0.1)         self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024)         self.classifier_act_func = kwargs.get("classifier_act_func", "gelu")   class BertTaggerConfig(BertConfig
       
        )
        : 
        # 使用tag方式的实体抽取所需的config 
        def 
        __init__
        (self
        , 
        **kwargs
        )
        : 
        super
        (BertTaggerConfig
        , self
        )
        .__init__
        (
        **kwargs
        ) self
        .num_labels 
        = kwargs
        .get
        (
        "num_labels"
        , 
        6
        ) self
        .classifier_dropout 
        = kwargs
        .get
        (
        "classifier_dropout"
        , 
        0.1
        ) self
        .classifier_sign 
        = kwargs
        .get
        (
        "classifier_sign"
        , 
        "multi_nonlinear"
        ) self
        .classifier_act_func 
        = kwargs
        .get
        (
        "classifier_act_func"
        , 
        "gelu"
        ) self
        .classifier_intermediate_hidden_size 
        = kwargs
        .get
        (
        "classifier_intermediate_hidden_size"
        , 
        1024
        ) 
       

classifier.py

为两种实体抽取方式分别定义分类头

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: classifier.py

import torch.nn as nn
from torch.nn import functional as F


class SingleLinearClassifier(nn.Module):  # 通过一个全连接层直接进行类别预测
    def __init__(self, hidden_size, num_label):
        super(SingleLinearClassifier, self).__init__()
        self.num_label = num_label
        self.classifier = nn.Linear(hidden_size, num_label)

    def forward(self, input_features):
        features_output = self.classifier(input_features)
        return features_output


class MultiNonLinearClassifier(nn.Module):  # MRC框架使用的分类头,使用了两个全连接层对bert输出的特征进行处理
    def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
        super(MultiNonLinearClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2


class BERTTaggerClassifier(nn.Module):  # tag方式的分类头
    def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
        super(BERTTaggerClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2

bert_tagger.py

使用bert进行tag方式的实体抽取

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: bert_tagger.py
#

import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from models.classifier import BERTTaggerClassifier


# 直接把bert最后一层输出的隐含状态直接连接一个分类器进行状态分类
class BertTagger(BertPreTrainedModel):
    def __init__(self, config):
        super(BertTagger, self).__init__(config)
        self.bert = BertModel(config)  # 基于config初始化bert模型

        self.num_labels = config.num_labels
        self.hidden_size = config.hidden_size
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        if config.classifier_sign == "multi_nonlinear":  # 调用tag方式的分类头
            self.classifier = BERTTaggerClassifier(self.hidden_size, self.num_labels,
                                                   config.classifier_dropout,
                                                   act_func=config.classifier_act_func,
                                                   intermediate_hidden_size=config.classifier_intermediate_hidden_size)
        else:
            self.classifier = nn.Linear(self.hidden_size, self.num_labels)

        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,):
        last_bert_layer, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        last_bert_layer = last_bert_layer.view(-1, self.hidden_size)
        last_bert_layer = self.dropout(last_bert_layer)
        logits = self.classifier(last_bert_layer)
        return logits

bert_query_ner.py

如论文中一样,分别对文本序列中每个token进行实体开始索引预测和结束索引预测,再计算开始索引与结束索引的匹配预测结果进行返回

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: bert_query_ner.py

import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from models.classifier import MultiNonLinearClassifier


class BertQueryNER(BertPreTrainedModel):
    def __init__(self, config):
        super(BertQueryNER, self).__init__(config)
        self.bert = BertModel(config)  # 初始化bert

        self.start_outputs = nn.Linear(config.hidden_size, 1)  # 用于计算实体开始索引
        self.end_outputs = nn.Linear(config.hidden_size, 1)  # 用于计算实体结束索引
        # 判断i、j是否为一个匹配
        self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout,
                                                       intermediate_hidden_size=config.classifier_intermediate_hidden_size)

        self.hidden_size = config.hidden_size

        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """ Args: input_ids: bert input tokens, tensor of shape [seq_len] token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len],query在前,文本在后 attention_mask: attention mask, tensor of shape [seq_len] Returns: start_logits: start/non-start probs of shape [seq_len] end_logits: end/non-end probs of shape [seq_len] match_logits: start-end-match probs of shape [seq_len, 1],此处的1表示匹配行的得分 """

        bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

        sequence_heatmap = bert_outputs[0]  # [batch, seq_len, hidden]
        batch_size, seq_len, hid_size = sequence_heatmap.size()

        start_logits = self.start_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1],开始索引预测
        end_logits = self.end_outputs(sequence_heatmap).squeeze(-1)  # [batch, seq_len, 1],结束索引预测

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.
        # [batch, seq_len, hidden]->[batch, seq_len, 1, hidden]->[batch, seq_len, seq_len, hidden]
        start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
        # [batch, seq_len, hidden]->[batch, 1, seq_len, hidden]->[batch, seq_len, seq_len, hidden]
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # [batch, seq_len, seq_len, hidden]+[batch, seq_len, seq_len, hidden]->[batch, seq_len, seq_len, hidden*2]
        span_matrix = torch.cat([start_extend, end_extend], 3)
        # [batch, seq_len, seq_len, hidden*2]->[batch, seq_len, seq_len, 1]->[batch, seq_len, seq_len]
        span_logits = self.span_embedding(span_matrix).squeeze(-1)  # 开始索引和结束索引匹配情况的预测

        return start_logits, end_logits, span_logits

train

mrc_ner_trainer.py

该项目主要使用Pytoch Lightning框架实现训练、测试等过程,使用Pytorch Lightning只用定义主要的训练/training_step()、验证/validation_step()和测试/test_step()函数,而不用写复杂的for循环,框架会自动进行训练。代码如下,该代码是在官方代码的基础上进行调整,增加了使用WandbLogger进行数据记录,可配合笔记进行阅读

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: mrc_ner_trainer.py

import os
import re
import argparse
import logging
from collections import namedtuple
from typing import Dict

import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from tokenizers import BertWordPieceTokenizer
from torch import Tensor
from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
from torch.optim import SGD
from pytorch_lightning.loggers import WandbLogger

from datasets.mrc_ner_dataset import MRCNERDataset
from datasets.truncate_dataset import TruncateDataset
from datasets.collate_functions import collate_to_max_length
from metrics.query_span_f1 import QuerySpanF1
from models.bert_query_ner import BertQueryNER
from models.model_config import BertQueryNerConfig
from utils.get_parser import get_parser
from utils.random_seed import set_random_seed

set_random_seed(0)  # 设置随机数种子,固定随机数,保证模型的复现性
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 使用GPU_0


class BertLabeling(pl.LightningModule):  # 使用pytorch lightning定义网络结构时,要继承其的LightningModule类,类似于nn.Module类
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        format = '%(asctime)s - %(name)s - %(message)s'
        if isinstance(args, argparse.Namespace):  # argparse.Namespace是parse_args(
            # )默认使用的简单类,用于创建一个包含属性的对象,并返回这个对象;此处表示训练模式
            self.save_hyperparameters(args)  # 将args中的参数保存在检查点,可通过self.hparams进行访问
            self.args = args
            logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_result_log.txt"),
                                level=logging.INFO)  # 创建日志记录文件,设置日志等级和格式
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)
            logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_test.txt"),
                                level=logging.INFO)

        self.bert_dir = args.bert_config_dir  # bert预训练模型路径
        self.data_dir = self.args.data_dir  # 数据所在路径
        # 构建bert的config
        bert_config = BertQueryNerConfig.from_pretrained(args.bert_config_dir,
                                                         hidden_dropout_prob=args.bert_dropout,
                                                         attention_probs_dropout_prob=args.bert_dropout,
                                                         mrc_dropout=args.mrc_dropout,
                                                         classifier_act_func=args.classifier_act_func,
                                                         classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size)
        # 初始化BertQueryNER模型
        self.model = BertQueryNER.from_pretrained(args.bert_config_dir, config=bert_config)
        logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args))  # 日志中记录args/训练参数
        self.result_logger = logging.getLogger(__name__)
        self.result_logger.setLevel(logging.INFO)
        self.result_logger.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args))
        self.bce_loss = BCEWithLogitsLoss(reduction="none")  # 损失函数
        # 设置三个损失的权重
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat  # /数据集是否包含嵌套实体
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)  # 自定义计算SpanF1的nn.module
        self.chinese = args.chinese  # 是否为中文
        self.optimizer = args.optimizer  # 优化器
        self.span_loss_candidates = args.span_loss_candidates

    @staticmethod
    def add_model_specific_args(parent_parser):  # 补充模型训练所需的其他的超参数
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--mrc_dropout", type=float, default=0.3,
                            help="mrc dropout rate")
        parser.add_argument("--bert_dropout", type=float, default=0.1,
                            help="bert dropout rate")
        parser.add_argument("--classifier_act_func", type=str, default="gelu")  # 分类头的激活函数
        parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024)  # 分类头中的中间隐变量大小
        parser.add_argument("--weight_start", type=float, default=1.0)  # 开始索引损失的权重
        parser.add_argument("--weight_end", type=float, default=1.0)  # 结束索引损失的权重
        parser.add_argument("--weight_span", type=float, default=0.1)  # 开始索引和结束索引匹配损失的权重
        parser.add_argument("--flat", action="store_true", help="is flat ner")  # 数据集是否是flat
        parser.add_argument("--span_loss_candidates", choices=["all", "pred_and_gold", "pred_gold_random", "gold"],
                            default="pred_and_gold", help="Candidates used to compute span loss")  # span_loss
        parser.add_argument("--chinese", action="store_true",
                            help="is chinese dataset")  # 数据集是否是中文
        parser.add_argument("--optimizer", choices=["adamw", "sgd", "torch.adam"], default="adamw",
                            help="loss type")  # 可选优化器
        parser.add_argument("--final_div_factor", type=float, default=20,
                            help="final div factor of linear decay scheduler")  # 线性衰减策略的最终div因子
        parser.add_argument("--lr_scheduler", type=str, default="onecycle" 

标签: bsz808a振动传感器变送器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

 锐单商城 - 一站式电子元器件采购平台  

 深圳锐单电子有限公司