from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Khởi tạo FastAPI app = FastAPI() # Tải model và tokenizer khi ứng dụng khởi động model_name = "Qwen/Qwen2.5-0.5B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") # Định nghĩa request body class TextInput(BaseModel): prompt: str max_length: int = 100 # API endpoint để sinh văn bản @app.post("/generate") async def generate_text(input: TextInput): try: # Mã hóa đầu vào inputs = tokenizer(input.prompt, return_tensors="pt").to(model.device) # Sinh văn bản outputs = model.generate( inputs["input_ids"], max_length=input.max_length, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95 ) # Giải mã kết quả generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Endpoint kiểm tra sức khỏe @app.get("/") async def root(): return {"message": "Qwen2.5-0.5B API is running!"}