File size: 13,788 Bytes
81b1a0e
509862d
6bafd2d
6284dc0
 
509862d
 
 
 
 
 
 
 
 
 
 
e797135
509862d
6bafd2d
 
 
 
 
63d9326
 
 
 
 
 
 
 
 
 
 
 
6bafd2d
 
 
 
 
 
 
 
5a7c5d5
 
 
 
 
 
6bafd2d
5a7c5d5
6bafd2d
5a7c5d5
6bafd2d
5a7c5d5
6bafd2d
5a7c5d5
6bafd2d
5a7c5d5
6bafd2d
 
 
5a7c5d5
 
 
6bafd2d
 
509862d
 
53ff575
81b1a0e
509862d
 
 
d967d62
509862d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8da09d2
132dae6
1ba1ac4
132dae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ba1ac4
 
 
 
 
 
6b21c48
1ba1ac4
 
 
 
 
 
 
 
1f5deb3
1ba1ac4
5a7c5d5
 
6fc9c48
5a7c5d5
 
 
 
 
 
 
 
6fc9c48
 
 
 
 
 
 
1ba1ac4
 
509862d
6bafd2d
132dae6
 
 
 
 
 
 
 
509862d
132dae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3847cbf
 
509862d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132dae6
3847cbf
509862d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3847cbf
509862d
6fc9c48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import gradio as gr
import torch
import spaces
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import moviepy.editor as mp
from PIL import Image
import numpy as np
import tempfile
import time
import os
import shutil
import ffmpeg
from concurrent.futures import ThreadPoolExecutor
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts

# Custom Theme Definition
class WhiteTheme(Base):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.orange,
        font: fonts.Font | str | tuple[fonts.Font | str, ...] = (
            fonts.GoogleFont("Inter"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        ),
        font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = (
            fonts.GoogleFont("Inter"),
            "ui-monospace",
            "system-ui",
            "monospace",
        )
    ):
        super().__init__(
            primary_hue=primary_hue,
            font=font,
            font_mono=font_mono,
        )
        
        self.set(
            # Light mode specific colors
            background_fill_primary="*primary_50",
            background_fill_secondary="white",
            border_color_primary="*primary_300",
            
            # General colors that should stay constant
            body_background_fill="white",
            body_background_fill_dark="white",
            block_background_fill="white",
            block_background_fill_dark="white",
            panel_background_fill="white",
            panel_background_fill_dark="white",
            body_text_color="black",
            body_text_color_dark="black",
            block_label_text_color="black",
            block_label_text_color_dark="black",
            block_border_color="white",
            panel_border_color="white",
            input_border_color="lightgray",
            input_background_fill="white",
            input_background_fill_dark="white",
            shadow_drop="none"
        )

# Set precision and device
torch.set_float32_matmul_precision("medium")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load models
print("Loading models...")
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
birefnet.to(device)
birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
birefnet_lite.to(device)
print("Models loaded successfully!")

# Image transformation
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

def process_frame(frame, fast_mode=True):
    """
    Process a single frame through the BiRefNet model.
    Maintains original resolution throughout processing.
    Returns a PIL Image with alpha channel.
    """
    try:
        # Preserve original resolution for final output
        image_ori = Image.fromarray(frame).convert('RGB')
        original_size = image_ori.size
        
        # Transform for model input while maintaining aspect ratio
        input_images = transform_image(image_ori).unsqueeze(0).to(device)
        
        # Select model based on mode
        model = birefnet_lite if fast_mode else birefnet
        
        with torch.no_grad():
            preds = model(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        
        # Resize mask back to original resolution
        pred_pil = transforms.ToPILImage()(pred)
        pred_pil = pred_pil.resize(original_size, Image.BICUBIC)
        
        # Create foreground with transparency
        foreground = image_ori.copy()
        foreground.putalpha(pred_pil)
        
        return foreground
    except Exception as e:
        print(f"Error processing frame: {e}")
        return None

@spaces.GPU(duration=300)  # 5-minute duration for processing
def process_video(video_path, fps=0, fast_mode=True, max_workers=6):
    """
    Process video to create transparent MOV file using ProRes 4444.
    Maintains original resolution and framerate if fps=0.
    """
    temp_dir = None
    try:
        start_time = time.time()
        video = mp.VideoFileClip(video_path)
        
        # Use original video FPS if not specified
        if fps == 0:
            fps = video.fps
            
        frames = list(video.iter_frames(fps=fps))
        total_frames = len(frames)
        
        print(f"Processing {total_frames} frames at {fps} FPS...")
        
        # Create temporary directory for PNG sequence
        temp_dir = tempfile.mkdtemp()
        png_dir = os.path.join(temp_dir, "frames")
        os.makedirs(png_dir, exist_ok=True)
        
        # Prepare to collect processed frames for live preview
        processed_frames = []
        
        # Process frames with parallel execution
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames]
            for i, future in enumerate(futures):
                try:
                    result = future.result()
                    if result:
                        # Save frame as PNG with transparency
                        frame_path = os.path.join(png_dir, f"frame_{i:06d}.png")
                        result.save(frame_path, "PNG")
                        
                        # Collect processed frames for live preview
                        processed_frames.append(np.array(result))
                        
                        # Update live preview
                        elapsed_time = time.time() - start_time
                        yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
                        
                    if (i + 1) % 10 == 0:
                        print(f"Processed {i+1}/{total_frames} frames")
                except Exception as e:
                    print(f"Error processing frame {i+1}: {e}")
        
        print("Creating output files...")
        # Create permanent output directory
        output_dir = os.path.join(os.path.dirname(video_path), "output")
        os.makedirs(output_dir, exist_ok=True)
        
        # Create ZIP file of PNG sequence
        zip_filename = f"frames_{int(time.time())}.zip"
        zip_path = os.path.join(output_dir, zip_filename)
        shutil.make_archive(zip_path[:-4], 'zip', png_dir)
        
        # Create MOV file with ProRes 4444
        print("Creating ProRes 4444 MOV...")
        mov_filename = f"video_{int(time.time())}.mov"
        mov_path = os.path.join(output_dir, mov_filename)
        
        try:
            # FFmpeg settings for high-quality ProRes 4444
            stream = ffmpeg.input(
                os.path.join(png_dir, 'frame_%06d.png'),
                pattern_type='sequence',
                framerate=fps
            )
            
            # ProRes 4444 settings for maximum quality with alpha
            stream = ffmpeg.output(
                stream,
                mov_path,
                vcodec='prores_ks',          # ProRes codec
                pix_fmt='yuva444p10le',      # 10-bit 4:4:4:4 pixel format with alpha
                profile='4444',              # ProRes 4444 profile for alpha support
                alpha_bits=16,               # Maximum alpha bit depth
                qscale=1,                    # Highest quality setting
                vendor='ap10',               # Standard ProRes vendor tag
                bits_per_mb=8000,            # High bitrate for quality
                threads=max_workers          # Parallel processing
            )
            
            # Run FFmpeg command
            ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
            print("MOV video created successfully!")
            
        except ffmpeg.Error as e:
            print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}")
            mov_path = None

        print("Processing complete!")
        # Yield the final outputs
        yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
        
    except Exception as e:
        print(f"Error: {e}")
        yield None, None, None, None, f"Error processing video: {e}"
    finally:
        # Clean up temporary directory
        if temp_dir and os.path.exists(temp_dir):
            try:
                shutil.rmtree(temp_dir)
            except Exception as e:
                print(f"Error cleaning up temp directory: {e}")

