abdull4h commited on
Commit
1f683db
·
verified ·
1 Parent(s): 72544b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +733 -113
app.py CHANGED
@@ -1,138 +1,758 @@
1
  import os
2
- import gradio as gr
 
 
 
 
 
 
3
  import spaces
4
 
5
- # Create a simple version for quick testing
6
- def main():
7
- with gr.Blocks(title="Vision 2030 Assistant - Debugging") as interface:
8
- gr.Markdown("# Vision 2030 Assistant - System Check")
9
- gr.Markdown("This interface tests your system configuration to ensure all components are working.")
10
-
11
- # Check files tab
12
- with gr.Tab("File Check"):
13
- gr.Markdown("### Check PDF Files and Directory Structure")
14
- check_btn = gr.Button("Check Files")
15
- files_output = gr.Textbox(label="Files Status", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def check_files():
18
- results = []
19
- # List files
20
- results.append("Files in directory:")
21
- files = os.listdir(".")
22
- results.append("\n".join(files))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Check PDFs
25
- results.append("\nPDF Status:")
26
- for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]:
27
- if os.path.exists(pdf_file):
28
- size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB
29
- results.append(f"{pdf_file}: Found ({size:.2f} MB)")
30
- else:
31
- results.append(f"{pdf_file}: Not found")
32
-
33
- return "\n".join(results)
34
 
