SpatialParseback / 1_SpatialParse.py
Shunfeng Zheng
Upload 83 files
4c425e5 verified
import streamlit as st
from spacy import displacy
import spacy
import geospacy
from PIL import Image
import base64
import sys
import pandas as pd
import en_core_web_md
from spacy.tokens import Span, Doc, Token
from utils import geoutil
from utils import llm_coding
import urllib.parse
import json
colors = {'GPE': "#43c6fc", "LOC": "#fd9720", "RSE":"#a6e22d"}
options = {"ents": ['GPE', 'LOC', "RSE"], "colors": colors}
HTML_WRAPPER = """<div style="overflow-x: auto; border: none solid #a6e22d; border-radius: 0.25rem; padding: 1rem">{}</div>"""
model = ""
gpe_selected = "GPE"
loc_selected = "LOC"
rse_selected = "RSE"
types = ""
#BASE_URL = "http://localhost:8080/"
BASE_URL = ""
def set_header():
LOGO_IMAGE = "tetis-1.png"
st.markdown(
"""
<style>
.container {
display: flex;
}
.logo-text {
font-weight:700 !important;
font-size:50px !important;
color: #f9a01b !important;
padding-left: 10px !important;
}
.logo-img {
float:right;
width: 28%;
height: 28%;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
f"""
<div class="container">
<img class="logo-img" src="data:image/png;base64,{base64.b64encode(open(LOGO_IMAGE, "rb").read()).decode()}">
<p class="logo-text">GeOspaCy</p>
</div>
""",
unsafe_allow_html=True
)
def set_side_menu():
global gpe_selected, loc_selected, rse_selected, model, types
types =""
params = st.query_params
st.sidebar.markdown("## Spacy Model")
st.sidebar.markdown("You can **select** the values of the *spacy model* from Dropdown.")
models = ['en_core_web_sm', 'en_core_web_md', 'en_core_web_lg', 'en_core_web_trf']
if "model" in params:
default_ix = models.index(params["model"][0])
else:
default_ix = models.index('en_core_web_sm')
model = st.sidebar.selectbox('Spacy Model',models, index=default_ix)
st.sidebar.markdown("## Spatial Entity Labels")
st.sidebar.markdown("**Mark** the Spatial Entities you want to extract?")
tpes = ""
if "type" in params:
tpes = params['type'][0]
if "g" in tpes:
gpe = st.sidebar.checkbox('GPE', value = True)
else:
gpe = st.sidebar.checkbox('GPE')
if "l" in tpes:
loc = st.sidebar.checkbox('LOC', value = True)
else:
loc = st.sidebar.checkbox('LOC')
if "r" in tpes:
rse = st.sidebar.checkbox('RSE', value = True)
else:
rse = st.sidebar.checkbox('RSE')
if(gpe):
gpe_selected ="GPE"
types+="g"
if(loc):
loc_selected ="LOC"
types+="l"
if(rse):
rse_selected ="RSE"
types+="r"
def set_input():
params = st.query_params
# 设置默认文本
default_text = params["text"][0] if "text" in params else ""
# 更友好的提示语
# text = st.text_area("Enter a location description to extract spatial entities:", default_text)
text = st.text_area("✍️ **Please input your text here:**", default_text)
# text = st.text_area("### Please input your text here:", default_text)
# 提交按钮
if st.button("Extract"):
return text
def set_selected_entities(doc):
global gpe_selected, loc_selected, rse_selected, model
ents = [ent for ent in doc.ents if ent.label_ == gpe_selected or ent.label_ == loc_selected or ent.label_ == rse_selected]
doc.ents = ents
return doc
def extract_spatial_entities(text):
nlp = spacy.load("en_core_web_md")
nlp.add_pipe("spatial_pipeline", after="ner")
doc = nlp(text)
# 分句处理
sent_ents = []
sent_texts = []
sent_rse_id = []
offset = 0
sent_start_positions = [0]
doc_copy = doc.copy()
for sent in doc.sents:
sent_doc = nlp(sent.text)
sent_doc = set_selected_entities(sent_doc)
sent_texts.append(sent_doc.text)
for ent in sent_doc.ents:
sent_rse_id.append(ent._.rse_id)
for ent in sent_doc.ents:
new_ent = Span(doc, ent.start + offset, ent.end + offset, label=ent.label_)
sent_ents.append(new_ent)
offset += len(sent)
sent_start_positions.append(sent_start_positions[-1] + len(sent))
# **创建新 Doc**
final_doc = Doc(nlp.vocab, words=[token.text for token in doc], spaces=[token.whitespace_ for token in doc])
for i in sent_start_positions:
if i < len(final_doc):
final_doc[i].is_sent_start = True
final_doc.set_ents(sent_ents)
for i in range(len(sent_rse_id)):
final_doc.ents[i]._.rse_id = sent_rse_id[i]
doc = final_doc
doc.to_disk("saved_doc.spacy")
highlight_ents = []
if 'g' in types:
highlight_ents.append('GPE')
if 'l' in types:
highlight_ents.append('LOC')
if 'r' in types:
highlight_ents.append('RSE')
options = {"ents": highlight_ents, "colors": colors}
html = displacy.render(doc, style="ent", options=options)
html = html.replace("\n","")
st.write(HTML_WRAPPER.format(html),unsafe_allow_html=True)
show_spatial_ent_table(doc, text)
show_sentence_selector_table(doc_copy)
def show_sentence_selector_table(doc_copy):
text = doc_copy.text
st.markdown("**______________________________________________________________________________________**")
st.markdown("**LLM-generated Spatial Composition**")
combo_obj = llm_coding.llmapi(text)
combo_lines = [json.dumps(item, separators=(", ", ": ")) for item in combo_obj]
combo_str = json.dumps(combo_obj)
combo_encoded = urllib.parse.quote(combo_str)
combo_encoded = urllib.parse.quote(combo_str)
text_encoded = urllib.parse.quote(text)
url = f"{BASE_URL}Locate?mode=geocombo&text={text_encoded}&combo={combo_encoded}"
rows = [{
'LLM Output': f'<pre>{combo_str}</pre>',
'Action': f'<a target="_self" href="{url}">Use this spatial composition</a>'
}]
df = pd.DataFrame(rows)
custom_style = """
<style>
table {
text-align: left !important;
}
th, td {
text-align: left !important;
}
</style>
"""
st.markdown(custom_style, unsafe_allow_html=True)
st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
def show_spatial_ent_table(doc, text):
global types
if len(doc.ents) > 0:
st.markdown("**______________________________________________________________________________________**")
st.markdown("**Spatial Entities List**")
df = pd.DataFrame(columns=['Sr.', 'entity', 'label', 'Map', 'GEOJson'])
rows = [] # 用于存储所有行
for ent in doc.ents:
url_map = BASE_URL + "Locate?map=true&type=" + types + "&model=" + model + "&text=" + text + "&entity=" + ent._.rse_id
url_json = BASE_URL + "Locate?geojson=true&type=" + types + "&model=" + model + "&text=" + text + "&entity=" + ent._.rse_id
new_row = {
'Sr.': len(rows) + 1,
'entity': ent.text,
'label': ent.label_,
'Map': f'<a target="_self" href="{url_map}">View</a>',
'GEOJson': f'<a target="_self" href="{url_json}">View</a>'
}
rows.append(new_row)
df = pd.DataFrame(rows)
st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
def set_header():
# LOGO_IMAGE = "title.jpg"
st.markdown(
"""
<style>
.container {
display: flex;
}
.logo-text {
font-weight:700 !important;
font-size:50px !important;
color: #52aee3 !important;
padding-left: 10px !important;
}
.logo-img {
float:right;
width: 10%;
height: 10%;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<div class="container">
<p class="logo-text">SpatialParse</p>
</div>
""",
unsafe_allow_html=True
)
def set_side_menu():
global gpe_selected, loc_selected, rse_selected, model, types
types = ""
params = st.query_params
st.sidebar.markdown("## Deployment Method")
st.sidebar.markdown("You can select the deployment method for the model.")
deployment_options = ["API", "Local deployment"]
use_local_model = st.sidebar.radio("Choose deployment method:", deployment_options, index=0) == "Local deployment"
if use_local_model:
local_model_path = st.sidebar.text_input("Enter local model path:", "")
st.sidebar.markdown("## LLM Model")
st.sidebar.markdown("You can **select** different *LLM model* powered by API.")
models = ['Llama-3-8B', 'Mistral-7B-0.3', 'Gemma-2-10B', 'GPT-4o', 'Gemini Pro', 'Deepseek-R1', 'en_core_web_sm', 'en_core_web_md', 'en_core_web_lg', 'en_core_web_trf']
if "model" in params:
default_ix = models.index(params["model"][0])
else:
default_ix = models.index('GPT-4o')
model = st.sidebar.selectbox('LLM Model', models, index=default_ix)
st.sidebar.markdown("## Spatial Entity Labels")
st.sidebar.markdown("Please **Mark** the Spatial Entities you want to extract.")
tpes = ""
if "type" in params:
tpes = params['type'][0]
# st.sidebar.markdown("### Absolute Spatial Entity:")
if "g" in tpes:
gpe = st.sidebar.checkbox('GPE', value=True)
else:
gpe = st.sidebar.checkbox('GPE')
if "l" in tpes:
loc = st.sidebar.checkbox('LOC', value=True)
else:
loc = st.sidebar.checkbox('LOC')
st.sidebar.markdown("### Relative Spatial Entity:")
if "r" in tpes:
rse = st.sidebar.checkbox('RSE', value=True)
else:
rse = st.sidebar.checkbox('RSE')
if (gpe):
gpe_selected = "GPE"
types += "g"
if (loc):
loc_selected = "LOC"
types += "l"
if (rse):
rse_selected = "RSE"
types += "r"
def main():
global gpe_selected, loc_selected, rse_selected, model
set_header()
set_side_menu()
text = set_input()
if(text is not None):
extract_spatial_entities(text)
elif "text" in st.session_state:
text = st.session_state.text
extract_spatial_entities(text)
if __name__ == '__main__':
main()