Spaces:
Running
Running
import streamlit as st | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Streamlit app setup | |
st.set_page_config(page_title="Chat", layout="wide") | |
# Sidebar: Model controls | |
st.sidebar.title("Model Controls") | |
model_options = { | |
"1": "karthikeyan-r/slm-custom-model_6k", | |
"2": "karthikeyan-r/calculation_model" | |
} | |
model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values())) | |
load_model_button = st.sidebar.button("Load Model") | |
clear_conversation_button = st.sidebar.button("Clear Conversation") | |
clear_model_button = st.sidebar.button("Clear Model") | |
# Main UI | |
st.title("Chat Conversation UI") | |
# Session states | |
if "model" not in st.session_state: | |
st.session_state["model"] = None | |
if "tokenizer" not in st.session_state: | |
st.session_state["tokenizer"] = None | |
if "qa_pipeline" not in st.session_state: | |
st.session_state["qa_pipeline"] = None | |
if "conversation" not in st.session_state: | |
st.session_state["conversation"] = [] | |
if "user_input" not in st.session_state: | |
st.session_state["user_input"] = "" | |
# Load Model | |
if load_model_button: | |
with st.spinner("Loading model..."): | |
try: | |
# Load the selected model | |
if model_choice == model_options["1"]: | |
# Load the T5 model for general QA (slm-custom-model_6k) | |
device = 0 if torch.cuda.is_available() else -1 | |
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache") | |
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache") | |
st.session_state["qa_pipeline"] = pipeline( | |
"text2text-generation", | |
model=st.session_state["model"], | |
tokenizer=st.session_state["tokenizer"], | |
device=device | |
) | |
elif model_choice == model_options["2"]: | |
# Load the calculation model (calculation_model) | |
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache") | |
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache") | |
# Add special tokens if not present | |
if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
if tokenizer.eos_token is None: | |
tokenizer.add_special_tokens({'eos_token': '[EOS]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Update configuration | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.config.eos_token_id = tokenizer.eos_token_id | |
st.session_state["model"] = model | |
st.session_state["tokenizer"] = tokenizer | |
st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline | |
st.success("Model loaded successfully and ready!") | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
# Clear Model | |
if clear_model_button: | |
st.session_state["model"] = None | |
st.session_state["tokenizer"] = None | |
st.session_state["qa_pipeline"] = None | |
st.success("Model cleared.") | |
# Chat Conversation Display | |
def display_conversation(): | |
"""Display the chat conversation dynamically.""" | |
st.subheader("Conversation") | |
for idx, (speaker, message) in enumerate(st.session_state["conversation"]): | |
if speaker == "You": | |
st.markdown(f"**You:** {message}") | |
else: | |
st.markdown(f"**Model:** {message}") | |
display_conversation() | |
# Input Area | |
if st.session_state["qa_pipeline"]: | |
user_input = st.text_input( | |
"Enter your query:", | |
value=st.session_state["user_input"], # Use session state for persistence | |
key="chat_input", | |
) | |
if st.button("Send", key="send_button"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
try: | |
# Generate the model response for general QA (T5 model) | |
response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400) | |
generated_text = response[0]["generated_text"] | |
# Update the conversation | |
st.session_state["conversation"].append(("You", user_input)) | |
st.session_state["conversation"].append(("Model", generated_text)) | |
# Clear the input field after submission | |
st.session_state["user_input"] = "" | |
# Rerender the conversation immediately | |
display_conversation() | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
else: | |
# Handle user input for the calculation model (calculation_model) | |
if st.session_state["model"] and model_choice == model_options["2"]: | |
user_input = st.text_input( | |
"Enter your query for calculation:", | |
value=st.session_state["user_input"], | |
key="calculation_input", | |
) | |
if st.button("Send Calculation", key="send_calculation_button"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
try: | |
# Generate the model response for the calculation model | |
inputs = st.session_state["tokenizer"](f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs.input_ids | |
attention_mask = inputs.attention_mask | |
output = st.session_state["model"].generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_length=50, | |
pad_token_id=st.session_state["tokenizer"].pad_token_id, | |
eos_token_id=st.session_state["tokenizer"].eos_token_id, | |
do_sample=False | |
) | |
decoded_output = st.session_state["tokenizer"].decode(output[0], skip_special_tokens=True) | |
if "Output:" in decoded_output: | |
answer = decoded_output.split("Output:")[-1].strip() | |
else: | |
answer = decoded_output.strip() | |
# Update the conversation | |
st.session_state["conversation"].append(("You", user_input)) | |
st.session_state["conversation"].append(("Model", answer)) | |
# Clear the input field after submission | |
st.session_state["user_input"] = "" | |
# Rerender the conversation immediately | |
display_conversation() | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
# Clear Conversation | |
if clear_conversation_button: | |
st.session_state["conversation"] = [] | |
st.session_state["user_input"] = "" # Clear input field | |
st.success("Conversation cleared.") | |