PochiBot / app.py
RocksPers's picture
Create app.py
52af2c5 verified
raw
history blame contribute delete
706 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
# Cargar modelo
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "nikravan/glm-4vq"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
class Query(BaseModel):
question: str
@app.post("/predict")
def predict(data: Query):
inputs = tokenizer(data.question, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=200)
return {"answer": tokenizer.decode(outputs[0])}