sagar007 commited on
Commit
b17a402
ยท
verified ยท
1 Parent(s): 270de0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -94
app.py CHANGED
@@ -8,13 +8,14 @@ from datetime import datetime
8
  import os
9
  import subprocess
10
  import numpy as np
 
11
 
12
- ## Install required dependencies for Kokoro with better error handling
13
  try:
14
  subprocess.run(['git', 'lfs', 'install'], check=True)
15
  if not os.path.exists('Kokoro-82M'):
16
  subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
17
-
18
  # Try installing espeak with proper package manager commands
19
  try:
20
  # Update package list first
@@ -32,42 +33,58 @@ except Exception as e:
32
  print(f"Warning: Initial setup error: {str(e)}")
33
  print("Continuing with limited functionality...")
34
 
35
- # Initialize models and tokenizers
36
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
  tokenizer.pad_token = tokenizer.eos_token
39
 
40
- # Move model initialization inside a function to prevent CUDA initialization in main process
41
- def init_models():
42
- model = AutoModelForCausalLM.from_pretrained(
43
- model_name,
44
- device_map="auto",
45
- offload_folder="offload",
46
- low_cpu_mem_usage=True,
47
- torch_dtype=torch.float16
48
- )
49
- return model
 
 
 
 
 
 
 
 
 
50
 
51
- # Initialize Kokoro TTS with better error handling
52
  try:
53
- import sys
54
- sys.path.append('Kokoro-82M')
55
- from models import build_model
56
- from kokoro import generate
57
-
58
- # Don't initialize models/voices in main process for ZeroGPU compatibility
59
- VOICE_CHOICES = {
60
- '๐Ÿ‡บ๐Ÿ‡ธ Female (Default)': 'af',
61
- '๐Ÿ‡บ๐Ÿ‡ธ Bella': 'af_bella',
62
- '๐Ÿ‡บ๐Ÿ‡ธ Sarah': 'af_sarah',
63
- '๐Ÿ‡บ๐Ÿ‡ธ Nicole': 'af_nicole'
64
- }
65
- TTS_ENABLED = True
 
 
 
 
 
 
 
66
  except Exception as e:
67
  print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
68
  TTS_ENABLED = False
69
 
70
- def get_web_results(query, max_results=5): # Increased to 5 for better context
 
71
  """Get web search results using DuckDuckGo"""
72
  try:
73
  with DDGS() as ddgs:
@@ -79,30 +96,27 @@ def get_web_results(query, max_results=5): # Increased to 5 for better context
79
  "date": result.get("published", "")
80
  } for result in results]
81
  except Exception as e:
 
82
  return []
83
 
84
- def format_prompt(query, context):
85
  """Format the prompt with web context"""
86
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
  context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
88
  return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
89
  Current Time: {current_time}
90
-
91
  Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
92
-
93
  Query: {query}
94
-
95
  Web Context:
96
  {context_lines}
97
-
98
  Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
