Spaces:
Runtime error
Runtime error
Commit
·
449a298
1
Parent(s):
ab9e9c4
Cutouts function
Browse files
app.py
CHANGED
@@ -23,6 +23,27 @@ from huggingface_hub import hf_hub_download
|
|
23 |
from CLIP import clip
|
24 |
from diffusion import get_model, sampling, utils
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
27 |
model = get_model('cc12m_1_cfg')()
|
28 |
_, side_y, side_x = model.shape
|
|
|
23 |
from CLIP import clip
|
24 |
from diffusion import get_model, sampling, utils
|
25 |
|
26 |
+
class MakeCutouts(nn.Module):
|
27 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
28 |
+
super().__init__()
|
29 |
+
self.cut_size = cut_size
|
30 |
+
self.cutn = cutn
|
31 |
+
self.cut_pow = cut_pow
|
32 |
+
|
33 |
+
def forward(self, input):
|
34 |
+
sideY, sideX = input.shape[2:4]
|
35 |
+
max_size = min(sideX, sideY)
|
36 |
+
min_size = min(sideX, sideY, self.cut_size)
|
37 |
+
cutouts = []
|
38 |
+
for _ in range(self.cutn):
|
39 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
40 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
41 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
42 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
43 |
+
cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
|
44 |
+
cutouts.append(cutout)
|
45 |
+
return torch.cat(cutouts)
|
46 |
+
|
47 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
48 |
model = get_model('cc12m_1_cfg')()
|
49 |
_, side_y, side_x = model.shape
|