annikwag's picture
Update app.py
296a14f verified
raw
history blame
15.4 kB
import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# Import existing modules from appStore
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# TF-IDF part (excluded from the app for now)
# from appStore.tfidf_extraction import extract_top_keywords
# Import helper modules
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
###########################################
# Model Config
###########################################
# Initialize the parser and read the configuration file
config = configparser.ConfigParser()
config.read('model_params.cfg')
# Retrieve model parameters
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
# Write access token from the settings
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
st.set_page_config(page_title="SEARCH IATI", layout='wide')
###########################################
# Cache the project data
###########################################
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide data, returning a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# Determine min and max budgets in million euros
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
###########################################
# Prepare region data
###########################################
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
###########################################
# Get device
###########################################
device = 'cuda' if cuda.is_available() else 'cpu'
###########################################
# Streamlit App Layout
###########################################
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("<h1 style='text-align:center;'>GIZ Project Search (PROTOTYPE)</h1>", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This app is a prototype for testing purposes using publicly available project data
from the German International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be wrong or misleading.
""", unsafe_allow_html=True
)
# Main query input
var = st.text_input("Enter Question")
###########################################
# Create or load the embeddings collection
###########################################
collection_name = "giz_worldwide"
client = get_client()
print(client.get_collections())
# If needed, once only:
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build country->code and code->region mapping
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
###########################################
# Filter Controls
###########################################
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year)
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options)
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val
)
show_exact_matches = st.checkbox("Show only exact matches", value=False)
###########################################
# Main Search / Results
###########################################
if not var.strip():
st.info("Please enter a question to see results.")
else:
# 1) Perform hybrid search
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out short pages
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply threshold to semantic results if desired
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
# 2) Filter results based on the user’s selections
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Remove duplicates
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
"""
Format a numeric or string value as currency (EUR) with commas.
"""
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# 3) Display results
if show_exact_matches:
# Lexical substring match only
st.write("Showing **Top 15 Lexical Search results**")
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in lexical_all
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical = filter_results(
lexical_substring_filtered,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe[:10]
# RAG answer
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
st.markdown(f"<h3 style='text-align:center; font-size:1.2em; font-style: italic;'>{var}</h3>", unsafe_allow_html=True)
st.write(rag_answer)
st.divider()
# Show each result
for res in top_results:
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
# Title
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
st.markdown(f"#### {title_html}", unsafe_allow_html=True)
# Description snippet
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Additional text
additional_text = (
f"**Objective:** {highlight_query(objective, var)}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
f"**Country:** {country_raw}<br>"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
else:
# Semantic results
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
top_results = filtered_semantic_no_dupe[:10]
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
st.markdown(f"<h2 style='text-align:center; font-size:2.5em;'>{var}</h2>", unsafe_allow_html=True)
st.write(rag_answer)
st.divider()
st.write("Showing **Top 15 Semantic Search results**")
for res in top_results:
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
st.markdown(f"#### {metadata['title']}")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
additional_text = (
f"**Objective:** {metadata.get('objective', '')}<br>"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
f"**Country:** {country_raw}<br>"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f"<br>**Contact:** [email protected]"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()