import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM import torch # Streamlit app setup st.set_page_config(page_title="Chat", layout="wide") # Sidebar: Model controls st.sidebar.title("Model Controls") model_options = { "1": "karthikeyan-r/slm-custom-model_6k", "2": "karthikeyan-r/calculation_model" } model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values())) load_model_button = st.sidebar.button("Load Model") clear_conversation_button = st.sidebar.button("Clear Conversation") clear_model_button = st.sidebar.button("Clear Model") # Main UI st.title("Chat Conversation UI") # Session states if "model" not in st.session_state: st.session_state["model"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "qa_pipeline" not in st.session_state: st.session_state["qa_pipeline"] = None if "conversation" not in st.session_state: st.session_state["conversation"] = [] if "user_input" not in st.session_state: st.session_state["user_input"] = "" # Load Model if load_model_button: with st.spinner("Loading model..."): try: # Load the selected model if model_choice == model_options["1"]: # Load the T5 model for general QA (slm-custom-model_6k) device = 0 if torch.cuda.is_available() else -1 st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache") st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache") st.session_state["qa_pipeline"] = pipeline( "text2text-generation", model=st.session_state["model"], tokenizer=st.session_state["tokenizer"], device=device ) elif model_choice == model_options["2"]: # Load the calculation model (calculation_model) tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache") model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache") # Add special tokens if not present if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) model.resize_token_embeddings(len(tokenizer)) if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '[EOS]'}) model.resize_token_embeddings(len(tokenizer)) # Update configuration model.config.pad_token_id = tokenizer.pad_token_id model.config.eos_token_id = tokenizer.eos_token_id st.session_state["model"] = model st.session_state["tokenizer"] = tokenizer st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline st.success("Model loaded successfully and ready!") except Exception as e: st.error(f"Error loading model: {e}") # Clear Model if clear_model_button: st.session_state["model"] = None st.session_state["tokenizer"] = None st.session_state["qa_pipeline"] = None st.success("Model cleared.") # Chat Conversation Display def display_conversation(): """Display the chat conversation dynamically.""" st.subheader("Conversation") for idx, (speaker, message) in enumerate(st.session_state["conversation"]): if speaker == "You": st.markdown(f"**You:** {message}") else: st.markdown(f"**Model:** {message}") display_conversation() # Input Area if st.session_state["qa_pipeline"]: user_input = st.text_input( "Enter your query:", value=st.session_state["user_input"], # Use session state for persistence key="chat_input", ) if st.button("Send", key="send_button"): if user_input: with st.spinner("Generating response..."): try: # Generate the model response for general QA (T5 model) response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400) generated_text = response[0]["generated_text"] # Update the conversation st.session_state["conversation"].append(("You", user_input)) st.session_state["conversation"].append(("Model", generated_text)) # Clear the input field after submission st.session_state["user_input"] = "" # Rerender the conversation immediately display_conversation() except Exception as e: st.error(f"Error generating response: {e}") else: # Handle user input for the calculation model (calculation_model) if st.session_state["model"] and model_choice == model_options["2"]: user_input = st.text_input( "Enter your query for calculation:", value=st.session_state["user_input"], key="calculation_input", ) if st.button("Send Calculation", key="send_calculation_button"): if user_input: with st.spinner("Generating response..."): try: # Generate the model response for the calculation model inputs = st.session_state["tokenizer"](f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True) input_ids = inputs.input_ids attention_mask = inputs.attention_mask output = st.session_state["model"].generate( input_ids=input_ids, attention_mask=attention_mask, max_length=50, pad_token_id=st.session_state["tokenizer"].pad_token_id, eos_token_id=st.session_state["tokenizer"].eos_token_id, do_sample=False ) decoded_output = st.session_state["tokenizer"].decode(output[0], skip_special_tokens=True) if "Output:" in decoded_output: answer = decoded_output.split("Output:")[-1].strip() else: answer = decoded_output.strip() # Update the conversation st.session_state["conversation"].append(("You", user_input)) st.session_state["conversation"].append(("Model", answer)) # Clear the input field after submission st.session_state["user_input"] = "" # Rerender the conversation immediately display_conversation() except Exception as e: st.error(f"Error generating response: {e}") # Clear Conversation if clear_conversation_button: st.session_state["conversation"] = [] st.session_state["user_input"] = "" # Clear input field st.success("Conversation cleared.")