현진_app.py 우울모델 추가
Browse files
app.py
CHANGED
@@ -3,11 +3,15 @@ import re
|
|
3 |
import time
|
4 |
import requests
|
5 |
import numpy as np
|
|
|
|
|
|
|
6 |
from fastapi import FastAPI, HTTPException
|
7 |
from pydantic import BaseModel
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
-
import os
|
10 |
from typing import Optional, List,Dict
|
|
|
|
|
11 |
|
12 |
|
13 |
#####################################
|
@@ -1147,6 +1151,39 @@ def chat_response(user_input, mode="emotion", max_retries=5):
|
|
1147 |
return "🚨 모델 로딩이 너무 오래 걸립니다. 잠시 후 다시 시도하세요."
|
1148 |
|
1149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1150 |
#####################################
|
1151 |
# 6) FastAPI Endpoint
|
1152 |
#####################################
|
@@ -1202,6 +1239,19 @@ class ChatOrRecommendRequest(BaseModel):
|
|
1202 |
# (5) 자동 분기 엔드포인트
|
1203 |
@app.post("/chat_or_recommend")
|
1204 |
def chat_or_recommend(req: ChatOrRecommendRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1205 |
user_text = req.user_input
|
1206 |
mode = req.mode.lower()
|
1207 |
|
|
|
3 |
import time
|
4 |
import requests
|
5 |
import numpy as np
|
6 |
+
import torch
|
7 |
+
import joblib
|
8 |
+
import xgboost as xgb
|
9 |
from fastapi import FastAPI, HTTPException
|
10 |
from pydantic import BaseModel
|
11 |
from sentence_transformers import SentenceTransformer
|
|
|
12 |
from typing import Optional, List,Dict
|
13 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
14 |
+
|
15 |
|
16 |
|
17 |
#####################################
|
|
|
1151 |
return "🚨 모델 로딩이 너무 오래 걸립니다. 잠시 후 다시 시도하세요."
|
1152 |
|
1153 |
|
1154 |
+
#우울분류 모델 추가
|
1155 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1156 |
+
|
1157 |
+
tokenizer = BertTokenizer.from_pretrained("monologg/kobert")
|
1158 |
+
bert_model = BertForSequenceClassification.from_pretrained("monologg/kobert", num_labels=2)
|
1159 |
+
bert_model.load_state_dict(torch.load("emotion_bert_model.pth", map_location=device))
|
1160 |
+
bert_model.to(device)
|
1161 |
+
bert_model.eval()
|
1162 |
+
|
1163 |
+
xgb_model = joblib.load("xgboost_model.pkl")
|
1164 |
+
vectorizer = joblib.load("tfidf_vectorizer.pkl")
|
1165 |
+
|
1166 |
+
def predict_depression(text: str):
|
1167 |
+
encoding = tokenizer(text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
|
1168 |
+
input_ids = encoding["input_ids"].to(device)
|
1169 |
+
attention_mask = encoding["attention_mask"].to(device)
|
1170 |
+
with torch.no_grad():
|
1171 |
+
outputs = bert_model(input_ids, attention_mask=attention_mask)
|
1172 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
1173 |
+
kobert_score = probabilities[0][1].item()
|
1174 |
+
text_vec = vectorizer.transform([text])
|
1175 |
+
xgb_proba = xgb_model.predict_proba(text_vec)[0][1]
|
1176 |
+
kobert_score = max(0.35, min(kobert_score, 0.88))
|
1177 |
+
xgb_proba = max(0.3, min(xgb_proba, 0.83))
|
1178 |
+
combined_score = (kobert_score * 0.55) + (xgb_proba * 0.45)
|
1179 |
+
if combined_score > 0.78:
|
1180 |
+
label = "상담 권장"
|
1181 |
+
elif combined_score > 0.65:
|
1182 |
+
label = "관심 필요"
|
1183 |
+
else:
|
1184 |
+
label = "정상"
|
1185 |
+
return combined_score, label
|
1186 |
+
|
1187 |
#####################################
|
1188 |
# 6) FastAPI Endpoint
|
1189 |
#####################################
|
|
|
1239 |
# (5) 자동 분기 엔드포인트
|
1240 |
@app.post("/chat_or_recommend")
|
1241 |
def chat_or_recommend(req: ChatOrRecommendRequest):
|
1242 |
+
depression_score, depression_label = predict_depression(req.user_input)
|
1243 |
+
if depression_label == "상담 권장":
|
1244 |
+
counseling_response = (
|
1245 |
+
"입력하신 메시지에서 심각한 우울 신호가 감지되었습니다.\n"
|
1246 |
+
"전문 상담을 받으실 것을 강력히 권장드립니다.\n"
|
1247 |
+
"빠른 시일 내에 전문가와 상담하시길 바랍니다."
|
1248 |
+
)
|
1249 |
+
return {
|
1250 |
+
"mode": "counseling",
|
1251 |
+
"response": counseling_response,
|
1252 |
+
"depression_score": round(depression_score, 4),
|
1253 |
+
"depression_label": depression_label
|
1254 |
+
}
|
1255 |
user_text = req.user_input
|
1256 |
mode = req.mode.lower()
|
1257 |
|