import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline import torch # Streamlit app setup st.set_page_config(page_title="Hugging Face Chat", layout="wide") # Sidebar: Model controls st.sidebar.title("Model Controls") model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k") load_model_button = st.sidebar.button("Load Model") clear_conversation_button = st.sidebar.button("Clear Conversation") clear_model_button = st.sidebar.button("Clear Model") # Main UI st.title("Custom Trained Model Chat Conversation UI") # Session states if "model" not in st.session_state: st.session_state["model"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "qa_pipeline" not in st.session_state: st.session_state["qa_pipeline"] = None if "conversation" not in st.session_state: st.session_state["conversation"] = [] # Load Model if load_model_button: with st.spinner("Loading model..."): try: device = 0 if torch.cuda.is_available() else -1 st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache") st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache") st.session_state["qa_pipeline"] = pipeline( "text2text-generation", model=st.session_state["model"], tokenizer=st.session_state["tokenizer"], device=device ) st.success("Model loaded successfully and ready!") except Exception as e: st.error(f"Error loading model: {e}") # Clear Model if clear_model_button: st.session_state["model"] = None st.session_state["tokenizer"] = None st.session_state["qa_pipeline"] = None st.success("Model cleared.") # Chat Input and Output if st.session_state["qa_pipeline"]: user_input = st.text_input("Enter your query:", key="chat_input") if st.button("Send"): if user_input: st.write(f"Debug: Query - {user_input}") # Debugging with st.spinner("Generating response..."): try: response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=300) generated_text = response[0]["generated_text"] st.write(f"Debug: Response - {generated_text}") # Debugging st.session_state["conversation"].append(("You", user_input)) st.session_state["conversation"].append(("Model", generated_text)) except Exception as e: st.error(f"Error generating response: {e}") # Display conversation for idx, (speaker, message) in enumerate(st.session_state["conversation"]): if speaker == "You": st.text_area(f"You ({idx}):", message, key=f"you_{idx}", disabled=False) else: st.text_area(f"Model ({idx}):", message, key=f"model_{idx}", disabled=False) # Clear Conversation if clear_conversation_button: st.session_state["conversation"] = [] st.success("Conversation cleared.")