piyushgrover commited on
Commit
fe18df5
·
verified ·
1 Parent(s): 8f13c5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -66
app.py CHANGED
@@ -45,34 +45,40 @@ tokenizer.pad_token = tokenizer.unk_token
45
  ### app functions ##
46
 
47
  context_added = False
 
48
  context = None
49
  context_type = ''
50
  query = ''
51
-
52
 
53
  def print_like_dislike(x: gr.LikeData):
54
  print(x.index, x.value, x.liked)
55
 
56
 
57
  def add_text(history, text):
58
- global context, context_type, context_added, query
59
  context_added = False
60
  if not context_type and '</context>' not in text:
61
- history += text
62
- history += "**Please add context (upload image/audio or enter text followed by </context>"
63
- elif not context_type:
64
- context_type = 'text'
65
  context_added = True
66
- text = text.replace('</context>', ' ')
67
- context = text
68
- else:
69
- if '</context>' in text:
70
  context_type = 'text'
71
  context_added = True
72
  text = text.replace('</context>', ' ')
73
  context = text
74
- elif context_type in ['text', 'image']:
75
- query = 'Human### ' + text + '\n' + 'AI### '
 
 
 
 
 
 
 
 
76
 
77
  history = history + [(text, None)]
78
 
@@ -80,59 +86,104 @@ def add_text(history, text):
80
 
81
 
82
  def add_file(history, file):
83
- global context_added, context, context_type
84
- context_added = False
85
- context_type = ''
86
- context = None
87
-
88
- history = history + [((file.name,), None)]
89
- history += [("Building context...", None)]
90
- image = Image.open(file)
91
- inputs = clip_processor(images=image, return_tensors="pt")
92
-
93
- x = clip_model(**inputs, output_hidden_states=True)
94
- image_features = x.hidden_states[-2]
95
-
96
- context = vision_projector(image_features)
97
  context_type = 'image'
98
  context_added = True
 
 
 
99
 
100
  return history
101
 
102
 
103
- def audio_file(history, audio_file):
104
- global context, context_type, context_added, query
105
 
106
  if audio_file:
 
 
 
 
107
  history = history + [((audio_file,), None)]
108
- context_added = False
109
-
110
- audio = whisperx.load_audio(audio_file)
111
- result = audi_model.transcribe(audio, batch_size=1)
112
-
113
- model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
114
- result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
115
-
116
- text = result["segments"][0]["text"]
117
 
118
- resp = "🗣" + "_" + text.strip() + "_"
119
- history += [(resp, None)]
120
 
121
- context_type = 'text'
122
- context_added = True
123
- context = text
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return history
126
 
127
-
128
  def bot(history):
129
- global context, context_added, query, context_type
 
 
130
  if context_added:
131
- response = "**Please proceed with your queries**"
132
  context_added = False
133
- query = ''
134
- else:
135
- if context_type == 'image':
 
 
 
 
 
 
 
 
 
 
 
 
136
  query_ids = tokenizer.encode(query)
137
  query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
138
  query_embeds = phi_model.get_input_embeddings()(query_ids)
@@ -140,7 +191,7 @@ def bot(history):
140
  out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
141
  bos_token_id=tokenizer.bos_token_id)
142
  response = tokenizer.decode(out[0], skip_special_tokens=True)
143
- elif context_type in ['text', 'audio']:
144
  input_text = context + query
145
 
146
  input_tokens = tokenizer.encode(input_text)
@@ -150,22 +201,30 @@ def bot(history):
150
  bos_token_id=tokenizer.bos_token_id)
151
  response = tokenizer.decode(out[0], skip_special_tokens=True)
152
  else:
 
153
  response = "**Please provide a valid context**"
 
 
 
 
 
 
 
 
 
154
 
155
- if len(history[-1]) > 1:
156
- history[-1][1] = ""
157
- for character in response:
158
- history[-1][1] += character
159
- time.sleep(0.05)
160
- yield history
161
 
162
 
163
  def clear_fn():
164
- global context_added, context_type, context, query
165
  context_added = False
166
  context_type = ''
167
  context = None
168
  query = ''
 
169
 
170
  return {
171
  chatbot: None
@@ -177,7 +236,7 @@ with gr.Blocks() as app:
177
  """
178
  # ContextGPT - A Multimodel chatbot
179
  ### Upload image or audio to add a context. And then ask questions.
180
- ### You can also enter text followed by \</context\> to set the context in text format.
181
  """
182
  )
