piyushgrover commited on
Commit
f414499
·
1 Parent(s): 7396aab

added app files

Browse files
Files changed (3) hide show
  1. app.py +220 -0
  2. config.py +2 -144
  3. requirement.txt +12 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import time
4
+ from PIL import Image
5
+ import torch
6
+ import whisperx
7
+
8
+ from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
9
+ from models.vision_projector_model import VisionProjector
10
+ from config import VisionProjectorConfig, app_config as cfg
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
15
+ clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+
17
+ vision_projector = VisionProjector(VisionProjectorConfig())
18
+ ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device))
19
+ vision_projector.load_state_dict(ckpt['model_state_dict'])
20
+
21
+ phi_base_model = AutoModelForCausalLM.from_pretrained(
22
+ 'microsoft/phi-2',
23
+ low_cpu_mem_usage=True,
24
+ return_dict=True,
25
+ torch_dtype=torch.float32,
26
+ trust_remote_code=True
27
+ # device_map=device_map,
28
+ )
29
+
30
+ from peft import PeftModel
31
+ phi_new_model = "models/phi_adapter"
32
+ phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model)
33
+ phi_model = phi_model.merge_and_unload()
34
+
35
+ audi_model = whisperx.load_model("large-v2", device, compute_type='float16')
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2', trust_remote_code=True)
38
+ tokenizer.pad_token = tokenizer.unk_token
39
+
40
+
41
+ ### app functions ##
42
+ context_added = False
43
+ context = None
44
+ context_type = ''
45
+ query = ''
46
+
47
+
48
+ def print_like_dislike(x: gr.LikeData):
49
+ print(x.index, x.value, x.liked)
50
+
51
+
52
+ def add_text(history, text):
53
+ global context, context_type, context_added, query
54
+ context_added = False
55
+ if not context_type and '</context>' not in text:
56
+ history += text
57
+ history += "**Please add context (upload image/audio or enter text followed by </context>"
58
+ elif not context_type:
59
+ context_type = 'text'
60
+ context_added = True
61
+ text = text.replace('</context>', ' ')
62
+ context = text
63
+ else:
64
+ if '</context>' in text:
65
+ context_type = 'text'
66
+ context_added = True
67
+ text = text.replace('</context>', ' ')
68
+ context = text
69
+ elif context_type in ['text', 'image']:
70
+ query = 'Human### ' + text + '\n' + 'AI### '
71
+
72
+ history = history + [(text, None)]
73
+
74
+ return history, gr.Textbox(value="", interactive=False)
75
+
76
+
77
+ def add_file(history, file):
78
+ global context_added, context, context_type
79
+ context_added = False
80
+ context_type = ''
81
+ context = None
82
+
83
+ history = history + [((file.name,), None)]
84
+ history += [("Building context...", None)]
85
+ image = Image.open(file)
86
+ inputs = clip_processor(images=image, return_tensors="pt")
87
+
88
+ x = clip_model(**inputs, output_hidden_states=True)
89
+ image_features = x.hidden_states[-2]
90
+
91
+ context = vision_projector(image_features)
92
+ context_type = 'image'
93
+ context_added = True
94
+
95
+ return history
96
+
97
+
98
+ def audio_file(history, audio_file):
99
+ global context, context_type, context_added, query
100
+
101
+ if audio_file:
102
+ history = history + [((audio_file,), None)]
103
+ context_added = False
104
+
105
+ audio = whisperx.load_audio(audio_file)
106
+ result = audi_model.transcribe(audio, batch_size=1)
107
+
108
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
109
+ result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
110
+
111
+ text = result["segments"][0]["text"]
112
+
113
+ resp = "🗣" + "_" + text.strip() + "_"
114
+ history += [(resp, None)]
115
+
116
+ context_type = 'text'
117
+ context_added = True
118
+ context = text
119
+
120
+ return history
121
+
122
+
123
+ def bot(history):
124
+ global context, context_added, query, context_type
125
+ if context_added:
126
+ response = "**Please proceed with your queries**"
127
+ context_added = False
128
+ query = ''
129
+ else:
130
+ if context_type == 'image':
131
+ query_ids = tokenizer.encode(query)
132
+ query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0)
133
+ query_embeds = phi_model.get_input_embeddings()(query_ids)
134
+ inputs_embeds = torch.cat([context, query_embeds], dim=1)
135
+ out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
136
+ bos_token_id=tokenizer.bos_token_id)
137
+ response = tokenizer.decode(out[0], skip_special_tokens=True)
138
+ elif context_type in ['text', 'audio']:
139
+ input_text = context + query
140
+
141
+ input_tokens = tokenizer.encode(input_text)
142
+ input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0)
143
+ inputs_embeds = phi_model.get_input_embeddings()(input_ids)
144
+ out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
145
+ bos_token_id=tokenizer.bos_token_id)
146
+ response = tokenizer.decode(out[0], skip_special_tokens=True)
147
+ else:
148
+ response = "**Please provide a valid context**"
149
+
150
+ if len(history[-1]) > 1:
151
+ history[-1][1] = ""
152
+ for character in response:
153
+ history[-1][1] += character
154
+ time.sleep(0.05)
155
+ yield history
156
+
157
+
158
+ def clear_fn():
159
+ global context_added, context_type, context, query
160
+ context_added = False
161
+ context_type = ''
162
+ context = None
163
+ query = ''
164
+
165
+ return {
166
+ chatbot: None
167
+ }
168
+
169
+
170
+ with gr.Blocks() as app:
171
+ gr.Markdown(
172
+ """
173
+ # ContextGPT - A Multimodel chatbot
174
+ ### Upload image or audio to add a context. And then ask questions.
175
+ ### You can also enter text followed by \</context\> to set the context in text format.
176
+ """
177
+ )
178
+
179
+ chatbot = gr.Chatbot(
180
+ [],
181
+ elem_id="chatbot",
182
+ bubble_full_width=False
183
+ )
184
+
185
+ with gr.Row():
186
+ aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
187
+ show_share_button=True)
188
+ btn = gr.UploadButton("📷", file_types=["image"])
189
+
190
+ with gr.Row():
191
+ txt = gr.Textbox(
192
+ scale=4,
193
+ show_label=False,
194
+ placeholder="Press enter to send ",
195
+ container=False,
196
+ )
197
+
198
+ with gr.Row():
199
+ clear = gr.Button("Clear")
200
+
201
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
202
+ bot, chatbot, chatbot, api_name="bot_response"
203
+ )
204
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
205
+ file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
206
+ bot, chatbot, chatbot
207
+ )
208
+
209
+ chatbot.like(print_like_dislike, None, None)
210
+ clear.click(clear_fn, None, chatbot, queue=False)
211
+
212
+ aud.stop_recording(audio_file, [chatbot, aud], [chatbot], queue=False).then(
213
+ bot, chatbot, chatbot, api_name="bot_response"
214
+ )
215
+ aud.upload(audio_file, [chatbot, aud], [chatbot], queue=False).then(
216
+ bot, chatbot, chatbot, api_name="bot_response"
217
+ )
218
+
219
+ app.queue()
220
+ app.launch()
config.py CHANGED
@@ -20,154 +20,12 @@ class VisionProjectorConfig(PretrainedConfig):
20
  self.kwargs = kwargs
