annikwag commited on
Commit
47177b9
verified
1 Parent(s): 755183b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -104
app.py CHANGED
@@ -10,10 +10,11 @@ from torch import cuda
10
  import json
11
  from datetime import datetime
12
 
13
- # get the device to be used either gpu or cpu
14
  device = 'cuda' if cuda.is_available() else 'cpu'
15
 
16
- st.set_page_config(page_title="SEARCH IATI", layout='wide')
 
17
  st.title("GIZ Project Database (PROTOTYPE)")
18
  var = st.text_input("Enter Search Query")
19
 
@@ -22,11 +23,14 @@ region_lookup_path = "docStore/regions_lookup.csv"
22
  region_df = load_region_data(region_lookup_path)
23
 
24
  #################### Create the embeddings collection and save ######################
25
- # Uncomment these lines to process and embed your data only once.
26
- # chunks = process_giz_worldwide()
27
- # temp_doc = create_documents(chunks, 'chunks')
 
 
 
28
  collection_name = "giz_worldwide"
29
- # hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
30
 
31
  ################### Hybrid Search ######################################################
32
  client = get_client()
@@ -43,7 +47,6 @@ _, unique_sub_regions = get_regions(region_df)
43
  def get_country_name_and_region_mapping(_client, collection_name, region_df):
44
  results = hybrid_search(_client, "", collection_name)
45
  country_set = set()
46
-
47
  for res in results[0] + results[1]:
48
  countries = res.payload.get('metadata', {}).get('countries', "[]")
49
  try:
@@ -77,11 +80,11 @@ col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
77
 
78
  # Region filter
79
  with col1:
80
- region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
81
 
82
  # Dynamically filter countries based on selected region
83
  if region_filter == "All/Not allocated":
84
- filtered_country_names = unique_country_names
85
  else:
86
  filtered_country_names = [
87
  name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter
@@ -89,23 +92,25 @@ else:
89
 
90
  # Country filter
91
  with col2:
92
- country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
93
-
94
- # ToDo: Add year filter later if needed (currently commented out)
95
- # with col3:
96
- # current_year = datetime.now().year
97
- # default_start_year = current_year - 5
98
- # end_year_range = st.slider(
99
- # "Project End Year",
100
- # min_value=2010,
101
- # max_value=max_end_year,
102
- # value=(default_start_year, max_end_year),
103
- # )
 
 
104
 
105
  # Checkbox to control whether to show only exact matches
106
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
107
 
108
- def filter_results(results, country_filter, region_filter):
109
  filtered = []
110
  for r in results:
111
  metadata = r.payload.get('metadata', {})
@@ -136,97 +141,133 @@ def filter_results(results, country_filter, region_filter):
136
  else:
137
  countries_in_region = c_list
138
 
 
139
  if (
140
  (country_filter == "All/Not allocated" or selected_iso_code in c_list)
141
  and (region_filter == "All/Not allocated" or countries_in_region)
 
142
  ):
143
  filtered.append(r)
144
  return filtered
145
 
146
  # Run the search
147
- results = hybrid_search(client, var, collection_name, limit=500)
 
 
 
 
148
  semantic_all = results[0]
149
  lexical_all = results[1]
150
 
151
- # Filter out very short content
152
- semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
153
- lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
 
 
 
 
154
 
155
- # Apply a threshold to SEMANTIC results (score >= 0.0)
156
- semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
157
 
158
- filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter)
159
- filtered_lexical = filter_results(lexical_all, country_filter, region_filter)
 
160
 
161
  filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
162
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
163
 
164
- # Define a helper function to format currency values
165
- def format_currency(value):
166
- try:
167
- # Convert to float then int for formatting (assumes whole numbers)
168
- return f"鈧瑊int(float(value)):,}"
169
- except (ValueError, TypeError):
170
- return value
171
 
172
- # Display Results
 
173
  if show_exact_matches:
 
174
  st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
175
 
 
 
