nishantgaurav23 commited on
Commit
d3429ce
·
verified ·
1 Parent(s): f5d1367

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -385
app.py DELETED
@@ -1,385 +0,0 @@
1
- import os
2
- import warnings
3
- warnings.filterwarnings("ignore", category=UserWarning)
4
-
5
- import streamlit as st
6
- import torch
7
- import torch.nn.functional as F
8
- import re
9
- import requests
10
- #from dotenv import load_dotenv
11
- from embedding_processor import SentenceTransformerRetriever, process_data
12
- import pickle
13
-
14
- import os
15
- import warnings
16
- import json # Add this import
17
-
18
-
19
-
20
- # Load environment variables
21
- #load_dotenv()
22
-
23
- # Add the new function here, right after imports and before API configuration
24
- @st.cache_data
25
- @st.cache_data
26
- def load_from_drive(file_id: str):
27
- """Load pickle file directly from Google Drive"""
28
- try:
29
- # Direct download URL for Google Drive
30
- url = f"https://drive.google.com/uc?id={file_id}&export=download"
31
-
32
- # First request to get the confirmation token
33
- session = requests.Session()
34
- response = session.get(url, stream=True)
35
-
36
- # Check if we need to confirm download
37
- for key, value in response.cookies.items():
38
- if key.startswith('download_warning'):
39
- # Add confirmation parameter to the URL
40
- url = f"{url}&confirm={value}"
41
- response = session.get(url, stream=True)
42
- break
43
-
44
- # Load the content and convert to pickle
45
- content = response.content
46
- print(f"Successfully downloaded {len(content)} bytes")
47
- return pickle.loads(content)
48
-
49
- except Exception as e:
50
- print(f"Detailed error: {str(e)}") # This will help debug
51
- st.error(f"Error loading file from Drive: {str(e)}")
52
- return None
53
-
54
- # Hugging Face API configuration
55
-
56
- # API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
57
- # headers = {"Authorization": f"Bearer HF_TOKEN"}
58
- model_name = 'mistralai/Mistral-7B-v0.1'
59
-
60
-
61
- class RAGPipeline:
62
-
63
- def __init__(self, data_folder: str, k: int = 3): # Reduced k for faster retrieval
64
- self.data_folder = data_folder
65
- self.k = k
66
- self.retriever = SentenceTransformerRetriever()
67
- cache_data = process_data(data_folder)
68
- self.documents = cache_data['documents']
69
- self.retriever.store_embeddings(cache_data['embeddings'])
70
-
71
-
72
- # Alternative API call with streaming
73
- def query_model(self, payload):
74
- """Query the Hugging Face API with streaming"""
75
- try:
76
- # Add streaming parameters
77
- payload["parameters"]["stream"] = True
78
-
79
- response = requests.post(
80
- model_name,
81
- #headers=headers,
82
- json=payload,
83
- stream=True
84
- )
85
- response.raise_for_status()
86
-
87
- # Collect the entire response
88
- full_response = ""
89
- for line in response.iter_lines():
90
- if line:
91
- try:
92
- json_response = json.loads(line)
93
- if isinstance(json_response, list) and len(json_response) > 0:
94
- chunk_text = json_response[0].get('generated_text', '')
95
- if chunk_text:
96
- full_response += chunk_text
97
- except json.JSONDecodeError as e:
98
- print(f"Error decoding JSON: {e}")
99
- continue
100
-
101
- return [{"generated_text": full_response}]
102
-
103
- except requests.exceptions.RequestException as e:
104
- print(f"API request failed: {str(e)}")
105
- raise
106
-
107
- def preprocess_query(self, query: str) -> str:
108
- """Clean and prepare the query"""
109
- query = query.lower().strip()
110
- query = re.sub(r'\s+', ' ', query)
111
- return query
112
-
113
- def postprocess_response(self, response: str) -> str:
114
- """Clean up the generated response"""
115
- response = response.strip()
116
- response = re.sub(r'\s+', ' ', response)
117
- response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
118
- return response
119
-
120
-
121
- def process_query(self, query: str, placeholder) -> str:
122
- try:
123
- # Preprocess query
124
- query = self.preprocess_query(query)
125
-
126
- # Show retrieval status
127
- status = placeholder.empty()
128
- status.write("🔍 Finding relevant information...")
129
-
130
- # Get embeddings and search using tensor operations
131
- query_embedding = self.retriever.encode([query])
132
- similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
133
- scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
134
-
135
- # Print search results for debugging
136
- print("\nSearch Results:")
137
- for idx, score in zip(indices.tolist(), scores.tolist()):
138
- print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
139
-
140
- relevant_docs = [self.documents[idx] for idx in indices.tolist()]
141
-
142
- # Update status
143
- status.write("💭 Generating response...")
144
-
145
- # Prepare context and prompt
146
- context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
147
- prompt = f"""Answer this question using the given context. Be specific and detailed.
148
-
149
- Context: {context}
150
-
151
- Question: {query}
152
-
153
- Answer (provide a complete, detailed response):"""
154
-
155
- # Generate response
156
- response_placeholder = placeholder.empty()
157
-
158
- try:
159
- response = requests.post(
160
- model_name,
161
- #headers=headers,
162
- json={
163
- "inputs": prompt,
164
- "parameters": {
165
- "max_new_tokens": 1024,
166
- "temperature": 0.5,
167
- "top_p": 0.9,
168
- "top_k": 50,
169
- "repetition_penalty": 1.03,
170
- "do_sample": True
171
- }
172
- },
173
- timeout=30
174
- ).json()
175
-
176
- if response and isinstance(response, list) and len(response) > 0:
177
- generated_text = response[0].get('generated_text', '').strip()
178
- if generated_text:
179
- # Find and extract only the answer part
180
- if "Answer:" in generated_text:
181
- answer_part = generated_text.split("Answer:")[-1].strip()
182
- elif "Answer (provide a complete, detailed response):" in generated_text:
183
- answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
184
- else:
185
- answer_part = generated_text.strip()
186
-
187
- # Clean up the answer
188
- answer_part = answer_part.replace("Context:", "").replace("Question:", "")
189
-
190
- final_response = self.postprocess_response(answer_part)
191
- response_placeholder.markdown(final_response)
192
- return final_response
193
-
194
- message = "No relevant answer found. Please try rephrasing your question."
195
- response_placeholder.warning(message)
196
- return message
197
-
198
- except Exception as e:
199
- print(f"Generation error: {str(e)}")
200
- message = "Had some trouble generating the response. Please try again."
201
- response_placeholder.warning(message)
202
- return message
203
-
204
- except Exception as e:
205
- print(f"Process error: {str(e)}")
206
- message = "Something went wrong. Please try again with a different question."
207
- placeholder.warning(message)
208
- return message
209
- def check_environment():
210
- """Check if the environment is properly set up"""
211
- # if not headers['Authorization']:
212
- # st.error("HUGGINGFACE_API_KEY environment variable not set!")
213
- # st.stop()
214
- # return False
215
-
216
- try:
217
- import torch
218
- import sentence_transformers
219
- return True
220
- except ImportError as e:
221
- st.error(f"Missing required package: {str(e)}")
222
- st.stop()
223
- return False
224
-
225
- # @st.cache_resource
226
- # def initialize_rag_pipeline():
227
- # """Initialize the RAG pipeline once"""
228
- # data_folder = "ESPN_data"
229
- # return RAGPipeline(data_folder)
230
-
231
- @st.cache_resource
232
- def initialize_rag_pipeline():
233
- """Initialize the RAG pipeline once"""
234
- data_folder = "ESPN_data"
235
- drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
236
-
237
- with st.spinner("Loading embeddings from Google Drive..."):
238
- cache_data = load_from_drive(drive_file_id)
239
- if cache_data is None:
240
- st.error("Failed to load embeddings from Google Drive")
241
- st.stop()
242
-
243
- rag = RAGPipeline(data_folder)
244
- rag.documents = cache_data['documents']
245
- rag.retriever.store_embeddings(cache_data['embeddings'])
246
- return rag
247
-
248
- def main():
249
- # Environment check
250
- if not check_environment():
251
- return
252
-
253
- # Page config
254
- st.set_page_config(
255
- page_title="The Sport Chatbot",
256
- page_icon="🏆",
257
- layout="wide"
258
- )
259
-
260
- # Improved CSS styling
261
- st.markdown("""
262
- <style>
263
- /* Container styling */
264
- .block-container {
265
- padding-top: 2rem;
266
- padding-bottom: 2rem;
267
- }
268
-
269
- /* Text input styling */
270
- .stTextInput > div > div > input {
271
- width: 100%;
272
- }
273
-
274
- /* Button styling */
275
- .stButton > button {
276
- width: 200px;
277
- margin: 0 auto;
278
- display: block;
279
- background-color: #FF4B4B;
280
- color: white;
281
- border-radius: 5px;
282
- padding: 0.5rem 1rem;
283
- }
284
-
285
- /* Title styling */
286
- .main-title {
287
- text-align: center;
288
- padding: 1rem 0;
289
- font-size: 3rem;
290
- color: #1F1F1F;
291
- }
292
-
293
- .sub-title {
294
- text-align: center;
295
- padding: 0.5rem 0;
296
- font-size: 1.5rem;
297
- color: #4F4F4F;
298
- }
299
-
300
- /* Description styling */
301
- .description {
302
- text-align: center;
303
- color: #666666;
304
- padding: 0.5rem 0;
305
- font-size: 1.1rem;
306
- line-height: 1.6;
307
- margin-bottom: 1rem;
308
- }
309
-
310
- /* Answer container styling */
311
- .stMarkdown {
312
- max-width: 100%;
313
- }
314
-
315
- /* Streamlit default overrides */
316
- .st-emotion-cache-16idsys p {
317
- font-size: 1.1rem;
318
- line-height: 1.6;
319
- }
320
-
321
- /* Container for main content */
322
- .main-content {
323
- max-width: 1200px;
324
- margin: 0 auto;
325
- padding: 0 1rem;
326
- }
327
- </style>
328
- """, unsafe_allow_html=True)
329
-
330
- # Header section with improved styling
331
- st.markdown("<h1 class='main-title'>🏆 The Sport Chatbot</h1>", unsafe_allow_html=True)
332
- st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
333
- st.markdown("""
334
- <p class='description'>
335
- Hey there! 👋 I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
336
- With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
337
- </p>
338
- <p class='description'>
339
- Got any general questions? Feel free to ask—I'll do my best to provide answers based on the information I've been trained on!
340
- </p>
341
- """, unsafe_allow_html=True)
342
-
343
- # Add some spacing
344
- st.markdown("<br>", unsafe_allow_html=True)
345
-
346
- # Initialize the pipeline
347
- try:
348
- with st.spinner("Loading resources..."):
349
- rag = initialize_rag_pipeline()
350
- except Exception as e:
351
- print(f"Initialization error: {str(e)}")
352
- st.error("Unable to initialize the system. Please check if all required files are present.")
353
- st.stop()
354
-
355
- # Create columns for layout with golden ratio
356
- col1, col2, col3 = st.columns([1, 6, 1])
357
-
358
- with col2:
359
- # Query input with label styling
360
- query = st.text_input("What would you like to know about sports?")
361
-
362
- # Centered button
363
- if st.button("Get Answer"):
364
- if query:
365
- response_placeholder = st.empty()
366
- try:
367
- response = rag.process_query(query, response_placeholder)
368
- print(f"Generated response: {response}")
369
- except Exception as e:
370
- print(f"Query processing error: {str(e)}")
371
- response_placeholder.warning("Unable to process your question. Please try again.")
372
- else:
373
- st.warning("Please enter a question!")
374
-
375
- # Footer with improved styling
376
- st.markdown("<br><br>", unsafe_allow_html=True)
377
- st.markdown("---")
378
- st.markdown("""
379
- <p style='text-align: center; color: #666666; padding: 1rem 0;'>
380
- Powered by ESPN Data & Mistral AI 🚀
381
- </p>
382
- """, unsafe_allow_html=True)
383
-
384
- if __name__ == "__main__":
385
- main()