piyushgrover commited on
Commit
d3dc36c
Β·
verified Β·
1 Parent(s): 632fd1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -54
app.py CHANGED
@@ -1,74 +1,296 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- from peft import PeftModel
 
 
 
 
 
 
 
 
5
 
6
- # βœ… Model and Tokenizer Loading
7
- model_name = "microsoft/phi-2"
8
- #device_map = {"": 0}
9
 
10
- # Load base model
11
- base_model = AutoModelForCausalLM.from_pretrained(
12
- model_name,
 
 
 
 
 
 
13
  low_cpu_mem_usage=True,
14
  return_dict=True,
15
- torch_dtype=torch.float16,
16
- trust_remote_code=True,
17
- device_map="auto",
18
  )
19
 
20
- # Load fine-tuned LoRA weights
21
- fine_tuned_model_path = "piyushgrover/phi2-qlora-adapter-s18erav3"
22
- model = PeftModel.from_pretrained(base_model, fine_tuned_model_path)
23
- model = model.merge_and_unload() # Merge LoRA weights
24
 
25
- # βœ… Load tokenizer
26
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
- tokenizer.pad_token = tokenizer.eos_token
28
- tokenizer.padding_side = "right"
29
 
30
- # βœ… Set up text generation pipeline
31
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=500, truncation=True)
32
 
 
 
33
 
34
- def chat(user_input, history=[]):
35
- """Generates a response from the fine-tuned Phi-2 model with conversation memory."""
36
- '''
37
- # Format conversation history
38
- formatted_history = ""
39
- for usr, bot in history:
40
- formatted_history += f"\n\n### User:\n{usr}\n\n### Assistant:\n{bot}"
41
 
42
- # Append the latest user message
43
- prompt = f"{formatted_history}\n\n### User:\n{user_input}\n\n### Assistant:\n"
44
 
45
- # Generate response
46
- response = generator(prompt, max_length=128, do_sample=True, truncation=True)
47
- answer = response[0]["generated_text"].split("### Assistant:\n")[-1].strip()
 
 
 
48
 
49
- # Append new response to history
50
- #history.append((user_input, answer))
51
 
52
- return answer
53
- '''
54
- prompt = f"\n\n### User:\n{user_input}\n\n### Assistant:\n"
55
- response = generator(prompt, max_length=128, do_sample=True, truncation=True)
56
- answer = response[0]["generated_text"].split("### Assistant:\n")[-1].strip()
57
 
58
- # Append new response to history
59
- # history.append((user_input, answer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- return answer
62
 
 
63
 
64
- # βœ… Create Gradio Chat Interface
65
- chatbot = gr.ChatInterface(
66
- fn=chat,
67
- title="Fine-Tuned Phi-2 Conversational Chat Assistant",
68
- description="πŸš€ Chat with a fine-tuned Phi-2 model. It remembers the conversation!",
69
- theme="compact",
70
- )
71
 
72
- # βœ… Launch App
73
- if __name__ == "__main__":
74
- chatbot.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import time
4
+ from PIL import Image
5
+ import torch
6
+ import whisperx
7
+
8
+
9
+ from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
10
+ from models.vision_projector_model import VisionProjector
11
+ from config import VisionProjectorConfig, app_config as cfg
12
 
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
14
 
15
+ clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
16
+ clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
+
18
+ vision_projector = VisionProjector(VisionProjectorConfig())
19
+ ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device))
20
+ vision_projector.load_state_dict(ckpt['model_state_dict'])
21
+
22
+ phi_base_model = AutoModelForCausalLM.from_pretrained(
23
+ 'microsoft/phi-2',
24
  low_cpu_mem_usage=True,
25
  return_dict=True,
26
+ torch_dtype=torch.float32,
27
+ trust_remote_code=True
28
+ # device_map=device_map,
29
  )
30
 
31
+ from peft import PeftModel
32
+ phi_new_model = "models/phi_adapter"
33
+ phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model)
34
+ phi_model = phi_model.merge_and_unload().to(device)
35
 
