annikwag commited on
Commit
c966f4d
·
verified ·
1 Parent(s): fb7eabb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -374
app.py CHANGED
@@ -2,239 +2,70 @@ import streamlit as st
2
  import requests
3
  import pandas as pd
4
  import re
 
 
 
 
 
5
  from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
6
  from appStore.prep_utils import create_documents, get_client
7
  from appStore.embed import hybrid_embed_chunks
8
  from appStore.search import hybrid_search
9
- from appStore.region_utils import load_region_data, clean_country_code, get_country_name, get_regions
10
- #from appStore.tfidf_extraction import extract_top_keywords # TF-IDF part commented out
11
- from torch import cuda
12
- import json
13
- from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  st.set_page_config(page_title="SEARCH IATI", layout='wide')
16
 
17
 
18
  ###########################################
19
- # Helper functions for data processing
20
- ###########################################
21
-
22
- # New helper: Truncate a text to a given (approximate) token count.
23
- def truncate_to_tokens(text, max_tokens):
24
- tokens = text.split() # simple approximation
25
- if len(tokens) > max_tokens:
26
- return " ".join(tokens[:max_tokens])
27
- return text
28
-
29
- # Build a context string for a single result using title, objectives and description.
30
- def build_context_for_result(res):
31
- metadata = res.payload.get('metadata', {})
32
- # Compute title if not already present.
33
- title = metadata.get("title", compute_title(metadata))
34
- objective = metadata.get("objective", "")
35
- # Use description.en if available; otherwise use description.de.
36
- desc_en = metadata.get("description.en", "").strip()
37
- desc_de = metadata.get("description.de", "").strip()
38
- description = desc_en if desc_en != "" else desc_de
39
- return f"{title}\n{objective}\n{description}"
40
-
41
- # Updated highlight: return HTML that makes the matched query red and bold.
42
- def highlight_query(text, query):
43
- pattern = re.compile(re.escape(query), re.IGNORECASE)
44
- return pattern.sub(lambda m: f"<span style='color:red; font-weight:bold;'>{m.group(0)}</span>", text)
45
-
46
- # Helper: Format project id (e.g., "201940485" -> "2019.4048.5")
47
- def format_project_id(pid):
48
- s = str(pid)
49
- if len(s) > 5:
50
- return s[:4] + "." + s[4:-1] + "." + s[-1]
51
- return s
52
-
53
- # Helper: Compute title from metadata using name.en (or name.de if empty)
54
- def compute_title(metadata):
55
- name_en = metadata.get("name.en", "").strip()
56
- name_de = metadata.get("name.de", "").strip()
57
- base = name_en if name_en else name_de
58
- pid = metadata.get("id", "")
59
- if base and pid:
60
- return f"{base} [{format_project_id(pid)}]"
61
- return base or "No Title"
62
-
63
- # Load CRS lookup CSV and define a lookup function.
64
- crs_lookup = pd.read_csv("docStore/crs5_codes.csv") # Assumes columns: "code" and "new_crs_value"
65
- def lookup_crs_value(crs_key):
66
- # Ensure the input is a string and clean it.
67
- key_clean = re.sub(r'\.0$', '', str(crs_key).strip())
68
- # Compare against the CSV codes as strings.
69
- row = crs_lookup[crs_lookup["code"].astype(str) == key_clean]
70
- if not row.empty:
71
- # If a column named "new_crs_value" exists, use it.
72
- if "new_crs_value" in row.columns:
73
- try:
74
- return re.sub(r'\.0$', '', str(int(float(row.iloc[0]["new_crs_value"]))))
75
- except Exception:
76
- return re.sub(r'\.0$', '', str(row.iloc[0]["new_crs_value"]))
77
- else:
78
- # Otherwise, use the "name" column as the lookup value.
79
- return re.sub(r'\.0$', '', str(row.iloc[0]["name"]).strip())
80
- return ""
81
-
82
-
83
  ###########################################
84
- # RAG Answer function (Change 1 & 2 & 3)
85
- ###########################################
86
- # ToDo move functions to utils and model specifications to config file!
87
- # Configuration for the dedicated model
88
- # https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud # 4k token
89
  DEDICATED_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
90
  DEDICATED_ENDPOINT = "https://nwea79x4q1clc89l.eu-west-1.aws.endpoints.huggingface.cloud"
91
- # Write access token from the settings
92
  WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
93
 
