piyushgrover commited on
Commit
7396aab
·
1 Parent(s): f4882bc
Files changed (4) hide show
  1. config.py +173 -0
  2. constants.py +2 -0
  3. models/vision_projector_model.py +44 -0
  4. utils.py +151 -0
config.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig, BitsAndBytesConfig
3
+ import math
4
+ from typing import Optional
5
+
6
+ class VisionProjectorConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ input_dim=768,
10
+ hidden_dim=256,
11
+ num_tokens=1,
12
+ output_dim=2560,
13
+ **kwargs
14
+ ):
15
+ #super.__init__(**kwargs)
16
+ self.input_dim = input_dim
17
+ self.hidden_dim = hidden_dim
18
+ self.output_dim = output_dim
19
+ self.num_tokens = num_tokens
20
+ self.kwargs = kwargs
21
+
22
+
23
+ class CustomPhiConfig(PretrainedConfig):
24
+ model_type = "phi-msft"
25
+ attribute_map = {
26
+ "max_position_embeddings": "n_positions",
27
+ "hidden_size": "n_embd",
28
+ "num_attention_heads": "n_head",
29
+ "num_hidden_layers": "n_layer",
30
+ }
31
+
32
+ def __init__(
33
+ self,
34
+ vocab_size: int = 51200,
35
+ n_positions: int = 2048,
36
+ n_embd: int = 2560,
37
+ n_layer: int = 32,
38
+ n_inner: Optional[int] = None,
39
+ n_head: int = 32,
40
+ n_head_kv: Optional[int] = None,
41
+ rotary_dim: Optional[int] = 32,
42
+ activation_function: Optional[str] = "gelu_new",
43
+ flash_attn: bool = False,
44
+ flash_rotary: bool = False,
45
+ fused_dense: bool = False,
46
+ attn_pdrop: float = 0.0,
47
+ embd_pdrop: float = 0.0,
48
+ resid_pdrop: float = 0.1,
49
+ layer_norm_epsilon: float = 1e-05,
50
+ initializer_range: float = 0.02,
51
+ tie_word_embeddings: bool = False,
52
+ pad_vocab_size_multiple: int = 64,
53
+ **kwargs
54
+ ) -> None:
55
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
56
+ self.n_positions = n_positions
57
+ self.n_embd = n_embd
58
+ self.n_layer = n_layer
59
+ self.n_inner = n_inner
60
+ self.n_head = n_head
61
+ self.n_head_kv = n_head_kv
62
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
63
+ self.activation_function = activation_function
64
+ self.flash_attn = flash_attn
65
+ self.flash_rotary = flash_rotary
66
+ self.fused_dense = fused_dense
67
+ self.attn_pdrop = attn_pdrop
68
+ self.embd_pdrop = embd_pdrop
69
+ self.resid_pdrop = resid_pdrop
70
+ self.layer_norm_epsilon = layer_norm_epsilon
71
+ self.initializer_range = initializer_range
72
+
73
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
74
+
75
+
76
+ class CLIPVisionToPhiConfig(PretrainedConfig):
77
+ def __init__(self,
78
+ vision_projector_config: VisionProjectorConfig,
79
+ phi_config: CustomPhiConfig,
80
+ **kwargs
81
+ ):
82
+
83
+ #super().__init__(**kwargs)
84
+
85
+ self.vision_projector_config = vision_projector_config
86
+ self.phi_config = phi_config
87
+ self.tokenizer = kwargs.get('tokenizer')
88
+ self.freeze_phi_model = True
89
+
90
+
91
+ '''
92
+ phi_config_obj = CustomPhiConfig(
93
+ **{
94
+ "_name_or_path": "microsoft/phi-2",
95
+ "architectures": [
96
+ "PhiForCausalLM"
97
+ ],
98
+ "auto_map": {
99
+ "AutoConfig": "configuration_phi.PhiConfig",
100
+ "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
101
+ },
102
+ "img_processor": None,
103
+ "model_type": "phi-msft",
104
+ "torch_dtype": "float16",
105
+ "transformers_version": "4.35.2"
106
+ }
107
+
108
+ )
109
+
110
+ '''
111
+ from peft import LoraConfig
112
+
113
+ bnb_config = BitsAndBytesConfig(
114
+ load_in_4bit=True,
115
+ bnb_4bit_quant_type="nf4",
116
+ bnb_4bit_compute_dtype=torch.float16
117
+ )
118
+
119
+ peft_config = LoraConfig(
120
+ lora_alpha=16,
121
+ lora_dropout=0.1,
122
+ r=64,
123
+ bias="none",
124
+ task_type="CAUSAL_LM",
125
+ target_modules=[
126
+ "q_proj",
127
+ "k_proj",
128
+ "v_proj",
129
+ "dense",
130
+ "fc1",
131
+ "fc2"
132
+ ]
133
+ )
134
+
135
+ class MultiInstructModelConfig(PretrainedConfig):
136
+ def __init__(self,
137
+ vision_projector_config: Optional[VisionProjectorConfig] = None,
138
+ **kwargs
139
+ ):
140
+
141
+ self.vision_projector_config = vision_projector_config
142
+ self.quantization_config = bnb_config
143
+
144
+ self.peft_config = peft_config
145
+
146
+ self.tokenizer = kwargs.get('tokenizer')
147
+ self.freeze_vision_projector = True
148
+
149
+
150
+ extra = dict(
151
+ num_epochs=1,
152
+ resume=False,
153
+ data_dir='../data',
154
+ checkpoint_dir='../checkpoints',
155
+ max_seqlen=80,
156
+ batch_size=2,
157
+ live_image_processing=True,
158
+ vision_projector_file='/Users/piyushgrover/Downloads/old_vt_proj/vp_ckpt_0.pth',
159
+ validation_phase=False
160
+ )
161
+
162
+ qlora_config = dict(
163
+ num_steps=1000,
164
+ max_seqlen=512,
165
+ max_caption_len=100,
166
+ batch_size=8,
167
+ micro_batch_size=2,
168
+ data_dir='../data',
169
+ output_dir="./results",
170
+ vision_model=True,
171
+ vision_projector_file='models/vision_projector/vp_ckpt_0.pth',
172
+ resume=False
173
+ )
constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ IGNORE_INDEX = -100
2
+ IMAGE_TOKEN_INDEX = -200
models/vision_projector_model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from config import VisionProjectorConfig
5
+
6
+ '''
7
+ class VisionProjector(nn.Module):
8
+
9
+ def __init__(self, config: VisionProjectorConfig):
10
+ super().__init__()
11
+ self.config = config
12
+ self.input_dim = config.input_dim
13
+ self.hidden_dim = config.hidden_dim
14
+ self.output_dim = config.output_dim
15
+ self.num_tokens = config.num_tokens
16
+
17
+ self.pre_norm = nn.LayerNorm(self.input_dim)
18
+
19
+ self.proj = nn.Sequential(
20
+ nn.GELU(),
21
+ nn.Linear(self.input_dim, self.num_tokens * self.output_dim)
22
+ )
23
+
24
+ def forward(self, x):
25
+ x = self.pre_norm(x)
26
+ x = self.proj(x)
27
+ x = x.reshape( (-1, self.num_tokens, self.output_dim) )
28
+ return x
29
+
30
+ '''
31
+
32
+ class VisionProjector(nn.Module):
33
+
34
+ def __init__(self, config: VisionProjectorConfig):
35
+ super().__init__()
36
+ self.config = config
37
+ self.input_dim = config.input_dim
38
+ self.output_dim = config.output_dim
39
+
40
+ self.proj = nn.Linear(self.input_dim, self.output_dim)
41
+
42
+ def forward(self, x):
43
+ x = self.proj(x)
44
+ return x
utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
8
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
9
+
10
+ def insert_separator(X, sep):
11
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
12
+
13
+ input_ids = []
14
+ offset = 0
15
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
16
+ offset = 1
17
+ input_ids.append(prompt_chunks[0][0])
18
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
19
+ input_ids.extend(x[offset:])
20
+
21
+ if return_tensors is not None:
22
+ if return_tensors == 'pt':
23
+ return torch.tensor(input_ids, dtype=torch.long)
24
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
25
+ return input_ids
26
+
27
+
28
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
29
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
30
+ Args:
31
+ logits: logits distribution shape (batch size x vocabulary size)
32
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
33
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
34
+ """
35
+ top_k = min(top_k, logits.size(-1)) # Safety check
36
+ if top_k > 0:
37
+ # Remove all tokens with a probability less than the last token of the top-k
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = filter_value
40
+
41
+ if top_p > 0.0:
42
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
43
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
44
+
45
+ # Remove tokens with cumulative probability above the threshold
46
+ sorted_indices_to_remove = cumulative_probs > top_p
47
+ # Shift the indices to the right to keep also the first token above the threshold
48
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
49
+ sorted_indices_to_remove[..., 0] = 0
50
+
51
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
52
+ logits[indices_to_remove] = filter_value
53
+ return logits
54
+
55
+ '''
56
+ def get_image_feature_for_vision_projector(image_url):
57
+ image_url = 'http://images.cocodataset.org/%s/%s' % (self.directory, self.image_indices_json[image_index])
58
+
59
+ image = Image.open(requests.get(image_url, stream=True).raw)
60
+ inputs = self.processor(images=image, return_tensors="pt")
61
+ x = self.model(**inputs, output_hidden_states=True)
62
+ image_feature = x.hidden_states[-2][:, 1:].squeeze(0).cpu().detach()
63
+ '''
64
+
65
+
66
+ def generate_output(model, tokenizer, length, input_ids=None, image_features=None, inputs_embeds=None, labels=None,
67
+ temperature=1, top_k=0, top_p=0.0):
68
+ if inputs_embeds is None and (image_features is None or input_ids is None):
69
+ print("image_features or input_ids missing.. returning")
70
+ return
71
+
72
+ ie_size = inputs_embeds.size(1) - 1
73
+ inputs = inputs_embeds
74
+ predicted_tokens = [] #torch.tensor([[]]).to(device)
75
+
76
+ label_size = labels.size(1)
77
+ out = {}
78
+ if labels is None:
79
+ with torch.no_grad():
80
+ for idx in range(length):
81
+ outputs = model.phi_model(inputs_embeds=inputs)
82
+ logits = outputs['logits']
83
+ next_token_logits = logits[:, -1, :] / temperature # Apply temperature
84
+
85
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k,
86
+ top_p=top_p) # Apply top-k and/or top-p
87
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # Sample
88
+
89
+ predicted_tokens.append(next_token)
90
+ next_token_embed = model.text_embedding(next_token)
91
+ inputs = torch.cat((inputs, next_token_embed), dim=1)
92
+
93
+ predicted_tokens = torch.cat([x.unsqueeze(1) for x in predicted_tokens], dim=1)
94
+ out['pred'] = predicted_tokens
95
+ out['logits'] = logits[:, ie_size:, :]
96
+
97
+ return out
98
+ else:
99
+ # traverse_len = labels.size(1) - inputs_embeds.size(1)
100
+ for idx in range(length):
101
+ outputs = model.phi_model(inputs_embeds=inputs)
102
+ logits = outputs['logits']
103
+ next_token_logits = logits[:, -1, :] / temperature # Apply temperature
104
+
105
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k,
106
+ top_p=top_p) # Apply top-k and/or top-p
107
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # Sample
108
+
109
+ predicted_tokens.append(next_token)
110
+
111
+ tf_token = labels[:, idx : idx+1 ].to(device)
112
+ tf_token_embed = model.text_embedding(tf_token)
113
+
114
+ inputs = torch.cat((inputs, tf_token_embed), dim=1) # Add the token to the generated text
115
+
116
+ predicted_tokens = torch.cat([x.unsqueeze(1) for x in predicted_tokens], dim=1).to(device)
117
+ #predicted_token_logits = torch.cat([x.unsqueeze(1) for x in predicted_token_logits], dim=1).to(device)
118
+
119
+ out = dict(pred=predicted_tokens,
120
+ logits=logits)
121
+
122
+ labels = labels.contiguous().type(torch.LongTensor).to(device)
123
+
124
+ logits = logits[:, ie_size:ie_size+label_size, :].contiguous()
125
+
126
+ loss = model.loss(logits.view(-1, logits.size(-1)), labels.view(-1))
127
+
128
+ out['loss'] = loss
129
+
130
+ #model.train()
131
+
132
+ return out
133
+
134
+
135
+ def generate_with_logits(logits, temperature=1, top_k=0, top_p=0.0):
136
+ predicted_tokens = []
137
+
138
+ for idx in range(logits.size(1)):
139
+ next_token_logits = logits[:, idx, :] / temperature # Apply temperature
140
+
141
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k,
142
+ top_p=top_p) # Apply top-k and/or top-p
143
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # Sample
144
+
145
+ predicted_tokens.append(next_token)
146
+
147
+ predicted_tokens = torch.cat([x.unsqueeze(1) for x in predicted_tokens], dim=1).to(device)
148
+
149
+ out = dict(pred=predicted_tokens,
150
+ logits=logits)
151
+ return out