Spaces:
Sleeping
Sleeping
from functools import lru_cache | |
import spacy | |
import gradio as gr | |
spacy.prefer_gpu() | |
import os | |
DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME","en_core_web_sm") | |
# 2024-Aug | |
ALL_MODELS = """\ | |
ca_core_news_sm | |
ca_core_news_md | |
ca_core_news_lg | |
ca_core_news_trf | |
zh_core_web_sm | |
zh_core_web_md | |
zh_core_web_lg | |
zh_core_web_trf | |
hr_core_news_sm | |
hr_core_news_md | |
hr_core_news_lg | |
da_core_news_sm | |
da_core_news_md | |
da_core_news_lg | |
da_core_news_trf | |
nl_core_news_sm | |
nl_core_news_md | |
nl_core_news_lg | |
en_core_web_sm | |
en_core_web_md | |
en_core_web_lg | |
en_core_web_trf | |
fi_core_news_sm | |
fi_core_news_md | |
fi_core_news_lg | |
fr_core_news_sm | |
fr_core_news_md | |
fr_core_news_lg | |
fr_dep_news_trf | |
de_core_news_sm | |
de_core_news_md | |
de_core_news_lg | |
de_dep_news_trf | |
el_core_news_sm | |
el_core_news_md | |
el_core_news_lg | |
it_core_news_sm | |
it_core_news_md | |
it_core_news_lg | |
ja_core_news_sm | |
ja_core_news_md | |
ja_core_news_lg | |
ja_core_news_trf | |
ko_core_news_sm | |
ko_core_news_md | |
ko_core_news_lg | |
lt_core_news_sm | |
lt_core_news_md | |
lt_core_news_lg | |
mk_core_news_sm | |
mk_core_news_md | |
mk_core_news_lg | |
xx_ent_wiki_sm | |
xx_sent_ud_sm | |
nb_core_news_sm | |
nb_core_news_md | |
nb_core_news_lg | |
pl_core_news_sm | |
pl_core_news_md | |
pl_core_news_lg | |
pt_core_news_sm | |
pt_core_news_md | |
pt_core_news_lg | |
ro_core_news_sm | |
ro_core_news_md | |
ro_core_news_lg | |
ru_core_news_sm | |
ru_core_news_md | |
ru_core_news_lg | |
sl_core_news_sm | |
sl_core_news_md | |
sl_core_news_lg | |
sl_core_news_trf | |
es_core_news_sm | |
es_core_news_md | |
es_core_news_lg | |
es_dep_news_trf | |
sv_core_news_sm | |
sv_core_news_md | |
sv_core_news_lg | |
uk_core_news_sm | |
uk_core_news_md | |
uk_core_news_lg | |
uk_core_news_trf""" | |
ALL_MODELS_LIST = [ | |
line.strip() | |
for line in ALL_MODELS.split("\n") | |
if line.strip() | |
] | |
def load_model(model_name=DEFAULT_MODEL_NAME): | |
try: | |
nlp = spacy.load(model_name) | |
except: | |
os.system(f"python -m spacy download {model_name}") | |
nlp = spacy.load(model_name) | |
return nlp | |
load_model() | |
def ner(sentence, model_name): | |
nlp = load_model(model_name) | |
result = nlp(sentence) | |
stand_char = 0 | |
ret = [] | |
entities = result.ents | |
for e in entities: | |
if e.start_char > stand_char: | |
ret.append(( | |
sentence[stand_char:e.start_char], | |
None | |
)) | |
stand_char = e.start_char | |
ret.append(( | |
sentence[e.start_char: e.end_char], | |
e.label_, | |
)) | |
stand_char = e.end_char | |
if stand_char < len(sentence): | |
ret.append(( | |
sentence[stand_char: len(sentence)], | |
None, | |
)) | |
return ret | |
demo = gr.Interface( | |
fn=ner, | |
inputs=[ | |
gr.Text(), | |
gr.Dropdown(ALL_MODELS_LIST, value=DEFAULT_MODEL_NAME) | |
], | |
outputs=[ | |
gr.HighlightedText( | |
show_legend=True, | |
) | |
], | |
examples=[ | |
[ | |
"This is 24/08/2024.\nAngela Zhang is one of the most important artist in Coca-Cola arena Dubai.", | |
"en_core_web_sm" | |
] | |
], | |
# cache_examples=True, | |
) | |
from fastapi import FastAPI | |
# from gradio2api.gr_fastapi import LocalGradioAppRouter | |
# router = LocalGradioAppRouter(demo) | |
app = FastAPI() | |
# app.include_router(router=router, prefix="/api") | |
# app = gr.mount_gradio_app(app,demo,path="/gui",) | |
app = gr.mount_gradio_app(app,demo,path="/",) | |