|
|
|
import time
|
|
import random
|
|
import numpy as np
|
|
import streamlit as st
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
st.set_page_config(
|
|
page_title="ChatGPT-124M",
|
|
layout="wide",
|
|
initial_sidebar_state="expanded",
|
|
page_icon="π€",
|
|
)
|
|
|
|
|
|
st.title("π€ ChatGPT-124M")
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
if "max_length" not in st.session_state:
|
|
st.session_state.max_length = 50
|
|
|
|
if "do_sample" not in st.session_state:
|
|
st.session_state.do_sample = True
|
|
|
|
if "top_k" not in st.session_state:
|
|
st.session_state.top_k = 5
|
|
|
|
if "top_p" not in st.session_state:
|
|
st.session_state.top_p = 0.95
|
|
|
|
if "temperature" not in st.session_state:
|
|
st.session_state.temperature = 0.9
|
|
|
|
|
|
for message in st.session_state.messages:
|
|
with st.chat_message(message["role"]):
|
|
st.markdown(message["content"])
|
|
|
|
|
|
MODEL_NAME = "GPT_124M"
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
|
|
|
|
def string_to_generator(text):
|
|
"""Yields text one character at a time for a streaming effect."""
|
|
for char in text:
|
|
time.sleep(0.005)
|
|
yield char
|
|
|
|
|
|
|
|
st.sidebar.header("βοΈ Generation Settings")
|
|
|
|
|
|
|
|
max_length = st.sidebar.slider(
|
|
"Max Length", min_value=1, max_value=100, key="max_length"
|
|
)
|
|
|
|
|
|
do_sample = st.sidebar.toggle(
|
|
"Enable Sampling", key="do_sample"
|
|
)
|
|
|
|
|
|
top_k = st.sidebar.slider(
|
|
"Top-K", min_value=1, max_value=100, disabled=not do_sample, key="top_k"
|
|
)
|
|
|
|
|
|
top_p = st.sidebar.slider(
|
|
"Top-P",
|
|
min_value=0.0,
|
|
max_value=1.0,
|
|
step=0.01,
|
|
disabled=not do_sample,
|
|
key="top_p",
|
|
)
|
|
|
|
|
|
temperature = st.sidebar.slider(
|
|
"Temperature",
|
|
min_value=0.0,
|
|
max_value=1.0,
|
|
step=0.01,
|
|
disabled=not do_sample,
|
|
key="temperature",
|
|
)
|
|
|
|
|
|
if st.sidebar.button("Reset"):
|
|
for st_key in [
|
|
"messages",
|
|
"do_sample",
|
|
"max_length",
|
|
"top_k",
|
|
"top_p",
|
|
"temperature",
|
|
]:
|
|
del st.session_state[st_key]
|
|
st.rerun()
|
|
|
|
|
|
loading_messages = [
|
|
"Generating your response, please wait...",
|
|
"Working on your response...",
|
|
"Processing, this will just take a moment...",
|
|
"Creating your response, hold on...",
|
|
"Loading your answer, please be patient...",
|
|
]
|
|
|
|
|
|
if prompt := st.chat_input(
|
|
"The Earth revolves around", max_chars=400, key="chat_input"
|
|
):
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
tokens = tokenizer.encode(prompt, return_tensors="pt")
|
|
|
|
with st.spinner(random.choice(loading_messages)):
|
|
generated_tokens = model.generate(
|
|
tokens,
|
|
max_length=max_length,
|
|
do_sample=do_sample,
|
|
top_k=top_k if do_sample else 1,
|
|
top_p=top_p if do_sample else 1.0,
|
|
temperature=temperature if do_sample else 1.0,
|
|
)
|
|
|
|
response_text = tokenizer.decode(generated_tokens)
|
|
response = st.write_stream(string_to_generator(response_text))
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|