Spaces:
Running
Running
File size: 1,990 Bytes
bdf7636 24d9d43 bdf7636 5826177 bdf7636 24d9d43 3469da9 bdf7636 32eb862 bdf7636 5826177 24d9d43 bdf7636 24d9d43 bdf7636 32eb862 24d9d43 5826177 24d9d43 5826177 32eb862 24d9d43 32eb862 24d9d43 5826177 32eb862 bdf7636 24d9d43 5826177 24d9d43 5826177 bdf7636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import gradio as gr
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
plt.switch_backend("Agg")
EXAMPLE_MAP = {}
with open("examples.json", "r") as f:
example_json = json.load(f)
EXAMPLE_MAP = {x["text"]: x["label"] for x in example_json}
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
def group_by_entity(raw):
out = defaultdict(int)
for ent in raw:
out[ent["entity_group"]] += 1
# out["total"] = sum(out.values())
return out
def plot_to_figure(grouped):
fig = plt.figure()
plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
plt.xticks(rotation=90)
return fig
def ner(text):
raw = pipe(text)
ner_content = {
"text": text,
"entities": [
{
"entity": x["entity_group"],
"word": x["word"],
"score": x["score"],
"start": x["start"],
"end": x["end"],
}
for x in raw
],
}
grouped = group_by_entity(raw)
figure = plot_to_figure(grouped)
label = EXAMPLE_MAP.get(text, "Unknown")
meta = {
"entity_counts": grouped,
"entities": len(set(grouped.keys())),
"counts": sum(grouped.values()),
}
return (ner_content, meta, label, figure)
interface = gr.Interface(
ner,
inputs=gr.Textbox(label="Note text", value=""),
outputs=[
gr.HighlightedText(label="NER", combine_adjacent=True),
gr.JSON(label="Entity Counts"),
gr.Label(label="Rating"),
gr.Plot(label="Bar"),
],
examples=list(EXAMPLE_MAP.keys()),
allow_flagging="never",
)
interface.launch()
|