akshatOP commited on
Commit
fa48fc0
·
1 Parent(s): 2f356af
Files changed (1) hide show
  1. app.py +22 -37
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, File, UploadFile, Response
2
- from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
  from llama_cpp import Llama
5
  import torch
@@ -10,74 +10,59 @@ from pydantic import BaseModel
10
 
11
  app = FastAPI()
12
 
13
- # Load TTS model
14
- TTS_MODEL_PATH = "./models/tts_model"
15
- if os.path.exists(TTS_MODEL_PATH):
16
- tts_model = AutoModelForSpeechSeq2Seq.from_pretrained(TTS_MODEL_PATH)
17
- tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_PATH)
18
  else:
19
- tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("suno/bark") # Replace with an actual TTS model
20
- tts_tokenizer = AutoTokenizer.from_pretrained("suno/bark")
21
 
22
- # Load SST model
23
- SST_MODEL_PATH = "./models/sst_model"
24
- if os.path.exists(SST_MODEL_PATH):
25
- sst_model = Wav2Vec2ForCTC.from_pretrained(SST_MODEL_PATH)
26
- sst_processor = Wav2Vec2Processor.from_pretrained(SST_MODEL_PATH)
27
  else:
28
  sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
29
- sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
30
 
31
- # Load Llama.cpp model
32
- LLM_MODEL_PATH = "./models/llama.gguf"
33
- if os.path.exists(LLM_MODEL_PATH):
34
- llm = Llama(model_path=LLM_MODEL_PATH) # Corrected usage
35
  else:
36
  raise FileNotFoundError("Please upload llama.gguf to models/ directory")
37
 
38
- # Request models
39
  class TTSRequest(BaseModel):
40
  text: str
41
 
42
  class LLMRequest(BaseModel):
43
  prompt: str
44
 
45
- # API Endpoints
46
  @app.post("/tts")
47
  async def tts_endpoint(request: TTSRequest):
48
  text = request.text
49
  inputs = tts_tokenizer(text, return_tensors="pt")
50
-
51
  with torch.no_grad():
52
- output_ids = tts_model.generate(**inputs)
53
-
54
- # Convert model output to speech (assuming Bark-like model)
55
- audio = output_ids.squeeze().cpu().numpy()
56
-
57
  buffer = io.BytesIO()
58
- sf.write(buffer, audio, samplerate=22050, format="WAV") # Ensure correct sample rate
59
  buffer.seek(0)
60
-
61
  return Response(content=buffer.getvalue(), media_type="audio/wav")
62
 
63
  @app.post("/sst")
64
  async def sst_endpoint(file: UploadFile = File(...)):
65
  audio_bytes = await file.read()
66
  audio, sr = sf.read(io.BytesIO(audio_bytes))
67
-
68
  inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
69
-
70
  with torch.no_grad():
71
- logits = sst_model(inputs.input_values).logits # Ensure .logits is detached
72
- predicted_ids = torch.argmax(logits, dim=-1)
73
-
74
  transcription = sst_processor.batch_decode(predicted_ids)[0]
75
-
76
  return {"text": transcription}
77
 
78
  @app.post("/llm")
79
  async def llm_endpoint(request: LLMRequest):
80
  prompt = request.prompt
81
- output = llm(prompt, max_tokens=50) # Corrected llama.cpp API call
82
-
83
- return {"text": output["choices"][0]["text"] if "choices" in output else output["content"]}
 
1
  from fastapi import FastAPI, File, UploadFile, Response
2
+ from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
  from llama_cpp import Llama
5
  import torch
 
10
 
11
  app = FastAPI()
12
 
13
+ # Load models
14
+ if os.path.exists("./models/tts_model"):
15
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained("./models/tts_model")
16
+ tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
 
17
  else:
18
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
19
+ tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
20
 
21
+ # SST and LLM loading remains unchanged
22
+ if os.path.exists("./models/sst_model"):
23
+ sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
24
+ sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
 
25
  else:
26
  sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
27
+ sst_processor = Wav2Vec2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
28
 
29
+ if os.path.exists("./models/llama.gguf"):
30
+ llm = Llama("./models/llama.gguf")
 
 
31
  else:
32
  raise FileNotFoundError("Please upload llama.gguf to models/ directory")
33
 
34
+ # Request models and endpoints remain unchanged
35
  class TTSRequest(BaseModel):
36
  text: str
37
 
38
  class LLMRequest(BaseModel):
39
  prompt: str
40
 
 
41
  @app.post("/tts")
42
  async def tts_endpoint(request: TTSRequest):
43
  text = request.text
44
  inputs = tts_tokenizer(text, return_tensors="pt")
 
45
  with torch.no_grad():
46
+ audio = tts_model.generate(**inputs)
47
+ audio = audio.squeeze().cpu().numpy()
 
 
 
48
  buffer = io.BytesIO()
49
+ sf.write(buffer, audio, 22050, format="WAV")
50
  buffer.seek(0)
 
51
  return Response(content=buffer.getvalue(), media_type="audio/wav")
52
 
53
  @app.post("/sst")
54
  async def sst_endpoint(file: UploadFile = File(...)):
55
  audio_bytes = await file.read()
56
  audio, sr = sf.read(io.BytesIO(audio_bytes))
 
57
  inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
 
58
  with torch.no_grad():
59
+ logits = sst_model(inputs.input_values).logits
60
+ predicted_ids = torch.argmax(logits, dim=-1)
 
61
  transcription = sst_processor.batch_decode(predicted_ids)[0]
 
62
  return {"text": transcription}
63
 
64
  @app.post("/llm")
65
  async def llm_endpoint(request: LLMRequest):
66
  prompt = request.prompt
67
+ output = llm(prompt, max_tokens=50)
68
+ return {"text": output["choices"][0]["text"]}