Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import fastapi
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
-
from fastapi.responses import HTMLResponse, FileResponse
|
5 |
from fastapi import FastAPI, Request, Form, UploadFile, File
|
6 |
import os
|
7 |
import time
|
@@ -10,6 +13,13 @@ import json
|
|
10 |
import shutil
|
11 |
import uvicorn
|
12 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Setup logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
@@ -31,47 +41,19 @@ if html_template.exists() and not static_html.exists():
|
|
31 |
# Mount static files
|
32 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
33 |
|
34 |
-
#
|
35 |
-
SESSIONS = {}
|
36 |
-
|
37 |
-
def generate_session_id():
|
38 |
-
"""Generate a unique session ID."""
|
39 |
-
import uuid
|
40 |
-
return str(uuid.uuid4())
|
41 |
-
|
42 |
def mock_transcribe(audio_bytes):
|
43 |
"""Mock function to simulate speech-to-text."""
|
44 |
-
# In production, this would use Whisper
|
45 |
logger.info("Transcribing audio...")
|
46 |
-
time.sleep(
|
47 |
return "This is a mock transcription of the audio."
|
48 |
|
49 |
-
def mock_agent_response(text, session_id="default"):
|
50 |
-
"""Mock function to simulate agent reasoning."""
|
51 |
-
# In production, this would use a real LLM
|
52 |
-
logger.info(f"Processing query: {text}")
|
53 |
-
time.sleep(1.5) # Simulate processing time
|
54 |
-
|
55 |
-
# Simple keyword-based responses
|
56 |
-
if "5g" in text.lower():
|
57 |
-
return "5G is the fifth generation of cellular networks, offering higher speeds, lower latency, and more capacity than previous generations."
|
58 |
-
elif "telecom" in text.lower():
|
59 |
-
return "Telecommunications (telecom) refers to the exchange of information over significant distances by electronic means."
|
60 |
-
elif "webrtc" in text.lower():
|
61 |
-
return "WebRTC (Web Real-Time Communication) is a free, open-source project that enables web browsers and mobile applications to have real-time communication via simple APIs."
|
62 |
-
else:
|
63 |
-
return "I'm an AI assistant specialized in telecom topics. Feel free to ask me about 5G, network technologies, or telecommunications in general."
|
64 |
-
|
65 |
def mock_synthesize_speech(text):
|
66 |
"""Mock function to simulate text-to-speech."""
|
67 |
-
# In production, this would use a real TTS engine
|
68 |
logger.info("Synthesizing speech...")
|
69 |
time.sleep(0.5) # Simulate processing time
|
70 |
|
71 |
# Create a dummy audio file
|
72 |
-
import numpy as np
|
73 |
-
from scipy.io.wavfile import write
|
74 |
-
|
75 |
sample_rate = 22050
|
76 |
duration = 2 # seconds
|
77 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
@@ -83,9 +65,6 @@ def mock_synthesize_speech(text):
|
|
83 |
with open(output_file, "rb") as f:
|
84 |
audio_bytes = f.read()
|
85 |
|
86 |
-
# Clean up
|
87 |
-
os.remove(output_file)
|
88 |
-
|
89 |
return audio_bytes
|
90 |
|
91 |
# Routes for the API
|
@@ -94,6 +73,15 @@ async def root():
|
|
94 |
"""Serve the main UI."""
|
95 |
return FileResponse("static/index.html")
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
@app.post("/api/transcribe")
|
98 |
async def transcribe(file: UploadFile = File(...)):
|
99 |
"""Transcribe audio to text."""
|
@@ -103,17 +91,24 @@ async def transcribe(file: UploadFile = File(...)):
|
|
103 |
return {"transcription": text}
|
104 |
except Exception as e:
|
105 |
logger.error(f"Transcription error: {str(e)}")
|
106 |
-
return
|
|
|
|
|
|
|
107 |
|
108 |
@app.post("/api/query")
|
109 |
async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
|
110 |
"""Process a text query with the agent."""
|
111 |
try:
|
112 |
-
response =
|
|
|
113 |
return {"response": response}
|
114 |
except Exception as e:
|
115 |
logger.error(f"Query error: {str(e)}")
|
116 |
-
return
|
|
|
|
|
|
|
117 |
|
118 |
@app.post("/api/speak")
|
119 |
async def speak(text: str = Form(...)):
|
@@ -127,19 +122,135 @@ async def speak(text: str = Form(...)):
|
|
127 |
)
|
128 |
except Exception as e:
|
129 |
logger.error(f"Speech synthesis error: {str(e)}")
|
130 |
-
return
|
|
|
|
|
|
|
131 |
|
132 |
@app.post("/api/session")
|
133 |
async def create_session():
|
134 |
"""Create a new session."""
|
135 |
-
|
136 |
-
|
|
|
137 |
return {"session_id": session_id}
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# Gradio interface
|
140 |
with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
|
141 |
gr.Markdown("# AGI Telecom POC Demo")
|
142 |
-
gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept
|
143 |
|
144 |
with gr.Row():
|
145 |
with gr.Column():
|
@@ -165,10 +276,13 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
|
|
165 |
|
166 |
# Status and info
|
167 |
status_output = gr.Textbox(label="Status", value="Ready")
|
|
|
168 |
|
169 |
# Link components with functions
|
170 |
def update_session():
|
171 |
-
|
|
|
|
|
172 |
status = f"Created new session: {new_id}"
|
173 |
return new_id, status
|
174 |
|
@@ -189,7 +303,7 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
|
|
189 |
text = mock_transcribe(audio_bytes)
|
190 |
|
191 |
# Get response
|
192 |
-
response =
|
193 |
|
194 |
# Synthesize
|
195 |
audio_bytes = mock_synthesize_speech(response)
|
@@ -210,7 +324,7 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
|
|
210 |
)
|
211 |
|
212 |
query_btn.click(
|
213 |
-
lambda text, session:
|
214 |
inputs=[text_input, session_id],
|
215 |
outputs=[response_output]
|
216 |
)
|
@@ -227,6 +341,16 @@ with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as in
|
|
227 |
inputs=[audio_input, session_id],
|
228 |
outputs=[transcription_output, response_output, audio_output, status_output]
|
229 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
# Mount Gradio app
|
232 |
app = gr.mount_gradio_app(app, interface, path="/gradio")
|
|
|
1 |
+
"""
|
2 |
+
Main FastAPI application integrating all components with Hugging Face Inference Endpoint.
|
3 |
+
"""
|
4 |
import gradio as gr
|
5 |
import fastapi
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
+
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
|
8 |
from fastapi import FastAPI, Request, Form, UploadFile, File
|
9 |
import os
|
10 |
import time
|
|
|
13 |
import shutil
|
14 |
import uvicorn
|
15 |
from pathlib import Path
|
16 |
+
from typing import Dict, List, Optional, Any
|
17 |
+
import io
|
18 |
+
import numpy as np
|
19 |
+
from scipy.io.wavfile import write
|
20 |
+
|
21 |
+
# Import our modules
|
22 |
+
from local_llm import run_llm, run_llm_with_memory, clear_memory, get_memory_sessions, get_model_info, test_endpoint
|
23 |
|
24 |
# Setup logging
|
25 |
logging.basicConfig(level=logging.INFO)
|
|
|
41 |
# Mount static files
|
42 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
43 |
|
44 |
+
# Helper functions for mock implementations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def mock_transcribe(audio_bytes):
|
46 |
"""Mock function to simulate speech-to-text."""
|
|
|
47 |
logger.info("Transcribing audio...")
|
48 |
+
time.sleep(0.5) # Simulate processing time
|
49 |
return "This is a mock transcription of the audio."
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def mock_synthesize_speech(text):
|
52 |
"""Mock function to simulate text-to-speech."""
|
|
|
53 |
logger.info("Synthesizing speech...")
|
54 |
time.sleep(0.5) # Simulate processing time
|
55 |
|
56 |
# Create a dummy audio file
|
|
|
|
|
|
|
57 |
sample_rate = 22050
|
58 |
duration = 2 # seconds
|
59 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
|
|
65 |
with open(output_file, "rb") as f:
|
66 |
audio_bytes = f.read()
|
67 |
|
|
|
|
|
|
|
68 |
return audio_bytes
|
69 |
|
70 |
# Routes for the API
|
|
|
73 |
"""Serve the main UI."""
|
74 |
return FileResponse("static/index.html")
|
75 |
|
76 |
+
@app.get("/health")
|
77 |
+
async def health_check():
|
78 |
+
"""Health check endpoint."""
|
79 |
+
endpoint_status = test_endpoint()
|
80 |
+
return {
|
81 |
+
"status": "ok",
|
82 |
+
"endpoint": endpoint_status
|
83 |
+
}
|
84 |
+
|
85 |
@app.post("/api/transcribe")
|
86 |
async def transcribe(file: UploadFile = File(...)):
|
87 |
"""Transcribe audio to text."""
|
|
|
91 |
return {"transcription": text}
|
92 |
except Exception as e:
|
93 |
logger.error(f"Transcription error: {str(e)}")
|
94 |
+
return JSONResponse(
|
95 |
+
status_code=500,
|
96 |
+
content={"error": f"Failed to transcribe audio: {str(e)}"}
|
97 |
+
)
|
98 |
|
99 |
@app.post("/api/query")
|
100 |
async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
|
101 |
"""Process a text query with the agent."""
|
102 |
try:
|
103 |
+
response = run_llm_with_memory(input_text, session_id=session_id)
|
104 |
+
logger.info(f"Query: {input_text[:30]}... Response: {response[:30]}...")
|
105 |
return {"response": response}
|
106 |
except Exception as e:
|
107 |
logger.error(f"Query error: {str(e)}")
|
108 |
+
return JSONResponse(
|
109 |
+
status_code=500,
|
110 |
+
content={"error": f"Failed to process query: {str(e)}"}
|
111 |
+
)
|
112 |
|
113 |
@app.post("/api/speak")
|
114 |
async def speak(text: str = Form(...)):
|
|
|
122 |
)
|
123 |
except Exception as e:
|
124 |
logger.error(f"Speech synthesis error: {str(e)}")
|
125 |
+
return JSONResponse(
|
126 |
+
status_code=500,
|
127 |
+
content={"error": f"Failed to synthesize speech: {str(e)}"}
|
128 |
+
)
|
129 |
|
130 |
@app.post("/api/session")
|
131 |
async def create_session():
|
132 |
"""Create a new session."""
|
133 |
+
import uuid
|
134 |
+
session_id = str(uuid.uuid4())
|
135 |
+
clear_memory(session_id)
|
136 |
return {"session_id": session_id}
|
137 |
|
138 |
+
@app.delete("/api/session/{session_id}")
|
139 |
+
async def delete_session(session_id: str):
|
140 |
+
"""Delete a session."""
|
141 |
+
success = clear_memory(session_id)
|
142 |
+
if success:
|
143 |
+
return {"message": f"Session {session_id} cleared"}
|
144 |
+
else:
|
145 |
+
return JSONResponse(
|
146 |
+
status_code=404,
|
147 |
+
content={"error": f"Session {session_id} not found"}
|
148 |
+
)
|
149 |
+
|
150 |
+
@app.get("/api/sessions")
|
151 |
+
async def list_sessions():
|
152 |
+
"""List all active sessions."""
|
153 |
+
return {"sessions": get_memory_sessions()}
|
154 |
+
|
155 |
+
@app.get("/api/model_info")
|
156 |
+
async def model_info():
|
157 |
+
"""Get information about the model."""
|
158 |
+
return get_model_info()
|
159 |
+
|
160 |
+
@app.post("/api/complete")
|
161 |
+
async def complete_flow(
|
162 |
+
request: Request,
|
163 |
+
audio_file: UploadFile = File(None),
|
164 |
+
text_input: str = Form(None),
|
165 |
+
session_id: str = Form("default")
|
166 |
+
):
|
167 |
+
"""
|
168 |
+
Complete flow: audio to text to agent to speech.
|
169 |
+
"""
|
170 |
+
try:
|
171 |
+
# If audio file provided, transcribe it
|
172 |
+
if audio_file:
|
173 |
+
audio_bytes = await audio_file.read()
|
174 |
+
text_input = mock_transcribe(audio_bytes)
|
175 |
+
logger.info(f"Transcribed input: {text_input[:30]}...")
|
176 |
+
|
177 |
+
# Process with agent
|
178 |
+
if not text_input:
|
179 |
+
return JSONResponse(
|
180 |
+
status_code=400,
|
181 |
+
content={"error": "No input provided"}
|
182 |
+
)
|
183 |
+
|
184 |
+
response = run_llm_with_memory(text_input, session_id=session_id)
|
185 |
+
logger.info(f"Agent response: {response[:30]}...")
|
186 |
+
|
187 |
+
# Synthesize speech
|
188 |
+
audio_bytes = mock_synthesize_speech(response)
|
189 |
+
|
190 |
+
# Save audio to a temporary file
|
191 |
+
import tempfile
|
192 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
193 |
+
temp_file.write(audio_bytes)
|
194 |
+
temp_file.close()
|
195 |
+
|
196 |
+
# Generate URL for audio
|
197 |
+
host = request.headers.get("host", "localhost")
|
198 |
+
scheme = request.headers.get("x-forwarded-proto", "http")
|
199 |
+
audio_url = f"{scheme}://{host}/audio/{os.path.basename(temp_file.name)}"
|
200 |
+
|
201 |
+
return {
|
202 |
+
"input": text_input,
|
203 |
+
"response": response,
|
204 |
+
"audio_url": audio_url
|
205 |
+
}
|
206 |
+
|
207 |
+
except Exception as e:
|
208 |
+
logger.error(f"Complete flow error: {str(e)}")
|
209 |
+
return JSONResponse(
|
210 |
+
status_code=500,
|
211 |
+
content={"error": f"Failed to process: {str(e)}"}
|
212 |
+
)
|
213 |
+
|
214 |
+
@app.get("/audio/{filename}")
|
215 |
+
async def get_audio(filename: str):
|
216 |
+
"""
|
217 |
+
Serve temporary audio files.
|
218 |
+
"""
|
219 |
+
try:
|
220 |
+
# Ensure filename only contains safe characters
|
221 |
+
import re
|
222 |
+
if not re.match(r'^[a-zA-Z0-9_.-]+$', filename):
|
223 |
+
return JSONResponse(
|
224 |
+
status_code=400,
|
225 |
+
content={"error": "Invalid filename"}
|
226 |
+
)
|
227 |
+
|
228 |
+
temp_dir = tempfile.gettempdir()
|
229 |
+
file_path = os.path.join(temp_dir, filename)
|
230 |
+
|
231 |
+
if not os.path.exists(file_path):
|
232 |
+
return JSONResponse(
|
233 |
+
status_code=404,
|
234 |
+
content={"error": "File not found"}
|
235 |
+
)
|
236 |
+
|
237 |
+
return FileResponse(
|
238 |
+
file_path,
|
239 |
+
media_type="audio/wav",
|
240 |
+
filename=filename
|
241 |
+
)
|
242 |
+
|
243 |
+
except Exception as e:
|
244 |
+
logger.error(f"Audio serving error: {str(e)}")
|
245 |
+
return JSONResponse(
|
246 |
+
status_code=500,
|
247 |
+
content={"error": f"Failed to serve audio: {str(e)}"}
|
248 |
+
)
|
249 |
+
|
250 |
# Gradio interface
|
251 |
with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
|
252 |
gr.Markdown("# AGI Telecom POC Demo")
|
253 |
+
gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept using a Hugging Face Inference Endpoint.")
|
254 |
|
255 |
with gr.Row():
|
256 |
with gr.Column():
|
|
|
276 |
|
277 |
# Status and info
|
278 |
status_output = gr.Textbox(label="Status", value="Ready")
|
279 |
+
endpoint_status = gr.Textbox(label="Endpoint Status", value="Checking endpoint connection...")
|
280 |
|
281 |
# Link components with functions
|
282 |
def update_session():
|
283 |
+
import uuid
|
284 |
+
new_id = str(uuid.uuid4())
|
285 |
+
clear_memory(new_id)
|
286 |
status = f"Created new session: {new_id}"
|
287 |
return new_id, status
|
288 |
|
|
|
303 |
text = mock_transcribe(audio_bytes)
|
304 |
|
305 |
# Get response
|
306 |
+
response = run_llm_with_memory(text, session)
|
307 |
|
308 |
# Synthesize
|
309 |
audio_bytes = mock_synthesize_speech(response)
|
|
|
324 |
)
|
325 |
|
326 |
query_btn.click(
|
327 |
+
lambda text, session: run_llm_with_memory(text, session),
|
328 |
inputs=[text_input, session_id],
|
329 |
outputs=[response_output]
|
330 |
)
|
|
|
341 |
inputs=[audio_input, session_id],
|
342 |
outputs=[transcription_output, response_output, audio_output, status_output]
|
343 |
)
|
344 |
+
|
345 |
+
# Check endpoint on load
|
346 |
+
def check_endpoint():
|
347 |
+
status = test_endpoint()
|
348 |
+
if status["status"] == "connected":
|
349 |
+
return f"✅ Connected to endpoint: {status['message']}"
|
350 |
+
else:
|
351 |
+
return f"❌ Error connecting to endpoint: {status['message']}"
|
352 |
+
|
353 |
+
gr.on_load(lambda: gr.update(value=check_endpoint()), outputs=endpoint_status)
|
354 |
|
355 |
# Mount Gradio app
|
356 |
app = gr.mount_gradio_app(app, interface, path="/gradio")
|