File size: 4,034 Bytes
40f772a dd68f21 236af7f 40f772a 236af7f 40f772a dd68f21 40f772a dd68f21 40f772a ff635e8 40f772a ff635e8 40f772a 5407ff1 40f772a dd68f21 40f772a dd68f21 40f772a dd68f21 40f772a dd68f21 40f772a dd68f21 40f772a 5407ff1 40f772a 6176e59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import gradio as gr
from pipeline_controlnet_sd_xl_raw import StableDiffusionXLControlNetRAWPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from torchvision import transforms
from PIL import Image
import traceback
# ========== 1. Load Models ==========
pipe = StableDiffusionXLControlNetRAWPipeline.from_pretrained(
"wencheng256/DiffusionRAW",
torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
# ========== 2. Utility function: tensor -> PIL ==========
def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image:
if img_tensor.is_cuda:
img_tensor = img_tensor.cpu()
if img_tensor.dtype != torch.float32:
img_tensor = img_tensor.float()
img_tensor = img_tensor.clamp(0, 1)
return transforms.ToPILImage()(img_tensor)
# ========== 3. Load a .pth file ==========
def load_pth_data(pth_path):
data = torch.load(pth_path)
rgb_tensor = data["rgb"]
raw_tensor = data["raw"]
mask_tensor = data["mask"]
cond_tensor = data["condition"]
raw_image_pil = tensor_to_pil(raw_tensor[0][:, :448])
rgb_tensor_pil = tensor_to_pil(torch.flip(rgb_tensor[0], dims=[0])[:, :448])
mask_image_pil = tensor_to_pil(1 - mask_tensor[0])
return rgb_tensor_pil, raw_image_pil, mask_image_pil, raw_tensor, mask_tensor, cond_tensor
# ========== 4. Inference function ==========
def infer_fn(prompt, mask_edited, raw_tensor_state, mask_tensor_state, cond_tensor_state):
try:
if isinstance(mask_edited, dict):
mask_edited = mask_edited["composite"]
mask_edited_tensor = transforms.ToTensor()(mask_edited)
mask_edited_tensor = 1-mask_edited_tensor[:1].unsqueeze(0).half()
raw_t = raw_tensor_state.half()
cond_t = cond_tensor_state.half()
generator = torch.manual_seed(0)
result = pipe(
prompt=prompt,
num_inference_steps=20,
generator=generator,
image=raw_t,
mask_image=mask_edited_tensor,
control_image=cond_t
).images[0]
return tensor_to_pil(result)
except Exception as e:
traceback.print_exc()
return "Error occurred during inference. Please check the terminal logs!"
# ========== 5. Build Gradio App ==========
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# DiffusionRAW")
pth_options = ["./data1.pth", "./data2.pth", "./data3.pth"]
pth_selector = gr.Dropdown(choices=pth_options, value=pth_options[0], label="Select a PTH file")
load_button = gr.Button("Load")
with gr.Row():
raw_display = gr.Image(label="Raw Image", interactive=False)
rgb_display = gr.Image(label="sRGB Image", interactive=False)
mask_editor = gr.Sketchpad(
label="Mask (Sketch)",
interactive=True,
width=512,
height=512
)
raw_tensor_state = gr.State()
mask_tensor_state = gr.State()
cond_tensor_state = gr.State()
load_button.click(
fn=load_pth_data,
inputs=[pth_selector],
outputs=[
rgb_display,
raw_display,
mask_editor,
raw_tensor_state,
mask_tensor_state,
cond_tensor_state
]
)
prompt_input = gr.Textbox(label="Prompt", value="An RAW Image.", lines=1)
generate_button = gr.Button("Generate")
output_image = gr.Image(label="Output")
generate_button.click(
fn=infer_fn,
inputs=[
prompt_input,
mask_editor,
raw_tensor_state,
mask_tensor_state,
cond_tensor_state
],
outputs=[output_image]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()
|