import torch class SMOLLm_VISION_ImageCaptioning(torch.nn.Module): def __init__(self, llm_model, hidden_dim): super(ImageCaptioningModel, self).__init__() self.llm_model = llm_model self.fc = torch.nn.Linear(768, 960) self.relu=torch.nn.GELU() def forward(self, images, input_ids,att): # Encode images image_features = self.relu(self.fc(images)) #image_att=torch.zeros([images.shape[0],]).view(-1,1).to('cuda:0') # Prepare text inputs for LLaMA2 llama_inputs = self.llm_model.prepare_inputs_for_generation(input_ids) with torch.no_grad(): llama_embeds=self.llm_model.get_input_embeddings()(llama_inputs['input_ids']) # Concatenate image features with LLaMA2 text inputs combined_inputs = torch.cat([image_features.unsqueeze(1).float(),llama_embeds], dim=1) #attention_mask=torch.cat((image_att,att),dim=-1) outputs = self.llm_model(inputs_embeds=combined_inputs,attention_mask=att) return outputs.logits[:,1:,:],combined_inputs #return class SmoLLM_processor(): def __init__(self,image_model,image_processor): self.image_model=image_model self.image_processor=image_processor def get_features(self,image): inputs = clip_processor(images=image, return_tensors="pt") with torch.no_grad(): image_features = clip_model.get_image_features(**inputs.to('cuda:0')).squeeze() #tokenized=tokenizer(prompt,padding=True, return_tensors='pt', max_length=50) #input_ids=tokenized['input_ids'].squeeze() #image=tfms(image/255.) #attention_mask=tokenized['attention_mask'].squeeze() #x=input_ids[:-1] #y=input_ids[1:] return image_features