Spaces:
Paused
Paused
File size: 11,965 Bytes
db8e1eb 6a500ca c6552d6 6a500ca db8e1eb c6552d6 db8e1eb 6a500ca c6552d6 6a500ca c6552d6 6a500ca c6552d6 6a500ca c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 6a500ca c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 db8e1eb c6552d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
"""
Main FastAPI application integrating all components with Hugging Face Inference Endpoint.
"""
import gradio as gr
import fastapi
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi import FastAPI, Request, Form, UploadFile, File
import os
import time
import logging
import json
import shutil
import uvicorn
from pathlib import Path
from typing import Dict, List, Optional, Any
import io
import numpy as np
from scipy.io.wavfile import write
# Import our modules
from local_llm import run_llm, run_llm_with_memory, clear_memory, get_memory_sessions, get_model_info, test_endpoint
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create the FastAPI app
app = FastAPI(title="AGI Telecom POC")
# Create static directory if it doesn't exist
static_dir = Path("static")
static_dir.mkdir(exist_ok=True)
# Copy index.html from templates to static if it doesn't exist
html_template = Path("templates/index.html")
static_html = static_dir / "index.html"
if html_template.exists() and not static_html.exists():
shutil.copy(html_template, static_html)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Helper functions for mock implementations
def mock_transcribe(audio_bytes):
"""Mock function to simulate speech-to-text."""
logger.info("Transcribing audio...")
time.sleep(0.5) # Simulate processing time
return "This is a mock transcription of the audio."
def mock_synthesize_speech(text):
"""Mock function to simulate text-to-speech."""
logger.info("Synthesizing speech...")
time.sleep(0.5) # Simulate processing time
# Create a dummy audio file
sample_rate = 22050
duration = 2 # seconds
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
audio = np.sin(2 * np.pi * 440 * t) * 0.3
output_file = "temp_audio.wav"
write(output_file, sample_rate, audio.astype(np.float32))
with open(output_file, "rb") as f:
audio_bytes = f.read()
return audio_bytes
# Routes for the API
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the main UI."""
return FileResponse("static/index.html")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
endpoint_status = test_endpoint()
return {
"status": "ok",
"endpoint": endpoint_status
}
@app.post("/api/transcribe")
async def transcribe(file: UploadFile = File(...)):
"""Transcribe audio to text."""
try:
audio_bytes = await file.read()
text = mock_transcribe(audio_bytes)
return {"transcription": text}
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to transcribe audio: {str(e)}"}
)
@app.post("/api/query")
async def query_agent(input_text: str = Form(...), session_id: str = Form("default")):
"""Process a text query with the agent."""
try:
response = run_llm_with_memory(input_text, session_id=session_id)
logger.info(f"Query: {input_text[:30]}... Response: {response[:30]}...")
return {"response": response}
except Exception as e:
logger.error(f"Query error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to process query: {str(e)}"}
)
@app.post("/api/speak")
async def speak(text: str = Form(...)):
"""Convert text to speech."""
try:
audio_bytes = mock_synthesize_speech(text)
return FileResponse(
"temp_audio.wav",
media_type="audio/wav",
filename="response.wav"
)
except Exception as e:
logger.error(f"Speech synthesis error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to synthesize speech: {str(e)}"}
)
@app.post("/api/session")
async def create_session():
"""Create a new session."""
import uuid
session_id = str(uuid.uuid4())
clear_memory(session_id)
return {"session_id": session_id}
@app.delete("/api/session/{session_id}")
async def delete_session(session_id: str):
"""Delete a session."""
success = clear_memory(session_id)
if success:
return {"message": f"Session {session_id} cleared"}
else:
return JSONResponse(
status_code=404,
content={"error": f"Session {session_id} not found"}
)
@app.get("/api/sessions")
async def list_sessions():
"""List all active sessions."""
return {"sessions": get_memory_sessions()}
@app.get("/api/model_info")
async def model_info():
"""Get information about the model."""
return get_model_info()
@app.post("/api/complete")
async def complete_flow(
request: Request,
audio_file: UploadFile = File(None),
text_input: str = Form(None),
session_id: str = Form("default")
):
"""
Complete flow: audio to text to agent to speech.
"""
try:
# If audio file provided, transcribe it
if audio_file:
audio_bytes = await audio_file.read()
text_input = mock_transcribe(audio_bytes)
logger.info(f"Transcribed input: {text_input[:30]}...")
# Process with agent
if not text_input:
return JSONResponse(
status_code=400,
content={"error": "No input provided"}
)
response = run_llm_with_memory(text_input, session_id=session_id)
logger.info(f"Agent response: {response[:30]}...")
# Synthesize speech
audio_bytes = mock_synthesize_speech(response)
# Save audio to a temporary file
import tempfile
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.write(audio_bytes)
temp_file.close()
# Generate URL for audio
host = request.headers.get("host", "localhost")
scheme = request.headers.get("x-forwarded-proto", "http")
audio_url = f"{scheme}://{host}/audio/{os.path.basename(temp_file.name)}"
return {
"input": text_input,
"response": response,
"audio_url": audio_url
}
except Exception as e:
logger.error(f"Complete flow error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to process: {str(e)}"}
)
@app.get("/audio/{filename}")
async def get_audio(filename: str):
"""
Serve temporary audio files.
"""
try:
# Ensure filename only contains safe characters
import re
if not re.match(r'^[a-zA-Z0-9_.-]+$', filename):
return JSONResponse(
status_code=400,
content={"error": "Invalid filename"}
)
temp_dir = tempfile.gettempdir()
file_path = os.path.join(temp_dir, filename)
if not os.path.exists(file_path):
return JSONResponse(
status_code=404,
content={"error": "File not found"}
)
return FileResponse(
file_path,
media_type="audio/wav",
filename=filename
)
except Exception as e:
logger.error(f"Audio serving error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to serve audio: {str(e)}"}
)
# Gradio interface
with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface:
gr.Markdown("# AGI Telecom POC Demo")
gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept using a Hugging Face Inference Endpoint.")
with gr.Row():
with gr.Column():
# Input components
audio_input = gr.Audio(label="Voice Input", type="filepath")
text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...", lines=2)
# Session management
session_id = gr.Textbox(label="Session ID", value="default")
new_session_btn = gr.Button("New Session")
# Action buttons
with gr.Row():
transcribe_btn = gr.Button("Transcribe Audio")
query_btn = gr.Button("Send Query")
speak_btn = gr.Button("Speak Response")
with gr.Column():
# Output components
transcription_output = gr.Textbox(label="Transcription", lines=2)
response_output = gr.Textbox(label="Agent Response", lines=5)
audio_output = gr.Audio(label="Voice Response", autoplay=True)
# Status and info
status_output = gr.Textbox(label="Status", value="Ready")
endpoint_status = gr.Textbox(label="Endpoint Status", value="Checking endpoint connection...")
# Link components with functions
def update_session():
import uuid
new_id = str(uuid.uuid4())
clear_memory(new_id)
status = f"Created new session: {new_id}"
return new_id, status
new_session_btn.click(
update_session,
outputs=[session_id, status_output]
)
def process_audio(audio_path, session):
if not audio_path:
return "No audio provided", "", None, "Error: No audio input"
try:
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# Transcribe
text = mock_transcribe(audio_bytes)
# Get response
response = run_llm_with_memory(text, session)
# Synthesize
audio_bytes = mock_synthesize_speech(response)
temp_file = "temp_response.wav"
with open(temp_file, "wb") as f:
f.write(audio_bytes)
return text, response, temp_file, "Processed successfully"
except Exception as e:
logger.error(f"Error: {str(e)}")
return "", "", None, f"Error: {str(e)}"
transcribe_btn.click(
lambda audio_path: mock_transcribe(open(audio_path, "rb").read()) if audio_path else "No audio provided",
inputs=[audio_input],
outputs=[transcription_output]
)
query_btn.click(
lambda text, session: run_llm_with_memory(text, session),
inputs=[text_input, session_id],
outputs=[response_output]
)
speak_btn.click(
lambda text: "temp_response.wav" if mock_synthesize_speech(text) else None,
inputs=[response_output],
outputs=[audio_output]
)
# Full process
audio_input.change(
process_audio,
inputs=[audio_input, session_id],
outputs=[transcription_output, response_output, audio_output, status_output]
)
# Check endpoint on load
def check_endpoint():
status = test_endpoint()
if status["status"] == "connected":
return f"✅ Connected to endpoint: {status['message']}"
else:
return f"❌ Error connecting to endpoint: {status['message']}"
gr.on_load(lambda: gr.update(value=check_endpoint()), outputs=endpoint_status)
# Mount Gradio app
app = gr.mount_gradio_app(app, interface, path="/gradio")
# Run the app
if __name__ == "__main__":
# Check if running on HF Spaces
if os.environ.get("SPACE_ID"):
# Running on HF Spaces - use their port
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
else:
# Running locally
uvicorn.run(app, host="0.0.0.0", port=8000) |