abdull4h commited on
Commit
a8285e4
·
verified ·
1 Parent(s): 1f683db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -643
app.py CHANGED
@@ -1,42 +1,15 @@
1
  import os
2
  import re
3
  import json
4
- import torch
5
- import numpy as np
6
- import pandas as pd
7
  from tqdm import tqdm
8
  from pathlib import Path
9
  import spaces
10
-
11
- # PDF processing
12
- import PyPDF2
13
-
14
- # LLM and embeddings
15
- from transformers import AutoTokenizer, AutoModelForCausalLM
16
- from sentence_transformers import SentenceTransformer
17
-
18
- # RAG components
19
- from langchain.text_splitter import RecursiveCharacterTextSplitter
20
- from langchain_community.vectorstores import FAISS
21
- from langchain.schema import Document
22
- from langchain.embeddings import HuggingFaceEmbeddings
23
-
24
- # Arabic text processing
25
- import arabic_reshaper
26
- from bidi.algorithm import get_display
27
-
28
- # Evaluation
29
- from rouge_score import rouge_scorer
30
- import sacrebleu
31
- from sklearn.metrics import accuracy_score, precision_recall_fscore_support
32
- import matplotlib.pyplot as plt
33
- import seaborn as sns
34
- from collections import defaultdict
35
-
36
- # Gradio for the interface
37
  import gradio as gr
38
 
39
- # Helper functions
 
 
 
40
  def safe_tokenize(text):
41
  """Pure regex tokenizer with no NLTK dependency"""
42
  if not text:
@@ -53,315 +26,6 @@ def detect_language(text):
53
  is_arabic = len(arabic_chars) > len(text) * 0.5
54
  return "arabic" if is_arabic else "english"
55
 
