import torch from transformers import AutoTokenizer, AutoModelForTokenClassification import gradio as gr # Load model and tokenizer once model_name = "zekun-li/geolm-base-toponym-recognition" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name) model.to("cpu") # Use "cuda" if you have GPU model.eval() # def get_toponym_entities(text): # inputs = tokenizer( # text, # return_offsets_mapping=True, # return_tensors="pt", # truncation=True, # max_length=512, # ) # offset_mapping = inputs.pop("offset_mapping")[0] # input_ids = inputs["input_ids"] # with torch.no_grad(): # outputs = model(**inputs) # predictions = torch.argmax(outputs.logits, dim=2)[0] # entities = [] # for idx, label_id in enumerate(predictions): # if label_id != 0 and idx < len(offset_mapping): # start, end = offset_mapping[idx].tolist() # if end > start: # entities.append({"start": start, "end": end, "entity": "Topo"}) # return {"text": text, "entities": entities} def get_toponym_entities(text): inputs = tokenizer( text, return_offsets_mapping=True, return_tensors="pt", truncation=True, max_length=512, return_attention_mask=True, ) offset_mapping = inputs.pop("offset_mapping")[0] input_ids = inputs["input_ids"][0] with torch.no_grad(): outputs = model(**inputs) predictions = torch.argmax(outputs.logits, dim=2)[0] entities = [] current_entity = None for idx, (pred, offset) in enumerate(zip(predictions, offset_mapping)): start, end = offset.tolist() if start == end: # skip special tokens continue if pred != 0: # Non-O label if current_entity is None: current_entity = {"start": start, "end": end} else: # Extend the current entity span current_entity["end"] = end else: if current_entity is not None: current_entity["entity"] = "Topo" entities.append(current_entity) current_entity = None # Catch any lingering entity at the end if current_entity is not None: current_entity["entity"] = "Topo" entities.append(current_entity) return {"text": text, "entities": entities} # Launch Gradio app demo = gr.Interface( fn=get_toponym_entities, inputs=gr.Textbox(lines=10, placeholder="Enter text with place names..."), outputs=gr.HighlightedText(), title="🌍 Toponym Recognition with GeoLM", description="Enter a paragraph and detect place names using the zekun-li/geolm-base-toponym-recognition model.", examples=[ ["Minneapolis, officially the City of Minneapolis, is a city in the state of Minnesota and the county seat of Hennepin County. As of the 2020 census the population was 429,954, making it the largest city in Minnesota and the 46th-most-populous in the United States. Nicknamed the City of Lakes, Minneapolis is abundant in water, with thirteen lakes, wetlands, the Mississippi River, creeks, and waterfalls."], ["Los Angeles, often referred to by its initials L.A., is the most populous city in California, the most populous U.S. state. It is the commercial, financial, and cultural center of Southern California. Los Angeles is the second-most populous city in the United States after New York City, with a population of roughly 3.9 million residents within the city limits as of 2020."], ], ) if __name__ == "__main__": demo.launch()