calm3b / app.py
OzoneAsai's picture
Update app.py
acdbcde
raw
history blame
1.81 kB
print("start to run")
import streamlit as st
import os
os.system("pip install torch transformers sentencepiece accelerate")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# モデルとトークナイザの初期化
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
# 推論用の関数
def generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty):
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output
# Streamlitアプリの設定
st.title("Causal Language Modeling")
st.write("AIによる文章生成")
# パラメータの入力
input_text = st.text_area("入力テキスト")
max_new_tokens = st.slider("生成する最大トークン数", min_value=1, max_value=512, value=64)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7)
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9)
repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=2.0, value=1.05)
# 推論結果の表示
if st.button("生成"):
output = generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty)
st.write("生成されたテキスト:")
st.write(output)