|
from flask import Flask, request, jsonify, render_template, session |
|
from transformers import pipeline |
|
import torch |
|
import torchaudio |
|
from pydub import AudioSegment |
|
import os |
|
import io |
|
import uuid |
|
from datetime import datetime |
|
import sqlite3 |
|
from pathlib import Path |
|
import whisper |
|
|
|
app = Flask(__name__) |
|
app.secret_key = 'your-very-secret-key-12345' |
|
|
|
|
|
|
|
def get_db_connection(): |
|
instance_path = Path('instance') |
|
instance_path.mkdir(exist_ok=True) |
|
db_path = instance_path / 'chats.db' |
|
conn = sqlite3.connect(str(db_path)) |
|
conn.row_factory = sqlite3.Row |
|
return conn |
|
|
|
|
|
def init_db(): |
|
conn = get_db_connection() |
|
try: |
|
conn.execute(''' |
|
CREATE TABLE IF NOT EXISTS chats ( |
|
chat_id TEXT PRIMARY KEY, |
|
user_id TEXT, |
|
created_at TEXT, |
|
title TEXT |
|
) |
|
''') |
|
conn.execute(''' |
|
CREATE TABLE IF NOT EXISTS messages ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
chat_id TEXT, |
|
sender TEXT, |
|
content TEXT, |
|
timestamp TEXT, |
|
FOREIGN KEY(chat_id) REFERENCES chats(chat_id) |
|
) |
|
''') |
|
conn.commit() |
|
finally: |
|
conn.close() |
|
|
|
|
|
init_db() |
|
|
|
|
|
emotion_map = { |
|
'joy': '😊 Радость', |
|
'neutral': '😐 Нейтрально', |
|
'anger': '😠 Злость', |
|
'sadness': '😢 Грусть', |
|
'surprise': '😲 Удивление' |
|
} |
|
|
|
|
|
try: |
|
|
|
speech_to_text_model = whisper.load_model("small") |
|
|
|
|
|
text_classifier = pipeline( |
|
"text-classification", |
|
model="cointegrated/rubert-tiny2-cedr-emotion-detection", |
|
top_k=None |
|
) |
|
audio_classifier = pipeline( |
|
"audio-classification", |
|
model="superb/hubert-large-superb-er", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
except Exception as e: |
|
print(f"Ошибка загрузки моделей: {e}") |
|
speech_to_text_model = None |
|
text_classifier = None |
|
audio_classifier = None |
|
|
|
|
|
def transcribe_audio(audio_path): |
|
"""Преобразование аудио в текст с помощью Whisper""" |
|
if not speech_to_text_model: |
|
return None |
|
|
|
try: |
|
result = speech_to_text_model.transcribe(audio_path, language="ru") |
|
return result["text"] |
|
except Exception as e: |
|
print(f"Ошибка преобразования аудио в текст: {e}") |
|
return None |
|
|
|
|
|
@app.route("/") |
|
def index(): |
|
if 'user_id' not in session: |
|
session['user_id'] = str(uuid.uuid4()) |
|
|
|
conn = get_db_connection() |
|
try: |
|
chats = conn.execute( |
|
"SELECT chat_id, title FROM chats WHERE user_id = ? ORDER BY created_at DESC", |
|
(session['user_id'],) |
|
).fetchall() |
|
return render_template("index.html", chats=chats) |
|
finally: |
|
conn.close() |
|
|
|
|
|
@app.route("/get_chats") |
|
def get_chats(): |
|
if 'user_id' not in session: |
|
return jsonify([]) |
|
|
|
conn = get_db_connection() |
|
try: |
|
chats = conn.execute( |
|
"SELECT chat_id, title FROM chats WHERE user_id = ? ORDER BY created_at DESC", |
|
(session['user_id'],) |
|
).fetchall() |
|
return jsonify([dict(chat) for chat in chats]) |
|
finally: |
|
conn.close() |
|
|
|
|
|
@app.route("/start_chat", methods=["POST"]) |
|
def start_chat(): |
|
if 'user_id' not in session: |
|
session['user_id'] = str(uuid.uuid4()) |
|
|
|
chat_id = str(uuid.uuid4()) |
|
title = "Новый чат " + datetime.now().strftime("%d.%m %H:%M") |
|
|
|
conn = get_db_connection() |
|
try: |
|
conn.execute( |
|
"INSERT INTO chats (chat_id, user_id, created_at, title) VALUES (?, ?, datetime('now'), ?)", |
|
(chat_id, session['user_id'], title) |
|
) |
|
conn.commit() |
|
return jsonify({"chat_id": chat_id, "title": title}) |
|
finally: |
|
conn.close() |
|
|
|
|
|
@app.route("/load_chat/<chat_id>", methods=["GET"]) |
|
def load_chat(chat_id): |
|
conn = get_db_connection() |
|
try: |
|
chat_exists = conn.execute( |
|
"SELECT 1 FROM chats WHERE chat_id = ?", (chat_id,) |
|
).fetchone() |
|
|
|
if not chat_exists: |
|
return jsonify({"error": "Chat not found"}), 404 |
|
|
|
messages = conn.execute( |
|
"SELECT sender, content FROM messages WHERE chat_id = ? ORDER BY timestamp ASC", |
|
(chat_id,) |
|
).fetchall() |
|
|
|
title_row = conn.execute( |
|
"SELECT title FROM chats WHERE chat_id = ?", (chat_id,) |
|
).fetchone() |
|
|
|
return jsonify({ |
|
"messages": [dict(msg) for msg in messages], |
|
"title": title_row['title'] if title_row else "Без названия" |
|
}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
finally: |
|
conn.close() |
|
|
|
|
|
@app.route("/save_message", methods=["POST"]) |
|
def save_message(): |
|
data = request.get_json() |
|
if not all([data.get("chat_id"), data.get("sender"), data.get("content")]): |
|
return jsonify({"error": "Missing parameters"}), 400 |
|
|
|
conn = get_db_connection() |
|
try: |
|
conn.execute( |
|
"INSERT INTO messages (chat_id, sender, content, timestamp) VALUES (?, ?, ?, datetime('now'))", |
|
(data['chat_id'], data['sender'], data['content']) |
|
) |
|
conn.commit() |
|
return jsonify({"status": "success"}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
finally: |
|
conn.close() |
|
|
|
|
|
@app.route("/analyze", methods=["POST"]) |
|
def analyze_text(): |
|
if not text_classifier: |
|
return jsonify({"error": "Model not loaded"}), 500 |
|
|
|
text = request.get_json().get("text", "").strip() |
|
if not text: |
|
return jsonify({"error": "Empty text"}), 400 |
|
|
|
try: |
|
predictions = text_classifier(text)[0] |
|
top_prediction = max(predictions, key=lambda x: x["score"]) |
|
return jsonify({ |
|
"emotion": emotion_map.get(top_prediction["label"], "❓ Неизвестно"), |
|
"confidence": round(top_prediction["score"], 2) |
|
}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route('/analyze_audio', methods=['POST']) |
|
def analyze_audio(): |
|
if not audio_classifier or not speech_to_text_model: |
|
return jsonify({"error": "Model not loaded"}), 500 |
|
|
|
if 'audio' not in request.files: |
|
return jsonify({'error': 'No audio file'}), 400 |
|
|
|
try: |
|
audio_file = request.files['audio'] |
|
temp_path = "temp_audio.wav" |
|
|
|
audio = AudioSegment.from_file(io.BytesIO(audio_file.read())) |
|
audio = audio.set_frame_rate(16000).set_channels(1) |
|
audio.export(temp_path, format="wav", codec="pcm_s16le") |
|
|
|
|
|
transcribed_text = transcribe_audio(temp_path) |
|
|
|
|
|
result = audio_classifier(temp_path) |
|
os.remove(temp_path) |
|
|
|
emotion_mapping = { |
|
'hap': 'happy', |
|
'sad': 'sad', |
|
'neu': 'neutral', |
|
'ang': 'angry' |
|
} |
|
emotions = {emotion_mapping.get(item['label'].lower(), 'neutral'): item['score'] |
|
for item in result if item['label'].lower() in emotion_mapping} |
|
|
|
dominant_emotion = max(emotions.items(), key=lambda x: x[1]) |
|
response_map = { |
|
'happy': '😊 Радость', |
|
'sad': '😢 Грусть', |
|
'angry': '😠 Злость', |
|
'neutral': '😐 Нейтрально' |
|
} |
|
|
|
return jsonify({ |
|
'emotion': response_map.get(dominant_emotion[0], 'неизвестно'), |
|
'confidence': round(dominant_emotion[1], 2), |
|
'transcribed_text': transcribed_text if transcribed_text else "Не удалось распознать текст" |
|
}) |
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860) |