File size: 2,662 Bytes
c2f4964 f8898c0 ef34ed3 c2f4964 f8898c0 642232a f71bfa4 ef34ed3 aba490b ef34ed3 c2f4964 aba490b c2f4964 8dd9dba aba490b ef34ed3 c2f4964 aba490b c2f4964 aba490b c2f4964 aba490b c2f4964 aba490b c2f4964 8dd9dba f8898c0 8dd9dba f8898c0 c2f4964 aba490b c2f4964 ef34ed3 aba490b c2f4964 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from fastapi import FastAPI, Request, Response, status
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import time
app = FastAPI()
# Create cache directory
cache_dir = "./model_cache"
os.makedirs(cache_dir, exist_ok=True)
# Track app status
app_status = {
"status": "initializing",
"model_name": "distilgpt2",
"model_loaded": False,
"tokenizer_loaded": False,
"startup_time": time.time(),
"errors": []
}
# Load model and tokenizer once at startup
model_name = "distilgpt2"
try:
# Try to load tokenizer
app_status["status"] = "loading_tokenizer"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
app_status["tokenizer_loaded"] = True
# Try to load model - with from_tf=True
app_status["status"] = "loading_model"
model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True, cache_dir=cache_dir)
app_status["model_loaded"] = True
app_status["status"] = "ready"
except Exception as e:
error_msg = f"Error loading model or tokenizer: {str(e)}"
app_status["status"] = "limited_functionality"
app_status["errors"].append(error_msg)
print(error_msg)
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
@app.post("/generate")
async def generate_text(req: PromptRequest, response: Response):
if app_status["status"] != "ready":
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
return {"error": "Model not ready", "status": app_status["status"], "details": app_status["errors"]}
try:
inputs = tokenizer(req.prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.95,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated}
except Exception as e:
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
@app.get("/")
async def root():
return {"message": "API is responding", "status": app_status["status"]}
@app.get("/status")
async def get_status():
# Calculate uptime
uptime = time.time() - app_status["startup_time"]
return {
"status": app_status["status"],
"model_name": app_status["model_name"],
"model_loaded": app_status["model_loaded"],
"tokenizer_loaded": app_status["tokenizer_loaded"],
"uptime_seconds": uptime,
"errors": app_status["errors"]
} |