karthikeyan-r commited on
Commit
c2d1e89
·
verified ·
1 Parent(s): f0b46d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
3
  import torch
4
 
5
  # Streamlit app setup
@@ -7,7 +7,12 @@ st.set_page_config(page_title="Chat", layout="wide")
7
 
8
  # Sidebar: Model controls
9
  st.sidebar.title("Model Controls")
10
- model_name = st.sidebar.text_input("Enter Model Name", value="karthikeyan-r/slm-custom-model_6k")
 
 
 
 
 
11
  load_model_button = st.sidebar.button("Load Model")
12
  clear_conversation_button = st.sidebar.button("Clear Conversation")
13
  clear_model_button = st.sidebar.button("Clear Model")
@@ -31,15 +36,39 @@ if "user_input" not in st.session_state:
31
  if load_model_button:
32
  with st.spinner("Loading model..."):
33
  try:
34
- device = 0 if torch.cuda.is_available() else -1
35
- st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir="./model_cache")
36
- st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_name, cache_dir="./model_cache")
37
- st.session_state["qa_pipeline"] = pipeline(
38
- "text2text-generation",
39
- model=st.session_state["model"],
40
- tokenizer=st.session_state["tokenizer"],
41
- device=device
42
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  st.success("Model loaded successfully and ready!")
44
  except Exception as e:
45
  st.error(f"Error loading model: {e}")
@@ -74,7 +103,7 @@ if st.session_state["qa_pipeline"]:
74
  if user_input:
75
  with st.spinner("Generating response..."):
76
  try:
77
- # Generate the model response
78
  response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400)
79
  generated_text = response[0]["generated_text"]
80
 
@@ -89,6 +118,49 @@ if st.session_state["qa_pipeline"]:
89
  display_conversation()
90
  except Exception as e:
91
  st.error(f"Error generating response: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Clear Conversation
94
  if clear_conversation_button:
 
1
  import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
  # Streamlit app setup
 
7
 
8
  # Sidebar: Model controls
9
  st.sidebar.title("Model Controls")
10
+ model_options = {
11
+ "1": "karthikeyan-r/slm-custom-model_6k",
12
+ "2": "karthikeyan-r/calculation_model"
13
+ }
14
+
15
+ model_choice = st.sidebar.selectbox("Select Model", options=list(model_options.values()))
16
  load_model_button = st.sidebar.button("Load Model")
17
  clear_conversation_button = st.sidebar.button("Clear Conversation")
18
  clear_model_button = st.sidebar.button("Clear Model")
 
36
  if load_model_button:
37
  with st.spinner("Loading model..."):
38
  try:
39
+ # Load the selected model
40
+ if model_choice == model_options["1"]:
41
+ # Load the T5 model for general QA (slm-custom-model_6k)
42
+ device = 0 if torch.cuda.is_available() else -1
43
+ st.session_state["model"] = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
44
+ st.session_state["tokenizer"] = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
45
+ st.session_state["qa_pipeline"] = pipeline(
46
+ "text2text-generation",
47
+ model=st.session_state["model"],
48
+ tokenizer=st.session_state["tokenizer"],
49
+ device=device
50
+ )
51
+ elif model_choice == model_options["2"]:
52
+ # Load the calculation model (calculation_model)
53
+ tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
54
+ model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
55
+
56
+ # Add special tokens if not present
57
+ if tokenizer.pad_token is None:
58
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
59
+ model.resize_token_embeddings(len(tokenizer))
60
+
61
+ if tokenizer.eos_token is None:
62
+ tokenizer.add_special_tokens({'eos_token': '[EOS]'})
63
+ model.resize_token_embeddings(len(tokenizer))
64
+
65
+ # Update configuration
66
+ model.config.pad_token_id = tokenizer.pad_token_id
67
+ model.config.eos_token_id = tokenizer.eos_token_id
68
+
69
+ st.session_state["model"] = model
70
+ st.session_state["tokenizer"] = tokenizer
71
+ st.session_state["qa_pipeline"] = None # Calculation model doesn't use text2text pipeline
72
  st.success("Model loaded successfully and ready!")
73
  except Exception as e:
74
  st.error(f"Error loading model: {e}")
 
103
  if user_input:
104
  with st.spinner("Generating response..."):
105
  try:
106
+ # Generate the model response for general QA (T5 model)
107
  response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=400)
108
  generated_text = response[0]["generated_text"]
109
 
 
118
  display_conversation()
119
  except Exception as e:
120
  st.error(f"Error generating response: {e}")
121
+ else:
122
+ # Handle user input for the calculation model (calculation_model)
123
+ if st.session_state["model"] and model_choice == model_options["2"]:
124
+ user_input = st.text_input(
125
+ "Enter your query for calculation:",
126
+ value=st.session_state["user_input"],
127
+ key="calculation_input",
128
+ )
129
+ if st.button("Send Calculation", key="send_calculation_button"):
130
+ if user_input:
131
+ with st.spinner("Generating response..."):
132
+ try:
133
+ # Generate the model response for the calculation model
134
+ inputs = st.session_state["tokenizer"](f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
135
+ input_ids = inputs.input_ids
136
+ attention_mask = inputs.attention_mask
137
+
138
+ output = st.session_state["model"].generate(
139
+ input_ids=input_ids,
140
+ attention_mask=attention_mask,
141
+ max_length=50,
142
+ pad_token_id=st.session_state["tokenizer"].pad_token_id,
143
+ eos_token_id=st.session_state["tokenizer"].eos_token_id,
144
+ do_sample=False
145
+ )
146
+
147
+ decoded_output = st.session_state["tokenizer"].decode(output[0], skip_special_tokens=True)
148
+ if "Output:" in decoded_output:
149
+ answer = decoded_output.split("Output:")[-1].strip()
150
+ else:
151
+ answer = decoded_output.strip()
152
+
153
+ # Update the conversation
154
+ st.session_state["conversation"].append(("You", user_input))
155
+ st.session_state["conversation"].append(("Model", answer))
156
+
157
+ # Clear the input field after submission
158
+ st.session_state["user_input"] = ""
159
+
160
+ # Rerender the conversation immediately
161
+ display_conversation()
162
+ except Exception as e:
163
+ st.error(f"Error generating response: {e}")
164
 
165
  # Clear Conversation
166
  if clear_conversation_button: