Spaces:
Build error
Build error
import json | |
import numpy as np | |
import math | |
import csv | |
import random | |
import argparse | |
import torch | |
import os | |
import torch.distributed as dist | |
import gradio as gr | |
from PIL import Image | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import spaces | |
from accelerate.utils import set_seed | |
from diffusion_pipeline.sd35_pipeline import StableDiffusion3Pipeline, FlowMatchEulerInverseScheduler | |
from diffusion_pipeline.sdxl_pipeline import StableDiffusionXLPipeline | |
from diffusers import BitsAndBytesConfig, SD3Transformer2DModel | |
from diffusers import FlowMatchEulerDiscreteScheduler, DDIMInverseScheduler, DDIMScheduler | |
from huggingface_hub import login | |
import os | |
login(token=os.getenv('HF_TOKEN')) | |
device = torch.device('cuda') | |
# Load models outside the function to avoid reloading every time | |
def load_models(): | |
# Load sd35 model | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model_nf4 = SD3Transformer2DModel.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-large", | |
subfolder="transformer", | |
quantization_config=nf4_config, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe_sd35 = StableDiffusion3Pipeline.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-large", | |
transformer=model_nf4, | |
torch_dtype=torch.bfloat16, | |
) | |
pipe_sd35.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_sd35.scheduler.config) | |
inverse_scheduler_sd35 = FlowMatchEulerInverseScheduler.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-large", | |
subfolder='scheduler' | |
) | |
pipe_sd35.inv_scheduler = inverse_scheduler_sd35 | |
# Load sdxl model | |
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
).to("cuda") | |
pipe_sdxl.scheduler = DDIMScheduler.from_config(pipe_sdxl.scheduler.config) | |
inverse_scheduler_sdxl = DDIMInverseScheduler.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
subfolder='scheduler' | |
) | |
pipe_sdxl.inv_scheduler = inverse_scheduler_sdxl | |
return pipe_sd35, pipe_sdxl | |
pipe_sd35, pipe_sdxl = load_models() | |
def generate_image( | |
model_name, | |
seed, | |
num_steps, | |
guidance_scale, | |
inv_cfg, | |
w2s_guidance, | |
end_timesteps, | |
prompt, | |
method, | |
size, | |
): | |
try: | |
# 根据传入的参数生成图像 | |
torch.cuda.empty_cache() | |
dtype = torch.float16 | |
set_seed(seed) | |
# Select the appropriate pipeline | |
if model_name == 'sd35': | |
pipe = pipe_sd35 | |
elif model_name == 'sdxl': | |
pipe = pipe_sdxl | |
else: | |
raise ValueError("Invalid model name") | |
pipe.to(device) | |
pipe.enable_model_cpu_offload() | |
os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights') | |
# TODO: load noise model | |
if method == 'core' or method == 'z-core': | |
from diffusion_pipeline.refine_model import PromptSD35Net, PromptSDXLNet | |
from diffusion_pipeline.lora import replace_linear_with_lora, lora_true | |
if model_name == 'sd35': | |
refine_model = PromptSD35Net() | |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28) | |
lora_true(refine_model, lora_idx=0) | |
checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu') | |
refine_model.load_state_dict(checkpoint) | |
elif model_name == 'sdxl': | |
refine_model = PromptSDXLNet() | |
replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50) | |
lora_true(refine_model, lora_idx=0) | |
checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu') | |
refine_model.load_state_dict(checkpoint) | |
refine_model = refine_model.to(torch.bfloat16) | |
refine_model = refine_model.to(device) | |
print("Load Lora Success") | |
# 根据模型类型设置形状 | |
if model_name == 'sdxl': | |
shape = (1, 4, size // 8, size // 8) | |
else: | |
shape = (1, 16, size // 8, size // 8) | |
start_latents = torch.randn(shape, dtype=dtype).to(device) | |
# 根据方法选择生成图像 | |
if model_name == 'sdxl': | |
if method == 'core': | |
output = pipe.core( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_steps, | |
latents=start_latents, | |
return_dict=False, | |
refine_model=refine_model, | |
lora_true=lora_true, | |
end_timesteps=end_timesteps, | |
w2s_guidance=w2s_guidance)[0][0] | |
elif method == 'zigzag': | |
output = pipe.zigzag( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
latents=start_latents, | |
return_dict=False, | |
num_inference_steps=num_steps, | |
inv_cfg=inv_cfg)[0][0] | |
elif method == 'z-core': | |
output = pipe.z_core( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_steps, | |
latents=start_latents, | |
return_dict=False, | |
refine_model=refine_model, | |
lora_true=lora_true, | |
end_timesteps=end_timesteps, | |
w2s_guidance=w2s_guidance, | |
inv_cfg=inv_cfg)[0][0] | |
elif method == 'standard': | |
output = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
latents=start_latents, | |
return_dict=False, | |
num_inference_steps=num_steps)[0][0] | |
else: | |
raise ValueError("Invalid method") | |
else: | |
if method == 'core': | |
output = pipe.core( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_steps, | |
latents=start_latents, | |
max_sequence_length=512, | |
return_dict=False, | |
refine_model=refine_model, | |
lora_true=lora_true, | |
end_timesteps=end_timesteps, | |
w2s_guidance=w2s_guidance)[0][0] | |
elif method == 'zigzag': | |
output = pipe.zigzag( | |
prompt=prompt, | |
max_sequence_length=512, | |
guidance_scale=guidance_scale, | |
latents=start_latents, | |
return_dict=False, | |
num_inference_steps=num_steps, | |
inv_cfg=inv_cfg)[0][0] | |
elif method == 'z-core': | |
output = pipe.z_core( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_steps, | |
latents=start_latents, | |
return_dict=False, | |
max_sequence_length=512, | |
refine_model=refine_model, | |
lora_true=lora_true, | |
end_timesteps=end_timesteps, | |
w2s_guidance=w2s_guidance)[0][0] | |
elif method == 'standard': | |
output = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
latents=start_latents, | |
return_dict=False, | |
max_sequence_length=512, | |
num_inference_steps=num_steps)[0][0] | |
else: | |
raise ValueError("Invalid method") | |
# 将生成的图像保存为临时文件并返回 | |
output_path = f'{model_name}_{method}.png' | |
output.save(output_path) | |
return output_path | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None | |
if __name__ == '__main__': | |
# 创建Gradio接口 | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Dropdown(choices=['sdxl', 'sd35'], value='sdxl', label="Model"), # 设置默认模型为 'sdxl' | |
gr.Slider(minimum=1, maximum=1000000, value=1, label="seed"), # 设置默认种子为 1 | |
gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps"), # 设置默认推理步数为 50 | |
gr.Slider(minimum=1, maximum=10, value=5.5, label="CFG"), # 设置默认CFG为 5.5 | |
gr.Slider(minimum=-10, maximum=10, value=-1, label="Inverse CFG"), # 设置默认逆CFG为 -1 | |
gr.Slider(minimum=1, maximum=3.5, value=2.5, label="W2S Guidance"), # 设置默认W2S指导为 2.5 | |
gr.Slider(minimum=1, maximum=100, value=50, label="End Timesteps"), # 设置默认结束时间步为 50 | |
gr.Textbox(label="Prompt"), # 文本框没有默认值 | |
gr.Dropdown(choices=['standard', 'core', 'zigzag', 'z-core'], value='core', label="Method"), # 设置默认方法为 'core' | |
gr.Slider(minimum=1024, maximum=2048, value=1024, label="Size") # 设置默认大小为 1024 | |
], | |
outputs=gr.Image(type="filepath"), # 修改了type参数 | |
title="Image Generation with CoRe^2", | |
) | |
iface.launch(share=True) | |