karthikeyan-r commited on
Commit
62098f3
·
verified ·
1 Parent(s): c2db41c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
3
+ import torch
4
+
5
+ # Initialize Streamlit app
6
+ st.set_page_config(page_title="Hugging Face Chat", layout="wide")
7
+
8
+ # Sidebar: Model controls
9
+ st.sidebar.title("Model Controls")
10
+ model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
11
+ load_model_button = st.sidebar.button("Load Model")
12
+ clear_conversation_button = st.sidebar.button("Clear Conversation")
13
+ clear_model_button = st.sidebar.button("Clear Model")
14
+
15
+ # Main UI
16
+ st.title("Chat Conversation UI")
17
+ st.write("Start a conversation with your Hugging Face model.")
18
+
19
+ # Initialize session states
20
+ if "model" not in st.session_state:
21
+ st.session_state["model"] = None
22
+ if "tokenizer" not in st.session_state:
23
+ st.session_state["tokenizer"] = None
24
+ if "qa_pipeline" not in st.session_state:
25
+ st.session_state["qa_pipeline"] = None
26
+ if "conversation" not in st.session_state:
27
+ st.session_state["conversation"] = []
28
+
29
+ # Load Model
30
+ if load_model_button:
31
+ with st.spinner("Loading model..."):
32
+ try:
33
+ # Set up device
34
+ device = 0 if torch.cuda.is_available() else -1
35
+
36
+ # Load model and tokenizer
37
+ st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
38
+ st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
39
+
40
+ # Initialize pipeline
41
+ st.session_state["qa_pipeline"] = pipeline(
42
+ "text2text-generation",
43
+ model=st.session_state["model"],
44
+ tokenizer=st.session_state["tokenizer"],
45
+ device=device
46
+ )
47
+
48
+ st.success("Model loaded successfully and ready!")
49
+ except Exception as e:
50
+ st.error(f"Error loading model: {e}")
51
+
52
+ # Clear Model
53
+ if clear_model_button:
54
+ st.session_state["model"] = None
55
+ st.session_state["tokenizer"] = None
56
+ st.session_state["qa_pipeline"] = None
57
+ st.success("Model cleared.")
58
+
59
+ # Chat Input and Output
60
+ if st.session_state["qa_pipeline"]:
61
+ user_input = st.text_input("Enter your query:", key="chat_input")
62
+ if st.button("Send"):
63
+ if user_input:
64
+ with st.spinner("Generating response..."):
65
+ try:
66
+ response = st.session_state["qa_pipeline"](user_input, max_length=300)
67
+ generated_text = response[0]["generated_text"]
68
+ st.session_state["conversation"].append(("You", user_input))
69
+ st.session_state["conversation"].append(("Model", generated_text))
70
+ except Exception as e:
71
+ st.error(f"Error generating response: {e}")
72
+
73
+ # Display conversation
74
+ for speaker, message in st.session_state["conversation"]:
75
+ if speaker == "You":
76
+ st.text_area("You:", message, key=message + "_you", disabled=True)
77
+ else:
78
+ st.text_area("Model:", message, key=message + "_model", disabled=True)
79
+
80
+ # Clear Conversation
81
+ if clear_conversation_button:
82
+ st.session_state["conversation"] = []
83
+ st.success("Conversation cleared.")