tommytracx commited on
Commit
db8e1eb
·
verified ·
1 Parent(s): 9f5d5d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -44
app.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
1
  import gradio as gr
2
  import fastapi
3
  from fastapi.staticfiles import StaticFiles
4
- from fastapi.responses import HTMLResponse, FileResponse
5
  from fastapi import FastAPI, Request, Form, UploadFile, File
6
  import os
7
  import time
@@ -10,6 +13,13 @@ import json
10
  import shutil
11
  import uvicorn
12
  from pathlib import Path
 
 
 
 
 
 
 
13
 
14
  # Setup logging
15
  logging.basicConfig(level=logging.INFO)
@@ -31,47 +41,19 @@ if html_template.exists() and not static_html.exists():
31
  # Mount static files
32
  app.mount("/static", StaticFiles(directory="static"), name="static")
33
 
34
- # Mock data and functions to simulate the real implementation
35
- SESSIONS = {}
36
-
37
- def generate_session_id():
38
- """Generate a unique session ID."""
39
- import uuid
40
- return str(uuid.uuid4())
41
-
42
  def mock_transcribe(audio_bytes):
43
  """Mock function to simulate speech-to-text."""
44
- # In production, this would use Whisper
45
  logger.info("Transcribing audio...")
46
- time.sleep(1) # Simulate processing time
47
  return "This is a mock transcription of the audio."
48
 
49
- def mock_agent_response(text, session_id="default"):
50
- """Mock function to simulate agent reasoning."""
51
- # In production, this would use a real LLM
52
- logger.info(f"Processing query: {text}")
53
- time.sleep(1.5) # Simulate processing time
54
-
55
- # Simple keyword-based responses
56
- if "5g" in text.lower():
57
- return "5G is the fifth generation of cellular networks, offering higher speeds, lower latency, and more capacity than previous generations."
58
- elif "telecom" in text.lower():
59
- return "Telecommunications (telecom) refers to the exchange of information over significant distances by electronic means."
60
- elif "webrtc" in text.lower():
61
- return "WebRTC (Web Real-Time Communication) is a free, open-source project that enables web browsers and mobile applications to have real-time communication via simple APIs."
62
- else:
63
- return "I'm an AI assistant specialized in telecom topics. Feel free to ask me about 5G, network technologies, or telecommunications in general."
64
-
65
  def mock_synthesize_speech(text):
66
  """Mock function to simulate text-to-speech."""
67
- # In production, this would use a real TTS engine
68
  logger.info("Synthesizing speech...")
69
  time.sleep(0.5) # Simulate processing time
70
 
71
  # Create a dummy audio file
72
- import numpy as np
73
- from scipy.io.wavfile import write
74
-
75
  sample_rate = 22050
76
  duration = 2 # seconds
77
  t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
@@ -83,9 +65,6 @@ def mock_synthesize_speech(text):
83
  with open(output_file, "rb") as f:
84
  audio_bytes = f.read()
85
 
86
- # Clean up
87
- os.remove(output_file)
88
-
89
  return audio_bytes
90
 
91
  # Routes for the API
@@ -94,6 +73,15 @@ async def root():
94
  """Serve the main UI."""
95
  return FileResponse("static/index.html")
96
 
 
 
 
 
 
 
 
 
 
97
  @app.post("/api/transcribe")
98
  async def transcribe(file: UploadFile = File(...)):
99
  """Transcribe audio to text."""
@@ -103,17 +91,24 @@ async def transcribe(file: UploadFile = File(...)):
103
  return {"transcription": text}
104
  except Exception as e:
105
  logger.error(f"Transcription error: {str(e)}")
106
- return {"error": f"Failed to transcribe audio: {str(e)}"}
 
 
 
107
 
108
  @app.post("/api/query")
109
  async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
110
  """Process a text query with the agent."""
111
  try:
112
- response = mock_agent_response(input_text, session_id)
 
113
  return {"response": response}
114
  except Exception as e:
115
  logger.error(f"Query error: {str(e)}")
116
- return {"error": f"Failed to process query: {str(e)}"}
 
 
 
117
 
118
  @app.post("/api/speak")
119
  async def speak(text: str = Form(...)):
@@ -127,19 +122,135 @@ async def speak(text: str = Form(...)):
127
  )
128
  except Exception as e:
129
  logger.error(f"Speech synthesis error: {str(e)}")
130
- return {"error": f"Failed to synthesize speech: {str(e)}"}
 
 
 
131
 
132
  @app.post("/api/session")
133
  async def create_session():
134
  """Create a new session."""
135
- session_id = generate_session_id()
136
- SESSIONS[session_id] = {"created_at": time.time()}
 
137
  return {"session_id": session_id}
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # Gradio interface
140
  with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
141
  gr.Markdown("# AGI Telecom POC Demo")
142
- gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept. The full interface is available via the direct API.")
143
 
144
  with gr.Row():
145
  with gr.Column():
@@ -165,10 +276,13 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
165
 
166
  # Status and info
167
  status_output = gr.Textbox(label="Status", value="Ready")
 
168
 
169
  # Link components with functions
170
  def update_session():
