zekun-li's picture
Update app.py
790df6e verified
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import gradio as gr
# Load model and tokenizer once
model_name = "zekun-li/geolm-base-toponym-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to("cpu") # Use "cuda" if you have GPU
model.eval()
# def get_toponym_entities(text):
# inputs = tokenizer(
# text,
# return_offsets_mapping=True,
# return_tensors="pt",
# truncation=True,
# max_length=512,
# )
# offset_mapping = inputs.pop("offset_mapping")[0]
# input_ids = inputs["input_ids"]
# with torch.no_grad():
# outputs = model(**inputs)
# predictions = torch.argmax(outputs.logits, dim=2)[0]
# entities = []
# for idx, label_id in enumerate(predictions):
# if label_id != 0 and idx < len(offset_mapping):
# start, end = offset_mapping[idx].tolist()
# if end > start:
# entities.append({"start": start, "end": end, "entity": "Topo"})
# return {"text": text, "entities": entities}
def get_toponym_entities(text):
inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
truncation=True,
max_length=512,
return_attention_mask=True,
)
offset_mapping = inputs.pop("offset_mapping")[0]
input_ids = inputs["input_ids"][0]
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)[0]
entities = []
current_entity = None
for idx, (pred, offset) in enumerate(zip(predictions, offset_mapping)):
start, end = offset.tolist()
if start == end: # skip special tokens
continue
if pred != 0: # Non-O label
if current_entity is None:
current_entity = {"start": start, "end": end}
else:
# Extend the current entity span
current_entity["end"] = end
else:
if current_entity is not None:
current_entity["entity"] = "Topo"
entities.append(current_entity)
current_entity = None
# Catch any lingering entity at the end
if current_entity is not None:
current_entity["entity"] = "Topo"
entities.append(current_entity)
return {"text": text, "entities": entities}
# Launch Gradio app
demo = gr.Interface(
fn=get_toponym_entities,
inputs=gr.Textbox(lines=10, placeholder="Enter text with place names..."),
outputs=gr.HighlightedText(),
title="🌍 Toponym Recognition with GeoLM",
description="Enter a paragraph and detect place names using the zekun-li/geolm-base-toponym-recognition model.",
examples=[
["Minneapolis, officially the City of Minneapolis, is a city in the state of Minnesota and the county seat of Hennepin County. As of the 2020 census the population was 429,954, making it the largest city in Minnesota and the 46th-most-populous in the United States. Nicknamed the City of Lakes, Minneapolis is abundant in water, with thirteen lakes, wetlands, the Mississippi River, creeks, and waterfalls."],
["Los Angeles, often referred to by its initials L.A., is the most populous city in California, the most populous U.S. state. It is the commercial, financial, and cultural center of Southern California. Los Angeles is the second-most populous city in the United States after New York City, with a population of roughly 3.9 million residents within the city limits as of 2020."],
],
)
if __name__ == "__main__":
demo.launch()