176
  query_substring = var.strip().lower()
177
  lexical_substring_filtered = []
178
  for r in lexical_all:
179
- if query_substring in r.payload["page_content"].lower():
 
 
 
180
  lexical_substring_filtered.append(r)
181
 
182
- filtered_lexical = filter_results(lexical_substring_filtered, country_filter, region_filter)
 
 
 
 
 
183
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
184
 
 
185
  if not filtered_lexical_no_dupe:
186
  st.write('No exact matches, consider unchecking "Show only exact matches"')
187
  else:
 
188
  for res in filtered_lexical_no_dupe[:15]:
189
- metadata = res.payload.get('metadata', {})
190
- # Get title and id; do not format as a link.
191
- project_name = metadata.get('project_name', 'Project Link')
192
- proj_id = metadata.get('id', 'Unknown')
193
- st.markdown(f"#### {project_name} [{proj_id}]")
194
-
195
- # Build snippet from objectives and descriptions.
196
- objectives = metadata.get("objectives", "")
197
- desc_de = metadata.get("description.de", "")
198
- desc_en = metadata.get("description.en", "")
199
- description = desc_de if desc_de else desc_en
200
- full_snippet = f"Objective: {objectives} Description: {description}"
201
- preview_limit = 400 # preview limit in characters
202
- preview_snippet = full_snippet if len(full_snippet) <= preview_limit else full_snippet[:preview_limit] + "..."
203
- # Using HTML to add a tooltip with the full snippet text.
204
- st.markdown(f'<span title="{full_snippet}">{preview_snippet}</span>', unsafe_allow_html=True)
205
-
206
- # Keywords remain the same.
207
  full_text = res.payload['page_content']
 
 
 
 
 
 
 
208
  top_keywords = extract_top_keywords(full_text, top_n=5)
209
  if top_keywords:
210
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
211
 
212
- # Metadata: get client, duration and budget details.
 
 
213
  client_name = metadata.get('client', 'Unknown Client')
214
  start_year = metadata.get('start_year', None)
215
  end_year = metadata.get('end_year', None)
216
- total_volume = metadata.get('total_volume', "Unknown")
217
- total_project = metadata.get('total_project', "Unknown")
218
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  start_year_str = f"{int(round(float(start_year)))}" if start_year else "Unknown"
220
  end_year_str = f"{int(round(float(end_year)))}" if end_year else "Unknown"
221
-
222
- formatted_project_budget = format_currency(total_project)
223
- formatted_total_volume = format_currency(total_volume)
224
-
225
- additional_text = (
226
- f"Commissioned by **{client_name}**\n"
227
- f"Projekt duration **{start_year_str}-{end_year_str}**\n"
228
- f"Budget: Project: **{formatted_project_budget}**, total volume: **{formatted_total_volume}**"
229
- )
 
 
 
 
 
230
  st.markdown(additional_text)
231
  st.divider()
232
 
@@ -236,51 +277,73 @@ else:
236
  if not filtered_semantic_no_dupe:
237
  st.write("No relevant results found.")
238
  else:
 
239
  for res in filtered_semantic_no_dupe[:15]:
240
- metadata = res.payload.get('metadata', {})
241
- project_name = metadata.get('project_name', 'Project Link')
242
- proj_id = metadata.get('id', 'Unknown')
243
- st.markdown(f"#### {project_name} [{proj_id}]")
244
-
245
- # Build snippet from objectives and descriptions.
246
- objectives = metadata.get("objectives", "")
247
- desc_de = metadata.get("description.de", "")
248
- desc_en = metadata.get("description.en", "")
249
- description = desc_de if desc_de else desc_en
250
- full_snippet = f"Objective: {objectives} Description: {description}"
251
- preview_limit = 400
252
- preview_snippet = full_snippet if len(full_snippet) <= preview_limit else full_snippet[:preview_limit] + "..."
253
- st.markdown(f'<span title="{full_snippet}">{preview_snippet}</span>', unsafe_allow_html=True)
254
 
