Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from pydantic import BaseModel | |
import os | |
from huggingface_hub import login | |
import traceback | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load Firebase | |
try: | |
cred = credentials.Certificate("firebase_config.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
logging.info("Firebase initialized successfully.") | |
except Exception as e: | |
logging.error(f"Error initializing Firebase: {e}") | |
db = None | |
# Đăng nhập vào Hugging Face (nếu cần) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
try: | |
login(token=HF_TOKEN) # Pass HF_TOKEN as keyword argument | |
logging.info("Hugging Face login successful.") | |
except Exception as e: | |
logging.error(f"Error logging into Hugging Face: {e}") | |
# Initialize tokenizer and model outside the request handler | |
tokenizer = None | |
ai_model = None | |
try: | |
# Load AI Model | |
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large") # Removed token argument here | |
ai_model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large") # Removed token argument here | |
logging.info("Tokenizer and model loaded successfully.") | |
except Exception as e: | |
logging.error(f"Error loading tokenizer/model: {e}") | |
app = FastAPI() | |
# Request Models | |
class UpdateBiosRequest(BaseModel): | |
user_id: str | |
bios: str | |
class FocusHistoryRequest(BaseModel): | |
user_id: str | |
time_start: str | |
total_time: int | |
class AIRequest(BaseModel): | |
user_id: str | |
bios: str | |
# API cập nhật bios | |
async def update_bios(request: UpdateBiosRequest): | |
try: | |
user_ref = db.collection("user_profiles").document(request.user_id) | |
user_ref.set({"bios": request.bios}, merge=True) | |
return {"message": "Cập nhật bios thành công"} | |
except Exception as e: | |
logging.error(f"Error in /update_bios: {e}\n{traceback.format_exc()}") | |
raise HTTPException(status_code=500, detail=f"Error updating bios: {e}") | |
# API thêm lịch sử focus | |
async def add_focus_history(request: FocusHistoryRequest): | |
try: | |
user_ref = db.collection("user_profiles").document(request.user_id) | |
user_doc = user_ref.get() | |
data = user_doc.to_dict() or {} | |
focus_history = data.get("focus_history", []) | |
focus_history.append({"time_start": request.time_start, "total_time": request.total_time}) | |
user_ref.set({"focus_history": focus_history}, merge=True) | |
return {"message": "Thêm lịch sử focus thành công"} | |
except Exception as e: | |
logging.error(f"Error in /add_focus_history: {e}\n{traceback.format_exc()}") | |
raise HTTPException(status_code=500, detail=f"Error adding focus history: {e}") | |
# API lấy dữ liệu người dùng | |
async def get_user_data(user_id: str): | |
try: | |
user_doc = db.collection("user_profiles").document(user_id).get() | |
data = user_doc.to_dict() or {} | |
return { | |
"bios": data.get("bios", "Chưa có bios."), | |
"focus_history": data.get("focus_history", []) | |
} | |
except Exception as e: | |
logging.error(f"Error in /get_user_data: {e}\n{traceback.format_exc()}") | |
raise HTTPException(status_code=500, detail=f"Error getting user data: {e}") | |
# API AI tư vấn | |
async def ai_personal_advice(request: AIRequest): | |
try: | |
if tokenizer is None or ai_model is None: | |
logging.error("Tokenizer or AI model not loaded.") | |
raise HTTPException(status_code=500, detail="Tokenizer or AI model not loaded.") | |
if db is None: | |
logging.error("Firebase not initialized.") | |
raise HTTPException(status_code=500, detail="Firebase not initialized.") | |
user_doc = db.collection("user_profiles").document(request.user_id).get() | |
# Check if the document exists | |
if not user_doc.exists: | |
logging.warning(f"User profile not found for user_id: {request.user_id}") | |
bios = "Không có thông tin người dùng." | |
focus_history = [] | |
else: | |
data = user_doc.to_dict() or {} | |
bios = request.bios if request.bios else data.get("bios", "Chưa có bios.") | |
focus_history = data.get("focus_history", []) | |
focus_text = "\n".join([f"- {f['time_start']}: {f['total_time']} phút" for f in focus_history]) | |
prompt = f""" | |
Thông tin người dùng: | |
- Bios: {bios} | |
- Lịch sử focus: | |
{focus_text} | |
Hãy tư vấn cách cải thiện hiệu suất làm việc dựa trên thông tin trên. | |
""" | |
# Truncate prompt if it's too long | |
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
response = ai_model.generate(**input_ids, max_new_tokens=500) | |
advice = tokenizer.decode(response[0], skip_special_tokens=True) | |
# Đảm bảo câu trả lời là một câu hoàn chỉnh | |
if not advice.endswith("."): | |
advice += "." | |
return {"advice": advice} | |
except Exception as e: | |
error_message = f"Error in /ai_personal_advice: {e}" | |
logging.error(error_message) | |
logging.error(traceback.format_exc()) | |
raise HTTPException(status_code=500, detail=error_message) | |
# Trang chủ | |
async def home(): | |
return JSONResponse(content={"message": "Welcome to the Recommendation API!"}) | |