nishantgaurav23 commited on
Commit
47a776e
Β·
verified Β·
1 Parent(s): 411e1ce

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +581 -0
app.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore", category=UserWarning)
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import re
9
+ import requests
10
+ #from dotenv import load_dotenv
11
+ from embedding_processor import SentenceTransformerRetriever, process_data
12
+ import pickle
13
+
14
+ import os
15
+ import warnings
16
+ import json # Add this import
17
+
18
+ # Add at the top with other imports
19
+ from llama_cpp import Llama
20
+ import requests
21
+ from tqdm import tqdm
22
+
23
+
24
+ import logging
25
+ import sys
26
+
27
+
28
+
29
+ # Configure logging
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
33
+ handlers=[logging.StreamHandler(sys.stdout)]
34
+ )
35
+
36
+ # Create necessary directories at startup
37
+ for directory in ['models', 'ESPN_data', 'embeddings_cache']:
38
+ os.makedirs(directory, exist_ok=True)
39
+
40
+
41
+
42
+ # Load environment variables
43
+ #load_dotenv()
44
+
45
+ # Add the new function here, right after imports and before API configuration
46
+
47
+ @st.cache_data
48
+ def load_from_drive(file_id: str):
49
+ """Load pickle file directly from Google Drive"""
50
+ try:
51
+ url = f"https://drive.google.com/uc?id={file_id}&export=download"
52
+ session = requests.Session()
53
+ response = session.get(url, stream=True)
54
+
55
+ for key, value in response.cookies.items():
56
+ if key.startswith('download_warning'):
57
+ url = f"{url}&confirm={value}"
58
+ response = session.get(url, stream=True)
59
+ break
60
+
61
+ content = response.content
62
+ print(f"Successfully downloaded {len(content)} bytes")
63
+ return pickle.loads(content)
64
+
65
+ except Exception as e:
66
+ print(f"Detailed error: {str(e)}")
67
+ st.error(f"Error loading file from Drive: {str(e)}")
68
+ return None
69
+
70
+ # Hugging Face API configuration
71
+
72
+ # API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
73
+ # headers = {"Authorization": f"Bearer HF_TOKEN"}
74
+ #model_name = 'mistralai/Mistral-7B-v0.1'
75
+
76
+
77
+ class RAGPipeline:
78
+
79
+ def __init__(self, data_folder: str, k: int = 5):
80
+ try:
81
+ self.data_folder = data_folder
82
+ self.k = k
83
+ self.retriever = SentenceTransformerRetriever()
84
+ self.documents = []
85
+ self.device = torch.device("cpu")
86
+ self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
87
+ self.llm = None
88
+ self.initialize_model() # Using the class method
89
+
90
+ except Exception as e:
91
+ logging.error(f"Error in RAGPipeline initialization: {str(e)}")
92
+ raise
93
+
94
+ @st.cache_resource
95
+ def initialize_model(_self): # Changed 'self' to '_self' for Streamlit caching
96
+ """Initialize the model with proper error handling and verification"""
97
+ try:
98
+ if not os.path.exists(_self.model_path):
99
+ st.info("Downloading model... This may take a while.")
100
+ direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
101
+ _self.download_file_with_progress(direct_url, _self.model_path)
102
+
103
+ # Verify file exists and has content
104
+ if not os.path.exists(_self.model_path):
105
+ raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts")
106
+
107
+ if os.path.getsize(_self.model_path) < 1000000: # Less than 1MB
108
+ os.remove(_self.model_path)
109
+ raise ValueError("Downloaded model file is too small, likely corrupted")
110
+
111
+ llm_config = {
112
+ "n_ctx": 2048,
113
+ "n_threads": 4,
114
+ "n_batch": 512,
115
+ "n_gpu_layers": 0,
116
+ "verbose": False
117
+ }
118
+
119
+ _self.llm = Llama(model_path=_self.model_path, **llm_config)
120
+ st.success("Model loaded successfully!")
121
+
122
+ except Exception as e:
123
+ st.error(f"Error initializing model: {str(e)}")
124
+ raise
125
+
126
+ def download_file_with_progress(self, url: str, filename: str):
127
+ """Download a file with progress bar using requests"""
128
+ response = requests.get(url, stream=True)
129
+ total_size = int(response.headers.get('content-length', 0))
130
+
131
+ with open(filename, 'wb') as file, tqdm(
132
+ desc=filename,
133
+ total=total_size,
134
+ unit='iB',
135
+ unit_scale=True,
136
+ unit_divisor=1024,
137
+ ) as progress_bar:
138
+ for data in response.iter_content(chunk_size=1024):
139
+ size = file.write(data)
140
+ progress_bar.update(size)
141
+
142
+ # Alternative API call with streaming
143
+ def query_model(self, prompt: str) -> str:
144
+ """Query the local Llama model instead of API"""
145
+ try:
146
+ if self.llm is None:
147
+ raise RuntimeError("Model not initialized")
148
+
149
+ # Generate response using Llama model
150
+ response = self.llm(
151
+ prompt,
152
+ max_tokens=512,
153
+ temperature=0.4,
154
+ top_p=0.95,
155
+ echo=False,
156
+ stop=["Question:", "\n\n"]
157
+ )
158
+
159
+ # Check and extract response
160
+ if response and 'choices' in response and len(response['choices']) > 0:
161
+ text = response['choices'][0].get('text', '').strip()
162
+ return text
163
+ else:
164
+ raise ValueError("No valid response generated")
165
+
166
+ except Exception as e:
167
+ logging.error(f"Error in query_model: {str(e)}")
168
+ raise
169
+ def preprocess_query(self, query: str) -> str:
170
+ """Clean and prepare the query"""
171
+ query = query.lower().strip()
172
+ query = re.sub(r'\s+', ' ', query)
173
+ return query
174
+
175
+ def process_query(self, query: str, placeholder) -> str:
176
+ try:
177
+ # Preprocess query
178
+ query = self.preprocess_query(query)
179
+
180
+ # Show retrieval status
181
+ status = placeholder.empty()
182
+ status.write("πŸ” Finding relevant information...")
183
+
184
+ # Get embeddings and search
185
+ query_embedding = self.retriever.encode([query])
186
+ similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
187
+ scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
188
+
189
+ relevant_docs = [self.documents[idx] for idx in indices.tolist()]
190
+
191
+ # Update status
192
+ status.write("πŸ’­ Generating response...")
193
+
194
+ # Prepare context and prompt
195
+ context = "\n".join(relevant_docs[:3]) # Use top 3 most relevant docs
196
+ prompt = f"""Context information is below:
197
+ {context}
198
+
199
+ Given the context above, please answer the following question:
200
+ {query}
201
+
202
+ Guidelines:
203
+ - If you cannot answer based on the context, say so politely
204
+ - Keep the response concise and focused
205
+ - Only include sports-related information
206
+ - No dates or timestamps in the response
207
+ - Use clear, natural language
208
+
209
+ Answer:"""
210
+
211
+ # Generate response
212
+ response_placeholder = placeholder.empty()
213
+
214
+ try:
215
+ response_text = self.query_model(prompt)
216
+ if response_text:
217
+ final_response = self.postprocess_response(response_text)
218
+ response_placeholder.markdown(final_response)
219
+ return final_response
220
+ else:
221
+ message = "No relevant answer found. Please try rephrasing your question."
222
+ response_placeholder.warning(message)
223
+ return message
224
+
225
+ except Exception as e:
226
+ logging.error(f"Generation error: {str(e)}")
227
+ message = "Had some trouble generating the response. Please try again."
228
+ response_placeholder.warning(message)
229
+ return message
230
+
231
+ except Exception as e:
232
+ logging.error(f"Process error: {str(e)}")
233
+ message = "Something went wrong. Please try again with a different question."
234
+ placeholder.warning(message)
235
+ return message
236
+
237
+ def postprocess_response(self, response: str) -> str:
238
+ """Clean up the generated response"""
239
+ response = response.strip()
240
+ response = re.sub(r'\s+', ' ', response)
241
+ response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
242
+ return response
243
+
244
+
245
+ # def process_query(self, query: str, placeholder) -> str:
246
+ # try:
247
+ # # Preprocess query
248
+ # query = self.preprocess_query(query)
249
+
250
+ # # Show retrieval status
251
+ # status = placeholder.empty()
252
+ # status.write("πŸ” Finding relevant information...")
253
+
254
+ # # Get embeddings and search using tensor operations
255
+ # query_embedding = self.retriever.encode([query])
256
+ # similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
257
+ # scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
258
+
259
+ # # Print search results for debugging
260
+ # print("\nSearch Results:")
261
+ # for idx, score in zip(indices.tolist(), scores.tolist()):
262
+ # print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
263
+
264
+ # relevant_docs = [self.documents[idx] for idx in indices.tolist()]
265
+
266
+ # # Update status
267
+ # status.write("πŸ’­ Generating response...")
268
+
269
+ # # Prepare context and prompt
270
+ # context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
271
+ # prompt = f"""Answer this question using the given context. Be specific and detailed.
272
+
273
+ # Context: {context}
274
+
275
+ # Question: {query}
276
+
277
+ # Answer (provide a complete, detailed response):"""
278
+
279
+ # # Generate response
280
+ # response_placeholder = placeholder.empty()
281
+
282
+ # try:
283
+ # response = requests.post(
284
+ # model_name,
285
+ # #headers=headers,
286
+ # json={
287
+ # "inputs": prompt,
288
+ # "parameters": {
289
+ # "max_new_tokens": 1024,
290
+ # "temperature": 0.5,
291
+ # "top_p": 0.9,
292
+ # "top_k": 50,
293
+ # "repetition_penalty": 1.03,
294
+ # "do_sample": True
295
+ # }
296
+ # },
297
+ # timeout=30
298
+ # ).json()
299
+
300
+ # if response and isinstance(response, list) and len(response) > 0:
301
+ # generated_text = response[0].get('generated_text', '').strip()
302
+ # if generated_text:
303
+ # # Find and extract only the answer part
304
+ # if "Answer:" in generated_text:
305
+ # answer_part = generated_text.split("Answer:")[-1].strip()
306
+ # elif "Answer (provide a complete, detailed response):" in generated_text:
307
+ # answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
308
+ # else:
309
+ # answer_part = generated_text.strip()
310
+
311
+ # # Clean up the answer
312
+ # answer_part = answer_part.replace("Context:", "").replace("Question:", "")
313
+
314
+ # final_response = self.postprocess_response(answer_part)
315
+ # response_placeholder.markdown(final_response)
316
+ # return final_response
317
+
318
+ # message = "No relevant answer found. Please try rephrasing your question."
319
+ # response_placeholder.warning(message)
320
+ # return message
321
+
322
+ # except Exception as e:
323
+ # print(f"Generation error: {str(e)}")
324
+ # message = "Had some trouble generating the response. Please try again."
325
+ # response_placeholder.warning(message)
326
+ # return message
327
+
328
+ # except Exception as e:
329
+ # print(f"Process error: {str(e)}")
330
+ # message = "Something went wrong. Please try again with a different question."
331
+ # placeholder.warning(message)
332
+ # return message
333
+ def check_environment():
334
+ """Check if the environment is properly set up"""
335
+ # if not headers['Authorization']:
336
+ # st.error("HUGGINGFACE_API_KEY environment variable not set!")
337
+ # st.stop()
338
+ # return False
339
+
340
+ try:
341
+ import torch
342
+ import sentence_transformers
343
+ return True
344
+ except ImportError as e:
345
+ st.error(f"Missing required package: {str(e)}")
346
+ st.stop()
347
+ return False
348
+
349
+ # @st.cache_resource
350
+ # def initialize_rag_pipeline():
351
+ # """Initialize the RAG pipeline once"""
352
+ # data_folder = "ESPN_data"
353
+ # return RAGPipeline(data_folder)
354
+ def check_space_requirements():
355
+ """Check if we're running on HF Space and have necessary resources"""
356
+ try:
357
+ # Check if we're on HF Space
358
+ is_space = os.environ.get('SPACE_ID') is not None
359
+
360
+ if is_space:
361
+ # Check disk space
362
+ disk_space = os.statvfs('/')
363
+ free_space_gb = (disk_space.f_frsize * disk_space.f_bavail) / (1024**3)
364
+
365
+ if free_space_gb < 10: # Need at least 10GB free
366
+ st.warning(f"Low disk space: {free_space_gb:.1f}GB free")
367
+
368
+ # Check if model exists
369
+ model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
370
+ if not os.path.exists(model_path):
371
+ st.info("Model will be downloaded on first run")
372
+
373
+ # Check if embeddings exist
374
+ if not os.path.exists('embeddings_cache/embeddings.pkl'):
375
+ st.info("Embeddings will be loaded from Drive")
376
+
377
+ return True
378
+
379
+ except Exception as e:
380
+ logging.error(f"Space requirements check failed: {str(e)}")
381
+ return False
382
+
383
+ @st.cache_resource(show_spinner=False)
384
+ def initialize_rag_pipeline():
385
+ """Initialize the RAG pipeline once"""
386
+ try:
387
+ # First check/create necessary directories
388
+ for directory in ['models', 'ESPN_data', 'embeddings_cache']:
389
+ os.makedirs(directory, exist_ok=True)
390
+
391
+ # Load embeddings from Drive
392
+ drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
393
+ with st.spinner("Loading embeddings from Google Drive..."):
394
+ cache_data = load_from_drive(drive_file_id)
395
+ if cache_data is None:
396
+ st.error("Failed to load embeddings from Google Drive")
397
+ st.stop()
398
+
399
+ # Initialize pipeline
400
+ data_folder = "ESPN_data"
401
+ rag = RAGPipeline(data_folder) # This will automatically initialize the model through __init__
402
+
403
+ # Store embeddings
404
+ rag.documents = cache_data['documents']
405
+ rag.retriever.store_embeddings(cache_data['embeddings'])
406
+
407
+ st.success("System initialized successfully!")
408
+ return rag
409
+
410
+ except Exception as e:
411
+ logging.error(f"Pipeline initialization error: {str(e)}")
412
+ st.error(f"Failed to initialize the system: {str(e)}")
413
+ raise
414
+
415
+ except Exception as e:
416
+ logging.error(f"Pipeline initialization error: {str(e)}")
417
+ st.error(f"Failed to initialize the system: {str(e)}")
418
+ raise
419
+
420
+ def main():
421
+ try:
422
+ # Environment check
423
+ if not check_environment() or not check_space_requirements():
424
+ return
425
+
426
+ # Session state for initialization status
427
+ if 'initialized' not in st.session_state:
428
+ st.session_state.initialized = False
429
+
430
+ # Page config
431
+ st.set_page_config(
432
+ page_title="The Sport Chatbot",
433
+ page_icon="πŸ†",
434
+ layout="wide"
435
+ )
436
+
437
+ # Improved CSS styling
438
+ st.markdown("""
439
+ <style>
440
+ /* Container styling */
441
+ .block-container {
442
+ padding-top: 2rem;
443
+ padding-bottom: 2rem;
444
+ }
445
+
446
+ /* Text input styling */
447
+ .stTextInput > div > div > input {
448
+ width: 100%;
449
+ }
450
+
451
+ /* Button styling */
452
+ .stButton > button {
453
+ width: 200px;
454
+ margin: 0 auto;
455
+ display: block;
456
+ background-color: #FF4B4B;
457
+ color: white;
458
+ border-radius: 5px;
459
+ padding: 0.5rem 1rem;
460
+ }
461
+
462
+ /* Title styling */
463
+ .main-title {
464
+ text-align: center;
465
+ padding: 1rem 0;
466
+ font-size: 3rem;
467
+ color: #1F1F1F;
468
+ }
469
+
470
+ .sub-title {
471
+ text-align: center;
472
+ padding: 0.5rem 0;
473
+ font-size: 1.5rem;
474
+ color: #4F4F4F;
475
+ }
476
+
477
+ /* Description styling */
478
+ .description {
479
+ text-align: center;
480
+ color: #666666;
481
+ padding: 0.5rem 0;
482
+ font-size: 1.1rem;
483
+ line-height: 1.6;
484
+ margin-bottom: 1rem;
485
+ }
486
+
487
+ /* Answer container styling */
488
+ .stMarkdown {
489
+ max-width: 100%;
490
+ }
491
+
492
+ /* Streamlit default overrides */
493
+ .st-emotion-cache-16idsys p {
494
+ font-size: 1.1rem;
495
+ line-height: 1.6;
496
+ }
497
+
498
+ /* Container for main content */
499
+ .main-content {
500
+ max-width: 1200px;
501
+ margin: 0 auto;
502
+ padding: 0 1rem;
503
+ }
504
+ </style>
505
+ """, unsafe_allow_html=True)
506
+
507
+ # Header section with improved styling
508
+ st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
509
+ st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
510
+ st.markdown("""
511
+ <p class='description'>
512
+ Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
513
+ With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
514
+ </p>
515
+ <p class='description'>
516
+ Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
517
+ </p>
518
+ """, unsafe_allow_html=True)
519
+
520
+ # Add some spacing
521
+ st.markdown("<br>", unsafe_allow_html=True)
522
+
523
+ # Initialize the pipeline
524
+ if not st.session_state.initialized:
525
+ try:
526
+ with st.spinner("Loading resources..."):
527
+ # Create necessary directories
528
+ for directory in ['models', 'ESPN_data', 'embeddings_cache']:
529
+ os.makedirs(directory, exist_ok=True)
530
+
531
+ # Initialize RAG pipeline
532
+ st.session_state.rag = initialize_rag_pipeline()
533
+ st.session_state.initialized = True
534
+
535
+ st.success("System initialized successfully!")
536
+ except Exception as e:
537
+ logging.error(f"Initialization error: {str(e)}")
538
+ st.error("Unable to initialize the system. Please check if all required files are present.")
539
+ st.stop()
540
+
541
+ # Create columns for layout with golden ratio
542
+ col1, col2, col3 = st.columns([1, 6, 1])
543
+
544
+ with col2:
545
+ # Query input with label styling
546
+ query = st.text_input("What would you like to know about sports?")
547
+
548
+ # Centered button
549
+ if st.button("Get Answer"):
550
+ if query:
551
+ response_placeholder = st.empty()
552
+ try:
553
+ # Get response from RAG pipeline
554
+ response = st.session_state.rag.process_query(query, response_placeholder)
555
+ logging.info(f"Generated response: {response}")
556
+ except Exception as e:
557
+ logging.error(f"Query processing error: {str(e)}")
558
+ response_placeholder.warning("Unable to process your question. Please try again.")
559
+ else:
560
+ st.warning("Please enter a question!")
561
+
562
+ # Footer with improved styling
563
+ st.markdown("<br><br>", unsafe_allow_html=True)
564
+ st.markdown("---")
565
+ st.markdown("""
566
+ <p style='text-align: center; color: #666666; padding: 1rem 0;'>
567
+ Powered by ESPN Data & Mistral AI πŸš€<br>
568
+ <small>Running on Hugging Face Spaces</small>
569
+ </p>
570
+ """, unsafe_allow_html=True)
571
+
572
+ except Exception as e:
573
+ logging.error(f"Application error: {str(e)}")
574
+ st.error("An unexpected error occurred. Please check the logs and try again.")
575
+
576
+ if __name__ == "__main__":
577
+ try:
578
+ main()
579
+ except Exception as e:
580
+ logging.error(f"Application error: {str(e)}")
581
+ st.error("An unexpected error occurred. Please check the logs and try again.")