TokenClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
tokenizer = BertTokenizerFast.from_pretrained('models/bert-base-chinese')
model = AutoModelForTokenClassification.from_pretrained('models/bert-base-chinese-ner')
nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "我叫王大,喜欢去旺角餐厅吃牛角包, 今年买了阿里巴巴的股票,亏得舅老爷的裤衩都没了,我的手机号是13587677888"
ner_results = nlp(example)
输出结果是按token进行分类,恰如其名:AutoModel For TokenClassification
[{'entity': 'B-PERSON',
'score': 0.999997,
'index': 3,
'word': '王',
'start': 2,
'end': 3},
{'entity': 'E-PERSON',
'score': 0.999998,
'index': 4,
'word': '大',
'start': 3,
'end': 4},
{'entity': 'B-FAC',
'score': 0.99907553,
'index': 9,
'word': '旺',
'start': 8,
'end': 9},
{'entity': 'I-FAC',
'score': 0.9980399,
'index': 10,
'word': '角',
'start': 9,
'end': 10},
{'entity': 'I-FAC',
'score': 0.99831957,
'index': 11,
'word': '餐',
'start': 10,
'end': 11},
{'entity': 'E-FAC',
'score': 0.99899906,
'index': 12,
'word': '厅',
'start': 11,
'end': 12},
{'entity': 'B-DATE',
'score': 0.9999993,
'index': 18,
'word': '今',
'start': 18,
'end': 19},
{'entity': 'E-DATE',
'score': 0.99999917,
'index': 19,
'word': '年',
'start': 19,
'end': 20},
{'entity': 'B-PERSON',
'score': 0.99988663,
'index': 22,
'word': '阿',
'start': 22,
'end': 23},
{'entity': 'I-PERSON',
'score': 0.99996257,
'index': 23,
'word': '里',
'start': 23,
'end': 24},
{'entity': 'I-PERSON',
'score': 0.9999058,
'index': 24,
'word': '巴',
'start': 24,
'end': 25},
{'entity': 'E-PERSON',
'score': 0.99996555,
'index': 25,
'word': '巴',
'start': 25,
'end': 26},
{'entity': 'B-CARDINAL',
'score': 0.99999297,
'index': 48,
'word': '135',
'start': 48,
'end': 51},
{'entity': 'I-CARDINAL',
'score': 0.9999881,
'index': 49,
'word': '##87',
'start': 51,
'end': 53},
{'entity': 'I-CARDINAL',
'score': 0.99998856,
'index': 50,
'word': '##67',
'start': 53,
'end': 55},
{'entity': 'I-CARDINAL',
'score': 0.9999887,
'index': 51,
'word': '##78',
'start': 55,
'end': 57},
{'entity': 'E-CARDINAL',
'score': 0.99999154,
'index': 52,
'word': '##88',
'start': 57,
'end': 59}]
如果不想用 pipeline,也可以用如下方式:
def token_classify(example):
inputs = tokenizer([example], return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
input_tokens = [tokenizer.decode(i) for i in tokenizer(example).input_ids]
return [{"word":i,"entity":j} for i,j in zip(input_tokens, predicted_token_class)]
[{'word': '[CLS]', 'entity': 'O'},
{'word': '我', 'entity': 'O'},
{'word': '叫', 'entity': 'O'},
{'word': '王', 'entity': 'B-PERSON'},
{'word': '大', 'entity': 'E-PERSON'},
{'word': ',', 'entity': 'O'},
{'word': '喜', 'entity': 'O'},
{'word': '欢', 'entity': 'O'},
{'word': '去', 'entity': 'O'},
{'word': '旺', 'entity': 'B-FAC'},
{'word': '角', 'entity': 'I-FAC'},
{'word': '餐', 'entity': 'I-FAC'},
{'word': '厅', 'entity': 'E-FAC'},
{'word': '吃', 'entity': 'O'},
{'word': '牛', 'entity': 'O'},
{'word': '角', 'entity': 'O'},
{'word': '包', 'entity': 'O'},
{'word': ',', 'entity': 'O'},
{'word': '今', 'entity': 'B-DATE'},
{'word': '年', 'entity': 'E-DATE'},
{'word': '买', 'entity': 'O'},
{'word': '了', 'entity': 'O'},
{'word': '阿', 'entity': 'B-PERSON'},
{'word': '里', 'entity': 'I-PERSON'},
{'word': '巴', 'entity': 'I-PERSON'},
{'word': '巴', 'entity': 'E-PERSON'},
{'word': '的', 'entity': 'O'},
{'word': '股', 'entity': 'O'},
{'word': '票', 'entity': 'O'},
{'word': ',', 'entity': 'O'},
{'word': '亏', 'entity': 'O'},
{'word': '得', 'entity': 'O'},
{'word': '舅', 'entity': 'O'},
{'word': '老', 'entity': 'O'},
{'word': '爷', 'entity': 'O'},
{'word': '的', 'entity': 'O'},
{'word': '裤', 'entity': 'O'},
{'word': '衩', 'entity': 'O'},
{'word': '都', 'entity': 'O'},
{'word': '没', 'entity': 'O'},
{'word': '了', 'entity': 'O'},
{'word': ',', 'entity': 'O'},
{'word': '我', 'entity': 'O'},
{'word': '的', 'entity': 'O'},
{'word': '手', 'entity': 'O'},
{'word': '机', 'entity': 'O'},
{'word': '号', 'entity': 'O'},
{'word': '是', 'entity': 'O'},
{'word': '135', 'entity': 'B-CARDINAL'},
{'word': '##87', 'entity': 'I-CARDINAL'},
{'word': '##67', 'entity': 'I-CARDINAL'},
{'word': '##78', 'entity': 'I-CARDINAL'},
{'word': '##88', 'entity': 'E-CARDINAL'},
{'word': '[SEP]', 'entity': 'O'}]
实体识别
为了得到正常的NER识别结果,需要把token聚合成实体。朕用如下代码实现,仅供参考:
def allow_merge(a, b):
a_flag, a_type = a.split('-')
b_flag, b_type = b.split('-')
if b_flag == 'B' or a_flag == 'E':
return False
if a_type != b_type:
return False
if (a_flag, b_flag) in [
("B", "I"),
("B", "E"),
("I", "I"),
("I", "E")
]:
return True
return False
def divide_entities(ner_results):
divided_entities = []
current_entity = []
for item in sorted(ner_results, key=lambda x: x['index']):
if not current_entity:
current_entity.append(item)
elif allow_merge(current_entity[-1]['entity'], item['entity']):
current_entity.append(item)
else:
divided_entities.append(current_entity)
current_entity = [item]
divided_entities.append(current_entity)
return divided_entities
def merge_entities(same_entities):
def avg(scores):
return sum(scores)/len(scores)
return {
'entity': same_entities[0]['entity'].split("-")[1],
'score': avg([e['score'] for e in same_entities]),
'word': ''.join(e['word'].replace('##', '') for e in same_entities),
'start': same_entities[0]['start'],
'end': same_entities[-1]['end']
}
def post_process(ner_results):
return [merge_entities(i) for i in divide_entities(ner_results)]
[{'entity': 'PERSON',
'score': 0.9999974966049194,
'word': '王大',
'start': 2,
'end': 4},
{'entity': 'FAC',
'score': 0.9986085146665573,
'word': '旺角餐厅',
'start': 8,
'end': 12},
{'entity': 'DATE',
'score': 0.9999992251396179,
'word': '今年',
'start': 18,
'end': 20},
{'entity': 'PERSON',
'score': 0.9999301433563232,
'word': '阿里巴巴',
'start': 22,
'end': 26},
{'entity': 'CARDINAL',
'score': 0.9999899625778198,
'word': '13587677888',
'start': 48,
'end': 59}]
参考:
- https://huggingface.co/docs/transformers/tasks/token_classification
- https://huggingface.co/google-bert/bert-base-chinese
- https://hf-mirror.com/ckiplab/bert-base-chinese-ner