abdull4h commited on
Commit
8f83e1c
·
verified ·
1 Parent(s): c8b0d13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +497 -444
app.py CHANGED
@@ -1,101 +1,226 @@
1
- # Force install sentencepiece
2
- import sys
3
- import subprocess
4
-
5
- def install_package(package):
6
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
-
8
- try:
9
- import sentencepiece
10
- print("SentencePiece is already installed")
11
- except ImportError:
12
- print("Installing SentencePiece...")
13
- install_package("sentencepiece==0.1.99")
14
- print("SentencePiece installed successfully")
15
-
16
- # Import other required libraries
17
- import gradio as gr
18
  import os
19
  import re
20
  import torch
 
21
  import numpy as np
22
  from pathlib import Path
 
 
 
 
23
  import PyPDF2
24
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
 
 
25
  from sentence_transformers import SentenceTransformer
 
 
26
  from langchain.text_splitter import RecursiveCharacterTextSplitter
27
  from langchain_community.vectorstores import FAISS
28
  from langchain.schema import Document
29
  from langchain.embeddings import HuggingFaceEmbeddings
30
- import spaces
31
 
32
- # Global variables to store model state
33
- model = None
34
- tokenizer = None
35
- assistant = None
36
- model_type = "primary" # Track if we're using primary or fallback model
37
 
38
- # Create the Vision 2030 Assistant class
39
- class Vision2030Assistant:
40
- def __init__(self, model, tokenizer, vector_store, model_type="primary"):
41
- self.model = model
42
- self.tokenizer = tokenizer
43
- self.vector_store = vector_store
44
- self.model_type = model_type
45
- self.conversation_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def answer(self, user_query):
48
- # Detect language
49
- language = detect_language(user_query)
50
 
