zekun-li's picture
Update app.py
58dd7a5 verified
raw
history blame
1.71 kB
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import gradio as gr
# Load model and tokenizer once
model_name = "zekun-li/geolm-base-toponym-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to("cpu") # Use "cuda" if you have GPU
model.eval()
def get_toponym_entities(text):
inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
truncation=True,
max_length=512,
)
offset_mapping = inputs.pop("offset_mapping")[0]
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)[0]
entities = []
for idx, label_id in enumerate(predictions):
if label_id != 0 and idx < len(offset_mapping):
start, end = offset_mapping[idx].tolist()
if end > start:
entities.append({"start": start, "end": end, "entity": "Toponym"})
return {"text": text, "entities": entities}
# Launch Gradio app
demo = gr.Interface(
fn=get_toponym_entities,
inputs=gr.Textbox(lines=10, placeholder="Enter text with place names..."),
outputs=gr.HighlightedText(),
title="🌍 Toponym Recognition with GeoLM",
description="Enter a paragraph and detect highlighted place names using the zekun-li/geolm-base-toponym-recognition model.",
examples=[
["Minneapolis, officially the City of Minneapolis, is a city in Minnesota."],
["Los Angeles is the most populous city in California."],
],
)
if __name__ == "__main__":
demo.launch()