99
  Answer:"""
100
 
101
- def format_sources(web_results):
102
  """Format sources with more details"""
103
  if not web_results:
104
  return "<div class='no-sources'>No sources available</div>"
105
-
106
  sources_html = "<div class='sources-container'>"
107
  for i, res in enumerate(web_results, 1):
108
  title = res["title"] or "Source"
@@ -120,22 +134,18 @@ def format_sources(web_results):
120
  sources_html += "</div>"
121
  return sources_html
122
 
123
- # Wrap the answer generation with spaces.GPU decorator
124
  @spaces.GPU(duration=30)
125
- def generate_answer(prompt):
126
  """Generate answer using the DeepSeek model"""
127
- # Initialize model inside the GPU-decorated function
128
- model = init_models()
129
-
130
  inputs = tokenizer(
131
- prompt,
132
- return_tensors="pt",
133
  padding=True,
134
  truncation=True,
135
  max_length=512,
136
  return_attention_mask=True
137
  ).to(model.device)
138
-
139
  outputs = model.generate(
140
  inputs.input_ids,
141
  attention_mask=inputs.attention_mask,
@@ -148,74 +158,75 @@ def generate_answer(prompt):
148
  )
149
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
150
 
151
- # Similarly wrap TTS generation with spaces.GPU
152
  @spaces.GPU(duration=60)
153
- def generate_speech_with_gpu(text, voice_name='af'):
154
- """Generate speech from text using Kokoro TTS model with GPU handling"""
 
 
 
 
 
155
  try:
156
- # Initialize TTS model and voice inside GPU function
157
- device = 'cuda'
158
- TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
159
- VOICEPACK = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', weights_only=True).to(device)
160
-
 
161
  # Clean the text
162
  clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
163
  clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
164
-
165
- # Split long text into chunks
166
  max_chars = 1000
167
  chunks = []
168
-
169
  if len(clean_text) > max_chars:
170
  sentences = clean_text.split('.')
171
  current_chunk = ""
172
-
173
  for sentence in sentences:
174
- if len(current_chunk) + len(sentence) < max_chars:
175
  current_chunk += sentence + "."
176
  else:
177
- if current_chunk:
178
- chunks.append(current_chunk)
179
  current_chunk = sentence + "."
180
- if current_chunk:
181
- chunks.append(current_chunk)
182
  else:
183
  chunks = [clean_text]
184
-
 
185
  # Generate audio for each chunk
186
  audio_chunks = []
187
  for chunk in chunks:
188
  if chunk.strip(): # Only process non-empty chunks
189
- chunk_audio, _ = generate(TTS_MODEL, chunk.strip(), VOICEPACK, lang='a')
190
  if isinstance(chunk_audio, torch.Tensor):
191
  chunk_audio = chunk_audio.cpu().numpy()
192
  audio_chunks.append(chunk_audio)
193
-
194
- # Concatenate chunks if we have any
195
  if audio_chunks:
196
- if len(audio_chunks) > 1:
197
- final_audio = np.concatenate(audio_chunks)
198
- else:
199
- final_audio = audio_chunks[0]
200
  return (24000, final_audio)
201
- return None
202
-
 
 
203
  except Exception as e:
204
  print(f"Error generating speech: {str(e)}")
205
  import traceback
206
  traceback.print_exc()
207
  return None
208
-
209
- def process_query(query, history, selected_voice='af'):
210
  """Process user query with streaming effect"""
211
  try:
212
  if history is None:
213
  history = []
214
-
215
  # Get web results first
216
  web_results = get_web_results(query)
217
  sources_html = format_sources(web_results)
218
-
219
  current_history = history + [[query, "*Searching...*"]]
220
  yield {
221
  answer_output: gr.Markdown("*Searching & Thinking...*"),
@@ -224,48 +235,48 @@ def process_query(query, history, selected_voice='af'):
224
  chat_history_display: current_history,
225
  audio_output: None
226
  }
227
-
228
  # Generate answer
229
  prompt = format_prompt(query, web_results)
230
  answer = generate_answer(prompt)
231
  final_answer = answer.split("Answer:")[-1].strip()
232
-
233
- # Generate speech from the answer
 
 
 
234
  if TTS_ENABLED:
 
 
 
 
 
 
 
235
  try:
236
- yield {
237
- answer_output: gr.Markdown(final_answer),
238
- sources_output: gr.HTML(sources_html),
239
- search_btn: gr.Button("Generating audio...", interactive=False),
240
- chat_history_display: history + [[query, final_answer]],
241
- audio_output: None
242
- }
243
-
244
  audio = generate_speech_with_gpu(final_answer, selected_voice)
245
- if audio is None:
246
- print("Failed to generate audio")
247
  except Exception as e:
248
- print(f"Error in speech generation: {str(e)}")
249
  audio = None
250
  else:
251
  audio = None
252
-
253
- updated_history = history + [[query, final_answer]]
254
  yield {
255
  answer_output: gr.Markdown(final_answer),
256
  sources_output: gr.HTML(sources_html),
257
  search_btn: gr.Button("Search", interactive=True),
258
  chat_history_display: updated_history,
259
- audio_output: audio if audio is not None else gr.Audio(value=None)
260
  }
 
261
  except Exception as e:
262
  error_message = str(e)
263
  if "GPU quota" in error_message:
264
- error_message = "โš ๏ธ GPU quota exceeded. Please try again later when the daily quota resets."
265
-
266
  yield {
267
  answer_output: gr.Markdown(f"Error: {error_message}"),
268
- sources_output: gr.HTML(sources_html),
269
  search_btn: gr.Button("Search", interactive=True),
270
  chat_history_display: history + [[query, f"*Error: {error_message}*"]],
271
  audio_output: None
 
8
  import os
9
  import subprocess
10
  import numpy as np
11
+ from typing import List, Dict, Tuple, Any
12
 
13
+ # Install required dependencies for Kokoro with better error handling
14
  try:
15
  subprocess.run(['git', 'lfs', 'install'], check=True)
16
  if not os.path.exists('Kokoro-82M'):
17
  subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
18
+
19
  # Try installing espeak with proper package manager commands
20
  try:
21
  # Update package list first
 
33
  print(f"Warning: Initial setup error: {str(e)}")
34
  print("Continuing with limited functionality...")
35
 
36
+ # --- Initialization (Do this ONCE) ---
37
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  tokenizer.pad_token = tokenizer.eos_token
40
 
41
+ # Initialize DeepSeek model
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ device_map="auto",
45
+ offload_folder="offload",
46
+ low_cpu_mem_usage=True,
47
+ torch_dtype=torch.float16
48
+ )
49
+
50
+ # Initialize Kokoro TTS (with error handling)
51
+ VOICE_CHOICES = {
52
+ '๐Ÿ‡บ๐Ÿ‡ธ Female (Default)': 'af',
53
+ '๐Ÿ‡บ๐Ÿ‡ธ Bella': 'af_bella',
54
+ '๐Ÿ‡บ๐Ÿ‡ธ Sarah': 'af_sarah',
55
+ '๐Ÿ‡บ๐Ÿ‡ธ Nicole': 'af_nicole'
56
+ }
57
+ TTS_ENABLED = False
58
+ TTS_MODEL = None
59
+ VOICEPACK = None
60
 
 
61
  try:
62
+ if os.path.exists('Kokoro-82M'):
63
+ import sys
64
+ sys.path.append('Kokoro-82M')
65
+ from models import build_model # type: ignore
66
+ from kokoro import generate # type: ignore
67
+
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu' # Correct device handling
69
+ TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
70
+
71
+ # Load default voice
72
+ try:
73
+ VOICEPACK = torch.load('Kokoro-82M/voices/af.pt', map_location=device, weights_only=True)
74
+ except Exception as e:
75
+ print(f"Warning: Could not load default voice: {e}")
76
+ raise
77
+
78
+ TTS_ENABLED = True
79
+ else:
80
+ print("Warning: Kokoro-82M directory not found. TTS disabled.")
81
+
82
  except Exception as e:
83
  print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
84
  TTS_ENABLED = False
85
 
86
+
87
+ def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
88
  """Get web search results using DuckDuckGo"""
89
  try:
90
  with DDGS() as ddgs:
 
96
  "date": result.get("published", "")
97
  } for result in results]
98
  except Exception as e:
99
+ print(f"Error in web search: {e}")
100
  return []
101
 
102
+ def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
103
  """Format the prompt with web context"""
104
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
105
  context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
106
  return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
107
  Current Time: {current_time}
 
108
  Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
 
109
  Query: {query}
 
110
  Web Context:
111
  {context_lines}
 
112
  Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
113
  Answer:"""
