使用 bert-base-chinese-ner 模型实现中文NER

发布于:2024-04-18 ⋅ 阅读:(22) ⋅ 点赞:(0)

TokenClassification

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = BertTokenizerFast.from_pretrained('models/bert-base-chinese')
model = AutoModelForTokenClassification.from_pretrained('models/bert-base-chinese-ner')

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "我叫王大,喜欢去旺角餐厅吃牛角包, 今年买了阿里巴巴的股票,亏得舅老爷的裤衩都没了,我的手机号是13587677888"
ner_results = nlp(example)

输出结果是按token进行分类,恰如其名:AutoModel For TokenClassification

[{'entity': 'B-PERSON',
  'score': 0.999997,
  'index': 3,
  'word': '王',
  'start': 2,
  'end': 3},
 {'entity': 'E-PERSON',
  'score': 0.999998,
  'index': 4,
  'word': '大',
  'start': 3,
  'end': 4},
 {'entity': 'B-FAC',
  'score': 0.99907553,
  'index': 9,
  'word': '旺',
  'start': 8,
  'end': 9},
 {'entity': 'I-FAC',
  'score': 0.9980399,
  'index': 10,
  'word': '角',
  'start': 9,
  'end': 10},
 {'entity': 'I-FAC',
  'score': 0.99831957,
  'index': 11,
  'word': '餐',
  'start': 10,
  'end': 11},
 {'entity': 'E-FAC',
  'score': 0.99899906,
  'index': 12,
  'word': '厅',
  'start': 11,
  'end': 12},
 {'entity': 'B-DATE',
  'score': 0.9999993,
  'index': 18,
  'word': '今',
  'start': 18,
  'end': 19},
 {'entity': 'E-DATE',
  'score': 0.99999917,
  'index': 19,
  'word': '年',
  'start': 19,
  'end': 20},
 {'entity': 'B-PERSON',
  'score': 0.99988663,
  'index': 22,
  'word': '阿',
  'start': 22,
  'end': 23},
 {'entity': 'I-PERSON',
  'score': 0.99996257,
  'index': 23,
  'word': '里',
  'start': 23,
  'end': 24},
 {'entity': 'I-PERSON',
  'score': 0.9999058,
  'index': 24,
  'word': '巴',
  'start': 24,
  'end': 25},
 {'entity': 'E-PERSON',
  'score': 0.99996555,
  'index': 25,
  'word': '巴',
  'start': 25,
  'end': 26},
 {'entity': 'B-CARDINAL',
  'score': 0.99999297,
  'index': 48,
  'word': '135',
  'start': 48,
  'end': 51},
 {'entity': 'I-CARDINAL',
  'score': 0.9999881,
  'index': 49,
  'word': '##87',
  'start': 51,
  'end': 53},
 {'entity': 'I-CARDINAL',
  'score': 0.99998856,
  'index': 50,
  'word': '##67',
  'start': 53,
  'end': 55},
 {'entity': 'I-CARDINAL',
  'score': 0.9999887,
  'index': 51,
  'word': '##78',
  'start': 55,
  'end': 57},
 {'entity': 'E-CARDINAL',
  'score': 0.99999154,
  'index': 52,
  'word': '##88',
  'start': 57,
  'end': 59}]

如果不想用 pipeline,也可以用如下方式:

def token_classify(example):
    inputs = tokenizer([example], return_tensors='pt')
    with torch.no_grad():
        logits = model(**inputs).logits
    predictions = torch.argmax(logits, dim=2)
    predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
    input_tokens = [tokenizer.decode(i) for i in tokenizer(example).input_ids]
    return [{"word":i,"entity":j} for i,j in zip(input_tokens, predicted_token_class)]
