seung275 commited on
Commit
67bccaa
·
verified ·
1 Parent(s): 71b066d

Upload 5 files

Browse files
model/AnomalyGPT_models.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ # from datas.dataset_3d import *
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class Normalize(nn.Module):
9
+ def __init__(self, dim: int) -> None:
10
+ super().__init__()
11
+ self.dim = dim
12
+
13
+ def forward(self, x):
14
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
15
+
16
+
17
+ class LinearLayer(nn.Module):
18
+ def __init__(self, dim_in, dim_out, k):
19
+ super(LinearLayer, self).__init__()
20
+ self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])
21
+
22
+ def forward(self, tokens):
23
+ for i in range(len(tokens)):
24
+ if len(tokens[i].shape) == 3:
25
+ tokens[i] = tokens[i].transpose(0,1)
26
+ tokens[i] = self.fc[i](tokens[i][:, 1:, :])
27
+ else:
28
+ B, C, H, W = tokens[i].shape
29
+ tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())
30
+ return tokens
31
+
32
+ class PromptLearner(nn.Module):
33
+ def __init__(self, dim_in, dim_out) -> None:
34
+ super().__init__()
35
+ self.meta_net = nn.Sequential(
36
+ nn.Conv2d(dim_in, dim_in * 4, kernel_size=3, padding=1),
37
+ # nn.BatchNorm2d(dim_in * 4),
38
+ nn.ReLU(inplace=True),
39
+ nn.MaxPool2d(2), # 112 * 112
40
+
41
+ nn.Conv2d(dim_in * 4, dim_in * 16, kernel_size=3, padding=1),
42
+ # nn.BatchNorm2d(dim_in * 16),
43
+ nn.ReLU(inplace=True),
44
+ nn.MaxPool2d(2), # 56 * 56
45
+
46
+ nn.Conv2d(dim_in * 16, dim_in * 64, kernel_size=3, padding=1),
47
+ # nn.BatchNorm2d(dim_in * 64),
48
+ nn.ReLU(inplace=True),
49
+ nn.MaxPool2d(2), # 28 * 28
50
+
51
+ nn.Conv2d(dim_in * 64, dim_in * 256, kernel_size=3, padding=1),
52
+ # nn.BatchNorm2d(dim_in * 256),
53
+ nn.ReLU(inplace=True),
54
+ nn.MaxPool2d(2), # 14 * 14
55
+
56
+ nn.Conv2d(dim_in * 256, dim_in * 1024, kernel_size=3, padding=1),
57
+ # nn.BatchNorm2d(dim_in * 1024),
58
+ nn.ReLU(inplace=True),
59
+ nn.MaxPool2d(2), # 7 * 7
60
+
61
+ nn.Conv2d(dim_in * 1024, dim_out, kernel_size=5, padding=0),
62
+ # nn.BatchNorm2d(dim_out),
63
+ # nn.ReLU(inplace=True),
64
+ )
65
+ self.base_prompts = nn.Parameter(torch.randn((9, dim_out)),requires_grad=True)
66
+
67
+ def forward(self, input):
68
+ B,C,H,W = input.shape
69
+ img_prompts = self.meta_net(input)
70
+ # print(input.shape, img_prompts.shape)
71
+ img_prompts = img_prompts.reshape(B,4096,9).transpose(-2,-1)
72
+ output = torch.cat([self.base_prompts.expand(B,-1,-1), img_prompts], dim=1)
73
+ return output
model/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import DeepSpeedAgent
2
+ from .openllama import OpenLLAMAPEFTModel
3
+ # from .openllama_CLIP import OpenLLAMAPEFTModel_CLIP
4
+ from .ImageBind import models
5
+
6
+ def load_model(args):
7
+ agent_name = args['models'][args['model']]['agent_name']
8
+ model_name = args['models'][args['model']]['model_name']
9
+ model = globals()[model_name](**args)
10
+ agent = globals()[agent_name](model, args)
11
+ return agent
model/agent.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from header import *
2
+
3
+ class DeepSpeedAgent:
4
+
5
+ def __init__(self, model, args):
6
+ super(DeepSpeedAgent, self).__init__()
7
+ self.args = args
8
+ self.model = model
9
+ self.load_stage_1_parameters(args["delta_ckpt_path"])
10
+
11
+
12
+
13
+ for name, param in self.model.named_parameters():
14
+ param.requires_grad = False
15
+
16
+ for name, param in self.model.image_decoder.named_parameters():
17
+ param.requires_grad = True
18
+
19
+ for name, param in self.model.prompt_learner.named_parameters():
20
+ param.requires_grad = True
21
+
22
+
23
+
24
+
25
+ # load config parameters of deepspeed
26
+ ds_params = json.load(open(self.args['ds_config_path']))
27
+ ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
28
+ ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
29
+ self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
30
+ model=self.model,
31
+ model_parameters=self.model.parameters(),
32
+ config_params=ds_params,
33
+ dist_init_required=True,
34
+ args=types.SimpleNamespace(**args)
35
+ )
36
+
37
+ @torch.no_grad()
38
+ def predict(self, batch):
39
+ self.model.eval()
40
+ string = self.model.generate_one_sample(batch)
41
+ return string
42
+
43
+ def train_model(self, batch, current_step=0, pbar=None):
44
+ self.ds_engine.module.train()
45
+ loss, mle_acc = self.ds_engine(batch)
46
+
47
+ self.ds_engine.backward(loss)
48
+ self.ds_engine.step()
49
+ pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
50
+ pbar.update(1)
51
+ if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
52
+ elapsed = pbar.format_dict['elapsed']
53
+ rate = pbar.format_dict['rate']
54
+ remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
55
+ remaining = str(datetime.timedelta(seconds=remaining))
56
+ logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
57
+
58
+ mle_acc *= 100
59
+ return mle_acc
60
+
61
+ def save_model(self, path, current_step):
62
+ # only save trainable model parameters
63
+ param_grad_dic = {
64
+ k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
65
+ }
66
+ state_dict = self.ds_engine.module.state_dict()
67
+ checkpoint = OrderedDict()
68
+ for k, v in self.ds_engine.module.named_parameters():
69
+ if v.requires_grad:
70
+ print(k)
71
+ checkpoint[k] = v
72
+ torch.save(checkpoint, f'{path}/pytorch_model.pt')
73
+ # save tokenizer
74
+ self.model.llama_tokenizer.save_pretrained(path)
75
+ # save configuration
76
+ self.model.llama_model.config.save_pretrained(path)
77
+ print(f'[!] save model into {path}')
78
+
79
+ def load_stage_1_parameters(self, path):
80
+ delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
81
+ self.model.load_state_dict(delta_ckpt, strict=False)
model/modeling_llama.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
2
+
3
+ """ PyTorch LLaMA model."""
4
+ import math
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
16
+ from transformers.models.llama.configuration_llama import LlamaConfig
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ _CONFIG_FOR_DOC = "LlamaConfig"
22
+
23
+
24
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
25
+ def _make_causal_mask(
26
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
27
+ ):
28
+ """
29
+ Make causal mask used for bi-directional self-attention.
30
+ """
31
+ bsz, tgt_len = input_ids_shape
32
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
33
+ mask_cond = torch.arange(mask.size(-1), device=device)
34
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
35
+ mask = mask.to(dtype)
36
+
37
+ if past_key_values_length > 0:
38
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
39
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
43
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
44
+ """
45
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
46
+ """
47
+ bsz, src_len = mask.size()
48
+ tgt_len = tgt_len if tgt_len is not None else src_len
49
+
50
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
51
+
52
+ inverted_mask = 1.0 - expanded_mask
53
+
54
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
55
+
56
+
57
+ class LlamaRMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ LlamaRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+
66
+ def forward(self, hidden_states):
67
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+
70
+ # convert into half-precision if necessary
71
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
72
+ hidden_states = hidden_states.to(self.weight.dtype)
73
+
74
+ return self.weight * hidden_states
75
+
76
+
77
+ class LlamaRotaryEmbedding(torch.nn.Module):
78
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
79
+ super().__init__()
80
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
81
+ self.register_buffer("inv_freq", inv_freq)
82
+
83
+ # Build here to make `torch.jit.trace` work.
84
+ self.max_seq_len_cached = max_position_embeddings
85
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
86
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
88
+ emb = torch.cat((freqs, freqs), dim=-1)
89
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
90
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
91
+
92
+ def forward(self, x, seq_len=None):
93
+ # x: [bs, num_attention_heads, seq_len, head_size]
94
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
95
+ if seq_len > self.max_seq_len_cached:
96
+ self.max_seq_len_cached = seq_len
97
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
98
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
99
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
101
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
102
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
103
+ return (
104
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
105
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
106
+ )
107
+
108
+
109
+ def rotate_half(x):
110
+ """Rotates half the hidden dims of the input."""
111
+ x1 = x[..., : x.shape[-1] // 2]
112
+ x2 = x[..., x.shape[-1] // 2 :]
113
+ return torch.cat((-x2, x1), dim=-1)
114
+
115
+
116
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
117
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
118
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
119
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
120
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ class LlamaMLP(nn.Module):
127
+ def __init__(
128
+ self,
129
+ hidden_size: int,
130
+ intermediate_size: int,
131
+ hidden_act: str,
132
+ ):
133
+ super().__init__()
134
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
135
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
136
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
137
+ self.act_fn = ACT2FN[hidden_act]
138
+
139
+ def forward(self, x):
140
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
141
+
142
+
143
+ class LlamaAttention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: LlamaConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.hidden_size = config.hidden_size
150
+ self.num_heads = config.num_attention_heads
151
+ self.head_dim = self.hidden_size // self.num_heads
152
+ self.max_position_embeddings = config.max_position_embeddings
153
+
154
+ if (self.head_dim * self.num_heads) != self.hidden_size:
155
+ raise ValueError(
156
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
157
+ f" and `num_heads`: {self.num_heads})."
158
+ )
159
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
160
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
161
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
162
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
163
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
164
+
165
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.LongTensor] = None,
173
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
174
+ output_attentions: bool = False,
175
+ use_cache: bool = False,
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177
+ bsz, q_len, _ = hidden_states.size()
178
+
179
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
180
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
181
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
+
183
+ kv_seq_len = key_states.shape[-2]
184
+ if past_key_value is not None:
185
+ kv_seq_len += past_key_value[0].shape[-2]
186
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
187
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
188
+ # [bsz, nh, t, hd]
189
+
190
+ if past_key_value is not None:
191
+ # reuse k, v, self_attention
192
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
193
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
194
+
195
+ past_key_value = (key_states, value_states) if use_cache else None
196
+
197
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
+
199
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
+ raise ValueError(
201
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
202
+ f" {attn_weights.size()}"
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
+ raise ValueError(
208
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
+ )
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
212
+
213
+ # upcast attention to fp32
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+
217
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
+ f" {attn_output.size()}"
221
+ )
222
+
223
+ attn_output = attn_output.transpose(1, 2)
224
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
225
+
226
+ attn_output = self.o_proj(attn_output)
227
+
228
+ if not output_attentions:
229
+ attn_weights = None
230
+
231
+ return attn_output, attn_weights, past_key_value
232
+
233
+
234
+ class LlamaDecoderLayer(nn.Module):
235
+ def __init__(self, config: LlamaConfig):
236
+ super().__init__()
237
+ self.hidden_size = config.hidden_size
238
+ self.self_attn = LlamaAttention(config=config)
239
+ self.mlp = LlamaMLP(
240
+ hidden_size=self.hidden_size,
241
+ intermediate_size=config.intermediate_size,
242
+ hidden_act=config.hidden_act,
243
+ )
244
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states: torch.Tensor,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ position_ids: Optional[torch.LongTensor] = None,
252
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
253
+ output_attentions: Optional[bool] = False,
254
+ use_cache: Optional[bool] = False,
255
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
256
+ """
257
+ Args:
258
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
260
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
261
+ output_attentions (`bool`, *optional*):
262
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
263
+ returned tensors for more detail.
264
+ use_cache (`bool`, *optional*):
265
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
266
+ (see `past_key_values`).
267
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
268
+ """
269
+
270
+ residual = hidden_states
271
+
272
+ hidden_states = self.input_layernorm(hidden_states)
273
+
274
+ # Self Attention
275
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
276
+ hidden_states=hidden_states,
277
+ attention_mask=attention_mask,
278
+ position_ids=position_ids,
279
+ past_key_value=past_key_value,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ )
283
+ hidden_states = residual + hidden_states
284
+
285
+ # Fully Connected
286
+ residual = hidden_states
287
+ hidden_states = self.post_attention_layernorm(hidden_states)
288
+ hidden_states = self.mlp(hidden_states)
289
+ hidden_states = residual + hidden_states
290
+
291
+ outputs = (hidden_states,)
292
+
293
+ if output_attentions:
294
+ outputs += (self_attn_weights,)
295
+
296
+ if use_cache:
297
+ outputs += (present_key_value,)
298
+
299
+ return outputs
300
+
301
+
302
+ LLAMA_START_DOCSTRING = r"""
303
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
304
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
305
+ etc.)
306
+
307
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
308
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
309
+ and behavior.
310
+
311
+ Parameters:
312
+ config ([`LlamaConfig`]):
313
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
314
+ load the weights associated with the model, only the configuration. Check out the
315
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
316
+ """
317
+
318
+
319
+ @add_start_docstrings(
320
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
321
+ LLAMA_START_DOCSTRING,
322
+ )
323
+ class LlamaPreTrainedModel(PreTrainedModel):
324
+ config_class = LlamaConfig
325
+ base_model_prefix = "model"
326
+ supports_gradient_checkpointing = True
327
+ _no_split_modules = ["LlamaDecoderLayer"]
328
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
329
+
330
+ def _init_weights(self, module):
331
+ std = self.config.initializer_range
332
+ if isinstance(module, nn.Linear):
333
+ module.weight.data.normal_(mean=0.0, std=std)
334
+ if module.bias is not None:
335
+ module.bias.data.zero_()
336
+ elif isinstance(module, nn.Embedding):
337
+ module.weight.data.normal_(mean=0.0, std=std)
338
+ if module.padding_idx is not None:
339
+ module.weight.data[module.padding_idx].zero_()
340
+
341
+ def _set_gradient_checkpointing(self, module, value=False):
342
+ if isinstance(module, LlamaModel):
343
+ module.gradient_checkpointing = value
344
+
345
+
346
+ LLAMA_INPUTS_DOCSTRING = r"""
347
+ Args:
348
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
349
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
350
+ it.
351
+
352
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
353
+ [`PreTrainedTokenizer.__call__`] for details.
354
+
355
+ [What are input IDs?](../glossary#input-ids)
356
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
357
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
358
+
359
+ - 1 for tokens that are **not masked**,
360
+ - 0 for tokens that are **masked**.
361
+
362
+ [What are attention masks?](../glossary#attention-mask)
363
+
364
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
365
+ [`PreTrainedTokenizer.__call__`] for details.
366
+
367
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
368
+ `past_key_values`).
369
+
370
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
371
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
372
+ information on the default strategy.
373
+
374
+ - 1 indicates the head is **not masked**,
375
+ - 0 indicates the head is **masked**.
376
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
377
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
378
+ config.n_positions - 1]`.
379
+
380
+ [What are position IDs?](../glossary#position-ids)
381
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
382
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
383
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
384
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
385
+
386
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
387
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
388
+
389
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
390
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
391
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
392
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
393
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
394
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
395
+ model's internal embedding lookup matrix.
396
+ use_cache (`bool`, *optional*):
397
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
398
+ `past_key_values`).
399
+ output_attentions (`bool`, *optional*):
400
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401
+ tensors for more detail.
402
+ output_hidden_states (`bool`, *optional*):
403
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404
+ more detail.
405
+ return_dict (`bool`, *optional*):
406
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
407
+ """
408
+
409
+
410
+ @add_start_docstrings(
411
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
412
+ LLAMA_START_DOCSTRING,
413
+ )
414
+ class LlamaModel(LlamaPreTrainedModel):
415
+ """
416
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
417
+
418
+ Args:
419
+ config: LlamaConfig
420
+ """
421
+
422
+ def __init__(self, config: LlamaConfig):
423
+ super().__init__(config)
424
+ self.padding_idx = config.pad_token_id
425
+ self.vocab_size = config.vocab_size
426
+
427
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
428
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
429
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
+
431
+ self.gradient_checkpointing = False
432
+ # Initialize weights and apply final processing
433
+ self.post_init()
434
+
435
+ def get_input_embeddings(self):
436
+ return self.embed_tokens
437
+
438
+ def set_input_embeddings(self, value):
439
+ self.embed_tokens = value
440
+
441
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
442
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
443
+ # create causal mask
444
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
445
+ combined_attention_mask = None
446
+ if input_shape[-1] > 1:
447
+ combined_attention_mask = _make_causal_mask(
448
+ input_shape,
449
+ inputs_embeds.dtype,
450
+ device=inputs_embeds.device,
451
+ past_key_values_length=past_key_values_length,
452
+ )
453
+
454
+ if attention_mask is not None:
455
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
456
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
457
+ inputs_embeds.device
458
+ )
459
+ combined_attention_mask = (
460
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
461
+ )
462
+
463
+ return combined_attention_mask
464
+
465
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
466
+ def forward(
467
+ self,
468
+ input_ids: torch.LongTensor = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ query_embeds: Optional[torch.FloatTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
479
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
+ output_hidden_states = (
481
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
+ )
483
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
484
+
485
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
+
487
+ # retrieve input_ids and inputs_embeds
488
+ if input_ids is not None and inputs_embeds is not None:
489
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
490
+ elif input_ids is not None:
491
+ batch_size, seq_length = input_ids.shape
492
+ elif inputs_embeds is not None:
493
+ batch_size, seq_length, _ = inputs_embeds.shape
494
+ else:
495
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
496
+
497
+ if inputs_embeds is None:
498
+ inputs_embeds = self.embed_tokens(input_ids)
499
+ if query_embeds is not None:
500
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
501
+ batch_size, seq_length, _ = inputs_embeds.shape
502
+
503
+ seq_length_with_past = seq_length
504
+ past_key_values_length = 0
505
+
506
+ if past_key_values is not None:
507
+ past_key_values_length = past_key_values[0][0].shape[2]
508
+ seq_length_with_past = seq_length_with_past + past_key_values_length
509
+
510
+ if position_ids is None:
511
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
512
+ position_ids = torch.arange(
513
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
514
+ )
515
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
516
+ else:
517
+ position_ids = position_ids.view(-1, seq_length).long()
518
+
519
+ # embed positions
520
+ if attention_mask is None:
521
+ attention_mask = torch.ones(
522
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
523
+ )
524
+ attention_mask = self._prepare_decoder_attention_mask(
525
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
526
+ )
527
+
528
+ hidden_states = inputs_embeds
529
+
530
+ if self.gradient_checkpointing and self.training:
531
+ if use_cache:
532
+ logger.warning_once(
533
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
534
+ )
535
+ use_cache = False
536
+
537
+ # decoder layers
538
+ all_hidden_states = () if output_hidden_states else None
539
+ all_self_attns = () if output_attentions else None
540
+ next_decoder_cache = () if use_cache else None
541
+
542
+ for idx, decoder_layer in enumerate(self.layers):
543
+ if output_hidden_states:
544
+ all_hidden_states += (hidden_states,)
545
+
546
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
547
+
548
+ if self.gradient_checkpointing and self.training:
549
+
550
+ def create_custom_forward(module):
551
+ def custom_forward(*inputs):
552
+ # None for past_key_value
553
+ return module(*inputs, output_attentions, None)
554
+
555
+ return custom_forward
556
+
557
+ layer_outputs = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(decoder_layer),
559
+ hidden_states,
560
+ attention_mask,
561
+ position_ids,
562
+ None,
563
+ )
564
+ else:
565
+ layer_outputs = decoder_layer(
566
+ hidden_states,
567
+ attention_mask=attention_mask,
568
+ position_ids=position_ids,
569
+ past_key_value=past_key_value,
570
+ output_attentions=output_attentions,
571
+ use_cache=use_cache,
572
+ )
573
+
574
+ hidden_states = layer_outputs[0]
575
+
576
+ if use_cache:
577
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
578
+
579
+ if output_attentions:
580
+ all_self_attns += (layer_outputs[1],)
581
+
582
+ hidden_states = self.norm(hidden_states)
583
+
584
+ # add hidden states from the last decoder layer
585
+ if output_hidden_states:
586
+ all_hidden_states += (hidden_states,)
587
+
588
+ next_cache = next_decoder_cache if use_cache else None
589
+ if not return_dict:
590
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
591
+ return BaseModelOutputWithPast(
592
+ last_hidden_state=hidden_states,
593
+ past_key_values=next_cache,
594
+ hidden_states=all_hidden_states,
595
+ attentions=all_self_attns,
596
+ )
597
+
598
+
599
+ class LlamaForCausalLM(LlamaPreTrainedModel):
600
+ def __init__(self, config):
601
+ super().__init__(config)
602
+ self.model = LlamaModel(config)
603
+
604
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
605
+
606
+ # Initialize weights and apply final processing
607
+ self.post_init()
608
+
609
+ def get_input_embeddings(self):
610
+ return self.model.embed_tokens
611
+
612
+ def set_input_embeddings(self, value):
613
+ self.model.embed_tokens = value
614
+
615
+ def get_output_embeddings(self):
616
+ return self.lm_head
617
+
618
+ def set_output_embeddings(self, new_embeddings):
619
+ self.lm_head = new_embeddings
620
+
621
+ def set_decoder(self, decoder):
622
+ self.model = decoder
623
+
624
+ def get_decoder(self):
625
+ return self.model
626
+
627
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
628
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
629
+ def forward(
630
+ self,
631
+ input_ids: torch.LongTensor = None,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ position_ids: Optional[torch.LongTensor] = None,
634
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
635
+ inputs_embeds: Optional[torch.FloatTensor] = None,
636
+ query_embeds: Optional[torch.FloatTensor] = None,
637
+ labels: Optional[torch.LongTensor] = None,
638
+ use_cache: Optional[bool] = None,
639
+ output_attentions: Optional[bool] = None,
640
+ output_hidden_states: Optional[bool] = None,
641
+ return_dict: Optional[bool] = None,
642
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
643
+ r"""
644
+ Args:
645
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
646
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
647
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
648
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
649
+
650
+ Returns:
651
+
652
+ Example:
653
+
654
+ ```python
655
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
656
+
657
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
658
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
659
+
660
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
661
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
662
+
663
+ >>> # Generate
664
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
665
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
666
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
667
+ ```"""
668
+
669
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
670
+ output_hidden_states = (
671
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
672
+ )
673
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
674
+
675
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
676
+ outputs = self.model(
677
+ input_ids=input_ids,
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ past_key_values=past_key_values,
681
+ inputs_embeds=inputs_embeds,
682
+ query_embeds=query_embeds,
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ return_dict=return_dict,
687
+ )
688
+
689
+ hidden_states = outputs[0]
690
+ logits = self.lm_head(hidden_states)
691
+
692
+ loss = None
693
+ if labels is not None:
694
+ # Shift so that tokens < n predict n
695
+ shift_logits = logits[..., :-1, :].contiguous()
696
+ shift_labels = labels[..., 1:].contiguous()
697
+ # Flatten the tokens
698
+ loss_fct = CrossEntropyLoss()
699
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
700
+ shift_labels = shift_labels.view(-1)
701
+ # Enable model parallelism
702
+ shift_labels = shift_labels.to(shift_logits.device)
703
+ loss = loss_fct(shift_logits, shift_labels)
704
+
705
+ if not return_dict:
706
+ output = (logits,) + outputs[1:]
707
+ return (loss,) + output if loss is not None else output
708
+
709
+ return CausalLMOutputWithPast(
710
+ loss=loss,
711
+ logits=logits,
712
+ past_key_values=outputs.past_key_values,
713
+ hidden_states=outputs.hidden_states,
714
+ attentions=outputs.attentions,
715
+ )
716
+
717
+ def prepare_inputs_for_generation(
718
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
719
+ ):
720
+ if past_key_values:
721
+ input_ids = input_ids[:, -1:]
722
+
723
+ position_ids = kwargs.get("position_ids", None)
724
+ if attention_mask is not None and position_ids is None:
725
+ # create position_ids on the fly for batch generation
726
+ position_ids = attention_mask.long().cumsum(-1) - 1
727
+ position_ids.masked_fill_(attention_mask == 0, 1)
728
+ if past_key_values:
729
+ position_ids = position_ids[:, -1].unsqueeze(-1)
730
+ query_embeds = None
731
+
732
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
733
+ if inputs_embeds is not None and past_key_values is None:
734
+ model_inputs = {"inputs_embeds": inputs_embeds}
735
+ else:
736
+ model_inputs = {"input_ids": input_ids}
737
+
738
+ model_inputs.update(
739
+ {
740
+ "position_ids": position_ids,
741
+ "query_embeds": query_embeds,
742
+ "past_key_values": past_key_values,
743
+ "use_cache": kwargs.get("use_cache"),
744
+ "attention_mask": attention_mask,
745
+ }
746
+ )
747
+ return model_inputs
748
+
749
+ @staticmethod
750
+ def _reorder_cache(past_key_values, beam_idx):
751
+ reordered_past = ()
752
+ for layer_past in past_key_values:
753
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
754
+ return reordered_past
755
+
model/openllama.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from header import *
2
+ import torch.nn.functional as F
3
+ from .ImageBind import *
4
+ from .ImageBind import data
5
+ from .modeling_llama import LlamaForCausalLM
6
+ from .AnomalyGPT_models import LinearLayer, PromptLearner
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+ from utils.loss import FocalLoss, BinaryDiceLoss
9
+ import kornia as K
10
+
11
+ import torch
12
+ from torch.nn.utils import rnn
13
+ from transformers import AutoConfig, AutoModelForCausalLM
14
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
15
+
16
+ CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
17
+ 'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum']
18
+
19
+ prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
20
+ prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
21
+
22
+ prompt_state = [prompt_normal, prompt_abnormal]
23
+ prompt_templates = ['a photo of a {}.', 'a photo of the {}.']
24
+ # prompt_templates = [
25
+ # 'a cropped photo of the {}.', 'a cropped photo of a {}.', 'a close-up photo of a {}.', 'a close-up photo of the {}.',
26
+ # 'a bright photo of the {}.', 'a bright photo of a {}.', 'a dark photo of a {}.', 'a dark photo of the {}.',
27
+ # 'a dark photo of the {}.', 'a dark photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a jpeg corrupted photo of the {}.',
28
+ # 'a blurry photo of the {}.', 'a blurry photo of a {}.', 'a photo of a {}.', 'a photo of the {}.',
29
+ # 'a photo of the small {}.', 'a photo of a small {}.', 'a photo of the large {}.', 'a photo of a large {}.',
30
+ # 'a photo of the {} for visual insprction.', 'a photo of a {} for visual insprction.',
31
+ # 'a photo of the {} for anomaly detection.', 'a photo of a {} for anomaly detection.'
32
+ # ]
33
+ objs = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
34
+ 'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum', 'macaroni1', 'macaroni2','pcb1', 'pcb2', 'pcb3', 'pcb4', 'capsules']
35
+
36
+ prompt_sentences = {}
37
+
38
+ for obj in objs:
39
+ prompt_sentence_obj = []
40
+ for i in range(len(prompt_state)):
41
+ prompted_state = [state.format(obj) for state in prompt_state[i]]
42
+ prompted_sentence = []
43
+ for s in prompted_state:
44
+ for template in prompt_templates:
45
+ prompted_sentence.append(template.format(s))
46
+ prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.cuda.current_device())#torch.cuda.current_device())
47
+ prompt_sentence_obj.append(prompted_sentence)
48
+ prompt_sentences[obj] = prompt_sentence_obj
49
+
50
+
51
+
52
+ def encode_text_with_prompt_ensemble(model, obj, device):
53
+
54
+ global prompt_sentences
55
+ normal_sentences = []
56
+ abnormal_sentences = []
57
+ for idx in range(len(obj)):
58
+ sentence = prompt_sentences[obj[idx].replace('_', ' ')]
59
+ normal_sentences.append(sentence[0])
60
+ abnormal_sentences.append(sentence[1])
61
+
62
+ normal_sentences = torch.cat(normal_sentences).to(device)
63
+ abnormal_sentences = torch.cat(abnormal_sentences).to(device)
64
+
65
+ class_embeddings_normal = model({ModalityType.TEXT: normal_sentences})[ModalityType.TEXT][0]
66
+ class_embeddings_abnormal = model({ModalityType.TEXT: abnormal_sentences})[ModalityType.TEXT][0]
67
+ # class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
68
+
69
+ class_embeddings_normal = class_embeddings_normal.reshape((len(obj), len(prompt_templates) * len(prompt_normal), 1024))
70
+ class_embeddings_normal = class_embeddings_normal.mean(dim=1, keepdim=True)
71
+ class_embeddings_normal = class_embeddings_normal / class_embeddings_normal.norm(dim=-1, keepdim=True)
72
+
73
+ class_embeddings_abnormal = class_embeddings_abnormal.reshape((len(obj), len(prompt_templates) * len(prompt_abnormal), 1024))
74
+ class_embeddings_abnormal = class_embeddings_abnormal.mean(dim=1, keepdim=True)
75
+ class_embeddings_abnormal = class_embeddings_abnormal / class_embeddings_abnormal.norm(dim=-1, keepdim=True)
76
+
77
+ text_features = torch.cat([class_embeddings_normal, class_embeddings_abnormal], dim=1)
78
+
79
+ return text_features
80
+
81
+
82
+
83
+ class StoppingCriteriaSub(StoppingCriteria):
84
+
85
+ def __init__(self, stops = [], encounters=1):
86
+ super().__init__()
87
+ self.stops = stops
88
+ self.ENCOUNTERS = encounters
89
+
90
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
91
+ stop_count = 0
92
+ for stop in self.stops:
93
+ stop_count = (stop == input_ids[0]).sum().item()
94
+ if stop_count >= self.ENCOUNTERS:
95
+ return True
96
+ return False
97
+
98
+ def build_one_instance(tokenizer, conversation):
99
+ text_list = []
100
+ turn_num = len(conversation)
101
+ input_ids, target_ids = [], []
102
+ for i in range(turn_num):
103
+ turn = conversation[i]
104
+ role = turn['from']
105
+ if i == 0: # the first human turn
106
+ assert role == 'human'
107
+ text = turn['value'] + '\n### Assistant:'
108
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
109
+ input_ids += one_input_id
110
+ target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt
111
+ else:
112
+ if role == 'human':
113
+ text = 'Human: ' + turn['value'] + '\n### Assistant:'
114
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
115
+ input_ids += one_input_id
116
+ target_ids += [-100]*len(one_input_id)
117
+ elif role == 'gpt':
118
+ text = turn['value'] + '\n###'
119
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
120
+ input_ids += one_input_id
121
+ target_ids += one_input_id
122
+ else:
123
+ raise Exception('Wrong Role!!!')
124
+ text_list.append(text)
125
+ assert len(input_ids) == len(target_ids)
126
+ return text_list, input_ids, target_ids
127
+
128
+ def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len):
129
+ batch_input_ids, batch_target_ids = [], []
130
+ for conversation in batch_of_conversations:
131
+ _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation)
132
+ batch_input_ids.append(torch.LongTensor(one_input_ids))
133
+ batch_target_ids.append(torch.LongTensor(one_target_ids))
134
+ input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
135
+ target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100)
136
+ assert input_ids.size() == target_ids.size()
137
+ input_ids = input_ids[:,:max_tgt_len]
138
+ target_ids = target_ids[:,:max_tgt_len]
139
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
140
+ assert attention_mask.size() == input_ids.size()
141
+ return input_ids, target_ids, attention_mask.long()
142
+
143
+ def find_first_file_in_directory(directory_path):
144
+ try:
145
+ file_list = os.listdir(directory_path)
146
+ for item in file_list:
147
+ item_path = os.path.join(directory_path, item)
148
+ if os.path.isfile(item_path):
149
+ return item_path
150
+ return None
151
+
152
+ except OSError as e:
153
+ print(f"Error while accessing directory: {e}")
154
+ return None
155
+
156
+
157
+ PROMPT_START = '### Human: <Img>'
158
+ class OpenLLAMAPEFTModel(nn.Module):
159
+
160
+ '''LoRA for LLaMa model'''
161
+
162
+ def __init__(self, **args):
163
+ super(OpenLLAMAPEFTModel, self).__init__()
164
+ self.args = args
165
+ imagebind_ckpt_path = args['imagebind_ckpt_path']
166
+ vicuna_ckpt_path = args['vicuna_ckpt_path']
167
+ max_tgt_len = args['max_tgt_len']
168
+ stage = args['stage']
169
+
170
+ self.device = torch.cuda.current_device()
171
+
172
+ print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
173
+
174
+ self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
175
+ self.visual_encoder.to(torch.float16).to(self.device)
176
+ imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
177
+ self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
178
+
179
+
180
+ self.iter = 0
181
+
182
+ self.image_decoder = LinearLayer(1280, 1024, 4).to(torch.float16).to(self.device)
183
+
184
+ self.prompt_learner = PromptLearner(1, 4096).to(torch.float16).to(self.device)
185
+
186
+ self.loss_focal = FocalLoss()
187
+ self.loss_dice = BinaryDiceLoss()
188
+
189
+
190
+ # free vision encoder
191
+ for name, param in self.visual_encoder.named_parameters():
192
+ param.requires_grad = False
193
+ self.visual_encoder.eval()
194
+ print ('Visual encoder initialized.')
195
+
196
+ print (f'Initializing language decoder from {vicuna_ckpt_path} ...')
197
+
198
+ # add the lora module
199
+ peft_config = LoraConfig(
200
+ task_type=TaskType.CAUSAL_LM,
201
+ inference_mode=False,
202
+ r=self.args['lora_r'],
203
+ lora_alpha=self.args['lora_alpha'],
204
+ lora_dropout=self.args['lora_dropout'],
205
+ target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
206
+ )
207
+
208
+ # config = AutoConfig.from_pretrained(vicuna_ckpt_path)
209
+ # with init_empty_weights():
210
+ # self.llama_model = AutoModelForCausalLM.from_config(config)
211
+
212
+ # # device_map = infer_auto_device_map(self.llama_model, no_split_module_classes=["OPTDecoderLayer"], dtype="float16")
213
+ # # print(device_map)
214
+ device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10.self_attn': 0, 'model.layers.10.mlp.gate_proj': 0, 'model.layers.10.mlp.down_proj': 'cpu', 'model.layers.10.mlp.up_proj': 'cpu', 'model.layers.10.mlp.act_fn': 'cpu', 'model.layers.10.input_layernorm': 'cpu', 'model.layers.10.post_attention_layernorm': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'disk', 'model.layers.25': 'disk', 'model.layers.26': 'disk', 'model.layers.27': 'disk', 'model.layers.28': 'disk', 'model.layers.29': 'disk', 'model.layers.30': 'disk', 'model.layers.31.self_attn': 'disk', 'model.layers.31.mlp.gate_proj': 'disk', 'model.layers.31.mlp.down_proj': 'disk', 'model.layers.31.mlp.up_proj': 'disk', 'model.layers.31.mlp.act_fn': 'disk', 'model.layers.31.input_layernorm': 'disk', 'model.layers.31.post_attention_layernorm': 'disk', 'model.norm': 'disk', 'lm_head': 'disk'}
215
+ # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
+ # # self.llama_model.to(torch.float16)
217
+ # # try:
218
+ self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True)
219
+ # # except:
220
+ # pass
221
+ # finally:
222
+ # print(self.llama_model.hf_device_map)
223
+ self.llama_model = get_peft_model(self.llama_model, peft_config)
224
+ # delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
225
+ # self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
+ self.llama_model.print_trainable_parameters()
227
+
228
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16)
229
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
+ self.llama_tokenizer.padding_side = "right"
231
+ print ('Language decoder initialized.')
232
+
233
+ self.llama_proj = nn.Linear(
234
+ self.visual_hidden_size, self.llama_model.config.hidden_size
235
+ ).to(torch.float16).to(self.device)
236
+
237
+ self.max_tgt_len = max_tgt_len
238
+
239
+
240
+
241
+ def rot90_img(self,x,k):
242
+ # k is 0,1,2,3
243
+ degreesarr = [0., 90., 180., 270., 360]
244
+ degrees = torch.tensor(degreesarr[k]).to(self.llama_model.dtype).to(self.device)
245
+ x = K.geometry.transform.rotate(x, angle = degrees, padding_mode='reflection')
246
+ return x
247
+
248
+ def encode_video(self, video_paths):
249
+ inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
250
+ # convert into visual dtype
251
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
252
+ with torch.no_grad():
253
+ embeddings = self.visual_encoder(inputs)
254
+ video_embeds = embeddings[ModalityType.VISION][0] # bsz x 1024
255
+ inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
256
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
257
+ return inputs_llama, atts_llama
258
+
259
+ def encode_audio(self, audio_paths):
260
+ inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
261
+ # convert into visual dtype
262
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
263
+ with torch.no_grad():
264
+ embeddings = self.visual_encoder(inputs)
265
+ audio_embeds = embeddings[ModalityType.AUDIO][0] # bsz x 1024
266
+ inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
267
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
268
+ return inputs_llama, atts_llama
269
+
270
+ def encode_thermal(self, thermal_paths):
271
+ inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
272
+ # convert into visual dtype
273
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
274
+ with torch.no_grad():
275
+ embeddings = self.visual_encoder(inputs)
276
+ image_embeds = embeddings['thermal'][0] # bsz x 1024
277
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
278
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
279
+ return inputs_llama, atts_llama
280
+
281
+ def encode_image(self, image_paths):
282
+ inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
283
+ # convert into visual dtype
284
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
285
+ with torch.no_grad():
286
+ embeddings = self.visual_encoder(inputs)
287
+ image_embeds = embeddings['vision'][0] # bsz x 1024
288
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1280
289
+ patch_tokens = self.image_decoder(patch_features) # bsz x h*w x 1024
290
+
291
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
292
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
293
+ return inputs_llama, atts_llama, patch_tokens
294
+
295
+ def encode_image_for_web_demo(self, image_paths):
296
+ inputs = {ModalityType.VISION: data.load_and_transform_vision_data_for_web_demo(image_paths, self.device)}
297
+ # convert into visual dtype
298
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
299
+ with torch.no_grad():
300
+ embeddings = self.visual_encoder(inputs)
301
+ image_embeds = embeddings['vision'][0] # bsz x 1024
302
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1280
303
+ patch_tokens = self.image_decoder(patch_features) # bsz x h*w x 1024
304
+
305
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
306
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
307
+ return inputs_llama, atts_llama, patch_tokens
308
+
309
+ def encode_image_for_one_shot(self, image_paths):
310
+ inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
311
+ # convert into visual dtype
312
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
313
+ with torch.no_grad():
314
+ embeddings = self.visual_encoder(inputs)
315
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1280
316
+ for i in range(len(patch_features)):
317
+ patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :]
318
+
319
+ return patch_features
320
+
321
+ def encode_image_for_one_shot_from_tensor(self, image_tensors):
322
+ if not isinstance(image_tensors, list):
323
+ image_tensors = [image_tensors]
324
+ inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)}
325
+ # convert into visual dtype
326
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
327
+ with torch.no_grad():
328
+ embeddings = self.visual_encoder(inputs)
329
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1280
330
+ for i in range(len(patch_features)):
331
+ patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :]
332
+
333
+ return patch_features
334
+
335
+ def encode_image_for_one_shot_with_aug(self, image_paths):
336
+ image_tensors = data.load_and_transform_vision_data(image_paths, self.device).to(self.llama_model.dtype)
337
+ B,C,H,W = image_tensors.shape
338
+ # print(B,C,H,W)
339
+
340
+ rotated_images = torch.zeros((4, B, C, H, W)).to(self.llama_model.dtype).to(self.device)
341
+
342
+
343
+ for j, degree in enumerate([0, 1, 2, 3]):
344
+ rotated_img = self.rot90_img(image_tensors, degree)
345
+ # 存储旋转后的图像
346
+ rotated_images[j] = rotated_img
347
+
348
+ image_tensors = rotated_images.transpose(0,1).reshape(B * 4, C, H, W)
349
+
350
+ inputs = {ModalityType.VISION: image_tensors}
351
+ # convert into visual dtype
352
+ inputs = {key: inputs[key] for key in inputs}
353
+ with torch.no_grad():
354
+ embeddings = self.visual_encoder(inputs)
355
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1280
356
+ for i in range(len(patch_features)):
357
+ patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :].reshape(B,4,256,1280).reshape(B, 4 * 256, 1280)
358
+
359
+ return patch_features
360
+
361
+ def encode_image_from_tensor(self, image_tensors):
362
+ if not isinstance(image_tensors, list):
363
+ image_tensors = [image_tensors]
364
+ inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)}
365
+ # convert into visual dtype
366
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
367
+ with torch.no_grad():
368
+ embeddings = self.visual_encoder(inputs)
369
+ image_embeds = embeddings['vision'][0] # bsz x 1024
370
+ patch_features = embeddings['vision'][1] # bsz x h*w x 1024
371
+ patch_tokens = self.image_decoder(patch_features)
372
+
373
+
374
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
375
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
376
+ return inputs_llama, atts_llama, patch_tokens
377
+
378
+ def encode_image_from_tensor_no_patch(self, image_tensors):
379
+ if not isinstance(image_tensors, list):
380
+ image_tensors = [image_tensors]
381
+ inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)}
382
+ # convert into visual dtype
383
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
384
+ with torch.no_grad():
385
+ embeddings = self.visual_encoder(inputs)
386
+ image_embeds = embeddings['vision'][0] # bsz x 1024
387
+
388
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
389
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
390
+ return inputs_llama, atts_llama
391
+
392
+
393
+
394
+ def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask, anomaly_embedding = None):
395
+ '''
396
+ input_ids, target_ids, attention_mask: bsz x s2
397
+ '''
398
+ input_ids = input_ids.to(self.device) # bsz x s2
399
+ target_ids = target_ids.to(self.device) # bsz x s2
400
+ attention_mask = attention_mask.to(self.device) # bsz x s2
401
+
402
+ batch_size = img_embeds.shape[0]
403
+ p_before = PROMPT_START
404
+ p_before_tokens = self.llama_tokenizer(p_before,
405
+ return_tensors="pt", add_special_tokens=False).to(self.device)
406
+ # peft model need deeper call
407
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
408
+
409
+ p_middle = '</Img> '
410
+ p_middle_tokens = self.llama_tokenizer(p_middle,
411
+ return_tensors="pt", add_special_tokens=False).to(self.device)
412
+ # peft model need deeper call
413
+ p_middle_embeds = self.llama_model.model.model.embed_tokens(p_middle_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
414
+
415
+
416
+ p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
417
+ bos = torch.ones([batch_size, 1],
418
+ dtype=p_before_tokens.input_ids.dtype,
419
+ device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
420
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
421
+
422
+
423
+
424
+ if anomaly_embedding != None:
425
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_middle_embeds, anomaly_embedding, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
426
+ # create targets
427
+ empty_targets = (
428
+ torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1] + anomaly_embedding.size()[1]], # 1 (bos) + s1 + 1 (image vector)
429
+ dtype=torch.long).to(self.device).fill_(-100)
430
+ ) # bsz x (1 + s1 + 1)
431
+ targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2)
432
+ assert inputs_embeds.size()[1] == targets.size()[1]
433
+
434
+ atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1] + anomaly_embedding.size()[1]], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1)
435
+ attention_mask = torch.cat([atts_prefix, attention_mask], dim=1)
436
+ assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2)
437
+ return inputs_embeds, targets, attention_mask
438
+ else:
439
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_middle_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
440
+ # create targets
441
+ empty_targets = (
442
+ torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1]], # 1 (bos) + s1 + 1 (image vector)
443
+ dtype=torch.long).to(self.device).fill_(-100)
444
+ ) # bsz x (1 + s1 + 1)
445
+ targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2)
446
+ assert inputs_embeds.size()[1] == targets.size()[1]
447
+
448
+ atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1]], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1)
449
+ attention_mask = torch.cat([atts_prefix, attention_mask], dim=1)
450
+ assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2)
451
+ return inputs_embeds, targets, attention_mask
452
+
453
+
454
+ def forward(self, inputs):
455
+
456
+ if 'masks' in inputs:
457
+
458
+ image_paths = inputs['images']
459
+ img_embeds, _, patch_tokens = self.encode_image_from_tensor(image_paths)
460
+ class_name = inputs['class_names']
461
+
462
+ loss_pixel = 0
463
+ feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object' for _ in class_name], self.device)
464
+
465
+ anomaly_maps = []
466
+ for layer in range(len(patch_tokens)):
467
+ patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True)
468
+ # print(patch_tokens[layer].shape)
469
+ # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1))
470
+ anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
471
+ B, L, C = anomaly_map.shape
472
+ H = int(np.sqrt(L))
473
+ anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
474
+ size=224, mode='bilinear', align_corners=True)
475
+ # anomaly_map_no_softmax = anomaly_map
476
+ anomaly_map = torch.softmax(anomaly_map, dim=1)
477
+ anomaly_maps.append(anomaly_map)
478
+ # anomaly_maps_ns.append(anomaly_map_no_softmax)
479
+
480
+ gt = inputs['masks']
481
+ gt = torch.stack(gt, dim=0).to(self.device)
482
+ gt = gt.squeeze()
483
+ # print(gt.max(), gt.min())
484
+ gt[gt > 0.3], gt[gt <= 0.3] = 1, 0
485
+
486
+
487
+ for num in range(len(anomaly_maps)):
488
+ f_loss = self.loss_focal(anomaly_maps[num], gt)
489
+ d_loss = self.loss_dice(anomaly_maps[num][:, 1, :, :], gt)
490
+ loss_pixel = loss_pixel + f_loss + d_loss
491
+
492
+ for num in range(len(anomaly_maps)):
493
+ anomaly_maps[num] = anomaly_maps[num][:,1,:,:]
494
+
495
+ anomaly_map_all = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1)
496
+
497
+ if random.randint(0,1) == 0 and len(inputs['img_paths']) == len(image_paths):
498
+
499
+ normal_paths = []
500
+ for path in inputs['img_paths']:
501
+ normal_path = path.replace('test', 'train')
502
+ normal_path = find_first_file_in_directory("/".join(normal_path.split('/')[:-2])+'/good')
503
+ normal_paths.append(normal_path)
504
+
505
+ print(normal_paths)
506
+ query_patch_tokens = self.encode_image_for_one_shot_from_tensor(image_paths)
507
+ normal_patch_tokens = self.encode_image_for_one_shot_with_aug(normal_paths)
508
+ sims = []
509
+ B = len(image_paths)
510
+
511
+ for i in range(len(query_patch_tokens)):
512
+ query_patch_tokens_reshaped = query_patch_tokens[i].view(B,256,1,1280)
513
+ normal_tokens_reshaped = normal_patch_tokens[i].reshape(B,1,-1,1280)
514
+ cosine_similarity_matrix = F.cosine_similarity(query_patch_tokens_reshaped, normal_tokens_reshaped, dim=-1)
515
+ sim_max, _ = torch.max(cosine_similarity_matrix, dim=-1)
516
+ sims.append(sim_max)
517
+
518
+ sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(B,1,16,16)
519
+ sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True)
520
+ anomaly_map_all = 1 - sim # (anomaly_map_all + 1 - sim) / 2
521
+
522
+ anomaly_map_prompts = self.prompt_learner(anomaly_map_all)
523
+
524
+ # img_embeds = img_embeds + anomaly_map_prompts
525
+
526
+ output_texts = inputs['texts']
527
+ input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len)
528
+ inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask, anomaly_map_prompts)
529
+
530
+ outputs = self.llama_model(
531
+ inputs_embeds=inputs_embeds,
532
+ attention_mask=attention_mask,
533
+ return_dict=True,
534
+ labels=targets,
535
+ )
536
+ loss = outputs.loss
537
+
538
+ # loss_l2 = torch.norm(anomaly_map_prompts / 2 , p=2)
539
+ # loss_l2 = nn.MSELoss()(img_embeds_origin, img_embeds)
540
+ # calculate the token accuarcy
541
+ chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1]
542
+ # print(self.llama_tokenizer.decode(chosen_tokens[0], skip_special_tokens=True))
543
+ labels = targets[:, 2:]
544
+ gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S]
545
+ valid_mask = (labels != -100).reshape(-1)
546
+ # print(self.llama_tokenizer.decode(chosen_tokens.reshape(-1)[valid_mask], skip_special_tokens=True))
547
+ valid_tokens = gen_acc & valid_mask # [B*S]
548
+ gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
549
+
550
+ return loss + loss_pixel, gen_acc
551
+
552
+ else:
553
+
554
+ image_paths = inputs['image_paths']
555
+ img_embeds, _, patch_tokens = self.encode_image_from_tensor(image_paths)
556
+
557
+ output_texts = inputs['output_texts']
558
+
559
+ c_name = 'object'
560
+ for name in CLASS_NAMES:
561
+ if name in output_texts:
562
+ c_name = name
563
+ break
564
+
565
+ feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object'] * len(image_paths), self.device)
566
+
567
+ anomaly_maps = []
568
+ for layer in range(len(patch_tokens)):
569
+ patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True)
570
+ # print(patch_tokens[layer].shape)
571
+ # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1))
572
+ anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
573
+ B, L, C = anomaly_map.shape
574
+ H = int(np.sqrt(L))
575
+ anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
576
+ size=224, mode='bilinear', align_corners=True)
577
+ # anomaly_map_no_softmax = anomaly_map
578
+ anomaly_map = torch.softmax(anomaly_map, dim=1)
579
+ anomaly_maps.append(anomaly_map)
580
+
581
+ for num in range(len(anomaly_maps)):
582
+ anomaly_maps[num] = anomaly_maps[num][:,1,:,:]
583
+
584
+ anomaly_map_all = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1)
585
+
586
+ anomaly_map_prompts = self.prompt_learner(anomaly_map_all)
587
+
588
+ # img_embeds = img_embeds + anomaly_map_prompts
589
+
590
+ input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len)
591
+ inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask, anomaly_map_prompts)
592
+
593
+ outputs = self.llama_model(
594
+ inputs_embeds=inputs_embeds,
595
+ attention_mask=attention_mask,
596
+ return_dict=True,
597
+ labels=targets,
598
+ )
599
+ loss = outputs.loss
600
+ # calculate the token accuarcy
601
+ chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1]
602
+ labels = targets[:, 2:]
603
+ gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S]
604
+ valid_mask = (labels != -100).reshape(-1)
605
+ valid_tokens = gen_acc & valid_mask # [B*S]
606
+ gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
607
+
608
+ return loss, gen_acc
609
+
610
+
611
+ def extract_multimodal_feature(self, inputs, web_demo):
612
+ features = []
613
+ if inputs['image_paths']:
614
+
615
+ prompt = inputs['prompt']
616
+ c_name = 'object'
617
+ for name in CLASS_NAMES:
618
+ if name in prompt:
619
+ c_name = name
620
+ break
621
+
622
+ if not web_demo:
623
+ image_embeds, _, patch_tokens = self.encode_image(inputs['image_paths'])
624
+ feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, [c_name], self.device)
625
+ else:
626
+ image_embeds, _, patch_tokens = self.encode_image_for_web_demo(inputs['image_paths'])
627
+ feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object'], self.device)
628
+
629
+ anomaly_maps = []
630
+ for layer in range(len(patch_tokens)):
631
+ patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True)
632
+ # print(patch_tokens[layer].shape)
633
+ # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1))
634
+ anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
635
+ B, L, C = anomaly_map.shape
636
+ H = int(np.sqrt(L))
637
+ # anomaly_map = anomaly_map.to(torch.float16)
638
+ anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
639
+ size=224, mode='bilinear', align_corners=True)
640
+ # anomaly_map = anomaly_map.to(torch.bfloat16)
641
+ anomaly_map = torch.softmax(anomaly_map, dim=1)
642
+ anomaly_maps.append(anomaly_map[:,1,:,:])
643
+
644
+ anomaly_map_ret = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1)
645
+ # anomaly_map_all = anomaly_map_ret.unsqueeze(1).repeat((1,3,1,1))
646
+ # anomaly_map_feature, _, _ = self.encode_image_from_tensor(anomaly_map_all)
647
+ # image_embeds = anomaly_map_feature + image_embeds
648
+ if inputs['normal_img_paths']:
649
+ query_patch_tokens = self.encode_image_for_one_shot(inputs['image_paths'])
650
+ if 'mvtec' in 'normal_img_paths':
651
+ normal_patch_tokens = self.encode_image_for_one_shot_with_aug(inputs['normal_img_paths'])
652
+ else:
653
+ normal_patch_tokens = self.encode_image_for_one_shot(inputs['normal_img_paths'])
654
+ sims = []
655
+
656
+ for i in range(len(query_patch_tokens)):
657
+ query_patch_tokens_reshaped = query_patch_tokens[i].view(256,1,1280)
658
+ normal_tokens_reshaped = normal_patch_tokens[i].reshape(1,-1,1280)
659
+ cosine_similarity_matrix = F.cosine_similarity(query_patch_tokens_reshaped, normal_tokens_reshaped, dim=2)
660
+ sim_max, _ = torch.max(cosine_similarity_matrix, dim=1)
661
+ sims.append(sim_max)
662
+
663
+ sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(1,1,16,16)
664
+ # anomaly_map = anomaly_map.to(torch.float16)
665
+ sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True)
666
+ # anomaly_map = anomaly_map.to(torch.bfloat16)
667
+ anomaly_map_ret = 1 - sim # (anomaly_map_ret + 1 - sim) / 2
668
+
669
+
670
+ features.append(image_embeds)
671
+ if inputs['audio_paths']:
672
+ audio_embeds, _ = self.encode_audio(inputs['audio_paths'])
673
+ features.append(audio_embeds)
674
+ if inputs['video_paths']:
675
+ video_embeds, _ = self.encode_video(inputs['video_paths'])
676
+ features.append(video_embeds)
677
+ if inputs['thermal_paths']:
678
+ thermal_embeds, _ = self.encode_thermal(inputs['thermal_paths'])
679
+ features.append(thermal_embeds)
680
+
681
+ feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0)
682
+ return feature_embeds, anomaly_map_ret
683
+
684
+ def prepare_generation_embedding(self, inputs, web_demo):
685
+ prompt = inputs['prompt']
686
+ # if len(inputs['modality_embeds']) == 1:
687
+ # feature_embeds = inputs['modality_embeds'][0]
688
+ # else:
689
+ feature_embeds, anomaly_map = self.extract_multimodal_feature(inputs, web_demo)
690
+ # print(anomaly_map.shape)
691
+ inputs['modality_embeds'].append(feature_embeds)
692
+
693
+ batch_size = feature_embeds.shape[0]
694
+ p_before = PROMPT_START
695
+ p_before_tokens = self.llama_tokenizer(p_before,
696
+ return_tensors="pt", add_special_tokens=False).to(self.device)
697
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
698
+
699
+ p_middle = '</Img> '
700
+ p_middle_tokens = self.llama_tokenizer(p_middle,
701
+ return_tensors="pt", add_special_tokens=False).to(self.device)
702
+ # peft model need deeper call
703
+ p_middle_embeds = self.llama_model.model.model.embed_tokens(p_middle_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
704
+
705
+ # self.prompt_learner.eval()
706
+ anomaly_map_prompts = self.prompt_learner(anomaly_map)
707
+
708
+
709
+
710
+
711
+ text = prompt + '\n### Assistant:'
712
+ p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
713
+ p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
714
+ bos = torch.ones([batch_size, 1],
715
+ dtype=p_before_tokens.input_ids.dtype,
716
+ device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
717
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
718
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_middle_embeds, anomaly_map_prompts, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
719
+
720
+ return inputs_embeds, anomaly_map
721
+
722
+ def generate(self, inputs, web_demo=False):
723
+ '''
724
+ inputs = {
725
+ 'image_paths': optional,
726
+ 'audio_paths': optional
727
+ 'video_paths': optional
728
+ 'thermal_paths': optional
729
+ 'mode': generation mode,
730
+ 'prompt': human input prompt,
731
+ 'max_tgt_len': generation length,
732
+ 'top_p': top_p,
733
+ 'temperature': temperature
734
+ 'modality_embeds': None or torch.tensor
735
+ 'modality_cache': save the image cache
736
+ }
737
+ '''
738
+ # self.prompt_learner.eval()
739
+ # self.llama_model.eval()
740
+ # self.llama_proj.eval()
741
+ # self.image_decoder.eval()
742
+ # self.llama_tokenizer.eval()
743
+ input_embeds, pixel_output = self.prepare_generation_embedding(inputs, web_demo)
744
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)])
745
+ outputs = self.llama_model.generate(
746
+ inputs_embeds=input_embeds,
747
+ max_new_tokens=inputs['max_tgt_len'],
748
+ top_p=inputs['top_p'],
749
+ temperature=inputs['temperature'],
750
+ do_sample=True,
751
+ use_cache=True,
752
+ stopping_criteria=stopping_criteria,
753
+ )
754
+ output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True)
755
+ return output_text, pixel_output