@spaces.GPU(duration=300)  # Match process_video duration
def process_wrapper(video, fps=0, fast_mode=True, max_workers=6):
    if video is None:
        raise gr.Error("Please upload a video.")
    try:
        for outputs in process_video(video, fps, fast_mode, max_workers):
            yield outputs
    except Exception as e:
        raise gr.Error(f"Error processing video: {str(e)}")

# Custom CSS for styling
custom_css = """
.title-container {
    text-align: center;
    padding: 10px 0;
}

#title {
    font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
    font-size: 36px;
    font-weight: bold;
    color: #000000;
    padding: 10px;
    border-radius: 10px;
    display: inline-block;
    background: linear-gradient(
        135deg,
        #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
        #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
    );
    background-size: 400% 400%;
    animation: gradient-animation 15s ease infinite;
}

@keyframes gradient-animation {
    0% { background-position: 0% 50%; }
    50% { background-position: 100% 50%; }
    100% { background-position: 0% 50%; }
}

#submit-button {
    background: linear-gradient(
        135deg,
        #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
        #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
    );
    background-size: 400% 400%;
    animation: gradient-animation 15s ease infinite;
    border-radius: 12px;
    color: black;
}

/* Force light mode styles */
:root, :root[data-theme='light'], :root[data-theme='dark'] {
    --body-background-fill: white !important;
    --background-fill-primary: white !important;
    --background-fill-secondary: white !important;
    --block-background-fill: white !important;
    --panel-background-fill: white !important;
    --body-text-color: black !important;
    --block-label-text-color: black !important;
}

/* Additional overrides for dark mode */
@media (prefers-color-scheme: dark) {
    :root {
        color-scheme: light;
    }
}
"""

# Gradio Interface
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
    gr.HTML('''
        <div class="title-container">
            <div id="title">
                <span>{.</span><span id="typed-text"></span><span>}</span>
            </div>
        </div>
        <script>
            (function() {
                const text = "video";
                const typedTextSpan = document.getElementById("typed-text");
                let charIndex = 0;
                
                function type() {
                    if (charIndex < text.length) {
                        typedTextSpan.textContent += text[charIndex];
                        charIndex++;
                        setTimeout(type, 150);
                    }
                }
                
                setTimeout(type, 150);
            })();
        </script>
    ''')
    
    with gr.Row():
        with gr.Column():
            video_input = gr.Video(
                label="Upload Video",
                interactive=True,
                show_label=True,
                height=360,
                width=640
            )
            with gr.Row():
                fps_slider = gr.Slider(
                    minimum=0,
                    maximum=60,
                    step=1,
                    value=0,
                    label="Output FPS (0 will inherit the original fps value)",
                )
                fast_mode_checkbox = gr.Checkbox(
                    label="Fast Mode (Use BiRefNet_lite)", 
                    value=True
                )
                max_workers_slider = gr.Slider(
                    minimum=1,
                    maximum=32,
                    step=1,
                    value=6,
                    label="Max Workers",
                    info="Determines how many frames to process in parallel"
                )
            btn = gr.Button("Process Video", elem_id="submit-button")
        
        with gr.Column():
            preview_image = gr.Image(label="Live Preview", show_label=True)
            output_foreground_zip = gr.File(label="Download PNG Sequence (ZIP)")
            output_foreground_video = gr.File(label="Download Video (ProRes 4444 MOV with transparency)")
            output_background = gr.Video(label="Background (Coming Soon)")
            time_textbox = gr.Textbox(label="Status", interactive=False)
            
            gr.Markdown("""
            ### Output Information
            - MOV file uses ProRes 4444 codec for professional-grade alpha channel
            - Original resolution and framerate are maintained
            - PNG sequence provided for maximum compatibility
            """)
    
    btn.click(
        fn=process_wrapper,
        inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider],
        outputs=[preview_image, output_foreground_zip, output_foreground_video,
                output_background, time_textbox]
    )

if __name__ == "__main__":
    demo.launch(debug=True)