File size: 3,957 Bytes
3ae8d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
'''
!git clone https://huggingface.co/spaces/radames/SPIGA-face-alignment-headpose-estimator
!cp -r SPIGA-face-alignment-headpose-estimator/SPIGA .
!pip install -r SPIGA/requirements.txt
!pip install datasets
!huggingface-cli login
'''
from pred_color import *
import gradio as gr

from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
import torch
from diffusers.utils import load_image

controlnet_model_name_or_path = "svjack/ControlNet-Face-Zh"
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)
#controlnet = controlnet.to("cuda")

base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_path, controlnet=controlnet,
    #torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe = pipe.to("cuda")

if torch.cuda.is_available():
    pipe = pipe.to("cuda")
else:
    #pipe.enable_model_cpu_offload()
    pass

example_sample = [
    ["Protector_Cromwell_style.png", "戴帽子穿灰色衣服的男子"]
]

from PIL import Image
def pred_func(image, prompt):
    out = single_pred_features(image)
    if type(out) == type({}):
        #return out["spiga_seg"]
        control_image = out["spiga_seg"]
        if type(image) == type("") and os.path.exists(image):
            image = Image.open(image).convert("RGB")
        elif hasattr(image, "shape"):
            image = Image.fromarray(image).convert("RGB")
        else:
            image = image.convert("RGB")
        image = image.resize((512, 512))

        generator = torch.manual_seed(0)
        image = pipe(
             prompt, num_inference_steps=50,
             generator=generator, image=control_image
        ).images[0]
        return control_image ,image


gr=gr.Interface(fn=pred_func, inputs=['image','text'],
outputs=[gr.Image(label='output').style(height=512),
gr.Image(label='output').style(height=512)],
examples=example_sample if example_sample else None,
)
gr.launch(share=False)

if __name__ == "__main__":
    '''
    control_image = load_image("./conditioning_image_1.png")
    prompt = "戴眼镜的中年男子"
    # generate image
    generator = torch.manual_seed(0)
    image = pipe(
         prompt, num_inference_steps=50, generator=generator, image=control_image
    ).images[0]
    image

    control_image = load_image("./conditioning_image_1.png")
    prompt = "穿蓝色衣服的秃头男子"
    # generate image
    generator = torch.manual_seed(0)
    image = pipe(
         prompt, num_inference_steps=50, generator=generator, image=control_image
    ).images[0]
    image

    control_image = load_image("./conditioning_image_2.png")
    prompt = "金色头发的美丽女子"
    # generate image
    generator = torch.manual_seed(0)
    image = pipe(
         prompt, num_inference_steps=50, generator=generator, image=control_image
    ).images[0]
    image

    control_image = load_image("./conditioning_image_2.png")
    prompt = "绿色运动衫的男子"
    # generate image
    generator = torch.manual_seed(0)
    image = pipe(
         prompt, num_inference_steps=50, generator=generator, image=control_image
    ).images[0]
    image

    from huggingface_hub import HfApi
    hf_api = HfApi()

    hf_api.upload_file(
        path_or_fileobj = "TSD_save_only/diffusion_pytorch_model.bin",
        path_in_repo = "diffusion_pytorch_model.bin",
        repo_id = "svjack/ControlNet-Face-Zh",
        repo_type = "model",
    )

    hf_api.upload_file(
        path_or_fileobj = "TSD_save_only/config.json",
        path_in_repo = "config.json",
        repo_id = "svjack/ControlNet-Face-Zh",
        repo_type = "model",
    )
    '''
    pass