Spaces:
Runtime error
Runtime error
File size: 5,069 Bytes
7651129 7aac29c 7651129 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import logging
import torch
import os
from TTS.api import TTS
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from langdetect import detect
import soundfile as sf # Import soundfile
# Allowlist XttsConfig so torch.load doesn't raise UnpicklingError
from torch.serialization import add_safe_globals
from TTS.tts.configs.xtts_config import XttsConfig
add_safe_globals([XttsConfig])
# ✅ Monkey-patch torch.load to always use weights_only=False
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
logging.basicConfig(level=logging.DEBUG)
# Initialize FastAPI
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load TTS model from local files
try:
model_dir = "/app/models/xtts_v2"
config_path = os.path.join(model_dir, "config.json")
# When providing config_path, TTS might expect the directory for model_path
tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu")
print("XTTS v2 model loaded successfully from local files.")
except Exception as e:
print(f"Error loading XTTS v2 model from local files: {e}")
print("Falling back to loading by model name (license might be required).")
tts = TTS("tts_models/multilingual/multi-dataset-xtts_v2").to("cuda" if torch.cuda.is_available() else "cpu")
# Load sentiment models
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter"
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT")
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
# Input class for POST body
class Message(BaseModel):
text: str
# Language detection
def detect_language_safely(text):
try:
if any('\u0600' <= c <= '\u06FF' for c in text):
return "ar"
return detect(text)
except:
return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en"
# Sentiment to emotion mapping
def map_sentiment_to_emotion(sentiment, language="en"):
if language == "ar":
return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral"
return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral"
# Simple Arabic sentiment analysis
def arabic_sentiment_analysis(text):
pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"]
neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"]
pos_count = sum(1 for word in pos_words if word in text.lower())
neg_count = sum(1 for word in neg_words if word in text.lower())
if pos_count > neg_count:
return "positive"
elif neg_count > pos_count:
return "negative"
else:
try:
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
outputs = sentiment_model(**inputs)
sentiment_class = torch.argmax(outputs.logits).item()
return ["negative", "neutral", "positive"][sentiment_class]
except:
return "neutral"
# Main TTS endpoint
@app.post("/text-to-speech/")
def text_to_speech(msg: Message):
text = msg.text
language = detect_language_safely(text)
emotion = "neutral"
if language == "en":
try:
sentiment_result = sentiment_analyzer(text)[0]
emotion = map_sentiment_to_emotion(sentiment_result["label"])
except:
pass
else:
try:
sentiment_result = arabic_sentiment_analysis(text)
emotion = map_sentiment_to_emotion(sentiment_result, language="ar")
except:
pass
output_filename = "output.wav"
try:
tts.tts_to_file(
text=text,
file_path=output_filename,
emotion=emotion,
speaker_wav="/app/audio/speaker_reference.wav", # Updated path
language=language
)
return {
"status": "success",
"audio_file": output_filename,
"url": "/audio"
}
except Exception as e:
return {"status": "error", "message": str(e)}
# ✅ Serve the audio file
@app.get("/audio")
def get_audio():
return FileResponse("output.wav", media_type="audio/wav", filename="output.wav")
# Serve static files (your web page) from the 'web' directory
app.mount("/", StaticFiles(directory="web", html=True), name="static") |