File size: 908 Bytes
9bf2007
 
dcd2d54
3e6fc0f
9bf2007
5ed2b9f
e48a0c0
e5e2748
21e7dd1
 
 
 
3e6fc0f
9bf2007
5ed2b9f
18dd69a
 
 
 
 
 
 
21e7dd1
 
 
 
 
 
 
 
5bc3efc
21e7dd1
 
18dd69a
 
21e7dd1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import transformers
import torch

from fastapi import FastAPI

from transformers import AutoModelForCausalLM, AutoTokenizer

app = FastAPI()

MODEL = None
TOKENIZER = None

@app.get("/")
def llama():
    text = "Hi, my name is "
    inputs = TOKENIZER(text, return_tensors="pt").input_ids
    outputs = MODEL.generate(
        inputs,
        max_length=256,
        pad_token_id=TOKENIZER.pad_token_id,
        eos_token_id=TOKENIZER.eos_token_id,
    )
    tresponse = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
    print(tresponse)

    return tresponse

@app.on_event("startup")
def init_model():
    global MODEL
    global TOKENIZER
    if not MODEL:
        print("loading model")
        TOKENIZER = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
        MODEL = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
        print("loaded model")