Spaces:
Runtime error
Runtime error
File size: 908 Bytes
9bf2007 dcd2d54 3e6fc0f 9bf2007 5ed2b9f e48a0c0 e5e2748 21e7dd1 3e6fc0f 9bf2007 5ed2b9f 18dd69a 21e7dd1 5bc3efc 21e7dd1 18dd69a 21e7dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import transformers
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
MODEL = None
TOKENIZER = None
@app.get("/")
def llama():
text = "Hi, my name is "
inputs = TOKENIZER(text, return_tensors="pt").input_ids
outputs = MODEL.generate(
inputs,
max_length=256,
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
)
tresponse = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
print(tresponse)
return tresponse
@app.on_event("startup")
def init_model():
global MODEL
global TOKENIZER
if not MODEL:
print("loading model")
TOKENIZER = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
MODEL = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
print("loaded model") |