Spaces:
Running
on
Zero
Running
on
Zero
""" | |
concept_attention_kwargs cases: | |
- case 1: concept_attention_kwargs is None | |
- Don't do concept attention | |
- case 2: concept_attention_kwargs is not None | |
- Mandate that concept_attention_layers, timesteps, and concepts are in concept_attention_kwargs and are not None | |
""" | |
import torch | |
from matplotlib import cm | |
from diffusers import FluxPipeline | |
from concept_attention.diffusers.flux import FluxWithConceptAttentionPipeline, FluxTransformer2DModelWithConceptAttention | |
if __name__ == "__main__": | |
transformer = FluxTransformer2DModelWithConceptAttention.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
torch_dtype=torch.bfloat16, | |
subfolder="transformer" | |
) | |
pipe = FluxWithConceptAttentionPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
transformer=transformer, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_model_cpu_offload() | |
prompt = "A cat on the grass" | |
out = pipe( | |
prompt=prompt, | |
guidance_scale=0., | |
height=1024, | |
width=1024, | |
num_inference_steps=4, | |
max_sequence_length=256, | |
concept_attention_kwargs={ | |
"layers": list(range(18)), | |
"timesteps": list(range(3, 4)), | |
"concepts": ["cat", "grass", "sky", "background", "dog"] | |
}, | |
# output_type="latent" | |
) | |
image = out.images[0] | |
concept_attention_maps = out.concept_attention_maps[0] | |
image.save("image.png") | |
# Pull out and save the concept attention maps | |
for i, attention_map in enumerate(concept_attention_maps): | |
attention_map.save(f"images/attention_map_{i}.png") | |