94
- def get_rag_answer(query, top_results):
95
- # Build context from each top result using title, objective, and description.
96
- context = "\n\n".join([build_context_for_result(res) for res in top_results])
97
- # Truncate context to 11500 tokens (approximation)
98
- context = truncate_to_tokens(context, 11500)
99
- # Improved prompt with role instruction and formatting instruction.
100
- prompt = (
101
- "You are a project portfolio adviser at the development cooperation GIZ. "
102
- "Using the following context, answer the question in English precisely. "
103
- "Ensure that any project title mentioned in your answer is wrapped in ** (markdown bold). "
104
- "Only output the final answer below, without repeating the context or question.\n\n"
105
- f"Context:\n{context}\n\n"
106
- f"Question: {query}\n\n"
107
- "Answer:"
108
- )
109
- headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"}
110
- payload = {
111
- "inputs": prompt,
112
- "parameters": {"max_new_tokens": 300}
113
- }
114
- response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload)
115
- if response.status_code == 200:
116
- result = response.json()
117
- answer = result[0]["generated_text"]
118
- if "Answer:" in answer:
119
- answer = answer.split("Answer:")[-1].strip()
120
- return answer
121
- else:
122
- return f"Error in generating answer: {response.text}"
123
 
124
  ###########################################
125
- # CRS Options using lookup (Change 7)
126
  ###########################################
127
- @st.cache_data(show_spinner=False)
128
- def get_crs_options(_client, collection_name):
129
- # Optionally clear the cache if needed: st.cache_data.clear()
130
- results = hybrid_search(_client, "", collection_name)
131
- all_results = results[0] + results[1]
132
- crs_set = set()
133
- for res in all_results:
134
- metadata = res.payload.get('metadata', {})
135
- raw_crs_key = metadata.get("crs_key", "")
136
- # Convert to string and remove trailing ".0"
137
- crs_key_clean = re.sub(r'\.0$', '', str(raw_crs_key).strip())
138
- if crs_key_clean:
139
- # Ensure lookup input is clean
140
- lookup_input = crs_key_clean
141
- new_value_raw = lookup_crs_value(lookup_input)
142
- # Convert lookup return value to string and remove trailing ".0"
143
- new_value_clean = re.sub(r'\.0$', '', str(new_value_raw).strip())
144
- crs_combined = f"{crs_key_clean}: {new_value_clean}"
145
- crs_set.add(crs_combined)
146
- return sorted(crs_set)
147
-
148
-
149
  @st.cache_data
150
  def load_project_data():
151
- # Load your full project DataFrame using your processing function.
 
 
152
  return process_giz_worldwide()
153
 
154
- # Load the project data (cached)
155
  project_data = load_project_data()
156
 
157
- # Convert the 'total_project' column to numeric (dropping errors) and compute min and max.
158
- # The budget is assumed to be in euros, so we convert to million euros.
159
  budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
160
  min_budget_val = float(budget_series.min() / 1e6)
161
  max_budget_val = float(budget_series.max() / 1e6)
162
 
163
  ###########################################
164
- # Revised filter_results with budget filtering (Change 7 & 9)
165
  ###########################################
166
- def parse_budget(value):
167
- try:
168
- return float(value)
169
- except:
170
- return 0.0
171
-
172
- def filter_results(results, country_filter, region_filter, end_year_range, crs_filter, budget_filter):
173
- filtered = []
174
- for r in results:
175
- metadata = r.payload.get('metadata', {})
176
- country = metadata.get('country', "[]")
177
- year_str = metadata.get('end_year')
178
- if year_str:
179
- extracted = extract_year(year_str)
180
- try:
181
- end_year_val = int(extracted) if extracted != "Unknown" else 0
182
- except ValueError:
183
- end_year_val = 0
184
- else:
185
- end_year_val = 0
186
- if country.strip().startswith("["):
187
- try:
188
- parsed_country = json.loads(country.replace("'", '"'))
189
- if isinstance(parsed_country, str):
190
- country_list = [parsed_country]
191
- else:
192
- country_list = parsed_country
193
- c_list = [clean_country_code(code) for code in country_list if len(clean_country_code(code)) == 2]
194
- except json.JSONDecodeError:
195
- c_list = []
196
- else:
197
- c_list = [clean_country_code(country)]
198
-
199
- # After obtaining and cleaning the country codes into c_list:
200
- resolved_names = [get_country_name(code, region_df) for code in c_list]
201
-
202
- # And similarly for regions, if you have a mapping (e.g., iso_code_to_sub_region),
203
- # check if any of the codes in c_list has a matching region:
204
- country_in_region = any(
205
- iso_code_to_sub_region.get(code, "Not allocated") == region_filter
206
- for code in c_list
207
- )
208
-
209
-
210
- crs_key = metadata.get("crs_key", "").strip()
211
- # Convert crs_key to a string and remove trailing ".0"
212
- crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
213
- # Lookup using the cleaned key
214
- new_crs_value = lookup_crs_value(crs_key_clean)
215
- # Clean the lookup return value similarly
216
- new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value).strip())
217
- crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else ""
218
-
219
-
220
- # Enforce CRS filter only if specified.
221
- if crs_filter != "All/Not allocated" and crs_combined:
222
- if crs_filter != crs_combined:
223
- continue
224
-
225
- # Budget filtering: parse total_project value.
226
- budget_value = parse_budget(metadata.get('total_project', "0"))
227
- # Only keep results with budget >= budget_filter (in million euros, so multiply by 1e6)
228
- if budget_value < (budget_filter * 1e6):
229
- continue
230
-
231
- year_ok = True if end_year_val == 0 else (end_year_range[0] <= end_year_val <= end_year_range[1])
232
-
233
- if ((country_filter == "All/Not allocated" or (country_filter in resolved_names))
234
- and (region_filter == "All/Not allocated" or country_in_region)
235
- and year_ok):
236
- filtered.append(r)
237
- return filtered
238
 
