API / app.py
FlameF0X's picture
Update app.py
ae57764 verified
raw
history blame
865 Bytes
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI()
# Load model and tokenizer once at startup
model_name = "distilbert/distilgpt2" # change this to your own model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
@app.post("/generate")
async def generate_text(req: PromptRequest):
inputs = tokenizer(req.prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.95,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated}