API / app.py
FlameF0X's picture
Update app.py
c2f4964 verified
raw
history blame
2.65 kB
from fastapi import FastAPI, Request, Response, status
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import time
app = FastAPI()
# Create cache directory
os.makedirs("./model_cache", exist_ok=True)
# Track app status
app_status = {
"status": "initializing",
"model_name": "distilgpt2",
"model_loaded": False,
"tokenizer_loaded": False,
"startup_time": time.time(),
"errors": []
}
# Load model and tokenizer once at startup
model_name = "distilgpt2" # change this to your own model
try:
# Try to load tokenizer
app_status["status"] = "loading_tokenizer"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
app_status["tokenizer_loaded"] = True
# Try to load model
app_status["status"] = "loading_model"
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
app_status["model_loaded"] = True
app_status["status"] = "ready"
except Exception as e:
error_msg = f"Error loading model or tokenizer: {str(e)}"
app_status["status"] = "error"
app_status["errors"].append(error_msg)
print(error_msg)
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
@app.post("/generate")
async def generate_text(req: PromptRequest, response: Response):
if app_status["status"] != "ready":
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return {"error": "Model not ready", "status": app_status["status"]}
try:
inputs = tokenizer(req.prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.95,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated}
except Exception as e:
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
@app.get("/")
async def root():
return {"message": "API is running", "status": app_status["status"]}
@app.get("/status")
async def get_status():
# Calculate uptime
uptime = time.time() - app_status["startup_time"]
return {
"status": app_status["status"],
"model_name": app_status["model_name"],
"model_loaded": app_status["model_loaded"],
"tokenizer_loaded": app_status["tokenizer_loaded"],
"uptime_seconds": uptime,
"errors": app_status["errors"]
}