239
  ###########################################
240
  # Get device
@@ -242,318 +73,297 @@ def filter_results(results, country_filter, region_filter, end_year_range, crs_f
242
  device = 'cuda' if cuda.is_available() else 'cpu'
243
 
244
  ###########################################
245
- # App heading and About button (Change 5 & 6)
246
  ###########################################
247
- col_title, col_about = st.columns([8,2])
248
  with col_title:
249
  st.markdown("<h1 style='text-align:center;'>GIZ Project Database (PROTOTYPE)</h1>", unsafe_allow_html=True)
250
  with col_about:
251
  with st.expander("ℹ️ About"):
252
  st.markdown(
253
  """
254
- This app is a prototype for testing purposes using publicly available project data from the German International Cooperation Society (GIZ) as of 23rd February 2025.
 
255
  **Please do NOT enter sensitive or personal information.**
256
  **Note**: The answers are AI-generated and may be wrong or misleading.
257
- """, unsafe_allow_html=True)
 
258
 
259
-
260
- ###########################################
261
- # Query input and budget slider (Change 9)
262
- ###########################################
263
  var = st.text_input("Enter Question")
264
 
265
-
266
- ###########################################
267
- # Load region lookup CSV
268
- ###########################################
269
- region_lookup_path = "docStore/regions_lookup.csv"
270
- region_df = load_region_data(region_lookup_path)
271
-
272
-
273
  ###########################################
274
- # Create the embeddings collection and save
275
  ###########################################
276
- # the steps below need to be performed only once and then commented out any unnecssary compute over-run
277
- ##### First we process and create the chunks for relvant data source
278
- #chunks = process_giz_worldwide()
279
- ##### Convert to langchain documents
280
- #temp_doc = create_documents(chunks,'chunks')
281
- ##### Embed and store docs, check if collection exist then you need to update the collection
282
  collection_name = "giz_worldwide"
283
- #hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
284
-
285
- ###########################################
286
- # Hybrid Search and Filters Setup
287
- ###########################################
288
  client = get_client()
289
  print(client.get_collections())
 
 
 
 
 
 
290
  max_end_year = get_max_end_year(client, collection_name)
291
  _, unique_sub_regions = get_regions(region_df)
292
 
293
- @st.cache_data
294
- def get_country_name_and_region_mapping(_client, collection_name, region_df):
295
- results = hybrid_search(_client, "", collection_name)
296
- country_set = set()
297
- for res in results[0] + results[1]:
298
- country = res.payload.get('metadata', {}).get('country', "[]")
299
- if country.strip().startswith("["):
300
- try:
301
- parsed_country = json.loads(country.replace("'", '"'))
302
- if isinstance(parsed_country, str):
303
- country_list = [parsed_country]
304
- else:
305
- country_list = parsed_country
306
- except json.JSONDecodeError:
307
- country_list = []
308
- else:
309
- country_list = [country.strip()]
310
- two_digit_codes = [clean_country_code(code) for code in country_list if len(clean_country_code(code)) == 2]
311
- country_set.update(two_digit_codes)
312
- country_name_to_code = {}
313
- iso_code_to_sub_region = {}
314
- for code in country_set:
315
- name = get_country_name(code, region_df)
316
- sub_region_row = region_df[region_df['alpha-2'] == code]
317
- sub_region = sub_region_row['sub-region'].values[0] if not sub_region_row.empty else "Not allocated"
318
- country_name_to_code[name] = code
319
- iso_code_to_sub_region[code] = sub_region
320
- return country_name_to_code, iso_code_to_sub_region
321
-
322
- client = get_client()
323
- country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df)
324
  unique_country_names = sorted(country_name_mapping.keys())
