gleisonnanet commited on
Commit
9850062
·
1 Parent(s): fdd143d

Add application file

Browse files
Files changed (5) hide show
  1. Dockerfile +11 -0
  2. audio.py +61 -0
  3. main.py +44 -0
  4. packages.txt +1 -0
  5. requirements.txt +25 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
audio.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiofiles
2
+ import hashlib
3
+ import os
4
+ from fastapi import APIRouter
5
+ from fastapi.responses import JSONResponse, FileResponse
6
+ from TTS.api import TTS
7
+ from pydantic import BaseModel
8
+
9
+ audio_router = APIRouter()
10
+
11
+ # Define the input model
12
+ class TTSInput(BaseModel):
13
+ input_text: str = 'olà tia divina '
14
+ emotion: str = "Happy"
15
+ language: str
16
+ speed: float = 1.5
17
+ key: str
18
+
19
+
20
+ async def generate_audio_file(input_data: TTSInput):
21
+ model_name = TTS.list_models()[0]
22
+
23
+ # Initialize the TTS model
24
+ tts = TTS(model_name=model_name, model_path="./model")
25
+
26
+ # Create a string with the input values for hashing
27
+ hash_input = f"{input_data.input_text}{input_data.emotion}{input_data.language}{input_data.speed}"
28
+
29
+ # Calculate the MD5 hash based on the input values
30
+ md5_hash = hashlib.md5(hash_input.encode()).hexdigest()
31
+
32
+ # Check if the audio file already exists
33
+ audio_file_path = os.path.join("audio", f"{md5_hash}.wav")
34
+ if not os.path.exists(audio_file_path):
35
+ # Generate TTS audio and save to a file
36
+ tts.tts_to_file(
37
+ text=input_data.input_text,
38
+ speaker=tts.speakers[5],
39
+ language=tts.languages[2],
40
+ file_path=audio_file_path,
41
+ gpu=True,
42
+ emotion=input_data.emotion,
43
+ speed=input_data.speed,
44
+ progress_bar=True
45
+ )
46
+
47
+ return audio_file_path, md5_hash
48
+
49
+ @audio_router.post("/generate_audio", response_class=JSONResponse)
50
+ async def generate_audio(input_data: TTSInput):
51
+ audio_file, md5_hash = await generate_audio_file(input_data)
52
+ return {"message": "Audio generated successfully", "audio_file": audio_file, "md5_hash": md5_hash}
53
+
54
+ @audio_router.get("/download_audio/{md5_hash}")
55
+ async def download_audio(md5_hash: str):
56
+ audio_file = os.path.join("audio", f"{md5_hash}.wav")
57
+
58
+ if os.path.exists(audio_file):
59
+ return FileResponse(audio_file, media_type="audio/wav")
60
+ else:
61
+ return JSONResponse("message", "Audio not found")
main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from audio import audio_router
4
+ from pydantic import BaseSettings, BaseModel
5
+ from typing import List, Dict, Optional
6
+
7
+ class APISettings(BaseSettings):
8
+ max_input_length: int = 10000
9
+ config_path = "config/config.yaml"
10
+ version: Optional[str] = None
11
+
12
+ class Config:
13
+ env_file = 'config/.env'
14
+ env_prefix = 'api_'
15
+
16
+ api_settings = APISettings()
17
+
18
+ app = FastAPI(title="Text-to-Speech API",
19
+ docs_url="/",
20
+ version=api_settings.version if api_settings.version else "dev",
21
+ description="An API that provides text-to-speech using neural models. "
22
+ "Developed by TartuNLP - the NLP research group of the University of Tartu.",
23
+ terms_of_service="https://www.tartunlp.ai/andmekaitsetingimused",
24
+ license_info={
25
+ "name": "MIT license",
26
+ "url": "https://github.com/TartuNLP/text-to-speech-api/blob/main/LICENSE"
27
+ },
28
+ contact={
29
+ "name": "TartuNLP",
30
+ "url": "https://tartunlp.ai",
31
+ "email": "[email protected]",
32
+ })
33
+
34
+ app.include_router(audio_router)
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_methods=["GET", "POST"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ import uvicorn
44
+ uvicorn.run(app="main:app", host="0.0.0.0", port=8000, reload=True, timeout_keep_alive=None)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.5.0
2
+ anyio==3.7.1
3
+ certifi==2023.7.22
4
+ charset-normalizer==3.2.0
5
+ click==8.1.6
6
+ fastapi==0.100.1
7
+ gTTS==2.3.2
8
+ h11==0.14.0
9
+ idna==3.4
10
+ pydantic==2.1.1
11
+ pydantic_core==2.4.0
12
+ requests==2.31.0
13
+ sniffio==1.3.0
14
+ starlette==0.27.0
15
+ typing_extensions==4.7.1
16
+ urllib3==2.0.4
17
+ uvicorn==0.23.1
18
+ neon-tts-plugin-coqui==0.7.3a1
19
+ TTS[all,dev,notebooks]
20
+ gradio
21
+ # TTS==0.7.1
22
+
23
+ stt
24
+ torch
25
+ transformers