nishantgaurav23 commited on
Commit
f30497e
Β·
verified Β·
1 Parent(s): 8d4c5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -262
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
  import warnings
 
 
3
  warnings.filterwarnings("ignore", category=UserWarning)
4
 
5
  import numpy as np
6
  import pandas as pd
7
  import torch
8
  from sentence_transformers import SentenceTransformer
9
- from typing import List, Callable
10
  import glob
11
  from tqdm import tqdm
12
  import pickle
@@ -19,9 +21,22 @@ import re
19
  import time
20
  import requests
21
 
 
 
 
 
 
 
 
 
 
22
  # Force CPU device
23
  torch.device('cpu')
24
 
 
 
 
 
25
  # Logging configuration
26
  LOGGING_CONFIG = {
27
  'enabled': True,
@@ -49,36 +64,7 @@ def download_file_with_progress(url: str, filename: str):
49
  for data in response.iter_content(chunk_size=1024):
50
  size = file.write(data)
51
  progress_bar.update(size)
52
-
53
- @st.cache_data
54
- def load_from_drive(file_id: str):
55
- """Load pickle file directly from Google Drive"""
56
- try:
57
- # Direct download URL for Google Drive
58
- url = f"https://drive.google.com/uc?id={file_id}&export=download"
59
-
60
- # First request to get the confirmation token
61
- session = requests.Session()
62
- response = session.get(url, stream=True)
63
-
64
- # Check if we need to confirm download
65
- for key, value in response.cookies.items():
66
- if key.startswith('download_warning'):
67
- # Add confirmation parameter to the URL
68
- url = f"{url}&confirm={value}"
69
- response = session.get(url, stream=True)
70
- break
71
-
72
- # Load the content and convert to pickle
73
- content = response.content
74
- print(f"Successfully downloaded {len(content)} bytes")
75
- return pickle.loads(content)
76
-
77
- except Exception as e:
78
- print(f"Detailed error: {str(e)}") # This will help debug
79
- st.error(f"Error loading file from Drive: {str(e)}")
80
- return None
81
-
82
  def log_function(func: Callable) -> Callable:
83
  """Decorator to log function inputs and outputs"""
84
  @functools.wraps(func)
@@ -135,48 +121,70 @@ def check_environment():
135
  st.stop()
136
  return False
137
 
138
-
139
-
140
  class SentenceTransformerRetriever:
141
- @st.cache_resource
142
- def __init__(_self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
143
- # Force CPU device and suppress warnings
 
 
 
 
 
 
 
 
 
 
144
  with warnings.catch_warnings():
145
  warnings.simplefilter("ignore")
146
- _self.device = torch.device("cpu")
147
- _self.model = SentenceTransformer(model_name, device="cpu")
148
- _self.doc_embeddings = None
149
- _self.cache_dir = cache_dir
150
- _self.cache_file = "embeddings.pkl"
151
- os.makedirs(cache_dir, exist_ok=True)
152
 
153
  def get_cache_path(self, data_folder: str = None) -> str:
154
  return os.path.join(self.cache_dir, self.cache_file)
155
-
156
 
157
  @log_function
158
  def save_cache(self, data_folder: str, cache_data: dict):
159
- cache_path = self.get_cache_path()
160
- if os.path.exists(cache_path):
161
- os.remove(cache_path)
162
- with open(cache_path, 'wb') as f:
163
- pickle.dump(cache_data, f)
164
- print(f"Cache saved at: {cache_path}")
 
 
 
 
165
 
166
  @log_function
167
  @st.cache_data
168
- def load_cache(_self, data_folder: str = None) -> dict:
169
- cache_path = _self.get_cache_path()
170
- if os.path.exists(cache_path):
171
- with open(cache_path, 'rb') as f:
172
- print(f"Loading cache from: {cache_path}")
173
- return pickle.load(f)
174
- return None
175
-
 
 
 
 
 
 
 
176
  @log_function
177
  def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
178
- embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
179
- return F.normalize(embeddings, p=2, dim=1)
 
 
 
 
180
 
181
  @log_function
182
  def store_embeddings(self, embeddings: torch.Tensor):
@@ -187,23 +195,16 @@ class SentenceTransformerRetriever:
187
  if self.doc_embeddings is None:
188
  raise ValueError("No document embeddings stored!")
189
 
190
- # Compute similarities
191
  similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
192
-
193
- # Get top k scores and indices
194
  k = min(k, len(documents))
195
  scores, indices = torch.topk(similarities, k=k)
196
 
197
- # Log similarity statistics
198
- print(f"\nSimilarity Stats:")
199
- print(f"Max similarity: {similarities.max().item():.4f}")
200
- print(f"Mean similarity: {similarities.mean().item():.4f}")
201
- print(f"Selected similarities: {scores.tolist()}")
202
 
203
  return indices.cpu(), scores.cpu()
204
-
205
-
206
-
207
 
208
  class RAGPipeline:
209
  def __init__(self, data_folder: str, k: int = 5):
@@ -212,29 +213,25 @@ class RAGPipeline:
212
  self.retriever = SentenceTransformerRetriever()
213
  self.documents = []
214
  self.device = torch.device("cpu")
215
- self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
216
- # Initialize model in init
217
  self.llm = None
218
- self.initialize_model()
219
-
220
 
221
- st.cache_resource
222
- def initialize_model(_self):
223
- """Initialize the model with proper error handling and verification
224
-
225
- Note: Using _self instead of self for Streamlit caching compatibility
226
- """
227
  try:
228
- if not os.path.exists(_self.model_path):
 
 
229
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
230
- download_file_with_progress(direct_url, _self.model_path)
231
 
232
- # Verify file exists and has content
233
- if not os.path.exists(_self.model_path):
234
- raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts")
235
 
236
- if os.path.getsize(_self.model_path) < 1000000: # Less than 1MB
237
- os.remove(_self.model_path)
238
  raise ValueError("Downloaded model file is too small, likely corrupted")
239
 
240
  llm_config = {
@@ -245,43 +242,73 @@ class RAGPipeline:
245
  "verbose": False
246
  }
247
 
248
- _self.llm = Llama(model_path=_self.model_path, **llm_config)
249
  st.success("Model loaded successfully!")
250
 
251
  except Exception as e:
 
252
  st.error(f"Error initializing model: {str(e)}")
253
  raise
254
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  @log_function
256
  @st.cache_data
257
- def load_and_process_csvs(_self):
258
- cache_data = _self.retriever.load_cache(_self.data_folder)
259
- if cache_data is not None:
260
- _self.documents = cache_data['documents']
261
- _self.retriever.store_embeddings(cache_data['embeddings'])
262
- return
263
-
264
- csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
265
- all_documents = []
266
-
267
- for csv_file in tqdm(csv_files, desc="Reading CSV files"):
268
- try:
269
- df = pd.read_csv(csv_file)
270
- texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
271
- all_documents.extend(texts)
272
- except Exception as e:
273
- print(f"Error processing file {csv_file}: {e}")
274
- continue
275
-
276
- _self.documents = all_documents
277
- embeddings = _self.retriever.encode(all_documents)
278
- _self.retriever.store_embeddings(embeddings)
279
-
280
- cache_data = {
281
- 'embeddings': embeddings,
282
- 'documents': _self.documents
283
- }
284
- _self.retriever.save_cache(_self.data_folder, cache_data)
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  def preprocess_query(self, query: str) -> str:
287
  """Clean and prepare the query"""
@@ -299,6 +326,12 @@ class RAGPipeline:
299
  @log_function
300
  def process_query(self, query: str, placeholder) -> str:
301
  try:
 
 
 
 
 
 
302
  # Preprocess query
303
  query = self.preprocess_query(query)
304
 
@@ -311,9 +344,9 @@ class RAGPipeline:
311
  indices, scores = self.retriever.search(query_embedding, self.k, self.documents)
312
 
313
  # Print search results for debugging
314
- print("\nSearch Results:")
315
  for idx, score in zip(indices.tolist(), scores.tolist()):
316
- print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
317
 
318
  relevant_docs = [self.documents[idx] for idx in indices.tolist()]
319
 
@@ -339,7 +372,6 @@ class RAGPipeline:
339
 
340
  # Generate response
341
  response_placeholder = placeholder.empty()
342
- generated_text = ""
343
 
344
  try:
345
  response = self.llm(
@@ -368,163 +400,183 @@ class RAGPipeline:
368
  return message
369
 
370
  except Exception as e:
371
- print(f"Generation error: {str(e)}")
372
  message = "Had some trouble generating the response. Please try again."
373
  response_placeholder.warning(message)
374
  return message
375
 
376
  except Exception as e:
377
- print(f"Process error: {str(e)}")
378
  message = "Something went wrong. Please try again with a different question."
379
  placeholder.warning(message)
380
  return message
381
-
382
 
383
-
384
- @st.cache_resource
385
  def initialize_rag_pipeline():
386
  """Initialize the RAG pipeline once"""
387
- data_folder = "ESPN_data" # Update this path as needed
388
- rag = RAGPipeline(data_folder)
389
- rag.load_and_process_csvs()
390
- return rag
 
 
 
 
 
 
 
 
391
 
392
  def main():
393
- # Environment check
394
- if not check_environment():
395
- return
396
-
397
- # Page config
398
- st.set_page_config(
399
- page_title="The Sport Chatbot",
400
- page_icon="πŸ†",
401
- layout="wide" # Changed back to wide for more space
402
- )
403
-
404
- # Improved CSS styling
405
- st.markdown("""
406
- <style>
407
- /* Container styling */
408
- .block-container {
409
- padding-top: 2rem;
410
- padding-bottom: 2rem;
411
- }
412
-
413
- /* Text input styling */
414
- .stTextInput > div > div > input {
415
- width: 100%;
416
- }
417
-
418
- /* Button styling */
419
- .stButton > button {
420
- width: 200px;
421
- margin: 0 auto;
422
- display: block;
423
- background-color: #FF4B4B;
424
- color: white;
425
- border-radius: 5px;
426
- padding: 0.5rem 1rem;
427
- }
428
-
429
- /* Title styling */
430
- .main-title {
431
- text-align: center;
432
- padding: 1rem 0;
433
- font-size: 3rem;
434
- color: #1F1F1F;
435
- }
436
-
437
- .sub-title {
438
- text-align: center;
439
- padding: 0.5rem 0;
440
- font-size: 1.5rem;
441
- color: #4F4F4F;
442
- }
443
-
444
- /* Description styling */
445
- .description {
446
- text-align: center;
447
- color: #666666;
448
- padding: 0.5rem 0;
449
- font-size: 1.1rem;
450
- line-height: 1.6;
451
- margin-bottom: 1rem;
452
- }
453
-
454
- /* Answer container styling */
455
- .stMarkdown {
456
- max-width: 100%;
457
- }
458
-
459
- /* Streamlit default overrides */
460
- .st-emotion-cache-16idsys p {
461
- font-size: 1.1rem;
462
- line-height: 1.6;
463
- }
464
-
465
- /* Container for main content */
466
- .main-content {
467
- max-width: 1200px;
468
- margin: 0 auto;
469
- padding: 0 1rem;
470
- }
471
- </style>
472
- """, unsafe_allow_html=True)
473
-
474
- # Header section with improved styling
475
- st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
476
- st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
477
- st.markdown("""
478
- <p class='description'>
479
- Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
480
- With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
481
- </p>
482
- <p class='description'>
483
- Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
484
- </p>
485
- """, unsafe_allow_html=True)
486
-
487
- # Add some spacing
488
- st.markdown("<br>", unsafe_allow_html=True)
489
-
490
-
491
- # Initialize the pipeline
492
  try:
493
- with st.spinner("Loading resources..."):
494
- rag = initialize_rag_pipeline()
495
- except Exception as e:
496
- print(f"Initialization error: {str(e)}")
497
- st.error("Unable to initialize the system. Please check if all required files are present.")
498
- st.stop()
499
 
500
- # Create columns for layout with golden ratio
501
- col1, col2, col3 = st.columns([1, 6, 1])
502
-
503
- with col2:
504
- # Query input with label styling
505
- query = st.text_input("What would you like to know about sports?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
- # Centered button
508
- if st.button("Get Answer"):
509
- if query:
510
- response_placeholder = st.empty()
511
- try:
512
- response = rag.process_query(query, response_placeholder)
513
- print(f"Generated response: {response}")
514
- except Exception as e:
515
- print(f"Query processing error: {str(e)}")
516
- response_placeholder.warning("Unable to process your question. Please try again.")
517
- else:
518
- st.warning("Please enter a question!")
519
-
520
- # Footer with improved styling
521
- st.markdown("<br><br>", unsafe_allow_html=True)
522
- st.markdown("---")
523
- st.markdown("""
524
- <p style='text-align: center; color: #666666; padding: 1rem 0;'>
525
- Powered by ESPN Data & Mistral AI πŸš€
526
- </p>
527
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
528
 
529
  if __name__ == "__main__":
530
- main()
 
 
 
 
 
1
  import os
2
  import warnings
3
+ import logging
4
+ import sys
5
  warnings.filterwarnings("ignore", category=UserWarning)
6
 
7
  import numpy as np
8
  import pandas as pd
9
  import torch
10
  from sentence_transformers import SentenceTransformer
11
+ from typing import List, Callable, Dict, Optional, Any
12
  import glob
13
  from tqdm import tqdm
14
  import pickle
 
21
  import time
22
  import requests
23
 
24
+ # Configure logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
28
+ handlers=[
29
+ logging.StreamHandler(sys.stdout)
30
+ ]
31
+ )
32
+
33
  # Force CPU device
34
  torch.device('cpu')
35
 
36
+ # Create necessary directories
37
+ for directory in ['models', 'ESPN_data', 'embeddings_cache']:
38
+ os.makedirs(directory, exist_ok=True)
39
+
40
  # Logging configuration
41
  LOGGING_CONFIG = {
42
  'enabled': True,
 
64
  for data in response.iter_content(chunk_size=1024):
65
  size = file.write(data)
66
  progress_bar.update(size)
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def log_function(func: Callable) -> Callable:
69
  """Decorator to log function inputs and outputs"""
70
  @functools.wraps(func)
 
121
  st.stop()
122
  return False
123
 
 
 
124
  class SentenceTransformerRetriever:
125
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
126
+ self.device = torch.device("cpu")
127
+ self.model_name = model_name
128
+ self.cache_dir = cache_dir
129
+ self.cache_file = "embeddings.pkl"
130
+ self.doc_embeddings = None
131
+ os.makedirs(cache_dir, exist_ok=True)
132
+ # Initialize model using cached method
133
+ self.model = self._load_model()
134
+
135
+ @st.cache_resource(show_spinner=False)
136
+ def _load_model(self):
137
+ """Load and cache the sentence transformer model"""
138
  with warnings.catch_warnings():
139
  warnings.simplefilter("ignore")
140
+ model = SentenceTransformer(self.model_name, device="cpu")
141
+ # Verify model is loaded correctly
142
+ test_embedding = model.encode("test", convert_to_tensor=True)
143
+ if not isinstance(test_embedding, torch.Tensor):
144
+ raise ValueError("Model initialization failed")
145
+ return model
146
 
147
  def get_cache_path(self, data_folder: str = None) -> str:
148
  return os.path.join(self.cache_dir, self.cache_file)
 
149
 
150
  @log_function
151
  def save_cache(self, data_folder: str, cache_data: dict):
152
+ try:
153
+ cache_path = self.get_cache_path()
154
+ if os.path.exists(cache_path):
155
+ os.remove(cache_path)
156
+ with open(cache_path, 'wb') as f:
157
+ pickle.dump(cache_data, f)
158
+ logging.info(f"Cache saved at: {cache_path}")
159
+ except Exception as e:
160
+ logging.error(f"Error saving cache: {str(e)}")
161
+ raise
162
 
163
  @log_function
164
  @st.cache_data
165
+ def load_cache(self, data_folder: str = None) -> Optional[Dict]:
166
+ try:
167
+ cache_path = self.get_cache_path()
168
+ if os.path.exists(cache_path):
169
+ with open(cache_path, 'rb') as f:
170
+ logging.info(f"Loading cache from: {cache_path}")
171
+ cache_data = pickle.load(f)
172
+ if isinstance(cache_data, dict) and 'embeddings' in cache_data and 'documents' in cache_data:
173
+ return cache_data
174
+ logging.warning("Invalid cache format")
175
+ return None
176
+ except Exception as e:
177
+ logging.error(f"Error loading cache: {str(e)}")
178
+ return None
179
+
180
  @log_function
181
  def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
182
+ try:
183
+ embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
184
+ return F.normalize(embeddings, p=2, dim=1)
185
+ except Exception as e:
186
+ logging.error(f"Error encoding texts: {str(e)}")
187
+ raise
188
 
189
  @log_function
190
  def store_embeddings(self, embeddings: torch.Tensor):
 
195
  if self.doc_embeddings is None:
196
  raise ValueError("No document embeddings stored!")
197
 
 
198
  similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
 
 
199
  k = min(k, len(documents))
200
  scores, indices = torch.topk(similarities, k=k)
201
 
202
+ logging.info(f"\nSimilarity Stats:")
203
+ logging.info(f"Max similarity: {similarities.max().item():.4f}")
204
+ logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
205
+ logging.info(f"Selected similarities: {scores.tolist()}")
 
206
 
207
  return indices.cpu(), scores.cpu()
 
 
 
208
 
209
  class RAGPipeline:
210
  def __init__(self, data_folder: str, k: int = 5):
 
213
  self.retriever = SentenceTransformerRetriever()
214
  self.documents = []
215
  self.device = torch.device("cpu")
216
+ self.model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
 
217
  self.llm = None
218
+ self._initialize_model()
 
219
 
220
+ @st.cache_resource(show_spinner=False)
221
+ def _initialize_model(self):
222
+ """Initialize the model with proper error handling and verification"""
 
 
 
223
  try:
224
+ os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
225
+
226
+ if not os.path.exists(self.model_path):
227
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
228
+ download_file_with_progress(direct_url, self.model_path)
229
 
230
+ if not os.path.exists(self.model_path):
231
+ raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts")
 
232
 
233
+ if os.path.getsize(self.model_path) < 1000000: # Less than 1MB
234
+ os.remove(self.model_path)
235
  raise ValueError("Downloaded model file is too small, likely corrupted")
236
 
237
  llm_config = {
 
242
  "verbose": False
243
  }
244
 
245
+ self.llm = Llama(model_path=self.model_path, **llm_config)
246
  st.success("Model loaded successfully!")
247
 
248
  except Exception as e:
249
+ logging.error(f"Error initializing model: {str(e)}")
250
  st.error(f"Error initializing model: {str(e)}")
251
  raise
252
+
253
+ def check_model_health(self):
254
+ """Verify that the model is loaded and functioning"""
255
+ try:
256
+ if self.llm is None:
257
+ return False
258
+
259
+ # Simple test prompt
260
+ test_response = self.llm(
261
+ "Test prompt",
262
+ max_tokens=10,
263
+ temperature=0.4,
264
+ echo=False
265
+ )
266
+
267
+ return isinstance(test_response, dict) and 'choices' in test_response
268
+ except Exception:
269
+ return False
270
+
271
  @log_function
272
  @st.cache_data
273
+ def load_and_process_csvs(self):
274
+ try:
275
+ cache_data = self.retriever.load_cache(self.data_folder)
276
+ if cache_data is not None:
277
+ self.documents = cache_data['documents']
278
+ self.retriever.store_embeddings(cache_data['embeddings'])
279
+ return
280
+
281
+ csv_files = glob.glob(os.path.join(self.data_folder, "*.csv"))
282
+ if not csv_files:
283
+ raise FileNotFoundError(f"No CSV files found in {self.data_folder}")
284
+
285
+ all_documents = []
286
+
287
+ for csv_file in tqdm(csv_files, desc="Reading CSV files"):
288
+ try:
289
+ df = pd.read_csv(csv_file)
290
+ texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
291
+ all_documents.extend(texts)
292
+ except Exception as e:
293
+ logging.error(f"Error processing file {csv_file}: {e}")
294
+ continue
295
+
296
+ if not all_documents:
297
+ raise ValueError("No documents were successfully loaded")
298
+
299
+ self.documents = all_documents
300
+ embeddings = self.retriever.encode(all_documents)
301
+ self.retriever.store_embeddings(embeddings)
302
+
303
+ cache_data = {
304
+ 'embeddings': embeddings,
305
+ 'documents': self.documents
306
+ }
307
+ self.retriever.save_cache(self.data_folder, cache_data)
308
+
309
+ except Exception as e:
310
+ logging.error(f"Error in load_and_process_csvs: {str(e)}")
311
+ raise
312
 
313
  def preprocess_query(self, query: str) -> str:
314
  """Clean and prepare the query"""
 
326
  @log_function
327
  def process_query(self, query: str, placeholder) -> str:
328
  try:
329
+ # Check if models are properly initialized
330
+ if self.llm is None:
331
+ raise RuntimeError("LLM model not initialized")
332
+ if self.retriever.model is None:
333
+ raise RuntimeError("Sentence transformer model not initialized")
334
+
335
  # Preprocess query
336
  query = self.preprocess_query(query)
337
 
 
344
  indices, scores = self.retriever.search(query_embedding, self.k, self.documents)
345
 
346
  # Print search results for debugging
347
+ logging.info("\nSearch Results:")
348
  for idx, score in zip(indices.tolist(), scores.tolist()):
349
+ logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
350
 
351
  relevant_docs = [self.documents[idx] for idx in indices.tolist()]
352
 
 
372
 
373
  # Generate response
374
  response_placeholder = placeholder.empty()
 
375
 
376
  try:
377
  response = self.llm(
 
400
  return message
401
 
402
  except Exception as e:
403
+ logging.error(f"Generation error: {str(e)}")
404
  message = "Had some trouble generating the response. Please try again."
405
  response_placeholder.warning(message)
406
  return message
407
 
408
  except Exception as e:
409
+ logging.error(f"Process error: {str(e)}")
410
  message = "Something went wrong. Please try again with a different question."
411
  placeholder.warning(message)
412
  return message
 
413
 
414
+ @st.cache_resource(show_spinner=False)
 
415
  def initialize_rag_pipeline():
416
  """Initialize the RAG pipeline once"""
417
+ try:
418
+ data_folder = "ESPN_data"
419
+ if not os.path.exists(data_folder):
420
+ os.makedirs(data_folder, exist_ok=True)
421
+
422
+ rag = RAGPipeline(data_folder)
423
+ rag.load_and_process_csvs()
424
+ return rag
425
+ except Exception as e:
426
+ logging.error(f"Pipeline initialization error: {str(e)}")
427
+ st.error("Failed to initialize the system. Please check your data folder and try again.")
428
+ raise
429
 
430
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  try:
432
+ # Environment check
433
+ if not check_environment():
434
+ return
 
 
 
435
 
436
+ # Page config
437
+ st.set_page_config(
438
+ page_title="The Sport Chatbot",
439
+ page_icon="πŸ†",
440
+ layout="wide"
441
+ )
442
+
443
+ # Improved CSS styling
444
+ st.markdown("""
445
+ <style>
446
+ /* Container styling */
447
+ .block-container {
448
+ padding-top: 2rem;
449
+ padding-bottom: 2rem;
450
+ }
451
+
452
+ /* Text input styling */
453
+ .stTextInput > div > div > input {
454
+ width: 100%;
455
+ }
456
+
457
+ /* Button styling */
458
+ .stButton > button {
459
+ width: 200px;
460
+ margin: 0 auto;
461
+ display: block;
462
+ background-color: #FF4B4B;
463
+ color: white;
464
+ border-radius: 5px;
465
+ padding: 0.5rem 1rem;
466
+ }
467
+
468
+ /* Title styling */
469
+ .main-title {
470
+ text-align: center;
471
+ padding: 1rem 0;
472
+ font-size: 3rem;
473
+ color: #1F1F1F;
474
+ }
475
+
476
+ .sub-title {
477
+ text-align: center;
478
+ padding: 0.5rem 0;
479
+ font-size: 1.5rem;
480
+ color: #4F4F4F;
481
+ }
482
+
483
+ /* Description styling */
484
+ .description {
485
+ text-align: center;
486
+ color: #666666;
487
+ padding: 0.5rem 0;
488
+ font-size: 1.1rem;
489
+ line-height: 1.6;
490
+ margin-bottom: 1rem;
491
+ }
492
+
493
+ /* Answer container styling */
494
+ .stMarkdown {
495
+ max-width: 100%;
496
+ }
497
+
498
+ /* Streamlit default overrides */
499
+ .st-emotion-cache-16idsys p {
500
+ font-size: 1.1rem;
501
+ line-height: 1.6;
502
+ }
503
+
504
+ /* Container for main content */
505
+ .main-content {
506
+ max-width: 1200px;
507
+ margin: 0 auto;
508
+ padding: 0 1rem;
509
+ }
510
+ </style>
511
+ """, unsafe_allow_html=True)
512
+
513
+ # Header section
514
+ st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
515
+ st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
516
+ st.markdown("""
517
+ <p class='description'>
518
+ Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
519
+ With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
520
+ </p>
521
+ <p class='description'>
522
+ Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
523
+ </p>
524
+ """, unsafe_allow_html=True)
525
+
526
+ # Add spacing
527
+ st.markdown("<br>", unsafe_allow_html=True)
528
+
529
+ # Initialize the pipeline
530
+ try:
531
+ with st.spinner("Loading resources..."):
532
+ rag = initialize_rag_pipeline()
533
+
534
+ # Add a model health check
535
+ if not rag.check_model_health():
536
+ st.error("Model initialization failed. Please try restarting the application.")
537
+ return
538
+
539
+ except Exception as e:
540
+ logging.error(f"Initialization error: {str(e)}")
541
+ st.error("Unable to initialize the system. Please check if all required files are present.")
542
+ return
543
+
544
+ # Create columns for layout with golden ratio
545
+ col1, col2, col3 = st.columns([1, 6, 1])
546
 
547
+ with col2:
548
+ # Query input with label styling
549
+ query = st.text_input("What would you like to know about sports?")
550
+
551
+ # Centered button
552
+ if st.button("Get Answer"):
553
+ if query:
554
+ response_placeholder = st.empty()
555
+ try:
556
+ response = rag.process_query(query, response_placeholder)
557
+ logging.info(f"Generated response: {response}")
558
+ except Exception as e:
559
+ logging.error(f"Query processing error: {str(e)}")
560
+ response_placeholder.warning("Unable to process your question. Please try again.")
561
+ else:
562
+ st.warning("Please enter a question!")
563
+
564
+ # Footer
565
+ st.markdown("<br><br>", unsafe_allow_html=True)
566
+ st.markdown("---")
567
+ st.markdown("""
568
+ <p style='text-align: center; color: #666666; padding: 1rem 0;'>
569
+ Powered by ESPN Data & Mistral AI πŸš€
570
+ </p>
571
+ """, unsafe_allow_html=True)
572
+
573
+ except Exception as e:
574
+ logging.error(f"Application error: {str(e)}")
575
+ st.error("An unexpected error occurred. Please check the logs and try again.")
576
 
577
  if __name__ == "__main__":
578
+ try:
579
+ main()
580
+ except Exception as e:
581
+ logging.error(f"Application error: {str(e)}")
582
+ st.error("An unexpected error occurred. Please check the logs and try again.")