171
- new_id = generate_session_id()
 
 
172
  status = f"Created new session: {new_id}"
173
  return new_id, status
174
 
@@ -189,7 +303,7 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
189
  text = mock_transcribe(audio_bytes)
190
 
191
  # Get response
192
- response = mock_agent_response(text, session)
193
 
194
  # Synthesize
195
  audio_bytes = mock_synthesize_speech(response)
@@ -210,7 +324,7 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
210
  )
211
 
212
  query_btn.click(
213
- lambda text, session: mock_agent_response(text, session),
214
  inputs=[text_input, session_id],
215
  outputs=[response_output]
216
  )
@@ -227,6 +341,16 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
227
  inputs=[audio_input, session_id],
228
  outputs=[transcription_output, response_output, audio_output, status_output]
229
  )
 
 
 
 
 
 
 
 
 
 
230
 
231
  # Mount Gradio app
232
  app = gr.mount_gradio_app(app, interface, path="/gradio")
 
1
+ """
2
+ Main FastAPI application integrating all components with Hugging Face Inference Endpoint.
3
+ """
4
  import gradio as gr
5
  import fastapi
6
  from fastapi.staticfiles import StaticFiles
7
+ from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
8
  from fastapi import FastAPI, Request, Form, UploadFile, File
9
  import os
10
  import time
 
13
  import shutil
14
  import uvicorn
15
  from pathlib import Path
16
+ from typing import Dict, List, Optional, Any
17
+ import io
18
+ import numpy as np
19
+ from scipy.io.wavfile import write
20
+
21
+ # Import our modules
22
+ from local_llm import run_llm, run_llm_with_memory, clear_memory, get_memory_sessions, get_model_info, test_endpoint
23
 
24
  # Setup logging
25
  logging.basicConfig(level=logging.INFO)
 
41
  # Mount static files
42
  app.mount("/static", StaticFiles(directory="static"), name="static")
43
 
44
+ # Helper functions for mock implementations
 
 
 
 
 
 
 
45
  def mock_transcribe(audio_bytes):
46
  """Mock function to simulate speech-to-text."""
 
47
  logger.info("Transcribing audio...")
48
+ time.sleep(0.5) # Simulate processing time
49
  return "This is a mock transcription of the audio."
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def mock_synthesize_speech(text):
52
  """Mock function to simulate text-to-speech."""
 
53
  logger.info("Synthesizing speech...")
54
  time.sleep(0.5) # Simulate processing time
55
 
56
  # Create a dummy audio file
 
 
 
57
  sample_rate = 22050
58
  duration = 2 # seconds
59
  t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
 
65
  with open(output_file, "rb") as f:
66
  audio_bytes = f.read()
67
 
 
 
 
68
  return audio_bytes
69
 
70
  # Routes for the API
 
73
  """Serve the main UI."""
74
  return FileResponse("static/index.html")
75
 
76
+ @app.get("/health")
77
+ async def health_check():
78
+ """Health check endpoint."""
79
+ endpoint_status = test_endpoint()
80
+ return {
81
+ "status": "ok",
82
+ "endpoint": endpoint_status
83
+ }
84
+
85
  @app.post("/api/transcribe")
86
  async def transcribe(file: UploadFile = File(...)):
87
  """Transcribe audio to text."""
 
91
  return {"transcription": text}
92
  except Exception as e:
93
  logger.error(f"Transcription error: {str(e)}")
94
+ return JSONResponse(
95
+ status_code=500,
96
+ content={"error": f"Failed to transcribe audio: {str(e)}"}
97
+ )
98
 
99
  @app.post("/api/query")
100
  async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
101
  """Process a text query with the agent."""
102
  try:
103
+ response = run_llm_with_memory(input_text, session_id=session_id)
104
+ logger.info(f"Query: {input_text[:30]}... Response: {response[:30]}...")
105
  return {"response": response}
106
  except Exception as e:
107
  logger.error(f"Query error: {str(e)}")
108
+ return JSONResponse(
109
+ status_code=500,
110
+ content={"error": f"Failed to process query: {str(e)}"}
111
+ )
112
 
113
  @app.post("/api/speak")
114
  async def speak(text: str = Form(...)):
 
122
  )
123
  except Exception as e:
124
  logger.error(f"Speech synthesis error: {str(e)}")
125
+ return JSONResponse(
126
+ status_code=500,
127
+ content={"error": f"Failed to synthesize speech: {str(e)}"}
128
+ )
129
 
130
  @app.post("/api/session")
131
  async def create_session():
132
  """Create a new session."""
133
+ import uuid
134
+ session_id = str(uuid.uuid4())
135
+ clear_memory(session_id)
136
  return {"session_id": session_id}
137
 