255
- # Keywords
256
  full_text = res.payload['page_content']
 
 
 
 
 
 
 
257
  top_keywords = extract_top_keywords(full_text, top_n=5)
258
  if top_keywords:
259
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
260
 
 
 
 
261
  client_name = metadata.get('client', 'Unknown Client')
262
  start_year = metadata.get('start_year', None)
263
  end_year = metadata.get('end_year', None)
264
- total_volume = metadata.get('total_volume', "Unknown")
265
- total_project = metadata.get('total_project', "Unknown")
266
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  start_year_str = extract_year(start_year) if start_year else "Unknown"
268
  end_year_str = extract_year(end_year) if end_year else "Unknown"
269
-
270
- formatted_project_budget = format_currency(total_project)
271
- formatted_total_volume = format_currency(total_volume)
272
-
273
- additional_text = (
274
- f"Commissioned by **{client_name}**\n"
275
- f"Projekt duration **{start_year_str}-{end_year_str}**\n"
276
- f"Budget: Project: **{formatted_project_budget}**, total volume: **{formatted_total_volume}**"
277
- )
 
 
 
 
 
278
  st.markdown(additional_text)
279
  st.divider()
280
 
281
- # Uncomment the following lines if you need to debug by listing raw results.
282
- # for i in results:
283
- # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
284
- # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
285
- # st.write(i.page_content)
286
- # st.divider()
 
10
  import json
11
  from datetime import datetime
12
 
13
+ # get the device to be used eithe gpu or cpu
14
  device = 'cuda' if cuda.is_available() else 'cpu'
15
 
16
+
17
+ st.set_page_config(page_title="SEARCH IATI",layout='wide')
18
  st.title("GIZ Project Database (PROTOTYPE)")
19
  var = st.text_input("Enter Search Query")
20
 
 
23
  region_df = load_region_data(region_lookup_path)
24
 
25
  #################### Create the embeddings collection and save ######################
26
+ # the steps below need to be performed only once and then commented out any unnecssary compute over-run
27
+ ##### First we process and create the chunks for relvant data source
28
+ chunks = process_giz_worldwide()
29
+ ##### Convert to langchain documents
30
+ temp_doc = create_documents(chunks,'chunks')
31
+ ##### Embed and store docs, check if collection exist then you need to update the collection
32
  collection_name = "giz_worldwide"
33
+ hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
34
 
35
  ################### Hybrid Search ######################################################
36
  client = get_client()
 
47
  def get_country_name_and_region_mapping(_client, collection_name, region_df):
48
  results = hybrid_search(_client, "", collection_name)
49
  country_set = set()
 
50
  for res in results[0] + results[1]:
51
  countries = res.payload.get('metadata', {}).get('countries', "[]")
52
  try:
 
80
 
81
  # Region filter
82
  with col1:
83
+ region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions)) # Display region names
84
 
85
  # Dynamically filter countries based on selected region
86
  if region_filter == "All/Not allocated":
87
+ filtered_country_names = unique_country_names # Show all countries if no region is selected
88
  else:
89
  filtered_country_names = [
90
  name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter
 
92
 
93
  # Country filter
94
  with col2:
95
+ country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names) # Display filtered country names
96
+
97
+ # Year range slider
98
+ with col3:
99
+ current_year = datetime.now().year
100
+ default_start_year = current_year - 5
101
+
102
+ # 3) The max_value is now the actual max end_year from collection
103
+ end_year_range = st.slider(
104
+ "Project End Year",
105
+ min_value=2010,
106
+ max_value=max_end_year,
107
+ value=(default_start_year, max_end_year),
108
+ )
109
 
110
  # Checkbox to control whether to show only exact matches
111
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
112
 
113
+ def filter_results(results, country_filter, region_filter, end_year_range):
114
  filtered = []
115
  for r in results:
116
  metadata = r.payload.get('metadata', {})
 
141
  else:
142
  countries_in_region = c_list
143
 
