Spaces:
Runtime error
Runtime error
Commit
·
bf89172
1
Parent(s):
9cd412c
Add more logic to clip embeds
Browse files
app.py
CHANGED
@@ -59,8 +59,32 @@ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
|
|
59 |
def run_all(prompt, steps, n_images, weight, clip_guided):
|
60 |
import random
|
61 |
seed = int(random.randint(0, 2147483647))
|
62 |
-
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float()
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def cfg_model_fn(x, t):
|
65 |
"""The CFG wrapper function."""
|
66 |
n = x.shape[0]
|
|
|
59 |
def run_all(prompt, steps, n_images, weight, clip_guided):
|
60 |
import random
|
61 |
seed = int(random.randint(0, 2147483647))
|
62 |
+
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float()#.cuda()
|
63 |
+
|
64 |
+
if(clip_guided):
|
65 |
+
prompts = [prompt]
|
66 |
+
def parse_prompt(prompt):
|
67 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
68 |
+
vals = prompt.rsplit(':', 2)
|
69 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
70 |
+
else:
|
71 |
+
vals = prompt.rsplit(':', 1)
|
72 |
+
vals = vals + ['', '1'][len(vals):]
|
73 |
+
return vals[0], float(vals[1])
|
74 |
+
|
75 |
+
for prompt in prompts:
|
76 |
+
txt, weight = parse_prompt(prompt)
|
77 |
+
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
78 |
+
weights.append(weight)
|
79 |
+
|
80 |
+
target_embeds = torch.cat(target_embeds)
|
81 |
+
weights = torch.tensor(weights, device=device)
|
82 |
+
if weights.sum().abs() < 1e-3:
|
83 |
+
raise RuntimeError('The weights must not sum to 0.')
|
84 |
+
weights /= weights.sum().abs()
|
85 |
+
clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
|
86 |
+
clip_embed = target_embed.repeat([n_images, 1])
|
87 |
+
|
88 |
def cfg_model_fn(x, t):
|
89 |
"""The CFG wrapper function."""
|
90 |
n = x.shape[0]
|