Spaces:
Running
Running
import streamlit as st | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline | |
import torch | |
# Initialize Streamlit app | |
st.set_page_config(page_title="Hugging Face Chat", layout="wide") | |
# Sidebar: Model controls | |
st.sidebar.title("Model Controls") | |
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k") | |
load_model_button = st.sidebar.button("Load Model") | |
clear_conversation_button = st.sidebar.button("Clear Conversation") | |
clear_model_button = st.sidebar.button("Clear Model") | |
# Main UI | |
st.title("Chat Conversation UI") | |
st.write("Start a conversation with your Hugging Face model.") | |
# Initialize session states | |
if "model" not in st.session_state: | |
st.session_state["model"] = None | |
if "tokenizer" not in st.session_state: | |
st.session_state["tokenizer"] = None | |
if "qa_pipeline" not in st.session_state: | |
st.session_state["qa_pipeline"] = None | |
if "conversation" not in st.session_state: | |
st.session_state["conversation"] = [] | |
# Load Model | |
if load_model_button: | |
with st.spinner("Loading model..."): | |
try: | |
# Set up device | |
device = 0 if torch.cuda.is_available() else -1 | |
# Load model and tokenizer | |
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache") | |
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache") | |
# Initialize pipeline | |
st.session_state["qa_pipeline"] = pipeline( | |
"text2text-generation", | |
model=st.session_state["model"], | |
tokenizer=st.session_state["tokenizer"], | |
device=device | |
) | |
st.success("Model loaded successfully and ready!") | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
# Clear Model | |
if clear_model_button: | |
st.session_state["model"] = None | |
st.session_state["tokenizer"] = None | |
st.session_state["qa_pipeline"] = None | |
st.success("Model cleared.") | |
# Chat Input and Output | |
if st.session_state["qa_pipeline"]: | |
user_input = st.text_input("Enter your query:", key="chat_input") | |
if st.button("Send"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
try: | |
response = st.session_state["qa_pipeline"](user_input, max_length=300) | |
generated_text = response[0]["generated_text"] | |
st.session_state["conversation"].append(("You", user_input)) | |
st.session_state["conversation"].append(("Model", generated_text)) | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
# Display conversation | |
for speaker, message in st.session_state["conversation"]: | |
if speaker == "You": | |
st.text_area("You:", message, key=message + "_you", disabled=True) | |
else: | |
st.text_area("Model:", message, key=message + "_model", disabled=True) | |
# Clear Conversation | |
if clear_conversation_button: | |
st.session_state["conversation"] = [] | |
st.success("Conversation cleared.") | |