File size: 798 Bytes
9bf2007
 
dcd2d54
3e6fc0f
9bf2007
5ed2b9f
e48a0c0
e5e2748
21e7dd1
 
 
 
 
3e6fc0f
9bf2007
5ed2b9f
21e7dd1
 
 
 
 
 
 
9bf2007
21e7dd1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import transformers
import torch

from fastapi import FastAPI

from transformers import AutoModelForCausalLM, AutoTokenizer

app = FastAPI()

MODEL = None
TOKENIZER = None


@app.get("/")
def llama():
    text = "Hi, my name is "
    inputs = TOKENIZER(text, return_tensors="pt")
    outputs = MODEL.generate(**inputs, max_new_tokens=64)
    tresponse = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
    print(tresponse)

    return tresponse


@app.on_event("startup")
def init_model():
    global MODEL
    if not MODEL:
        print("loading model")
        TOKENIZER = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-v1.0")
        MODEL = AutoModelForCausalLM.from_pretrained("Upstage/SOLAR-10.7B-v1.0", device_map="auto", torch_dtype=torch.float16,)
        print("loaded model")