56
- # Evaluation metrics
57
- def calculate_bleu(prediction, reference):
58
- """Calculate BLEU score without any NLTK dependency"""
59
- # Tokenize texts using our own tokenizer
60
- pred_tokens = safe_tokenize(prediction.lower())
61
- ref_tokens = [safe_tokenize(reference.lower())]
62
-
63
- # If either is empty, return 0
64
- if not pred_tokens or not ref_tokens[0]:
65
- return {"bleu_1": 0, "bleu_2": 0, "bleu_4": 0}
66
-
67
- # Get n-grams function
68
- def get_ngrams(tokens, n):
69
- return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
70
-
71
- # Calculate precision for each n-gram level
72
- precisions = []
73
- for n in range(1, 5): # 1-gram to 4-gram
74
- if len(pred_tokens) < n:
75
- precisions.append(0)
76
- continue
77
-
78
- pred_ngrams = get_ngrams(pred_tokens, n)
79
- ref_ngrams = get_ngrams(ref_tokens[0], n)
80
-
81
- # Count matches
82
- matches = sum(1 for ng in pred_ngrams if ng in ref_ngrams)
83
-
84
- # Calculate precision
85
- if pred_ngrams:
86
- precisions.append(matches / len(pred_ngrams))
87
- else:
88
- precisions.append(0)
89
-
90
- # Return BLEU scores
91
- return {
92
- "bleu_1": precisions[0],
93
- "bleu_2": (precisions[0] * precisions[1]) ** 0.5 if len(precisions) > 1 else 0,
94
- "bleu_4": (precisions[0] * precisions[1] * precisions[2] * precisions[3]) ** 0.25 if len(precisions) > 3 else 0
95
- }
96
-
97
- def calculate_meteor(prediction, reference):
98
- """Simple word overlap metric as METEOR alternative"""
99
- # Tokenize with our custom tokenizer
100
- pred_tokens = set(safe_tokenize(prediction.lower()))
101
- ref_tokens = set(safe_tokenize(reference.lower()))
102
-
103
- # Calculate Jaccard similarity as METEOR alternative
104
- if not pred_tokens or not ref_tokens:
105
- return 0
106
-
107
- intersection = len(pred_tokens.intersection(ref_tokens))
108
- union = len(pred_tokens.union(ref_tokens))
109
-
110
- return intersection / union if union > 0 else 0
111
-
112
- def calculate_f1_precision_recall(prediction, reference):
113
- """Calculate word-level F1, precision, and recall with custom tokenizer"""
114
- # Tokenize with our custom tokenizer
115
- pred_tokens = set(safe_tokenize(prediction.lower()))
116
- ref_tokens = set(safe_tokenize(reference.lower()))
117
-
118
- # Calculate overlap
119
- common = pred_tokens.intersection(ref_tokens)
120
-
121
- # Calculate precision, recall, F1
122
- precision = len(common) / len(pred_tokens) if pred_tokens else 0
123
- recall = len(common) / len(ref_tokens) if ref_tokens else 0
124
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
125
-
126
- return {'precision': precision, 'recall': recall, 'f1': f1}
127
-
128
- def evaluate_retrieval_quality(contexts, query, language):
129
- """Evaluate the quality of retrieved contexts"""
130
- # This is a placeholder implementation
131
- return {
132
- 'language_match_ratio': 1.0,
133
- 'source_diversity': len(set([ctx.get('source', '') for ctx in contexts])) / max(1, len(contexts)),
134
- 'mrr': 1.0
135
- }
136
-
137
- # PDF Processing and Vector Store
138
- def simple_process_pdfs(pdf_paths):
139
- """Process PDF documents and return document objects"""
140
- documents = []
141
-
142
- print(f"Processing PDFs: {pdf_paths}")
143
-
144
- for pdf_path in pdf_paths:
145
- try:
146
- if not os.path.exists(pdf_path):
147
- print(f"Warning: {pdf_path} does not exist")
148
- continue
149
-
150
- print(f"Processing {pdf_path}...")
151
- text = ""
152
- with open(pdf_path, 'rb') as file:
153
- reader = PyPDF2.PdfReader(file)
154
- for page in reader.pages:
155
- page_text = page.extract_text()
156
- if page_text: # If we got text from this page
157
- text += page_text + "\n\n"
158
-
159
- if text.strip(): # If we got some text
160
- doc = Document(
161
- page_content=text,
162
- metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
163
- )
164
- documents.append(doc)
165
- print(f"Successfully processed: {pdf_path}")
166
- else:
167
- print(f"Warning: No text extracted from {pdf_path}")
168
- except Exception as e:
169
- print(f"Error processing {pdf_path}: {e}")
170
- import traceback
171
- traceback.print_exc()
172
-
173
- print(f"Processed {len(documents)} PDF documents")
174
- return documents
175
-
176
- def create_vector_store(documents):
177
- """Split documents into chunks and create a FAISS vector store"""
178
- # Text splitter for breaking documents into chunks
179
- text_splitter = RecursiveCharacterTextSplitter(
180
- chunk_size=500,
181
- chunk_overlap=50,
182
- separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
183
- )
184
-
185
- # Split documents into chunks
186
- chunks = []
187
- for doc in documents:
188
- doc_chunks = text_splitter.split_text(doc.page_content)
189
- # Preserve metadata for each chunk
190
- chunks.extend([
191
- Document(page_content=chunk, metadata=doc.metadata)
192
- for chunk in doc_chunks
193
- ])
194
-
195
- print(f"Created {len(chunks)} chunks from {len(documents)} documents")
196
-
197
- # Create a proper embedding function for LangChain
198
- embedding_function = HuggingFaceEmbeddings(
199
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
200
- )
201
-
202
- # Create FAISS index
203
- vector_store = FAISS.from_documents(
204
- chunks,
205
- embedding_function
206
- )
207
-
208
- return vector_store
209
-
210
- # Model Loading and RAG System
211
- @spaces.GPU
212
- def load_model_and_tokenizer():
213
- """Load the ALLaM-7B model and tokenizer with error handling"""
214
- model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
215
- print(f"Loading model: {model_name}")
216
-
217
- try:
218
- # Load tokenizer with correct settings
219
- tokenizer = AutoTokenizer.from_pretrained(
220
- model_name,
221
- trust_remote_code=True,
222
- use_fast=False
223
- )
224
-
225
- # Load model with appropriate settings for ALLaM
226
- model = AutoModelForCausalLM.from_pretrained(
227
- model_name,
228
- torch_dtype=torch.bfloat16,
229
- trust_remote_code=True,
230
- device_map="auto",
231
- )
232
-
233
- print("Model loaded successfully!")
234
- return model, tokenizer
235
-
236
- except Exception as e:
237
- print(f"Error loading model: {e}")
238
- import traceback
239
- traceback.print_exc()
240
- raise Exception(f"Failed to load model: {e}")
241
-
242
- def retrieve_context(query, vector_store, top_k=5):
243
- """Retrieve most relevant document chunks for a given query"""
244
- # Search the vector store using similarity search
245
- results = vector_store.similarity_search_with_score(query, k=top_k)
246
-
247
- # Format the retrieved contexts
248
- contexts = []
249
- for doc, score in results:
250
- contexts.append({
251
- "content": doc.page_content,
252
- "source": doc.metadata.get("source", "Unknown"),
253
- "relevance_score": score
254
- })
255
-
256
- return contexts
257
-
258
- @spaces.GPU
259
- def generate_response(query, contexts, model, tokenizer, language="auto"):
260
- """Generate a response using retrieved contexts with ALLaM-specific formatting"""
261
- # Auto-detect language if not specified
262
- if language == "auto":
263
- language = detect_language(query)
264
-
265
- # Format the prompt based on language
266
- if language == "arabic":
267
- instruction = (
268
- "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
269
- "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
270
- )
271
- else: # english
272
- instruction = (
273
- "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
274
- "If you don't know the answer, honestly say you don't know."
275
- )
276
-
277
- # Combine retrieved contexts
278
- context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
279
-
280
- # Format the prompt for ALLaM instruction format
281
- prompt = f"""<s>[INST] {instruction}
282
-
283
- Context:
284
- {context_text}
285
-
286
- Question: {query} [/INST]</s>"""
287
-
288
- try:
289
- # Generate response with appropriate parameters for ALLaM
290
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
291
-
292
- # Generate with appropriate parameters
293
- outputs = model.generate(
294
- inputs.input_ids,
295
- attention_mask=inputs.attention_mask,
296
- max_new_tokens=512,
297
- temperature=0.7,
298
- top_p=0.9,
299
- do_sample=True,
300
- repetition_penalty=1.1
301
- )
302
-
303
- # Decode the response
304
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
305
-
306
- # Extract just the answer part (after the instruction)
307
- response = full_output.split("[/INST]")[-1].strip()
308
-
309
- # If response is empty for some reason, return the full output
310
- if not response:
311
- response = full_output
312
-
313
- return response
314
-
315
- except Exception as e:
316
- print(f"Error during generation: {e}")
317
- # Fallback response
318
- return "I apologize, but I encountered an error while generating a response."
319
-
320
- # Assistant Class
321
- class Vision2030Assistant:
322
- def __init__(self, model, tokenizer, vector_store):
323
- self.model = model
324
- self.tokenizer = tokenizer
325
- self.vector_store = vector_store
326
- self.conversation_history = []
327
-
328
- def answer(self, user_query):
329
- """Process a user query and return a response with sources"""
330
- # Detect language
331
- language = detect_language(user_query)
332
-
333
- # Add user query to conversation history
334
- self.conversation_history.append({"role": "user", "content": user_query})
335
-
336
- # Get the full conversation context
337
- conversation_context = "\n".join([
338
- f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
339
- for msg in self.conversation_history[-6:] # Keep last 3 turns (6 messages)
340
- ])
341
-
342
- # Enhance query with conversation context for better retrieval
343
- enhanced_query = f"{conversation_context}\n{user_query}"
344
-
345
- # Retrieve relevant contexts
346
- contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5)
347
-
348
- # Generate response
349
- response = generate_response(user_query, contexts, self.model, self.tokenizer, language)
350
-
351
- # Add response to conversation history
352
- self.conversation_history.append({"role": "assistant", "content": response})
353
-
354
- # Also return sources for transparency
355
- sources = [ctx.get("source", "Unknown") for ctx in contexts]
356
- unique_sources = list(set(sources))
357
-
358
- return response, unique_sources, contexts
359
-
360
- def reset_conversation(self):
361
- """Reset the conversation history"""
362
- self.conversation_history = []
363
- return "Conversation has been reset."
364
-
365
  # Comprehensive evaluation dataset
