|
import spaces |
|
import PIL |
|
import torch |
|
import subprocess |
|
import gradio as gr |
|
import os |
|
|
|
from typing import Optional |
|
from accelerate import Accelerator |
|
from diffusers import ( |
|
AutoencoderKL, |
|
StableDiffusionXLControlNetPipeline, |
|
ControlNetModel, |
|
UNet2DConditionModel, |
|
) |
|
from transformers import ( |
|
BlipProcessor, BlipForConditionalGeneration, |
|
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer |
|
) |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from clip_interrogator import Interrogator, Config, list_clip_models |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
os.makedirs("sdxl_light_caption_output", exist_ok=True) |
|
os.makedirs("sdxl_light_custom_caption_output", exist_ok=True) |
|
|
|
snapshot_download( |
|
repo_id = 'nickpai/sdxl_light_caption_output', |
|
local_dir = 'sdxl_light_caption_output' |
|
) |
|
|
|
snapshot_download( |
|
repo_id = 'nickpai/sdxl_light_custom_caption_output', |
|
local_dir = 'sdxl_light_custom_caption_output' |
|
) |
|
|
|
|
|
def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image: |
|
|
|
image_lab = image.convert('LAB') |
|
color_map_lab = color_map.convert('LAB') |
|
|
|
|
|
l, a , b = image_lab.split() |
|
_, a_map, b_map = color_map_lab.split() |
|
|
|
|
|
merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) |
|
|
|
|
|
result_rgb = merged_lab.convert('RGB') |
|
return result_rgb |
|
|
|
def remove_unlikely_words(prompt: str) -> str: |
|
""" |
|
Removes unlikely words from a prompt. |
|
|
|
Args: |
|
prompt: The text prompt to be cleaned. |
|
|
|
Returns: |
|
The cleaned prompt with unlikely words removed. |
|
""" |
|
unlikely_words = [] |
|
|
|
a1_list = [f'{i}s' for i in range(1900, 2000)] |
|
a2_list = [f'{i}' for i in range(1900, 2000)] |
|
a3_list = [f'year {i}' for i in range(1900, 2000)] |
|
a4_list = [f'circa {i}' for i in range(1900, 2000)] |
|
b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] |
|
b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] |
|
b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] |
|
b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] |
|
|
|
words_list = [ |
|
"black and white,", "black and white", "black & white,", "black & white", "circa", |
|
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", |
|
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,", |
|
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", |
|
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", |
|
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", |
|
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", |
|
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", |
|
"black-and-white photo,", "black-and-white photo", "black - and - white photography", |
|
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", |
|
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", |
|
"black - and - white photograph,", "black - and - white photograph", "black on white,", |
|
"black on white", "black-and-white", "historical image,", "historical picture,", |
|
"historical photo,", "historical photograph,", "archival photo,", "taken in the early", |
|
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored", |
|
"historical photo", "historical setting,", |
|
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", |
|
"taken in", "shot on leica", "shot on leica sl2", "sl2", |
|
"taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", |
|
"overcast day", "overcast weather", "slight overcast", "overcast", |
|
"picture taken in", "photo taken in", |
|
", photo", ", photo", ", photo", ", photo", ", photograph", |
|
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", |
|
] |
|
|
|
unlikely_words.extend(a1_list) |
|
unlikely_words.extend(a2_list) |
|
unlikely_words.extend(a3_list) |
|
unlikely_words.extend(a4_list) |
|
unlikely_words.extend(b1_list) |
|
unlikely_words.extend(b2_list) |
|
unlikely_words.extend(b3_list) |
|
unlikely_words.extend(b4_list) |
|
unlikely_words.extend(words_list) |
|
|
|
for word in unlikely_words: |
|
prompt = prompt.replace(word, "") |
|
return prompt |
|
|
|
def blip_image_captioning(image: PIL.Image.Image, |
|
model_backbone: str, |
|
weight_dtype: type, |
|
device: str, |
|
conditional: bool) -> str: |
|
|
|
|
|
if weight_dtype == torch.bfloat16: |
|
weight_dtype = torch.float16 |
|
|
|
processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") |
|
model = BlipForConditionalGeneration.from_pretrained( |
|
f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) |
|
|
|
valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] |
|
if model_backbone not in valid_backbones: |
|
raise ValueError(f"Invalid model backbone '{model_backbone}'. \ |
|
Valid options are: {', '.join(valid_backbones)}") |
|
|
|
if conditional: |
|
text = "a photography of" |
|
inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) |
|
else: |
|
inputs = processor(image, return_tensors="pt").to(device) |
|
out = model.generate(**inputs) |
|
caption = processor.decode(out[0], skip_special_tokens=True) |
|
return caption |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def process_image(image_path: str, |
|
controlnet_model_name_or_path: str, |
|
caption_model_name: str, |
|
positive_prompt: Optional[str], |
|
negative_prompt: Optional[str], |
|
seed: int, |
|
num_inference_steps: int, |
|
mixed_precision: str, |
|
pretrained_model_name_or_path: str, |
|
pretrained_vae_model_name_or_path: Optional[str], |
|
revision: Optional[str], |
|
variant: Optional[str], |
|
repo: str, |
|
ckpt: str,) -> PIL.Image.Image: |
|
|
|
generator = torch.manual_seed(seed) |
|
|
|
|
|
accelerator = Accelerator( |
|
mixed_precision=mixed_precision, |
|
) |
|
|
|
weight_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
|
|
vae_path = ( |
|
pretrained_model_name_or_path |
|
if pretrained_vae_model_name_or_path is None |
|
else pretrained_vae_model_name_or_path |
|
) |
|
vae = AutoencoderKL.from_pretrained( |
|
vae_path, |
|
subfolder="vae" if pretrained_vae_model_name_or_path is None else None, |
|
revision=revision, |
|
variant=variant, |
|
) |
|
unet = UNet2DConditionModel.from_config( |
|
pretrained_model_name_or_path, |
|
subfolder="unet", |
|
revision=revision, |
|
variant=variant, |
|
) |
|
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) |
|
|
|
|
|
|
|
if pretrained_vae_model_name_or_path is not None: |
|
vae.to(accelerator.device, dtype=weight_dtype) |
|
else: |
|
vae.to(accelerator.device, dtype=torch.float32) |
|
unet.to(accelerator.device, dtype=weight_dtype) |
|
|
|
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
vae=vae, |
|
unet=unet, |
|
controlnet=controlnet, |
|
) |
|
pipe.to(accelerator.device, dtype=weight_dtype) |
|
|
|
image = PIL.Image.open(image_path) |
|
|
|
|
|
pipe, image = accelerator.prepare(pipe, image) |
|
pipe.safety_checker = None |
|
|
|
|
|
original_size = image.size |
|
control_image = image.convert("L").convert("RGB").resize((512, 512)) |
|
|
|
|
|
if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": |
|
caption = blip_image_captioning(control_image, caption_model_name, |
|
weight_dtype, accelerator.device, conditional=True) |
|
|
|
|
|
|
|
|
|
caption = remove_unlikely_words(caption) |
|
|
|
|
|
prompt = [positive_prompt + ", " + caption] |
|
|
|
|
|
image = pipe(prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
image=control_image).images[0] |
|
|
|
|
|
result_image = apply_color(control_image, image) |
|
result_image = result_image.resize(original_size) |
|
return result_image, caption |
|
|
|
|
|
def get_image_paths(folder_path): |
|
import os |
|
image_paths = [] |
|
for filename in os.listdir(folder_path): |
|
if filename.endswith(".jpg") or filename.endswith(".png"): |
|
image_paths.append([os.path.join(folder_path, filename)]) |
|
return image_paths |
|
|
|
|
|
def create_interface(): |
|
controlnet_model_dict = { |
|
"sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet", |
|
"sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet", |
|
} |
|
images = get_image_paths("example/legacy_images") |
|
|
|
interface = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(label="Upload image", |
|
value="example/legacy_images/Hollywood-Sign.jpg", |
|
type='filepath'), |
|
gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict], |
|
value=controlnet_model_dict["sdxl-light-caption-30000"], |
|
label="Select ControlNet Model"), |
|
gr.Dropdown(choices=["blip-image-captioning-large", |
|
"blip-image-captioning-base",], |
|
value="blip-image-captioning-large", |
|
label="Select Image Captioning Model"), |
|
gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"), |
|
gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", |
|
label="Negative Prompt", placeholder="Text for negative prompt"), |
|
], |
|
outputs=[ |
|
gr.Image(label="Colorized image", |
|
value="example/UUColor_results/Hollywood-Sign.jpeg", |
|
format="jpeg"), |
|
gr.Textbox(label="Captioning Result", show_copy_button=True) |
|
], |
|
examples=images, |
|
additional_inputs=[ |
|
|
|
|
|
|
|
|
|
|
|
gr.Slider(0, 1000, 123, label="Seed"), |
|
gr.Radio(choices=[1, 2, 4, 8], |
|
value=8, |
|
label="Inference Steps", |
|
info="1-step, 2-step, 4-step, or 8-step distilled models"), |
|
gr.Radio(choices=["no", "fp16", "bf16"], |
|
value="fp16", |
|
label="Mixed Precision", |
|
info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."), |
|
gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"], |
|
value="stabilityai/stable-diffusion-xl-base-1.0", |
|
label="Base Model", |
|
info="Path to pretrained model or model identifier from huggingface.co/models."), |
|
gr.Dropdown(choices=["None"], |
|
value=None, |
|
label="VAE Model", |
|
info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."), |
|
gr.Dropdown(choices=["None"], |
|
value=None, |
|
label="Varient", |
|
info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"), |
|
gr.Dropdown(choices=["None"], |
|
value=None, |
|
label="Revision", |
|
info="Revision of pretrained model identifier from huggingface.co/models."), |
|
gr.Dropdown(choices=["ByteDance/SDXL-Lightning"], |
|
value="ByteDance/SDXL-Lightning", |
|
label="Repository", |
|
info="Repository from huggingface.co"), |
|
gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors", |
|
"sdxl_lightning_2step_unet.safetensors", |
|
"sdxl_lightning_4step_unet.safetensors", |
|
"sdxl_lightning_8step_unet.safetensors"], |
|
value="sdxl_lightning_8step_unet.safetensors", |
|
label="Checkpoint", |
|
info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"), |
|
], |
|
title="Text-Guided Image Colorization", |
|
description="Upload an image and select a model to colorize it.", |
|
cache_examples=False |
|
) |
|
return interface |
|
|
|
def main(): |
|
|
|
interface = create_interface() |
|
interface.launch(ssr_mode=False) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|