Spaces:
Build error
Build error
import streamlit as st | |
from utils import ( | |
load_model, | |
load_finetuned_model, | |
generate_response, | |
get_hf_token | |
) | |
import os | |
import json | |
from datetime import datetime | |
st.set_page_config(page_title="Gemma Chat", layout="wide") | |
# ------------------------------- | |
# π‘ Theme Toggle | |
# ------------------------------- | |
dark_mode = st.sidebar.toggle("π Dark Mode", value=False) | |
if dark_mode: | |
st.markdown( | |
""" | |
<style> | |
body { background-color: #1e1e1e; color: #ffffff; } | |
.stTextInput, .stTextArea, .stSelectbox, .stSlider { color: #ffffff !important; } | |
</style> | |
""", unsafe_allow_html=True | |
) | |
st.title("π¬ Chat with Gemma Model") | |
# ------------------------------- | |
# π Model Source Selection | |
# ------------------------------- | |
model_source = st.sidebar.radio("π Select Model Source", ["Local (.pt)", "Hugging Face"]) | |
# ------------------------------- | |
# π₯ Dynamic Model List | |
# ------------------------------- | |
if model_source == "Local (.pt)": | |
model_dir = "models" | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] | |
if local_models: | |
selected_model = st.sidebar.selectbox("π οΈ Select Local Model", local_models) | |
model_path = os.path.join(model_dir, selected_model) | |
else: | |
st.warning("β οΈ No fine-tuned models found. Fine-tune a model first.") | |
st.stop() | |
else: | |
hf_models = [ | |
"google/gemma-3-1b-it", | |
"google/gemma-3-4b-pt", | |
"google/gemma-3-4b-it", | |
"google/gemma-3-12b-pt", | |
"google/gemma-3-12b-it", | |
"google/gemma-3-27b-pt", | |
"google/gemma-3-27b-it" | |
] | |
selected_model = st.sidebar.selectbox("π οΈ Select Hugging Face Model", hf_models) | |
model_path = None | |
# ------------------------------- | |
# π₯ Model Loading | |
# ------------------------------- | |
hf_token = get_hf_token() | |
if model_source == "Local (.pt)": | |
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) # Base model first | |
model = load_finetuned_model(model, model_path) | |
if model: | |
st.success(f"β Local fine-tuned model loaded: `{selected_model}`") | |
else: | |
st.error("β Failed to load local model.") | |
st.stop() | |
else: | |
tokenizer, model = load_model(selected_model, hf_token) | |
if model: | |
st.success(f"β Hugging Face model loaded: `{selected_model}`") | |
else: | |
st.error("β Failed to load Hugging Face model.") | |
st.stop() | |
# ------------------------------- | |
# βοΈ Model Configuration Panel | |
# ------------------------------- | |
st.sidebar.header("βοΈ Model Configuration") | |
temperature = st.sidebar.slider("π₯ Temperature", 0.1, 1.5, 0.7, 0.1) | |
top_p = st.sidebar.slider("π― Top-p", 0.1, 1.0, 0.9, 0.1) | |
repetition_penalty = st.sidebar.slider("π Repetition Penalty", 0.5, 2.0, 1.0, 0.1) | |
# ------------------------------- | |
# π¬ Chat Interface | |
# ------------------------------- | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = [] | |
prompt = st.text_area("π¬ Enter your message:", "Hello, how are you?", key="prompt", height=100) | |
max_length = st.slider("π Max Response Length", min_value=50, max_value=1000, value=300, step=50) | |
# ------------------------------- | |
# π Streaming Response Function | |
# ------------------------------- | |
def stream_response(): | |
""" | |
Streams the response token by token. | |
""" | |
response = generate_response(prompt, model, tokenizer, max_length) | |
if response: | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
st.session_state.conversation.append({"sender": "π€ You", "message": prompt, "timestamp": timestamp}) | |
st.session_state.conversation.append({"sender": "π€ AI", "message": response, "timestamp": timestamp}) | |
return response | |
else: | |
st.error("β Failed to generate response.") | |
return None | |
# ------------------------------- | |
# π― Conversation Controls | |
# ------------------------------- | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
if col1.button("π Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"): | |
stream_response() | |
if col2.button("ποΈ Clear Conversation"): | |
st.session_state.conversation = [] | |
# Export & Import | |
if col3.download_button("π₯ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"): | |
st.success("β Chat exported successfully!") | |
uploaded_file = st.file_uploader("π€ Import Conversation", type=["json"]) | |
if uploaded_file is not None: | |
st.session_state.conversation = json.load(uploaded_file) | |
st.success("β Conversation imported successfully!") | |
# ------------------------------- | |
# π οΈ Display Conversation | |
# ------------------------------- | |
st.subheader("π Conversation History") | |
for msg in st.session_state.conversation: | |
with st.container(): | |
st.markdown(f"**{msg['sender']}** \nπ {msg['timestamp']}") | |
st.write(msg['message']) | |