File size: 3,237 Bytes
2d6061e 2590cf7 2d6061e 43af989 499970a 2d6061e 71a5076 2d6061e 66863fc 71a5076 2d6061e 66863fc 2d6061e 66863fc 2d6061e 66863fc 2d6061e 614d206 2d6061e aad2eee 2d6061e 22deffa 2d6061e a4a0425 2d6061e 442366c 2d6061e 9ae65e8 2d6061e 51bd190 687f0bf 66863fc 22deffa 2d6061e 520f6d6 614d206 2d6061e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login
login(token=os.getenv("HF_READ_TOKEN"))
API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}
word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']
def query(prompt, is_negative=False, steps=7, cfg_scale=7, seed=None, num_images=4):
for filter in word_list:
if re.search(rf"\b{filter}\b", prompt):
raise gr.Error("Unsafe content found. Please try again with different prompts.")
images = []
for _ in range(num_images):
payload = {
"inputs": prompt + ", 8k",
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed is not None else random.randint(-1, 2147483647)
}
image_bytes = requests.post(API_URL, headers=headers, json=payload).content
image = Image.open(io.BytesIO(image_bytes))
images.append(image)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Open Diffusion 1.0 Demo
</h1>
</div>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
with gr.Row(elem_id="prompt-container"):
with gr.Column():
text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1, elem_id="prompt-text-input")
negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1, elem_id="negative-prompt-text-input")
text_button = gr.Button("Generate").style(margin=False, rounded=(False, True, True, False), full_width=False)
text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)
demo.launch(show_api=False) |