Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,10 +10,11 @@ from torch import cuda
|
|
10 |
import json
|
11 |
from datetime import datetime
|
12 |
|
13 |
-
# get the device to be used
|
14 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
15 |
|
16 |
-
|
|
|
17 |
st.title("GIZ Project Database (PROTOTYPE)")
|
18 |
var = st.text_input("Enter Search Query")
|
19 |
|
@@ -22,11 +23,14 @@ region_lookup_path = "docStore/regions_lookup.csv"
|
|
22 |
region_df = load_region_data(region_lookup_path)
|
23 |
|
24 |
#################### Create the embeddings collection and save ######################
|
25 |
-
#
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
28 |
collection_name = "giz_worldwide"
|
29 |
-
|
30 |
|
31 |
################### Hybrid Search ######################################################
|
32 |
client = get_client()
|
@@ -43,7 +47,6 @@ _, unique_sub_regions = get_regions(region_df)
|
|
43 |
def get_country_name_and_region_mapping(_client, collection_name, region_df):
|
44 |
results = hybrid_search(_client, "", collection_name)
|
45 |
country_set = set()
|
46 |
-
|
47 |
for res in results[0] + results[1]:
|
48 |
countries = res.payload.get('metadata', {}).get('countries', "[]")
|
49 |
try:
|
@@ -77,11 +80,11 @@ col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
|
|
77 |
|
78 |
# Region filter
|
79 |
with col1:
|
80 |
-
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
|
81 |
|
82 |
# Dynamically filter countries based on selected region
|
83 |
if region_filter == "All/Not allocated":
|
84 |
-
filtered_country_names = unique_country_names
|
85 |
else:
|
86 |
filtered_country_names = [
|
87 |
name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter
|
@@ -89,23 +92,25 @@ else:
|
|
89 |
|
90 |
# Country filter
|
91 |
with col2:
|
92 |
-
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
|
93 |
-
|
94 |
-
#
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
104 |
|
105 |
# Checkbox to control whether to show only exact matches
|
106 |
show_exact_matches = st.checkbox("Show only exact matches", value=False)
|
107 |
|
108 |
-
def filter_results(results, country_filter, region_filter):
|
109 |
filtered = []
|
110 |
for r in results:
|
111 |
metadata = r.payload.get('metadata', {})
|
@@ -136,97 +141,133 @@ def filter_results(results, country_filter, region_filter):
|
|
136 |
else:
|
137 |
countries_in_region = c_list
|
138 |
|
|
|
139 |
if (
|
140 |
(country_filter == "All/Not allocated" or selected_iso_code in c_list)
|
141 |
and (region_filter == "All/Not allocated" or countries_in_region)
|
|
|
142 |
):
|
143 |
filtered.append(r)
|
144 |
return filtered
|
145 |
|
146 |
# Run the search
|
147 |
-
|
|
|
|
|
|
|
|
|
148 |
semantic_all = results[0]
|
149 |
lexical_all = results[1]
|
150 |
|
151 |
-
# Filter out
|
152 |
-
semantic_all = [
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
# Apply a threshold to SEMANTIC results (score >= 0.
|
156 |
-
semantic_thresholded = [r for r in semantic_all if r.score >= 0.
|
157 |
|
158 |
-
|
159 |
-
|
|
|
160 |
|
161 |
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
|
162 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
163 |
|
164 |
-
# Define a helper function to format currency values
|
165 |
-
def format_currency(value):
|
166 |
-
try:
|
167 |
-
# Convert to float then int for formatting (assumes whole numbers)
|
168 |
-
return f"鈧瑊int(float(value)):,}"
|
169 |
-
except (ValueError, TypeError):
|
170 |
-
return value
|
171 |
|
172 |
-
#
|
|
|
173 |
if show_exact_matches:
|
|
|
174 |
st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
|
175 |
|
|
|
|
|
176 |
query_substring = var.strip().lower()
|
177 |
lexical_substring_filtered = []
|
178 |
for r in lexical_all:
|
179 |
-
|
|
|
|
|
|
|
180 |
lexical_substring_filtered.append(r)
|
181 |
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
183 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
184 |
|
|
|
185 |
if not filtered_lexical_no_dupe:
|
186 |
st.write('No exact matches, consider unchecking "Show only exact matches"')
|
187 |
else:
|
|
|
188 |
for res in filtered_lexical_no_dupe[:15]:
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
# Build snippet from objectives and descriptions.
|
196 |
-
objectives = metadata.get("objectives", "")
|
197 |
-
desc_de = metadata.get("description.de", "")
|
198 |
-
desc_en = metadata.get("description.en", "")
|
199 |
-
description = desc_de if desc_de else desc_en
|
200 |
-
full_snippet = f"Objective: {objectives} Description: {description}"
|
201 |
-
preview_limit = 400 # preview limit in characters
|
202 |
-
preview_snippet = full_snippet if len(full_snippet) <= preview_limit else full_snippet[:preview_limit] + "..."
|
203 |
-
# Using HTML to add a tooltip with the full snippet text.
|
204 |
-
st.markdown(f'<span title="{full_snippet}">{preview_snippet}</span>', unsafe_allow_html=True)
|
205 |
-
|
206 |
-
# Keywords remain the same.
|
207 |
full_text = res.payload['page_content']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
top_keywords = extract_top_keywords(full_text, top_n=5)
|
209 |
if top_keywords:
|
210 |
st.markdown(f"_{' 路 '.join(top_keywords)}_")
|
211 |
|
212 |
-
# Metadata
|
|
|
|
|
213 |
client_name = metadata.get('client', 'Unknown Client')
|
214 |
start_year = metadata.get('start_year', None)
|
215 |
end_year = metadata.get('end_year', None)
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
start_year_str = f"{int(round(float(start_year)))}" if start_year else "Unknown"
|
220 |
end_year_str = f"{int(round(float(end_year)))}" if end_year else "Unknown"
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
230 |
st.markdown(additional_text)
|
231 |
st.divider()
|
232 |
|
@@ -236,51 +277,73 @@ else:
|
|
236 |
if not filtered_semantic_no_dupe:
|
237 |
st.write("No relevant results found.")
|
238 |
else:
|
|
|
239 |
for res in filtered_semantic_no_dupe[:15]:
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
st.markdown(f"#### {project_name} [{proj_id}]")
|
244 |
-
|
245 |
-
# Build snippet from objectives and descriptions.
|
246 |
-
objectives = metadata.get("objectives", "")
|
247 |
-
desc_de = metadata.get("description.de", "")
|
248 |
-
desc_en = metadata.get("description.en", "")
|
249 |
-
description = desc_de if desc_de else desc_en
|
250 |
-
full_snippet = f"Objective: {objectives} Description: {description}"
|
251 |
-
preview_limit = 400
|
252 |
-
preview_snippet = full_snippet if len(full_snippet) <= preview_limit else full_snippet[:preview_limit] + "..."
|
253 |
-
st.markdown(f'<span title="{full_snippet}">{preview_snippet}</span>', unsafe_allow_html=True)
|
254 |
|
255 |
-
#
|
256 |
full_text = res.payload['page_content']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
top_keywords = extract_top_keywords(full_text, top_n=5)
|
258 |
if top_keywords:
|
259 |
st.markdown(f"_{' 路 '.join(top_keywords)}_")
|
260 |
|
|
|
|
|
|
|
261 |
client_name = metadata.get('client', 'Unknown Client')
|
262 |
start_year = metadata.get('start_year', None)
|
263 |
end_year = metadata.get('end_year', None)
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
start_year_str = extract_year(start_year) if start_year else "Unknown"
|
268 |
end_year_str = extract_year(end_year) if end_year else "Unknown"
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
278 |
st.markdown(additional_text)
|
279 |
st.divider()
|
280 |
|
281 |
-
|
282 |
-
#
|
283 |
-
#
|
284 |
-
#
|
285 |
-
#
|
286 |
-
#
|
|
|
10 |
import json
|
11 |
from datetime import datetime
|
12 |
|
13 |
+
# get the device to be used eithe gpu or cpu
|
14 |
device = 'cuda' if cuda.is_available() else 'cpu'
|
15 |
|
16 |
+
|
17 |
+
st.set_page_config(page_title="SEARCH IATI",layout='wide')
|
18 |
st.title("GIZ Project Database (PROTOTYPE)")
|
19 |
var = st.text_input("Enter Search Query")
|
20 |
|
|
|
23 |
region_df = load_region_data(region_lookup_path)
|
24 |
|
25 |
#################### Create the embeddings collection and save ######################
|
26 |
+
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
|
27 |
+
##### First we process and create the chunks for relvant data source
|
28 |
+
chunks = process_giz_worldwide()
|
29 |
+
##### Convert to langchain documents
|
30 |
+
temp_doc = create_documents(chunks,'chunks')
|
31 |
+
##### Embed and store docs, check if collection exist then you need to update the collection
|
32 |
collection_name = "giz_worldwide"
|
33 |
+
hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
|
34 |
|
35 |
################### Hybrid Search ######################################################
|
36 |
client = get_client()
|
|
|
47 |
def get_country_name_and_region_mapping(_client, collection_name, region_df):
|
48 |
results = hybrid_search(_client, "", collection_name)
|
49 |
country_set = set()
|
|
|
50 |
for res in results[0] + results[1]:
|
51 |
countries = res.payload.get('metadata', {}).get('countries', "[]")
|
52 |
try:
|
|
|
80 |
|
81 |
# Region filter
|
82 |
with col1:
|
83 |
+
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions)) # Display region names
|
84 |
|
85 |
# Dynamically filter countries based on selected region
|
86 |
if region_filter == "All/Not allocated":
|
87 |
+
filtered_country_names = unique_country_names # Show all countries if no region is selected
|
88 |
else:
|
89 |
filtered_country_names = [
|
90 |
name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter
|
|
|
92 |
|
93 |
# Country filter
|
94 |
with col2:
|
95 |
+
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names) # Display filtered country names
|
96 |
+
|
97 |
+
# Year range slider
|
98 |
+
with col3:
|
99 |
+
current_year = datetime.now().year
|
100 |
+
default_start_year = current_year - 5
|
101 |
+
|
102 |
+
# 3) The max_value is now the actual max end_year from collection
|
103 |
+
end_year_range = st.slider(
|
104 |
+
"Project End Year",
|
105 |
+
min_value=2010,
|
106 |
+
max_value=max_end_year,
|
107 |
+
value=(default_start_year, max_end_year),
|
108 |
+
)
|
109 |
|
110 |
# Checkbox to control whether to show only exact matches
|
111 |
show_exact_matches = st.checkbox("Show only exact matches", value=False)
|
112 |
|
113 |
+
def filter_results(results, country_filter, region_filter, end_year_range):
|
114 |
filtered = []
|
115 |
for r in results:
|
116 |
metadata = r.payload.get('metadata', {})
|
|
|
141 |
else:
|
142 |
countries_in_region = c_list
|
143 |
|
144 |
+
# Filtering
|
145 |
if (
|
146 |
(country_filter == "All/Not allocated" or selected_iso_code in c_list)
|
147 |
and (region_filter == "All/Not allocated" or countries_in_region)
|
148 |
+
and (end_year_range[0] <= end_year_val <= end_year_range[1])
|
149 |
):
|
150 |
filtered.append(r)
|
151 |
return filtered
|
152 |
|
153 |
# Run the search
|
154 |
+
|
155 |
+
# 1) Adjust limit so we get more than 15 results
|
156 |
+
results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200
|
157 |
+
|
158 |
+
# results is a tuple: (semantic_results, lexical_results)
|
159 |
semantic_all = results[0]
|
160 |
lexical_all = results[1]
|
161 |
|
162 |
+
# 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score)
|
163 |
+
semantic_all = [
|
164 |
+
r for r in semantic_all if len(r.payload["page_content"]) >= 20
|
165 |
+
]
|
166 |
+
lexical_all = [
|
167 |
+
r for r in lexical_all if len(r.payload["page_content"]) >= 20
|
168 |
+
]
|
169 |
|
170 |
+
# 2) Apply a threshold to SEMANTIC results (score >= 0.4)
|
171 |
+
semantic_thresholded = [r for r in semantic_all if r.score >= 0.4]
|
172 |
|
173 |
+
# 2) Filter the entire sets
|
174 |
+
filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range)
|
175 |
+
filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range)
|
176 |
|
177 |
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
|
178 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
# 3) Retrieve top 15 *after* filtering
|
182 |
+
# Check user preference
|
183 |
if show_exact_matches:
|
184 |
+
# 1) Display heading
|
185 |
st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
|
186 |
|
187 |
+
# 2) Do a simple substring check (case-insensitive)
|
188 |
+
# We'll create a new list lexical_substring_filtered
|
189 |
query_substring = var.strip().lower()
|
190 |
lexical_substring_filtered = []
|
191 |
for r in lexical_all:
|
192 |
+
# page_content in lowercase
|
193 |
+
page_text_lower = r.payload["page_content"].lower()
|
194 |
+
# Keep this result only if the query substring is found
|
195 |
+
if query_substring in page_text_lower:
|
196 |
lexical_substring_filtered.append(r)
|
197 |
|
198 |
+
# 3) Now apply your region/country/year filter on that new list
|
199 |
+
filtered_lexical = filter_results(
|
200 |
+
lexical_substring_filtered, country_filter, region_filter, end_year_range
|
201 |
+
)
|
202 |
+
|
203 |
+
# 4) Remove duplicates
|
204 |
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
|
205 |
|
206 |
+
# 5) If empty after substring + filters + dedupe, show a custom message
|
207 |
if not filtered_lexical_no_dupe:
|
208 |
st.write('No exact matches, consider unchecking "Show only exact matches"')
|
209 |
else:
|
210 |
+
# 6) Display the first 15 matching results
|
211 |
for res in filtered_lexical_no_dupe[:15]:
|
212 |
+
project_name = res.payload['metadata'].get('project_name', 'Project Link')
|
213 |
+
url = res.payload['metadata'].get('url', '#')
|
214 |
+
st.markdown(f"#### [{project_name}]({url})")
|
215 |
+
|
216 |
+
# Snippet logic (80 words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
full_text = res.payload['page_content']
|
218 |
+
words = full_text.split()
|
219 |
+
preview_word_count = 80
|
220 |
+
preview_text = " ".join(words[:preview_word_count])
|
221 |
+
remainder_text = " ".join(words[preview_word_count:])
|
222 |
+
st.write(preview_text + ("..." if remainder_text else ""))
|
223 |
+
|
224 |
+
# Keywords
|
225 |
top_keywords = extract_top_keywords(full_text, top_n=5)
|
226 |
if top_keywords:
|
227 |
st.markdown(f"_{' 路 '.join(top_keywords)}_")
|
228 |
|
229 |
+
# Metadata
|
230 |
+
metadata = res.payload.get('metadata', {})
|
231 |
+
countries = metadata.get('countries', "[]")
|
232 |
client_name = metadata.get('client', 'Unknown Client')
|
233 |
start_year = metadata.get('start_year', None)
|
234 |
end_year = metadata.get('end_year', None)
|
235 |
+
|
236 |
+
try:
|
237 |
+
c_list = json.loads(countries.replace("'", '"'))
|
238 |
+
except json.JSONDecodeError:
|
239 |
+
c_list = []
|
240 |
+
|
241 |
+
# Only keep country names if the region lookup (get_country_name)
|
242 |
+
# returns something different than the raw code.
|
243 |
+
matched_countries = []
|
244 |
+
for code in c_list:
|
245 |
+
if len(code) == 2:
|
246 |
+
resolved_name = get_country_name(code.upper(), region_df)
|
247 |
+
# If get_country_name didn't find a match,
|
248 |
+
# it typically just returns the same code (like "XX").
|
249 |
+
# We'll consider "successfully looked up" if
|
250 |
+
# resolved_name != code.upper().
|
251 |
+
if resolved_name.upper() != code.upper():
|
252 |
+
matched_countries.append(resolved_name)
|
253 |
+
|
254 |
+
# Format the year range
|
255 |
start_year_str = f"{int(round(float(start_year)))}" if start_year else "Unknown"
|
256 |
end_year_str = f"{int(round(float(end_year)))}" if end_year else "Unknown"
|
257 |
+
|
258 |
+
# Build the final string
|
259 |
+
if matched_countries:
|
260 |
+
# We have at least 1 valid country name
|
261 |
+
additional_text = (
|
262 |
+
f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**, "
|
263 |
+
f"**{start_year_str}-{end_year_str}**"
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
# No valid countries found
|
267 |
+
additional_text = (
|
268 |
+
f"Commissioned by **{client_name}**, **{start_year_str}-{end_year_str}**"
|
269 |
+
)
|
270 |
+
|
271 |
st.markdown(additional_text)
|
272 |
st.divider()
|
273 |
|
|
|
277 |
if not filtered_semantic_no_dupe:
|
278 |
st.write("No relevant results found.")
|
279 |
else:
|
280 |
+
# Show the top 15 from filtered_semantic
|
281 |
for res in filtered_semantic_no_dupe[:15]:
|
282 |
+
project_name = res.payload['metadata'].get('project_name', 'Project Link')
|
283 |
+
url = res.payload['metadata'].get('url', '#')
|
284 |
+
st.markdown(f"#### [{project_name}]({url})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
+
# Snippet logic
|
287 |
full_text = res.payload['page_content']
|
288 |
+
words = full_text.split()
|
289 |
+
preview_word_count = 10
|
290 |
+
preview_text = " ".join(words[:preview_word_count])
|
291 |
+
remainder_text = " ".join(words[preview_word_count:])
|
292 |
+
st.write(preview_text + ("..." if remainder_text else ""))
|
293 |
+
|
294 |
+
# Keywords
|
295 |
top_keywords = extract_top_keywords(full_text, top_n=5)
|
296 |
if top_keywords:
|
297 |
st.markdown(f"_{' 路 '.join(top_keywords)}_")
|
298 |
|
299 |
+
# Metadata
|
300 |
+
metadata = res.payload.get('metadata', {})
|
301 |
+
countries = metadata.get('countries', "[]")
|
302 |
client_name = metadata.get('client', 'Unknown Client')
|
303 |
start_year = metadata.get('start_year', None)
|
304 |
end_year = metadata.get('end_year', None)
|
305 |
+
|
306 |
+
try:
|
307 |
+
c_list = json.loads(countries.replace("'", '"'))
|
308 |
+
except json.JSONDecodeError:
|
309 |
+
c_list = []
|
310 |
+
|
311 |
+
# Only keep country names if the region lookup (get_country_name)
|
312 |
+
# returns something different than the raw code.
|
313 |
+
matched_countries = []
|
314 |
+
for code in c_list:
|
315 |
+
if len(code) == 2:
|
316 |
+
resolved_name = get_country_name(code.upper(), region_df)
|
317 |
+
# If get_country_name didn't find a match,
|
318 |
+
# it typically just returns the same code (like "XX").
|
319 |
+
# We'll consider "successfully looked up" if
|
320 |
+
# resolved_name != code.upper().
|
321 |
+
if resolved_name.upper() != code.upper():
|
322 |
+
matched_countries.append(resolved_name)
|
323 |
+
|
324 |
+
# Format the year range
|
325 |
start_year_str = extract_year(start_year) if start_year else "Unknown"
|
326 |
end_year_str = extract_year(end_year) if end_year else "Unknown"
|
327 |
+
|
328 |
+
# Build the final string
|
329 |
+
if matched_countries:
|
330 |
+
# We have at least 1 valid country name
|
331 |
+
additional_text = (
|
332 |
+
f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**, "
|
333 |
+
f"**{start_year_str}-{end_year_str}**"
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
# No valid countries found
|
337 |
+
additional_text = (
|
338 |
+
f"Commissioned by **{client_name}**, **{start_year_str}-{end_year_str}**"
|
339 |
+
)
|
340 |
+
|
341 |
st.markdown(additional_text)
|
342 |
st.divider()
|
343 |
|
344 |
+
|
345 |
+
# for i in results:
|
346 |
+
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
|
347 |
+
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
|
348 |
+
# st.write(i.page_content)
|
349 |
+
# st.divider()
|