piyushgrover commited on
Commit
632fd1a
Β·
verified Β·
1 Parent(s): b0fb27f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -276
app.py CHANGED
@@ -1,296 +1,74 @@
1
- import gradio as gr
2
- import os
3
- import time
4
- from PIL import Image
5
  import torch
6
- import whisperx
7
-
8
-
9
- from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
10
- from models.vision_projector_model import VisionProjector
11
- from config import VisionProjectorConfig, app_config as cfg
12
-
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
-
15
- clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
16
- clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
17
 
18
- vision_projector = VisionProjector(VisionProjectorConfig())
19
- ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device))
20
- vision_projector.load_state_dict(ckpt['model_state_dict'])
21
 
22
- phi_base_model = AutoModelForCausalLM.from_pretrained(
23
- 'microsoft/phi-2',
 
24
  low_cpu_mem_usage=True,
25
  return_dict=True,
26
- torch_dtype=torch.float32,
27
- trust_remote_code=True
28
- # device_map=device_map,
29
  )
30
 
31
- from peft import PeftModel
32
- phi_new_model = "models/phi_adapter"
33
- phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model)
34
- phi_model = phi_model.merge_and_unload().to(device)
35
-
36
- '''compute_type = 'float32'
37
- if device != 'cpu':
38
- compute_type = 'float16'''
39
-
40
- audi_model = whisperx.load_model("small", device, compute_type='float16')
41
-
42
- tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
43
- tokenizer.pad_token = tokenizer.unk_token
44
-
45
-
46
- ### app functions ##
47
-
48
- context_added = False
49
- query_added = False
50
- context = None
51
- context_type = ''
52
- query = ''
53
- bot_active = False
54
-
55
- def print_like_dislike(x: gr.LikeData):
56
- print(x.index, x.value, x.liked)
57
-
58
-
59
- def add_text(history, text):
60
- global context, context_type, context_added, query, query_added
61
- context_added = False
62
- if not context_type and '</context>' not in text:
63
- context = "**Please add context (upload image/audio or enter text followed by \</context\>"
64
- context_type = 'error'
65
- context_added = True
66
- query_added = False
67
-
68
- elif '</context>' in text:
69
- context_type = 'text'
70
- context_added = True
71
- text = text.replace('</context>', ' ')
72
- context = text
73
- query_added = False
74
- elif context_type in ['[text]', '[image]', '[audio]']:
75
- query = 'Human### ' + text + '\n' + 'AI### '
76
- query_added = True
77
- context_added = False
78
- else:
79
- query_added = False
80
- context_added = True
81
- context = 'error'
82
- context = "**Please provide a valid context**"
83
-
84
- history = history + [(text, None)]
85
-
86
- return history, gr.Textbox(value="", interactive=False)
87
-
88
 
89
- def add_file(history, file):
90
- global context_added, context, context_type, query_added
91
-
92
- context = file
93
- context_type = 'image'
94
- context_added = True
95
- query_added = False
96
 
97
- history = history + [((file.name,), None)]
 
98
 
99
- return history
100
 
 
 
 
 
 
 
 
101
 
102
- def audio_upload(history, audio_file):
103
- global context, context_type, context_added, query, query_added
104
 
105
- if audio_file:
106
- context_added = True
107
- context_type = 'audio'
108
- context = audio_file
109
- query_added = False
110
- history = history + [((audio_file,), None)]
111
-
112
- else:
113
- pass
114
 
115
- return history
 
116
 
 
 
 
 
 
117
 
118
- def preprocess_fn(history):
119
- global context, context_added, query, context_type, query_added
120
-
121
- if context_added:
122
- if context_type == 'image':
123
- image = Image.open(context)
124
- inputs = clip_processor(images=image, return_tensors="pt")
125
 
126
- x = clip_model(**inputs, output_hidden_states=True)
127
- image_features = x.hidden_states[-2]
128
 
129
- context = vision_projector(image_features)
130
-
131
- elif context_type == 'audio':
132
- audio_file = context
133
- audio = whisperx.load_audio(audio_file)
134
- result = audi_model.transcribe(audio, batch_size=1)
135
 
