nishantgaurav23 commited on
Commit
c72316a
·
verified ·
1 Parent(s): 092e93c

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -384
app.py DELETED
@@ -1,384 +0,0 @@
1
- import os
2
- import warnings
3
- warnings.filterwarnings("ignore", category=UserWarning)
4
-
5
- import streamlit as st
6
- import torch
7
- import torch.nn.functional as F
8
- import re
9
- import requests
10
- from dotenv import load_dotenv
11
- from embedding_processor import SentenceTransformerRetriever, process_data
12
- import pickle
13
-
14
- import os
15
- import warnings
16
- import json # Add this import
17
-
18
-
19
-
20
- # Load environment variables
21
- load_dotenv()
22
-
23
- # Add the new function here, right after imports and before API configuration
24
- @st.cache_data
25
- @st.cache_data
26
- def load_from_drive(file_id: str):
27
- """Load pickle file directly from Google Drive"""
28
- try:
29
- # Direct download URL for Google Drive
30
- url = f"https://drive.google.com/uc?id={file_id}&export=download"
31
-
32
- # First request to get the confirmation token
33
- session = requests.Session()
34
- response = session.get(url, stream=True)
35
-
36
- # Check if we need to confirm download
37
- for key, value in response.cookies.items():
38
- if key.startswith('download_warning'):
39
- # Add confirmation parameter to the URL
40
- url = f"{url}&confirm={value}"
41
- response = session.get(url, stream=True)
42
- break
43
-
44
- # Load the content and convert to pickle
45
- content = response.content
46
- print(f"Successfully downloaded {len(content)} bytes")
47
- return pickle.loads(content)
48
-
49
- except Exception as e:
50
- print(f"Detailed error: {str(e)}") # This will help debug
51
- st.error(f"Error loading file from Drive: {str(e)}")
52
- return None
53
-
54
- # Hugging Face API configuration
55
-
56
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
57
- headers = {"Authorization": f"Bearer HF_TOKEN"}
58
-
59
-
60
- class RAGPipeline:
61
-
62
- def __init__(self, data_folder: str, k: int = 3): # Reduced k for faster retrieval
63
- self.data_folder = data_folder
64
- self.k = k
65
- self.retriever = SentenceTransformerRetriever()
66
- cache_data = process_data(data_folder)
67
- self.documents = cache_data['documents']
68
- self.retriever.store_embeddings(cache_data['embeddings'])
69
-
70
-
71
- # Alternative API call with streaming
72
- def query_model(self, payload):
73
- """Query the Hugging Face API with streaming"""
74
- try:
75
- # Add streaming parameters
76
- payload["parameters"]["stream"] = True
77
-
78
- response = requests.post(
79
- API_URL,
80
- headers=headers,
81
- json=payload,
82
- stream=True
83
- )
84
- response.raise_for_status()
85
-
86
- # Collect the entire response
87
- full_response = ""
88
- for line in response.iter_lines():
89
- if line:
90
- try:
91
- json_response = json.loads(line)
92
- if isinstance(json_response, list) and len(json_response) > 0:
93
- chunk_text = json_response[0].get('generated_text', '')
94
- if chunk_text:
95
- full_response += chunk_text
96
- except json.JSONDecodeError as e:
97
- print(f"Error decoding JSON: {e}")
98
- continue
99
-
100
- return [{"generated_text": full_response}]
101
-
102
- except requests.exceptions.RequestException as e:
103
- print(f"API request failed: {str(e)}")
104
- raise
105
-
106
- def preprocess_query(self, query: str) -> str:
107
- """Clean and prepare the query"""
108
- query = query.lower().strip()
109
- query = re.sub(r'\s+', ' ', query)
110
- return query
111
-
112
- def postprocess_response(self, response: str) -> str:
113
- """Clean up the generated response"""
114
- response = response.strip()
115
- response = re.sub(r'\s+', ' ', response)
116
- response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
117
- return response
118
-
119
-
120
- def process_query(self, query: str, placeholder) -> str:
121
- try:
122
- # Preprocess query
123
- query = self.preprocess_query(query)
124
-
125
- # Show retrieval status
126
- status = placeholder.empty()
127
- status.write("🔍 Finding relevant information...")
128
-
129
- # Get embeddings and search using tensor operations
130
- query_embedding = self.retriever.encode([query])
131
- similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
132
- scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
133
-
134
- # Print search results for debugging
135
- print("\nSearch Results:")
136
- for idx, score in zip(indices.tolist(), scores.tolist()):
137
- print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
138
-
139
- relevant_docs = [self.documents[idx] for idx in indices.tolist()]
140
-
141
- # Update status
142
- status.write("💭 Generating response...")
143
-
144
- # Prepare context and prompt
145
- context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
146
- prompt = f"""Answer this question using the given context. Be specific and detailed.
147
-
148
- Context: {context}
149
-
150
- Question: {query}
151
-
152
- Answer (provide a complete, detailed response):"""
153
-
154
- # Generate response
155
- response_placeholder = placeholder.empty()
156
-
157
- try:
158
- response = requests.post(
159
- API_URL,
160
- headers=headers,
161
- json={
162
- "inputs": prompt,
163
- "parameters": {
164
- "max_new_tokens": 1024,
165
- "temperature": 0.5,
166
- "top_p": 0.9,
167
- "top_k": 50,
168
- "repetition_penalty": 1.03,
169
- "do_sample": True
170
- }
171
- },
172
- timeout=30
173
- ).json()
174
-
175
- if response and isinstance(response, list) and len(response) > 0:
176
- generated_text = response[0].get('generated_text', '').strip()
177
- if generated_text:
178
- # Find and extract only the answer part
179
- if "Answer:" in generated_text:
180
- answer_part = generated_text.split("Answer:")[-1].strip()
181
- elif "Answer (provide a complete, detailed response):" in generated_text:
182
- answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
183
- else:
184
- answer_part = generated_text.strip()
185
-
186
- # Clean up the answer
187
- answer_part = answer_part.replace("Context:", "").replace("Question:", "")
188
-
189
- final_response = self.postprocess_response(answer_part)
190
- response_placeholder.markdown(final_response)
191
- return final_response
192
-
193
- message = "No relevant answer found. Please try rephrasing your question."
194
- response_placeholder.warning(message)
195
- return message
196
-
197
- except Exception as e:
198
- print(f"Generation error: {str(e)}")
199
- message = "Had some trouble generating the response. Please try again."
200
- response_placeholder.warning(message)
201
- return message
202
-
203
- except Exception as e:
204
- print(f"Process error: {str(e)}")
205
- message = "Something went wrong. Please try again with a different question."
206
- placeholder.warning(message)
207
- return message
208
- def check_environment():
209
- """Check if the environment is properly set up"""
210
- if not headers['Authorization']:
211
- st.error("HUGGINGFACE_API_KEY environment variable not set!")
212
- st.stop()
213
- return False
214
-
215
- try:
216
- import torch
217
- import sentence_transformers
218
- return True
219
- except ImportError as e:
220
- st.error(f"Missing required package: {str(e)}")
221
- st.stop()
222
- return False
223
-
224
- # @st.cache_resource
225
- # def initialize_rag_pipeline():
226
- # """Initialize the RAG pipeline once"""
227
- # data_folder = "ESPN_data"
228
- # return RAGPipeline(data_folder)
229
-
230
- @st.cache_resource
231
- def initialize_rag_pipeline():
232
- """Initialize the RAG pipeline once"""
233
- data_folder = "ESPN_data"
234
- drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
235
-
236
- with st.spinner("Loading embeddings from Google Drive..."):
237
- cache_data = load_from_drive(drive_file_id)
238
- if cache_data is None:
239
- st.error("Failed to load embeddings from Google Drive")
240
- st.stop()
241
-
242
- rag = RAGPipeline(data_folder)
243
- rag.documents = cache_data['documents']
244
- rag.retriever.store_embeddings(cache_data['embeddings'])
245
- return rag
246
-
247
- def main():
248
- # Environment check
249
- if not check_environment():
250
- return
251
-
252
- # Page config
253
- st.set_page_config(
254
- page_title="The Sport Chatbot",
255
- page_icon="🏆",
256
- layout="wide"
257
- )
258
-
259
- # Improved CSS styling
260
- st.markdown("""
261
- <style>
262
- /* Container styling */
263
- .block-container {
264
- padding-top: 2rem;
265
- padding-bottom: 2rem;
266
- }
267
-
268
- /* Text input styling */
269
- .stTextInput > div > div > input {
270
- width: 100%;
271
- }
272
-
273
- /* Button styling */
274
- .stButton > button {
275
- width: 200px;
276
- margin: 0 auto;
277
- display: block;
278
- background-color: #FF4B4B;
279
- color: white;
280
- border-radius: 5px;
281
- padding: 0.5rem 1rem;
282
- }
283
-
284
- /* Title styling */
285
- .main-title {
286
- text-align: center;
287
- padding: 1rem 0;
288
- font-size: 3rem;
289
- color: #1F1F1F;
290
- }
291
-
292
- .sub-title {
293
- text-align: center;
294
- padding: 0.5rem 0;
295
- font-size: 1.5rem;
296
- color: #4F4F4F;
297
- }
298
-
299
- /* Description styling */
300
- .description {
301
- text-align: center;
302
- color: #666666;
303
- padding: 0.5rem 0;
304
- font-size: 1.1rem;
305
- line-height: 1.6;
306
- margin-bottom: 1rem;
307
- }
308
-
309
- /* Answer container styling */
310
- .stMarkdown {
311
- max-width: 100%;
312
- }
313
-
314
- /* Streamlit default overrides */
315
- .st-emotion-cache-16idsys p {
316
- font-size: 1.1rem;
317
- line-height: 1.6;
318
- }
319
-
320
- /* Container for main content */
321
- .main-content {
322
- max-width: 1200px;
323
- margin: 0 auto;
324
- padding: 0 1rem;
325
- }
326
- </style>
327
- """, unsafe_allow_html=True)
328
-
329
- # Header section with improved styling
330
- st.markdown("<h1 class='main-title'>🏆 The Sport Chatbot</h1>", unsafe_allow_html=True)
331
- st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
332
- st.markdown("""
333
- <p class='description'>
334
- Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
335
- With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
336
- </p>
337
- <p class='description'>
338
- Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!
339
- </p>
340
- """, unsafe_allow_html=True)
341
-
342
- # Add some spacing
343
- st.markdown("<br>", unsafe_allow_html=True)
344
-
345
- # Initialize the pipeline
346
- try:
347
- with st.spinner("Loading resources..."):
348
- rag = initialize_rag_pipeline()
349
- except Exception as e:
350
- print(f"Initialization error: {str(e)}")
351
- st.error("Unable to initialize the system. Please check if all required files are present.")
352
- st.stop()
353
-
354
- # Create columns for layout with golden ratio
355
- col1, col2, col3 = st.columns([1, 6, 1])
356
-
357
- with col2:
358
- # Query input with label styling
359
- query = st.text_input("What would you like to know about sports?")
360
-
361
- # Centered button
362
- if st.button("Get Answer"):
363
- if query:
364
- response_placeholder = st.empty()
365
- try:
366
- response = rag.process_query(query, response_placeholder)
367
- print(f"Generated response: {response}")
368
- except Exception as e:
369
- print(f"Query processing error: {str(e)}")
370
- response_placeholder.warning("Unable to process your question. Please try again.")
371
- else:
372
- st.warning("Please enter a question!")
373
-
374
- # Footer with improved styling
375
- st.markdown("<br><br>", unsafe_allow_html=True)
376
- st.markdown("---")
377
- st.markdown("""
378
- <p style='text-align: center; color: #666666; padding: 1rem 0;'>
379
- Powered by ESPN Data & Mistral AI 🚀
380
- </p>
381
- """, unsafe_allow_html=True)
382
-
383
- if __name__ == "__main__":
384
- main()