openlamm commited on
Commit
4fc00f2
·
1 Parent(s): 8478a70

Update model/openlamm.py

Browse files
Files changed (1) hide show
  1. model/openlamm.py +1 -44
model/openlamm.py CHANGED
@@ -203,7 +203,7 @@ class LAMMPEFTModel(nn.Module):
203
  target_modules=self.args['lora_target_modules']
204
  )
205
 
206
- self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, cache_dir='~/.cache/')
207
  self.llama_model = get_peft_model(self.llama_model, peft_config)
208
  self.llama_model.print_trainable_parameters()
209
 
@@ -221,39 +221,6 @@ class LAMMPEFTModel(nn.Module):
221
  self.system_header = system_header
222
  self.device = torch.cuda.current_device()
223
 
224
- # def encode_video(self, video_paths):
225
- # inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
226
- # # convert into visual dtype
227
- # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
228
- # with torch.no_grad():
229
- # embeddings = self.visual_encoder(inputs)
230
- # video_embeds = embeddings[ModalityType.VISION] # bsz x 1024
231
- # inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
232
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
233
- # return inputs_llama, atts_llama
234
-
235
- # def encode_audio(self, audio_paths):
236
- # inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
237
- # # convert into visual dtype
238
- # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
239
- # with torch.no_grad():
240
- # embeddings = self.visual_encoder(inputs)
241
- # audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024
242
- # inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
243
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
244
- # return inputs_llama, atts_llama
245
-
246
- # def encode_thermal(self, thermal_paths):
247
- # inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
248
- # # convert into visual dtype
249
- # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
250
- # with torch.no_grad():
251
- # embeddings = self.visual_encoder(inputs)
252
- # image_embeds = embeddings['thermal'] # bsz x 1024
253
- # inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
254
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
255
- # return inputs_llama, atts_llama
256
-
257
  def encode_image(self, image_paths):
258
  """encode images to llama inputs
259
 
@@ -279,16 +246,6 @@ class LAMMPEFTModel(nn.Module):
279
 
280
  def my_encode_image(self, images):
281
  """encoder loaded image objects"""
282
- # if self.encoder_pretrain == 'imagebind':
283
- # inputs = {ModalityType.VISION: data.transform_vision_data(images, self.device)}
284
- # # convert into visual dtype
285
- # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
286
- # with torch.no_grad():
287
- # embeddings = self.visual_encoder(inputs)
288
- # image_embeds = embeddings['vision'] # bsz x 1024
289
- # inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
290
- # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
291
- # return inputs_llama, atts_llama
292
  if self.encoder_pretrain == 'clip':
293
  inputs = data.transform_vision_data(images, self.device) # bsz x 3 x 224 x 224
294
  inputs_llama = self.clip_encode_image(inputs) # bsz x 1/256 x llama_size
 
203
  target_modules=self.args['lora_target_modules']
204
  )
205
 
206
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
207
  self.llama_model = get_peft_model(self.llama_model, peft_config)
208
  self.llama_model.print_trainable_parameters()
209
 
 
221
  self.system_header = system_header
222
  self.device = torch.cuda.current_device()
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def encode_image(self, image_paths):
225
  """encode images to llama inputs
226
 
 
246
 
247
  def my_encode_image(self, images):
248
  """encoder loaded image objects"""
 
 
 
 
 
 
 
 
 
 
249
  if self.encoder_pretrain == 'clip':
250
  inputs = data.transform_vision_data(images, self.device) # bsz x 3 x 224 x 224
251
  inputs_llama = self.clip_encode_image(inputs) # bsz x 1/256 x llama_size