draptic-demo / app.py
matteomarjanovic's picture
add hf login
5b3343a
raw
history blame
5.57 kB
import gradio as gr
import numpy as np
import random
import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
from controlnet_aux import CannyDetector
from huggingface_hub import login
import torch
import subprocess
from groq import Groq
import base64
import os
login(token=os.environ.get("HF_API_KEY"))
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
# Load FLUX image generator
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
flat_lora_path = "matteomarjanovic/flatsketcher"
canny_lora_path = "black-forest-labs/FLUX.1-Canny-dev-lora"
flat_weigths_file = "lora.safetensors"
canny_weigths_file = "flux1-canny-dev-lora.safetensors"
processor = CannyDetector()
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
pipe.load_lora_weights(flat_lora_path, weight_name=flat_weigths_file, adapter_name="flat")
pipe.load_lora_weights(canny_lora_path, weight_name=canny_weigths_file, adapter_name="canny")
pipe.set_adapters(["flat", "canny"], adapter_weights=[0.7, 0.7])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# @spaces.GPU #[uncomment to use ZeroGPU]
# def infer(
# prompt,
# progress=gr.Progress(track_tqdm=True),
# ):
# # seed = random.randint(0, MAX_SEED)
# # generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# guidance_scale=0.,
# num_inference_steps=4,
# width=1420,
# height=1080,
# max_sequence_length=256,
# ).images[0]
# return image
@spaces.GPU #[uncomment to use ZeroGPU]
def generate_description_fn(
image,
progress=gr.Progress(track_tqdm=True),
):
base64_image = encode_image(image)
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """
I want you to imagine how the technical flat sketch of the garment you see in the picture would look like, both front and back descriptions are mandatory, and describe it to me in rich details, in one paragraph. Don't add any additional comment.
The style of the result should look somewhat like the following example:
The technical flat sketch of the dress would depict a midi-length, off-the-shoulder design with a smocked bodice and short puff sleeves that have elasticized cuffs. The elastic neckline sits straight across the chest and back, ensuring a secure fit. The bodice transitions into a flowy, tiered skirt with three evenly spaced gathered panels, creating soft volume. The back view mirrors the front, maintaining the smocked fit and tiered skirt without visible closures, suggesting a pullover style. Elasticized areas would be marked with textured lines, while the gathers and drape would be indicated through subtle curved strokes, ensuring clarity in construction details.
"""
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
}
],
model="llama-3.2-11b-vision-preview",
)
prompt = chat_completion.choices[0].message.content + " In the style of FLTSKC"
control_image = processor(
image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024
)
image = pipe(
prompt=prompt,
control_image=control_image,
guidance_scale=0.,
num_inference_steps=4,
width=1420,
height=1080,
max_sequence_length=256,
).images[0]
return prompt, image
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
.gradio-container {
background-color: oklch(98% 0 0);
}
.btn-primary {
background-color: #422ad5;
outline-color: #422ad5;
}
"""
# generated_prompt = ""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
# gr.Markdown("# Draptic: from garment image to technical flat sketch")
with gr.Row():
with gr.Column(elem_id="col-input-image"):
# gr.Markdown(" ## Drop your image here")
input_image = gr.Image(type="filepath")
with gr.Column(elem_id="col-container"):
generate_button = gr.Button("Generate flat sketch", scale=0, variant="primary", elem_classes="btn btn-primary")
result = gr.Image(label="Result", show_label=False)
if result:
gr.Markdown("## Description of the garment:")
generated_prompt = gr.Markdown("")
gr.on(
triggers=[generate_button.click],
fn=generate_description_fn,
inputs=[
input_image,
],
outputs=[generated_prompt, result],
)
if __name__ == "__main__":
demo.launch()