File size: 1,848 Bytes
12863e1
2cf6be0
 
af84433
 
 
05d3d42
987f112
 
 
 
2cf6be0
af84433
 
 
6b0d828
af84433
 
 
 
 
2cf6be0
 
12863e1
0fc6bc9
4095388
af84433
 
987f112
0fc6bc9
2cf6be0
fa0ee64
 
 
 
fb01197
0fc6bc9
f0f8ecd
262e1d3
0fc6bc9
ad9ba71
0fc6bc9
2cf6be0
fa0ee64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import spaces
import gradio as gr
import torch
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import rembg
from io import BytesIO
import PIL.Image as Image
import cv2
import numpy

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")

pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

# Function to generate an image from text using diffusion
@spaces.GPU
def generate_image(prompt):
    prompt += "no background, side view, minimalist shot"
    
    image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]

    return image

_TITLE = "Shoe Generator"
with gr.Blocks(_TITLE) as ShoeGen:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Enter a discription of a shoe")
            # neg_prompt = gr.Textbox(label="Enter a negative prompt", value="low quality, watermark, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, closed eyes, text, logo")
            button_gen = gr.Button("Generate Image")
        with gr.Column():
            image = gr.Image(label="Generated Image", show_download_button=True)   
    
    button_gen.click(generate_image, inputs=[prompt], outputs=[image])

ShoeGen.launch()