Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import datetime | |
import json | |
import itertools | |
import math | |
import os | |
import spaces | |
import time | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import torch | |
from einops import rearrange, repeat | |
from huggingface_hub import snapshot_download | |
from PIL import Image, ImageOps | |
from safetensors.torch import load_file | |
from torchvision.transforms import functional as F | |
from tqdm import tqdm | |
import sampling | |
from modules.autoencoder import AutoEncoder | |
from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder | |
from modules.model_edit import Step1XParams, Step1XEdit | |
print("TORCH_CUDA", torch.cuda.is_available()) | |
def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True): | |
if Path(ckpt_path).suffix == ".safetensors": | |
state_dict = load_file(ckpt_path, device) | |
else: | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
missing, unexpected = model.load_state_dict( | |
state_dict, strict=strict, assign=assign | |
) | |
if len(missing) > 0 and len(unexpected) > 0: | |
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
print("\n" + "-" * 79 + "\n") | |
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
elif len(missing) > 0: | |
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
elif len(unexpected) > 0: | |
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
return model | |
def load_models( | |
dit_path=None, | |
ae_path=None, | |
qwen2vl_model_path=None, | |
device="cuda", | |
max_length=256, | |
dtype=torch.bfloat16, | |
): | |
qwen2vl_encoder = Qwen2VLEmbedder( | |
qwen2vl_model_path, | |
device=device, | |
max_length=max_length, | |
dtype=dtype, | |
) | |
with torch.device("meta"): | |
ae = AutoEncoder( | |
resolution=256, | |
in_channels=3, | |
ch=128, | |
out_ch=3, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=2, | |
z_channels=16, | |
scale_factor=0.3611, | |
shift_factor=0.1159, | |
) | |
step1x_params = Step1XParams( | |
in_channels=64, | |
out_channels=64, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=3072, | |
mlp_ratio=4.0, | |
num_heads=24, | |
depth=19, | |
depth_single_blocks=38, | |
axes_dim=[16, 56, 56], | |
theta=10_000, | |
qkv_bias=True, | |
) | |
dit = Step1XEdit(step1x_params) | |
ae = load_state_dict(ae, ae_path) | |
dit = load_state_dict( | |
dit, dit_path | |
) | |
dit = dit.to(device=device, dtype=dtype) | |
ae = ae.to(device=device, dtype=torch.float32) | |
return ae, dit, qwen2vl_encoder | |
class ImageGenerator: | |
def __init__( | |
self, | |
dit_path=None, | |
ae_path=None, | |
qwen2vl_model_path=None, | |
device="cuda", | |
max_length=640, | |
dtype=torch.bfloat16, | |
) -> None: | |
self.device = torch.device(device) | |
self.ae, self.dit, self.llm_encoder = load_models( | |
dit_path=dit_path, | |
ae_path=ae_path, | |
qwen2vl_model_path=qwen2vl_model_path, | |
max_length=max_length, | |
dtype=dtype, | |
) | |
self.ae = self.ae.to(device=self.device, dtype=torch.float32) | |
self.dit = self.dit.to(device=self.device, dtype=dtype) | |
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype) | |
def to_cuda(self): | |
self.ae.to(device='cuda', dtype=torch.float32) | |
self.dit.to(device='cuda', dtype=torch.bfloat16) | |
self.llm_encoder.to(device='cuda', dtype=torch.bfloat16) | |
def prepare(self, prompt, img, ref_image, ref_image_raw): | |
bs, _, h, w = img.shape | |
bs, _, ref_h, ref_w = ref_image.shape | |
assert h == ref_h and w == ref_w | |
if bs == 1 and not isinstance(prompt, str): | |
bs = len(prompt) | |
elif bs >= 1 and isinstance(prompt, str): | |
prompt = [prompt] * bs | |
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2) | |
if img.shape[0] == 1 and bs > 1: | |
img = repeat(img, "1 ... -> bs ...", bs=bs) | |
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) | |
img_ids = torch.zeros(h // 2, w // 2, 3) | |
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) | |
ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] | |
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] | |
ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs) | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
txt, mask = self.llm_encoder(prompt, ref_image_raw) | |
txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2) | |
img_ids = torch.cat([img_ids, ref_img_ids], dim=-2) | |
return { | |
"img": img, | |
"mask": mask, | |
"img_ids": img_ids.to(img.device), | |
"llm_embedding": txt.to(img.device), | |
"txt_ids": txt_ids.to(img.device), | |
} | |
def process_diff_norm(diff_norm, k): | |
pow_result = torch.pow(diff_norm, k) | |
result = torch.where( | |
diff_norm > 1.0, | |
pow_result, | |
torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), | |
) | |
return result | |
def denoise( | |
self, | |
img: torch.Tensor, | |
img_ids: torch.Tensor, | |
llm_embedding: torch.Tensor, | |
txt_ids: torch.Tensor, | |
timesteps: list[float], | |
cfg_guidance: float = 4.5, | |
mask=None, | |
show_progress=False, | |
timesteps_truncate=1.0, | |
): | |
if show_progress: | |
pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...') | |
else: | |
pbar = itertools.pairwise(timesteps) | |
for t_curr, t_prev in pbar: | |
if img.shape[0] == 1 and cfg_guidance != -1: | |
img = torch.cat([img, img], dim=0) | |
t_vec = torch.full( | |
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device | |
) | |
txt, vec = self.dit.connector(llm_embedding, t_vec, mask) | |
pred = self.dit( | |
img=img, | |
img_ids=img_ids, | |
txt=txt, | |
txt_ids=txt_ids, | |
y=vec, | |
timesteps=t_vec, | |
) | |
if cfg_guidance != -1: | |
cond, uncond = ( | |
pred[0 : pred.shape[0] // 2, :], | |
pred[pred.shape[0] // 2 :, :], | |
) | |
if t_curr > timesteps_truncate: | |
diff = cond - uncond | |
diff_norm = torch.norm(diff, dim=(2), keepdim=True) | |
pred = uncond + cfg_guidance * ( | |
cond - uncond | |
) / self.process_diff_norm(diff_norm, k=0.4) | |
else: | |
pred = uncond + cfg_guidance * (cond - uncond) | |
tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred | |
img_input_length = img.shape[1] // 2 | |
img = torch.cat( | |
[ | |
tem_img[:, :img_input_length], | |
img[ : img.shape[0] // 2, img_input_length:], | |
], dim=1 | |
) | |
return img[:, :img.shape[1] // 2] | |
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: | |
return rearrange( | |
x, | |
"b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
h=math.ceil(height / 16), | |
w=math.ceil(width / 16), | |
ph=2, | |
pw=2, | |
) | |
def load_image(image): | |
from PIL import Image | |
if isinstance(image, np.ndarray): | |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 | |
image = image.unsqueeze(0) | |
return image | |
elif isinstance(image, Image.Image): | |
image = F.to_tensor(image.convert("RGB")) | |
image = image.unsqueeze(0) | |
return image | |
elif isinstance(image, torch.Tensor): | |
return image | |
elif isinstance(image, str): | |
image = F.to_tensor(Image.open(image).convert("RGB")) | |
image = image.unsqueeze(0) | |
return image | |
else: | |
raise ValueError(f"Unsupported image type: {type(image)}") | |
def output_process_image(self, resize_img, image_size): | |
res_image = resize_img.resize(image_size) | |
return res_image | |
def input_process_image(self, img, img_size=512): | |
# 1. 打开图片 | |
w, h = img.size | |
r = w / h | |
if w > h: | |
w_new = math.ceil(math.sqrt(img_size * img_size * r)) | |
h_new = math.ceil(w_new / r) | |
else: | |
h_new = math.ceil(math.sqrt(img_size * img_size / r)) | |
w_new = math.ceil(h_new * r) | |
h_new = math.ceil(h_new) // 16 * 16 | |
w_new = math.ceil(w_new) // 16 * 16 | |
img_resized = img.resize((w_new, h_new)) | |
return img_resized, img.size | |
def generate_image( | |
self, | |
prompt, | |
negative_prompt, | |
ref_images, | |
num_steps, | |
cfg_guidance, | |
seed, | |
num_samples=1, | |
init_image=None, | |
image2image_strength=0.0, | |
show_progress=False, | |
size_level=512, | |
): | |
assert num_samples == 1, "num_samples > 1 is not supported yet." | |
ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level) | |
width, height = ref_images_raw.width, ref_images_raw.height | |
ref_images_raw = self.load_image(ref_images_raw) | |
ref_images_raw = ref_images_raw.to(self.device) | |
# print(f'self.ae, self.dit device: {self.ae.device}, {self.dit.device}') | |
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1) | |
seed = int(seed) | |
seed = torch.Generator(device="cpu").seed() if seed < 0 else seed | |
t0 = time.perf_counter() | |
if init_image is not None: | |
init_image = self.load_image(init_image) | |
init_image = init_image.to(self.device) | |
init_image = torch.nn.functional.interpolate(init_image, (height, width)) | |
init_image = self.ae.encode(init_image.to() * 2 - 1) | |
x = torch.randn( | |
num_samples, | |
16, | |
height // 8, | |
width // 8, | |
device=self.device, | |
dtype=torch.bfloat16, | |
generator=torch.Generator(device=self.device).manual_seed(seed), | |
) | |
timesteps = sampling.get_schedule( | |
num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True | |
) | |
if init_image is not None: | |
t_idx = int((1 - image2image_strength) * num_steps) | |
t = timesteps[t_idx] | |
timesteps = timesteps[t_idx:] | |
x = t * x + (1.0 - t) * init_image.to(x.dtype) | |
x = torch.cat([x, x], dim=0) | |
ref_images = torch.cat([ref_images, ref_images], dim=0) | |
ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0) | |
inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw) | |
x = self.denoise( | |
**inputs, | |
cfg_guidance=cfg_guidance, | |
timesteps=timesteps, | |
show_progress=show_progress, | |
timesteps_truncate=1.0, | |
) | |
x = self.unpack(x.float(), height, width) | |
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): | |
x = self.ae.decode(x) | |
x = x.clamp(-1, 1) | |
x = x.mul(0.5).add(0.5) | |
t1 = time.perf_counter() | |
print(f"Done in {t1 - t0:.1f}s.") | |
images_list = [] | |
for img in x.float(): | |
images_list.append(self.output_process_image(F.to_pil_image(img), img_info)) | |
return images_list | |
# 模型仓库ID(如:"bert-base-uncased") | |
model_repo = "stepfun-ai/Step1X-Edit" | |
# 本地保存路径 | |
model_path = "./model_weights" | |
os.makedirs(model_path, exist_ok=True) | |
# 下载模型(包括所有文件) | |
snapshot_download( | |
repo_id=model_repo, | |
local_dir=model_path, | |
local_dir_use_symlinks=False # 避免使用符号链接 | |
) | |
image_edit = ImageGenerator( | |
ae_path=os.path.join(model_path, 'vae.safetensors'), | |
dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"), | |
qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct', | |
max_length=640, | |
) | |
def inference(prompt, ref_images, seed, size_level): | |
start_time = time.time() | |
if seed == -1: | |
import random | |
random_seed = random.randint(0, 2**32 - 1) | |
else: | |
random_seed = seed | |
image_edit.to_cuda() | |
inference_func = image_edit.generate_image | |
image = inference_func( | |
prompt, | |
negative_prompt="", | |
ref_images=ref_images.convert('RGB'), | |
num_samples=1, | |
num_steps=28, | |
cfg_guidance=6.0, | |
seed=random_seed, | |
show_progress=True, | |
size_level=size_level, | |
)[0] | |
print(f"Time taken: {time.time() - start_time:.2f} seconds") | |
return image, random_seed | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Step1X-Edit | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="编辑指令", | |
value='Remove the person from the image.', | |
) | |
init_image = gr.Image(label="Input Image", type='pil') | |
random_seed = gr.Number(label="Random Seed", value=-1, minimum=-1) | |
size_level = gr.Number(label="size level (recommend 512, 768, 1024, min 512)", value=512, minimum=512) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image",type='pil',image_mode='RGB') | |
output_random_seed = gr.Textbox(label="Used Seed", lines=5) | |
from functools import partial | |
generate_btn.click( | |
fn=inference, | |
inputs=[ | |
prompt, | |
init_image, | |
random_seed, | |
size_level, | |
], | |
outputs=[output_image, output_random_seed], | |
) | |
demo.launch() | |