File size: 832 Bytes
65226cd
37527e9
 
87e455f
65226cd
 
87e455f
65226cd
37527e9
65226cd
37527e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e455f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer


# Create an instance of the FastAPI class
app = FastAPI()

# Define a route for the root endpoint
@app.get("/llm")
async def read_root():
    device = "cpu"
    model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
    text = """<s>[INST] What is your favourite condiment? [/INST]
    """
    encodeds = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)
    model.to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True)
    decoded = tokenizer.batch_decode(generated_ids)
    print(decoded[0])
    return {"message": decoded[0]}