325
 
326
- # Layout filter columns
 
 
327
  col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
 
328
  with col1:
329
  region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
 
330
  if region_filter == "All/Not allocated":
331
  filtered_country_names = unique_country_names
332
  else:
333
- filtered_country_names = [name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter]
 
 
 
 
334
  with col2:
335
  country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
 
336
  with col3:
337
  current_year = datetime.now().year
338
  default_start_year = current_year - 4
339
- end_year_range = st.slider("Project End Year", min_value=2010, max_value=max_end_year, value=(default_start_year, max_end_year))
 
 
 
 
 
 
340
  with col4:
341
  crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
342
  crs_filter = st.selectbox("CRS", crs_options)
343
- with col5:
344
- # Now use these values as the slider range:
345
  min_budget = st.slider(
346
  "Minimum Project Budget (Million €)",
347
  min_value=min_budget_val,
348
  max_value=max_budget_val,
349
- value=min_budget_val)
350
-
351
 
352
- # Checkbox for exact matches
353
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
354
 
 
 
 
355
  if not var.strip():
356
  st.info("Please enter a question to see results.")
357
  else:
358
-
359
- ###########################################
360
- # Run the search and apply filters
361
- ###########################################
362
  results = hybrid_search(client, var, collection_name, limit=500)
363
- semantic_all = results[0]
364
- lexical_all = results[1]
 
365
  semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
366
  lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
 
 
367
  semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
368
-
369
- # Pass the budget filter (min_budget) into filter_results
370
- filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range, crs_filter, min_budget)
371
- filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range, crs_filter, min_budget)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
373
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
374
-
375
  def format_currency(value):
 
 
 
376
  try:
377
  return f"€{int(float(value)):,}"
378
  except (ValueError, TypeError):
379
  return value
380
-
381
- ###########################################
382
- # Display Results (Lexical and Semantic)
383
- ###########################################
384
- # --- Lexical Results Branch ---
385
  if show_exact_matches:
 
386
  st.write("Showing **Top 15 Lexical Search results**")
387
  query_substring = var.strip().lower()
388
- lexical_substring_filtered = [r for r in lexical_all if query_substring in r.payload["page_content"].lower()]
389
- filtered_lexical = filter_results(lexical_substring_filtered, country_filter, region_filter, end_year_range, crs_filter, min_budget)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
391
  if not filtered_lexical_no_dupe:
392
  st.write('No exact matches, consider unchecking "Show only exact matches"')
393
  else:
394
  top_results = filtered_lexical_no_dupe[:10]
395
- rag_answer = get_rag_answer(var, top_results)
396
- # Use the query as heading; increase size and center it.
397
  st.markdown(f"<h2 style='text-align:center; font-size:1.5em;'>{var}</h2>", unsafe_allow_html=True)
398
  st.write(rag_answer)
399
  st.divider()
 
 
400
  for res in top_results:
401
  metadata = res.payload.get('metadata', {})
402
  if "title" not in metadata:
403
  metadata["title"] = compute_title(metadata)
404
- # Highlight query matches in title (rendered with HTML)
 
405
  title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
406
  st.markdown(f"#### {title_html}", unsafe_allow_html=True)
407
- # Build snippet from objectives and description
 
408
  objective = metadata.get("objective", "None")
409
  desc_en = metadata.get("description.en", "").strip()
410
  desc_de = metadata.get("description.de", "").strip()
411
- description = desc_en if desc_en != "" else desc_de
412
  if not description:
413
  description = "No project description available"
414
- full_snippet = f"{description}"
415
- words = full_snippet.split()
416
  preview_word_count = 90
417
  preview_text = " ".join(words[:preview_word_count])
418
  remainder_text = " ".join(words[preview_word_count:])
419
- # If the preview text is empty, set a default message.
420
- if not preview_text.strip():
421
- preview_text = "No project description available"
422
- # Create two columns: left for description, right for additional details.
423
  col_left, col_right = st.columns(2)
424
  with col_left:
425
  st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
426
  if remainder_text:
427
  with st.expander("Show more"):
428
  st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
 
429
  with col_right:
430
- # Format additional text with line breaks using <br>
431
- start_year = metadata.get('start_year', None)
432
- end_year = metadata.get('end_year', None)
433
- start_year_str = extract_year(start_year) if start_year else "Unknown"
434
- end_year_str = extract_year(end_year) if end_year else "Unknown"
435
  total_project = metadata.get('total_project', "Unknown")
