from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch app = FastAPI() # Initialize the model (we'll use a small model for this example) model_name = "EleutherAI/gpt-neo-125M" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) class GenerateRequest(BaseModel): prompt: str @app.post("/generate") async def generate(request: GenerateRequest): try: input_ids = tokenizer.encode(request.prompt, return_tensors="pt") output = model.generate(input_ids, max_length=100, num_return_sequences=1) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Model server is running"}