138
+ @app.delete("/api/session/{session_id}")
139
+ async def delete_session(session_id: str):
140
+ """Delete a session."""
141
+ success = clear_memory(session_id)
142
+ if success:
143
+ return {"message": f"Session {session_id} cleared"}
144
+ else:
145
+ return JSONResponse(
146
+ status_code=404,
147
+ content={"error": f"Session {session_id} not found"}
148
+ )
149
+
150
+ @app.get("/api/sessions")
151
+ async def list_sessions():
152
+ """List all active sessions."""
153
+ return {"sessions": get_memory_sessions()}
154
+
155
+ @app.get("/api/model_info")
156
+ async def model_info():
157
+ """Get information about the model."""
158
+ return get_model_info()
159
+
160
+ @app.post("/api/complete")
161
+ async def complete_flow(
162
+ request: Request,
163
+ audio_file: UploadFile = File(None),
164
+ text_input: str = Form(None),
165
+ session_id: str = Form("default")
166
+ ):
167
+ """
168
+ Complete flow: audio to text to agent to speech.
169
+ """
170
+ try:
171
+ # If audio file provided, transcribe it
172
+ if audio_file:
173
+ audio_bytes = await audio_file.read()
174
+ text_input = mock_transcribe(audio_bytes)
175
+ logger.info(f"Transcribed input: {text_input[:30]}...")
176
+
177
+ # Process with agent
178
+ if not text_input:
179
+ return JSONResponse(
180
+ status_code=400,
181
+ content={"error": "No input provided"}
182
+ )
183
+
184
+ response = run_llm_with_memory(text_input, session_id=session_id)
185
+ logger.info(f"Agent response: {response[:30]}...")
186
+
187
+ # Synthesize speech
188
+ audio_bytes = mock_synthesize_speech(response)
189
+
190
+ # Save audio to a temporary file
191
+ import tempfile
192
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
193
+ temp_file.write(audio_bytes)
194
+ temp_file.close()
195
+
196
+ # Generate URL for audio
197
+ host = request.headers.get("host", "localhost")
198
+ scheme = request.headers.get("x-forwarded-proto", "http")
199
+ audio_url = f"{scheme}://{host}/audio/{os.path.basename(temp_file.name)}"
200
+
201
+ return {
202
+ "input": text_input,
203
+ "response": response,
204
+ "audio_url": audio_url
205
+ }
206
+
207
+ except Exception as e:
208
+ logger.error(f"Complete flow error: {str(e)}")
209
+ return JSONResponse(
210
+ status_code=500,
211
+ content={"error": f"Failed to process: {str(e)}"}
212
+ )
213
+
214
+ @app.get("/audio/{filename}")
215
+ async def get_audio(filename: str):
216
+ """
217
+ Serve temporary audio files.
218
+ """
219
+ try:
220
+ # Ensure filename only contains safe characters
221
+ import re
222
+ if not re.match(r'^[a-zA-Z0-9_.-]+$', filename):
223
+ return JSONResponse(
224
+ status_code=400,
225
+ content={"error": "Invalid filename"}
226
+ )
227
+
228
+ temp_dir = tempfile.gettempdir()
229
+ file_path = os.path.join(temp_dir, filename)
230
+
231
+ if not os.path.exists(file_path):
232
+ return JSONResponse(
233
+ status_code=404,
234
+ content={"error": "File not found"}
235
+ )
236
+
237
+ return FileResponse(
238
+ file_path,
239
+ media_type="audio/wav",
240
+ filename=filename
241
+ )
242
+
243
+ except Exception as e:
244
+ logger.error(f"Audio serving error: {str(e)}")
245
+ return JSONResponse(
246
+ status_code=500,
247
+ content={"error": f"Failed to serve audio: {str(e)}"}
248
+ )
249
+
250
  # Gradio interface
251
  with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
252
  gr.Markdown("# AGI Telecom POC Demo")
253
+ gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept using a Hugging Face Inference Endpoint.")
254
 
255
  with gr.Row():
256
  with gr.Column():
 
276
 
277
  # Status and info
278
  status_output = gr.Textbox(label="Status", value="Ready")
279
+ endpoint_status = gr.Textbox(label="Endpoint Status", value="Checking endpoint connection...")
280
 
281
  # Link components with functions
282
  def update_session():
283
+ import uuid
284
+ new_id = str(uuid.uuid4())
285
+ clear_memory(new_id)
286
  status = f"Created new session: {new_id}"
287
  return new_id, status
288
 
 
303
  text = mock_transcribe(audio_bytes)
304
 
305
  # Get response
306
+ response = run_llm_with_memory(text, session)
307
 
308
  # Synthesize
309
  audio_bytes = mock_synthesize_speech(response)
 
324
  )
325
 
326
  query_btn.click(
327
+ lambda text, session: run_llm_with_memory(text, session),
328
  inputs=[text_input, session_id],
329
  outputs=[response_output]
330
  )
 
341
  inputs=[audio_input, session_id],
342
  outputs=[transcription_output, response_output, audio_output, status_output]
343
  )
344
+
345
+ # Check endpoint on load
346
+ def check_endpoint():
347
+ status = test_endpoint()
348
+ if status["status"] == "connected":
349
+ return f"✅ Connected to endpoint: {status['message']}"
350
+ else:
351
+ return f"❌ Error connecting to endpoint: {status['message']}"
352
+
353
+ gr.on_load(lambda: gr.update(value=check_endpoint()), outputs=endpoint_status)
354
 
355
  # Mount Gradio app
356
  app = gr.mount_gradio_app(app, interface, path="/gradio")