366
  comprehensive_evaluation_data = [
367
  # === Overview ===
@@ -407,352 +71,386 @@ comprehensive_evaluation_data = [
407
  }
408
  ]
409
 
410
- # Gradio Interface
411
- def initialize_system():
412
- """Initialize the Vision 2030 Assistant system"""
413
- # Define paths for PDF files in the root directory
414
- pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]
415
-
416
- # Process PDFs and create vector store
417
- vector_store_dir = "vector_stores"
418
- os.makedirs(vector_store_dir, exist_ok=True)
419
-
420
- if os.path.exists(os.path.join(vector_store_dir, "index.faiss")):
421
- print("Loading existing vector store...")
422
- embedding_function = HuggingFaceEmbeddings(
423
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
424
- )
425
- vector_store = FAISS.load_local(vector_store_dir, embedding_function)
426
- else:
427
- print("Creating new vector store...")
428
- documents = simple_process_pdfs(pdf_files)
429
- if not documents:
430
- raise ValueError("No documents were processed successfully. Cannot continue.")
431
- vector_store = create_vector_store(documents)
432
- vector_store.save_local(vector_store_dir)
433
-
434
- # Load model and tokenizer
435
- model, tokenizer = load_model_and_tokenizer()
436
-
437
- # Initialize assistant
438
- assistant = Vision2030Assistant(model, tokenizer, vector_store)
439
-
440
- return assistant
441
-
442
- def evaluate_response(query, response, reference):
443
- """Evaluate a single response against a reference"""
444
- # Calculate metrics
445
- rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
446
- rouge_scores = rouge.score(response, reference)
447
-
448
- bleu_scores = calculate_bleu(response, reference)
449
- meteor = calculate_meteor(response, reference)
450
- word_metrics = calculate_f1_precision_recall(response, reference)
451
-
452
- # Format results
453
- evaluation_results = {
454
- "ROUGE-1": f"{rouge_scores['rouge1'].fmeasure:.4f}",
455
- "ROUGE-2": f"{rouge_scores['rouge2'].fmeasure:.4f}",
456
- "ROUGE-L": f"{rouge_scores['rougeL'].fmeasure:.4f}",
457
- "BLEU-1": f"{bleu_scores['bleu_1']:.4f}",
458
- "BLEU-4": f"{bleu_scores['bleu_4']:.4f}",
459
- "METEOR": f"{meteor:.4f}",
460
- "Word Precision": f"{word_metrics['precision']:.4f}",
461
- "Word Recall": f"{word_metrics['recall']:.4f}",
462
- "Word F1": f"{word_metrics['f1']:.4f}"
463
- }
464
-
465
- return evaluation_results
466
-
467
- @spaces.GPU
468
- def run_evaluation_on_sample(assistant, sample_index=0):
469
- """Run evaluation on a selected sample from the evaluation dataset"""
470
- if sample_index < 0 or sample_index >= len(comprehensive_evaluation_data):
471
- return "Invalid sample index", "", "", {}
472
-
473
- # Get the sample
474
- sample = comprehensive_evaluation_data[sample_index]
475
- query = sample["query"]
476
- reference = sample["reference"]
477
- category = sample["category"]
478
- language = sample["language"]
479
-
480
- # Reset conversation and get response
481
- assistant.reset_conversation()
482
- response, sources, contexts = assistant.answer(query)
483
-
484
- # Evaluate response
485
- evaluation_results = evaluate_response(query, response, reference)
486
-
487
- return query, response, reference, evaluation_results, sources, category, language
488
-
489
- def qualitative_evaluation_interface(assistant=None):
490
- """Create a Gradio interface for qualitative evaluation"""
491
-
492
- # If assistant is None, create a simplified interface
493
- if assistant is None:
494
- with gr.Blocks(title="Vision 2030 Assistant - Initialization Error") as interface:
495
- gr.Markdown("# Vision 2030 Assistant - Initialization Error")
496
- gr.Markdown("There was an error initializing the assistant. Please check the logs for details.")
497
- gr.Textbox(label="Status", value="System initialization failed")
498
- return interface
499
-
500
- sample_options = [f"{i+1}. {item['query'][:50]}..." for i, item in enumerate(comprehensive_evaluation_data)]
501
-
502
- with gr.Blocks(title="Vision 2030 Assistant - Qualitative Evaluation") as interface:
503
- gr.Markdown("# Vision 2030 Assistant - Qualitative Evaluation")
504
- gr.Markdown("This interface allows you to evaluate the Vision 2030 Assistant on predefined samples or your own queries.")
505
 