36
+ '''compute_type = 'float32'
37
+ if device != 'cpu':
38
+ compute_type = 'float16'''
 
39
 
40
+ audi_model = whisperx.load_model("small", device, compute_type='float16')
 
41
 
42
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
43
+ tokenizer.pad_token = tokenizer.unk_token
44
 
 
 
 
 
 
 
 
45
 
46
+ ### app functions ##
 
47
 
48
+ context_added = False
49
+ query_added = False
50
+ context = None
51
+ context_type = ''
52
+ query = ''
53
+ bot_active = False
54
 
55
+ def print_like_dislike(x: gr.LikeData):
56
+ print(x.index, x.value, x.liked)
57
 
 
 
 
 
 
58
 
59
+ def add_text(history, text):
60
+ global context, context_type, context_added, query, query_added
61
+ context_added = False
62
+ if not context_type and '</context>' not in text:
63
+ context = "**Please add context (upload image/audio or enter text followed by \</context\>"
64
+ context_type = 'error'
65
+ context_added = True
66
+ query_added = False
67
+
68
+ elif '</context>' in text:
69
+ context_type = 'text'
70
+ context_added = True
71
+ text = text.replace('</context>', ' ')
72
+ context = text
73
+ query_added = False
74
+ elif context_type in ['[text]', '[image]', '[audio]']:
75
+ query = 'Human### ' + text + '\n' + 'AI### '
76
+ query_added = True
77
+ context_added = False
78
+ else:
79
+ query_added = False
80
+ context_added = True
81
+ context = 'error'
82
+ context = "**Please provide a valid context**"
83
 
84
+ history = history + [(text, None)]
85
 
86
+ return history, gr.Textbox(value="", interactive=False)
87
 
 
 
 
 
 
 
 
88
 