114
 
115
+ def format_sources(web_results: List[Dict[str, str]]) -> str:
116
  """Format sources with more details"""
117
  if not web_results:
118
  return "<div class='no-sources'>No sources available</div>"
119
+
120
  sources_html = "<div class='sources-container'>"
121
  for i, res in enumerate(web_results, 1):
122
  title = res["title"] or "Source"
 
134
  sources_html += "</div>"
135
  return sources_html
136
 
 
137
  @spaces.GPU(duration=30)
138
+ def generate_answer(prompt: str) -> str:
139
  """Generate answer using the DeepSeek model"""
 
 
 
140
  inputs = tokenizer(
141
+ prompt,
142
+ return_tensors="pt",
143
  padding=True,
144
  truncation=True,
145
  max_length=512,
146
  return_attention_mask=True
147
  ).to(model.device)
148
+
149
  outputs = model.generate(
150
  inputs.input_ids,
151
  attention_mask=inputs.attention_mask,
 
158
  )
159
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
160
 
 
161
  @spaces.GPU(duration=60)
162
+ def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_MODEL, voicepack = VOICEPACK) -> Tuple[int, np.ndarray] | None:
163
+ """Generate speech from text using Kokoro TTS model."""
164
+
165
+ if not TTS_ENABLED or tts_model is None:
166
+ print("TTS is not enabled or model is not loaded.")
167
+ return None
168
+
169
  try:
170
+ # Load voicepack if it hasn't been loaded or if a different voice is requested
171
+ if voice_name != 'af' or voicepack is None :
172
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
173
+ voicepack = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', map_location=device, weights_only=True)
174
+
175
+
176
  # Clean the text
177
  clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
178
  clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
179
+
180
+ # Split long text into chunks (improved logic)
181
  max_chars = 1000
182
  chunks = []
 
183
  if len(clean_text) > max_chars:
184
  sentences = clean_text.split('.')
185
  current_chunk = ""
 
186
  for sentence in sentences:
187
+ if len(current_chunk) + len(sentence) + 1 < max_chars: # +1 for the dot
188
  current_chunk += sentence + "."
189
  else:
190
+ chunks.append(current_chunk.strip())
 
191
  current_chunk = sentence + "."
192
+ if current_chunk: # Add the last chunk
193
+ chunks.append(current_chunk.strip())
194
  else:
195
  chunks = [clean_text]
196
+
197
+
198
  # Generate audio for each chunk
199
  audio_chunks = []
200
  for chunk in chunks:
201
  if chunk.strip(): # Only process non-empty chunks
202
+ chunk_audio, _ = generate(tts_model, chunk, voicepack, lang='a')
203
  if isinstance(chunk_audio, torch.Tensor):
204
  chunk_audio = chunk_audio.cpu().numpy()
205
  audio_chunks.append(chunk_audio)
206
+
207
+ # Concatenate chunks
208
  if audio_chunks:
209
+ final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
 
 
 
210
  return (24000, final_audio)
211
+ else:
212
+ return None
213
+
214
+
215
  except Exception as e:
216
  print(f"Error generating speech: {str(e)}")
217
  import traceback
218
  traceback.print_exc()
219
  return None
220
+ def process_query(query: str, history: List[List[str]], selected_voice: str = 'af') -> Dict[str, Any]:
 
221
  """Process user query with streaming effect"""
222
  try:
223
  if history is None:
224
  history = []
225
+
226
  # Get web results first
227
  web_results = get_web_results(query)
228
  sources_html = format_sources(web_results)
229
+
230
  current_history = history + [[query, "*Searching...*"]]
231
  yield {
232
  answer_output: gr.Markdown("*Searching & Thinking...*"),
 
235
  chat_history_display: current_history,
236
  audio_output: None
237
  }
238
+
239
  # Generate answer
240
  prompt = format_prompt(query, web_results)
241
  answer = generate_answer(prompt)
242
  final_answer = answer.split("Answer:")[-1].strip()
243
+
244
+ # Update history *before* TTS (important for correct display)
245
+ updated_history = history + [[query, final_answer]]
246
+
247
+ # Generate speech from the answer (only if enabled)
248
  if TTS_ENABLED:
249
+ yield { # Intermediate update before TTS
250
+ answer_output: gr.Markdown(final_answer),
251
+ sources_output: gr.HTML(sources_html),
252
+ search_btn: gr.Button("Generating audio...", interactive=False),
253
+ chat_history_display: updated_history,
254
+ audio_output: None
255
+ }
256
  try:
 
 
 
 
 
 
 
 
257
  audio = generate_speech_with_gpu(final_answer, selected_voice)
 
 
258
  except Exception as e:
259
+ print(f"Error during TTS: {e}")
260
  audio = None
261
  else:
262
  audio = None
263
+
 
264
  yield {
265
  answer_output: gr.Markdown(final_answer),
266
  sources_output: gr.HTML(sources_html),
267
  search_btn: gr.Button("Search", interactive=True),
268
  chat_history_display: updated_history,
269
+ audio_output: audio if audio is not None else gr.Audio(value=None) # Ensure valid audio output
270
  }
271
+
272
  except Exception as e:
273
  error_message = str(e)
274
  if "GPU quota" in error_message:
275
+ error_message = "โš ๏ธ GPU quota exceeded. Please try again later when the daily quota resets."
276
+
277
  yield {
278
  answer_output: gr.Markdown(f"Error: {error_message}"),
279
+ sources_output: gr.HTML(sources_html), #Still show sources on error
280
  search_btn: gr.Button("Search", interactive=True),
281
  chat_history_display: history + [[query, f"*Error: {error_message}*"]],
282
  audio_output: None