506
- with gr.Tab("Sample Evaluation"):
507
- gr.Markdown("### Evaluate the assistant on predefined samples")
 
 
 
508
 
509
- sample_dropdown = gr.Dropdown(
510
- choices=sample_options,
511
- label="Select a sample query",
512
- value=sample_options[0] if sample_options else None
513
- )
 
 
 
 
 
514
 
515
- eval_button = gr.Button("Evaluate Sample")
 
516
 
517
- with gr.Row():
518
- with gr.Column():
519
- sample_query = gr.Textbox(label="Query")
520
- sample_category = gr.Textbox(label="Category")
521
- sample_language = gr.Textbox(label="Language")
522
-
523
- with gr.Column():
524
- sample_response = gr.Textbox(label="Assistant Response")
525
- sample_reference = gr.Textbox(label="Reference Answer")
526
- sample_sources = gr.Textbox(label="Sources Used")
527
 
528
- with gr.Row():
529
- metrics_display = gr.JSON(label="Evaluation Metrics")
530
-
531
- with gr.Tab("Custom Evaluation"):
532
- gr.Markdown("### Evaluate the assistant on your own query")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
- custom_query = gr.Textbox(
535
- lines=3,
536
- placeholder="Enter your question about Saudi Vision 2030...",
537
- label="Your Query"
 
 
538
  )
539
 
540
- custom_reference = gr.Textbox(
541
- lines=3,
542
- placeholder="Enter a reference answer (optional)...",
543
- label="Reference Answer (Optional)"
 
544
  )
545
 
