alethanhson commited on
Commit
04817a7
·
1 Parent(s): 63c4f82
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import base64
2
+ # import io
3
+ # import logging
4
+ # from typing import List, Optional
5
+
6
+ # import torch
7
+ # import torchaudio
8
+ # import uvicorn
9
+ # from fastapi import FastAPI, HTTPException
10
+ # from fastapi.middleware.cors import CORSMiddleware
11
+ # from pydantic import BaseModel
12
+
13
+ # from generator import load_csm_1b, Segment
14
+ # import gradio as gr
15
+
16
+ # logging.basicConfig(level=logging.INFO)
17
+ # logger = logging.getLogger(__name__)
18
+
19
+ # app = FastAPI(
20
+ # title="CSM 1B API",
21
+ # description="API for Sesame's Conversational Speech Model",
22
+ # version="1.0.0",
23
+ # )
24
+
25
+ # app.add_middleware(
26
+ # CORSMiddleware,
27
+ # allow_origins=["*"],
28
+ # allow_credentials=True,
29
+ # allow_methods=["*"],
30
+ # allow_headers=["*"],
31
+ # )
32
+
33
+ # generator = None
34
+
35
+ # class SegmentRequest(BaseModel):
36
+ # speaker: int
37
+ # text: str
38
+ # audio_base64: Optional[str] = None
39
+
40
+ # class GenerateAudioRequest(BaseModel):
41
+ # text: str
42
+ # speaker: int
43
+ # context: List[SegmentRequest] = []
44
+ # max_audio_length_ms: float = 10000
45
+ # temperature: float = 0.9
46
+ # topk: int = 50
47
+
48
+ # class AudioResponse(BaseModel):
49
+ # audio_base64: str
50
+ # sample_rate: int
51
+
52
+ # @app.on_event("startup")
53
+ # async def startup_event():
54
+ # global generator
55
+ # logger.info("Loading CSM 1B model...")
56
+
57
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ # if device == "cpu":
59
+ # logger.info("Loading CSM 1B model...")
60
+ # logger.warning("GPU not available. Using CPU, performance may be slow!")
61
+ # logger.info(f"Using device: {device}")
62
+ # try:
63
+ # generator = load_csm_1b(device=device)
64
+ # logger.info(f"Model loaded successfully on device: {device}")
65
+ # except Exception as e:
66
+ # logger.error(f"Could not load model: {str(e)}")
67
+ # raise e
68
+
69
+ # @app.post("/generate-audio", response_model=AudioResponse)
70
+ # async def generate_audio(request: GenerateAudioRequest):
71
+ # global generator
72
+
73
+ # if generator is None:
74
+ # raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.")
75
+
76
+ # try:
77
+ # context_segments = []
78
+ # for segment in request.context:
79
+ # if segment.audio_base64:
80
+ # audio_bytes = base64.b64decode(segment.audio_base64)
81
+ # audio_buffer = io.BytesIO(audio_bytes)
82
+
83
+ # audio_tensor, sample_rate = torchaudio.load(audio_buffer)
84
+ # audio_tensor = torchaudio.functional.resample(
85
+ # audio_tensor.squeeze(0),
86
+ # orig_freq=sample_rate,
87
+ # new_freq=generator.sample_rate
88
+ # )
89
+ # else:
90
+ # audio_tensor = torch.zeros(0, dtype=torch.float32)
91
+
92
+ # context_segments.append(
93
+ # Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor)
94
+ # )
95
+
96
+ # audio = generator.generate(
97
+ # text=request.text,
98
+ # speaker=request.speaker,
99
+ # context=context_segments,
100
+ # max_audio_length_ms=request.max_audio_length_ms,
101
+ # temperature=request.temperature,
102
+ # topk=request.topk,
103
+ # )
104
+
105
+ # buffer = io.BytesIO()
106
+ # torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
107
+ # # torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
108
+ # buffer.seek(0)
109
+ # # audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
110
+
111
+ # return AudioResponse(
112
+ # content=buffer.read(),
113
+ # media_type="audio/wav",
114
+ # headers={"Content-Disposition": "attachment; filename=audio.wav"}
115
+ # )
116
+
117
+ # except Exception as e:
118
+ # logger.error(f"error when building audio: {str(e)}")
119
+ # raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}")
120
+
121
+ # @app.get("/health")
122
+ # async def health_check():
123
+ # if generator is None:
124
+ # return {"status": "not_ready", "message": "Model is loading"}
125
+ # return {"status": "ready", "message": "API is ready to serve"}
126
+
127
+
128
+ import gradio as gr
129
+
130
+ def greet(name):
131
+ return "Hello " + name + "!!"
132
+
133
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
134
+ demo.launch()