karthikeyan-r commited on
Commit
d7d1a81
·
verified ·
1 Parent(s): 2caa37b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -102
app.py CHANGED
@@ -1,26 +1,32 @@
1
  import streamlit as st
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
3
  import torch
4
 
5
- # Streamlit app setup
6
  st.set_page_config(page_title="Chat", layout="wide")
7
 
8
- # Sidebar: Model controls
9
  st.sidebar.title("Model Controls")
10
  model_options = {
11
  "1": "karthikeyan-r/calculation_model_11k",
12
  "2": "karthikeyan-r/slm-custom-model_6k"
13
  }
14
 
15
- model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
 
 
 
16
  load_model_button = st.sidebar.button("Load Model")
17
  clear_conversation_button = st.sidebar.button("Clear Conversation")
18
  clear_model_button = st.sidebar.button("Clear Model")
19
 
20
- # Main UI
21
- st.title("Chat Conversation UI")
22
-
23
- # Session states
24
  if "model" not in st.session_state:
25
  st.session_state["model"] = None
26
  if "tokenizer" not in st.session_state:
@@ -28,144 +34,148 @@ if "tokenizer" not in st.session_state:
28
  if "qa_pipeline" not in st.session_state:
29
  st.session_state["qa_pipeline"] = None
30
  if "conversation" not in st.session_state:
 
31
  st.session_state["conversation"] = []
32
- if "user_input" not in st.session_state:
33
- st.session_state["user_input"] = ""
34
 
35
- # Load Model
36
  if load_model_button:
37
  with st.spinner("Loading model..."):
38
  try:
39
- # Load the selected model
40
  if model_choice == model_options["1"]:
41
-
42
- # Load the calculation model (calculation_model)
43
  tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
44
  model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
45
-
46
- # Add special tokens if not present
47
  if tokenizer.pad_token is None:
48
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
49
  model.resize_token_embeddings(len(tokenizer))
50
-
51
  if tokenizer.eos_token is None:
52
  tokenizer.add_special_tokens({'eos_token': '[EOS]'})
53
  model.resize_token_embeddings(len(tokenizer))
54
 
55
- # Update configuration
56
  model.config.pad_token_id = tokenizer.pad_token_id
57
  model.config.eos_token_id = tokenizer.eos_token_id
58
 
59
  st.session_state["model"] = model
60
  st.session_state["tokenizer"] = tokenizer
61
- st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
62
-
63
  elif model_choice == model_options["2"]:
64
- # Load the T5 model for general QA (slm-custom-model_6k)
65
  device = 0 if torch.cuda.is_available() else -1
66
- st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
67
- st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
68
- st.session_state["qa_pipeline"] = pipeline(
69
  "text2text-generation",
70
- model=st.session_state["model"],
71
- tokenizer=st.session_state["tokenizer"],
72
  device=device
73
  )
 
 
 
 
74
  st.success("Model loaded successfully and ready!")
75
  except Exception as e:
76
  st.error(f"Error loading model: {e}")
77
 
78
- # Clear Model
79
  if clear_model_button:
80
  st.session_state["model"] = None
81
  st.session_state["tokenizer"] = None
82
  st.session_state["qa_pipeline"] = None
83
  st.success("Model cleared.")
84
 
85
- # Chat Conversation Display
86
- def display_conversation():
87
- """Display the chat conversation dynamically."""
88
- st.subheader("Conversation")
89
- for idx, (speaker, message) in enumerate(st.session_state["conversation"]):
90
- if speaker == "You":
91
- st.markdown(f"**You:** {message}")
92
- else:
93
- st.markdown(f"**Model:** {message}")
94
 
95
- display_conversation()
 
96
 
97
- # Input Area
 
 
 
 
 
 
 
 
 
 
98
  if st.session_state["qa_pipeline"]:
99
- user_input = st.text_input(
100
- "Enter your query:",
101
- value=st.session_state["user_input"], # Use session state for persistence
102
- key="chat_input",
103
- )
104
- if st.button("Send", key="send_button"):
105
- if user_input:
 
106
  with st.spinner("Generating response..."):
107
  try:
108
- # Generate the model response for general QA (T5 model)
109
  response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250)
110
  generated_text = response[0]["generated_text"]
 
 
 
 
111
 
112
- # Update the conversation
113
- st.session_state["conversation"].append(("You", user_input))
114
- st.session_state["conversation"].append(("Model", generated_text))
115
 
