Spaces:
Running
on
Zero
Running
on
Zero
arthur-qiu
commited on
Commit
·
f833804
1
Parent(s):
cca304f
fix filter
Browse files- scale_attention.py +6 -3
- scale_attention_turbo.py +6 -3
scale_attention.py
CHANGED
@@ -159,6 +159,7 @@ def scale_forward(
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
|
|
162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
163 |
|
164 |
attn_output_global = self.attn1(
|
@@ -167,12 +168,12 @@ def scale_forward(
|
|
167 |
attention_mask=attention_mask,
|
168 |
**cross_attention_kwargs,
|
169 |
)
|
170 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w
|
171 |
|
172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
173 |
|
174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
175 |
-
attn_output = rearrange(attn_output, 'bh h w
|
176 |
|
177 |
elif fourg_window:
|
178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
@@ -198,6 +199,7 @@ def scale_forward(
|
|
198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
199 |
attn_output = torch.where(count>0, value/count, value)
|
200 |
|
|
|
201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
202 |
|
203 |
value = torch.zeros_like(norm_hidden_states)
|
@@ -219,10 +221,11 @@ def scale_forward(
|
|
219 |
|
220 |
attn_output_global = torch.where(count>0, value/count, value)
|
221 |
|
|
|
222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
223 |
|
224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
225 |
-
attn_output = rearrange(attn_output, 'bh h w
|
226 |
|
227 |
else:
|
228 |
attn_output = self.attn1(
|
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
162 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
164 |
|
165 |
attn_output_global = self.attn1(
|
|
|
168 |
attention_mask=attention_mask,
|
169 |
**cross_attention_kwargs,
|
170 |
)
|
171 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
|
172 |
|
173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
174 |
|
175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
176 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
177 |
|
178 |
elif fourg_window:
|
179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
200 |
attn_output = torch.where(count>0, value/count, value)
|
201 |
|
202 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
204 |
|
205 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
221 |
|
222 |
attn_output_global = torch.where(count>0, value/count, value)
|
223 |
|
224 |
+
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
226 |
|
227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
228 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
229 |
|
230 |
else:
|
231 |
attn_output = self.attn1(
|
scale_attention_turbo.py
CHANGED
@@ -159,6 +159,7 @@ def scale_forward(
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
|
|
162 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
163 |
|
164 |
attn_output_global = self.attn1(
|
@@ -167,12 +168,12 @@ def scale_forward(
|
|
167 |
attention_mask=attention_mask,
|
168 |
**cross_attention_kwargs,
|
169 |
)
|
170 |
-
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w
|
171 |
|
172 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
173 |
|
174 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
175 |
-
attn_output = rearrange(attn_output, 'bh h w
|
176 |
|
177 |
elif fourg_window:
|
178 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
@@ -198,6 +199,7 @@ def scale_forward(
|
|
198 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
199 |
attn_output = torch.where(count>0, value/count, value)
|
200 |
|
|
|
201 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
202 |
|
203 |
value = torch.zeros_like(norm_hidden_states)
|
@@ -219,10 +221,11 @@ def scale_forward(
|
|
219 |
|
220 |
attn_output_global = torch.where(count>0, value/count, value)
|
221 |
|
|
|
222 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
223 |
|
224 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
225 |
-
attn_output = rearrange(attn_output, 'bh h w
|
226 |
|
227 |
else:
|
228 |
attn_output = self.attn1(
|
|
|
159 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
160 |
attn_output = torch.where(count>0, value/count, value)
|
161 |
|
162 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
163 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
164 |
|
165 |
attn_output_global = self.attn1(
|
|
|
168 |
attention_mask=attention_mask,
|
169 |
**cross_attention_kwargs,
|
170 |
)
|
171 |
+
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh d h w', h = latent_h)
|
172 |
|
173 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
174 |
|
175 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
176 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
177 |
|
178 |
elif fourg_window:
|
179 |
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
|
|
199 |
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
200 |
attn_output = torch.where(count>0, value/count, value)
|
201 |
|
202 |
+
attn_output = rearrange(attn_output, 'bh h w d -> bh d h w')
|
203 |
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
204 |
|
205 |
value = torch.zeros_like(norm_hidden_states)
|
|
|
221 |
|
222 |
attn_output_global = torch.where(count>0, value/count, value)
|
223 |
|
224 |
+
attn_output_global = rearrange(attn_output_global, 'bh h w d -> bh d h w')
|
225 |
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
226 |
|
227 |
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
228 |
+
attn_output = rearrange(attn_output, 'bh d h w -> bh (h w) d')
|
229 |
|
230 |
else:
|
231 |
attn_output = self.attn1(
|