File size: 1,990 Bytes
bdf7636
24d9d43
bdf7636
5826177
bdf7636
24d9d43
3469da9
bdf7636
32eb862
 
bdf7636
5826177
24d9d43
 
bdf7636
 
24d9d43
bdf7636
32eb862
 
 
24d9d43
 
 
 
5826177
24d9d43
 
 
5826177
 
 
 
 
 
 
32eb862
 
24d9d43
32eb862
 
 
 
 
 
 
 
 
 
 
 
24d9d43
5826177
 
 
 
 
 
 
 
 
 
32eb862
 
bdf7636
 
24d9d43
 
 
 
 
5826177
24d9d43
 
5826177
bdf7636
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
from collections import defaultdict

import matplotlib.pyplot as plt
import gradio as gr
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification

tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")

plt.switch_backend("Agg")

EXAMPLE_MAP = {}
with open("examples.json", "r") as f:
    example_json = json.load(f)
    EXAMPLE_MAP = {x["text"]: x["label"] for x in example_json}

pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")


def group_by_entity(raw):
    out = defaultdict(int)
    for ent in raw:
        out[ent["entity_group"]] += 1
    # out["total"] = sum(out.values())
    return out


def plot_to_figure(grouped):
    fig = plt.figure()
    plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
    plt.xticks(rotation=90)
    return fig


def ner(text):
    raw = pipe(text)
    ner_content = {
        "text": text,
        "entities": [
            {
                "entity": x["entity_group"],
                "word": x["word"],
                "score": x["score"],
                "start": x["start"],
                "end": x["end"],
            }
            for x in raw
        ],
    }
    grouped = group_by_entity(raw)
    figure = plot_to_figure(grouped)
    label = EXAMPLE_MAP.get(text, "Unknown")

    meta = {
        "entity_counts": grouped,
        "entities": len(set(grouped.keys())),
        "counts": sum(grouped.values()),
    }

    return (ner_content, meta, label, figure)


interface = gr.Interface(
    ner,
    inputs=gr.Textbox(label="Note text", value=""),
    outputs=[
        gr.HighlightedText(label="NER", combine_adjacent=True),
        gr.JSON(label="Entity Counts"),
        gr.Label(label="Rating"),
        gr.Plot(label="Bar"),
    ],
    examples=list(EXAMPLE_MAP.keys()),
    allow_flagging="never",
)

interface.launch()