File size: 5,069 Bytes
7651129
 
 
 
 
 
 
 
 
 
 
7aac29c
7651129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import logging
import torch
import os
from TTS.api import TTS
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from langdetect import detect
import soundfile as sf  # Import soundfile

# Allowlist XttsConfig so torch.load doesn't raise UnpicklingError
from torch.serialization import add_safe_globals
from TTS.tts.configs.xtts_config import XttsConfig
add_safe_globals([XttsConfig])

# ✅ Monkey-patch torch.load to always use weights_only=False
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
    kwargs["weights_only"] = False
    return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load

logging.basicConfig(level=logging.DEBUG)

# Initialize FastAPI
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load TTS model from local files
try:
    model_dir = "/app/models/xtts_v2"
    config_path = os.path.join(model_dir, "config.json")
    # When providing config_path, TTS might expect the directory for model_path
    tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu")
    print("XTTS v2 model loaded successfully from local files.")
except Exception as e:
    print(f"Error loading XTTS v2 model from local files: {e}")
    print("Falling back to loading by model name (license might be required).")
    tts = TTS("tts_models/multilingual/multi-dataset-xtts_v2").to("cuda" if torch.cuda.is_available() else "cpu")

# Load sentiment models
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter"
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT")
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")

# Input class for POST body
class Message(BaseModel):
    text: str

# Language detection
def detect_language_safely(text):
    try:
        if any('\u0600' <= c <= '\u06FF' for c in text):
            return "ar"
        return detect(text)
    except:
        return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en"

# Sentiment to emotion mapping
def map_sentiment_to_emotion(sentiment, language="en"):
    if language == "ar":
        return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral"
    return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral"

# Simple Arabic sentiment analysis
def arabic_sentiment_analysis(text):
    pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"]
    neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"]
    pos_count = sum(1 for word in pos_words if word in text.lower())
    neg_count = sum(1 for word in neg_words if word in text.lower())

    if pos_count > neg_count:
        return "positive"
    elif neg_count > pos_count:
        return "negative"
    else:
        try:
            inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
            outputs = sentiment_model(**inputs)
            sentiment_class = torch.argmax(outputs.logits).item()
            return ["negative", "neutral", "positive"][sentiment_class]
        except:
            return "neutral"

# Main TTS endpoint
@app.post("/text-to-speech/")
def text_to_speech(msg: Message):
    text = msg.text
    language = detect_language_safely(text)
    emotion = "neutral"

    if language == "en":
        try:
            sentiment_result = sentiment_analyzer(text)[0]
            emotion = map_sentiment_to_emotion(sentiment_result["label"])
        except:
            pass
    else:
        try:
            sentiment_result = arabic_sentiment_analysis(text)
            emotion = map_sentiment_to_emotion(sentiment_result, language="ar")
        except:
            pass

    output_filename = "output.wav"
    try:
        tts.tts_to_file(
            text=text,
            file_path=output_filename,
            emotion=emotion,
            speaker_wav="/app/audio/speaker_reference.wav", # Updated path
            language=language
        )
        return {
            "status": "success",
            "audio_file": output_filename,
            "url": "/audio"
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}

# ✅ Serve the audio file
@app.get("/audio")
def get_audio():
    return FileResponse("output.wav", media_type="audio/wav", filename="output.wav")

# Serve static files (your web page) from the 'web' directory
app.mount("/", StaticFiles(directory="web", html=True), name="static")