from fastapi import FastAPI, Request from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os app = FastAPI() # Create cache directory os.makedirs("./model_cache", exist_ok=True) # Load model and tokenizer once at startup model_name = "distilgpt2" # change this to your own model try: # Try to load from local cache first tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False) model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False) except OSError as e: print(f"Error loading model: {e}") print("Attempting to download model directly...") # If that fails, try downloading explicitly tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache") model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache") class PromptRequest(BaseModel): prompt: str max_new_tokens: int = 50 @app.post("/generate") async def generate_text(req: PromptRequest): inputs = tokenizer(req.prompt, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.8, top_p=0.95, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated} @app.get("/") async def root(): return {"status": "API is running", "model": model_name}