Spaces:
Running
Running
import streamlit as st | |
from transformers import ( | |
T5ForConditionalGeneration, | |
T5Tokenizer, | |
pipeline, | |
AutoTokenizer, | |
AutoModelForCausalLM | |
) | |
import torch | |
# ----- Streamlit page config ----- | |
st.set_page_config(page_title="Chat", layout="wide") | |
# ----- Sidebar: Model controls ----- | |
st.sidebar.title("Model Controls") | |
model_options = { | |
"1": "karthikeyan-r/calculation_model_11k", | |
"2": "karthikeyan-r/slm-custom-model_6k" | |
} | |
model_choice = st.sidebar.selectbox( | |
"Select Model", | |
options=list(model_options.values()) | |
) | |
load_model_button = st.sidebar.button("Load Model") | |
clear_conversation_button = st.sidebar.button("Clear Conversation") | |
clear_model_button = st.sidebar.button("Clear Model") | |
# ----- Session States ----- | |
if "model" not in st.session_state: | |
st.session_state["model"] = None | |
if "tokenizer" not in st.session_state: | |
st.session_state["tokenizer"] = None | |
if "qa_pipeline" not in st.session_state: | |
st.session_state["qa_pipeline"] = None | |
if "conversation" not in st.session_state: | |
# We'll store conversation as a list of dicts, | |
# e.g. [{"role": "assistant", "content": "Hello..."}, {"role": "user", "content": "..."}] | |
st.session_state["conversation"] = [] | |
# ----- Load Model ----- | |
if load_model_button: | |
with st.spinner("Loading model..."): | |
try: | |
if model_choice == model_options["1"]: | |
# Load the calculation model | |
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache") | |
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache") | |
# Add special tokens if needed | |
if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
if tokenizer.eos_token is None: | |
tokenizer.add_special_tokens({'eos_token': '[EOS]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.config.eos_token_id = tokenizer.eos_token_id | |
st.session_state["model"] = model | |
st.session_state["tokenizer"] = tokenizer | |
st.session_state["qa_pipeline"] = None # Not needed for calculation model | |
elif model_choice == model_options["2"]: | |
# Load the T5 model for general QA | |
device = 0 if torch.cuda.is_available() else -1 | |
model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache") | |
tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache") | |
qa_pipe = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=device | |
) | |
st.session_state["model"] = model | |
st.session_state["tokenizer"] = tokenizer | |
st.session_state["qa_pipeline"] = qa_pipe | |
# If conversation is empty, insert a welcome message | |
if len(st.session_state["conversation"]) == 0: | |
st.session_state["conversation"].append({ | |
"role": "assistant", | |
"content": "Hello! I’m your assistant. How can I help you today?" | |
}) | |
st.success("Model loaded successfully and ready!") | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
# ----- Clear Model ----- | |
if clear_model_button: | |
st.session_state["model"] = None | |
st.session_state["tokenizer"] = None | |
st.session_state["qa_pipeline"] = None | |
st.success("Model cleared.") | |
# ----- Clear Conversation ----- | |
if clear_conversation_button: | |
st.session_state["conversation"] = [] | |
st.success("Conversation cleared.") | |
# ----- Title ----- | |
st.title("Chat Conversation UI") | |
user_input = None | |
if st.session_state["qa_pipeline"]: | |
# T5 pipeline | |
user_input = st.chat_input("Enter your query:") | |
if user_input: | |
# 1) Save user message | |
st.session_state["conversation"].append({ | |
"role": "user", | |
"content": user_input | |
}) | |
# 2) Generate assistant response | |
try: | |
response = st.session_state["qa_pipeline"]( | |
f"Q: {user_input}", max_length=250 | |
) | |
answer = response[0]["generated_text"] | |
except Exception as e: | |
answer = f"Error: {str(e)}" | |
# 3) Append assistant message to conversation | |
st.session_state["conversation"].append({ | |
"role": "assistant", | |
"content": answer | |
}) | |
elif st.session_state["model"] and (model_choice == model_options["1"]): | |
# Calculation model | |
user_input = st.chat_input("Enter your query for calculation:") | |
if user_input: | |
# 1) Save user message | |
st.session_state["conversation"].append({ | |
"role": "user", | |
"content": user_input | |
}) | |
# 2) Generate assistant response | |
tokenizer = st.session_state["tokenizer"] | |
model = st.session_state["model"] | |
try: | |
inputs = tokenizer( | |
f"Input: {user_input}\nOutput:", | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
input_ids = inputs.input_ids | |
attention_mask = inputs.attention_mask | |
output = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_length=250, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=False | |
) | |
decoded_output = tokenizer.decode( | |
output[0], | |
skip_special_tokens=True | |
) | |
# Extract answer after 'Output:' if present | |
if "Output:" in decoded_output: | |
answer = decoded_output.split("Output:")[-1].strip() | |
else: | |
answer = decoded_output.strip() | |
except Exception as e: | |
answer = f"Error: {str(e)}" | |
# 3) Append assistant message to conversation | |
st.session_state["conversation"].append({ | |
"role": "assistant", | |
"content": answer | |
}) | |
else: | |
# If no model is loaded: | |
st.info("No model is loaded. Please select a model and click 'Load Model' from the sidebar.") | |
for message in st.session_state["conversation"]: | |
if message["role"] == "user": | |
with st.chat_message("user"): | |
st.write(message["content"]) | |
else: | |
with st.chat_message("assistant"): | |
st.write(message["content"]) | |