nishantgaurav23 commited on
Commit
f5d1367
Β·
verified Β·
1 Parent(s): c72316a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -0
app.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore", category=UserWarning)
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import re
9
+ import requests
10
+ #from dotenv import load_dotenv
11
+ from embedding_processor import SentenceTransformerRetriever, process_data
12
+ import pickle
13
+
14
+ import os
15
+ import warnings
16
+ import json # Add this import
17
+
18
+
19
+
20
+ # Load environment variables
21
+ #load_dotenv()
22
+
23
+ # Add the new function here, right after imports and before API configuration
24
+ @st.cache_data
25
+ @st.cache_data
26
+ def load_from_drive(file_id: str):
27
+ """Load pickle file directly from Google Drive"""
28
+ try:
29
+ # Direct download URL for Google Drive
30
+ url = f"https://drive.google.com/uc?id={file_id}&export=download"
31
+
32
+ # First request to get the confirmation token
33
+ session = requests.Session()
34
+ response = session.get(url, stream=True)
35
+
36
+ # Check if we need to confirm download
37
+ for key, value in response.cookies.items():
38
+ if key.startswith('download_warning'):
39
+ # Add confirmation parameter to the URL
40
+ url = f"{url}&confirm={value}"
41
+ response = session.get(url, stream=True)
42
+ break
43
+
44
+ # Load the content and convert to pickle
45
+ content = response.content
46
+ print(f"Successfully downloaded {len(content)} bytes")
47
+ return pickle.loads(content)
48
+
49
+ except Exception as e:
50
+ print(f"Detailed error: {str(e)}") # This will help debug
51
+ st.error(f"Error loading file from Drive: {str(e)}")
52
+ return None
53
+
54
+ # Hugging Face API configuration
55
+
56
+ # API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
57
+ # headers = {"Authorization": f"Bearer HF_TOKEN"}
58
+ model_name = 'mistralai/Mistral-7B-v0.1'
59
+
60
+
61
+ class RAGPipeline:
62
+
63
+ def __init__(self, data_folder: str, k: int = 3): # Reduced k for faster retrieval
64
+ self.data_folder = data_folder
65
+ self.k = k
66
+ self.retriever = SentenceTransformerRetriever()
67
+ cache_data = process_data(data_folder)
68
+ self.documents = cache_data['documents']
69
+ self.retriever.store_embeddings(cache_data['embeddings'])
70
+
71
+
72
+ # Alternative API call with streaming
73
+ def query_model(self, payload):
74
+ """Query the Hugging Face API with streaming"""
75
+ try:
76
+ # Add streaming parameters
77
+ payload["parameters"]["stream"] = True
78
+
79
+ response = requests.post(
80
+ model_name,
81
+ #headers=headers,
82
+ json=payload,
83
+ stream=True
84
+ )
85
+ response.raise_for_status()
86
+
87
+ # Collect the entire response
88
+ full_response = ""
89
+ for line in response.iter_lines():
90
+ if line:
91
+ try:
92
+ json_response = json.loads(line)
93
+ if isinstance(json_response, list) and len(json_response) > 0:
94
+ chunk_text = json_response[0].get('generated_text', '')
95
+ if chunk_text:
96
+ full_response += chunk_text
97
+ except json.JSONDecodeError as e:
98
+ print(f"Error decoding JSON: {e}")
99
+ continue
100
+
101
+ return [{"generated_text": full_response}]
102
+
103
+ except requests.exceptions.RequestException as e:
104
+ print(f"API request failed: {str(e)}")
105
+ raise
106
+
107
+ def preprocess_query(self, query: str) -> str:
108
+ """Clean and prepare the query"""
109
+ query = query.lower().strip()
110
+ query = re.sub(r'\s+', ' ', query)
111
+ return query
112
+
113
+ def postprocess_response(self, response: str) -> str:
114
+ """Clean up the generated response"""
115
+ response = response.strip()
116
+ response = re.sub(r'\s+', ' ', response)
117
+ response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
118
+ return response
119
+
120
+
121
+ def process_query(self, query: str, placeholder) -> str:
122
+ try:
123
+ # Preprocess query
124
+ query = self.preprocess_query(query)
125
+
126
+ # Show retrieval status
127
+ status = placeholder.empty()
128
+ status.write("πŸ” Finding relevant information...")
129
+
130
+ # Get embeddings and search using tensor operations
131
+ query_embedding = self.retriever.encode([query])
132
+ similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
133
+ scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
134
+
135
+ # Print search results for debugging
136
+ print("\nSearch Results:")
137
+ for idx, score in zip(indices.tolist(), scores.tolist()):
138
+ print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
139
+
140
+ relevant_docs = [self.documents[idx] for idx in indices.tolist()]
141
+
142
+ # Update status
143
+ status.write("πŸ’­ Generating response...")
144
+
145
+ # Prepare context and prompt
146
+ context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
147
+ prompt = f"""Answer this question using the given context. Be specific and detailed.
148
+
149
+ Context: {context}
150
+
151
+ Question: {query}
152
+
153
+ Answer (provide a complete, detailed response):"""
154
+
155
+ # Generate response
156
+ response_placeholder = placeholder.empty()
157
+
158
+ try:
159
+ response = requests.post(
160
+ model_name,
161
+ #headers=headers,
162
+ json={
163
+ "inputs": prompt,
164
+ "parameters": {
165
+ "max_new_tokens": 1024,
166
+ "temperature": 0.5,
167
+ "top_p": 0.9,
168
+ "top_k": 50,
169
+ "repetition_penalty": 1.03,
170
+ "do_sample": True
171
+ }
172
+ },
173
+ timeout=30
174
+ ).json()
175
+
176
+ if response and isinstance(response, list) and len(response) > 0:
177
+ generated_text = response[0].get('generated_text', '').strip()
178
+ if generated_text:
179
+ # Find and extract only the answer part
180
+ if "Answer:" in generated_text:
181
+ answer_part = generated_text.split("Answer:")[-1].strip()
182
+ elif "Answer (provide a complete, detailed response):" in generated_text:
183
+ answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
184
+ else:
185
+ answer_part = generated_text.strip()
186
+
187
+ # Clean up the answer
188
+ answer_part = answer_part.replace("Context:", "").replace("Question:", "")
189
+
190
+ final_response = self.postprocess_response(answer_part)
191
+ response_placeholder.markdown(final_response)
192
+ return final_response
193
+
194
+ message = "No relevant answer found. Please try rephrasing your question."
195
+ response_placeholder.warning(message)
196
+ return message
197
+
198
+ except Exception as e:
199
+ print(f"Generation error: {str(e)}")
200
+ message = "Had some trouble generating the response. Please try again."
201
+ response_placeholder.warning(message)
202
+ return message
203
+
204
+ except Exception as e:
205
+ print(f"Process error: {str(e)}")
206
+ message = "Something went wrong. Please try again with a different question."
207
+ placeholder.warning(message)
208
+ return message
209
+ def check_environment():
210
+ """Check if the environment is properly set up"""
211
+ # if not headers['Authorization']:
212
+ # st.error("HUGGINGFACE_API_KEY environment variable not set!")
213
+ # st.stop()
214
+ # return False
215
+
216
+ try:
217
+ import torch
218
+ import sentence_transformers
219
+ return True
220
+ except ImportError as e:
221
+ st.error(f"Missing required package: {str(e)}")
222
+ st.stop()
223
+ return False
224
+
225
+ # @st.cache_resource
226
+ # def initialize_rag_pipeline():
227
+ # """Initialize the RAG pipeline once"""
228
+ # data_folder = "ESPN_data"
229
+ # return RAGPipeline(data_folder)
230
+
231
+ @st.cache_resource
232
+ def initialize_rag_pipeline():
233
+ """Initialize the RAG pipeline once"""
234
+ data_folder = "ESPN_data"
235
+ drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
236
+
237
+ with st.spinner("Loading embeddings from Google Drive..."):
238
+ cache_data = load_from_drive(drive_file_id)
239
+ if cache_data is None:
240
+ st.error("Failed to load embeddings from Google Drive")
241
+ st.stop()
242
+
243
+ rag = RAGPipeline(data_folder)
244
+ rag.documents = cache_data['documents']
245
+ rag.retriever.store_embeddings(cache_data['embeddings'])
246
+ return rag
247
+
248
+ def main():
249
+ # Environment check
250
+ if not check_environment():
251
+ return
252
+
253
+ # Page config
254
+ st.set_page_config(
255
+ page_title="The Sport Chatbot",
256
+ page_icon="πŸ†",
257
+ layout="wide"
258
+ )
259
+
260
+ # Improved CSS styling
261
+ st.markdown("""
262
+ <style>
263
+ /* Container styling */
264
+ .block-container {
265
+ padding-top: 2rem;
266
+ padding-bottom: 2rem;
267
+ }
268
+
269
+ /* Text input styling */
270
+ .stTextInput > div > div > input {
271
+ width: 100%;
272
+ }
273
+
274
+ /* Button styling */
275
+ .stButton > button {
276
+ width: 200px;
277
+ margin: 0 auto;
278
+ display: block;
279
+ background-color: #FF4B4B;
280
+ color: white;
281
+ border-radius: 5px;
282
+ padding: 0.5rem 1rem;
283
+ }
284
+
285
+ /* Title styling */
286
+ .main-title {
287
+ text-align: center;
288
+ padding: 1rem 0;
289
+ font-size: 3rem;
290
+ color: #1F1F1F;
291
+ }
292
+
293
+ .sub-title {
294
+ text-align: center;
295
+ padding: 0.5rem 0;
296
+ font-size: 1.5rem;
297
+ color: #4F4F4F;
298
+ }
299
+
300
+ /* Description styling */
301
+ .description {
302
+ text-align: center;
303
+ color: #666666;
304
+ padding: 0.5rem 0;
305
+ font-size: 1.1rem;
306
+ line-height: 1.6;
307
+ margin-bottom: 1rem;
308
+ }
309
+
310
+ /* Answer container styling */
311
+ .stMarkdown {
312
+ max-width: 100%;
313
+ }
314
+
315
+ /* Streamlit default overrides */
316
+ .st-emotion-cache-16idsys p {
317
+ font-size: 1.1rem;
318
+ line-height: 1.6;
319
+ }
320
+
321
+ /* Container for main content */
322
+ .main-content {
323
+ max-width: 1200px;
324
+ margin: 0 auto;
325
+ padding: 0 1rem;
326
+ }
327
+ </style>
328
+ """, unsafe_allow_html=True)
329
+
330
+ # Header section with improved styling
331
+ st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
332
+ st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
333
+ st.markdown("""
334
+ <p class='description'>
335
+ Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
336
+ With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
337
+ </p>
338
+ <p class='description'>
339
+ Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
340
+ </p>
341
+ """, unsafe_allow_html=True)
342
+
343
+ # Add some spacing
344
+ st.markdown("<br>", unsafe_allow_html=True)
345
+
346
+ # Initialize the pipeline
347
+ try:
348
+ with st.spinner("Loading resources..."):
349
+ rag = initialize_rag_pipeline()
350
+ except Exception as e:
351
+ print(f"Initialization error: {str(e)}")
352
+ st.error("Unable to initialize the system. Please check if all required files are present.")
353
+ st.stop()
354
+
355
+ # Create columns for layout with golden ratio
356
+ col1, col2, col3 = st.columns([1, 6, 1])
357
+
358
+ with col2:
359
+ # Query input with label styling
360
+ query = st.text_input("What would you like to know about sports?")
361
+
362
+ # Centered button
363
+ if st.button("Get Answer"):
364
+ if query:
365
+ response_placeholder = st.empty()
366
+ try:
367
+ response = rag.process_query(query, response_placeholder)
368
+ print(f"Generated response: {response}")
369
+ except Exception as e:
370
+ print(f"Query processing error: {str(e)}")
371
+ response_placeholder.warning("Unable to process your question. Please try again.")
372
+ else:
373
+ st.warning("Please enter a question!")
374
+
375
+ # Footer with improved styling
376
+ st.markdown("<br><br>", unsafe_allow_html=True)
377
+ st.markdown("---")
378
+ st.markdown("""
379
+ <p style='text-align: center; color: #666666; padding: 1rem 0;'>
380
+ Powered by ESPN Data & Mistral AI πŸš€
381
+ </p>
382
+ """, unsafe_allow_html=True)
383
+
384
+ if __name__ == "__main__":
385
+ main()