ai_focus_new / main.py
tuan243's picture
Update main.py
205ba60 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
import firebase_admin
from firebase_admin import credentials, firestore
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pydantic import BaseModel
import os
from huggingface_hub import login
import traceback
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load Firebase
try:
cred = credentials.Certificate("firebase_config.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
logging.info("Firebase initialized successfully.")
except Exception as e:
logging.error(f"Error initializing Firebase: {e}")
db = None
# Đăng nhập vào Hugging Face (nếu cần)
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN) # Pass HF_TOKEN as keyword argument
logging.info("Hugging Face login successful.")
except Exception as e:
logging.error(f"Error logging into Hugging Face: {e}")
# Initialize tokenizer and model outside the request handler
tokenizer = None
ai_model = None
try:
# Load AI Model
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large") # Removed token argument here
ai_model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large") # Removed token argument here
logging.info("Tokenizer and model loaded successfully.")
except Exception as e:
logging.error(f"Error loading tokenizer/model: {e}")
app = FastAPI()
# Request Models
class UpdateBiosRequest(BaseModel):
user_id: str
bios: str
class FocusHistoryRequest(BaseModel):
user_id: str
time_start: str
total_time: int
class AIRequest(BaseModel):
user_id: str
bios: str
# API cập nhật bios
@app.post("/update_bios")
async def update_bios(request: UpdateBiosRequest):
try:
user_ref = db.collection("user_profiles").document(request.user_id)
user_ref.set({"bios": request.bios}, merge=True)
return {"message": "Cập nhật bios thành công"}
except Exception as e:
logging.error(f"Error in /update_bios: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error updating bios: {e}")
# API thêm lịch sử focus
@app.post("/add_focus_history")
async def add_focus_history(request: FocusHistoryRequest):
try:
user_ref = db.collection("user_profiles").document(request.user_id)
user_doc = user_ref.get()
data = user_doc.to_dict() or {}
focus_history = data.get("focus_history", [])
focus_history.append({"time_start": request.time_start, "total_time": request.total_time})
user_ref.set({"focus_history": focus_history}, merge=True)
return {"message": "Thêm lịch sử focus thành công"}
except Exception as e:
logging.error(f"Error in /add_focus_history: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error adding focus history: {e}")
# API lấy dữ liệu người dùng
@app.get("/get_user_data")
async def get_user_data(user_id: str):
try:
user_doc = db.collection("user_profiles").document(user_id).get()
data = user_doc.to_dict() or {}
return {
"bios": data.get("bios", "Chưa có bios."),
"focus_history": data.get("focus_history", [])
}
except Exception as e:
logging.error(f"Error in /get_user_data: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error getting user data: {e}")
# API AI tư vấn
@app.post("/ai_personal_advice")
async def ai_personal_advice(request: AIRequest):
try:
if tokenizer is None or ai_model is None:
logging.error("Tokenizer or AI model not loaded.")
raise HTTPException(status_code=500, detail="Tokenizer or AI model not loaded.")
if db is None:
logging.error("Firebase not initialized.")
raise HTTPException(status_code=500, detail="Firebase not initialized.")
user_doc = db.collection("user_profiles").document(request.user_id).get()
# Check if the document exists
if not user_doc.exists:
logging.warning(f"User profile not found for user_id: {request.user_id}")
bios = "Không có thông tin người dùng."
focus_history = []
else:
data = user_doc.to_dict() or {}
bios = request.bios if request.bios else data.get("bios", "Chưa có bios.")
focus_history = data.get("focus_history", [])
focus_text = "\n".join([f"- {f['time_start']}: {f['total_time']} phút" for f in focus_history])
prompt = f"""
Thông tin người dùng:
- Bios: {bios}
- Lịch sử focus:
{focus_text}
Hãy tư vấn cách cải thiện hiệu suất làm việc dựa trên thông tin trên.
"""
# Truncate prompt if it's too long
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
response = ai_model.generate(**input_ids, max_new_tokens=500)
advice = tokenizer.decode(response[0], skip_special_tokens=True)
# Đảm bảo câu trả lời là một câu hoàn chỉnh
if not advice.endswith("."):
advice += "."
return {"advice": advice}
except Exception as e:
error_message = f"Error in /ai_personal_advice: {e}"
logging.error(error_message)
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=error_message)
# Trang chủ
@app.get("/")
async def home():
return JSONResponse(content={"message": "Welcome to the Recommendation API!"})