Spaces:
Sleeping
Sleeping
File size: 3,032 Bytes
b795d51 9c7b939 b795d51 9c7b939 a68f3d0 b795d51 9c7b939 b795d51 9c7b939 b795d51 9c7b939 b795d51 9c7b939 b795d51 9c7b939 a68f3d0 9c7b939 a68f3d0 b795d51 9c7b939 b795d51 9c7b939 b795d51 9c7b939 b795d51 9c7b939 b795d51 a68f3d0 b795d51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from PIL import Image, ImageEnhance
import numpy as np
import cv2
from lang_sam import LangSAM
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer
import torch
# Load the LangSAM model
model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
# Function to apply color matching based on reference image
def apply_color_matching(source_img_np, ref_img_np):
# Initialize ColorMatcher
cm = ColorMatcher()
# Apply color matching
img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
# Normalize the result
img_res = Normalizer(img_res).uint8_norm()
return img_res
# Function to extract sky and apply color matching using a reference image
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
# Use LangSAM to predict the mask for the sky
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
# Ensure masks is converted from tensor to NumPy
masks_np = masks[0].cpu().numpy() # Convert the tensor to NumPy array
# Convert the mask to a binary format and create a mask image
sky_mask = (masks_np > 0).astype(np.uint8) * 255 # Ensure it's a binary mask
# Convert PIL image to numpy array for processing
img_np = np.array(image_pil)
# Convert sky mask to 3-channel format to blend with the original image
sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
# Extract the sky region
sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
# Convert the reference image to a numpy array
ref_img_np = np.array(reference_image_pil)
# Apply color matching using the reference image to the extracted sky region
sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
# Combine the color-matched sky region back into the original image
result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
# Convert the result back to PIL Image for final output
result_img_pil = Image.fromarray(result_img_np)
return result_img_pil
# Gradio Interface
def gradio_interface():
# Gradio function to be called on input
def process_image(source_img, ref_img):
# Extract sky and apply color matching using reference image
result_img_pil = extract_and_color_match_sky(source_img, ref_img)
return result_img_pil
# Define Gradio input components
inputs = [
gr.Image(type="pil", label="Source Image"),
gr.Image(type="pil", label="Reference Image") # Second input for reference image
]
# Define Gradio output component
outputs = gr.Image(type="pil", label="Resulting Image")
# Launch Gradio app
gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch(share=True)
# Run the Gradio Interface
if __name__ == "__main__":
gradio_interface()
|