Spaces:
Running
Running
File size: 3,489 Bytes
62098f3 409d83c 62098f3 c4c458c 62098f3 409d83c 62098f3 d371388 62098f3 c4c458c 62098f3 c4c458c 11d5f76 732cf88 62098f3 c4c458c d371388 732cf88 c4c458c 62098f3 e559684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch
# Streamlit app setup
st.set_page_config(page_title="Hugging Face Chat", layout="wide")
# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# Main UI
st.title("Chat Conversation UI")
# Session states
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
st.session_state["conversation"] = []
if "user_input" not in st.session_state:
st.session_state["user_input"] = ""
# Load Model
if load_model_button:
with st.spinner("Loading model..."):
try:
device = 0 if torch.cuda.is_available() else -1
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
st.session_state["qa_pipeline"] = pipeline(
"text2text-generation",
model=st.session_state["model"],
tokenizer=st.session_state["tokenizer"],
device=device
)
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
# Clear Model
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# Layout for chat
chat_container = st.container()
input_container = st.container()
# Chat Conversation Display
with chat_container:
st.subheader("Conversation")
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
st.markdown(f"**{speaker}:** {message}")
# Input Area
with input_container:
if st.session_state["qa_pipeline"]:
user_input = st.text_input(
"Enter your query:",
value=st.session_state["user_input"], # Use session state for persistence
key="chat_input",
)
send_button_clicked = st.button("Send", key="send_button")
if send_button_clicked and user_input.strip():
# Process the input
with st.spinner("Generating response..."):
try:
response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400)
generated_text = response[0]["generated_text"]
# Append to conversation
st.session_state["conversation"].append(("You", user_input))
st.session_state["conversation"].append(("Model", generated_text))
# Clear input after submission
st.session_state["user_input"] = ""
except Exception as e:
st.error(f"Error generating response: {e}")
# Clear Conversation
if clear_conversation_button:
st.session_state["conversation"] = []
st.success("Conversation cleared.")
|