textgeneration / question_paper.py
Yash Sachdeva
solar
18dd69a
raw
history blame
908 Bytes
import transformers
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
MODEL = None
TOKENIZER = None
@app.get("/")
def llama():
text = "Hi, my name is "
inputs = TOKENIZER(text, return_tensors="pt").input_ids
outputs = MODEL.generate(
inputs,
max_length=256,
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
)
tresponse = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
print(tresponse)
return tresponse
@app.on_event("startup")
def init_model():
global MODEL
global TOKENIZER
if not MODEL:
print("loading model")
TOKENIZER = AutoTokenizer.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
MODEL = AutoModelForCausalLM.from_pretrained("berkeley-nest/Starling-LM-7B-alpha")
print("loaded model")