API / app.py
FlameF0X's picture
Update app.py
ef34ed3 verified
raw
history blame
1.52 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
app = FastAPI()
# Create cache directory
os.makedirs("./model_cache", exist_ok=True)
# Load model and tokenizer once at startup
model_name = "distilgpt2" # change this to your own model
try:
# Try to load from local cache first
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
except OSError as e:
print(f"Error loading model: {e}")
print("Attempting to download model directly...")
# If that fails, try downloading explicitly
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache")
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache")
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
@app.post("/generate")
async def generate_text(req: PromptRequest):
inputs = tokenizer(req.prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.95,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated}
@app.get("/")
async def root():
return {"status": "API is running", "model": model_name}