File size: 5,765 Bytes
8e36162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

import cv2
import numpy as np
import random
import os

def extract_video_frames(video_path, n_frames=30, frame_size=(96, 96)):
    """
    Extracts frames from a video, handling various lengths and potential errors.

    Args:
        video_path (str): Path to the video file.
        n_frames (int): The target number of frames to extract.
        frame_size (tuple): The target (width, height) for each frame.

    Returns:
        np.ndarray: An array of shape (n_frames, height, width, 3) with normalized
                    pixel values (0-1), or None if extraction fails critically.
                    Frames will be padded if the video is too short or has read errors.
    """
    if not os.path.exists(video_path):
        print(f"Error: Video file not found at {video_path}")
        return None

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}")
        return None

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Basic validation
    if total_frames < 1:
        print(f"Warning: Video has {total_frames} frames. Cannot extract.")
        cap.release()
        # Return array of zeros matching the expected shape
        return np.zeros((n_frames, *frame_size[::-1], 3), dtype=np.float32)
    if fps < 1:
        print(f"Warning: Video has invalid FPS ({fps}). Proceeding, but timing might be off.")
        # Use a default assumption if FPS is invalid but frames exist
        fps = 30.0 # Or another sensible default

    frames = []
    extracted_count = 0
    last_good_frame_processed = None # Store the last successfully processed frame

    # Calculate indices of frames to attempt extraction (evenly spaced)
    # Ensure indices are within the valid range [0, total_frames - 1]
    indices = np.linspace(0, total_frames - 1, n_frames, dtype=int)

    for i, frame_index in enumerate(indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = cap.read()

        processed_frame = None
        if ret and frame is not None:
            try:
                # Process valid frame
                frame_resized = cv2.resize(frame, frame_size)
                frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
                processed_frame = frame_rgb.astype(np.float32) / 255.0
                last_good_frame_processed = processed_frame # Update last good frame
                extracted_count += 1
            except cv2.error as e:
                print(f"Warning: OpenCV error processing frame {frame_index}: {e}")
                # Fallback to last good frame if available
                if last_good_frame_processed is not None:
                    processed_frame = last_good_frame_processed.copy()
                else: # If no good frame seen yet, create a placeholder
                    processed_frame = np.zeros((*frame_size[::-1], 3), dtype=np.float32)
            except Exception as e:
                 print(f"Warning: Unexpected error processing frame {frame_index}: {e}")
                 if last_good_frame_processed is not None:
                    processed_frame = last_good_frame_processed.copy()
                 else:
                    processed_frame = np.zeros((*frame_size[::-1], 3), dtype=np.float32)

        else:
            # Handle read failure (e.g., end of video reached early, corrupted frame)
            print(f"Warning: Failed to read frame at index {frame_index}. Using fallback.")
            if last_good_frame_processed is not None:
                processed_frame = last_good_frame_processed.copy()
            else:
                # If read fails and no previous frame exists, use a zero frame
                processed_frame = np.zeros((*frame_size[::-1], 3), dtype=np.float32)

        frames.append(processed_frame)

    cap.release()

    if extracted_count == 0 and total_frames > 0:
         print("Warning: Failed to extract or process any valid frames, returning array of zeros.")
         # This case should ideally be covered by fallbacks, but as a safeguard:
         return np.zeros((n_frames, *frame_size[::-1], 3), dtype=np.float32)

    # Ensure the final output always has n_frames by padding if necessary
    # (This should technically be handled by the loop logic now, but double-check)
    final_frames = np.array(frames)
    if final_frames.shape[0] < n_frames:
        print(f"Warning: Padding needed, final array shape {final_frames.shape} vs target {n_frames}")
        if final_frames.shape[0] == 0: # If somehow array is empty
             padding = np.zeros((n_frames, *frame_size[::-1], 3), dtype=np.float32)
        else:
            padding_needed = n_frames - final_frames.shape[0]
            # Use the very last frame in the list (could be a fallback frame) for padding
            last_frame_for_padding = final_frames[-1][np.newaxis, ...]
            padding = np.repeat(last_frame_for_padding, padding_needed, axis=0)
        final_frames = np.concatenate((final_frames, padding), axis=0)
    elif final_frames.shape[0] > n_frames:
        # Should not happen with linspace logic, but truncate if it does
        print(f"Warning: More frames than expected ({final_frames.shape[0]}), truncating to {n_frames}")
        final_frames = final_frames[:n_frames]


    # Final check of output shape
    if final_frames.shape != (n_frames, frame_size[1], frame_size[0], 3):
         print(f"Error: Final frame array shape mismatch! Expected {(n_frames, frame_size[1], frame_size[0], 3)}, Got {final_frames.shape}")
         # Attempt to reshape or return None/zeros? Returning zeros is safer.
         return np.zeros((n_frames, *frame_size[::-1], 3), dtype=np.float32)


    return final_frames