File size: 3,237 Bytes
2d6061e
 
 
2590cf7
2d6061e
 
 
43af989
499970a
 
 
2d6061e
 
 
 
 
71a5076
 
2d6061e
66863fc
71a5076
 
 
2d6061e
66863fc
2d6061e
66863fc
2d6061e
 
 
 
 
 
66863fc
2d6061e
 
614d206
2d6061e
aad2eee
2d6061e
 
 
 
22deffa
 
 
2d6061e
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a0425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d6061e
442366c
2d6061e
9ae65e8
2d6061e
51bd190
687f0bf
66863fc
22deffa
 
2d6061e
 
 
520f6d6
614d206
2d6061e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login

login(token=os.getenv("HF_READ_TOKEN"))

API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}

word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']

def query(prompt, is_negative=False, steps=7, cfg_scale=7, seed=None, num_images=4):
    for filter in word_list:
        if re.search(rf"\b{filter}\b", prompt):
            raise gr.Error("Unsafe content found. Please try again with different prompts.")
    images = []
    
    for _ in range(num_images):
        payload = {
                "inputs": prompt + ", 8k",
                "is_negative": is_negative,
                "steps": steps,
                "cfg_scale": cfg_scale,
                "seed": seed if seed is not None else random.randint(-1, 2147483647)
            }
        
        image_bytes = requests.post(API_URL, headers=headers, json=payload).content
        image = Image.open(io.BytesIO(image_bytes))
        
        images.append(image)

    return images


css = """
        .gradio-container {
            font-family: 'IBM Plex Sans', sans-serif;
        }
        #gallery {
            min-height: 22rem;
            margin-bottom: 15px;
            margin-left: auto;
            margin-right: auto;
            border-bottom-right-radius: .5rem !important;
            border-bottom-left-radius: .5rem !important;
        }
        #gallery>div>.h-full {
            min-height: 20rem;
        }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
            <div style="text-align: center; margin: 0 auto;">
              <div
                style="
                  display: inline-flex;
                  align-items: center;
                  gap: 0.8rem;
                  font-size: 1.75rem;
                "
              >
                <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
                  Open Diffusion 1.0 Demo
                </h1>
              </div>
            </div>
        """
    )
    with gr.Group():
        with gr.Box():
            with gr.Row():
                with gr.Column():
                    gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
            with gr.Row(elem_id="prompt-container"):
                with gr.Column():
                    text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1, elem_id="prompt-text-input")
                    negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1, elem_id="negative-prompt-text-input")
                    text_button = gr.Button("Generate").style(margin=False, rounded=(False, True, True, False), full_width=False)

        
        
    text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)

demo.launch(show_api=False)