Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
st.set_page_config(page_title="Hugging Face Chat", layout="wide")
|
7 |
|
8 |
# Sidebar: Model controls
|
@@ -13,10 +13,9 @@ clear_conversation_button = st.sidebar.button("Clear Conversation")
|
|
13 |
clear_model_button = st.sidebar.button("Clear Model")
|
14 |
|
15 |
# Main UI
|
16 |
-
st.title("Chat Conversation UI")
|
17 |
-
st.write("Start a conversation with your Hugging Face model.")
|
18 |
|
19 |
-
#
|
20 |
if "model" not in st.session_state:
|
21 |
st.session_state["model"] = None
|
22 |
if "tokenizer" not in st.session_state:
|
@@ -30,21 +29,15 @@ if "conversation" not in st.session_state:
|
|
30 |
if load_model_button:
|
31 |
with st.spinner("Loading model..."):
|
32 |
try:
|
33 |
-
# Set up device
|
34 |
device = 0 if torch.cuda.is_available() else -1
|
35 |
-
|
36 |
-
# Load model and tokenizer
|
37 |
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
|
38 |
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
|
39 |
-
|
40 |
-
# Initialize pipeline
|
41 |
st.session_state["qa_pipeline"] = pipeline(
|
42 |
"text2text-generation",
|
43 |
model=st.session_state["model"],
|
44 |
tokenizer=st.session_state["tokenizer"],
|
45 |
device=device
|
46 |
)
|
47 |
-
|
48 |
st.success("Model loaded successfully and ready!")
|
49 |
except Exception as e:
|
50 |
st.error(f"Error loading model: {e}")
|
@@ -61,10 +54,12 @@ if st.session_state["qa_pipeline"]:
|
|
61 |
user_input = st.text_input("Enter your query:", key="chat_input")
|
62 |
if st.button("Send"):
|
63 |
if user_input:
|
|
|
64 |
with st.spinner("Generating response..."):
|
65 |
try:
|
66 |
response = st.session_state["qa_pipeline"](user_input, max_length=300)
|
67 |
generated_text = response[0]["generated_text"]
|
|
|
68 |
st.session_state["conversation"].append(("You", user_input))
|
69 |
st.session_state["conversation"].append(("Model", generated_text))
|
70 |
except Exception as e:
|
@@ -73,10 +68,9 @@ if st.session_state["qa_pipeline"]:
|
|
73 |
# Display conversation
|
74 |
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
|
75 |
if speaker == "You":
|
76 |
-
st.text_area(f"You:", message, key=f"you_{idx}", disabled=True)
|
77 |
else:
|
78 |
-
st.text_area(f"Model:", message, key=f"model_{idx}", disabled=True)
|
79 |
-
|
80 |
|
81 |
# Clear Conversation
|
82 |
if clear_conversation_button:
|
|
|
2 |
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
|
3 |
import torch
|
4 |
|
5 |
+
# Streamlit app setup
|
6 |
st.set_page_config(page_title="Hugging Face Chat", layout="wide")
|
7 |
|
8 |
# Sidebar: Model controls
|
|
|
13 |
clear_model_button = st.sidebar.button("Clear Model")
|
14 |
|
15 |
# Main UI
|
16 |
+
st.title("Custom Trained Model Chat Conversation UI")
|
|
|
17 |
|
18 |
+
# Session states
|
19 |
if "model" not in st.session_state:
|
20 |
st.session_state["model"] = None
|
21 |
if "tokenizer" not in st.session_state:
|
|
|
29 |
if load_model_button:
|
30 |
with st.spinner("Loading model..."):
|
31 |
try:
|
|
|
32 |
device = 0 if torch.cuda.is_available() else -1
|
|
|
|
|
33 |
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
|
34 |
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
|
|
|
|
|
35 |
st.session_state["qa_pipeline"] = pipeline(
|
36 |
"text2text-generation",
|
37 |
model=st.session_state["model"],
|
38 |
tokenizer=st.session_state["tokenizer"],
|
39 |
device=device
|
40 |
)
|
|
|
41 |
st.success("Model loaded successfully and ready!")
|
42 |
except Exception as e:
|
43 |
st.error(f"Error loading model: {e}")
|
|
|
54 |
user_input = st.text_input("Enter your query:", key="chat_input")
|
55 |
if st.button("Send"):
|
56 |
if user_input:
|
57 |
+
st.write(f"Debug: Query - {user_input}") # Debugging
|
58 |
with st.spinner("Generating response..."):
|
59 |
try:
|
60 |
response = st.session_state["qa_pipeline"](user_input, max_length=300)
|
61 |
generated_text = response[0]["generated_text"]
|
62 |
+
st.write(f"Debug: Response - {generated_text}") # Debugging
|
63 |
st.session_state["conversation"].append(("You", user_input))
|
64 |
st.session_state["conversation"].append(("Model", generated_text))
|
65 |
except Exception as e:
|
|
|
68 |
# Display conversation
|
69 |
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
|
70 |
if speaker == "You":
|
71 |
+
st.text_area(f"You ({idx}):", message, key=f"you_{idx}", disabled=True)
|
72 |
else:
|
73 |
+
st.text_area(f"Model ({idx}):", message, key=f"model_{idx}", disabled=True)
|
|
|
74 |
|
75 |
# Clear Conversation
|
76 |
if clear_conversation_button:
|