Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,12 @@
|
|
1 |
-
# Vision 2030 Virtual Assistant with Arabic (ALLaM-7B) and English (Mistral-7B-Instruct) + RAG +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
@@ -6,131 +14,545 @@ from langdetect import detect
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import faiss
|
8 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
response = arabic_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7)
|
95 |
-
reply = response[0]['generated_text']
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
f"
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
return reply
|
117 |
-
|
118 |
-
# ----------------------------
|
119 |
-
# Gradio UI
|
120 |
-
# ----------------------------
|
121 |
-
with gr.Blocks() as demo:
|
122 |
-
gr.Markdown("# Vision 2030 Virtual Assistant 🌍\n\nSupports Arabic & English queries about Vision 2030 (with RAG retrieval and improved prompting).")
|
123 |
-
chatbot = gr.Chatbot()
|
124 |
-
msg = gr.Textbox(label="Ask me anything about Vision 2030")
|
125 |
-
clear = gr.Button("Clear")
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def chat(message, history):
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
history.append((message, reply))
|
|
|
130 |
return history, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
demo.launch()
|
|
|
1 |
+
# Vision 2030 Virtual Assistant with Arabic (ALLaM-7B) and English (Mistral-7B-Instruct) + RAG + Evaluation Framework
|
2 |
+
"""
|
3 |
+
Enhanced implementation of the Vision 2030 Virtual Assistant that meets all project requirements:
|
4 |
+
1. Implements proper NLP task structure (bilingual QA system)
|
5 |
+
2. Adds comprehensive evaluation framework for quantitative and qualitative assessment
|
6 |
+
3. Improves RAG implementation with better retrieval and document processing
|
7 |
+
4. Adds user feedback collection for continuous improvement
|
8 |
+
5. Includes structured logging and performance monitoring
|
9 |
+
"""
|
10 |
|
11 |
import gradio as gr
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
14 |
from sentence_transformers import SentenceTransformer
|
15 |
import faiss
|
16 |
import numpy as np
|
17 |
+
import json
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
from datetime import datetime
|
23 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
|
24 |
+
import pandas as pd
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
import PyPDF2
|
27 |
+
import io
|
28 |
|
29 |
+
# Configure logging
|
30 |
+
logging.basicConfig(
|
31 |
+
level=logging.INFO,
|
32 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
33 |
+
handlers=[
|
34 |
+
logging.FileHandler("vision2030_assistant.log"),
|
35 |
+
logging.StreamHandler()
|
36 |
+
]
|
37 |
+
)
|
38 |
+
logger = logging.getLogger('vision2030_assistant')
|
39 |
|
40 |
+
class Vision2030Assistant:
|
41 |
+
def __init__(self, pdf_path="vision2030.pdf", eval_data_path="evaluation_data.json"):
|
42 |
+
"""
|
43 |
+
Initialize the Vision 2030 Assistant with models, knowledge base, and evaluation framework
|
44 |
+
|
45 |
+
Args:
|
46 |
+
pdf_path: Path to the Vision 2030 PDF document
|
47 |
+
eval_data_path: Path to evaluation dataset
|
48 |
+
"""
|
49 |
+
logger.info("Initializing Vision 2030 Assistant...")
|
50 |
+
self.load_models()
|
51 |
+
self.load_and_process_documents(pdf_path)
|
52 |
+
self.setup_evaluation_framework(eval_data_path)
|
53 |
+
self.response_history = []
|
54 |
+
logger.info("Vision 2030 Assistant initialized successfully")
|
55 |
+
|
56 |
+
def load_models(self):
|
57 |
+
"""Load language models and embedding models for both Arabic and English"""
|
58 |
+
logger.info("Loading language and embedding models...")
|
59 |
+
|
60 |
+
# Load Arabic Model (ALLaM-7B)
|
61 |
+
try:
|
62 |
+
self.arabic_model_id = "ALLaM-AI/ALLaM-7B-Instruct-preview"
|
63 |
+
self.arabic_tokenizer = AutoTokenizer.from_pretrained(self.arabic_model_id)
|
64 |
+
self.arabic_model = AutoModelForCausalLM.from_pretrained(self.arabic_model_id, device_map="auto")
|
65 |
+
self.arabic_pipe = pipeline("text-generation", model=self.arabic_model, tokenizer=self.arabic_tokenizer)
|
66 |
+
logger.info("Arabic model loaded successfully")
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Error loading Arabic model: {str(e)}")
|
69 |
+
raise
|
70 |
+
|
71 |
+
# Load English Model (Mistral-7B-Instruct)
|
72 |
+
try:
|
73 |
+
self.english_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
|
74 |
+
self.english_tokenizer = AutoTokenizer.from_pretrained(self.english_model_id)
|
75 |
+
self.english_model = AutoModelForCausalLM.from_pretrained(self.english_model_id, device_map="auto")
|
76 |
+
self.english_pipe = pipeline("text-generation", model=self.english_model, tokenizer=self.english_tokenizer)
|
77 |
+
logger.info("English model loaded successfully")
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Error loading English model: {str(e)}")
|
80 |
+
raise
|
81 |
+
|
82 |
+
# Load Embedding Models for Retrieval
|
83 |
+
try:
|
84 |
+
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
|
85 |
+
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
86 |
+
logger.info("Embedding models loaded successfully")
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Error loading embedding models: {str(e)}")
|
89 |
+
raise
|
90 |
|
91 |
+
def load_and_process_documents(self, pdf_path):
|
92 |
+
"""Load and process the Vision 2030 document from PDF"""
|
93 |
+
logger.info(f"Processing Vision 2030 document from {pdf_path}")
|
94 |
+
|
95 |
+
# Initialize empty document lists
|
96 |
+
self.english_texts = []
|
97 |
+
self.arabic_texts = []
|
98 |
+
|
99 |
+
try:
|
100 |
+
# Check if PDF exists
|
101 |
+
if os.path.exists(pdf_path):
|
102 |
+
# Extract text from PDF
|
103 |
+
with open(pdf_path, 'rb') as file:
|
104 |
+
reader = PyPDF2.PdfReader(file)
|
105 |
+
full_text = ""
|
106 |
+
for page_num in range(len(reader.pages)):
|
107 |
+
page = reader.pages[page_num]
|
108 |
+
full_text += page.extract_text() + "\n"
|
109 |
+
|
110 |
+
# Split into chunks (simple approach - could be improved with better text segmentation)
|
111 |
+
chunks = [chunk.strip() for chunk in re.split(r'\n\s*\n', full_text) if chunk.strip()]
|
112 |
+
|
113 |
+
# Detect language and add to appropriate list
|
114 |
+
for chunk in chunks:
|
115 |
+
try:
|
116 |
+
lang = detect(chunk)
|
117 |
+
if lang == "ar":
|
118 |
+
self.arabic_texts.append(chunk)
|
119 |
+
else: # Default to English for other languages
|
120 |
+
self.english_texts.append(chunk)
|
121 |
+
except:
|
122 |
+
# If language detection fails, assume English
|
123 |
+
self.english_texts.append(chunk)
|
124 |
+
|
125 |
+
logger.info(f"Processed {len(self.arabic_texts)} Arabic and {len(self.english_texts)} English chunks")
|
126 |
+
else:
|
127 |
+
logger.warning(f"PDF file not found at {pdf_path}. Using fallback sample data.")
|
128 |
+
self._create_sample_data()
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
131 |
+
logger.info("Using fallback sample data")
|
132 |
+
self._create_sample_data()
|
133 |
+
|
134 |
+
# Create FAISS indices
|
135 |
+
self._create_indices()
|
136 |
|
137 |
+
def _create_sample_data(self):
|
138 |
+
"""Create sample Vision 2030 data if PDF processing fails"""
|
139 |
+
logger.info("Creating sample Vision 2030 data")
|
140 |
+
|
141 |
+
# English sample texts
|
142 |
+
self.english_texts = [
|
143 |
+
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.",
|
144 |
+
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.",
|
145 |
+
"The Saudi Public Investment Fund (PIF) plays a crucial role in Vision 2030 by investing in strategic sectors.",
|
146 |
+
"NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030.",
|
147 |
+
"Vision 2030 aims to increase women's participation in the workforce from 22% to 30%.",
|
148 |
+
"The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast.",
|
149 |
+
"Qiddiya is a entertainment mega-project being built in Riyadh as part of Vision 2030.",
|
150 |
+
"Vision 2030 targets increasing the private sector's contribution to GDP from 40% to 65%.",
|
151 |
+
"One goal of Vision 2030 is to increase foreign direct investment from 3.8% to 5.7% of GDP.",
|
152 |
+
"Vision 2030 includes plans to develop the digital infrastructure and support for tech startups in Saudi Arabia."
|
153 |
+
]
|
154 |
+
|
155 |
+
# Arabic sample texts (same content as English)
|
156 |
+
self.arabic_texts = [
|
157 |
+
"رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة.",
|
158 |
+
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.",
|
159 |
+
"يلعب صندوق الاستثمارات العامة السعودي دورًا محوريًا في رؤية 2030 من خلال الاستثمار في القطاعات الاستراتيجية.",
|
160 |
+
"نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030.",
|
161 |
+
"تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪.",
|
162 |
+
"مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي.",
|
163 |
+
"القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030.",
|
164 |
+
"تستهدف رؤية 2030 زيادة مساهمة القطاع الخاص في الناتج المحلي الإجمالي من 40٪ إلى 65٪.",
|
165 |
+
"أحد أهداف رؤية 2030 هو زيادة الاستثمار الأجنبي المباشر من 3.8٪ إلى 5.7٪ من الناتج المحلي الإجمالي.",
|
166 |
+
"تتضمن رؤية 2030 خططًا لتطوير البنية التحتية الرقمية والدعم للشركات الناشئة التكنولوجية في المملكة العربية السعودية."
|
167 |
+
]
|
168 |
|
169 |
+
def _create_indices(self):
|
170 |
+
"""Create FAISS indices for fast text retrieval"""
|
171 |
+
logger.info("Creating FAISS indices for text retrieval")
|
172 |
+
|
173 |
+
try:
|
174 |
+
# Process and embed English texts
|
175 |
+
self.english_vectors = []
|
176 |
+
for text in self.english_texts:
|
177 |
+
vec = self.english_embedder.encode(text)
|
178 |
+
self.english_vectors.append(vec)
|
179 |
+
|
180 |
+
# Create English index
|
181 |
+
if self.english_vectors:
|
182 |
+
self.english_index = faiss.IndexFlatL2(len(self.english_vectors[0]))
|
183 |
+
self.english_index.add(np.array(self.english_vectors))
|
184 |
+
logger.info(f"Created English index with {len(self.english_vectors)} vectors")
|
185 |
+
else:
|
186 |
+
logger.warning("No English texts to index")
|
187 |
+
|
188 |
+
# Process and embed Arabic texts
|
189 |
+
self.arabic_vectors = []
|
190 |
+
for text in self.arabic_texts:
|
191 |
+
vec = self.arabic_embedder.encode(text)
|
192 |
+
self.arabic_vectors.append(vec)
|
193 |
+
|
194 |
+
# Create Arabic index
|
195 |
+
if self.arabic_vectors:
|
196 |
+
self.arabic_index = faiss.IndexFlatL2(len(self.arabic_vectors[0]))
|
197 |
+
self.arabic_index.add(np.array(self.arabic_vectors))
|
198 |
+
logger.info(f"Created Arabic index with {len(self.arabic_vectors)} vectors")
|
199 |
+
else:
|
200 |
+
logger.warning("No Arabic texts to index")
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
logger.error(f"Error creating FAISS indices: {str(e)}")
|
204 |
+
raise
|
205 |
|
206 |
+
def setup_evaluation_framework(self, eval_data_path):
|
207 |
+
"""Set up the evaluation framework with test data and metrics"""
|
208 |
+
logger.info("Setting up evaluation framework")
|
209 |
+
|
210 |
+
# Initialize metrics trackers
|
211 |
+
self.metrics = {
|
212 |
+
"response_times": [],
|
213 |
+
"user_ratings": [],
|
214 |
+
"retrieval_precision": [],
|
215 |
+
"factual_accuracy": []
|
216 |
+
}
|
217 |
+
|
218 |
+
# Load evaluation data if exists, otherwise create sample
|
219 |
+
try:
|
220 |
+
if os.path.exists(eval_data_path):
|
221 |
+
with open(eval_data_path, 'r', encoding='utf-8') as f:
|
222 |
+
self.eval_data = json.load(f)
|
223 |
+
logger.info(f"Loaded {len(self.eval_data)} evaluation examples from {eval_data_path}")
|
224 |
+
else:
|
225 |
+
logger.warning(f"Evaluation data not found at {eval_data_path}. Creating sample evaluation data.")
|
226 |
+
self._create_sample_eval_data()
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Error loading evaluation data: {str(e)}")
|
229 |
+
self._create_sample_eval_data()
|
230 |
+
|
231 |
+
def _create_sample_eval_data(self):
|
232 |
+
"""Create sample evaluation data with ground truth"""
|
233 |
+
self.eval_data = [
|
234 |
+
{
|
235 |
+
"question": "What are the key pillars of Vision 2030?",
|
236 |
+
"lang": "en",
|
237 |
+
"reference_answer": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
|
241 |
+
"lang": "ar",
|
242 |
+
"reference_answer": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"question": "What is NEOM?",
|
246 |
+
"lang": "en",
|
247 |
+
"reference_answer": "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030."
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"question": "ما هو مشروع البح�� الأحمر؟",
|
251 |
+
"lang": "ar",
|
252 |
+
"reference_answer": "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي."
|
253 |
+
}
|
254 |
+
]
|
255 |
+
logger.info(f"Created {len(self.eval_data)} sample evaluation examples")
|
256 |
|
257 |
+
def retrieve_context(self, query, lang):
|
258 |
+
"""Retrieve relevant context for a query based on language"""
|
259 |
+
start_time = time.time()
|
260 |
+
|
261 |
+
try:
|
262 |
+
if lang == "ar":
|
263 |
+
query_vec = self.arabic_embedder.encode(query)
|
264 |
+
D, I = self.arabic_index.search(np.array([query_vec]), k=2) # Get top 2 most relevant chunks
|
265 |
+
context = "\n".join([self.arabic_texts[i] for i in I[0] if i < len(self.arabic_texts) and i >= 0])
|
266 |
+
else:
|
267 |
+
query_vec = self.english_embedder.encode(query)
|
268 |
+
D, I = self.english_index.search(np.array([query_vec]), k=2) # Get top 2 most relevant chunks
|
269 |
+
context = "\n".join([self.english_texts[i] for i in I[0] if i < len(self.english_texts) and i >= 0])
|
270 |
+
|
271 |
+
retrieval_time = time.time() - start_time
|
272 |
+
logger.info(f"Retrieved context in {retrieval_time:.2f}s")
|
273 |
+
|
274 |
+
return context
|
275 |
+
except Exception as e:
|
276 |
+
logger.error(f"Error retrieving context: {str(e)}")
|
277 |
+
return ""
|
278 |
|
279 |
+
def generate_response(self, user_input):
|
280 |
+
"""Generate a response to user input using the appropriate model and retrieval system"""
|
281 |
+
start_time = time.time()
|
282 |
+
|
283 |
+
# Default response in case of failure
|
284 |
+
default_response = {
|
285 |
+
"en": "I apologize, but I couldn't process your request properly. Please try again.",
|
286 |
+
"ar": "أعتذر، لم أتمكن من معالجة طلبك بشكل صحيح. الرجاء المحاولة مرة أخرى."
|
287 |
+
}
|
288 |
+
|
289 |
+
try:
|
290 |
+
# Detect language
|
291 |
+
try:
|
292 |
+
lang = detect(user_input)
|
293 |
+
if lang != "ar": # Simplify to just Arabic vs non-Arabic
|
294 |
+
lang = "en"
|
295 |
+
except:
|
296 |
+
lang = "en" # Default fallback
|
297 |
+
|
298 |
+
logger.info(f"Detected language: {lang}")
|
299 |
+
|
300 |
+
# Retrieve relevant context
|
301 |
+
context = self.retrieve_context(user_input, lang)
|
302 |
+
|
303 |
+
if lang == "ar":
|
304 |
+
# Improved Arabic Prompt
|
305 |
+
input_text = (
|
306 |
+
f"أنت خبير في رؤية السعودية 2030.\n"
|
307 |
+
f"إليك بعض المعلومات المهمة:\n{context}\n\n"
|
308 |
+
f"مثال:\n"
|
309 |
+
f"السؤال: ما هي ركائز رؤية 2030؟\n"
|
310 |
+
f"الإجابة: ركائز رؤية 2030 هي مجتمع حيوي، اقتصاد مزدهر، ووطن طموح.\n\n"
|
311 |
+
f"أجب عن سؤال المستخدم بشكل واضح ودقيق، مستندًا إلى المعلومات المقدمة. إذا لم تكن المعلومات متوفرة، أوضح ذلك.\n"
|
312 |
+
f"السؤال: {user_input}\n"
|
313 |
+
f"الإجابة:"
|
314 |
+
)
|
315 |
+
|
316 |
+
response = self.arabic_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7)
|
317 |
+
full_text = response[0]['generated_text']
|
318 |
+
|
319 |
+
# Extract the answer part
|
320 |
+
answer_pattern = r"الإجابة:(.*?)(?:$)"
|
321 |
+
match = re.search(answer_pattern, full_text, re.DOTALL)
|
322 |
+
if match:
|
323 |
+
reply = match.group(1).strip()
|
324 |
+
else:
|
325 |
+
reply = full_text
|
326 |
+
else:
|
327 |
+
# Improved English Prompt
|
328 |
+
input_text = (
|
329 |
+
f"You are an expert on Saudi Arabia's Vision 2030.\n"
|
330 |
+
f"Here is some relevant information:\n{context}\n\n"
|
331 |
+
f"Example:\n"
|
332 |
+
f"Question: What are the key pillars of Vision 2030?\n"
|
333 |
+
f"Answer: The key pillars are a vibrant society, a thriving economy, and an ambitious nation.\n\n"
|
334 |
+
f"Answer the user's question clearly and accurately based on the provided information. If information is not available, make that clear.\n"
|
335 |
+
f"Question: {user_input}\n"
|
336 |
+
f"Answer:"
|
337 |
+
)
|
338 |
+
|
339 |
+
response = self.english_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7)
|
340 |
+
full_text = response[0]['generated_text']
|
341 |
+
|
342 |
+
# Extract the answer part
|
343 |
+
answer_pattern = r"Answer:(.*?)(?:$)"
|
344 |
+
match = re.search(answer_pattern, full_text, re.DOTALL)
|
345 |
+
if match:
|
346 |
+
reply = match.group(1).strip()
|
347 |
+
else:
|
348 |
+
reply = full_text
|
349 |
+
|
350 |
+
except Exception as e:
|
351 |
+
logger.error(f"Error generating response: {str(e)}")
|
352 |
+
reply = default_response.get(lang, default_response["en"])
|
353 |
+
|
354 |
+
# Record response time
|
355 |
+
response_time = time.time() - start_time
|
356 |
+
self.metrics["response_times"].append(response_time)
|
357 |
+
|
358 |
+
logger.info(f"Generated response in {response_time:.2f}s")
|
359 |
+
|
360 |
+
# Store the interaction for later evaluation
|
361 |
+
interaction = {
|
362 |
+
"timestamp": datetime.now().isoformat(),
|
363 |
+
"user_input": user_input,
|
364 |
+
"response": reply,
|
365 |
+
"language": lang,
|
366 |
+
"response_time": response_time
|
367 |
+
}
|
368 |
+
self.response_history.append(interaction)
|
369 |
+
|
370 |
+
return reply
|
371 |
|
372 |
+
def evaluate_factual_accuracy(self, response, reference):
|
373 |
+
"""Simple evaluation of factual accuracy by keyword matching"""
|
374 |
+
# This is a simplified approach - in production, use more sophisticated methods
|
375 |
+
keywords_reference = set(re.findall(r'\b\w+\b', reference.lower()))
|
376 |
+
keywords_response = set(re.findall(r'\b\w+\b', response.lower()))
|
377 |
+
|
378 |
+
common_keywords = keywords_reference.intersection(keywords_response)
|
379 |
+
|
380 |
+
if len(keywords_reference) > 0:
|
381 |
+
accuracy = len(common_keywords) / len(keywords_reference)
|
382 |
+
else:
|
383 |
+
accuracy = 0
|
384 |
+
|
385 |
+
return accuracy
|
386 |
|
387 |
+
def evaluate_on_test_set(self):
|
388 |
+
"""Evaluate the assistant on the test set"""
|
389 |
+
logger.info("Running evaluation on test set")
|
390 |
+
|
391 |
+
eval_results = []
|
392 |
+
|
393 |
+
for example in self.eval_data:
|
394 |
+
# Generate response
|
395 |
+
response = self.generate_response(example["question"])
|
396 |
+
|
397 |
+
# Calculate factual accuracy
|
398 |
+
accuracy = self.evaluate_factual_accuracy(response, example["reference_answer"])
|
399 |
+
|
400 |
+
eval_results.append({
|
401 |
+
"question": example["question"],
|
402 |
+
"reference": example["reference_answer"],
|
403 |
+
"response": response,
|
404 |
+
"factual_accuracy": accuracy
|
405 |
+
})
|
406 |
+
|
407 |
+
self.metrics["factual_accuracy"].append(accuracy)
|
408 |
+
|
409 |
+
# Calculate average factual accuracy
|
410 |
+
avg_accuracy = sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0
|
411 |
+
avg_response_time = sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0
|
412 |
+
|
413 |
+
results = {
|
414 |
+
"average_factual_accuracy": avg_accuracy,
|
415 |
+
"average_response_time": avg_response_time,
|
416 |
+
"detailed_results": eval_results
|
417 |
+
}
|
418 |
+
|
419 |
+
logger.info(f"Evaluation results: Factual accuracy = {avg_accuracy:.2f}, Avg response time = {avg_response_time:.2f}s")
|
420 |
+
|
421 |
+
return results
|
422 |
|
423 |
+
def record_user_feedback(self, user_input, response, rating, feedback_text=""):
|
424 |
+
"""Record user feedback for a response"""
|
425 |
+
feedback = {
|
426 |
+
"timestamp": datetime.now().isoformat(),
|
427 |
+
"user_input": user_input,
|
428 |
+
"response": response,
|
429 |
+
"rating": rating,
|
430 |
+
"feedback_text": feedback_text
|
431 |
+
}
|
432 |
+
|
433 |
+
self.metrics["user_ratings"].append(rating)
|
434 |
+
|
435 |
+
# In a production system, store this in a database
|
436 |
+
logger.info(f"Recorded user feedback: rating={rating}")
|
437 |
+
|
438 |
+
return True
|
|
|
|
|
439 |
|
440 |
+
def save_evaluation_metrics(self, output_path="evaluation_metrics.json"):
|
441 |
+
"""Save evaluation metrics to a file"""
|
442 |
+
try:
|
443 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
444 |
+
json.dump({
|
445 |
+
"response_times": self.metrics["response_times"],
|
446 |
+
"user_ratings": self.metrics["user_ratings"],
|
447 |
+
"factual_accuracy": self.metrics["factual_accuracy"],
|
448 |
+
"average_factual_accuracy": sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0,
|
449 |
+
"average_response_time": sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0,
|
450 |
+
"average_user_rating": sum(self.metrics["user_ratings"]) / len(self.metrics["user_ratings"]) if self.metrics["user_ratings"] else 0,
|
451 |
+
"timestamp": datetime.now().isoformat()
|
452 |
+
}, f, indent=2)
|
453 |
+
|
454 |
+
logger.info(f"Saved evaluation metrics to {output_path}")
|
455 |
+
return True
|
456 |
+
except Exception as e:
|
457 |
+
logger.error(f"Error saving evaluation metrics: {str(e)}")
|
458 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
+
# --- Gradio UI --- #
|
461 |
+
def create_gradio_interface():
|
462 |
+
# Initialize the assistant
|
463 |
+
assistant = Vision2030Assistant()
|
464 |
+
|
465 |
+
# Track conversation history
|
466 |
+
conversation_history = []
|
467 |
+
|
468 |
def chat(message, history):
|
469 |
+
if not message:
|
470 |
+
return history, ""
|
471 |
+
|
472 |
+
# Generate response
|
473 |
+
reply = assistant.generate_response(message)
|
474 |
+
|
475 |
+
# Update history
|
476 |
history.append((message, reply))
|
477 |
+
|
478 |
return history, ""
|
479 |
+
|
480 |
+
def provide_feedback(message, rating, feedback_text):
|
481 |
+
# Find the most recent interaction
|
482 |
+
if conversation_history:
|
483 |
+
last_interaction = conversation_history[-1]
|
484 |
+
assistant.record_user_feedback(last_interaction[0], last_interaction[1], rating, feedback_text)
|
485 |
+
return f"Thank you for your feedback! (Rating: {rating}/5)"
|
486 |
+
return "No conversation found to rate."
|
487 |
+
|
488 |
+
def clear_history():
|
489 |
+
conversation_history.clear()
|
490 |
+
return []
|
491 |
+
|
492 |
+
def download_metrics():
|
493 |
+
assistant.save_evaluation_metrics()
|
494 |
+
return "evaluation_metrics.json"
|
495 |
+
|
496 |
+
def run_evaluation():
|
497 |
+
results = assistant.evaluate_on_test_set()
|
498 |
+
return f"Evaluation Results:\nFactual Accuracy: {results['average_factual_accuracy']:.2f}\nAverage Response Time: {results['average_response_time']:.2f}s"
|
499 |
+
|
500 |
+
# Create Gradio interface
|
501 |
+
with gr.Blocks() as demo:
|
502 |
+
gr.Markdown("# Vision 2030 Virtual Assistant 🌍\n\nAsk questions about Saudi Vision 2030 in Arabic or English")
|
503 |
+
|
504 |
+
with gr.Tab("Chat"):
|
505 |
+
chatbot = gr.Chatbot(show_label=False)
|
506 |
+
msg = gr.Textbox(label="Ask me anything about Vision 2030", placeholder="Type your question here...")
|
507 |
+
clear = gr.Button("Clear Conversation")
|
508 |
+
|
509 |
+
with gr.Row():
|
510 |
+
with gr.Column(scale=4):
|
511 |
+
feedback_text = gr.Textbox(label="Provide additional feedback (optional)")
|
512 |
+
with gr.Column(scale=1):
|
513 |
+
rating = gr.Slider(label="Rate Response (1-5)", minimum=1, maximum=5, step=1, value=3)
|
514 |
+
|
515 |
+
submit_feedback = gr.Button("Submit Feedback")
|
516 |
+
feedback_result = gr.Textbox(label="Feedback Status")
|
517 |
+
|
518 |
+
# Set up event handlers
|
519 |
+
msg.submit(chat, [msg, chatbot], [chatbot, msg])
|
520 |
+
clear.click(clear_history, None, chatbot)
|
521 |
+
submit_feedback.click(provide_feedback, [msg, rating, feedback_text], feedback_result)
|
522 |
+
|
523 |
+
with gr.Tab("Evaluation"):
|
524 |
+
eval_button = gr.Button("Run Evaluation on Test Set")
|
525 |
+
eval_results = gr.Textbox(label="Evaluation Results")
|
526 |
+
download_button = gr.Button("Download Metrics")
|
527 |
+
download_file = gr.File(label="Download evaluation metrics as JSON")
|
528 |
+
|
529 |
+
# Set up evaluation handlers
|
530 |
+
eval_button.click(run_evaluation, None, eval_results)
|
531 |
+
download_button.click(download_metrics, None, download_file)
|
532 |
+
|
533 |
+
with gr.Tab("About"):
|
534 |
+
gr.Markdown("""
|
535 |
+
## About Vision 2030 Virtual Assistant
|
536 |
+
|
537 |
+
This assistant uses a combination of state-of-the-art language models to answer questions about Saudi Arabia's Vision 2030 strategic framework in both Arabic and English.
|
538 |
+
|
539 |
+
### Features:
|
540 |
+
- Bilingual support (Arabic and English)
|
541 |
+
- Retrieval-Augmented Generation (RAG) for factual accuracy
|
542 |
+
- Evaluation framework for measuring performance
|
543 |
+
- User feedback collection for continuous improvement
|
544 |
+
|
545 |
+
### Models Used:
|
546 |
+
- Arabic: ALLaM-7B-Instruct-preview
|
547 |
+
- English: Mistral-7B-Instruct-v0.2
|
548 |
+
- Embeddings: CAMeL-Lab/bert-base-arabic-camelbert-ca and sentence-transformers/all-MiniLM-L6-v2
|
549 |
+
|
550 |
+
This project demonstrates the application of advanced NLP techniques for multilingual question answering, particularly for Arabic language support.
|
551 |
+
""")
|
552 |
+
|
553 |
+
return demo
|
554 |
|
555 |
+
# Launch the application
|
556 |
+
if __name__ == "__main__":
|
557 |
+
demo = create_gradio_interface()
|
558 |
+
demo.launch()
|
|