kaanyarali1
llm try
37527e9
raw
history blame
832 Bytes
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
# Create an instance of the FastAPI class
app = FastAPI()
# Define a route for the root endpoint
@app.get("/llm")
async def read_root():
device = "cpu"
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
text = """<s>[INST] What is your favourite condiment? [/INST]
"""
encodeds = tokenizer(text, return_tensors="pt", add_special_tokens=False)
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])
return {"message": decoded[0]}