nishantgaurav23 commited on
Commit
f644e57
Β·
verified Β·
1 Parent(s): d3429ce

Upload app_new.py

Browse files
Files changed (1) hide show
  1. app_new.py +492 -0
app_new.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore", category=UserWarning)
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from sentence_transformers import SentenceTransformer
9
+ from typing import List, Callable
10
+ import glob
11
+ from tqdm import tqdm
12
+ import pickle
13
+ import torch.nn.functional as F
14
+ from llama_cpp import Llama
15
+ import streamlit as st
16
+ import functools
17
+ from datetime import datetime
18
+ import re
19
+ import time
20
+ import requests
21
+
22
+ # Force CPU device
23
+ torch.device('cpu')
24
+
25
+ # Logging configuration
26
+ LOGGING_CONFIG = {
27
+ 'enabled': True,
28
+ 'functions': {
29
+ 'encode': True,
30
+ 'store_embeddings': True,
31
+ 'search': True,
32
+ 'load_and_process_csvs': True,
33
+ 'process_query': True
34
+ }
35
+ }
36
+ @st.cache_data
37
+ def load_from_drive(file_id: str):
38
+ """Load pickle file directly from Google Drive"""
39
+ try:
40
+ # Direct download URL for Google Drive
41
+ url = f"https://drive.google.com/uc?id={file_id}&export=download"
42
+
43
+ # First request to get the confirmation token
44
+ session = requests.Session()
45
+ response = session.get(url, stream=True)
46
+
47
+ # Check if we need to confirm download
48
+ for key, value in response.cookies.items():
49
+ if key.startswith('download_warning'):
50
+ # Add confirmation parameter to the URL
51
+ url = f"{url}&confirm={value}"
52
+ response = session.get(url, stream=True)
53
+ break
54
+
55
+ # Load the content and convert to pickle
56
+ content = response.content
57
+ print(f"Successfully downloaded {len(content)} bytes")
58
+ return pickle.loads(content)
59
+
60
+ except Exception as e:
61
+ print(f"Detailed error: {str(e)}") # This will help debug
62
+ st.error(f"Error loading file from Drive: {str(e)}")
63
+ return None
64
+
65
+ def log_function(func: Callable) -> Callable:
66
+ """Decorator to log function inputs and outputs"""
67
+ @functools.wraps(func)
68
+ def wrapper(*args, **kwargs):
69
+ if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False):
70
+ return func(*args, **kwargs)
71
+
72
+ if args and hasattr(args[0], '__class__'):
73
+ class_name = args[0].__class__.__name__
74
+ else:
75
+ class_name = func.__module__
76
+
77
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
78
+ log_args = args[1:] if class_name != func.__module__ else args
79
+
80
+ def format_arg(arg):
81
+ if isinstance(arg, torch.Tensor):
82
+ return f"Tensor(shape={list(arg.shape)}, device={arg.device})"
83
+ elif isinstance(arg, list):
84
+ return f"List(len={len(arg)})"
85
+ elif isinstance(arg, str) and len(arg) > 100:
86
+ return f"String(len={len(arg)}): {arg[:100]}..."
87
+ return arg
88
+
89
+ formatted_args = [format_arg(arg) for arg in log_args]
90
+ formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()}
91
+
92
+ print(f"\n{'='*80}")
93
+ print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}")
94
+ print(f"INPUTS:")
95
+ print(f" args: {formatted_args}")
96
+ print(f" kwargs: {formatted_kwargs}")
97
+
98
+ result = func(*args, **kwargs)
99
+
100
+ formatted_result = format_arg(result)
101
+ print(f"OUTPUT:")
102
+ print(f" {formatted_result}")
103
+ print(f"{'='*80}\n")
104
+
105
+ return result
106
+ return wrapper
107
+
108
+ def check_environment():
109
+ """Check if the environment is properly set up"""
110
+ try:
111
+ import numpy as np
112
+ import torch
113
+ import sentence_transformers
114
+ import llama_cpp
115
+ return True
116
+ except ImportError as e:
117
+ st.error(f"Missing required package: {str(e)}")
118
+ st.stop()
119
+ return False
120
+
121
+ @st.cache_resource
122
+ def initialize_model():
123
+ """Initialize the Llama model once"""
124
+ #model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
125
+ model_path = "mistralai/Mistral-7B-v0.1"
126
+ if not os.path.exists(model_path):
127
+ st.error(f"Model file {model_path} not found!")
128
+ st.stop()
129
+
130
+ llm_config = {
131
+ "n_ctx": 2048,
132
+ "n_threads": 4,
133
+ "n_batch": 512,
134
+ "n_gpu_layers": 0,
135
+ "verbose": False
136
+ }
137
+
138
+ return Llama(model_path=model_path, **llm_config)
139
+
140
+
141
+ class SentenceTransformerRetriever:
142
+ @st.cache_resource
143
+ def __init__(_self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
144
+ # Force CPU device and suppress warnings
145
+ with warnings.catch_warnings():
146
+ warnings.simplefilter("ignore")
147
+ _self.device = torch.device("cpu")
148
+ _self.model = SentenceTransformer(model_name, device="cpu")
149
+ _self.doc_embeddings = None
150
+ _self.cache_dir = cache_dir
151
+ _self.cache_file = "embeddings.pkl"
152
+ os.makedirs(cache_dir, exist_ok=True)
153
+
154
+ def get_cache_path(self, data_folder: str = None) -> str:
155
+ return os.path.join(self.cache_dir, self.cache_file)
156
+
157
+ @log_function
158
+ def save_cache(self, data_folder: str, cache_data: dict):
159
+ cache_path = self.get_cache_path()
160
+ if os.path.exists(cache_path):
161
+ os.remove(cache_path)
162
+ with open(cache_path, 'wb') as f:
163
+ pickle.dump(cache_data, f)
164
+ print(f"Cache saved at: {cache_path}")
165
+
166
+ @log_function
167
+ @st.cache_data
168
+ def load_cache(_self, data_folder: str = None) -> dict:
169
+ cache_path = _self.get_cache_path()
170
+ if os.path.exists(cache_path):
171
+ with open(cache_path, 'rb') as f:
172
+ print(f"Loading cache from: {cache_path}")
173
+ return pickle.load(f)
174
+ return None
175
+
176
+ @log_function
177
+ def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
178
+ embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
179
+ return F.normalize(embeddings, p=2, dim=1)
180
+
181
+ @log_function
182
+ def store_embeddings(self, embeddings: torch.Tensor):
183
+ self.doc_embeddings = embeddings
184
+
185
+ @log_function
186
+ def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
187
+ if self.doc_embeddings is None:
188
+ raise ValueError("No document embeddings stored!")
189
+
190
+ # Compute similarities
191
+ similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
192
+
193
+ # Get top k scores and indices
194
+ k = min(k, len(documents))
195
+ scores, indices = torch.topk(similarities, k=k)
196
+
197
+ # Log similarity statistics
198
+ print(f"\nSimilarity Stats:")
199
+ print(f"Max similarity: {similarities.max().item():.4f}")
200
+ print(f"Mean similarity: {similarities.mean().item():.4f}")
201
+ print(f"Selected similarities: {scores.tolist()}")
202
+
203
+ return indices.cpu(), scores.cpu()
204
+
205
+
206
+
207
+
208
+ class RAGPipeline:
209
+ def __init__(self, data_folder: str, k: int = 5):
210
+ self.data_folder = data_folder
211
+ self.k = k
212
+ self.retriever = SentenceTransformerRetriever()
213
+ self.documents = []
214
+ self.device = torch.device("cpu")
215
+ self.llm = initialize_model()
216
+
217
+ @log_function
218
+ @st.cache_data
219
+ def load_and_process_csvs(_self):
220
+ cache_data = _self.retriever.load_cache(_self.data_folder)
221
+ if cache_data is not None:
222
+ _self.documents = cache_data['documents']
223
+ _self.retriever.store_embeddings(cache_data['embeddings'])
224
+ return
225
+
226
+ csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
227
+ all_documents = []
228
+
229
+ for csv_file in tqdm(csv_files, desc="Reading CSV files"):
230
+ try:
231
+ df = pd.read_csv(csv_file)
232
+ texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
233
+ all_documents.extend(texts)
234
+ except Exception as e:
235
+ print(f"Error processing file {csv_file}: {e}")
236
+ continue
237
+
238
+ _self.documents = all_documents
239
+ embeddings = _self.retriever.encode(all_documents)
240
+ _self.retriever.store_embeddings(embeddings)
241
+
242
+ cache_data = {
243
+ 'embeddings': embeddings,
244
+ 'documents': _self.documents
245
+ }
246
+ _self.retriever.save_cache(_self.data_folder, cache_data)
247
+
248
+ def preprocess_query(self, query: str) -> str:
249
+ """Clean and prepare the query"""
250
+ query = query.lower().strip()
251
+ query = re.sub(r'\s+', ' ', query)
252
+ return query
253
+
254
+ def postprocess_response(self, response: str) -> str:
255
+ """Clean up the generated response"""
256
+ response = response.strip()
257
+ response = re.sub(r'\s+', ' ', response)
258
+ response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
259
+ return response
260
+
261
+ @log_function
262
+ def process_query(self, query: str, placeholder) -> str:
263
+ try:
264
+ # Preprocess query
265
+ query = self.preprocess_query(query)
266
+
267
+ # Show retrieval status
268
+ status = placeholder.empty()
269
+ status.write("πŸ” Finding relevant information...")
270
+
271
+ # Retrieve relevant documents
272
+ query_embedding = self.retriever.encode([query])
273
+ indices, scores = self.retriever.search(query_embedding, self.k, self.documents)
274
+
275
+ # Print search results for debugging
276
+ print("\nSearch Results:")
277
+ for idx, score in zip(indices.tolist(), scores.tolist()):
278
+ print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
279
+
280
+ relevant_docs = [self.documents[idx] for idx in indices.tolist()]
281
+
282
+ # Update status
283
+ status.write("πŸ’­ Generating response...")
284
+
285
+ # Prepare context and prompt
286
+ context = "\n".join(relevant_docs)
287
+ prompt = f"""Context information is below:
288
+ {context}
289
+
290
+ Given the context above, please answer the following question:
291
+ {query}
292
+
293
+ Guidelines:
294
+ - If you cannot answer based on the context, say so politely
295
+ - Keep the response concise and focused
296
+ - Only include sports-related information
297
+ - No dates or timestamps in the response
298
+ - Use clear, natural language
299
+
300
+ Answer:"""
301
+
302
+ # Generate response
303
+ response_placeholder = placeholder.empty()
304
+ generated_text = ""
305
+
306
+ try:
307
+ response = self.llm(
308
+ prompt,
309
+ max_tokens=512,
310
+ temperature=0.4,
311
+ top_p=0.95,
312
+ echo=False,
313
+ stop=["Question:", "\n\n"]
314
+ )
315
+
316
+ if response and 'choices' in response and len(response['choices']) > 0:
317
+ generated_text = response['choices'][0].get('text', '').strip()
318
+
319
+ if generated_text:
320
+ final_response = self.postprocess_response(generated_text)
321
+ response_placeholder.markdown(final_response)
322
+ return final_response
323
+ else:
324
+ message = "No relevant answer found. Please try rephrasing your question."
325
+ response_placeholder.warning(message)
326
+ return message
327
+ else:
328
+ message = "Unable to generate response. Please try again."
329
+ response_placeholder.warning(message)
330
+ return message
331
+
332
+ except Exception as e:
333
+ print(f"Generation error: {str(e)}")
334
+ message = "Had some trouble generating the response. Please try again."
335
+ response_placeholder.warning(message)
336
+ return message
337
+
338
+ except Exception as e:
339
+ print(f"Process error: {str(e)}")
340
+ message = "Something went wrong. Please try again with a different question."
341
+ placeholder.warning(message)
342
+ return message
343
+
344
+
345
+
346
+ @st.cache_resource
347
+ def initialize_rag_pipeline():
348
+ """Initialize the RAG pipeline once"""
349
+ data_folder = "ESPN_data" # Update this path as needed
350
+ rag = RAGPipeline(data_folder)
351
+ rag.load_and_process_csvs()
352
+ return rag
353
+
354
+ def main():
355
+ # Environment check
356
+ if not check_environment():
357
+ return
358
+
359
+ # Page config
360
+ st.set_page_config(
361
+ page_title="The Sport Chatbot",
362
+ page_icon="πŸ†",
363
+ layout="wide" # Changed back to wide for more space
364
+ )
365
+
366
+ # Improved CSS styling
367
+ st.markdown("""
368
+ <style>
369
+ /* Container styling */
370
+ .block-container {
371
+ padding-top: 2rem;
372
+ padding-bottom: 2rem;
373
+ }
374
+
375
+ /* Text input styling */
376
+ .stTextInput > div > div > input {
377
+ width: 100%;
378
+ }
379
+
380
+ /* Button styling */
381
+ .stButton > button {
382
+ width: 200px;
383
+ margin: 0 auto;
384
+ display: block;
385
+ background-color: #FF4B4B;
386
+ color: white;
387
+ border-radius: 5px;
388
+ padding: 0.5rem 1rem;
389
+ }
390
+
391
+ /* Title styling */
392
+ .main-title {
393
+ text-align: center;
394
+ padding: 1rem 0;
395
+ font-size: 3rem;
396
+ color: #1F1F1F;
397
+ }
398
+
399
+ .sub-title {
400
+ text-align: center;
401
+ padding: 0.5rem 0;
402
+ font-size: 1.5rem;
403
+ color: #4F4F4F;
404
+ }
405
+
406
+ /* Description styling */
407
+ .description {
408
+ text-align: center;
409
+ color: #666666;
410
+ padding: 0.5rem 0;
411
+ font-size: 1.1rem;
412
+ line-height: 1.6;
413
+ margin-bottom: 1rem;
414
+ }
415
+
416
+ /* Answer container styling */
417
+ .stMarkdown {
418
+ max-width: 100%;
419
+ }
420
+
421
+ /* Streamlit default overrides */
422
+ .st-emotion-cache-16idsys p {
423
+ font-size: 1.1rem;
424
+ line-height: 1.6;
425
+ }
426
+
427
+ /* Container for main content */
428
+ .main-content {
429
+ max-width: 1200px;
430
+ margin: 0 auto;
431
+ padding: 0 1rem;
432
+ }
433
+ </style>
434
+ """, unsafe_allow_html=True)
435
+
436
+ # Header section with improved styling
437
+ st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
438
+ st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
439
+ st.markdown("""
440
+ <p class='description'>
441
+ Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
442
+ With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
443
+ </p>
444
+ <p class='description'>
445
+ Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
446
+ </p>
447
+ """, unsafe_allow_html=True)
448
+
449
+ # Add some spacing
450
+ st.markdown("<br>", unsafe_allow_html=True)
451
+
452
+
453
+ # Initialize the pipeline
454
+ try:
455
+ with st.spinner("Loading resources..."):
456
+ rag = initialize_rag_pipeline()
457
+ except Exception as e:
458
+ print(f"Initialization error: {str(e)}")
459
+ st.error("Unable to initialize the system. Please check if all required files are present.")
460
+ st.stop()
461
+
462
+ # Create columns for layout with golden ratio
463
+ col1, col2, col3 = st.columns([1, 6, 1])
464
+
465
+ with col2:
466
+ # Query input with label styling
467
+ query = st.text_input("What would you like to know about sports?")
468
+
469
+ # Centered button
470
+ if st.button("Get Answer"):
471
+ if query:
472
+ response_placeholder = st.empty()
473
+ try:
474
+ response = rag.process_query(query, response_placeholder)
475
+ print(f"Generated response: {response}")
476
+ except Exception as e:
477
+ print(f"Query processing error: {str(e)}")
478
+ response_placeholder.warning("Unable to process your question. Please try again.")
479
+ else:
480
+ st.warning("Please enter a question!")
481
+
482
+ # Footer with improved styling
483
+ st.markdown("<br><br>", unsafe_allow_html=True)
484
+ st.markdown("---")
485
+ st.markdown("""
486
+ <p style='text-align: center; color: #666666; padding: 1rem 0;'>
487
+ Powered by ESPN Data & Mistral AI πŸš€
488
+ </p>
489
+ """, unsafe_allow_html=True)
490
+
491
+ if __name__ == "__main__":
492
+ main()