436
  total_volume = metadata.get('total_volume', "Unknown")
437
  formatted_project_budget = format_currency(total_project)
438
  formatted_total_volume = format_currency(total_volume)
439
  country_raw = metadata.get('country', "Unknown")
440
- country_value = metadata.get('country', "").strip()
441
- if country_value.startswith("["):
442
- try:
443
- c_list = json.loads(country_value.replace("'", '"'))
444
- except json.JSONDecodeError:
445
- c_list = []
446
- else:
447
- c_list = [country_value]
448
- matched_country = []
449
- for code in c_list:
450
- cleaned = clean_country_code(code)
451
- if len(cleaned) == 2:
452
- resolved_name = get_country_name(cleaned, region_df)
453
- if resolved_name.upper() != cleaned.upper():
454
- matched_country.append(resolved_name)
455
  crs_key = metadata.get("crs_key", "").strip()
456
- # Convert to string and remove trailing ".0"
457
- crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
458
- new_crs_value = lookup_crs_value(crs_key_clean)
459
- new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
460
- crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
461
- client_name = metadata.get('client', 'Unknown Client')
462
- contact = metadata.get("contact", "").strip()
463
- objective_highlighted = highlight_query(objective, var) if var.strip() else objective
464
  additional_text = (
465
- f"**Objective:** {objective_highlighted}<br>"
466
- f"**Commissioned by:** {client_name}<br>"
467
  f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
468
  f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
469
  f"**Country:** {country_raw}<br>"
470
- f"**Sector:** {crs_combined}"
471
  )
 
472
  if contact and contact.lower() != "[email protected]":
473
  additional_text += f"<br>**Contact:** [email protected]"
 
474
  st.markdown(additional_text, unsafe_allow_html=True)
 
475
  st.divider()
476
-
477
- # --- Semantic Results Branch ---
478
  else:
 
479
  if not filtered_semantic_no_dupe:
480
  st.write("No relevant results found.")
481
  else:
482
  top_results = filtered_semantic_no_dupe[:10]
483
- rag_answer = get_rag_answer(var, top_results)
484
  st.markdown(f"<h2 style='text-align:center; font-size:2.5em;'>{var}</h2>", unsafe_allow_html=True)
485
  st.write(rag_answer)
486
  st.divider()
487
  st.write("Showing **Top 15 Semantic Search results**")
 
488
  for res in top_results:
489
  metadata = res.payload.get('metadata', {})
490
  if "title" not in metadata:
491
  metadata["title"] = compute_title(metadata)
 
492
  st.markdown(f"#### {metadata['title']}")
493
- objective = metadata.get("objective", "")
494
  desc_en = metadata.get("description.en", "").strip()
495
  desc_de = metadata.get("description.de", "").strip()
496
- description = desc_en if desc_en != "" else desc_de
497
  if not description:
498
- description = "No project description available"
499
- full_snippet = f"{description}"
500
- words = full_snippet.split()
501
  preview_word_count = 90
502
  preview_text = " ".join(words[:preview_word_count])
503
  remainder_text = " ".join(words[preview_word_count:])
504
- # If the preview text is empty, set a default message.
505
- if not preview_text.strip():
506
- preview_text = "No project description available"
507
- # Create two columns: left for full description (with expander) and right for additional details.
508
  col_left, col_right = st.columns(2)
509
  with col_left:
510
  st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
511
  if remainder_text:
512
  with st.expander("Show more"):
513
  st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
 
514
  with col_right:
515
- start_year = metadata.get('start_year', None)
516
- end_year = metadata.get('end_year', None)
517
- start_year_str = extract_year(start_year) if start_year else "Unknown"
518
- end_year_str = extract_year(end_year) if end_year else "Unknown"
519
  total_project = metadata.get('total_project', "Unknown")
520
  total_volume = metadata.get('total_volume', "Unknown")
521
  formatted_project_budget = format_currency(total_project)
522
  formatted_total_volume = format_currency(total_volume)
523
  country_raw = metadata.get('country', "Unknown")
524
- country_value = metadata.get('country', "").strip()
525
- if country_value.startswith("["):
526
- try:
527
- c_list = json.loads(country_value.replace("'", '"'))
528
- except json.JSONDecodeError:
529
- c_list = []
530
- else:
531
- c_list = [country_value]
532
- matched_country = []
533
- for code in c_list:
534
- cleaned = clean_country_code(code)
535
- if len(cleaned) == 2:
536
- resolved_name = get_country_name(cleaned, region_df)
537
- if resolved_name.upper() != cleaned.upper():
538
- matched_country.append(resolved_name)
539
  crs_key = metadata.get("crs_key", "").strip()
