from fastapi import FastAPI, Request, Response, status from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os import time app = FastAPI() # Create cache directory cache_dir = "./model_cache" os.makedirs(cache_dir, exist_ok=True) # Track app status app_status = { "status": "initializing", "model_name": "distilgpt2", "model_loaded": False, "tokenizer_loaded": False, "startup_time": time.time(), "errors": [] } # Load model and tokenizer once at startup model_name = "distilgpt2" try: # Try to load tokenizer app_status["status"] = "loading_tokenizer" tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) app_status["tokenizer_loaded"] = True # Try to load model - with from_tf=True app_status["status"] = "loading_model" model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True, cache_dir=cache_dir) app_status["model_loaded"] = True app_status["status"] = "ready" except Exception as e: error_msg = f"Error loading model or tokenizer: {str(e)}" app_status["status"] = "limited_functionality" app_status["errors"].append(error_msg) print(error_msg) class PromptRequest(BaseModel): prompt: str max_new_tokens: int = 50 @app.post("/generate") async def generate_text(req: PromptRequest, response: Response): if app_status["status"] != "ready": response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE return {"error": "Model not ready", "status": app_status["status"], "details": app_status["errors"]} try: inputs = tokenizer(req.prompt, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.8, top_p=0.95, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated} except Exception as e: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR return {"error": str(e)} @app.get("/") async def root(): return {"message": "API is responding", "status": app_status["status"]} @app.get("/status") async def get_status(): # Calculate uptime uptime = time.time() - app_status["startup_time"] return { "status": app_status["status"], "model_name": app_status["model_name"], "model_loaded": app_status["model_loaded"], "tokenizer_loaded": app_status["tokenizer_loaded"], "uptime_seconds": uptime, "errors": app_status["errors"] }