arthur-qiu commited on
Commit
f833804
·
1 Parent(s): cca304f

fix filter

Browse files
Files changed (2) hide show
  1. scale_attention.py +6 -3
  2. scale_attention_turbo.py +6 -3
scale_attention.py CHANGED
@@ -159,6 +159,7 @@ def scale_forward(
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
 
162
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
163
 
164
  attn_output_global = self.attn1(
@@ -167,12 +168,12 @@ def scale_forward(
167
  attention_mask=attention_mask,
168
  **cross_attention_kwargs,
169
  )
170
- attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
171
 
172
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
173
 
174
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
175
- attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
176
 
177
  elif fourg_window:
178
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
@@ -198,6 +199,7 @@ def scale_forward(
198
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
  attn_output = torch.where(count>0, value/count, value)
200
 
 
201
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
202
 
203
  value = torch.zeros_like(norm_hidden_states)
@@ -219,10 +221,11 @@ def scale_forward(
219
 
220
  attn_output_global = torch.where(count>0, value/count, value)
221
 
 
222
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
223
 
224
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
225
- attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
226
 
227
  else:
228
  attn_output = self.attn1(
 
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
162
+ attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
163
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
 
165
  attn_output_global = self.attn1(
 
168
  attention_mask=attention_mask,
169
  **cross_attention_kwargs,
170
  )
171
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
172
 
173
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
 
175
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
+ attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
177
 
178
  elif fourg_window:
179
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
 
199
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
  attn_output = torch.where(count>0, value/count, value)
201
 
202
+ attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
203
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
204
 
205
  value = torch.zeros_like(norm_hidden_states)
 
221
 
222
  attn_output_global = torch.where(count>0, value/count, value)
223
 
224
+ attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
225
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
226
 
227
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
228
+ attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
229
 
230
  else:
231
  attn_output = self.attn1(
scale_attention_turbo.py CHANGED
@@ -159,6 +159,7 @@ def scale_forward(
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
 
162
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
163
 
164
  attn_output_global = self.attn1(
@@ -167,12 +168,12 @@ def scale_forward(
167
  attention_mask=attention_mask,
168
  **cross_attention_kwargs,
169
  )
170
- attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
171
 
172
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
173
 
174
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
175
- attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
176
 
177
  elif fourg_window:
178
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
@@ -198,6 +199,7 @@ def scale_forward(
198
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
  attn_output = torch.where(count>0, value/count, value)
200
 
 
201
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
202
 
203
  value = torch.zeros_like(norm_hidden_states)
@@ -219,10 +221,11 @@ def scale_forward(
219
 
220
  attn_output_global = torch.where(count>0, value/count, value)
221
 
 
222
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
223
 
224
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
225
- attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
226
 
227
  else:
228
  attn_output = self.attn1(
 
159
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
  attn_output = torch.where(count>0, value/count, value)
161
 
162
+ attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
163
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
 
165
  attn_output_global = self.attn1(
 
168
  attention_mask=attention_mask,
169
  **cross_attention_kwargs,
170
  )
171
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
172
 
173
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
 
175
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
+ attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
177
 
178
  elif fourg_window:
179
  norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
 
199
  count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
  attn_output = torch.where(count>0, value/count, value)
201
 
202
+ attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
203
  gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
204
 
205
  value = torch.zeros_like(norm_hidden_states)
 
221
 
222
  attn_output_global = torch.where(count>0, value/count, value)
223
 
224
+ attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
225
  gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
226
 
227
  attn_output = gaussian_local + (attn_output_global - gaussian_global)
228
+ attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
229
 
230
  else:
231
  attn_output = self.attn1(