540
- # Convert to string and remove trailing ".0"
541
- crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
542
- new_crs_value = lookup_crs_value(crs_key_clean)
543
- new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
544
- crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
545
- client_name = metadata.get('client', 'Unknown Client')
546
- contact = metadata.get("contact", "").strip()
547
  additional_text = (
548
- f"**Objective:** {objective}<br>"
549
- f"**Commissioned by:** {client_name}<br>"
550
  f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
551
  f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
552
  f"**Country:** {country_raw}<br>"
553
- #f"**Country:** {', '.join(matched_country)}<br>"
554
- f"**Sector:** {crs_combined}"
555
  )
 
556
  if contact and contact.lower() != "[email protected]":
557
  additional_text += f"<br>**Contact:** [email protected]"
 
558
  st.markdown(additional_text, unsafe_allow_html=True)
559
- st.divider()
 
 
2
  import requests
3
  import pandas as pd
4
  import re
5
+ import json
6
+ from datetime import datetime
7
+ from torch import cuda
8
+
9
+ # Import existing modules from appStore
10
  from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
11
  from appStore.prep_utils import create_documents, get_client
12
  from appStore.embed import hybrid_embed_chunks
13
  from appStore.search import hybrid_search
14
+ from appStore.region_utils import (
15
+ load_region_data,
16
+ clean_country_code,
17
+ get_country_name,
18
+ get_regions,
19
+ get_country_name_and_region_mapping
20
+ )
21
+ # TF-IDF part (excluded from the app for now)
22
+ # from appStore.tfidf_extraction import extract_top_keywords
23
+
24
+ # Import your new helper modules
25
+ from appStore.rag_utils import (
26
+ highlight_query,
27
+ get_rag_answer,
28
+ compute_title
29
+ )
30
+ from appStore.filter_utils import (
31
+ parse_budget,
32
+ filter_results,
33
+ get_crs_options
34
+ )
35
 
36
  st.set_page_config(page_title="SEARCH IATI", layout='wide')
37
 
38
 
39
  ###########################################
40
+ # Global / Model Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ###########################################
 
 
 
 
 
42
  DEDICATED_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
43
  DEDICATED_ENDPOINT = "https://nwea79x4q1clc89l.eu-west-1.aws.endpoints.huggingface.cloud"
 
44
  WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  ###########################################
48
+ # Cache the project data
49
  ###########################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @st.cache_data
51
  def load_project_data():
52
+ """
53
+ Load and process the GIZ worldwide data, returning a pandas DataFrame.
54
+ """
55
  return process_giz_worldwide()
56
 
 
57
  project_data = load_project_data()
58
 
59
+ # Determine min and max budgets in million euros
 
60
  budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
61
  min_budget_val = float(budget_series.min() / 1e6)
62
  max_budget_val = float(budget_series.max() / 1e6)
63
 
64
  ###########################################
65
+ # Prepare region data
66
  ###########################################
67
+ region_lookup_path = "docStore/regions_lookup.csv"
68
+ region_df = load_region_data(region_lookup_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  ###########################################
71
  # Get device
 
73
  device = 'cuda' if cuda.is_available() else 'cpu'
74
 
75
  ###########################################
76
+ # Streamlit App Layout
77
  ###########################################
78
+ col_title, col_about = st.columns([8, 2])
79
  with col_title:
80
  st.markdown("<h1 style='text-align:center;'>GIZ Project Database (PROTOTYPE)</h1>", unsafe_allow_html=True)
81
  with col_about:
82
  with st.expander("ℹ️ About"):
83
  st.markdown(
84
  """
85
+ This app is a prototype for testing purposes using publicly available project data
86
+ from the German International Cooperation Society (GIZ) as of 23rd February 2025.
87
  **Please do NOT enter sensitive or personal information.**
88
  **Note**: The answers are AI-generated and may be wrong or misleading.
89
+ """, unsafe_allow_html=True
90
+ )
91
 
92
+ # Main query input
 
 
 
93
  var = st.text_input("Enter Question")
94
 
 
 
 
 
 
 
 
 
95
  ###########################################
96
+ # Create or load the embeddings collection
97
  ###########################################
 
 
 
 
 
 
98
  collection_name = "giz_worldwide"
 
 
 
 
 
99
  client = get_client()
100
  print(client.get_collections())
101
+
102
+ # If needed, once only:
103
+ # chunks = process_giz_worldwide()
104
+ # temp_doc = create_documents(chunks, 'chunks')
105
+ # hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
106
+
107
  max_end_year = get_max_end_year(client, collection_name)
108
  _, unique_sub_regions = get_regions(region_df)
109
 
110
+ # Build country->code and code->region mapping
111
+ country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
112
+ client,
113
+ collection_name,
114
+ region_df,
115
+ hybrid_search,
116
+ clean_country_code,
117
+ get_country_name
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  unique_country_names = sorted(country_name_mapping.keys())
120
 
121
+ ###########################################
122
+ # Filter Controls
123
+ ###########################################
124
  col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
125
+
126
  with col1:
127
  region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
128
+
129
  if region_filter == "All/Not allocated":
130
  filtered_country_names = unique_country_names
131
  else:
132
+ filtered_country_names = [
133
+ name for name, code in country_name_mapping.items()
134
+ if iso_code_to_sub_region.get(code) == region_filter
135
+ ]
136
+
137
  with col2:
138
  country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names)
