File size: 3,218 Bytes
62098f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch

# Initialize Streamlit app
st.set_page_config(page_title="Hugging Face Chat", layout="wide")

# Sidebar: Model controls
st.sidebar.title("Model Controls")
model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")

# Main UI
st.title("Chat Conversation UI")
st.write("Start a conversation with your Hugging Face model.")

# Initialize session states
if "model" not in st.session_state:
    st.session_state["model"] = None
if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
    st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
    st.session_state["conversation"] = []

# Load Model
if load_model_button:
    with st.spinner("Loading model..."):
        try:
            # Set up device
            device = 0 if torch.cuda.is_available() else -1
            
            # Load model and tokenizer
            st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
            st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
            
            # Initialize pipeline
            st.session_state["qa_pipeline"] = pipeline(
                "text2text-generation",
                model=st.session_state["model"],
                tokenizer=st.session_state["tokenizer"],
                device=device
            )
            
            st.success("Model loaded successfully and ready!")
        except Exception as e:
            st.error(f"Error loading model: {e}")

# Clear Model
if clear_model_button:
    st.session_state["model"] = None
    st.session_state["tokenizer"] = None
    st.session_state["qa_pipeline"] = None
    st.success("Model cleared.")

# Chat Input and Output
if st.session_state["qa_pipeline"]:
    user_input = st.text_input("Enter your query:", key="chat_input")
    if st.button("Send"):
        if user_input:
            with st.spinner("Generating response..."):
                try:
                    response = st.session_state["qa_pipeline"](user_input, max_length=300)
                    generated_text = response[0]["generated_text"]
                    st.session_state["conversation"].append(("You", user_input))
                    st.session_state["conversation"].append(("Model", generated_text))
                except Exception as e:
                    st.error(f"Error generating response: {e}")

    # Display conversation
    for speaker, message in st.session_state["conversation"]:
        if speaker == "You":
            st.text_area("You:", message, key=message + "_you", disabled=True)
        else:
            st.text_area("Model:", message, key=message + "_model", disabled=True)

# Clear Conversation
if clear_conversation_button:
    st.session_state["conversation"] = []
    st.success("Conversation cleared.")