File size: 4,103 Bytes
333d03b d9fd074 333d03b 1d6ed4e fcc1b9e 1d6ed4e 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b d9fd074 333d03b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Importing necessary libraries
import time
import random
import numpy as np
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
# Set page title
st.set_page_config(
page_title="ChatGPT-124M",
layout="wide",
initial_sidebar_state="expanded",
page_icon="🤖",
)
# Title
st.title("🤖 ChatGPT-124M")
# --- Initialize Session State with Defaults ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "max_length" not in st.session_state:
st.session_state.max_length = 50
if "do_sample" not in st.session_state:
st.session_state.do_sample = True
if "top_k" not in st.session_state:
st.session_state.top_k = 5
if "top_p" not in st.session_state:
st.session_state.top_p = 0.95
if "temperature" not in st.session_state:
st.session_state.temperature = 0.9
# Display previous chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Load model and tokenizer
MODEL_NAME = "GPT_124M"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Function to convert a string into a generator (Fake stream output)
def string_to_generator(text):
"""Yields text one character at a time for a streaming effect."""
for char in text:
time.sleep(0.005)
yield char
# --- UI Controls for Generation Parameters ---
st.sidebar.header("⚙️ Generation Settings")
# Slider for max length (1 to 100)
max_length = st.sidebar.slider(
"Max Length", min_value=1, max_value=100, key="max_length"
)
# Toggle for `do_sample`
do_sample = st.sidebar.toggle(
"Enable Sampling", key="do_sample"
) # If `do_sample` is enabled, enable additional parameters
# Slider for top k (1 to 100)
top_k = st.sidebar.slider(
"Top-K", min_value=1, max_value=100, disabled=not do_sample, key="top_k"
)
# Slider for top p (0 to 1)
top_p = st.sidebar.slider(
"Top-P",
min_value=0.0,
max_value=1.0,
step=0.01,
disabled=not do_sample,
key="top_p",
)
# Slider for temperature (0 to 1)
temperature = st.sidebar.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
step=0.01,
disabled=not do_sample,
key="temperature",
)
# Reset Generation Settings
if st.sidebar.button("Reset"):
for st_key in [
"messages",
"do_sample",
"max_length",
"top_k",
"top_p",
"temperature",
]:
del st.session_state[st_key]
st.rerun()
# List of dynamic loading messages
loading_messages = [
"Generating your response, please wait...",
"Working on your response...",
"Processing, this will just take a moment...",
"Creating your response, hold on...",
"Loading your answer, please be patient...",
]
# --- Chat Input ---
if prompt := st.chat_input(
"The Earth revolves around", max_chars=400, key="chat_input"
):
# Save user message
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate response
with st.chat_message("assistant"):
tokens = tokenizer.encode(prompt, return_tensors="pt")
with st.spinner(random.choice(loading_messages)):
generated_tokens = model.generate(
tokens,
max_length=max_length,
do_sample=do_sample,
top_k=top_k if do_sample else 1,
top_p=top_p if do_sample else 1.0,
temperature=temperature if do_sample else 1.0,
)
response_text = tokenizer.decode(generated_tokens)
response = st.write_stream(string_to_generator(response_text))
# Save bot response
st.session_state.messages.append({"role": "assistant", "content": response})
|