import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline import torch # Streamlit app setup st.set_page_config(page_title="Hugging Face Chat", layout="wide") # Sidebar: Model controls st.sidebar.title("Model Controls") model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k") load_model_button = st.sidebar.button("Load Model") clear_conversation_button = st.sidebar.button("Clear Conversation") clear_model_button = st.sidebar.button("Clear Model") # Main UI st.title("Chat Conversation UI") # Session states if "model" not in st.session_state: st.session_state["model"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "qa_pipeline" not in st.session_state: st.session_state["qa_pipeline"] = None if "conversation" not in st.session_state: st.session_state["conversation"] = [] if "user_input" not in st.session_state: st.session_state["user_input"] = "" # Load Model if load_model_button: with st.spinner("Loading model..."): try: device = 0 if torch.cuda.is_available() else -1 st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache") st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache") st.session_state["qa_pipeline"] = pipeline( "text2text-generation", model=st.session_state["model"], tokenizer=st.session_state["tokenizer"], device=device ) st.success("Model loaded successfully and ready!") except Exception as e: st.error(f"Error loading model: {e}") # Clear Model if clear_model_button: st.session_state["model"] = None st.session_state["tokenizer"] = None st.session_state["qa_pipeline"] = None st.success("Model cleared.") # Chat Conversation Display st.subheader("Conversation") for idx, (speaker, message) in enumerate(st.session_state["conversation"]): if speaker == "You": st.markdown(f"**You:** {message}") else: st.markdown(f"**Model:** {message}") # Input Area if st.session_state["qa_pipeline"]: user_input = st.text_input( "Enter your query:", value=st.session_state["user_input"], # Use session state for persistence key="chat_input", ) if st.button("Send", key="send_button"): if user_input: with st.spinner("Generating response..."): try: response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400) generated_text = response[0]["generated_text"] st.session_state["conversation"].append(("You", user_input)) st.session_state["conversation"].append(("Model", generated_text)) st.session_state["user_input"] = "" # Clear input after submission st.experimental_rerun() # Rerun to update the conversation display except Exception as e: st.error(f"Error generating response: {e}") # Clear Conversation if clear_conversation_button: st.session_state["conversation"] = [] st.session_state["user_input"] = "" # Clear input field st.success("Conversation cleared.") st.experimental_rerun() # Rerun to refresh the cleared conversation