139
+
140
  with col3:
141
  current_year = datetime.now().year
142
  default_start_year = current_year - 4
143
+ end_year_range = st.slider(
144
+ "Project End Year",
145
+ min_value=2010,
146
+ max_value=max_end_year,
147
+ value=(default_start_year, max_end_year)
148
+ )
149
+
150
  with col4:
151
  crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
152
  crs_filter = st.selectbox("CRS", crs_options)
153
+
154
+ with col5:
155
  min_budget = st.slider(
156
  "Minimum Project Budget (Million €)",
157
  min_value=min_budget_val,
158
  max_value=max_budget_val,
159
+ value=min_budget_val
160
+ )
161
 
 
162
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
163
 
164
+ ###########################################
165
+ # Main Search / Results
166
+ ###########################################
167
  if not var.strip():
168
  st.info("Please enter a question to see results.")
169
  else:
170
+ # 1) Perform hybrid search
 
 
 
171
  results = hybrid_search(client, var, collection_name, limit=500)
172
+ semantic_all, lexical_all = results[0], results[1]
173
+
174
+ # Filter out short pages
175
  semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
176
  lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
177
+
178
+ # Apply threshold to semantic results if desired
179
  semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
180
+
181
+ # 2) Filter results based on the user’s selections
182
+ filtered_semantic = filter_results(
183
+ semantic_thresholded,
184
+ country_filter,
185
+ region_filter,
186
+ end_year_range,
187
+ crs_filter,
188
+ min_budget,
189
+ region_df,
190
+ iso_code_to_sub_region,
191
+ clean_country_code,
192
+ get_country_name
193
+ )
194
+ filtered_lexical = filter_results(
195
+ lexical_all,
196
+ country_filter,
197
+ region_filter,
198
+ end_year_range,
199
+ crs_filter,
200
+ min_budget,
201
+ region_df,
202
+ iso_code_to_sub_region,
203
+ clean_country_code,
204
+ get_country_name
205
+ )
206
+
207
+ # Remove duplicates
208
  filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
209
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
210
+
211
  def format_currency(value):
212
+ """
213
+ Format a numeric or string value as currency (EUR) with commas.
214
+ """
215
  try:
216
  return f"€{int(float(value)):,}"
217
  except (ValueError, TypeError):
218
  return value
219
+
220
+ # 3) Display results
 
 
 
221
  if show_exact_matches:
222
+ # Lexical substring match only
223
  st.write("Showing **Top 15 Lexical Search results**")
224
  query_substring = var.strip().lower()
225
+ lexical_substring_filtered = [
226
+ r for r in lexical_all
227
+ if query_substring in r.payload["page_content"].lower()
228
+ ]
229
+ filtered_lexical = filter_results(
230
+ lexical_substring_filtered,
231
+ country_filter,
232
+ region_filter,
233
+ end_year_range,
234
+ crs_filter,
235
+ min_budget,
236
+ region_df,
237
+ iso_code_to_sub_region,
238
+ clean_country_code,
239
+ get_country_name
240
+ )
241
  filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
242
  if not filtered_lexical_no_dupe:
243
  st.write('No exact matches, consider unchecking "Show only exact matches"')
244
  else:
245
  top_results = filtered_lexical_no_dupe[:10]
246
+ # RAG answer
247
+ rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
248
  st.markdown(f"<h2 style='text-align:center; font-size:1.5em;'>{var}</h2>", unsafe_allow_html=True)
249
  st.write(rag_answer)
250
  st.divider()
251
+
252
+ # Show each result
253
  for res in top_results:
254
  metadata = res.payload.get('metadata', {})
255
  if "title" not in metadata:
256
  metadata["title"] = compute_title(metadata)
257
+
258
+ # Title
259
  title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
260
  st.markdown(f"#### {title_html}", unsafe_allow_html=True)
261
+
262
+ # Description snippet
263
  objective = metadata.get("objective", "None")
264
  desc_en = metadata.get("description.en", "").strip()
265
  desc_de = metadata.get("description.de", "").strip()
266
+ description = desc_en if desc_en else desc_de
267
  if not description:
268
  description = "No project description available"
269
+ words = description.split()
 
270
  preview_word_count = 90
271
  preview_text = " ".join(words[:preview_word_count])
272
  remainder_text = " ".join(words[preview_word_count:])
273
+
 
 
 
274
  col_left, col_right = st.columns(2)
275
  with col_left:
276
  st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
277
  if remainder_text:
278
  with st.expander("Show more"):
279
  st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
280
+
281
  with col_right:
282
+ start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
283
+ end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
 
 
 
284
  total_project = metadata.get('total_project', "Unknown")
285
  total_volume = metadata.get('total_volume', "Unknown")
286
  formatted_project_budget = format_currency(total_project)
287
  formatted_total_volume = format_currency(total_volume)
288
  country_raw = metadata.get('country', "Unknown")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  crs_key = metadata.get("crs_key", "").strip()
290
+
291
+ # Additional text
 
 
 
 
 
 
292
  additional_text = (
293
+ f"**Objective:** {highlight_query(objective, var)}<br>"
294
+ f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
295
  f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
296
  f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
297
  f"**Country:** {country_raw}<br>"
298
+ f"**Sector:** {crs_key if crs_key else 'Unknown'}"
299
  )
300
+ contact = metadata.get("contact", "").strip()
301
  if contact and contact.lower() != "[email protected]":
302
  additional_text += f"<br>**Contact:** [email protected]"
303
+
304
  st.markdown(additional_text, unsafe_allow_html=True)
305
+
306
  st.divider()
307
+
 
308
  else:
309
+ # Semantic results
310
  if not filtered_semantic_no_dupe:
311
  st.write("No relevant results found.")
312
  else:
313
  top_results = filtered_semantic_no_dupe[:10]
314
+ rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
315
  st.markdown(f"<h2 style='text-align:center; font-size:2.5em;'>{var}</h2>", unsafe_allow_html=True)
316
  st.write(rag_answer)
317
  st.divider()
318
  st.write("Showing **Top 15 Semantic Search results**")
319
+
320
  for res in top_results:
321
  metadata = res.payload.get('metadata', {})
322
  if "title" not in metadata:
323
  metadata["title"] = compute_title(metadata)
324
+
325
  st.markdown(f"#### {metadata['title']}")
326
+
327
  desc_en = metadata.get("description.en", "").strip()
328
  desc_de = metadata.get("description.de", "").strip()
329
+ description = desc_en if desc_en else desc_de
330
  if not description:
331
+ description = "No project description available"
332
+
333
+ words = description.split()
334
  preview_word_count = 90
335
  preview_text = " ".join(words[:preview_word_count])
336
  remainder_text = " ".join(words[preview_word_count:])
337
+
 
 
 
338
  col_left, col_right = st.columns(2)
339
  with col_left:
340
  st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
341
  if remainder_text:
342
  with st.expander("Show more"):
343
  st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
344
+
345
  with col_right:
346
+ start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
347
+ end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
 
 
348
  total_project = metadata.get('total_project', "Unknown")
349
  total_volume = metadata.get('total_volume', "Unknown")
350
  formatted_project_budget = format_currency(total_project)
351
  formatted_total_volume = format_currency(total_volume)
352
  country_raw = metadata.get('country', "Unknown")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  crs_key = metadata.get("crs_key", "").strip()
354
+
 
 
 
 
 
 
355
  additional_text = (
356
+ f"**Objective:** {metadata.get('objective', '')}<br>"
357
+ f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}<br>"
358
  f"**Projekt duration:** {start_year_str}-{end_year_str}<br>"
359
  f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}<br>"
360
  f"**Country:** {country_raw}<br>"
361
+ f"**Sector:** {crs_key if crs_key else 'Unknown'}"
 
362
  )
363
+ contact = metadata.get("contact", "").strip()
364
  if contact and contact.lower() != "[email protected]":
365
  additional_text += f"<br>**Contact:** [email protected]"
366
+
367
  st.markdown(additional_text, unsafe_allow_html=True)
368
+
369
+ st.divider()