546
- custom_eval_button = gr.Button("Get Response and Evaluate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
- custom_response = gr.Textbox(label="Assistant Response")
549
- custom_sources = gr.Textbox(label="Sources Used")
 
 
 
 
 
550
 
551
- custom_metrics = gr.JSON(
552
- label="Evaluation Metrics (if reference provided)",
553
- visible=True
554
- )
 
 
 
 
 
 
 
 
 
555
 
556
- with gr.Tab("Conversation Mode"):
557
- gr.Markdown("### Have a conversation with the Vision 2030 Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
- chatbot = gr.Chatbot(label="Conversation")
 
560
 
561
- conv_input = gr.Textbox(
562
- placeholder="Ask about Saudi Vision 2030...",
563
- label="Your message"
 
 
 
 
 
564
  )
565
 
566
- with gr.Row():
567
- conv_button = gr.Button("Send")
568
- reset_button = gr.Button("Reset Conversation")
569
 
570
- conv_sources = gr.Textbox(label="Sources Used")
571
-
572
- # Sample evaluation event handlers
573
- def handle_sample_selection(selection):
574
- if not selection:
575
- return "", "", "", "", "", "", ""
576
-
577
- # Extract index from the selection string
578
- try:
579
- index = int(selection.split(".")[0]) - 1
580
- query, response, reference, metrics, sources, category, language = run_evaluation_on_sample(assistant, index)
581
- sources_str = ", ".join(sources)
582
- return query, response, reference, metrics, sources_str, category, language
583
- except Exception as e:
584
- print(f"Error in handle_sample_selection: {e}")
585
- import traceback
586
- traceback.print_exc()
587
- return f"Error processing selection: {e}", "", "", {}, "", "", ""
588
-
589
- eval_button.click(
590
- handle_sample_selection,
591
- inputs=[sample_dropdown],
592
- outputs=[sample_query, sample_response, sample_reference, metrics_display,
593
- sample_sources, sample_category, sample_language]
594
- )
595
-
596
- sample_dropdown.change(
597
- handle_sample_selection,
598
- inputs=[sample_dropdown],
599
- outputs=[sample_query, sample_response, sample_reference, metrics_display,
600
- sample_sources, sample_category, sample_language]
601
- )
602
 
603
- # Custom evaluation event handlers
604
- @spaces.GPU
605
- def handle_custom_evaluation(query, reference):
606
- if not query:
607
- return "Please enter a query", "", {}
608
 
609
- # Reset conversation to ensure clean state
610
- assistant.reset_conversation()
 
 
 
611
 
612
- # Get response
613
- response, sources, _ = assistant.answer(query)
614
- sources_str = ", ".join(sources)
615
 
616
- # Evaluate if reference is provided
617
- metrics = {}
618
- if reference:
619
- metrics = evaluate_response(query, response, reference)
620
 
621
- return response, sources_str, metrics
622
-
623
- custom_eval_button.click(
624
- handle_custom_evaluation,
625
- inputs=[custom_query, custom_reference],
626
- outputs=[custom_response, custom_sources, custom_metrics]
627
- )
628
-
629
- # Conversation mode event handlers
630
- @spaces.GPU
631
- def handle_conversation(message, history):
632
- if not message:
633
- return history, "", ""
634
 
635
- # Get response
636
- response, sources, _ = assistant.answer(message)
637
- sources_str = ", ".join(sources)
638
 
639
- # Update history
640
- history = history + [[message, response]]
 
641
 
642
- return history, "", sources_str
643
-
644
- def reset_conv():
645
- result = assistant.reset_conversation()
646
- return [], result, ""
647
-
648
- conv_button.click(
649
- handle_conversation,
650
- inputs=[conv_input, chatbot],
651
- outputs=[chatbot, conv_input, conv_sources]
652
- )
653
-
654
- reset_button.click(
655
- reset_conv,
656
- inputs=[],
657
- outputs=[chatbot, conv_input, conv_sources]
658
- )
659
-
660
- return interface
661
 
662
- # Main function to run in Hugging Face Space
663
  def main():
664
- # Start with a loading interface
665
- with gr.Blocks(title="Vision 2030 Assistant - Loading") as loading_interface:
666
- gr.Markdown("# Vision 2030 Assistant")
667
- gr.Markdown("System is initializing. This may take a few minutes...")
668
- loading_status = gr.Textbox(value="Loading system...", label="Status")
669
 
670
- interface = loading_interface.queue()
671
-
672
- # Initialize the system
673
- try:
674
- print("Starting system initialization...")
675
- assistant = initialize_system()
676
-
677
- print("Creating interface...")
678
- full_interface = qualitative_evaluation_interface(assistant)
679
-
680
- print("System ready!")
681
- # Will replace the loading interface
682
- return full_interface
683
 
684
- except Exception as e:
685
- print(f"Error during initialization: {e}")
686
- import traceback
687
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
- # Create a simple error interface
690
- with gr.Blocks(title="Vision 2030 Assistant - Error") as error_interface:
691
- gr.Markdown("# Vision 2030 Assistant - Initialization Error")
692
- gr.Markdown("There was an error initializing the assistant.")
693
-
694
- # Display error details
695
- gr.Textbox(
696
- value=f"Error: {str(e)}",
697
- label="Error Details",
698
- lines=5
699
- )
700
 
701
- # Show potential solutions
702
- gr.Markdown("## Potential Solutions")
703
- gr.Markdown("""
704
- 1. Check that all dependencies are installed:
705
- - sentencepiece
706
- - accelerate
707
- - transformers
708
- - langchain and langchain-community
709
 
710
- 2. Verify PDF files are accessible and in the correct location
711
 
712
- 3. Check GPU memory is sufficient for loading the model
713
- """)
 
 
714
 
715
- # Add a button to check system
716
- def check_system():
717
- results = []
718
-
719
- # Check dependencies
720
- for lib in ["torch", "transformers", "sentencepiece", "accelerate"]:
721
- try:
722
- module = __import__(lib)
723
- if hasattr(module, "__version__"):
724
- results.append(f"✓ {lib}: {module.__version__}")
725
- else:
726
- results.append(f"✓ {lib}: Installed")
727
- except ImportError:
728
- results.append(f"✗ {lib}: Not installed")
729
-
730
- # Check GPU
731
- try:
732
- import torch
733
- results.append(f"CUDA available: {torch.cuda.is_available()}")
734
- if torch.cuda.is_available():
735
- results.append(f"GPU: {torch.cuda.get_device_name(0)}")
736
- results.append(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
737
- except:
738
- results.append("Could not check GPU status")
739
-
740
- # Check PDFs
741
  for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]:
742
  if os.path.exists(pdf_file):
743
  size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB
744
- results.append(f"{pdf_file}: Found ({size:.2f} MB)")
745
  else:
746
- results.append(f"{pdf_file}: Not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
 
748
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
 
750
- check_btn = gr.Button("Run System Check")
751
- system_status = gr.Textbox(label="System Status", lines=15)
752
- check_btn.click(check_system, inputs=[], outputs=[system_status])
753
 
754
- return error_interface
 
 
 
 
 
 
 
 
 
 
 
755
 
756
  if __name__ == "__main__":
757
  demo = main()
 
758
  demo.launch()
 
1
  import os
2
  import re
3
  import json
 
 
 
4
  from tqdm import tqdm
5
  from pathlib import Path
6
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import gradio as gr
8
 
9
+ # WARNING: Don't import torch, cuda, or GPU-related modules at the top level
10
+ # They must ONLY be imported inside functions decorated with @spaces.GPU
11
+
12
+ # Helper functions that don't use GPU
13
  def safe_tokenize(text):
14
  """Pure regex tokenizer with no NLTK dependency"""
15
  if not text:
 
26
  is_arabic = len(arabic_chars) > len(text) * 0.5
27
  return "arabic" if is_arabic else "english"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Comprehensive evaluation dataset
30
  comprehensive_evaluation_data = [
31
  # === Overview ===
 
71
  }
72
  ]
73
 
74
+ # RAG Service class
75
+ class Vision2030Service:
76
+ def __init__(self):
77
+ self.initialized = False
78
+ self.model = None
79
+ self.tokenizer = None
80
+ self.vector_store = None
81
+ self.conversation_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ @spaces.GPU
84
+ def initialize(self):
85
+ """Initialize the system - ALL GPU operations must happen here"""
86
+ if self.initialized:
87
+ return True
88
 
89
+ try:
90
+ # Import all GPU-dependent libraries only inside this function
91
+ import torch
92
+ import PyPDF2
93
+ from transformers import AutoTokenizer, AutoModelForCausalLM
94
+ from sentence_transformers import SentenceTransformer
95
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
96
+ from langchain_community.vectorstores import FAISS
97
+ from langchain.schema import Document
98
+ from langchain.embeddings import HuggingFaceEmbeddings
99
 
100
+ # Define paths for PDF files
101
+ pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]
102
 
103
+ # Process PDFs and create vector store
104
+ vector_store_dir = "vector_stores"
105
+ os.makedirs(vector_store_dir, exist_ok=True)
 
 
 
 
 
 
 
106
 
107
+ if os.path.exists(os.path.join(vector_store_dir, "index.faiss")):
108
+ print("Loading existing vector store...")
109
+ embedding_function = HuggingFaceEmbeddings(
110
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
111
+ )
112
+ self.vector_store = FAISS.load_local(vector_store_dir, embedding_function)
113
+ else:
114
+ print("Creating new vector store...")
115
+ # Process PDFs
116
+ documents = []
117
+ for pdf_path in pdf_files:
118
+ if not os.path.exists(pdf_path):
119
+ print(f"Warning: {pdf_path} does not exist")
120
+ continue
121
+
122
+ print(f"Processing {pdf_path}...")
123
+ text = ""
124
+ with open(pdf_path, 'rb') as file:
125
+ reader = PyPDF2.PdfReader(file)
126
+ for page in reader.pages:
127
+ page_text = page.extract_text()
128
+ if page_text:
129
+ text += page_text + "\n\n"
130
+
131
+ if text.strip():
132
+ doc = Document(
133
+ page_content=text,
134
+ metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
135
+ )
136
+ documents.append(doc)
137
+
138
+ if not documents:
139
+ raise ValueError("No documents were processed successfully.")
140
+
141
+ # Split into chunks
142
+ text_splitter = RecursiveCharacterTextSplitter(
143
+ chunk_size=500,
144
+ chunk_overlap=50,
145
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
146
+ )
147
+
148
+ chunks = []
149
+ for doc in documents:
150
+ doc_chunks = text_splitter.split_text(doc.page_content)
151
+ chunks.extend([
152
+ Document(page_content=chunk, metadata=doc.metadata)
153
+ for chunk in doc_chunks
154
+ ])
155
+
156
+ # Create vector store
157
+ embedding_function = HuggingFaceEmbeddings(
158
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
159
+ )
160
+ self.vector_store = FAISS.from_documents(chunks, embedding_function)
161
+ self.vector_store.save_local(vector_store_dir)
162
 
163
+ # Load model
164
+ model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
165
+ self.tokenizer = AutoTokenizer.from_pretrained(
166
+ model_name,
167
+ trust_remote_code=True,
168
+ use_fast=False
169
  )
170
 
171
+ self.model = AutoModelForCausalLM.from_pretrained(
172
+ model_name,
173
+ torch_dtype=torch.bfloat16,
174
+ trust_remote_code=True,
175
+ device_map="auto",
176
  )
177
 
178
+ self.initialized = True
179
+ return True
180
+
181
+ except Exception as e:
182
+ import traceback
183
+ print(f"Initialization error: {e}")
184
+ print(traceback.format_exc())
185
+ return False
186
+
187
+ @spaces.GPU
188
+ def retrieve_context(self, query, top_k=5):
189
+ """Retrieve contexts from vector store"""
190
+ # Import must be inside the function to avoid CUDA init in main process
191
+
192
+ if not self.initialized:
193
+ return []
194
+
195
+ try:
196
+ results = self.vector_store.similarity_search_with_score(query, k=top_k)
197
 
198
+ contexts = []
199
+ for doc, score in results:
200
+ contexts.append({
201
+ "content": doc.page_content,
202
+ "source": doc.metadata.get("source", "Unknown"),
203
+ "relevance_score": score
204
+ })
205
 
206
+ return contexts
207
+ except Exception as e:
208
+ print(f"Error retrieving context: {e}")
209
+ return []
210
+
211
+ @spaces.GPU
212
+ def generate_response(self, query, contexts, language="auto"):
213
+ """Generate response using the model"""
214
+ # Import must be inside the function to avoid CUDA init in main process
215
+ import torch
216
+
217
+ if not self.initialized or self.model is None or self.tokenizer is None:
218
+ return "I'm still initializing. Please try again in a moment."
219
 
220
+ try:
221
+ # Auto-detect language if not specified
222
+ if language == "auto":
223
+ language = detect_language(query)
224
+
225
+ # Format the prompt based on language
226
+ if language == "arabic":
227
+ instruction = (
228
+ "أنت مساعد ��فتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
229
+ "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
230
+ )
231
+ else: # english
232
+ instruction = (
233
+ "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
234
+ "If you don't know the answer, honestly say you don't know."
235
+ )
236
+
237
+ # Combine retrieved contexts
238
+ context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
239
+
240
+ # Format the prompt for ALLaM instruction format
241
+ prompt = f"""<s>[INST] {instruction}
242
+
243
+ Context:
244
+ {context_text}
245
+
246
+ Question: {query} [/INST]</s>"""
247
 
248
+ # Generate response
249
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
250
 
251
+ outputs = self.model.generate(
252
+ inputs.input_ids,
253
+ attention_mask=inputs.attention_mask,
254
+ max_new_tokens=512,
255
+ temperature=0.7,
256
+ top_p=0.9,
257
+ do_sample=True,
258
+ repetition_penalty=1.1
259
  )
260
 
261
+ # Decode the response
262
+ full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
263
 
264
+ # Extract just the answer part (after the instruction)
265
+ response = full_output.split("[/INST]")[-1].strip()
266
+
267
+ # If response is empty for some reason, return the full output
268
+ if not response:
269
+ response = full_output
270
+
271
+ return response
272
+
273
+ except Exception as e:
274
+ import traceback
275
+ print(f"Error generating response: {e}")
276
+ print(traceback.format_exc())
277
+ return f"Sorry, I encountered an error while generating a response."
278
+
279
+ @spaces.GPU
280
+ def answer_question(self, query):
281
+ """Process a user query and return a response with sources"""
282
+ if not self.initialized:
283
+ if not self.initialize():
284
+ return "System initialization failed. Please check the logs.", []
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ try:
287
+ # Add user query to conversation history
288
+ self.conversation_history.append({"role": "user", "content": query})
 
 
289
 
290
+ # Get the full conversation context
291
+ conversation_context = "\n".join([
292
+ f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
293
+ for msg in self.conversation_history[-6:] # Keep last 3 turns
294
+ ])
295
 
296
+ # Enhance query with conversation context
297
+ enhanced_query = f"{conversation_context}\n{query}"
 
298
 
299
+ # Retrieve relevant contexts
300
+ contexts = self.retrieve_context(enhanced_query, top_k=5)
 
 
301
 
302
+ # Generate response
303
+ response = self.generate_response(query, contexts)
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # Add response to conversation history
306
+ self.conversation_history.append({"role": "assistant", "content": response})
 
307
 
308
+ # Get sources
309
+ sources = [ctx.get("source", "Unknown") for ctx in contexts]
310
+ unique_sources = list(set(sources))
311
 
312
+ return response, unique_sources
313
+ except Exception as e:
314
+ import traceback
315
+ print(f"Error answering question: {e}")
316
+ print(traceback.format_exc())
317
+ return f"Sorry, I encountered an error: {str(e)}", []
318
+
319
+ def reset_conversation(self):
320
+ """Reset the conversation history"""
321
+ self.conversation_history = []
322
+ return "Conversation has been reset."
 
 
 
 
 
 
 
 
323
 
324
+ # Main function with Gradio UI
325
  def main():
326
+ # Create the Vision 2030 service
327
+ service = Vision2030Service()
 
 
 
328
 
329
+ # Build the Gradio interface
330
+ with gr.Blocks(title="Vision 2030 Assistant") as demo:
331
+ gr.Markdown("# Vision 2030 Assistant")
332
+ gr.Markdown("Ask questions about Saudi Vision 2030 in English or Arabic")
 
 
 
 
 
 
 
 
 
333
 
334
+ with gr.Tab("Chat"):
335
+ chatbot = gr.Chatbot()
336
+ msg = gr.Textbox(label="Your question", placeholder="Ask about Vision 2030...")
337
+ clear = gr.Button("Clear History")
338
+
339
+ @spaces.GPU
340
+ def respond(message, history):
341
+ if not message:
342
+ return history, ""
343
+
344
+ response, sources = service.answer_question(message)
345
+ sources_text = ", ".join(sources) if sources else "No specific sources"
346
+
347
+ # Format the response to include sources
348
+ full_response = f"{response}\n\nSources: {sources_text}"
349
+
350
+ return history + [[message, full_response]], ""
351
+
352
+ def reset_chat():
353
+ service.reset_conversation()
354
+ return [], "Conversation history has been reset."
355
+
356
+ msg.submit(respond, [msg, chatbot], [chatbot, msg])
357
+ clear.click(reset_chat, None, [chatbot, msg])
358
 
359
+ with gr.Tab("System Status"):
360
+ init_btn = gr.Button("Initialize System")
361
+ status_box = gr.Textbox(label="Status", value="System not initialized")
 
 
 
 
 
 
 
 
362
 
363
+ @spaces.GPU
364
+ def initialize_system():
365
+ success = service.initialize()
366
+ if success:
367
+ return "System initialized successfully!"
368
+ else:
369
+ return "System initialization failed. Check logs for details."
 
370
 
371
+ init_btn.click(initialize_system, None, status_box)
372
 
373
+ # PDF Check section
374
+ gr.Markdown("### PDF Status")
375
+ pdf_btn = gr.Button("Check PDF Files")
376
+ pdf_status = gr.Textbox(label="PDF Files")
377
 
378
+ def check_pdfs():
379
+ result = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]:
381
  if os.path.exists(pdf_file):