136
- error = False
137
- if result.get('language', None) and result.get('segments', None):
138
- try:
139
- model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
140
- result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
141
- except Exception as e:
142
- error = True
143
-
144
- print(result.get('language', None))
145
- if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None):
146
- text = result["segments"][0].get('text', '')
147
- print(text)
148
- context_type = 'audio'
149
- context_added = True
150
- context = text
151
- query_added = False
152
- print(context)
153
- else:
154
- error = True
155
- else:
156
- error = True
157
-
158
- if error:
159
- context_type = 'error'
160
- context_added = True
161
- context = "**Please provide a valid audio file / context**"
162
- query_added = False
163
-
164
- print("Here")
165
- return history
166
-
167
- def bot(history):
168
- global context, context_added, query, context_type, query_added, bot_active
169
-
170
- response = ''
171
- if context_added:
172
- context_added = False
173
- if context_type == 'error':
174
- response = context
175
- query = ''
176
-
177
- elif context_type in ['image', 'audio', 'text']:
178
- response = ''
179
- if context_type == 'audio':
180
- response = 'Context: \nπŸ—£ ' + '"_' + context.strip() + '_"\n\n'
181
-
182
- response += "**Please proceed with your queries**"
183
- query = ''
184
- context_type = '[' + context_type + ']'
185
- elif query_added:
186
- query_added = False
187
- if context_type == '[image]':
188
- query_ids = tokenizer.encode(query)
189
- query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
190
- query_embeds = phi_model.get_input_embeddings()(query_ids)
191
- inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1)
192
- out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
193
- bos_token_id=tokenizer.bos_token_id)
194
- response = tokenizer.decode(out[0], skip_special_tokens=True)
195
- elif context_type in ['[text]', '[audio]']:
196
- input_text = context + query
197
-
198
- input_tokens = tokenizer.encode(input_text)
199
- input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device)
200
- inputs_embeds = phi_model.get_input_embeddings()(input_ids)
201
- out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
202
- bos_token_id=tokenizer.bos_token_id)
203
- response = tokenizer.decode(out[0], skip_special_tokens=True)
204
- else:
205
- query = ''
206
- response = "**Please provide a valid context**"
207
-
208
- if response:
209
- bot_active = True
210
- if history and len(history[-1]) > 1:
211
- history[-1][1] = ""
212
- for character in response:
213
- history[-1][1] += character
214
- time.sleep(0.05)
215
- yield history
216
-
217
- time.sleep(0.5)
218
- bot_active = False
219
-
220
-
221
-
222
- def clear_fn():
223
- global context_added, context_type, context, query, query_added
224
- context_added = False
225
- context_type = ''
226
- context = None
227
- query = ''
228
- query_added = False
229
-
230
- return {
231
- chatbot: None
232
- }
233
-
234
-
235
- with gr.Blocks() as app:
236
- gr.Markdown(
237
- """
238
- # ContextGPT - A Multimodal chatbot
239
- ### Upload image or audio to add a context. And then ask questions.
240
- ### You can also enter text followed by \</context\> to set the context.
241
- """
242
- )
243
-
244
- chatbot = gr.Chatbot(
245
- [],
246
- elem_id="chatbot",
247
- bubble_full_width=False
248
- )
249
-
250
- with gr.Row():
251
- txt = gr.Textbox(
252
- scale=4,
253
- show_label=False,
254
- placeholder="Press enter to send ",
255
- container=False,
256
- )
257
-
258
- with gr.Row():
259
- aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
260
- show_share_button=True)
261
- btn = gr.UploadButton("πŸ“·", file_types=["image"])
262
-
263
- with gr.Row():
264
- clear = gr.Button("Clear")
265
-
266
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
267
- preprocess_fn, chatbot, chatbot
268
- ).then(
269
- bot, chatbot, chatbot, api_name="bot_response"
270
- )
271
-
272
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
273
-
274
- file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
275
- preprocess_fn, chatbot, chatbot
276
- ).then(
277
- bot, chatbot, chatbot, api_name="bot_response"
278
- )
279
-
280
- chatbot.like(print_like_dislike, None, None)
281
- clear.click(clear_fn, None, chatbot, queue=False)
282
-
283
- aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
284
- preprocess_fn, chatbot, chatbot
285
- ).then(
286
- bot, chatbot, chatbot, api_name="bot_response"
287
- )
288
-
289
- aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
290
- preprocess_fn, chatbot, chatbot
291
- ).then(
292
- bot, chatbot, chatbot, api_name="bot_response"
293
- )
294
 
295
- app.queue()
296
- app.launch()
 
 
 
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from peft import PeftModel
 
 
 
 
 
 
 
 
5
 
6
+ # βœ… Model and Tokenizer Loading
7
+ model_name = "microsoft/phi-2"
8
+ #device_map = {"": 0}
9
 
10
+ # Load base model
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
  low_cpu_mem_usage=True,
14
  return_dict=True,
15
+ torch_dtype=torch.float16,
16
+ trust_remote_code=True,
17
+ device_map="auto",
18
  )
19
 
20
+ # Load fine-tuned LoRA weights
21
+ fine_tuned_model_path = "piyushgrover/phi2-qlora-adapter-s18erav3"
22
+ model = PeftModel.from_pretrained(base_model, fine_tuned_model_path)
23
+ model = model.merge_and_unload() # Merge LoRA weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # βœ… Load tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+ tokenizer.padding_side = "right"
 
 
 
29
 
30
+ # βœ… Set up text generation pipeline
31
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=500, truncation=True)
32
 
 
33
 
34
+ def chat(user_input, history=[]):
35
+ """Generates a response from the fine-tuned Phi-2 model with conversation memory."""
36
+ '''
37
+ # Format conversation history
38
+ formatted_history = ""
39
+ for usr, bot in history:
40
+ formatted_history += f"\n\n### User:\n{usr}\n\n### Assistant:\n{bot}"
41
 
42
+ # Append the latest user message
43
+ prompt = f"{formatted_history}\n\n### User:\n{user_input}\n\n### Assistant:\n"
44
 
45
+ # Generate response
46
+ response = generator(prompt, max_length=128, do_sample=True, truncation=True)
47
+ answer = response[0]["generated_text"].split("### Assistant:\n")[-1].strip()
 
 
 
 
 
 
48
 
49
+ # Append new response to history
50
+ #history.append((user_input, answer))
51
 
52
+ return answer
53
+ '''
54
+ prompt = f"\n\n### User:\n{user_input}\n\n### Assistant:\n"
55
+ response = generator(prompt, max_length=128, do_sample=True, truncation=True)
56
+ answer = response[0]["generated_text"].split("### Assistant:\n")[-1].strip()
57
 
58
+ # Append new response to history
59
+ # history.append((user_input, answer))
 
 
 
 
 
60
 
61
+ return answer
 
62
 
 
 
 
 
 
 
63
 
64
+ # βœ… Create Gradio Chat Interface
65
+ chatbot = gr.ChatInterface(
66
+ fn=chat,
67
+ title="Fine-Tuned Phi-2 Conversational Chat Assistant",
68
+ description="πŸš€ Chat with a fine-tuned Phi-2 model. It remembers the conversation!",
69
+ theme="compact",
70
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # βœ… Launch App
73
+ if __name__ == "__main__":
74
+ chatbot.launch(debug=True)