183
 
@@ -187,11 +246,6 @@ with gr.Blocks() as app:
187
  bubble_full_width=False
188
  )
189
 
190
- with gr.Row():
191
- aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
192
- show_share_button=True)
193
- btn = gr.UploadButton("📷", file_types=["image"])
194
-
195
  with gr.Row():
196
  txt = gr.Textbox(
197
  scale=4,
@@ -200,26 +254,42 @@ with gr.Blocks() as app:
200
  container=False,
201
  )
202
 
 
 
 
 
 
203
  with gr.Row():
204
  clear = gr.Button("Clear")
205
 
206
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
 
 
207
  bot, chatbot, chatbot, api_name="bot_response"
208
  )
 
209
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
210
  file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
211
- bot, chatbot, chatbot
 
 
212
  )
213
 
214
  chatbot.like(print_like_dislike, None, None)
215
  clear.click(clear_fn, None, chatbot, queue=False)
216
 
217
- aud.change(audio_file, [chatbot, aud], [chatbot], queue=False).then(
 
 
218
  bot, chatbot, chatbot, api_name="bot_response"
219
  )
220
- '''aud.upload(audio_file, [chatbot, aud], [chatbot], queue=False).then(
 
 
 
221
  bot, chatbot, chatbot, api_name="bot_response"
222
- )'''
223
 
224
  app.queue()
225
  app.launch()
 
45
  ### app functions ##
46
 
47
  context_added = False
48
+ query_added = False
49
  context = None
50
  context_type = ''
51
  query = ''
52
+ bot_active = False
53
 
54
  def print_like_dislike(x: gr.LikeData):
55
  print(x.index, x.value, x.liked)
56
 
57
 
58
  def add_text(history, text):
59
+ global context, context_type, context_added, query, query_added
60
  context_added = False
61
  if not context_type and '</context>' not in text:
62
+ context = "**Please add context (upload image/audio or enter text followed by \</context\>"
63
+ context_type = 'error'
 
 
64
  context_added = True
65
+ query_added = False
66
+
67
+ elif '</context>' in text:
 
68
  context_type = 'text'
69
  context_added = True
70
  text = text.replace('</context>', ' ')
71
  context = text
72
+ query_added = False
73
+ elif context_type in ['[text]', '[image]', '[audio]']:
74
+ query = 'Human### ' + text + '\n' + 'AI### '
75
+ query_added = True
76
+ context_added = False
77
+ else:
78
+ query_added = False
79
+ context_added = True
80
+ context = 'error'
81
+ context = "**Please provide a valid context**"
82
 
83
  history = history + [(text, None)]
84
 
 
86
 
87
 
88
  def add_file(history, file):
89
+ global context_added, context, context_type, query_added
90
+
91
+ context = file
 
 
 
 
 
 
 
 
 
 
 
92
  context_type = 'image'
93
  context_added = True
94
+ query_added = False
95
+
96
+ history = history + [((file.name,), None)]
97
 
98
  return history
99
 
100
 
101
+ def audio_upload(history, audio_file):
102
+ global context, context_type, context_added, query, query_added
103
 
104
  if audio_file:
105
+ context_added = True
106
+ context_type = 'audio'
107
+ context = audio_file
108
+ query_added = False
109
  history = history + [((audio_file,), None)]
110
+
111
+ else:
112
+ pass
 
 
 
 
 
 
113
 
114
+ return history
 
115
 
 
 
 
116
 
117
+ def preprocess_fn(history):
118
+ global context, context_added, query, context_type, query_added
119
+
120
+ if context_added:
121
+ if context_type == 'image':
122
+ image = Image.open(context)
123
+ inputs = clip_processor(images=image, return_tensors="pt")
124
+
125
+ x = clip_model(**inputs, output_hidden_states=True)
126
+ image_features = x.hidden_states[-2]
127
+
128
+ context = vision_projector(image_features)
129
+
130
+ elif context_type == 'audio':
131
+ audio_file = context
132
+ audio = whisperx.load_audio(audio_file)
133
+ result = audi_model.transcribe(audio, batch_size=1)
134
+
135
+ error = False
136
+ if result.get('language', None) and result.get('segments', None):
137
+ try:
138
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
139
+ result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
140
+ except Exception as e:
141
+ error = True
142
+
143
+ print(result.get('language', None))
144
+ if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None):
145
+ text = result["segments"][0].get('text', '')
146
+ print(text)
147
+ context_type = 'audio'
148
+ context_added = True
149
+ context = text
150
+ query_added = False
151
+ print(context)
152
+ else:
153
+ error = True
154
+ else:
155
+ error = True
156
+
157
+ if error:
158
+ context_type = 'error'
159
+ context_added = True
160
+ context = "**Please provide a valid audio file / context**"
161
+ query_added = False
162
+
163
+ print("Here")
164
  return history
