annikwag's picture
Update app.py
50281ac verified
raw
history blame
21.5 kB
import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# Import existing modules from appStore
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# TF-IDF part (excluded from the app for now)
# from appStore.tfidf_extraction import extract_top_keywords
# Import helper modules
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
###########################################
# Model Config
###########################################
config = configparser.ConfigParser()
config.read('model_params.cfg')
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
st.set_page_config(page_title="SEARCH IATI", layout='wide')
###########################################
# Cache the project data
###########################################
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide data, returning a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# Determine min and max budgets in million euros
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
###########################################
# Prepare region data
###########################################
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
###########################################
# Get device
###########################################
device = 'cuda' if cuda.is_available() else 'cpu'
###########################################
# Streamlit App Layout
###########################################
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This app is a prototype for testing purposes using publicly available project data
from the German International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be wrong or misleading.
""", unsafe_allow_html=True
)
# Main query input (with a key so we can reset it)
var = st.text_input("Enter Question", key="query")
###########################################
# Create or load the embeddings collection
###########################################
collection_name = "giz_worldwide"
client = get_client()
print(client.get_collections())
# Uncomment if needed:
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build country->code and code->region mapping
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
###########################################
# Define reset_filters function using session_state
###########################################
def reset_filters():
st.session_state["region_filter"] = "All/Not allocated"
st.session_state["country_filter"] = "All/Not allocated"
current_year = datetime.now().year
default_start_year = current_year - 4
st.session_state["end_year_range"] = (default_start_year, max_end_year)
st.session_state["crs_filter"] = "All/Not allocated"
st.session_state["min_budget"] = min_budget_val
st.session_state["client_filter"] = "All/Not allocated"
st.session_state["query"] = ""
st.session_state["show_exact_matches"] = False
st.session_state["page"] = 1
###########################################
# Filter Controls - Row 1
###########################################
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions), key="region_filter")
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names, key="country_filter")
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year),
key="end_year_range"
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options, key="crs_filter")
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val,
key="min_budget"
)
###########################################
# Filter Controls - Row 2 (Additional Filters)
###########################################
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
with col1_2:
client_options = sorted(project_data["client"].dropna().unique().tolist())
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
with col2_2:
st.empty()
with col3_2:
st.empty()
with col4_2:
st.empty()
with col5_2:
# Plain reset button (will be moved to row 3 as well)
st.button("Reset Filters", on_click=reset_filters, key="reset_button_row2")
###########################################
# Filter Controls - Row 3 (Remaining Filter)
###########################################
col1_3, col2_3, col3_3, col4_3, col5_3 = st.columns(5)
with col1_3:
# Place the "Show only exact matches" checkbox here
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches")
with col2_3:
st.empty()
with col3_3:
st.empty()
with col4_3:
st.empty()
with col5_3:
# Right-align a more prominent reset button
with st.container():
st.markdown("<div style='text-align: right;'>", unsafe_allow_html=True)
if st.button("**Reset Filters**", key="reset_button_row3"):
reset_filters()
st.markdown("</div>", unsafe_allow_html=True)
###########################################
# Main Search / Results
###########################################
if not var.strip():
st.info("Please enter a question to see results.")
else:
# 1) Perform hybrid search
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out short pages
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply threshold to semantic results if desired
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
# 2) Filter results based on the user’s selections
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Additional filter by client
if client_filter != "All/Not allocated":
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
# Remove duplicates
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# --- Reprint Query (Right Aligned with "Query:") ---
st.markdown(f"<div style='text-align: right; font-size:2.1em; font-style: italic; font-weight: bold;'>Query: {var}</div>", unsafe_allow_html=True)
# 3) Display results
# Lexical Search Results Branch
if show_exact_matches:
st.write("Showing **Top Lexical Search results**")
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in filtered_lexical
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical_no_dupe = remove_duplicates(lexical_substring_filtered)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe # Show all matching lexical results
# --- Pagination (Above Lexical Results) ---
page_size = 15
total_results = len(top_results)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Top pagination widget (right aligned, 1/7 width)
col_pag_top = st.columns([6, 1])[1]
new_page_top = col_pag_top.selectbox("Select Page", list(range(1, total_pages + 1)), index=current_page - 1, key="page_top")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
paged_results = top_results[start_index:end_index]
for i, res in enumerate(paged_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
title_clean = re.sub(r'<a.*?>|</a>', '', title_html)
# Prepend the result number
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Insert Predecessor/Successor line if available
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
extra_line = ""
if predecessor:
extra_line += f"<br>**Predecessor Project:** {predecessor}"
if successor:
extra_line += f"<br>**Successor Project:** {successor}"
additional_text = (
f"**Objective:** {highlight_query(objective, var)}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b><br>"
+ extra_line +
f"<br>**Country:** {country_raw}<br>"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget
col_pag_bot = st.columns([6, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)), index=st.session_state.page - 1, key="page_bot")
st.session_state.page = new_page_bot
# Semantic Search Results Branch
else:
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
page_size = 15
total_results = len(filtered_semantic_no_dupe)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Top pagination widget (right aligned, 1/7 width)
col_pag_top = st.columns([6, 1])[1]
new_page_top = col_pag_top.selectbox("Select Page", list(range(1, total_pages + 1)), index=current_page - 1, key="page_top_sem")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
top_results = filtered_semantic_no_dupe[start_index:end_index]
# Prominent page info with bold numbers and green highlight if current page is not 1
page_num = f"<b style='color: green;'>{st.session_state.page}</b>" if st.session_state.page != 1 else f"<b>{st.session_state.page}</b>"
total_pages_str = f"<b>{total_pages}</b>"
st.markdown(f"Showing **{len(top_results)}** Semantic Search results (Page {page_num} of {total_pages_str})", unsafe_allow_html=True)
# --- RAG Answer (Right aligned, bullet points, bold numbers) ---
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
bullet_lines = []
for line in rag_answer.splitlines():
if line.strip():
# Bold any numbers in the line
line_bold = re.sub(r'(\d+)', r'<b>\1</b>', line)
bullet_lines.append(f"<li>{line_bold}</li>")
formatted_rag_answer = "<ul style='text-align: right; list-style-position: inside;'>" + "".join(bullet_lines) + "</ul>"
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
st.divider()
for i, res in enumerate(top_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_clean = re.sub(r'<a.*?>|</a>', '', metadata["title"])
# Prepend result number and make title bold
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
extra_line = ""
if predecessor:
extra_line += f"<br>**Predecessor Project:** {predecessor}"
if successor:
extra_line += f"<br>**Successor Project:** {successor}"
additional_text = (
f"**Objective:** {metadata.get('objective', '')}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: <b>{formatted_project_budget}</b>, Total volume: <b>{formatted_total_volume}</b><br>"
+ extra_line +
f"<br>**Country:** {country_raw}<br>"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget (right aligned, 1/7 width)
col_pag_bot = st.columns([6, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)), index=st.session_state.page - 1, key="page_bot_sem")
st.session_state.page = new_page_bot