MEMTrack / app.py
Medha Sawhney
uploading MEMTrack codebase
34a8bb0
raw
history blame
7.58 kB
import os
import sys
import cv2
import os
import glob
import shutil
import gdown
import zipfile
# import spaces
import time
import random
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
sys.path.insert(1, "MEMTrack/src")
from data_prep_utils import process_data
from data_feature_gen import create_train_data, create_test_data
from inferenceBacteriaRetinanet_Motility_v2 import run_inference
from GenerateTrackingData import gen_tracking_data
from Tracking import track_bacteria
from TrackingAnalysis import analyse_tracking
from GenerateVideo import gen_tracking_video
def find_and_return_csv_files(folder_path, search_pattern):
search_pattern = f"{folder_path}/{search_pattern}*.csv"
csv_files = list(glob.glob(search_pattern))
return csv_files
def read_video(video, raw_frame_dir, progress=gr.Progress()):
# read video and save frames
video_dir = str(random.randint(111111111, 999999999))
images_dir = "Images without Labels"
frames_dir = os.path.join(raw_frame_dir, video_dir, images_dir)
os.makedirs(frames_dir, exist_ok=True)
count = 0
frames = []
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
processed_frames = 0
while cap.isOpened():
ret, frame = cap.read()
if ret is False:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frame_path = os.path.join(frames_dir, f"{count}.jpg")
cv2.imwrite(frame_path, frame)
frames.append(frame)
count += 1
processed_frames += 1
print(f"Processing frame {processed_frames}")
progress(processed_frames / total_frames, desc=f"Reading frame {processed_frames}/{total_frames}")
cap.release()
return video_dir
def download_and_unzip_google_drive_file(file_id, output_path, unzip_path):
url = f'https://drive.google.com/uc?id={file_id}'
url="https://drive.usercontent.google.com/download?id=1agsLD5HV_VmDNpDhjHXTCAVmGUm2IQ6p&export=download&&confirm=t"
gdown.download(url, output_path, quiet=False, )
with zipfile.ZipFile(output_path, 'r') as zip_ref:
zip_ref.extractall(unzip_path)
# @spaces.GPU()
def doo(video, progress=gr.Progress()):
# download and unzip models
file_id = '1agsLD5HV_VmDNpDhjHXTCAVmGUm2IQ6p'
output_path = 'models.zip'
unzip_path = './'
download_and_unzip_google_drive_file(file_id, output_path, unzip_path)
# Initialize paths and variables
raw_frame_dir = "raw_data/" # Path to raw videos vefore processing (same format as sample data)
final_data_dir = "data" # root directory to store processed videos
out_sub_dir = "bacteria" # sub directory to store processed videos
target_data_sub_dir = os.path.join(final_data_dir, out_sub_dir)
feature_dir = "DataFeatures" # directory to store processed videos
test_video_list = ["video1"] # list of videos to generate features for
exp_name = "collagen_motility_inference" # name of experiment
feature_data_path = os.path.join(feature_dir, exp_name)
# #path to saved models
# no_motility_model_path = "models/motility/no/collagen_optical_flow_median_bkg_more_data_90k/"
# low_motility_model_path = "models/motility/low/collagen_optical_flow_median_bkg_more_data_90k/"
# mid_motility_model_path = "models/motility/mid/collagen_optical_flow_median_bkg_more_data_90k/"
# high_motility_model_path = "models/motility/high/collagen_optical_flow_median_bkg_more_data_90k/"
# # Clear previous results and data
# if os.path.exists(final_data_dir):
# shutil.rmtree(final_data_dir)
# if os.path.exists(raw_frame_dir):
# shutil.rmtree(raw_frame_dir)
# if os.path.exists(feature_dir):
# shutil.rmtree(feature_dir)
# # Read video and store frames separately for object detection model
# video_dir = read_video(video, raw_frame_dir, progress=gr.Progress())
# # Process raw frames and store in acceptable format
# progress(1 / 3, desc=f"Processing Frames {1}/{3}")
# video_num = process_data(video_dir, raw_frame_dir, final_data_dir, out_sub_dir)
# progress(3 / 3, desc=f"Processing Frames {3}/{3}")
# # generate features for raw frames for the object detector model
# progress(1 / 3, desc=f"Generating Features {1}/{3}")
# create_test_data(target_data_sub_dir, feature_dir, exp_name, test_video_list)
# progress(3 / 3, desc=f"Features Generated {3}/{3}")
# progress(1 / 3, desc=f"Loading Models {1}/{3}")
# # Run Object Detection Code
# for video_num in [1]:
# #To genearate testing files for all motilities
# run_inference(video_num=video_num, output_dir=no_motility_model_path,
# annotations_test="All", test_dir=feature_data_path, register_dataset=True)
# progress(3 / 3, desc=f"Models Loaded{3}/{3}")
# run_inference(video_num=video_num, output_dir=mid_motility_model_path,
# annotations_test="Motility-mid", test_dir=feature_data_path, register_dataset=False)
# progress(1 / 3, desc=f"Running Bacteria Detection {1}/{3}")
# run_inference(video_num=video_num, output_dir=high_motility_model_path,
# annotations_test="Motility-high", test_dir=feature_data_path, register_dataset=False)
# progress(2 / 3, desc=f"Running Bacteria Detection {2}/{3}")
# run_inference(video_num=video_num, output_dir=low_motility_model_path,
# annotations_test="Motility-low", test_dir=feature_data_path, register_dataset=False)
# progress(3 / 3, desc=f"Running Bacteria Detection {3}/{3}")
# # Tracking where GT is present
# progress(0 / 3, desc=f"Tracking {0}/{3}")
for video_num in [1]:
# gen_tracking_data(video_num=video_num, data_path=feature_data_path, filter_thresh=0.3)
# progress(1 / 3, desc=f"Tracking {1}/{3}")
# track_bacteria(video_num=video_num, max_age=35, max_interpolation=35, data_path=feature_data_path)
# progress(2 / 3, desc=f"Tracking {2}/{3}")
folder_path = analyse_tracking(video_num=video_num, data_feature_path=feature_data_path, data_root_path=final_data_dir, plot=True)
progress(3 / 3, desc=f"Tracking {3}/{3}")
output_video = gen_tracking_video(video_num=video_num, fps=60, data_path=feature_data_path)
final_video = os.path.basename(output_video)
shutil.copy(output_video, final_video)
print(output_video)
print(final_video)
search_pattern = "TrackedRawData"
tracking_preds = find_and_return_csv_files(folder_path, search_pattern)
return final_video, tracking_preds #str(tmpname) + '.mp4'
examples = [['./sample_videos/control_4_h264.mp4']]
title = "🎞️ MEMTrack Bacteria Tracking Video Tool"
description = "Upload a video or selct from example to track. <br><br> If the input video does not play on browser, ensure its in a browser accetable format. Output will be generated iirespective of playback on browser. Refer: https://colab.research.google.com/drive/1U5pX_9iaR_T8knVV7o4ftKdDoGndCdEM?usp=sharing"
iface = gr.Interface(
fn=doo,
inputs=gr.Video(label="Input Video"),
outputs=[
gr.Video(label="Tracked Video"),
gr.File(label="CSV Data")
],
examples=examples,
title=title,
description=description
)
if __name__ == "__main__":
iface.launch(share=True)