|
|
|
|
|
from transformers import pipeline |
|
from transformers import Tool |
|
|
|
class NamedEntityRecognitionTool(Tool): |
|
name = "ner_tool" |
|
description = "Identifies and labels various entities in a given text." |
|
inputs = ["text"] |
|
outputs = ["text"] |
|
|
|
def __call__(self, text: str): |
|
|
|
ner_analyzer = pipeline("ner") |
|
|
|
|
|
entities = ner_analyzer(text) |
|
|
|
|
|
token_entities = [] |
|
|
|
for entity in entities: |
|
label = entity.get("entity", "UNKNOWN") |
|
word = entity.get("word", "") |
|
start = entity.get("start", -1) |
|
end = entity.get("end", -1) |
|
|
|
|
|
entity_text = text[start:end].strip() |
|
|
|
|
|
if "##" in word: |
|
|
|
sub_tokens = word.split("##") |
|
for i, sub_token in enumerate(sub_tokens): |
|
token_entities.append({"token": sub_token, "label": label, "entity_text": entity_text}) |
|
else: |
|
|
|
token_entities.append({"token": word, "label": label, "entity_text": entity_text}) |
|
|
|
|
|
print(f"Token-level Entities: {token_entities}") |
|
|
|
return {"entities": token_entities} |
|
|