[{'word': '[CLS]', 'entity': 'O'},
 {'word': '我', 'entity': 'O'},
 {'word': '叫', 'entity': 'O'},
 {'word': '王', 'entity': 'B-PERSON'},
 {'word': '大', 'entity': 'E-PERSON'},
 {'word': ',', 'entity': 'O'},
 {'word': '喜', 'entity': 'O'},
 {'word': '欢', 'entity': 'O'},
 {'word': '去', 'entity': 'O'},
 {'word': '旺', 'entity': 'B-FAC'},
 {'word': '角', 'entity': 'I-FAC'},
 {'word': '餐', 'entity': 'I-FAC'},
 {'word': '厅', 'entity': 'E-FAC'},
 {'word': '吃', 'entity': 'O'},
 {'word': '牛', 'entity': 'O'},
 {'word': '角', 'entity': 'O'},
 {'word': '包', 'entity': 'O'},
 {'word': ',', 'entity': 'O'},
 {'word': '今', 'entity': 'B-DATE'},
 {'word': '年', 'entity': 'E-DATE'},
 {'word': '买', 'entity': 'O'},
 {'word': '了', 'entity': 'O'},
 {'word': '阿', 'entity': 'B-PERSON'},
 {'word': '里', 'entity': 'I-PERSON'},
 {'word': '巴', 'entity': 'I-PERSON'},
 {'word': '巴', 'entity': 'E-PERSON'},
 {'word': '的', 'entity': 'O'},
 {'word': '股', 'entity': 'O'},
 {'word': '票', 'entity': 'O'},
 {'word': ',', 'entity': 'O'},
 {'word': '亏', 'entity': 'O'},
 {'word': '得', 'entity': 'O'},
 {'word': '舅', 'entity': 'O'},
 {'word': '老', 'entity': 'O'},
 {'word': '爷', 'entity': 'O'},
 {'word': '的', 'entity': 'O'},
 {'word': '裤', 'entity': 'O'},
 {'word': '衩', 'entity': 'O'},
 {'word': '都', 'entity': 'O'},
 {'word': '没', 'entity': 'O'},
 {'word': '了', 'entity': 'O'},
 {'word': ',', 'entity': 'O'},
 {'word': '我', 'entity': 'O'},
 {'word': '的', 'entity': 'O'},
 {'word': '手', 'entity': 'O'},
 {'word': '机', 'entity': 'O'},
 {'word': '号', 'entity': 'O'},
 {'word': '是', 'entity': 'O'},
 {'word': '135', 'entity': 'B-CARDINAL'},
 {'word': '##87', 'entity': 'I-CARDINAL'},
 {'word': '##67', 'entity': 'I-CARDINAL'},
 {'word': '##78', 'entity': 'I-CARDINAL'},
 {'word': '##88', 'entity': 'E-CARDINAL'},
 {'word': '[SEP]', 'entity': 'O'}]

实体识别

为了得到正常的NER识别结果,需要把token聚合成实体。朕用如下代码实现,仅供参考:

def allow_merge(a, b):
    a_flag, a_type = a.split('-')
    b_flag, b_type = b.split('-')
    if b_flag == 'B' or a_flag == 'E':
        return False
    if a_type != b_type:
        return False
    if (a_flag, b_flag) in [
        ("B", "I"),
        ("B", "E"),
        ("I", "I"),
        ("I", "E")
    ]:
        return True
    return False

def divide_entities(ner_results):
    divided_entities = []
    current_entity = []

    for item in sorted(ner_results, key=lambda x: x['index']):
        if not current_entity:
            current_entity.append(item)
        elif allow_merge(current_entity[-1]['entity'], item['entity']):
            current_entity.append(item)
        else:
            divided_entities.append(current_entity)
            current_entity = [item]
    divided_entities.append(current_entity)
    return divided_entities

def merge_entities(same_entities):
    def avg(scores):
        return sum(scores)/len(scores)
    return {
        'entity': same_entities[0]['entity'].split("-")[1],
        'score': avg([e['score'] for e in same_entities]),
        'word': ''.join(e['word'].replace('##', '') for e in same_entities),
        'start': same_entities[0]['start'],
        'end': same_entities[-1]['end']
    }

def post_process(ner_results):
    return [merge_entities(i) for i in divide_entities(ner_results)]
[{'entity': 'PERSON',
  'score': 0.9999974966049194,
  'word': '王大',
  'start': 2,
  'end': 4},
 {'entity': 'FAC',
  'score': 0.9986085146665573,
  'word': '旺角餐厅',
  'start': 8,
  'end': 12},
 {'entity': 'DATE',
  'score': 0.9999992251396179,
  'word': '今年',
  'start': 18,
  'end': 20},
 {'entity': 'PERSON',
  'score': 0.9999301433563232,
  'word': '阿里巴巴',
  'start': 22,
  'end': 26},
 {'entity': 'CARDINAL',
  'score': 0.9999899625778198,
  'word': '13587677888',
  'start': 48,
  'end': 59}]

参考:

  • https://huggingface.co/docs/transformers/tasks/token_classification
  • https://huggingface.co/google-bert/bert-base-chinese
  • https://hf-mirror.com/ckiplab/bert-base-chinese-ner