File size: 3,032 Bytes
b795d51
9c7b939
b795d51
 
9c7b939
 
 
a68f3d0
b795d51
9c7b939
 
b795d51
9c7b939
 
 
b795d51
9c7b939
 
 
b795d51
 
 
9c7b939
 
b795d51
9c7b939
 
 
 
a68f3d0
 
 
 
9c7b939
a68f3d0
b795d51
9c7b939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b795d51
 
 
9c7b939
 
 
 
 
 
 
b795d51
 
9c7b939
b795d51
9c7b939
 
b795d51
 
 
a68f3d0
b795d51
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from PIL import Image, ImageEnhance
import numpy as np
import cv2
from lang_sam import LangSAM
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer
import torch

# Load the LangSAM model
model = LangSAM()  # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")

# Function to apply color matching based on reference image
def apply_color_matching(source_img_np, ref_img_np):
    # Initialize ColorMatcher
    cm = ColorMatcher()
    
    # Apply color matching
    img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
    
    # Normalize the result
    img_res = Normalizer(img_res).uint8_norm()
    
    return img_res

# Function to extract sky and apply color matching using a reference image
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
    # Use LangSAM to predict the mask for the sky
    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)

    # Ensure masks is converted from tensor to NumPy
    masks_np = masks[0].cpu().numpy()  # Convert the tensor to NumPy array

    # Convert the mask to a binary format and create a mask image
    sky_mask = (masks_np > 0).astype(np.uint8) * 255  # Ensure it's a binary mask
    
    # Convert PIL image to numpy array for processing
    img_np = np.array(image_pil)
    
    # Convert sky mask to 3-channel format to blend with the original image
    sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
    
    # Extract the sky region
    sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
    
    # Convert the reference image to a numpy array
    ref_img_np = np.array(reference_image_pil)
    
    # Apply color matching using the reference image to the extracted sky region
    sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
    
    # Combine the color-matched sky region back into the original image
    result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
    
    # Convert the result back to PIL Image for final output
    result_img_pil = Image.fromarray(result_img_np)
    
    return result_img_pil

# Gradio Interface
def gradio_interface():
    # Gradio function to be called on input
    def process_image(source_img, ref_img):
        # Extract sky and apply color matching using reference image
        result_img_pil = extract_and_color_match_sky(source_img, ref_img)
        return result_img_pil

    # Define Gradio input components
    inputs = [
        gr.Image(type="pil", label="Source Image"),
        gr.Image(type="pil", label="Reference Image")  # Second input for reference image
    ]

    # Define Gradio output component
    outputs = gr.Image(type="pil", label="Resulting Image")

    # Launch Gradio app
    gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch(share=True)

# Run the Gradio Interface
if __name__ == "__main__":
    gradio_interface()