conversationbot / app.py
karthikeyan-r's picture
Create app.py
62098f3 verified
raw
history blame
3.22 kB
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch
# Initialize Streamlit app
st.set_page_config(page_title="Hugging Face Chat", layout="wide")
# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# Main UI
st.title("Chat Conversation UI")
st.write("Start a conversation with your Hugging Face model.")
# Initialize session states
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
st.session_state["conversation"] = []
# Load Model
if load_model_button:
with st.spinner("Loading model..."):
try:
# Set up device
device = 0 if torch.cuda.is_available() else -1
# Load model and tokenizer
st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
# Initialize pipeline
st.session_state["qa_pipeline"] = pipeline(
"text2text-generation",
model=st.session_state["model"],
tokenizer=st.session_state["tokenizer"],
device=device
)
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
# Clear Model
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# Chat Input and Output
if st.session_state["qa_pipeline"]:
user_input = st.text_input("Enter your query:", key="chat_input")
if st.button("Send"):
if user_input:
with st.spinner("Generating response..."):
try:
response = st.session_state["qa_pipeline"](user_input, max_length=300)
generated_text = response[0]["generated_text"]
st.session_state["conversation"].append(("You", user_input))
st.session_state["conversation"].append(("Model", generated_text))
except Exception as e:
st.error(f"Error generating response: {e}")
# Display conversation
for speaker, message in st.session_state["conversation"]:
if speaker == "You":
st.text_area("You:", message, key=message + "_you", disabled=True)
else:
st.text_area("Model:", message, key=message + "_model", disabled=True)
# Clear Conversation
if clear_conversation_button:
st.session_state["conversation"] = []
st.success("Conversation cleared.")