116
- # Clear the input field after submission
117
- st.session_state["user_input"] = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Rerender the conversation immediately
120
- display_conversation()
121
  except Exception as e:
122
- st.error(f"Error generating response: {e}")
 
 
 
 
 
123
  else:
124
- # Handle user input for the calculation model (calculation_model)
125
- if st.session_state["model"] and model_choice == model_options["1"]:
126
- user_input = st.text_input(
127
- "Enter your query for calculation:",
128
- value=st.session_state["user_input"],
129
- key="calculation_input",
130
- )
131
- if st.button("Submit", key="send_calculation_button"):
132
- if user_input:
133
- with st.spinner("Generating response..."):
134
- try:
135
- # Generate the model response for the calculation model
136
- inputs = st.session_state["tokenizer"](f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
137
- input_ids = inputs.input_ids
138
- attention_mask = inputs.attention_mask
139
-
140
- output = st.session_state["model"].generate(
141
- input_ids=input_ids,
142
- attention_mask=attention_mask,
143
- max_length=250,
144
- pad_token_id=st.session_state["tokenizer"].pad_token_id,
145
- eos_token_id=st.session_state["tokenizer"].eos_token_id,
146
- do_sample=False
147
- )
148
-
149
- decoded_output = st.session_state["tokenizer"].decode(output[0], skip_special_tokens=True)
150
- if "Output:" in decoded_output:
151
- answer = decoded_output.split("Output:")[-1].strip()
152
- else:
153
- answer = decoded_output.strip()
154
-
155
- # Update the conversation
156
- st.session_state["conversation"].append(("You", user_input))
157
- st.session_state["conversation"].append(("Model", answer))
158
-
159
- # Clear the input field after submission
160
- st.session_state["user_input"] = ""
161
-
162
- # Rerender the conversation immediately
163
- display_conversation()
164
- except Exception as e:
165
- st.error(f"Error generating response: {e}")
166
-
167
- # Clear Conversation
168
- if clear_conversation_button:
169
- st.session_state["conversation"] = []
170
- st.session_state["user_input"] = "" # Clear input field
171
- st.success("Conversation cleared.")
 
1
  import streamlit as st
2
+ from transformers import (
3
+ T5ForConditionalGeneration,
4
+ T5Tokenizer,
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM
8
+ )
9
  import torch
10
 
11
+ # ----- Streamlit page config -----
12
  st.set_page_config(page_title="Chat", layout="wide")
13
 
14
+ # ----- Sidebar: Model controls -----
15
  st.sidebar.title("Model Controls")
16
  model_options = {
17
  "1": "karthikeyan-r/calculation_model_11k",
18
  "2": "karthikeyan-r/slm-custom-model_6k"
19
  }
20
 
21
+ model_choice = st.sidebar.selectbox(
22
+ "Select Model",
23
+ options=list(model_options.values())
24
+ )
25
  load_model_button = st.sidebar.button("Load Model")
26
  clear_conversation_button = st.sidebar.button("Clear Conversation")
27
  clear_model_button = st.sidebar.button("Clear Model")
28
 
29
+ # ----- Session States -----
 
 
 
30
  if "model" not in st.session_state:
31
  st.session_state["model"] = None
32
  if "tokenizer" not in st.session_state:
 
34
  if "qa_pipeline" not in st.session_state:
35
  st.session_state["qa_pipeline"] = None
36
  if "conversation" not in st.session_state:
37
+ # We'll store conversation as a list of dicts, e.g. [{"role": "user"/"assistant", "content": "..."}]
38
  st.session_state["conversation"] = []
 
 
39
 
40
+ # ----- Load Model -----
41
  if load_model_button:
42
  with st.spinner("Loading model..."):
43
  try:
 
44
  if model_choice == model_options["1"]:
45
+ # Load the calculation model
 
46
  tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
47
  model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
48
+
49
+ # Add special tokens if needed
50
  if tokenizer.pad_token is None:
51
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
52
  model.resize_token_embeddings(len(tokenizer))
 
53
  if tokenizer.eos_token is None:
54
  tokenizer.add_special_tokens({'eos_token': '[EOS]'})
55
  model.resize_token_embeddings(len(tokenizer))
56
 
 
57
  model.config.pad_token_id = tokenizer.pad_token_id
58
  model.config.eos_token_id = tokenizer.eos_token_id
59
 
60
  st.session_state["model"] = model
61
  st.session_state["tokenizer"] = tokenizer
62
+ st.session_state["qa_pipeline"] = None # Not needed for calculation model
63
+
64
  elif model_choice == model_options["2"]:
65
+ # Load the T5 model for general QA
66
  device = 0 if torch.cuda.is_available() else -1
67
+ model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
68
+ tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
69
+ qa_pipe = pipeline(
70
  "text2text-generation",
71
+ model=model,
72
+ tokenizer=tokenizer,
73
  device=device
74
  )
75
+ st.session_state["model"] = model
76
+ st.session_state["tokenizer"] = tokenizer
77
+ st.session_state["qa_pipeline"] = qa_pipe
78
+
79
  st.success("Model loaded successfully and ready!")
80
  except Exception as e:
81
  st.error(f"Error loading model: {e}")
82
 
83
+ # ----- Clear Model -----
84
  if clear_model_button:
85
  st.session_state["model"] = None
86
  st.session_state["tokenizer"] = None
87
  st.session_state["qa_pipeline"] = None
88
  st.success("Model cleared.")
89
 
90
+ # ----- Clear Conversation -----
91
+ if clear_conversation_button:
92
+ st.session_state["conversation"] = []
93
+ st.success("Conversation cleared.")
 
 
 
 
 
94
 
95
+ # ----- Display Chat Conversation -----
96
+ st.title("Chat Conversation UI")
97
 
98
+ # Loop through existing conversation in session_state and display it
99
+ for message in st.session_state["conversation"]:
100
+ if message["role"] == "user":
101
+ with st.chat_message("user"):
102
+ st.write(message["content"])
103
+ else:
104
+ with st.chat_message("assistant"):
105
+ st.write(message["content"])
106
+
107
+ # ----- Chat Input Logic -----
108
+ # If we have a T5 pipeline (general QA model):
109
  if st.session_state["qa_pipeline"]:
110
+ # Use the new Streamlit chat input
111
+ user_input = st.chat_input("Enter your query:")
112
+ if user_input:
113
+ # 1) Save user message
114
+ st.session_state["conversation"].append({"role": "user", "content": user_input})
115
+
116
+ # 2) Generate response
117
+ with st.chat_message("assistant"):
118
  with st.spinner("Generating response..."):
119
  try:
 
120
  response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250)