35
- check_btn.click(check_files, inputs=[], outputs=[files_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Check dependencies tab
38
- with gr.Tab("Dependency Check"):
39
- gr.Markdown("### Check Required Dependencies")
40
- dep_btn = gr.Button("Check Dependencies")
41
- dep_output = gr.Textbox(label="Dependency Status", lines=20)
 
42
 
43
- @spaces.GPU
44
- def check_dependencies():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  results = []
46
 
47
- # Basic dependencies
48
- for lib_name in [
49
- "torch",
50
- "transformers",
51
- "sentencepiece",
52
- "accelerate",
53
- "langchain",
54
- "langchain_community",
55
- "PyPDF2"
56
- ]:
57
  try:
58
- module = __import__(lib_name)
59
  if hasattr(module, "__version__"):
60
- results.append(f"✓ {lib_name}: {module.__version__}")
61
  else:
62
- results.append(f"✓ {lib_name}: Installed (no version info)")
63
  except ImportError:
64
- results.append(f"✗ {lib_name}: Not installed")
65
- except Exception as e:
66
- results.append(f"? {lib_name}: Error - {str(e)}")
67
 
68
- # Test GPU access
69
  try:
70
  import torch
71
- results.append(f"\nGPU status:")
72
  results.append(f"CUDA available: {torch.cuda.is_available()}")
73
  if torch.cuda.is_available():
74
- results.append(f"CUDA device count: {torch.cuda.device_count()}")
75
- results.append(f"CUDA current device: {torch.cuda.current_device()}")
76
- results.append(f"CUDA device name: {torch.cuda.get_device_name(0)}")
77
- except Exception as e:
78
- results.append(f"GPU status check error: {str(e)}")
79
-
80
- return "\n".join(results)
81
-
82
- dep_btn.click(check_dependencies, inputs=[], outputs=[dep_output])
83
-
84
- # Test tokenizer tab
85
- with gr.Tab("Model Check"):
86
- gr.Markdown("### Test Model Loading")
87
- model_btn = gr.Button("Test Tokenizer Only")
88
- model_output = gr.Textbox(label="Model Status", lines=15)
89
-
90
- @spaces.GPU
91
- def test_tokenizer():
92
- results = []
93
 
94
- try:
95
- results.append("Testing AutoTokenizer...")
96
- from transformers import AutoTokenizer
97
-
98
- # Check if accelerate is available
99
- try:
100
- import accelerate
101
- results.append(f"Accelerate version: {accelerate.__version__}")
102
- except ImportError:
103
- results.append("Accelerate is not installed")
104
-
105
- # Check if sentencepiece is available
106
- try:
107
- import sentencepiece
108
- results.append("SentencePiece is installed")
109
- except ImportError:
110
- results.append("SentencePiece is not installed")
111
-
112
- # Try loading just the tokenizer
113
- model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
114
- tokenizer = AutoTokenizer.from_pretrained(
115
- model_name,
116
- trust_remote_code=True,
117
- use_fast=False
118
- )
119
- results.append(f"✓ Successfully loaded tokenizer for {model_name}")
120
-
121
- # Test tokenization
122
- tokens = tokenizer("Hello, this is a test", return_tensors="pt")
123
- results.append(f"✓ Tokenizer works properly")
124
- results.append(f"Input IDs: {tokens.input_ids.shape}")
125
-
126
- except Exception as e:
127
- results.append(f"✗ Error: {str(e)}")
128
- import traceback
129
- results.append(traceback.format_exc())
130
 
131
  return "\n".join(results)
132
 
133
- model_btn.click(test_tokenizer, inputs=[], outputs=[model_output])
134
-
135
- interface.launch()
 
 
136
 
137
  if __name__ == "__main__":
138
- main()
 
 
1
  import os
2
+ import re
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
  import spaces
10
 
11
+ # PDF processing
12
+ import PyPDF2
13
+
14
+ # LLM and embeddings
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ from sentence_transformers import SentenceTransformer
17
+
18
+ # RAG components
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain_community.vectorstores import FAISS
21
+ from langchain.schema import Document
22
+ from langchain.embeddings import HuggingFaceEmbeddings
23
+
24
+ # Arabic text processing
25
+ import arabic_reshaper
26
+ from bidi.algorithm import get_display
27
+
28
+ # Evaluation
29
+ from rouge_score import rouge_scorer
30
+ import sacrebleu
31
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
32
+ import matplotlib.pyplot as plt
33
+ import seaborn as sns
34
+ from collections import defaultdict
35
+
36
+ # Gradio for the interface
37
+ import gradio as gr
38
+
39
+ # Helper functions
40
+ def safe_tokenize(text):
41
+ """Pure regex tokenizer with no NLTK dependency"""
42
+ if not text:
43
+ return []
44
+ # Replace punctuation with spaces around them
45
+ text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text)
46
+ # Split on whitespace and filter empty strings
47
+ return [token for token in re.split(r'\s+', text.lower()) if token]
48
+
49
+ def detect_language(text):
50
+ """Detect if text is primarily Arabic or English"""
51
+ # Simple heuristic: count Arabic characters
52
+ arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
53
+ is_arabic = len(arabic_chars) > len(text) * 0.5
54
+ return "arabic" if is_arabic else "english"
55
+
56
+ # Evaluation metrics
57
+ def calculate_bleu(prediction, reference):
58
+ """Calculate BLEU score without any NLTK dependency"""
59
+ # Tokenize texts using our own tokenizer
60
+ pred_tokens = safe_tokenize(prediction.lower())
61
+ ref_tokens = [safe_tokenize(reference.lower())]
62
+
63
+ # If either is empty, return 0
64
+ if not pred_tokens or not ref_tokens[0]:
65
+ return {"bleu_1": 0, "bleu_2": 0, "bleu_4": 0}
66
+
67
+ # Get n-grams function
68
+ def get_ngrams(tokens, n):
69
+ return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
70
+
71
+ # Calculate precision for each n-gram level
72
+ precisions = []
73
+ for n in range(1, 5): # 1-gram to 4-gram
74
+ if len(pred_tokens) < n:
75
+ precisions.append(0)
76
+ continue
77
 
78
+ pred_ngrams = get_ngrams(pred_tokens, n)
79
+ ref_ngrams = get_ngrams(ref_tokens[0], n)
80
+
81
+ # Count matches
82
+ matches = sum(1 for ng in pred_ngrams if ng in ref_ngrams)
83
+
84
+ # Calculate precision
85
+ if pred_ngrams:
86
+ precisions.append(matches / len(pred_ngrams))
87
+ else:
88
+ precisions.append(0)
89
+
90
+ # Return BLEU scores
91
+ return {
92
+ "bleu_1": precisions[0],
93
+ "bleu_2": (precisions[0] * precisions[1]) ** 0.5 if len(precisions) > 1 else 0,
94
+ "bleu_4": (precisions[0] * precisions[1] * precisions[2] * precisions[3]) ** 0.25 if len(precisions) > 3 else 0
95
+ }
96
+
97
+ def calculate_meteor(prediction, reference):
98
+ """Simple word overlap metric as METEOR alternative"""
99
+ # Tokenize with our custom tokenizer
100
+ pred_tokens = set(safe_tokenize(prediction.lower()))
101
+ ref_tokens = set(safe_tokenize(reference.lower()))
102
+
103
+ # Calculate Jaccard similarity as METEOR alternative
104
+ if not pred_tokens or not ref_tokens:
105
+ return 0
106
+
107
+ intersection = len(pred_tokens.intersection(ref_tokens))
108
+ union = len(pred_tokens.union(ref_tokens))
109
+
110
+ return intersection / union if union > 0 else 0
111
+
112
+ def calculate_f1_precision_recall(prediction, reference):
113
+ """Calculate word-level F1, precision, and recall with custom tokenizer"""
114
+ # Tokenize with our custom tokenizer
115
+ pred_tokens = set(safe_tokenize(prediction.lower()))
116
+ ref_tokens = set(safe_tokenize(reference.lower()))
117
+
118
+ # Calculate overlap
119
+ common = pred_tokens.intersection(ref_tokens)
120
+
121
+ # Calculate precision, recall, F1
122
+ precision = len(common) / len(pred_tokens) if pred_tokens else 0
123
+ recall = len(common) / len(ref_tokens) if ref_tokens else 0
124
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
125
+
126
+ return {'precision': precision, 'recall': recall, 'f1': f1}
127
+
128
+ def evaluate_retrieval_quality(contexts, query, language):
129
+ """Evaluate the quality of retrieved contexts"""
130
+ # This is a placeholder implementation
131
+ return {
132
+ 'language_match_ratio': 1.0,
133
+ 'source_diversity': len(set([ctx.get('source', '') for ctx in contexts])) / max(1, len(contexts)),
134
+ 'mrr': 1.0
135
+ }
136
+
137
+ # PDF Processing and Vector Store
138
+ def simple_process_pdfs(pdf_paths):
139
+ """Process PDF documents and return document objects"""
140
+ documents = []
141
+
142
+ print(f"Processing PDFs: {pdf_paths}")
143
+
144
+ for pdf_path in pdf_paths:
145
+ try:
146
+ if not os.path.exists(pdf_path):
147
+ print(f"Warning: {pdf_path} does not exist")
148
+ continue
149
 
150
+ print(f"Processing {pdf_path}...")
151
+ text = ""
152
+ with open(pdf_path, 'rb') as file:
153
+ reader = PyPDF2.PdfReader(file)
154
+ for page in reader.pages:
155
+ page_text = page.extract_text()
156
+ if page_text: # If we got text from this page
157
+ text += page_text + "\n\n"
 
 
158
 
159
+ if text.strip(): # If we got some text
160
+ doc = Document(
161
+ page_content=text,
162
+ metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
163
+ )
164
+ documents.append(doc)
165
+ print(f"Successfully processed: {pdf_path}")
166
+ else:
167
+ print(f"Warning: No text extracted from {pdf_path}")
168
+ except Exception as e:
169
+ print(f"Error processing {pdf_path}: {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
+
173
+ print(f"Processed {len(documents)} PDF documents")
174
+ return documents
175
+
176
+ def create_vector_store(documents):
177
+ """Split documents into chunks and create a FAISS vector store"""
178
+ # Text splitter for breaking documents into chunks
179
+ text_splitter = RecursiveCharacterTextSplitter(
180
+ chunk_size=500,
181
+ chunk_overlap=50,
182
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
183
+ )
184
+
185
+ # Split documents into chunks
186
+ chunks = []
187
+ for doc in documents:
188
+ doc_chunks = text_splitter.split_text(doc.page_content)
189
+ # Preserve metadata for each chunk
190
+ chunks.extend([
191
+ Document(page_content=chunk, metadata=doc.metadata)
192
+ for chunk in doc_chunks
193
+ ])
194
+
195
+ print(f"Created {len(chunks)} chunks from {len(documents)} documents")
196
+
197
+ # Create a proper embedding function for LangChain
198
+ embedding_function = HuggingFaceEmbeddings(
199
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
200
+ )
201
+
202
+ # Create FAISS index
203
+ vector_store = FAISS.from_documents(
204
+ chunks,
205
+ embedding_function
206
+ )
207
+
208
+ return vector_store
209
+
210
+ # Model Loading and RAG System
211
+ @spaces.GPU
212
+ def load_model_and_tokenizer():
213
+ """Load the ALLaM-7B model and tokenizer with error handling"""
214
+ model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
215
+ print(f"Loading model: {model_name}")
216
+
217
+ try:
218
+ # Load tokenizer with correct settings
219
+ tokenizer = AutoTokenizer.from_pretrained(
220
+ model_name,
221
+ trust_remote_code=True,
222
+ use_fast=False
223
+ )
224
+
225
+ # Load model with appropriate settings for ALLaM
226
+ model = AutoModelForCausalLM.from_pretrained(
227
+ model_name,
228
+ torch_dtype=torch.bfloat16,
229
+ trust_remote_code=True,
230
+ device_map="auto",
231
+ )
232
+
233
+ print("Model loaded successfully!")
234
+ return model, tokenizer
235
+
236
+ except Exception as e:
237
+ print(f"Error loading model: {e}")
238
+ import traceback
239
+ traceback.print_exc()
240
+ raise Exception(f"Failed to load model: {e}")
241
+
242
+ def retrieve_context(query, vector_store, top_k=5):
243
+ """Retrieve most relevant document chunks for a given query"""
244
+ # Search the vector store using similarity search
245
+ results = vector_store.similarity_search_with_score(query, k=top_k)
246
+
247
+ # Format the retrieved contexts
248
+ contexts = []
249
+ for doc, score in results:
250
+ contexts.append({
251
+ "content": doc.page_content,
252
+ "source": doc.metadata.get("source", "Unknown"),
253
+ "relevance_score": score
254
+ })
255
+
256
+ return contexts
257
+
258
+ @spaces.GPU
259
+ def generate_response(query, contexts, model, tokenizer, language="auto"):
260
+ """Generate a response using retrieved contexts with ALLaM-specific formatting"""
261
+ # Auto-detect language if not specified
262
+ if language == "auto":
263
+ language = detect_language(query)
264
+
265
+ # Format the prompt based on language
266
+ if language == "arabic":
267
+ instruction = (
268
+ "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
269
+ "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
270
+ )
271
+ else: # english
272
+ instruction = (
273
+ "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
274
+ "If you don't know the answer, honestly say you don't know."
275
+ )
276
+
277
+ # Combine retrieved contexts
278
+ context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
279
+
280
+ # Format the prompt for ALLaM instruction format
281
+ prompt = f"""<s>[INST] {instruction}
282
+
283
+ Context:
284
+ {context_text}
285
+
286
+ Question: {query} [/INST]</s>"""
287
+
288
+ try:
289
+ # Generate response with appropriate parameters for ALLaM
290
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
291
+
292
+ # Generate with appropriate parameters
293
+ outputs = model.generate(
294
+ inputs.input_ids,
295
+ attention_mask=inputs.attention_mask,
296
+ max_new_tokens=512,
297
+ temperature=0.7,
298
+ top_p=0.9,
299
+ do_sample=True,
300
+ repetition_penalty=1.1
301
+ )
302
+
303
+ # Decode the response
304
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
305
 
306
+ # Extract just the answer part (after the instruction)
307
+ response = full_output.split("[/INST]")[-1].strip()
308
+
309
+ # If response is empty for some reason, return the full output
310
+ if not response:
311
+ response = full_output
312
 
313
+ return response
314
+
315
+ except Exception as e:
316
+ print(f"Error during generation: {e}")
317
+ # Fallback response
318
+ return "I apologize, but I encountered an error while generating a response."
319
+
320
+ # Assistant Class
321
+ class Vision2030Assistant:
322
+ def __init__(self, model, tokenizer, vector_store):
323
+ self.model = model
324
+ self.tokenizer = tokenizer
325
+ self.vector_store = vector_store
326
+ self.conversation_history = []
327
+
328
+ def answer(self, user_query):
329
+ """Process a user query and return a response with sources"""
330
+ # Detect language
331
+ language = detect_language(user_query)
332
+
333
+ # Add user query to conversation history
334
+ self.conversation_history.append({"role": "user", "content": user_query})
335
+
336
+ # Get the full conversation context
337
+ conversation_context = "\n".join([
338
+ f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
339
+ for msg in self.conversation_history[-6:] # Keep last 3 turns (6 messages)
340
+ ])
341
+
342
+ # Enhance query with conversation context for better retrieval
343
+ enhanced_query = f"{conversation_context}\n{user_query}"
344
+
345
+ # Retrieve relevant contexts
346
+ contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5)
347
+
348
+ # Generate response
349
+ response = generate_response(user_query, contexts, self.model, self.tokenizer, language)
350
+
351
+ # Add response to conversation history
352
+ self.conversation_history.append({"role": "assistant", "content": response})
353
+
354
+ # Also return sources for transparency
355
+ sources = [ctx.get("source", "Unknown") for ctx in contexts]
356
+ unique_sources = list(set(sources))
357
+
358
+ return response, unique_sources, contexts
359
+
360
+ def reset_conversation(self):
361
+ """Reset the conversation history"""
362
+ self.conversation_history = []
363
+ return "Conversation has been reset."
364
+
365
+ # Comprehensive evaluation dataset
366
+ comprehensive_evaluation_data = [
367
+ # === Overview ===
368
+ {
369
+ "query": "ما هي رؤية السعودية 2030؟",
370
+ "reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.",
371
+ "category": "overview",
372
+ "language": "arabic"
373
+ },
374
+ {
375
+ "query": "What is Saudi Vision 2030?",
376
+ "reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.",
377
+ "category": "overview",
378
+ "language": "english"
379
+ },
380
+
381
+ # === Economic Goals ===
382
+ {
383
+ "query": "ما هي الأهداف الاقتصادية لرؤية 2030؟",
384
+ "reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.",
385
+ "category": "economic",
386
+ "language": "arabic"
387
+ },
388
+ {
389
+ "query": "What are the economic goals of Vision 2030?",
390
+ "reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.",
391
+ "category": "economic",
392
+ "language": "english"
393
+ },
394
+
395
+ # === Social Goals ===
396
+ {
397
+ "query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟",
398
+ "reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.",
399
+ "category": "social",
400
+ "language": "arabic"
401
+ },
402
+ {
403
+ "query": "How does Vision 2030 aim to improve quality of life?",
404
+ "reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.",
405
+ "category": "social",
406
+ "language": "english"
407
+ }
408
+ ]
409
+
410
+ # Gradio Interface
411
+ def initialize_system():
412
+ """Initialize the Vision 2030 Assistant system"""
413
+ # Define paths for PDF files in the root directory
414
+ pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]
415
+
416
+ # Process PDFs and create vector store
417
+ vector_store_dir = "vector_stores"
418
+ os.makedirs(vector_store_dir, exist_ok=True)
419
+
420
+ if os.path.exists(os.path.join(vector_store_dir, "index.faiss")):
421
+ print("Loading existing vector store...")
422
+ embedding_function = HuggingFaceEmbeddings(
423
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
424
+ )
425
+ vector_store = FAISS.load_local(vector_store_dir, embedding_function)
426
+ else:
427
+ print("Creating new vector store...")
428
+ documents = simple_process_pdfs(pdf_files)
429
+ if not documents:
430
+ raise ValueError("No documents were processed successfully. Cannot continue.")
431
+ vector_store = create_vector_store(documents)
432
+ vector_store.save_local(vector_store_dir)
433
+
434
+ # Load model and tokenizer
435
+ model, tokenizer = load_model_and_tokenizer()
436
+
437
+ # Initialize assistant
438
+ assistant = Vision2030Assistant(model, tokenizer, vector_store)
439
+
440
+ return assistant
441
+
442
+ def evaluate_response(query, response, reference):
443
+ """Evaluate a single response against a reference"""
444
+ # Calculate metrics
445
+ rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
446
+ rouge_scores = rouge.score(response, reference)
447
+
448
+ bleu_scores = calculate_bleu(response, reference)
449
+ meteor = calculate_meteor(response, reference)
450
+ word_metrics = calculate_f1_precision_recall(response, reference)
451
+
452
+ # Format results
453
+ evaluation_results = {
454
+ "ROUGE-1": f"{rouge_scores['rouge1'].fmeasure:.4f}",
455
+ "ROUGE-2": f"{rouge_scores['rouge2'].fmeasure:.4f}",
456
+ "ROUGE-L": f"{rouge_scores['rougeL'].fmeasure:.4f}",
457
+ "BLEU-1": f"{bleu_scores['bleu_1']:.4f}",
458
+ "BLEU-4": f"{bleu_scores['bleu_4']:.4f}",
459
+ "METEOR": f"{meteor:.4f}",
460
+ "Word Precision": f"{word_metrics['precision']:.4f}",
461
+ "Word Recall": f"{word_metrics['recall']:.4f}",
462
+ "Word F1": f"{word_metrics['f1']:.4f}"
463
+ }
464
+
465
+ return evaluation_results
466
+
467
+ @spaces.GPU
468
+ def run_evaluation_on_sample(assistant, sample_index=0):
469
+ """Run evaluation on a selected sample from the evaluation dataset"""
470
+ if sample_index < 0 or sample_index >= len(comprehensive_evaluation_data):
471
+ return "Invalid sample index", "", "", {}
472
+
473
+ # Get the sample
474
+ sample = comprehensive_evaluation_data[sample_index]
475
+ query = sample["query"]
476
+ reference = sample["reference"]
477
+ category = sample["category"]
478
+ language = sample["language"]
479
+
480
+ # Reset conversation and get response
481
+ assistant.reset_conversation()
482
+ response, sources, contexts = assistant.answer(query)
483
+
484
+ # Evaluate response
485
+ evaluation_results = evaluate_response(query, response, reference)
486
+
487
+ return query, response, reference, evaluation_results, sources, category, language
488
+
489
+ def qualitative_evaluation_interface(assistant=None):
490
+ """Create a Gradio interface for qualitative evaluation"""
491
+
492
+ # If assistant is None, create a simplified interface
493
+ if assistant is None:
494
+ with gr.Blocks(title="Vision 2030 Assistant - Initialization Error") as interface:
495
+ gr.Markdown("# Vision 2030 Assistant - Initialization Error")
496
+ gr.Markdown("There was an error initializing the assistant. Please check the logs for details.")
497
+ gr.Textbox(label="Status", value="System initialization failed")
498
+ return interface
499
+
500
+ sample_options = [f"{i+1}. {item['query'][:50]}..." for i, item in enumerate(comprehensive_evaluation_data)]
501
+
502
+ with gr.Blocks(title="Vision 2030 Assistant - Qualitative Evaluation") as interface:
503
+ gr.Markdown("# Vision 2030 Assistant - Qualitative Evaluation")
504
+ gr.Markdown("This interface allows you to evaluate the Vision 2030 Assistant on predefined samples or your own queries.")
505
+
506
+ with gr.Tab("Sample Evaluation"):
507
+ gr.Markdown("### Evaluate the assistant on predefined samples")
508
+
509
+ sample_dropdown = gr.Dropdown(
510
+ choices=sample_options,
511
+ label="Select a sample query",
512
+ value=sample_options[0] if sample_options else None
513
+ )
514
+
515
+ eval_button = gr.Button("Evaluate Sample")
516
+
517
+ with gr.Row():
518
+ with gr.Column():
519
+ sample_query = gr.Textbox(label="Query")
520
+ sample_category = gr.Textbox(label="Category")
521
+ sample_language = gr.Textbox(label="Language")
522
+
523
+ with gr.Column():
524
+ sample_response = gr.Textbox(label="Assistant Response")
525
+ sample_reference = gr.Textbox(label="Reference Answer")
526
+ sample_sources = gr.Textbox(label="Sources Used")
527
+
528
+ with gr.Row():
529
+ metrics_display = gr.JSON(label="Evaluation Metrics")
530
+
531
+ with gr.Tab("Custom Evaluation"):
532
+ gr.Markdown("### Evaluate the assistant on your own query")
533
+
534
+ custom_query = gr.Textbox(
535
+ lines=3,
536
+ placeholder="Enter your question about Saudi Vision 2030...",
537
+ label="Your Query"
538
+ )
539
+
540
+ custom_reference = gr.Textbox(
541
+ lines=3,
542
+ placeholder="Enter a reference answer (optional)...",
543
+ label="Reference Answer (Optional)"
544
+ )
545
+
546
+ custom_eval_button = gr.Button("Get Response and Evaluate")
547
+
548
+ custom_response = gr.Textbox(label="Assistant Response")
549
+ custom_sources = gr.Textbox(label="Sources Used")
550
+
551
+ custom_metrics = gr.JSON(
552
+ label="Evaluation Metrics (if reference provided)",
553
+ visible=True
554
+ )
555
+
556
+ with gr.Tab("Conversation Mode"):
557
+ gr.Markdown("### Have a conversation with the Vision 2030 Assistant")
558
+
559
+ chatbot = gr.Chatbot(label="Conversation")
560
+
561
+ conv_input = gr.Textbox(
562
+ placeholder="Ask about Saudi Vision 2030...",
563
+ label="Your message"
564
+ )
565
+
566
+ with gr.Row():
567
+ conv_button = gr.Button("Send")
568
+ reset_button = gr.Button("Reset Conversation")
569
+
570
+ conv_sources = gr.Textbox(label="Sources Used")
571
+
572
+ # Sample evaluation event handlers
573
+ def handle_sample_selection(selection):
574
+ if not selection:
575
+ return "", "", "", "", "", "", ""
576
+
577
+ # Extract index from the selection string
578
+ try:
579
+ index = int(selection.split(".")[0]) - 1
580
+ query, response, reference, metrics, sources, category, language = run_evaluation_on_sample(assistant, index)
581
+ sources_str = ", ".join(sources)
582
+ return query, response, reference, metrics, sources_str, category, language
583
+ except Exception as e:
584
+ print(f"Error in handle_sample_selection: {e}")
585
+ import traceback
586
+ traceback.print_exc()
587
+ return f"Error processing selection: {e}", "", "", {}, "", "", ""
588
+
589
+ eval_button.click(
590
+ handle_sample_selection,
591
+ inputs=[sample_dropdown],
592
+ outputs=[sample_query, sample_response, sample_reference, metrics_display,
593
+ sample_sources, sample_category, sample_language]
594
+ )
595
+
596
+ sample_dropdown.change(
597
+ handle_sample_selection,
598
+ inputs=[sample_dropdown],
599
+ outputs=[sample_query, sample_response, sample_reference, metrics_display,
600
+ sample_sources, sample_category, sample_language]
601
+ )
602
+
603
+ # Custom evaluation event handlers
604
+ @spaces.GPU
605
+ def handle_custom_evaluation(query, reference):
606
+ if not query:
607
+ return "Please enter a query", "", {}
608
+
609
+ # Reset conversation to ensure clean state
610
+ assistant.reset_conversation()
611
+
612
+ # Get response
613
+ response, sources, _ = assistant.answer(query)
614
+ sources_str = ", ".join(sources)
615
+
616
+ # Evaluate if reference is provided
617
+ metrics = {}
618
+ if reference:
619
+ metrics = evaluate_response(query, response, reference)
620
+
621
+ return response, sources_str, metrics
622
+
623
+ custom_eval_button.click(
624
+ handle_custom_evaluation,
625
+ inputs=[custom_query, custom_reference],
626
+ outputs=[custom_response, custom_sources, custom_metrics]
627
+ )
628
+
629
+ # Conversation mode event handlers
630
+ @spaces.GPU
631
+ def handle_conversation(message, history):
632
+ if not message:
633
+ return history, "", ""
634
+
635
+ # Get response
636
+ response, sources, _ = assistant.answer(message)
637
+ sources_str = ", ".join(sources)
638
+
639
+ # Update history
640
+ history = history + [[message, response]]
641
+
642
+ return history, "", sources_str
643
+
644
+ def reset_conv():
645
+ result = assistant.reset_conversation()
646
+ return [], result, ""
647
+
648
+ conv_button.click(
649
+ handle_conversation,
650
+ inputs=[conv_input, chatbot],
651
+ outputs=[chatbot, conv_input, conv_sources]
652
+ )
653
+
654
+ reset_button.click(
655
+ reset_conv,
656
+ inputs=[],
657
+ outputs=[chatbot, conv_input, conv_sources]
658
+ )
659
+
660
+ return interface
661
+
662
+ # Main function to run in Hugging Face Space
663
+ def main():
664
+ # Start with a loading interface
665
+ with gr.Blocks(title="Vision 2030 Assistant - Loading") as loading_interface:
666
+ gr.Markdown("# Vision 2030 Assistant")
667
+ gr.Markdown("System is initializing. This may take a few minutes...")
668
+ loading_status = gr.Textbox(value="Loading system...", label="Status")
669
+
670
+ interface = loading_interface.queue()
671
+
672
+ # Initialize the system
673
+ try:
674
+ print("Starting system initialization...")
675
+ assistant = initialize_system()
676
+
677
+ print("Creating interface...")
678
+ full_interface = qualitative_evaluation_interface(assistant)
679
+
680
+ print("System ready!")
681
+ # Will replace the loading interface
682
+ return full_interface
683
+
684
+ except Exception as e:
685
+ print(f"Error during initialization: {e}")
686
+ import traceback
687
+ traceback.print_exc()
688
+
689
+ # Create a simple error interface
690
+ with gr.Blocks(title="Vision 2030 Assistant - Error") as error_interface:
691
+ gr.Markdown("# Vision 2030 Assistant - Initialization Error")
692
+ gr.Markdown("There was an error initializing the assistant.")
693
+
694
+ # Display error details
695
+ gr.Textbox(
696
+ value=f"Error: {str(e)}",
697
+ label="Error Details",
698
+ lines=5
699
+ )
700
+
701
+ # Show potential solutions
702
+ gr.Markdown("## Potential Solutions")
703
+ gr.Markdown("""
704
+ 1. Check that all dependencies are installed:
705
+ - sentencepiece
706
+ - accelerate
707
+ - transformers
708
+ - langchain and langchain-community
709
+
710
+ 2. Verify PDF files are accessible and in the correct location
711
+
712
+ 3. Check GPU memory is sufficient for loading the model
713
+ """)
714
+
715
+ # Add a button to check system
716
+ def check_system():
717
  results = []