144
+ # Filtering
145
  if (
146
  (country_filter == "All/Not allocated" or selected_iso_code in c_list)
147
  and (region_filter == "All/Not allocated" or countries_in_region)
148
+ and (end_year_range[0] <= end_year_val <= end_year_range[1])
149
  ):
150
  filtered.append(r)
151
  return filtered
152
 
153
  # Run the search
154
+
155
+ # 1) Adjust limit so we get more than 15 results
156
+ results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200
157
+
158
+ # results is a tuple: (semantic_results, lexical_results)
159
  semantic_all = results[0]
160
  lexical_all = results[1]
161
 
162
+ # 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score)
163
+ semantic_all = [
164
+ r for r in semantic_all if len(r.payload["page_content"]) >= 20
165
+ ]
166
+ lexical_all = [
167
+ r for r in lexical_all if len(r.payload["page_content"]) >= 20
168
+ ]
169
 
170
+ # 2) Apply a threshold to SEMANTIC results (score >= 0.4)
171
+ semantic_thresholded = [r for r in semantic_all if r.score >= 0.4]
172
 
173
+ # 2) Filter the entire sets
174
+ filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range)
175
+ filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range)
176
 
177
  filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
178
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
179
 
 
 
 
 
 
 
 
180
 
181
+ # 3) Retrieve top 15 *after* filtering
182
+ # Check user preference
183
  if show_exact_matches:
184
+ # 1) Display heading
185
  st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
186
 
187
+ # 2) Do a simple substring check (case-insensitive)
188
+ # We'll create a new list lexical_substring_filtered
189
  query_substring = var.strip().lower()
190
  lexical_substring_filtered = []
191
  for r in lexical_all:
192
+ # page_content in lowercase
193
+ page_text_lower = r.payload["page_content"].lower()
194
+ # Keep this result only if the query substring is found
195
+ if query_substring in page_text_lower:
196
  lexical_substring_filtered.append(r)
197
 
198
+ # 3) Now apply your region/country/year filter on that new list
199
+ filtered_lexical = filter_results(
200
+ lexical_substring_filtered, country_filter, region_filter, end_year_range
201
+ )
202
+
203
+ # 4) Remove duplicates
204
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
205
 
206
+ # 5) If empty after substring + filters + dedupe, show a custom message
207
  if not filtered_lexical_no_dupe:
208
  st.write('No exact matches, consider unchecking "Show only exact matches"')
209
  else:
210
+ # 6) Display the first 15 matching results
211
  for res in filtered_lexical_no_dupe[:15]:
212
+ project_name = res.payload['metadata'].get('project_name', 'Project Link')
213
+ url = res.payload['metadata'].get('url', '#')
214
+ st.markdown(f"#### [{project_name}]({url})")
215
+
216
+ # Snippet logic (80 words)
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  full_text = res.payload['page_content']
218
+ words = full_text.split()
219
+ preview_word_count = 80
220
+ preview_text = " ".join(words[:preview_word_count])
221
+ remainder_text = " ".join(words[preview_word_count:])
222
+ st.write(preview_text + ("..." if remainder_text else ""))
223
+
224
+ # Keywords
225
  top_keywords = extract_top_keywords(full_text, top_n=5)
226
  if top_keywords:
227
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
228
 
229
+ # Metadata
230
+ metadata = res.payload.get('metadata', {})
231
+ countries = metadata.get('countries', "[]")
232
  client_name = metadata.get('client', 'Unknown Client')
233
  start_year = metadata.get('start_year', None)
234
  end_year = metadata.get('end_year', None)
