MakiPan's picture
Update app.py
3dd4371
raw
history blame
10.3 kB
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training.common_utils import shard
from PIL import Image
from argparse import Namespace
import gradio as gr
import copy # added
import numpy as np
import mediapipe as mp
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
from diffusers import (
FlaxControlNetModel,
FlaxStableDiffusionControlNetPipeline,
)
right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
right_style_lm[0].color=(251, 206, 177)
left_style_lm[0].color=(255, 255, 225)
def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False):
hand_landmarks_list = detection_result.hand_landmarks
handedness_list = detection_result.handedness
if overlap:
annotated_image = np.copy(rgb_image)
else:
annotated_image = np.zeros_like(rgb_image)
# Loop through the detected hands to visualize.
for idx in range(len(hand_landmarks_list)):
hand_landmarks = hand_landmarks_list[idx]
handedness = handedness_list[idx]
# Draw the hand landmarks.
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
hand_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
])
if hand_encoding:
if handedness[0].category_name == "Left":
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
left_style_lm,
solutions.drawing_styles.get_default_hand_connections_style())
if handedness[0].category_name == "Right":
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
right_style_lm,
solutions.drawing_styles.get_default_hand_connections_style())
else:
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
solutions.drawing_styles.get_default_hand_landmarks_style(),
solutions.drawing_styles.get_default_hand_connections_style())
return annotated_image
def generate_annotation(img, overlap=False, hand_encoding=False):
"""img(input): numpy array
annotated_image(output): numpy array
"""
# STEP 2: Create an HandLandmarker object.
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
options = vision.HandLandmarkerOptions(base_options=base_options,
num_hands=2)
detector = vision.HandLandmarker.create_from_options(options)
# STEP 3: Load the input image.
image = mp.Image(
image_format=mp.ImageFormat.SRGB, data=img)
# STEP 4: Detect hand landmarks from the input image.
detection_result = detector.detect(image)
# STEP 5: Process the classification result. In this case, visualize it.
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
return annotated_image
model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right")
model_type="Standard"
if model_type=="Standard":
args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
revision="non-ema",
from_pt=True,
controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
controlnet_revision=None,
controlnet_from_pt=False,
)
if model_type=="Hand Encoding":
args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
revision="non-ema",
from_pt=True,
controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k",
controlnet_revision=None,
controlnet_from_pt=False,
)
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
args.controlnet_model_name_or_path,
revision=args.controlnet_revision,
from_pt=args.controlnet_from_pt,
dtype=jnp.float32, # jnp.bfloat16
)
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
# tokenizer=tokenizer,
controlnet=controlnet,
safety_checker=None,
dtype=jnp.float32, # jnp.bfloat16
revision=args.revision,
from_pt=args.from_pt,
)
pipeline_params["controlnet"] = controlnet_params
pipeline_params = jax_utils.replicate(pipeline_params)
rng = jax.random.PRNGKey(0)
num_samples = jax.device_count()
prng_seed = jax.random.split(rng, jax.device_count())
def infer(prompt, negative_prompt, image):
prompts = num_samples * [prompt]
prompt_ids = pipeline.prepare_text_inputs(prompts)
prompt_ids = shard(prompt_ids)
if model_type=="Standard":
annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
if model_type=="Hand Encoding":
annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
validation_image = Image.fromarray(annotated_image).convert("RGB")
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image)
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
negative_prompt_ids = shard(negative_prompt_ids)
images = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=pipeline_params,
prng_seed=prng_seed,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
results = [i for i in images]
return [overlap_image, annotated_image] + results
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("## Stable Diffusion with Hand Control")
gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
gr.Markdown("""
Model1 can be found at [https://huggingface.co/Vincent-luo/controlnet-hands](https://huggingface.co/Vincent-luo/controlnet-hands)
Model2 can be found at [https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/ ](https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/)
Dataset1 can be found at [https://huggingface.co/datasets/MakiPan/hagrid250k-blip2](https://huggingface.co/datasets/MakiPan/hagrid250k-blip2)
Dataset2 can be found at [https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k](https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k)
Preprocessing1 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py)
Preprocessing2 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py)
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
input_image = gr.Image(label="Input Image")
# output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
submit_btn = gr.Button(value = "Submit")
# inputs = [prompt_input, negative_prompt, input_image]
# submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
with gr.Column():
output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
gr.Examples(
examples=[
[
"a woman is making an ok sign in front of a painting",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example.png"
],
[
"a man with his hands up in the air making a rock sign",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example1.png"
],
[
"a man is making a thumbs up gesture",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example2.png"
],
[
"a woman is holding up her hand in front of a window",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example3.png"
],
[
"a man with his finger on his lips",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example4.png"
],
],
inputs=[prompt_input, negative_prompt, input_image],
outputs=[output_image],
fn=infer,
cache_examples=True,
)
inputs = [prompt_input, negative_prompt, input_image]
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
demo.launch()