Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
from datetime import datetime | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Set up the Streamlit app layout | |
st.title("HuggingFace Model Chat App") | |
st.write("Enter a HuggingFace model name, set parameters, and chat with the model!") | |
# Sidebar for model selection and parameters | |
with st.sidebar: | |
st.header("Model and Parameters") | |
model_name = st.text_input("HuggingFace Model Name (e.g., meta-llama/Llama-2-7b-chat-hf)", value="meta-llama/Llama-2-7b-chat-hf") | |
system_prompt = st.text_area("System Prompt", value="You are a helpful AI assistant.", height=100) | |
temperature = st.slider("Temperature (Randomness)", min_value=0.1, max_value=2.0, value=0.7, step=0.1) | |
top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.9, step=0.05) | |
max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10) | |
load_model_button = st.button("Load Model") | |
# Initialize session state for model, tokenizer, and chat history | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
st.session_state.tokenizer = None | |
st.session_state.chat_history = [] | |
st.session_state.model_loaded = False | |
# Function to save query and response to markdown file | |
def save_to_md(query, response): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"chat_{timestamp}.md" | |
os.makedirs("chat_history", exist_ok=True) | |
with open(os.path.join("chat_history", filename), "w", encoding="utf-8") as f: | |
f.write(f"# Chat Log\n\n**Query:** {query}\n\n**Response:** {response}\n") | |
return filename | |
# Load model and tokenizer when button is clicked | |
if load_model_button: | |
with st.spinner("Loading model... This may take a while."): | |
try: | |
st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HUGGINGFACE_TOKEN")) | |
st.session_state.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
use_auth_token=os.getenv("HUGGINGFACE_TOKEN") | |
) | |
st.session_state.model_loaded = True | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.session_state.model_loaded = False | |
# Chat interface | |
if st.session_state.model_loaded: | |
st.header("Chat with the Model") | |
user_input = st.text_area("Your Message", height=100) | |
send_button = st.button("Send") | |
# Display chat history | |
st.subheader("Chat History") | |
for chat in st.session_state.chat_history: | |
st.markdown(f"**You:** {chat['query']}") | |
st.markdown(f"**Model:** {chat['response']}") | |
st.markdown(f"**Saved as:** {chat['filename']}") | |
st.markdown("---") | |
# Process user input and generate response | |
if send_button and user_input: | |
try: | |
# Prepare input with system prompt | |
full_input = f"{system_prompt}\n\nUser: {user_input}\nAssistant: " | |
inputs = st.session_state.tokenizer(full_input, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
# Generate response | |
with st.spinner("Generating response..."): | |
outputs = st.session_state.model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=st.session_state.tokenizer.eos_token_id | |
) | |
response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
response = response.split("Assistant: ")[-1].strip() | |
# Save to markdown | |
filename = save_to_md(user_input, response) | |
# Update chat history | |
st.session_state.chat_history.append({ | |
"query": user_input, | |
"response": response, | |
"filename": filename | |
}) | |
# Rerun to update the display | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
else: | |
st.info("Please load a model to start chatting.") |