conversationbot / app.py
karthikeyan-r's picture
Update app.py
c2d1e89 verified
raw
history blame
7.41 kB
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# Streamlit app setup
st.set_page_config(page_title="Chat", layout="wide")
# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_options = {
"1": "karthikeyan-r/slm-custom-model_6k",
"2": "karthikeyan-r/calculation_model"
}
model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# Main UI
st.title("Chat Conversation UI")
# Session states
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
st.session_state["conversation"] = []
if "user_input" not in st.session_state:
st.session_state["user_input"] = ""
# Load Model
if load_model_button:
with st.spinner("Loading model..."):
try:
# Load the selected model
if model_choice == model_options["1"]:
# Load the T5 model for general QA (slm-custom-model_6k)
device = 0 if torch.cuda.is_available() else -1
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
st.session_state["qa_pipeline"] = pipeline(
"text2text-generation",
model=st.session_state["model"],
tokenizer=st.session_state["tokenizer"],
device=device
)
elif model_choice == model_options["2"]:
# Load the calculation model (calculation_model)
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
# Add special tokens if not present
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({'eos_token': '[EOS]'})
model.resize_token_embeddings(len(tokenizer))
# Update configuration
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
# Clear Model
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# Chat Conversation Display
def display_conversation():
"""Display the chat conversation dynamically."""
st.subheader("Conversation")
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
if speaker == "You":
st.markdown(f"**You:** {message}")
else:
st.markdown(f"**Model:** {message}")
display_conversation()
# Input Area
if st.session_state["qa_pipeline"]:
user_input = st.text_input(
"Enter your query:",
value=st.session_state["user_input"], # Use session state for persistence
key="chat_input",
)
if st.button("Send", key="send_button"):
if user_input:
with st.spinner("Generating response..."):
try:
# Generate the model response for general QA (T5 model)
response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400)
generated_text = response[0]["generated_text"]
# Update the conversation
st.session_state["conversation"].append(("You", user_input))
st.session_state["conversation"].append(("Model", generated_text))
# Clear the input field after submission
st.session_state["user_input"] = ""
# Rerender the conversation immediately
display_conversation()
except Exception as e:
st.error(f"Error generating response: {e}")
else:
# Handle user input for the calculation model (calculation_model)
if st.session_state["model"] and model_choice == model_options["2"]:
user_input = st.text_input(
"Enter your query for calculation:",
value=st.session_state["user_input"],
key="calculation_input",
)
if st.button("Send Calculation", key="send_calculation_button"):
if user_input:
with st.spinner("Generating response..."):
try:
# Generate the model response for the calculation model
inputs = st.session_state["tokenizer"](f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = st.session_state["model"].generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=50,
pad_token_id=st.session_state["tokenizer"].pad_token_id,
eos_token_id=st.session_state["tokenizer"].eos_token_id,
do_sample=False
)
decoded_output = st.session_state["tokenizer"].decode(output[0], skip_special_tokens=True)
if "Output:" in decoded_output:
answer = decoded_output.split("Output:")[-1].strip()
else:
answer = decoded_output.strip()
# Update the conversation
st.session_state["conversation"].append(("You", user_input))
st.session_state["conversation"].append(("Model", answer))
# Clear the input field after submission
st.session_state["user_input"] = ""
# Rerender the conversation immediately
display_conversation()
except Exception as e:
st.error(f"Error generating response: {e}")
# Clear Conversation
if clear_conversation_button:
st.session_state["conversation"] = []
st.session_state["user_input"] = "" # Clear input field
st.success("Conversation cleared.")