51
- # Add user query to conversation history
52
- self.conversation_history.append({"role": "user", "content": user_query})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Get the full conversation context
55
- conversation_context = "\n".join([
56
- f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
57
- for msg in self.conversation_history[-6:] # Keep last 3 turns (6 messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Enhance query with conversation context for better retrieval
61
- enhanced_query = f"{conversation_context}\n{user_query}"
62
-
63
- # Retrieve relevant contexts
64
- contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5)
 
 
65
 
66
- # Generate response based on model type
67
- if self.model_type == "primary":
68
- response = generate_response_primary(user_query, contexts, self.model, self.tokenizer, language)
69
- else:
70
- response = generate_response_fallback(user_query, contexts, self.model, self.tokenizer, language)
71
 
72
- # Add response to conversation history
73
- self.conversation_history.append({"role": "assistant", "content": response})
 
74
 
75
- # Also return sources for transparency
76
- sources = [ctx.get("source", "Unknown") for ctx in contexts]
77
- unique_sources = list(set(sources))
78
 
79
- # Format the response with sources
80
- if unique_sources:
81
- source_text = "\n\nSources: " + ", ".join([os.path.basename(src) for src in unique_sources])
82
- response_with_sources = response + source_text
83
- else:
84
- response_with_sources = response
 
85
 
86
- return response_with_sources
87
 
88
- def reset_conversation(self):
89
- """Reset the conversation history"""
90
- self.conversation_history = []
91
- return "Conversation has been reset."
92
-
93
- # Helper functions
94
- def detect_language(text):
95
- """Detect if text is primarily Arabic or English"""
96
- arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
97
- is_arabic = len(arabic_chars) > len(text) * 0.5
98
- return "arabic" if is_arabic else "english"
99
 
100
  def retrieve_context(query, vector_store, top_k=5):
101
  """Retrieve most relevant document chunks for a given query"""
@@ -113,9 +238,8 @@ def retrieve_context(query, vector_store, top_k=5):
113
 
114
  return contexts
115
 
116
- @spaces.GPU
117
- def generate_response_primary(query, contexts, model, tokenizer, language="auto"):
118
- """Generate a response using ALLaM model"""
119
  # Auto-detect language if not specified
120
  if language == "auto":
121
  language = detect_language(query)
@@ -175,403 +299,332 @@ Question: {query} [/INST]</s>"""
175
  # Fallback response
176
  return "I apologize, but I encountered an error while generating a response."
177
 
178
- @spaces.GPU
179
- def generate_response_fallback(query, contexts, model, tokenizer, language="auto"):
180
- """Generate a response using the fallback model (BLOOM or mBART)"""
181
- # Auto-detect language if not specified
182
- if language == "auto":
183
- language = detect_language(query)
184
-
185
- # Format the prompt based on language
186
- if language == "arabic":
187
- system_prompt = (
188
- "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم السياق التالي للإجابة على السؤال: "
189
- )
190
- else:
191
- system_prompt = (
192
- "You are a virtual assistant for Saudi Vision 2030. Use the following context to answer the question: "
193
- )
194
-
195
- # Combine retrieved contexts
196
- context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
197
-
198
- # Format prompt for fallback model (simpler format)
199
- prompt = f"{system_prompt}\n\nContext:\n{context_text}\n\nQuestion: {query}\n\nAnswer:"
200
-
201
- try:
202
- # Generate with fallback model
203
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(model.device)
204
 
205
- outputs = model.generate(
206
- inputs.input_ids,
207
- attention_mask=inputs.attention_mask,
208
- max_length=inputs.input_ids.shape[1] + 512,
209
- temperature=0.7,
210
- top_p=0.9,
211
- do_sample=True,
212
- pad_token_id=tokenizer.eos_token_id
213
- )
214
 
215
- # For most models, this is how we extract the response
216
- response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
217
 
218
- # Cleanup and return
219
- return response.strip()
220
-
221
- except Exception as e:
222
- print(f"Error during fallback generation: {e}")
223
- return "I apologize, but I encountered an error while generating a response with the fallback model."
224
-
225
- def process_pdf_files(pdf_files):
226
- """Process PDF files and create documents"""
227
- documents = []
228
-
229
- for pdf_file in pdf_files:
230
- try:
231
- # Save the uploaded file temporarily
232
- temp_path = f"temp_{pdf_file.name}"
233
- with open(temp_path, "wb") as f:
234
- f.write(pdf_file.read())
235
-
236
- # Extract text
237
- text = ""
238
- with open(temp_path, 'rb') as file:
239
- reader = PyPDF2.PdfReader(file)
240
- for page in reader.pages:
241
- page_text = page.extract_text()
242
- if page_text:
243
- text += page_text + "\n\n"
244
-
245
- # Clean up
246
- os.remove(temp_path)
247
-
248
- if text.strip(): # If we got some text
249
- doc = Document(
250
- page_content=text,
251
- metadata={"source": pdf_file.name, "filename": pdf_file.name}
252
- )
253
- documents.append(doc)
254
- print(f"Successfully processed: {pdf_file.name}")
255
- else:
256
- print(f"Warning: No text extracted from {pdf_file.name}")
257
- except Exception as e:
258
- print(f"Error processing {pdf_file.name}: {e}")
259
-
260
- print(f"Processed {len(documents)} PDF documents")
261
- return documents
262
-
263
- def create_vector_store(documents):
264
- """Create a vector store from documents"""
265
- # Text splitter for breaking documents into chunks
266
- text_splitter = RecursiveCharacterTextSplitter(
267
- chunk_size=500,
268
- chunk_overlap=50,
269
- separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
270
- )
271
-
272
- # Split documents into chunks
273
- chunks = []
274
- for doc in documents:
275
- doc_chunks = text_splitter.split_text(doc.page_content)
276
- # Preserve metadata for each chunk
277
- chunks.extend([
278
- Document(page_content=chunk, metadata=doc.metadata)
279
- for chunk in doc_chunks
280
  ])
281
-
282
- print(f"Created {len(chunks)} chunks from {len(documents)} documents")
283
-
284
- # Create embedding function
285
- embedding_function = HuggingFaceEmbeddings(
286
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
287
- )
288
-
289
- # Create FAISS index
290
- vector_store = FAISS.from_documents(chunks, embedding_function)
291
- return vector_store
292
-
293
- # Attempt to create mock documents if none are available yet
294
- def create_mock_documents():
295
- """Create mock documents about Vision 2030"""
296
- documents = []
297
-
298
- # Sample content about Vision 2030 in both languages
299
- samples = [
300
- {
301
- "content": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.",
302
- "source": "vision2030_overview_ar.txt"
303
- },
304
- {
305
- "content": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.",
306
- "source": "vision2030_overview_en.txt"
307
- },
308
- {
309
- "content": "تشمل الأهداف الاقتصادية لرؤية 2030 زيادة مساهمة القطاع الخاص من 40% إلى 65% من الناتج المحلي الإجمالي، ورفع نسبة الصادرات غير النفطية من 16% إلى 50% من الناتج المحلي الإجمالي غير النفطي، وخفض البطالة إلى 7%.",
310
- "source": "economic_goals_ar.txt"
311
- },
312
- {
313
- "content": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, and reducing unemployment from 11.6% to 7%.",
314
- "source": "economic_goals_en.txt"
315
- },
316
- {
317
- "content": "تركز رؤية 2030 على زيادة مشاركة المرأة في سوق العمل من 22% إلى 30% بحلول عام 2030، مع توفير فرص متساوية في التعليم والعمل.",
318
- "source": "women_empowerment_ar.txt"
319
- },
320
- {
321
- "content": "Vision 2030 emphasizes increasing women's participation in the workforce from 22% to 30% by 2030, while providing equal opportunities in education and employment.",
322
- "source": "women_empowerment_en.txt"
323
- }
324
- ]
325
-
326
- # Create documents from samples
327
- for sample in samples:
328
- doc = Document(
329
- page_content=sample["content"],
330
- metadata={"source": sample["source"], "filename": sample["source"]}
331
- )
332
- documents.append(doc)
333
-
334
- print(f"Created {len(documents)} mock documents")
335
- return documents
336
-
337
- @spaces.GPU
338
- def load_primary_model():
339
- """Load the ALLaM-7B model with error handling"""
340
- global model, tokenizer, model_type
341
-
342
- if model is not None and tokenizer is not None and model_type == "primary":
343
- return "Primary model (ALLaM-7B) already loaded"
344
-
345
- model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
346
- print(f"Loading primary model: {model_name}")
347
-
348
- try:
349
- # Try to import sentencepiece explicitly first
350
- import sentencepiece as spm
351
- print("SentencePiece imported successfully")
352
 
353
- # First attempt with AutoTokenizer and explicit trust_remote_code
354
- tokenizer = AutoTokenizer.from_pretrained(
355
- model_name,
356
- trust_remote_code=True,
357
- use_fast=False
358
- )
359
 
360
- # Load model with appropriate settings for ALLaM
361
- model = AutoModelForCausalLM.from_pretrained(
362
- model_name,
363
- torch_dtype=torch.bfloat16,
364
- trust_remote_code=True,
365
- device_map="auto",
366
- )
367
 
368
- model_type = "primary"
369
- return "Primary model (ALLaM-7B) loaded successfully!"
370
 
371
- except Exception as e:
372
- error_msg = f"Primary model loading failed: {e}"
373
- print(error_msg)
374
- return error_msg
375
-
376
- @spaces.GPU
377
- def load_fallback_model():
378
- """Load the fallback model (BLOOM-7B1) when ALLaM fails"""
379
- global model, tokenizer, model_type
380
-
381
- if model is not None and tokenizer is not None and model_type == "fallback":
382
- return "Fallback model already loaded"
383
-
384
- try:
385
- print("Loading fallback model: BLOOM-7B1...")
386
 
387
- # Use BLOOM model as fallback (it doesn't need SentencePiece)
388
- tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
389
- model = AutoModelForCausalLM.from_pretrained(
390
- "bigscience/bloom-7b1",
391
- torch_dtype=torch.bfloat16,
392
- device_map="auto",
393
- load_in_8bit=True # Reduce memory usage
394
- )
395
 
396
- model_type = "fallback"
397
- return "Fallback model (BLOOM-7B1) loaded successfully!"
398
- except Exception as e:
399
- return f"Fallback model loading failed: {e}"
400
-
401
- def load_mbart_model():
402
- """Load mBART as a second fallback option"""
403
- global model, tokenizer, model_type
404
 
405
- try:
406
- print("Loading mBART multilingual model...")
407
-
408
- model_name = "facebook/mbart-large-50-many-to-many-mmt"
409
- tokenizer = AutoTokenizer.from_pretrained(model_name)
410
- model = AutoModelForSeq2SeqLM.from_pretrained(
411
- model_name,
412
- torch_dtype=torch.float16,
413
- device_map="auto",
414
- load_in_8bit=True
415
- )
416
-
417
- model_type = "mbart"
418
- return "mBART multilingual model loaded successfully!"
419
- except Exception as e:
420
- return f"mBART model loading failed: {e}"
421
 
422
- # Gradio Interface Functions
423
- def process_pdfs(pdf_files):
424
- if not pdf_files:
425
- return "No files uploaded. Please upload PDF documents about Vision 2030."
426
-
427
- documents = process_pdf_files(pdf_files)
428
-
429
- if not documents:
430
- return "Failed to extract text from the uploaded PDFs."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
- global assistant, model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
433
 
434
- # Ensure model is loaded
435
- if model is None or tokenizer is None:
436
- return "Please load a model first (primary or fallback) before processing documents."
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
- # Create vector store
439
- vector_store = create_vector_store(documents)
 
440
 
441
  # Initialize assistant
442
- assistant = Vision2030Assistant(model, tokenizer, vector_store, model_type)
 
443
 
444
- return f"Successfully processed {len(documents)} documents. The assistant is ready to use!"
445
 
446
- def use_mock_documents():
447
- """Use mock documents when no PDFs are available"""
448
- documents = create_mock_documents()
449
-
450
- global assistant, model, tokenizer
451
-
452
- # Ensure model is loaded
453
- if model is None or tokenizer is None:
454
- return "Please load a model first (primary or fallback) before using mock documents."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
- # Create vector store
457
- vector_store = create_vector_store(documents)
 
458
 
459
- # Initialize assistant
460
- assistant = Vision2030Assistant(model, tokenizer, vector_store, model_type)
461
 
462
- return "Successfully initialized with mock Vision 2030 documents. The assistant is ready for testing!"
463
-
464
- @spaces.GPU
465
- def answer_query(message, history):
466
- global assistant
467
 
468
- if assistant is None:
469
- return [(message, "Please load a model and process documents first (or use mock documents for testing).")]
470
-
471
- response = assistant.answer(message)
472
- history.append((message, response))
473
- return history
474
 
475
  def reset_chat():
476
- global assistant
477
-
478
- if assistant is None:
479
- return "No active conversation to reset."
480
-
481
- reset_message = assistant.reset_conversation()
482
- return reset_message
483
 
484
- def restart_factory():
485
- return "Restarting the application... Please reload the page in a few seconds."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  # Create Gradio interface
488
- with gr.Blocks(title="Vision 2030 Virtual Assistant") as demo:
489
- gr.Markdown("# Vision 2030 Virtual Assistant")
490
- gr.Markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.")
491
 
492
- with gr.Tab("Setup"):
493
- gr.Markdown("## Step 1: Load a Model")
494
- with gr.Row():
495
- with gr.Column():
496
- primary_btn = gr.Button("Load ALLaM-7B Model (Primary)", variant="primary")
497
- primary_output = gr.Textbox(label="Primary Model Status")
498
- primary_btn.click(load_primary_model, inputs=[], outputs=primary_output)
499
-
500
- with gr.Column():
501
- fallback_btn = gr.Button("Load BLOOM-7B1 (Fallback)", variant="secondary")
502
- fallback_output = gr.Textbox(label="Fallback Model Status")
503
- fallback_btn.click(load_fallback_model, inputs=[], outputs=fallback_output)
504
-
505
- with gr.Column():
506
- mbart_btn = gr.Button("Load mBART (Alternative)", variant="secondary")
507
- mbart_output = gr.Textbox(label="mBART Model Status")
508
- mbart_btn.click(load_mbart_model, inputs=[], outputs=mbart_output)
509
 
510
- gr.Markdown("## Step 2: Prepare Documents")
511
- with gr.Row():
512
- with gr.Column():
513
- pdf_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDF Documents")
514
- process_btn = gr.Button("Process Documents", variant="primary")
515
- process_output = gr.Textbox(label="Processing Status")
516
- process_btn.click(process_pdfs, inputs=[pdf_files], outputs=process_output)
517
-
518
- with gr.Column():
519
- mock_btn = gr.Button("Use Mock Documents (for testing)", variant="secondary")
520
- mock_output = gr.Textbox(label="Mock Documents Status")
521
- mock_btn.click(use_mock_documents, inputs=[], outputs=mock_output)
522
-
523
- gr.Markdown("## Troubleshooting")
524
- restart_btn = gr.Button("Restart Application", variant="secondary")
525
- restart_output = gr.Textbox(label="Restart Status")
526
- restart_btn.click(restart_factory, inputs=[], outputs=restart_output)
527
- restart_btn.click(None, [], None, _js="() => {setTimeout(() => {location.reload()}, 5000)}")
528
 
529
- with gr.Tab("Chat"):
530
- chatbot = gr.Chatbot(label="Conversation", height=500)
531
-
532
  with gr.Row():
533
- message = gr.Textbox(
534
- label="Ask a question about Vision 2030 (in Arabic or English)",
535
- placeholder="What are the main goals of Vision 2030?",
536
- lines=2
537
- )
538
- submit_btn = gr.Button("Submit", variant="primary")
539
-
540
- reset_btn = gr.Button("Reset Conversation")
541
-
542
- gr.Markdown("### Example Questions")
543
- with gr.Row():
544
- with gr.Column():
545
- gr.Markdown("**English Questions:**")
546
- en_examples = gr.Examples(
547
- examples=[
548
- "What is Saudi Vision 2030?",
549
- "What are the economic goals of Vision 2030?",
550
- "How does Vision 2030 support women's empowerment?",
551
- "What environmental initiatives are part of Vision 2030?",
552
- "What is the role of the Public Investment Fund in Vision 2030?"
553
- ],
554
- inputs=message
555
- )
556
-
557
- with gr.Column():
558
- gr.Markdown("**Arabic Questions:**")
559
- ar_examples = gr.Examples(
560
- examples=[
561
- "ما هي رؤية السعودية 2030؟",
562
- "ما هي الأهداف الاقتصادية لرؤية 2030؟",
563
- "كيف تدعم رؤية 2030 تمكين المرأة السعودية؟",
564
- "ما هي مبادرات رؤية 2030 للحفاظ على البيئة؟",
565
- "ما هي استراتيجية صندوق الاستثمارات العامة في رؤية 2030؟"
566
- ],
567
- inputs=message
568
- )
569
-
570
- reset_output = gr.Textbox(label="Reset Status", visible=False)
571
- submit_btn.click(answer_query, inputs=[message, chatbot], outputs=[chatbot])
572
- message.submit(answer_query, inputs=[message, chatbot], outputs=[chatbot])
573
- reset_btn.click(reset_chat, inputs=[], outputs=[reset_output])
574
- reset_btn.click(lambda: None, inputs=[], outputs=[chatbot], postprocess=lambda: [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
- # Launch the app
577
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import torch
4
+ import gradio as gr
5
  import numpy as np
6
  from pathlib import Path
7
+ from tqdm import tqdm
8
+ import json
9
+
10
+ # PDF processing
11
  import PyPDF2
12
+
13
+ # LLM and embeddings
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
  from sentence_transformers import SentenceTransformer
16
+
17
+ # RAG components
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
19
  from langchain_community.vectorstores import FAISS
20
  from langchain.schema import Document
21
  from langchain.embeddings import HuggingFaceEmbeddings
 
22
 
23
+ # Arabic text processing
24
+ import arabic_reshaper
25
+ from bidi.algorithm import get_display
 
 
26
 
27
+ # Evaluation
28
+ from rouge_score import rouge_scorer
29
+
30
+ # Helper functions from your notebook
31
+ def detect_language(text):
32
+ """Detect if text is primarily Arabic or English"""
33
+ # Simple heuristic: count Arabic characters
34
+ arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
35
+ is_arabic = len(arabic_chars) > len(text) * 0.5
36
+ return "arabic" if is_arabic else "english"
37
+
38
+ def safe_tokenize(text):
39
+ """Pure regex tokenizer with no NLTK dependency"""
40
+ if not text:
41
+ return []
42
+ # Replace punctuation with spaces around them
43
+ text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text)
44
+ # Split on whitespace and filter empty strings
45
+ return [token for token in re.split(r'\s+', text.lower()) if token]
46
+
47
+ # Evaluation metric functions
48
+ def calculate_bleu(prediction, reference):
49
+ """Calculate BLEU score without any NLTK dependency"""
50
+ # Tokenize texts using our own tokenizer
51
+ pred_tokens = safe_tokenize(prediction.lower())
52
+ ref_tokens = [safe_tokenize(reference.lower())]
53
+
54
+ # If either is empty, return 0
55
+ if not pred_tokens or not ref_tokens[0]:
56
+ return {"bleu_1": 0, "bleu_2": 0, "bleu_4": 0}
57
+
58
+ # Get n-grams function
59
+ def get_ngrams(tokens, n):
60
+ return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
61
+
62
+ # Calculate precision for each n-gram level
63
+ precisions = []
64
+ for n in range(1, 5): # 1-gram to 4-gram
65
+ if len(pred_tokens) < n:
66
+ precisions.append(0)
67
+ continue
68
+
69
+ pred_ngrams = get_ngrams(pred_tokens, n)
70
+ ref_ngrams = get_ngrams(ref_tokens[0], n)
71
 
72
+ # Count matches
73
+ matches = sum(1 for ng in pred_ngrams if ng in ref_ngrams)
 
74
 
75
+ # Calculate precision
76
+ if pred_ngrams:
77
+ precisions.append(matches / len(pred_ngrams))
78
+ else:
79
+ precisions.append(0)
80
+
81
+ # Return BLEU scores
82
+ return {
83
+ "bleu_1": precisions[0],
84
+ "bleu_2": (precisions[0] * precisions[1]) ** 0.5 if len(precisions) > 1 else 0,
85
+ "bleu_4": (precisions[0] * precisions[1] * precisions[2] * precisions[3]) ** 0.25 if len(precisions) > 3 else 0
86
+ }
87
+
88
+ def calculate_meteor(prediction, reference):
89
+ """Simple word overlap metric as METEOR alternative"""
90
+ # Tokenize with our custom tokenizer
91
+ pred_tokens = set(safe_tokenize(prediction.lower()))
92
+ ref_tokens = set(safe_tokenize(reference.lower()))
93
+
94
+ # Calculate Jaccard similarity as METEOR alternative
95
+ if not pred_tokens or not ref_tokens:
96
+ return 0
97
 
98
+ intersection = len(pred_tokens.intersection(ref_tokens))
99
+ union = len(pred_tokens.union(ref_tokens))
100
+
101
+ return intersection / union if union > 0 else 0
102
+
103
+ def calculate_f1_precision_recall(prediction, reference):
104
+ """Calculate word-level F1, precision, and recall with custom tokenizer"""
105
+ # Tokenize with our custom tokenizer
106
+ pred_tokens = set(safe_tokenize(prediction.lower()))
107
+ ref_tokens = set(safe_tokenize(reference.lower()))
108
+
109
+ # Calculate overlap
110
+ common = pred_tokens.intersection(ref_tokens)
111
+
112
+ # Calculate precision, recall, F1
113
+ precision = len(common) / len(pred_tokens) if pred_tokens else 0
114
+ recall = len(common) / len(ref_tokens) if ref_tokens else 0
115
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
116
+
117
+ return {'precision': precision, 'recall': recall, 'f1': f1}
118
+
119
+ # Load PDFs and create vector store
120
+ def process_pdfs(pdf_files):
121
+ """Process uploaded PDF documents and return document objects"""
122
+ documents = []
123
+
124
+ for pdf_path in pdf_files:
125
+ try:
126
+ text = ""
127
+ with open(pdf_path, 'rb') as file:
128
+ reader = PyPDF2.PdfReader(file)
129
+ for page in reader.pages:
130
+ page_text = page.extract_text()
131
+ if page_text: # If we got text from this page
132
+ text += page_text + "\n\n"
133
+
134
+ if text.strip(): # If we got some text
135
+ doc = Document(
136
+ page_content=text,
137
+ metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
138
+ )
139
+ documents.append(doc)
140
+ print(f"Successfully processed: {pdf_path}")
141
+ else:
142
+ print(f"Warning: No text extracted from {pdf_path}")
143
+ except Exception as e:
144
+ print(f"Error processing {pdf_path}: {e}")
145
+
146
+ print(f"Processed {len(documents)} PDF documents")
147
+ return documents
148
+
149
+ def create_vector_store(documents):
150
+ """Split documents into chunks and create a FAISS vector store"""
151
+ # Text splitter for breaking documents into chunks
152
+ text_splitter = RecursiveCharacterTextSplitter(
153
+ chunk_size=500,
154
+ chunk_overlap=50,
155
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
156
+ )
157
+
158
+ # Split documents into chunks
159
+ chunks = []
160
+ for doc in documents:
161
+ doc_chunks = text_splitter.split_text(doc.page_content)
162
+ # Preserve metadata for each chunk
163
+ chunks.extend([
164
+ Document(page_content=chunk, metadata=doc.metadata)
165
+ for chunk in doc_chunks
166
  ])
167
+
168
+ print(f"Created {len(chunks)} chunks from {len(documents)} documents")
169
+
170
+ # Create a proper embedding function for LangChain
171
+ embedding_function = HuggingFaceEmbeddings(
172
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
173
+ )
174
+
175
+ # Create FAISS index
176
+ vector_store = FAISS.from_documents(
177
+ chunks,
178
+ embedding_function
179
+ )
180
+
181
+ return vector_store
182
+
183
+ def load_model_and_tokenizer():
184
+ """Load the ALLaM-7B model and tokenizer with error handling"""
185
+ model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
186
+ print(f"Loading model: {model_name}")
187
+
188
+ try:
189
+ # First attempt with AutoTokenizer
190
+ tokenizer = AutoTokenizer.from_pretrained(
191
+ model_name,
192
+ trust_remote_code=True,
193
+ use_fast=False
194
+ )
195
 
196
+ # Load model with appropriate settings for ALLaM
197
+ model = AutoModelForCausalLM.from_pretrained(
198
+ model_name,
199
+ torch_dtype=torch.bfloat16, # Use bfloat16 for better compatibility
200
+ trust_remote_code=True,
201
+ device_map="auto",
202
+ )
203
 
204
+ print("Model loaded successfully with AutoTokenizer!")
 
 
 
 
205
 
206
+ except Exception as e:
207
+ print(f"First loading attempt failed: {e}")
208
+ print("Trying alternative loading approach...")
209
 
210
+ # Try with specific tokenizer class if the first attempt fails
211
+ from transformers import LlamaTokenizer
 
212
 
213
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
214
+ model = AutoModelForCausalLM.from_pretrained(
215
+ model_name,
216
+ torch_dtype=torch.float16,
217
+ trust_remote_code=True,
218
+ device_map="auto",
219
+ )
220
 
221
+ print("Model loaded successfully with LlamaTokenizer!")
222
 
223
+ return model, tokenizer
 
 
 
 
 
 
 
 
 
 
224
 
225
  def retrieve_context(query, vector_store, top_k=5):
226
  """Retrieve most relevant document chunks for a given query"""
 
238
 
239
  return contexts
240
 
241
+ def generate_response(query, contexts, model, tokenizer, language="auto"):
242
+ """Generate a response using retrieved contexts with ALLaM-specific formatting"""
 
243
  # Auto-detect language if not specified
244
  if language == "auto":
245
  language = detect_language(query)
 
299
  # Fallback response
300
  return "I apologize, but I encountered an error while generating a response."
301
 
302
+ # Assistant class
303
+ class Vision2030Assistant:
304
+ def __init__(self, model, tokenizer, vector_store):
305
+ self.model = model
306
+ self.tokenizer = tokenizer
307
+ self.vector_store = vector_store
308
+ self.conversation_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ def answer(self, user_query):
311
+ """Process a user query and return a response with sources"""
312
+ # Detect language
313
+ language = detect_language(user_query)
 
 
 
 
 
314
 
315
+ # Add user query to conversation history
316
+ self.conversation_history.append({"role": "user", "content": user_query})
317
 
318
+ # Get the full conversation context
319
+ conversation_context = "\n".join([
320
+ f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
321
+ for msg in self.conversation_history[-6:] # Keep last 3 turns (6 messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Enhance query with conversation context for better retrieval
325
+ enhanced_query = f"{conversation_context}\n{user_query}"
 
 
 
 
326
 
327
+ # Retrieve relevant contexts
328
+ contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5)
 
 
 
 
 
329
 
330
+ # Generate response
331
+ response = generate_response(user_query, contexts, self.model, self.tokenizer, language)
332
 
333
+ # Add response to conversation history
334
+ self.conversation_history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ # Also return sources for transparency
337
+ sources = [ctx.get("source", "Unknown") for ctx in contexts]
338
+ unique_sources = list(set(sources))
 
 
 
 
 
339
 
340
+ return response, unique_sources, contexts
 
 
 
 
 
 
 
341
 
342
+ def reset_conversation(self):
343
+ """Reset the conversation history"""
344
+ self.conversation_history = []
345
+ return "Conversation has been reset."
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
+ # Sample evaluation data (subset)
348
+ sample_evaluation_data = [
349
+ {
350
+ "query": "ما هي رؤية السعودية 2030؟",
351
+ "reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.",
352
+ "category": "overview",
353
+ "language": "arabic"
354
+ },
355
+ {
356
+ "query": "What is Saudi Vision 2030?",
357
+ "reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.",
358
+ "category": "overview",
359
+ "language": "english"
360
+ },
361
+ {
362
+ "query": "ما هي الأهداف الاقتصادية لرؤية 2030؟",
363
+ "reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.",
364
+ "category": "economic",
365
+ "language": "arabic"
366
+ },
367
+ {
368
+ "query": "What are the economic goals of Vision 2030?",
369
+ "reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.",
370
+ "category": "economic",
371
+ "language": "english"
372
+ },
373
+ {
374
+ "query": "How does Vision 2030 support small and medium enterprises (SMEs)?",
375
+ "reference": "Vision 2030 supports SMEs by increasing their GDP contribution, facilitating access to funding, and reducing regulatory obstacles.",
376
+ "category": "economic",
377
+ "language": "english"
378
+ }
379
+ ]
380
+
381
+ # Global variables for storing state
382
+ ASSISTANT = None
383
+ MODEL = None
384
+ TOKENIZER = None
385
+ VECTOR_STORE = None
386
+ PDF_PATHS = ["vision2030_docs/saudi_vision203.pdf", "vision2030_docs/saudi_vision2030_ar.pdf"]
387
+
388
+ # Initialize evaluation
389
+ rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
390
+
391
+ def initialize_system():
392
+ global MODEL, TOKENIZER, VECTOR_STORE, ASSISTANT
393
 
394
+ # Try to load from saved files first
395
+ if os.path.exists("data/vision2030_vector_store"):
396
+ print("Loading vector store from saved file...")
397
+ try:
398
+ embedding_function = HuggingFaceEmbeddings(
399
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
400
+ )
401
+ VECTOR_STORE = FAISS.load_local("data/vision2030_vector_store", embedding_function)
402
+ print("Vector store loaded successfully!")
403
+ except Exception as e:
404
+ print(f"Error loading vector store: {e}")
405
+ VECTOR_STORE = None
406
 
407
+ # If vector store not loaded, process PDFs and create it
408
+ if VECTOR_STORE is None:
409
+ print("Processing PDF documents...")
410
+ vision2030_docs = process_pdfs(PDF_PATHS)
411
+
412
+ if not vision2030_docs:
413
+ return "Error: No documents were processed. Cannot continue."
414
+
415
+ print("Creating vector store...")
416
+ VECTOR_STORE = create_vector_store(vision2030_docs)
417
+
418
+ # Save the vector store for future use
419
+ os.makedirs("data", exist_ok=True)
420
+ VECTOR_STORE.save_local("data/vision2030_vector_store")
421
+ print("Vector store saved to data/vision2030_vector_store")
422
 
423
+ # Load model and tokenizer
424
+ print("Loading ALLaM-7B model...")
425
+ MODEL, TOKENIZER = load_model_and_tokenizer()
426
 
427
  # Initialize assistant
428
+ ASSISTANT = Vision2030Assistant(MODEL, TOKENIZER, VECTOR_STORE)
429
+ print("Vision 2030 Assistant initialized successfully!")
430
 
431
+ return "System initialized and ready!"
432
 
433
+ def process_query(query, reference=None):
434
+ """Process a user query and return the response with evaluation if reference is provided"""
435
+ if ASSISTANT is None:
436
+ return "System not initialized. Please initialize first.", "", "", "", ""
437
+
438
+ # Process query
439
+ response, sources, contexts = ASSISTANT.answer(query)
440
+
441
+ # Additional details
442
+ language = detect_language(query)
443
+ source_text = "\n".join([f"Source: {s}" for s in sources])
444
+ context_text = "\n\n".join([f"Context {i+1}: {ctx['content'][:200]}..." for i, ctx in enumerate(contexts)])
445
+
446
+ # Calculate metrics if reference is provided
447
+ metrics_text = ""
448
+ if reference:
449
+ # ROUGE scores
450
+ rouge_scores = rouge_scorer_instance.score(response, reference)
451
+
452
+ # BLEU scores
453
+ bleu_scores = calculate_bleu(response, reference)
454
+
455
+ # METEOR score
456
+ meteor = calculate_meteor(response, reference)
457
+
458
+ # F1, Precision, Recall
459
+ word_metrics = calculate_f1_precision_recall(response, reference)
460
+
461
+ # Format metrics text
462
+ metrics_text = f"""
463
+ ## Evaluation Metrics:
464
+ - **ROUGE-1**: {rouge_scores['rouge1'].fmeasure:.4f}
465
+ - **ROUGE-L**: {rouge_scores['rougeL'].fmeasure:.4f}
466
+ - **BLEU-1**: {bleu_scores['bleu_1']:.4f}
467
+ - **BLEU-4**: {bleu_scores['bleu_4']:.4f}
468
+ - **METEOR**: {meteor:.4f}
469
+ - **Word F1**: {word_metrics['f1']:.4f}
470
+ - **Word Precision**: {word_metrics['precision']:.4f}
471
+ - **Word Recall**: {word_metrics['recall']:.4f}
472
+ """
473
+
474
+ return response, source_text, context_text, metrics_text, language
475
+
476
+ def evaluate_sample(sample_index):
477
+ """Evaluate a sample from the predefined evaluation dataset"""
478
+ if sample_index < 0 or sample_index >= len(sample_evaluation_data):
479
+ return "Invalid sample index", "", "", "", ""
480
 
481
+ sample = sample_evaluation_data[sample_index]
482
+ query = sample["query"]
483
+ reference = sample["reference"]
484
 
485
+ # Process the query with the reference for evaluation
486
+ response, source_text, context_text, metrics_text, language = process_query(query, reference)
487
 
488
+ # Add reference to the output
489
+ reference_text = f"""
490
+ ## Reference Answer:
491
+ {reference}
492
+ """
493
 
494
+ return response, source_text, context_text, metrics_text + reference_text, language
 
 
 
 
 
495
 
496
  def reset_chat():
497
+ """Reset the conversation history"""
498
+ if ASSISTANT:
499
+ ASSISTANT.reset_conversation()
500
+ return "Conversation has been reset."
501
+ return "System not initialized."
 
 
502
 
503
+ def qualitative_feedback(response, user_feedback, feedback_type):
504
+ """Save qualitative feedback from users"""
505
+ try:
506
+ feedback_data = {
507
+ "response": response,
508
+ "user_feedback": user_feedback,
509
+ "feedback_type": feedback_type,
510
+ "timestamp": str(datetime.datetime.now())
511
+ }
512
+
513
+ # Ensure directory exists
514
+ os.makedirs("feedback", exist_ok=True)
515
+
516
+ # Append to feedback file
517
+ with open("feedback/user_feedback.jsonl", "a") as f:
518
+ f.write(json.dumps(feedback_data) + "\n")
519
+
520
+ return f"Thank you for your {feedback_type} feedback!"
521
+ except Exception as e:
522
+ return f"Error saving feedback: {e}"
523
 
524
  # Create Gradio interface
525
+ with gr.Blocks(title="Vision 2030 Assistant - Qualitative Evaluation") as demo:
526
+ gr.Markdown("# Vision 2030 Virtual Assistant - Qualitative Evaluation")
527
+ gr.Markdown("This interface allows you to interact with and evaluate the multilingual Vision 2030 Assistant.")
528
 
529
+ with gr.Tab("System Initialization"):
530
+ init_button = gr.Button("Initialize System")
531
+ init_output = gr.Textbox(label="Initialization Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
+ init_button.click(initialize_system, inputs=[], outputs=[init_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ with gr.Tab("Chat & Evaluation"):
 
 
536
  with gr.Row():
537
+ with gr.Column(scale=2):
538
+ query_input = gr.Textbox(label="Ask about Saudi Vision 2030 (in English or Arabic)", lines=3)
539
+ reference_input = gr.Textbox(label="Reference Answer (Optional - for evaluation)", lines=3)
540
+
541
+ with gr.Row():
542
+ submit_btn = gr.Button("Submit")
543
+ reset_btn = gr.Button("Reset Chat")
544
+
545
+ response_output = gr.Textbox(label="Response", lines=6)
546
+
547
+ with gr.Accordion("Evaluation Metrics", open=False):
548
+ metrics_output = gr.Markdown()
549
+
550
+ with gr.Accordion("Retrieved Sources", open=False):
551
+ sources_output = gr.Textbox(label="Sources")
552
+
553
+ with gr.Accordion("Retrieved Contexts", open=False):
554
+ contexts_output = gr.Textbox(label="Contexts", lines=10)
555
+
556
+ with gr.Accordion("Qualitative Feedback", open=False):
557
+ feedback_text = gr.Textbox(label="Your Feedback", lines=3)
558
+ feedback_type = gr.Radio(
559
+ ["Correctness", "Relevance", "Fluency", "Completeness", "Other"],
560
+ label="Feedback Type"
561
+ )
562
+ feedback_btn = gr.Button("Submit Feedback")
563
+ feedback_output = gr.Textbox(label="Feedback Status")
564
+
565
+ with gr.Tab("Sample Evaluation"):
566
+ sample_index = gr.Slider(0, len(sample_evaluation_data)-1, 0, step=1, label="Sample Index")
567
+ eval_btn = gr.Button("Evaluate Sample")
568
+
569
+ sample_response = gr.Textbox(label="Response", lines=6)
570
+ sample_metrics = gr.Markdown(label="Metrics & Reference")
571
+
572
+ with gr.Accordion("Retrieved Sources", open=False):
573
+ sample_sources = gr.Textbox(label="Sources")
574
+
575
+ with gr.Accordion("Retrieved Contexts", open=False):
576
+ sample_contexts = gr.Textbox(label="Contexts", lines=10)
577
+
578
+ with gr.Tab("About"):
579
+ gr.Markdown("""
580
+ ## Vision 2030 Assistant
581
+
582
+ This is a multilingual RAG-based Conversational Agent using ALLaM-7B for answering questions about Saudi Vision 2030.
583
+
584
+ ### Features:
585
+ - Supports both Arabic and English queries
586
+ - Uses Retrieval-Augmented Generation (RAG) for accurate answers
587
+ - Provides transparent sources for information
588
+ - Comprehensive evaluation metrics
589
+
590
+ ### How to use:
591
+ 1. Initialize the system (first tab)
592
+ 2. Ask questions about Saudi Vision 2030 in the Chat tab
593
+ 3. Optionally provide reference answers for evaluation
594
+ 4. Explore sample evaluations from our test dataset
595
+
596
+ ### Evaluation Metrics:
597
+ - ROUGE: Measures overlap of n-grams between response and reference
598
+ - BLEU: Measures precision of n-grams in the response compared to reference
599
+ - METEOR: Measures semantic similarity between response and reference
600
+ - F1/Precision/Recall: Word-level comparison metrics
601
+ """)
602
+
603
+ # Set up event handlers
604
+ submit_btn.click(
605
+ process_query,
606
+ inputs=[query_input, reference_input],
607
+ outputs=[response_output, sources_output, contexts_output, metrics_output]
608
+ )
609
+
610
+ reset_btn.click(
611
+ reset_chat,
612
+ inputs=[],
613
+ outputs=[response_output]
614
+ )
615
+
616
+ eval_btn.click(
617
+ evaluate_sample,
618
+ inputs=[sample_index],
619
+ outputs=[sample_response, sample_sources, sample_contexts, sample_metrics]
620
+ )
621
+
622
+ feedback_btn.click(
623
+ qualitative_feedback,
624
+ inputs=[response_output, feedback_text, feedback_type],
625
+ outputs=[feedback_output]
626
+ )
627
 
628
+ # Launch the interface
629
+ if __name__ == "__main__":
630
+ demo.launch()