718
 
719
+ # Check dependencies
720
+ for lib in ["torch", "transformers", "sentencepiece", "accelerate"]:
 
 
 
 
 
 
 
 
721
  try:
722
+ module = __import__(lib)
723
  if hasattr(module, "__version__"):
724
+ results.append(f"✓ {lib}: {module.__version__}")
725
  else:
726
+ results.append(f"✓ {lib}: Installed")
727
  except ImportError:
728
+ results.append(f"✗ {lib}: Not installed")
 
 
729
 
730
+ # Check GPU
731
  try:
732
  import torch
 
733
  results.append(f"CUDA available: {torch.cuda.is_available()}")
734
  if torch.cuda.is_available():
735
+ results.append(f"GPU: {torch.cuda.get_device_name(0)}")
736
+ results.append(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
737
+ except:
738
+ results.append("Could not check GPU status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
+ # Check PDFs
741
+ for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]:
742
+ if os.path.exists(pdf_file):
743
+ size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB
744
+ results.append(f"{pdf_file}: Found ({size:.2f} MB)")
745
+ else:
746
+ results.append(f"{pdf_file}: Not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
 
748
  return "\n".join(results)
749
 
750
+ check_btn = gr.Button("Run System Check")
751
+ system_status = gr.Textbox(label="System Status", lines=15)
752
+ check_btn.click(check_system, inputs=[], outputs=[system_status])
753
+
754
+ return error_interface
755
 
756
  if __name__ == "__main__":
757
+ demo = main()
758
+ demo.launch()