File size: 3,474 Bytes
62098f3
 
 
 
409d83c
62098f3
 
 
 
 
 
 
 
 
 
c4c458c
62098f3
409d83c
62098f3
 
 
 
 
 
 
 
d371388
 
62098f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c458c
f70af54
 
 
 
 
 
62098f3
c4c458c
f70af54
 
 
 
 
 
 
 
732cf88
 
 
 
 
 
f70af54
 
732cf88
 
c4c458c
62098f3
 
 
f70af54
e559684
f70af54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch

# Streamlit app setup
st.set_page_config(page_title="Hugging Face Chat", layout="wide")

# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")

# Main UI
st.title("Chat Conversation UI")

# Session states
if "model" not in st.session_state:
    st.session_state["model"] = None
if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
    st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
    st.session_state["conversation"] = []
if "user_input" not in st.session_state:
    st.session_state["user_input"] = ""

# Load Model
if load_model_button:
    with st.spinner("Loading model..."):
        try:
            device = 0 if torch.cuda.is_available() else -1
            st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
            st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
            st.session_state["qa_pipeline"] = pipeline(
                "text2text-generation",
                model=st.session_state["model"],
                tokenizer=st.session_state["tokenizer"],
                device=device
            )
            st.success("Model loaded successfully and ready!")
        except Exception as e:
            st.error(f"Error loading model: {e}")

# Clear Model
if clear_model_button:
    st.session_state["model"] = None
    st.session_state["tokenizer"] = None
    st.session_state["qa_pipeline"] = None
    st.success("Model cleared.")

# Chat Conversation Display
st.subheader("Conversation")
for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
    if speaker == "You":
        st.markdown(f"**You:** {message}")
    else:
        st.markdown(f"**Model:** {message}")

# Input Area
if st.session_state["qa_pipeline"]:
    user_input = st.text_input(
        "Enter your query:",
        value=st.session_state["user_input"],  # Use session state for persistence
        key="chat_input",
    )
    if st.button("Send", key="send_button"):
        if user_input:
            with st.spinner("Generating response..."):
                try:
                    response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400)
                    generated_text = response[0]["generated_text"]
                    st.session_state["conversation"].append(("You", user_input))
                    st.session_state["conversation"].append(("Model", generated_text))
                    st.session_state["user_input"] = ""  # Clear input after submission
                    st.experimental_rerun()  # Rerun to update the conversation display
                except Exception as e:
                    st.error(f"Error generating response: {e}")

# Clear Conversation
if clear_conversation_button:
    st.session_state["conversation"] = []
    st.session_state["user_input"] = ""  # Clear input field
    st.success("Conversation cleared.")
    st.experimental_rerun()  # Rerun to refresh the cleared conversation