89
+ def add_file(history, file):
90
+ global context_added, context, context_type, query_added
91
+
92
+ context = file
93
+ context_type = 'image'
94
+ context_added = True
95
+ query_added = False
96
+
97
+ history = history + [((file.name,), None)]
98
+
99
+ return history
100
+
101
+
102
+ def audio_upload(history, audio_file):
103
+ global context, context_type, context_added, query, query_added
104
+
105
+ if audio_file:
106
+ context_added = True
107
+ context_type = 'audio'
108
+ context = audio_file
109
+ query_added = False
110
+ history = history + [((audio_file,), None)]
111
+
112
+ else:
113
+ pass
114
+
115
+ return history
116
+
117
+
118
+ def preprocess_fn(history):
119
+ global context, context_added, query, context_type, query_added
120
+
121
+ if context_added:
122
+ if context_type == 'image':
123
+ image = Image.open(context)
124
+ inputs = clip_processor(images=image, return_tensors="pt")
125
+
126
+ x = clip_model(**inputs, output_hidden_states=True)
127
+ image_features = x.hidden_states[-2]
128
+
129
+ context = vision_projector(image_features)
130
+
131
+ elif context_type == 'audio':
132
+ audio_file = context
133
+ audio = whisperx.load_audio(audio_file)
134
+ result = audi_model.transcribe(audio, batch_size=1)
135
+
136
+ error = False
137
+ if result.get('language', None) and result.get('segments', None):
138
+ try:
139
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
140
+ result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
141
+ except Exception as e:
142
+ error = True
143
+
144
+ print(result.get('language', None))
145
+ if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None):
146
+ text = result["segments"][0].get('text', '')
147
+ print(text)
148
+ context_type = 'audio'
149
+ context_added = True
150
+ context = text
151
+ query_added = False
152
+ print(context)
153
+ else:
154
+ error = True
155
+ else:
156
+ error = True
157
+
158
+ if error:
159
+ context_type = 'error'
160
+ context_added = True
161
+ context = "**Please provide a valid audio file / context**"
162
+ query_added = False
163
+
164
+ print("Here")
165
+ return history
166
+
167
+ def bot(history):
168
+ global context, context_added, query, context_type, query_added, bot_active
169
+
170
+ response = ''
171
+ if context_added:
172
+ context_added = False
173
+ if context_type == 'error':
174
+ response = context
175
+ query = ''
176
+
177
+ elif context_type in ['image', 'audio', 'text']:
178
+ response = ''
179
+ if context_type == 'audio':
180
+ response = 'Context: \nπŸ—£ ' + '"_' + context.strip() + '_"\n\n'
181
+
182
+ response += "**Please proceed with your queries**"
183
+ query = ''
184
+ context_type = '[' + context_type + ']'
185
+ elif query_added:
186
+ query_added = False
187
+ if context_type == '[image]':
188
+ query_ids = tokenizer.encode(query)
189
+ query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
190
+ query_embeds = phi_model.get_input_embeddings()(query_ids)
191
+ inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1)
192
+ out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
193
+ bos_token_id=tokenizer.bos_token_id)
194
+ response = tokenizer.decode(out[0], skip_special_tokens=True)
195
+ elif context_type in ['[text]', '[audio]']:
196
+ input_text = context + query
197
+
198
+ input_tokens = tokenizer.encode(input_text)
199
+ input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device)
200
+ inputs_embeds = phi_model.get_input_embeddings()(input_ids)
201
+ out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
202
+ bos_token_id=tokenizer.bos_token_id)
203
+ response = tokenizer.decode(out[0], skip_special_tokens=True)
204
+ else:
205
+ query = ''
206
+ response = "**Please provide a valid context**"
207
+
208
+ if response:
209
+ bot_active = True
210
+ if history and len(history[-1]) > 1:
211
+ history[-1][1] = ""
212
+ for character in response:
213
+ history[-1][1] += character
214
+ time.sleep(0.05)
215
+ yield history
216
+
217
+ time.sleep(0.5)
218
+ bot_active = False
219
+
220
+
221
+
222
+ def clear_fn():
223
+ global context_added, context_type, context, query, query_added
224
+ context_added = False
225
+ context_type = ''
226
+ context = None
227
+ query = ''
228
+ query_added = False
229
+
230
+ return {
231
+ chatbot: None
232
+ }
233
+
234
+
235
+ with gr.Blocks() as app:
236
+ gr.Markdown(
237
+ """
238
+ # ContextGPT - A Multimodal chatbot
239
+ ### Upload image or audio to add a context. And then ask questions.
240
+ ### You can also enter text followed by \</context\> to set the context.
241
+ """
242
+ )
243
+
244
+ chatbot = gr.Chatbot(
245
+ [],
246
+ elem_id="chatbot",
247
+ bubble_full_width=False
248
+ )
249
+
250
+ with gr.Row():
251
+ txt = gr.Textbox(
252
+ scale=4,
253
+ show_label=False,
254
+ placeholder="Press enter to send ",
255
+ container=False,
256
+ )
257
+
258
+ with gr.Row():
259
+ aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
260
+ show_share_button=True)
261
+ btn = gr.UploadButton("πŸ“·", file_types=["image"])
262
+
263
+ with gr.Row():
264
+ clear = gr.Button("Clear")
265
+
266
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
267
+ preprocess_fn, chatbot, chatbot
268
+ ).then(
269
+ bot, chatbot, chatbot, api_name="bot_response"
270
+ )
271
+
272
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
273
+
274
+ file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
275
+ preprocess_fn, chatbot, chatbot
276
+ ).then(
277
+ bot, chatbot, chatbot, api_name="bot_response"
278
+ )
279
+
280
+ chatbot.like(print_like_dislike, None, None)
281
+ clear.click(clear_fn, None, chatbot, queue=False)
282
+
283
+ aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
284
+ preprocess_fn, chatbot, chatbot
285
+ ).then(
286
+ bot, chatbot, chatbot, api_name="bot_response"
287
+ )
288
+
289
+ aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
290
+ preprocess_fn, chatbot, chatbot
291
+ ).then(
292
+ bot, chatbot, chatbot, api_name="bot_response"
293
+ )
294
+
295
+ app.queue()
296
+ app.launch()