Spaces:
Runtime error
Runtime error
Create chatapp.py
Browse files- chatapp.py +104 -0
chatapp.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Set up the Streamlit app layout
|
8 |
+
st.title("HuggingFace Model Chat App")
|
9 |
+
st.write("Enter a HuggingFace model name, set parameters, and chat with the model!")
|
10 |
+
|
11 |
+
# Sidebar for model selection and parameters
|
12 |
+
with st.sidebar:
|
13 |
+
st.header("Model and Parameters")
|
14 |
+
model_name = st.text_input("HuggingFace Model Name (e.g., meta-llama/Llama-2-7b-chat-hf)", value="meta-llama/Llama-2-7b-chat-hf")
|
15 |
+
system_prompt = st.text_area("System Prompt", value="You are a helpful AI assistant.", height=100)
|
16 |
+
temperature = st.slider("Temperature (Randomness)", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
|
17 |
+
top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
|
18 |
+
max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)
|
19 |
+
load_model_button = st.button("Load Model")
|
20 |
+
|
21 |
+
# Initialize session state for model, tokenizer, and chat history
|
22 |
+
if 'model' not in st.session_state:
|
23 |
+
st.session_state.model = None
|
24 |
+
st.session_state.tokenizer = None
|
25 |
+
st.session_state.chat_history = []
|
26 |
+
st.session_state.model_loaded = False
|
27 |
+
|
28 |
+
# Function to save query and response to markdown file
|
29 |
+
def save_to_md(query, response):
|
30 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
31 |
+
filename = f"chat_{timestamp}.md"
|
32 |
+
os.makedirs("chat_history", exist_ok=True)
|
33 |
+
with open(os.path.join("chat_history", filename), "w", encoding="utf-8") as f:
|
34 |
+
f.write(f"# Chat Log\n\n**Query:** {query}\n\n**Response:** {response}\n")
|
35 |
+
return filename
|
36 |
+
|
37 |
+
# Load model and tokenizer when button is clicked
|
38 |
+
if load_model_button:
|
39 |
+
with st.spinner("Loading model... This may take a while."):
|
40 |
+
try:
|
41 |
+
st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HUGGINGFACE_TOKEN"))
|
42 |
+
st.session_state.model = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_name,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
device_map="auto",
|
46 |
+
use_auth_token=os.getenv("HUGGINGFACE_TOKEN")
|
47 |
+
)
|
48 |
+
st.session_state.model_loaded = True
|
49 |
+
st.success("Model loaded successfully!")
|
50 |
+
except Exception as e:
|
51 |
+
st.error(f"Error loading model: {str(e)}")
|
52 |
+
st.session_state.model_loaded = False
|
53 |
+
|
54 |
+
# Chat interface
|
55 |
+
if st.session_state.model_loaded:
|
56 |
+
st.header("Chat with the Model")
|
57 |
+
user_input = st.text_area("Your Message", height=100)
|
58 |
+
send_button = st.button("Send")
|
59 |
+
|
60 |
+
# Display chat history
|
61 |
+
st.subheader("Chat History")
|
62 |
+
for chat in st.session_state.chat_history:
|
63 |
+
st.markdown(f"**You:** {chat['query']}")
|
64 |
+
st.markdown(f"**Model:** {chat['response']}")
|
65 |
+
st.markdown(f"**Saved as:** {chat['filename']}")
|
66 |
+
st.markdown("---")
|
67 |
+
|
68 |
+
# Process user input and generate response
|
69 |
+
if send_button and user_input:
|
70 |
+
try:
|
71 |
+
# Prepare input with system prompt
|
72 |
+
full_input = f"{system_prompt}\n\nUser: {user_input}\nAssistant: "
|
73 |
+
inputs = st.session_state.tokenizer(full_input, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
74 |
+
|
75 |
+
# Generate response
|
76 |
+
with st.spinner("Generating response..."):
|
77 |
+
outputs = st.session_state.model.generate(
|
78 |
+
**inputs,
|
79 |
+
max_length=max_length,
|
80 |
+
temperature=temperature,
|
81 |
+
top_p=top_p,
|
82 |
+
do_sample=True,
|
83 |
+
pad_token_id=st.session_state.tokenizer.eos_token_id
|
84 |
+
)
|
85 |
+
response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
86 |
+
# Extract only the assistant's response
|
87 |
+
response = response.split("Assistant: ")[-1].strip()
|
88 |
+
|
89 |
+
# Save to markdown
|
90 |
+
filename = save_to_md(user_input, response)
|
91 |
+
|
92 |
+
# Update chat history
|
93 |
+
st.session_state.chat_history.append({
|
94 |
+
"query": user_input,
|
95 |
+
"response": response,
|
96 |
+
"filename": filename
|
97 |
+
})
|
98 |
+
|
99 |
+
# Rerun to update the display
|
100 |
+
st.rerun()
|
101 |
+
except Exception as e:
|
102 |
+
st.error(f"Error generating response: {str(e)}")
|
103 |
+
else:
|
104 |
+
st.info("Please load a model to start chatting.")
|