21
 
22
 
23
- class CustomPhiConfig(PretrainedConfig):
24
- model_type = "phi-msft"
25
- attribute_map = {
26
- "max_position_embeddings": "n_positions",
27
- "hidden_size": "n_embd",
28
- "num_attention_heads": "n_head",
29
- "num_hidden_layers": "n_layer",
30
- }
31
-
32
- def __init__(
33
- self,
34
- vocab_size: int = 51200,
35
- n_positions: int = 2048,
36
- n_embd: int = 2560,
37
- n_layer: int = 32,
38
- n_inner: Optional[int] = None,
39
- n_head: int = 32,
40
- n_head_kv: Optional[int] = None,
41
- rotary_dim: Optional[int] = 32,
42
- activation_function: Optional[str] = "gelu_new",
43
- flash_attn: bool = False,
44
- flash_rotary: bool = False,
45
- fused_dense: bool = False,
46
- attn_pdrop: float = 0.0,
47
- embd_pdrop: float = 0.0,
48
- resid_pdrop: float = 0.1,
49
- layer_norm_epsilon: float = 1e-05,
50
- initializer_range: float = 0.02,
51
- tie_word_embeddings: bool = False,
52
- pad_vocab_size_multiple: int = 64,
53
- **kwargs
54
- ) -> None:
55
- self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
56
- self.n_positions = n_positions
57
- self.n_embd = n_embd
58
- self.n_layer = n_layer
59
- self.n_inner = n_inner
60
- self.n_head = n_head
61
- self.n_head_kv = n_head_kv
62
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
63
- self.activation_function = activation_function
64
- self.flash_attn = flash_attn
65
- self.flash_rotary = flash_rotary
66
- self.fused_dense = fused_dense
67
- self.attn_pdrop = attn_pdrop
68
- self.embd_pdrop = embd_pdrop
69
- self.resid_pdrop = resid_pdrop
70
- self.layer_norm_epsilon = layer_norm_epsilon
71
- self.initializer_range = initializer_range
72
-
73
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
74
-
75
-
76
- class CLIPVisionToPhiConfig(PretrainedConfig):
77
- def __init__(self,
78
- vision_projector_config: VisionProjectorConfig,
79
- phi_config: CustomPhiConfig,
80
- **kwargs
81
- ):
82
-
83
- #super().__init__(**kwargs)
84
-
85
- self.vision_projector_config = vision_projector_config
86
- self.phi_config = phi_config
87
- self.tokenizer = kwargs.get('tokenizer')
88
- self.freeze_phi_model = True
89
-
90
-
91
- '''
92
- phi_config_obj = CustomPhiConfig(
93
- **{
94
- "_name_or_path": "microsoft/phi-2",
95
- "architectures": [
96
- "PhiForCausalLM"
97
- ],
98
- "auto_map": {
99
- "AutoConfig": "configuration_phi.PhiConfig",
100
- "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
101
- },
102
- "img_processor": None,
103
- "model_type": "phi-msft",
104
- "torch_dtype": "float16",
105
- "transformers_version": "4.35.2"
106
- }
107
-
108
- )
109
-
110
- '''
111
- from peft import LoraConfig
112
-
113
- bnb_config = BitsAndBytesConfig(
114
- load_in_4bit=True,
115
- bnb_4bit_quant_type="nf4",
116
- bnb_4bit_compute_dtype=torch.float16
117
- )
118
-
119
- peft_config = LoraConfig(
120
- lora_alpha=16,
121
- lora_dropout=0.1,
122
- r=64,
123
- bias="none",
124
- task_type="CAUSAL_LM",
125
- target_modules=[
126
- "q_proj",
127
- "k_proj",
128
- "v_proj",
129
- "dense",
130
- "fc1",
131
- "fc2"
132
- ]
133
- )
134
-
135
- class MultiInstructModelConfig(PretrainedConfig):
136
- def __init__(self,
137
- vision_projector_config: Optional[VisionProjectorConfig] = None,
138
- **kwargs
139
- ):
140
-
141
- self.vision_projector_config = vision_projector_config
142
- self.quantization_config = bnb_config
143
-
144
- self.peft_config = peft_config
145
-
146
- self.tokenizer = kwargs.get('tokenizer')
147
- self.freeze_vision_projector = True
148
-
149
-
150
- extra = dict(
151
- num_epochs=1,
152
- resume=False,
153
- data_dir='../data',
154
- checkpoint_dir='../checkpoints',
155
- max_seqlen=80,
156
- batch_size=2,
157
- live_image_processing=True,
158
- vision_projector_file='/Users/piyushgrover/Downloads/old_vt_proj/vp_ckpt_0.pth',
159
- validation_phase=False
160
- )
161
-
162
- qlora_config = dict(
163
- num_steps=1000,
164
  max_seqlen=512,
165
  max_caption_len=100,
166
- batch_size=8,
167
- micro_batch_size=2,
168
  data_dir='../data',
169
  output_dir="./results",
170
  vision_model=True,
171
  vision_projector_file='models/vision_projector/vp_ckpt_0.pth',
172
- resume=False
173
  )
 
20
  self.kwargs = kwargs
21
 
22
 
23
+ app_config = dict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  max_seqlen=512,
25
  max_caption_len=100,
 
 
26
  data_dir='../data',
27
  output_dir="./results",
28
  vision_model=True,
29
  vision_projector_file='models/vision_projector/vp_ckpt_0.pth',
30
+ phi_adapter_dir='models/phi_adapter'
31
  )
requirement.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ trl
4
+ transformers
5
+ accelerate
6
+ git+https://github.com/huggingface/peft.git
7
+ datasets
8
+ bitsandbytes
9
+ einops
10
+ wandb
11
+ git+https://github.com/m-bain/whisperx.git
12
+ scipy