conversationbot / app.py
karthikeyan-r's picture
Update app.py
eeaf9ed verified
raw
history blame
6.87 kB
import streamlit as st
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
import torch
# ----- Streamlit page config -----
st.set_page_config(page_title="Chat", layout="wide")
# ----- Sidebar: Model controls -----
st.sidebar.title("Model Controls")
model_options = {
"1": "karthikeyan-r/calculation_model_11k",
"2": "karthikeyan-r/slm-custom-model_6k"
}
model_choice = st.sidebar.selectbox(
"Select Model",
options=list(model_options.values())
)
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# ----- Session States -----
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
# We'll store conversation as a list of dicts,
# e.g. [{"role": "assistant", "content": "Hello..."}, {"role": "user", "content": "..."}]
st.session_state["conversation"] = []
# ----- Load Model -----
if load_model_button:
with st.spinner("Loading model..."):
try:
if model_choice == model_options["1"]:
# Load the calculation model
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
# Add special tokens if needed
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({'eos_token': '[EOS]'})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
st.session_state["qa_pipeline"] = None # Not needed for calculation model
elif model_choice == model_options["2"]:
# Load the T5 model for general QA
device = 0 if torch.cuda.is_available() else -1
model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
qa_pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
st.session_state["qa_pipeline"] = qa_pipe
# If conversation is empty, insert a welcome message
if len(st.session_state["conversation"]) == 0:
st.session_state["conversation"].append({
"role": "assistant",
"content": "Hello! I’m your assistant. How can I help you today?"
})
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
# ----- Clear Model -----
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# ----- Clear Conversation -----
if clear_conversation_button:
st.session_state["conversation"] = []
st.success("Conversation cleared.")
# ----- Title -----
st.title("Chat Conversation UI")
user_input = None
if st.session_state["qa_pipeline"]:
# T5 pipeline
user_input = st.chat_input("Enter your query:")
if user_input:
# 1) Save user message
st.session_state["conversation"].append({
"role": "user",
"content": user_input
})
# 2) Generate assistant response
try:
response = st.session_state["qa_pipeline"](
f"Q: {user_input}", max_length=250
)
answer = response[0]["generated_text"]
except Exception as e:
answer = f"Error: {str(e)}"
# 3) Append assistant message to conversation
st.session_state["conversation"].append({
"role": "assistant",
"content": answer
})
elif st.session_state["model"] and (model_choice == model_options["1"]):
# Calculation model
user_input = st.chat_input("Enter your query for calculation:")
if user_input:
# 1) Save user message
st.session_state["conversation"].append({
"role": "user",
"content": user_input
})
# 2) Generate assistant response
tokenizer = st.session_state["tokenizer"]
model = st.session_state["model"]
try:
inputs = tokenizer(
f"Input: {user_input}\nOutput:",
return_tensors="pt",
padding=True,
truncation=True
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=250,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False
)
decoded_output = tokenizer.decode(
output[0],
skip_special_tokens=True
)
# Extract answer after 'Output:' if present
if "Output:" in decoded_output:
answer = decoded_output.split("Output:")[-1].strip()
else:
answer = decoded_output.strip()
except Exception as e:
answer = f"Error: {str(e)}"
# 3) Append assistant message to conversation
st.session_state["conversation"].append({
"role": "assistant",
"content": answer
})
else:
# If no model is loaded:
st.info("No model is loaded. Please select a model and click 'Load Model' from the sidebar.")
for message in st.session_state["conversation"]:
if message["role"] == "user":
with st.chat_message("user"):
st.write(message["content"])
else:
with st.chat_message("assistant"):
st.write(message["content"])