235
+
236
+ try:
237
+ c_list = json.loads(countries.replace("'", '"'))
238
+ except json.JSONDecodeError:
239
+ c_list = []
240
+
241
+ # Only keep country names if the region lookup (get_country_name)
242
+ # returns something different than the raw code.
243
+ matched_countries = []
244
+ for code in c_list:
245
+ if len(code) == 2:
246
+ resolved_name = get_country_name(code.upper(), region_df)
247
+ # If get_country_name didn't find a match,
248
+ # it typically just returns the same code (like "XX").
249
+ # We'll consider "successfully looked up" if
250
+ # resolved_name != code.upper().
251
+ if resolved_name.upper() != code.upper():
252
+ matched_countries.append(resolved_name)
253
+
254
+ # Format the year range
255
  start_year_str = f"{int(round(float(start_year)))}" if start_year else "Unknown"
256
  end_year_str = f"{int(round(float(end_year)))}" if end_year else "Unknown"
257
+
258
+ # Build the final string
259
+ if matched_countries:
260
+ # We have at least 1 valid country name
261
+ additional_text = (
262
+ f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**, "
263
+ f"**{start_year_str}-{end_year_str}**"
264
+ )
265
+ else:
266
+ # No valid countries found
267
+ additional_text = (
268
+ f"Commissioned by **{client_name}**, **{start_year_str}-{end_year_str}**"
269
+ )
270
+
271
  st.markdown(additional_text)
272
  st.divider()
273
 
 
277
  if not filtered_semantic_no_dupe:
278
  st.write("No relevant results found.")
279
  else:
280
+ # Show the top 15 from filtered_semantic
281
  for res in filtered_semantic_no_dupe[:15]:
282
+ project_name = res.payload['metadata'].get('project_name', 'Project Link')
283
+ url = res.payload['metadata'].get('url', '#')
284
+ st.markdown(f"#### [{project_name}]({url})")
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ # Snippet logic
287
  full_text = res.payload['page_content']
288
+ words = full_text.split()
289
+ preview_word_count = 10
290
+ preview_text = " ".join(words[:preview_word_count])
291
+ remainder_text = " ".join(words[preview_word_count:])
292
+ st.write(preview_text + ("..." if remainder_text else ""))
293
+
294
+ # Keywords
295
  top_keywords = extract_top_keywords(full_text, top_n=5)
296
  if top_keywords:
297
  st.markdown(f"_{' 路 '.join(top_keywords)}_")
298
 
299
+ # Metadata
300
+ metadata = res.payload.get('metadata', {})
301
+ countries = metadata.get('countries', "[]")
302
  client_name = metadata.get('client', 'Unknown Client')
303
  start_year = metadata.get('start_year', None)
304
  end_year = metadata.get('end_year', None)
305
+
306
+ try:
307
+ c_list = json.loads(countries.replace("'", '"'))
308
+ except json.JSONDecodeError:
309
+ c_list = []
310
+
311
+ # Only keep country names if the region lookup (get_country_name)
312
+ # returns something different than the raw code.
313
+ matched_countries = []
314
+ for code in c_list:
315
+ if len(code) == 2:
316
+ resolved_name = get_country_name(code.upper(), region_df)
317
+ # If get_country_name didn't find a match,
318
+ # it typically just returns the same code (like "XX").
319
+ # We'll consider "successfully looked up" if
320
+ # resolved_name != code.upper().
321
+ if resolved_name.upper() != code.upper():
322
+ matched_countries.append(resolved_name)
323
+
324
+ # Format the year range
325
  start_year_str = extract_year(start_year) if start_year else "Unknown"
326
  end_year_str = extract_year(end_year) if end_year else "Unknown"
327
+
328
+ # Build the final string
329
+ if matched_countries:
330
+ # We have at least 1 valid country name
331
+ additional_text = (
332
+ f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**, "
333
+ f"**{start_year_str}-{end_year_str}**"
334
+ )
335
+ else:
336
+ # No valid countries found
337
+ additional_text = (
338
+ f"Commissioned by **{client_name}**, **{start_year_str}-{end_year_str}**"
339
+ )
340
+
341
  st.markdown(additional_text)
342
  st.divider()
343
 
344
+
345
+ # for i in results:
346
+ # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
347
+ # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
348
+ # st.write(i.page_content)
349
+ # st.divider()