import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import copy # added import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) right_style_lm[0].color=(251, 206, 177) left_style_lm[0].color=(255, 255, 225) def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness if overlap: annotated_image = np.copy(rgb_image) else: annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) if hand_encoding: if handedness[0].category_name == "Left": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, left_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) if handedness[0].category_name == "Right": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, right_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) else: solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img, overlap=False, hand_encoding=False): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding) return annotated_image model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right") model_type="Standard" if model_type=="Standard": args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) if model_type=="Hand Encoding": args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k", controlnet_revision=None, controlnet_from_pt=False, ) controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( args.controlnet_model_name_or_path, revision=args.controlnet_revision, from_pt=args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=args.revision, from_pt=args.from_pt, ) pipeline_params["controlnet"] = controlnet_params pipeline_params = jax_utils.replicate(pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) def infer(prompt, negative_prompt, image): prompts = num_samples * [prompt] prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) if model_type=="Standard": annotated_image = generate_annotation(image, overlap=False, hand_encoding=False) overlap_image = generate_annotation(image, overlap=True, hand_encoding=False) if model_type=="Hand Encoding": annotated_image = generate_annotation(image, overlap=False, hand_encoding=True) overlap_image = generate_annotation(image, overlap=True, hand_encoding=True) validation_image = Image.fromarray(annotated_image).convert("RGB") processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = pipeline( prompt_ids=prompt_ids, image=processed_image, params=pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] return [overlap_image, annotated_image] + results with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.") gr.Markdown(""" Model1 can be found at [https://huggingface.co/Vincent-luo/controlnet-hands](https://huggingface.co/Vincent-luo/controlnet-hands) Model2 can be found at [https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/ ](https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/) Dataset1 can be found at [https://huggingface.co/datasets/MakiPan/hagrid250k-blip2](https://huggingface.co/datasets/MakiPan/hagrid250k-blip2) Dataset2 can be found at [https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k](https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k) Preprocessing1 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py) Preprocessing2 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py) """) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") input_image = gr.Image(label="Input Image") # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') submit_btn = gr.Button(value = "Submit") # inputs = [prompt_input, negative_prompt, input_image] # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) with gr.Column(): output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto') gr.Examples( examples=[ [ "a woman is making an ok sign in front of a painting", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example.png" ], [ "a man with his hands up in the air making a rock sign", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example1.png" ], [ "a man is making a thumbs up gesture", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example2.png" ], [ "a woman is holding up her hand in front of a window", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example3.png" ], [ "a man with his finger on his lips", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example4.png" ], ], inputs=[prompt_input, negative_prompt, input_image], outputs=[output_image], fn=infer, cache_examples=True, ) inputs = [prompt_input, negative_prompt, input_image] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) demo.launch()