165
 
 
166
  def bot(history):
167
+ global context, context_added, query, context_type, query_added, bot_active
168
+
169
+ response = ''
170
  if context_added:
 
171
  context_added = False
172
+ if context_type == 'error':
173
+ response = context
174
+ query = ''
175
+
176
+ elif context_type in ['image', 'audio', 'text']:
177
+ response = ''
178
+ if context_type == 'audio':
179
+ response = 'Context: \n🗣 ' + '"_' + context.strip() + '_"\n\n'
180
+
181
+ response += "**Please proceed with your queries**"
182
+ query = ''
183
+ context_type = '[' + context_type + ']'
184
+ elif query_added:
185
+ query_added = False
186
+ if context_type == '[image]':
187
  query_ids = tokenizer.encode(query)
188
  query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
189
  query_embeds = phi_model.get_input_embeddings()(query_ids)
 
191
  out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
192
  bos_token_id=tokenizer.bos_token_id)
193
  response = tokenizer.decode(out[0], skip_special_tokens=True)
194
+ elif context_type in ['[text]', '[audio]']:
195
  input_text = context + query
196
 
197
  input_tokens = tokenizer.encode(input_text)
 
201
  bos_token_id=tokenizer.bos_token_id)
202
  response = tokenizer.decode(out[0], skip_special_tokens=True)
203
  else:
204
+ query = ''
205
  response = "**Please provide a valid context**"
206
+
207
+ if response:
208
+ bot_active = True
209
+ if history and len(history[-1]) > 1:
210
+ history[-1][1] = ""
211
+ for character in response:
212
+ history[-1][1] += character
213
+ time.sleep(0.05)
214
+ yield history
215
 
216
+ time.sleep(0.5)
217
+ bot_active = False
218
+
 
 
 
219
 
220
 
221
  def clear_fn():
222
+ global context_added, context_type, context, query, query_added
223
  context_added = False
224
  context_type = ''
225
  context = None
226
  query = ''
227
+ query_added = False
228
 
229
  return {
230
  chatbot: None
 
236
  """
237
  # ContextGPT - A Multimodel chatbot
238
  ### Upload image or audio to add a context. And then ask questions.
239
+ ### You can also enter text followed by \</context\> to set the context.
240
  """
241
  )
242
 
 
246
  bubble_full_width=False
247
  )
248
 
 
 
 
 
 
249
  with gr.Row():
250
  txt = gr.Textbox(
251
  scale=4,
 
254
  container=False,
255
  )
256
 
257
+ with gr.Row():
258
+ aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
259
+ show_share_button=True)
260
+ btn = gr.UploadButton("📷", file_types=["image"])
261
+
262
  with gr.Row():
263
  clear = gr.Button("Clear")
264
 
265
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
266
+ preprocess_fn, chatbot, chatbot
267
+ ).then(
268
  bot, chatbot, chatbot, api_name="bot_response"
269
  )
270
+
271
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
272
+
273
  file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
274
+ preprocess_fn, chatbot, chatbot
275
+ ).then(
276
+ bot, chatbot, chatbot, api_name="bot_response"
277
  )
278
 
279
  chatbot.like(print_like_dislike, None, None)
280
  clear.click(clear_fn, None, chatbot, queue=False)
281
 
282
+ aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
283
+ preprocess_fn, chatbot, chatbot
284
+ ).then(
285
  bot, chatbot, chatbot, api_name="bot_response"
286
  )
287
+
288
+ aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
289
+ preprocess_fn, chatbot, chatbot
290
+ ).then(
291
  bot, chatbot, chatbot, api_name="bot_response"
292
+ )
293
 
294
  app.queue()
295
  app.launch()