zhtwbloomdemo / app.py
jeffeux's picture
s
d63c268
raw
history blame
2.29 kB
# ------------------- LIBRARIES -------------------- #
import os, logging, torch, streamlit as st
from transformers import (
AutoTokenizer, AutoModelForCausalLM)
# --------------------- HELPER --------------------- #
def C(text, color="yellow"):
color_dict: dict = dict(
red="\033[01;31m",
green="\033[01;32m",
yellow="\033[01;33m",
blue="\033[01;34m",
magenta="\033[01;35m",
cyan="\033[01;36m",
)
color_dict[None] = "\033[0m"
return (
f"{color_dict.get(color, None)}"
f"{text}{color_dict[None]}")
# ------------------ ENVIORNMENT ------------------- #
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
device = ("cuda"
if torch.cuda.is_available() else "cpu")
logging.info(C("[INFO] "f"device = {device}"))
# ------------------ INITITALIZE ------------------- #
@st.cache(
suppress_st_warning=True
)
def model_init():
from transformers import GenerationConfig
# generation_config, unused_kwargs = GenerationConfig.from_pretrained(
# "ckip-joint/bloom-1b1-zh",
# max_new_tokens=200,
# return_unused_kwargs=True)
tokenizer = AutoTokenizer.from_pretrained(
"ckip-joint/bloom-1b1-zh")
model = AutoModelForCausalLM.from_pretrained(
"ckip-joint/bloom-1b1-zh",
# Ref.: Eric, Thanks!
# torch_dtype="auto",
# device_map="auto",
# Ref. for `half`: Chan-Jan, Thanks!
).eval().to(device)
st.balloons()
logging.info(C("[INFO] "f"Model init success!"))
return tokenizer, model
tokenizer, model = model_init()
try:
# ===================== INPUT ====================== #
prompt = st.text_input("Prompt: ")
# =================== INFERENCE ==================== #
if prompt:
st.balloons()
with torch.no_grad():
[texts_out] = model.generate(
**tokenizer(
prompt, return_tensors="pt",
).to(device),
max_new_tokens=200,
)
output_text = tokenizer.decode(texts_out)
st.balloons()
st.markdown(output_text)
except Exception as err:
st.write(str(err))
st.snow()