File size: 2,033 Bytes
5e373a9
 
 
 
 
 
 
 
441ee9c
 
5e373a9
 
 
 
 
 
 
 
 
 
 
 
 
441ee9c
5e373a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441ee9c
5e373a9
 
441ee9c
5e373a9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright 2024 Adobe. All rights reserved.

from run_magicfu import MagicFixup
import os
import pathlib
import torchvision
from torch import autocast
from PIL import Image
import gradio as gr
import numpy as np
import argparse


def sample(original_image, coarse_edit):
    to_tensor = torchvision.transforms.ToTensor()
    with autocast("cuda"):
        w, h = coarse_edit.size
        ref_image_t = to_tensor(original_image.resize((512,512))).half().cuda()
        coarse_edit_t = to_tensor(coarse_edit.resize((512,512))).half().cuda()
        # get mask from coarse
        coarse_edit_mask_t = to_tensor(coarse_edit.resize((512,512))).half().cuda()
        mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop
        coarse_edit_t_rgb = coarse_edit_t[:-1]
        
        out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
        output =  out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy()
        output = (output * 255.0).astype(np.uint8)
        output_pil = Image.fromarray(output)
        output_pil = output_pil.resize((w, h))
        return output_pil

def file_exists(path):
    """ Check if a file exists and is not a directory. """
    if not os.path.isfile(path):
        raise argparse.ArgumentTypeError(f"{path} is not a valid file.")
    return path

def parse_arguments():
    """ Parses command-line arguments. """
    parser = argparse.ArgumentParser(description="Process images based on provided paths.")
    parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.")

    return parser.parse_args()

demo = gr.Interface(fn=sample, inputs=[gr.Image(type="pil", image_mode='RGB'), gr.Image(type="pil", image_mode='RGBA')], outputs=gr.Image(),
                    examples='examples')
    
if __name__ == "__main__":
    args = parse_arguments()

    # create magic fixup model
    magic_fixup = MagicFixup(model_path=args.checkpoint)
    demo.launch(share=True)