382
  size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB
383
+ result.append(f"{pdf_file}: Found ({size:.2f} MB)")
384
  else:
385
+ result.append(f"{pdf_file}: Not found")
386
+ return "\n".join(result)
387
+
388
+ pdf_btn.click(check_pdfs, None, pdf_status)
389
+
390
+ # System check section
391
+ gr.Markdown("### Dependencies")
392
+ sys_btn = gr.Button("Check Dependencies")
393
+ sys_status = gr.Textbox(label="Dependencies Status")
394
+
395
+ @spaces.GPU
396
+ def check_dependencies():
397
+ result = []
398
+
399
+ # Safe imports inside GPU-decorated function
400
+ try:
401
+ import torch
402
+ result.append(f"✓ PyTorch: {torch.__version__}")
403
+ except ImportError:
404
+ result.append("✗ PyTorch: Not installed")
405
+
406
+ try:
407
+ import transformers
408
+ result.append(f"✓ Transformers: {transformers.__version__}")
409
+ except ImportError:
410
+ result.append("✗ Transformers: Not installed")
411
 
412
+ try:
413
+ import sentencepiece
414
+ result.append("✓ SentencePiece: Installed")
415
+ except ImportError:
416
+ result.append("✗ SentencePiece: Not installed")
417
+
418
+ try:
419
+ import accelerate
420
+ result.append(f"✓ Accelerate: {accelerate.__version__}")
421
+ except ImportError:
422
+ result.append("✗ Accelerate: Not installed")
423
+
424
+ try:
425
+ import langchain
426
+ result.append(f"�� LangChain: {langchain.__version__}")
427
+ except ImportError:
428
+ result.append("✗ LangChain: Not installed")
429
+
430
+ try:
431
+ import langchain_community
432
+ result.append(f"✓ LangChain Community: {langchain_community.__version__}")
433
+ except ImportError:
434
+ result.append("✗ LangChain Community: Not installed")
435
+
436
+ return "\n".join(result)
437
 
438
+ sys_btn.click(check_dependencies, None, sys_status)
 
 
439
 
440
+ with gr.Tab("Sample Questions"):
441
+ gr.Markdown("### Sample Questions to Try")
442
+
443
+ sample_questions = []
444
+
445
+ for item in comprehensive_evaluation_data:
446
+ sample_questions.append(item["query"])
447
+
448
+ questions_md = "\n".join([f"- {q}" for q in sample_questions])
449
+ gr.Markdown(questions_md)
450
+
451
+ return demo
452
 
453
  if __name__ == "__main__":
454
  demo = main()
455
+ demo.queue()
456
  demo.launch()