Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from data.prefix_instruction import get_image_prompt, get_task_instruction, get_layout_instruction, get_content_instruction | |
import random | |
from PIL import Image | |
from .gradio_tasks import dense_prediction_data | |
from degradation_utils import add_degradation | |
import numpy as np | |
degradation_list = [ | |
# blur | |
"blur", | |
"compression", | |
"SRx2", | |
"SRx4", | |
"pixelate", | |
"Defocus", | |
"GaussianBlur", | |
# sharpen | |
"oversharpen", | |
# nosie | |
"GaussianNoise", | |
"PoissonNoise", | |
"SPNoise", | |
# mosaic | |
"mosaic", | |
# contrast | |
"contrast_strengthen", | |
"contrast_weaken", | |
# quantization | |
"quantization", | |
"JPEG", | |
# light | |
"brighten", | |
"darken", | |
"LowLight", | |
# color | |
"saturate_strengthen", | |
"saturate_weaken", | |
"gray", | |
"ColorDistortion", | |
# infilling | |
"Inpainting", | |
# rotate | |
"rotate180", | |
# other | |
"Barrel", | |
"Pincushion", | |
"Elastic", | |
# spacial effect | |
"Rain", | |
"Frost", | |
] | |
image_restoration = [dict(name=degradation, image_type=[degradation, "target"]) for degradation in degradation_list] | |
image_restoration_text = [[x['name']] for x in image_restoration] | |
def process_image_restoration_tasks(x): | |
for task in image_restoration: | |
if task['name'] == x[0]: | |
image_type = task['image_type'] | |
image_prompt_list = [get_image_prompt(x)[0] for x in image_type] | |
image_prompt_list = [f"[IMAGE{idx+1}] {image_prompt}" for idx, image_prompt in enumerate(image_prompt_list)] | |
condition_prompt = ", ".join(image_prompt_list[:-1]) | |
target_prompt = image_prompt_list[-1] | |
task_prompt = get_task_instruction(condition_prompt, target_prompt) | |
# sample examples | |
valid_data = dense_prediction_data | |
n_samples = random.randint(2, min(len(valid_data), 3)) | |
images = random.sample(valid_data, k=n_samples) | |
rets = [] | |
for image in images: | |
for t in image_type: | |
if t == "target": | |
rets.append(Image.open(image["target"])) | |
else: | |
deg_image, _ = add_degradation(np.array(Image.open(image["target"])), deg_type=t) | |
rets.append(deg_image) | |
content_prompt = get_content_instruction() + images[-1]['prompt'] | |
grid_h = n_samples | |
grid_w = len(image_type) | |
mask = task.get('mask', [0 for _ in range(grid_w - 1)] + [1]) | |
layout_prompt = get_layout_instruction(grid_w, grid_h) | |
upsampling_noise = None | |
steps = None | |
outputs = [mask, grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + rets | |
break | |
return outputs | |