Spaces:
Runtime error
Runtime error
import transformers | |
import torch | |
from fastapi import FastAPI | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
app = FastAPI() | |
MODEL = None | |
TOKENIZER = None | |
def llama(): | |
text = "Hi, my name is " | |
inputs = TOKENIZER(text, return_tensors="pt").input_ids | |
outputs = MODEL.generate( | |
inputs, | |
max_length=256, | |
pad_token_id=TOKENIZER.pad_token_id, | |
eos_token_id=TOKENIZER.eos_token_id, | |
) | |
tresponse = TOKENIZER.decode(outputs[0], skip_special_tokens=True) | |
print(tresponse) | |
return tresponse | |
def init_model(): | |
global MODEL | |
global TOKENIZER | |
if not MODEL: | |
print("loading model") | |
TOKENIZER = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha") | |
MODEL = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha") | |
print("loaded model") |