121
  generated_text = response[0]["generated_text"]
122
+ except Exception as e:
123
+ generated_text = f"Error: {str(e)}"
124
+
125
+ st.write(generated_text)
126
 
127
+ # 3) Save assistant message
128
+ st.session_state["conversation"].append({"role": "assistant", "content": generated_text})
 
129
 
130
+ # If we have the calculation model loaded (model_options["1"]):
131
+ elif st.session_state["model"] and (model_choice == model_options["1"]):
132
+ user_input = st.chat_input("Enter your query for calculation:")
133
+ if user_input:
134
+ # 1) Save user message
135
+ st.session_state["conversation"].append({"role": "user", "content": user_input})
136
+
137
+ # 2) Generate response
138
+ with st.chat_message("assistant"):
139
+ with st.spinner("Generating response..."):
140
+ try:
141
+ tokenizer = st.session_state["tokenizer"]
142
+ model = st.session_state["model"]
143
+
144
+ inputs = tokenizer(
145
+ f"Input: {user_input}\nOutput:",
146
+ return_tensors="pt",
147
+ padding=True,
148
+ truncation=True
149
+ )
150
+ input_ids = inputs.input_ids
151
+ attention_mask = inputs.attention_mask
152
+
153
+ output = model.generate(
154
+ input_ids=input_ids,
155
+ attention_mask=attention_mask,
156
+ max_length=250,
157
+ pad_token_id=tokenizer.pad_token_id,
158
+ eos_token_id=tokenizer.eos_token_id,
159
+ do_sample=False
160
+ )
161
+
162
+ decoded_output = tokenizer.decode(
163
+ output[0],
164
+ skip_special_tokens=True
165
+ )
166
+ # Extract answer after 'Output:' if present
167
+ if "Output:" in decoded_output:
168
+ answer = decoded_output.split("Output:")[-1].strip()
169
+ else:
170
+ answer = decoded_output.strip()
171
 
 
 
172
  except Exception as e:
173
+ answer = f"Error: {str(e)}"
174
+
175
+ st.write(answer)
176
+
177
+ # 3) Save assistant message
178
+ st.session_state["conversation"].append({"role": "assistant", "content": answer})
179
  else:
180
+ # If no model is loaded at all
181
+ st.info("No model is loaded. Please select a model and click 'Load Model' from the sidebar.")