Spaces:
Build error
Build error
import streamlit as st | |
from spacy import displacy | |
import spacy | |
import geospacy | |
from PIL import Image | |
import base64 | |
import sys | |
import pandas as pd | |
import en_core_web_md | |
from spacy.tokens import Span, Doc, Token | |
from utils import geoutil | |
from utils import llm_coding | |
import urllib.parse | |
import json | |
colors = {'GPE': "#43c6fc", "LOC": "#fd9720", "RSE":"#a6e22d"} | |
options = {"ents": ['GPE', 'LOC', "RSE"], "colors": colors} | |
HTML_WRAPPER = """<div style="overflow-x: auto; border: none solid #a6e22d; border-radius: 0.25rem; padding: 1rem">{}</div>""" | |
model = "" | |
gpe_selected = "GPE" | |
loc_selected = "LOC" | |
rse_selected = "RSE" | |
types = "" | |
#BASE_URL = "http://localhost:8080/" | |
BASE_URL = "" | |
def set_header(): | |
LOGO_IMAGE = "tetis-1.png" | |
st.markdown( | |
""" | |
<style> | |
.container { | |
display: flex; | |
} | |
.logo-text { | |
font-weight:700 !important; | |
font-size:50px !important; | |
color: #f9a01b !important; | |
padding-left: 10px !important; | |
} | |
.logo-img { | |
float:right; | |
width: 28%; | |
height: 28%; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
f""" | |
<div class="container"> | |
<img class="logo-img" src="data:image/png;base64,{base64.b64encode(open(LOGO_IMAGE, "rb").read()).decode()}"> | |
<p class="logo-text">GeOspaCy</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
def set_side_menu(): | |
global gpe_selected, loc_selected, rse_selected, model, types | |
types ="" | |
params = st.query_params | |
st.sidebar.markdown("## Spacy Model") | |
st.sidebar.markdown("You can **select** the values of the *spacy model* from Dropdown.") | |
models = ['en_core_web_sm', 'en_core_web_md', 'en_core_web_lg', 'en_core_web_trf'] | |
if "model" in params: | |
default_ix = models.index(params["model"][0]) | |
else: | |
default_ix = models.index('en_core_web_sm') | |
model = st.sidebar.selectbox('Spacy Model',models, index=default_ix) | |
st.sidebar.markdown("## Spatial Entity Labels") | |
st.sidebar.markdown("**Mark** the Spatial Entities you want to extract?") | |
tpes = "" | |
if "type" in params: | |
tpes = params['type'][0] | |
if "g" in tpes: | |
gpe = st.sidebar.checkbox('GPE', value = True) | |
else: | |
gpe = st.sidebar.checkbox('GPE') | |
if "l" in tpes: | |
loc = st.sidebar.checkbox('LOC', value = True) | |
else: | |
loc = st.sidebar.checkbox('LOC') | |
if "r" in tpes: | |
rse = st.sidebar.checkbox('RSE', value = True) | |
else: | |
rse = st.sidebar.checkbox('RSE') | |
if(gpe): | |
gpe_selected ="GPE" | |
types+="g" | |
if(loc): | |
loc_selected ="LOC" | |
types+="l" | |
if(rse): | |
rse_selected ="RSE" | |
types+="r" | |
def set_input(): | |
params = st.query_params | |
# 设置默认文本 | |
default_text = params["text"][0] if "text" in params else "" | |
# 更友好的提示语 | |
# text = st.text_area("Enter a location description to extract spatial entities:", default_text) | |
text = st.text_area("✍️ **Please input your text here:**", default_text) | |
# text = st.text_area("### Please input your text here:", default_text) | |
# 提交按钮 | |
if st.button("Extract"): | |
return text | |
def set_selected_entities(doc): | |
global gpe_selected, loc_selected, rse_selected, model | |
ents = [ent for ent in doc.ents if ent.label_ == gpe_selected or ent.label_ == loc_selected or ent.label_ == rse_selected] | |
doc.ents = ents | |
return doc | |
def extract_spatial_entities(text): | |
nlp = spacy.load("en_core_web_md") | |
nlp.add_pipe("spatial_pipeline", after="ner") | |
doc = nlp(text) | |
# 分句处理 | |
sent_ents = [] | |
sent_texts = [] | |
sent_rse_id = [] | |
offset = 0 | |
sent_start_positions = [0] | |
doc_copy = doc.copy() | |
for sent in doc.sents: | |
sent_doc = nlp(sent.text) | |
sent_doc = set_selected_entities(sent_doc) | |
sent_texts.append(sent_doc.text) | |
for ent in sent_doc.ents: | |
sent_rse_id.append(ent._.rse_id) | |
for ent in sent_doc.ents: | |
new_ent = Span(doc, ent.start + offset, ent.end + offset, label=ent.label_) | |
sent_ents.append(new_ent) | |
offset += len(sent) | |
sent_start_positions.append(sent_start_positions[-1] + len(sent)) | |
# **创建新 Doc** | |
final_doc = Doc(nlp.vocab, words=[token.text for token in doc], spaces=[token.whitespace_ for token in doc]) | |
for i in sent_start_positions: | |
if i < len(final_doc): | |
final_doc[i].is_sent_start = True | |
final_doc.set_ents(sent_ents) | |
for i in range(len(sent_rse_id)): | |
final_doc.ents[i]._.rse_id = sent_rse_id[i] | |
doc = final_doc | |
doc.to_disk("saved_doc.spacy") | |
highlight_ents = [] | |
if 'g' in types: | |
highlight_ents.append('GPE') | |
if 'l' in types: | |
highlight_ents.append('LOC') | |
if 'r' in types: | |
highlight_ents.append('RSE') | |
options = {"ents": highlight_ents, "colors": colors} | |
html = displacy.render(doc, style="ent", options=options) | |
html = html.replace("\n","") | |
st.write(HTML_WRAPPER.format(html),unsafe_allow_html=True) | |
show_spatial_ent_table(doc, text) | |
show_sentence_selector_table(doc_copy) | |
def show_sentence_selector_table(doc_copy): | |
text = doc_copy.text | |
st.markdown("**______________________________________________________________________________________**") | |
st.markdown("**LLM-generated Spatial Composition**") | |
combo_obj = llm_coding.llmapi(text) | |
combo_lines = [json.dumps(item, separators=(", ", ": ")) for item in combo_obj] | |
combo_str = json.dumps(combo_obj) | |
combo_encoded = urllib.parse.quote(combo_str) | |
combo_encoded = urllib.parse.quote(combo_str) | |
text_encoded = urllib.parse.quote(text) | |
url = f"{BASE_URL}Locate?mode=geocombo&text={text_encoded}&combo={combo_encoded}" | |
rows = [{ | |
'LLM Output': f'<pre>{combo_str}</pre>', | |
'Action': f'<a target="_self" href="{url}">Use this spatial composition</a>' | |
}] | |
df = pd.DataFrame(rows) | |
custom_style = """ | |
<style> | |
table { | |
text-align: left !important; | |
} | |
th, td { | |
text-align: left !important; | |
} | |
</style> | |
""" | |
st.markdown(custom_style, unsafe_allow_html=True) | |
st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) | |
def show_spatial_ent_table(doc, text): | |
global types | |
if len(doc.ents) > 0: | |
st.markdown("**______________________________________________________________________________________**") | |
st.markdown("**Spatial Entities List**") | |
df = pd.DataFrame(columns=['Sr.', 'entity', 'label', 'Map', 'GEOJson']) | |
rows = [] # 用于存储所有行 | |
for ent in doc.ents: | |
url_map = BASE_URL + "Locate?map=true&type=" + types + "&model=" + model + "&text=" + text + "&entity=" + ent._.rse_id | |
url_json = BASE_URL + "Locate?geojson=true&type=" + types + "&model=" + model + "&text=" + text + "&entity=" + ent._.rse_id | |
new_row = { | |
'Sr.': len(rows) + 1, | |
'entity': ent.text, | |
'label': ent.label_, | |
'Map': f'<a target="_self" href="{url_map}">View</a>', | |
'GEOJson': f'<a target="_self" href="{url_json}">View</a>' | |
} | |
rows.append(new_row) | |
df = pd.DataFrame(rows) | |
st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) | |
def set_header(): | |
# LOGO_IMAGE = "title.jpg" | |
st.markdown( | |
""" | |
<style> | |
.container { | |
display: flex; | |
} | |
.logo-text { | |
font-weight:700 !important; | |
font-size:50px !important; | |
color: #52aee3 !important; | |
padding-left: 10px !important; | |
} | |
.logo-img { | |
float:right; | |
width: 10%; | |
height: 10%; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
""" | |
<div class="container"> | |
<p class="logo-text">SpatialParse</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
def set_side_menu(): | |
global gpe_selected, loc_selected, rse_selected, model, types | |
types = "" | |
params = st.query_params | |
st.sidebar.markdown("## Deployment Method") | |
st.sidebar.markdown("You can select the deployment method for the model.") | |
deployment_options = ["API", "Local deployment"] | |
use_local_model = st.sidebar.radio("Choose deployment method:", deployment_options, index=0) == "Local deployment" | |
if use_local_model: | |
local_model_path = st.sidebar.text_input("Enter local model path:", "") | |
st.sidebar.markdown("## LLM Model") | |
st.sidebar.markdown("You can **select** different *LLM model* powered by API.") | |
models = ['Llama-3-8B', 'Mistral-7B-0.3', 'Gemma-2-10B', 'GPT-4o', 'Gemini Pro', 'Deepseek-R1', 'en_core_web_sm', 'en_core_web_md', 'en_core_web_lg', 'en_core_web_trf'] | |
if "model" in params: | |
default_ix = models.index(params["model"][0]) | |
else: | |
default_ix = models.index('GPT-4o') | |
model = st.sidebar.selectbox('LLM Model', models, index=default_ix) | |
st.sidebar.markdown("## Spatial Entity Labels") | |
st.sidebar.markdown("Please **Mark** the Spatial Entities you want to extract.") | |
tpes = "" | |
if "type" in params: | |
tpes = params['type'][0] | |
# st.sidebar.markdown("### Absolute Spatial Entity:") | |
if "g" in tpes: | |
gpe = st.sidebar.checkbox('GPE', value=True) | |
else: | |
gpe = st.sidebar.checkbox('GPE') | |
if "l" in tpes: | |
loc = st.sidebar.checkbox('LOC', value=True) | |
else: | |
loc = st.sidebar.checkbox('LOC') | |
st.sidebar.markdown("### Relative Spatial Entity:") | |
if "r" in tpes: | |
rse = st.sidebar.checkbox('RSE', value=True) | |
else: | |
rse = st.sidebar.checkbox('RSE') | |
if (gpe): | |
gpe_selected = "GPE" | |
types += "g" | |
if (loc): | |
loc_selected = "LOC" | |
types += "l" | |
if (rse): | |
rse_selected = "RSE" | |
types += "r" | |
def main(): | |
global gpe_selected, loc_selected, rse_selected, model | |
set_header() | |
set_side_menu() | |
text = set_input() | |
if(text is not None): | |
extract_spatial_entities(text) | |
elif "text" in st.session_state: | |
text = st.session_state.text | |
extract_spatial_entities(text) | |
if __name__ == '__main__': | |
main() | |