File size: 4,431 Bytes
94b4030
 
 
 
 
 
 
0033c7f
94b4030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695b9f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np

device = "cpu"
model, preprocess = clip.load("RN50x64", device=device)


def img_process(img1,img2,location_width,location_height,size_width,size_height):
    im1=Image.open(img1)
    im2=Image.open(img2).convert('RGBA').resize((600,400))
    print(im1.mode)
    if im1.mode == 'RGBA':
        size = im1.size
        im3 = im1.resize((int(size[0]/2),int(size[1]/2)))
        r, g, b, a = im3.split()
        im2.paste(im3,(50, 50), mask=a)
    elif im1.mode == 'RGB':
        threshold=240
        size = im1.size
        im1 = im1.resize((size_width,size_height))
        im1=im1.convert('RGBA')
        arr=np.array(np.asarray(im1))
        r,g,b,a=np.rollaxis(arr,axis=-1)
        mask=((r>threshold)
              & (g>threshold)
              & (b>threshold)
              )
        arr[mask,3]=0
        im1=Image.fromarray(arr,mode='RGBA')
        r, g, b, a = im1.split()
        im2.paste(im1,(location_width,location_height,), mask=a)
    return im2

def itm(obj,back,location_width,location_height,size_width,size_height,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr):
    
    img1 = img_process(obj,back,location_width,location_height,size_width,size_height)
    img = preprocess(img1).unsqueeze(0)
    obj_prompt = neg_obj if is_obj else pos_obj
    attr_prompt = neg_attr if is_attr else pos_attr 
    text = clip.tokenize([f"a photo of {pos_attr} {pos_obj}",f"a photo of {attr_prompt} {obj_prompt}"])
    with torch.no_grad():
        
        logits_per_image, logits_per_text = model(img, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    print("Label probs:", probs)
    return f"a photo of {pos_attr} {pos_obj}",probs[0][0],f"a photo of {attr_prompt} {obj_prompt}",probs[0][1],img1


with gr.Blocks() as demo:
    gr.Markdown("<h1><center>VL-Checklist Demo</center></h1>")
    gr.Markdown("""
    Tips:
    - In this demo, you can change the object and attribute of object in the text prompt, and you can also change the size and location of the object.
    - Please upload an object image with white background.
    - The model we used in the demo is CLIP.
    """)
    with gr.Row():
        with gr.Column():
            img_obj = gr.Image(value ='sample/apple.png',type = "filepath",label='object_img(Plz input an object with white background)')
            
            loc_w = gr.Slider(maximum = 500,label='location_width',step=1)
            loc_h = gr.Slider(maximum = 300,label='location_height',step=1)
            s_w = gr.Number(value =200,precision=0,label='size_width')
            s_h = gr.Number(value =200,precision=0,label='size_height')
            gr.Markdown("Click **Submit** to get the output!")
        with gr.Column():
            img_back = gr.Image(value ='sample/back.jpg',type = "filepath",label='background_img')
            is_obj = gr.Checkbox(value = True,label='Does negative prompt change the object?')
            pos_obj = gr.Textbox(value = 'apple',label='positive object')
            neg_obj = gr.Textbox(value = 'dog',label='negative object')
            is_attr = gr.Checkbox(value = False,label='Does negative prompt change the attribute?')
            pos_attr = gr.Textbox(value = 'red',label='positive attribute')
            neg_attr = gr.Textbox(value = 'green',label='negative attribute')
    with gr.Row():
                btn = gr.Button("Submit",variant="primary")
    
    with  gr.Row():
        with gr.Column():
            img_output = gr.Image(type = "pil",label='output_img')
        with gr.Column():
            pos_prom = gr.Textbox(label='Positive prompt')
            pos_s = gr.Textbox(label='Positive score')
            neg_prom = gr.Textbox(label='Negative prompt')
            neg_s = gr.Textbox(label='Negative score')

    with  gr.Row():
        gr.Examples([['sample/apple.png', 'sample/back.jpg',50,50,200,200,True,'apple','dog',False,'red','green'],
        ['sample/banana.jpg', 'sample/back.jpg',300,200,200,200,True,'bananas','peaches',False,'yellow','green']],
        [img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr],
        [pos_prom,pos_s,neg_prom,neg_s,img_output],itm,True)

    btn.click(fn=itm,inputs=[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr],
    outputs=[pos_prom,pos_s,neg_prom,neg_s,img_output],
    )


demo.launch()