WASP / app.py
amanpreet7's picture
Update app.py
f1b121c verified
import streamlit as st
import requests
import tempfile
import shutil
from kaggle.api.kaggle_api_extended import KaggleApi
import os
import time
import json
import subprocess
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
from googleapiclient.errors import HttpError
def setup_kaggle_api():
try:
os.environ["KAGGLE_USERNAME"] = os.environ.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = os.environ.get("KAGGLE_KEY")
api = KaggleApi()
api.authenticate()
return api
except KeyError as e:
st.error(f"Missing environment variable: {e}")
st.stop()
except Exception as e:
st.error(f" API setup failed: {e}")
st.stop()
def setup_drive_service():
try:
credentials_json = os.environ.get("GOOGLE_SERVICE_ACCOUNT")
credentials_dict = json.loads(credentials_json)
credentials = Credentials(
token=credentials_dict["access_token"],
refresh_token=credentials_dict["refresh_token"],
client_id=credentials_dict["client_id"],
client_secret=credentials_dict["client_secret"],
token_uri=credentials_dict["token_uri"],
scopes=credentials_dict["scopes"]
)
if credentials.expired and credentials.refresh_token:
credentials.refresh(requests.Request())
drive_service = build('drive', 'v3', credentials=credentials)
return drive_service
except KeyError as e:
st.error(f"Missing GOOGLE_SERVICE_ACCOUNT environment variable: {e}")
st.stop()
except json.JSONDecodeError as e:
st.error(f"Invalid JSON in GOOGLE_SERVICE_ACCOUNT: {e}")
st.stop()
except Exception as e:
st.error(f"Google Drive auth failed: {e}")
st.stop()
def upload_to_drive(drive_service, file_path, title, folder_id=None):
try:
folder_id = os.environ.get("DRIVE_FOLDER_ID", "1T6v7Iqc90-NA-F3I-HeHDSvEaIyFibKd")
file_metadata = {'name': title, 'parents': [folder_id]}
media = MediaFileUpload(file_path, resumable=True)
file = drive_service.files().create(
body=file_metadata, media_body=media, fields='id, name'
).execute()
file_id = file['id']
drive_service.permissions().create(
fileId=file_id,
body={'type': 'anyone', 'role': 'reader'},
fields='id'
).execute()
shareable_link = f"https://drive.google.com/file/d/{file_id}/view?usp=sharing"
return file_id, shareable_link
except HttpError as e:
st.error(f"Drive upload failed: {e}")
raise
except Exception as e:
st.error(f"Unexpected error during upload: {e}")
raise
def delete_from_drive(drive_service, file_id):
try:
drive_service.files().delete(fileId=file_id).execute()
except Exception as e:
st.error(f"Failed to delete file {file_id} from Drive: {e}")
def get_bvh_from_folder(drive_service, folder_id=None):
try:
folder_id = os.environ.get("DRIVE_FOLDER_ID", "1T6v7Iqc90-NA-F3I-HeHDSvEaIyFibKd")
query = f"'{folder_id}' in parents and name contains '.bvh'"
response = drive_service.files().list(q=query, fields="files(id, name, mimeType)", pageSize=1).execute()
files = response.get('files', [])
if files:
bvh_file = files[0]
bvh_id = bvh_file['id']
bvh_url = f"https://drive.google.com/uc?id={bvh_id}"
return bvh_id, bvh_url, bvh_file['name']
return None, None, None
except Exception as e:
st.error(f"Error checking folder for BVH: {e}")
return None, None, None
def download_notebook_from_drive(drive_service, temp_dir):
try:
notebook_file_id = os.environ.get("NOTEBOOK_FILE_ID")
if not notebook_file_id:
st.error("NOTEBOOK_FILE_ID environment variable not set")
raise KeyError("NOTEBOOK_FILE_ID not set")
request = drive_service.files().get_media(fileId=notebook_file_id)
notebook_path = os.path.join(temp_dir, 'video-to-bvh-converter.ipynb')
with open(notebook_path, 'wb') as f:
f.write(request.execute()) # Directly write the downloaded content
# Add kernel specification to notebook
with open(notebook_path, 'r') as f:
notebook_content = json.load(f)
notebook_content['metadata']['kernelspec'] = {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
notebook_content['metadata']['language_info'] = {
"name": "python",
"version": "3.10.12",
"mimetype": "text/x-python",
"codemirror_mode": {"name": "ipython", "version": 3},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
with open(notebook_path, 'w') as f:
json.dump(notebook_content, f)
return notebook_path
except Exception as e:
st.error(f"Failed to download notebook from Drive: {e}")
raise
def push_kaggle_kernel(api, temp_dir, notebook_slug):
try:
drive_service = setup_drive_service()
local_notebook_path = download_notebook_from_drive(drive_service, temp_dir)
kernel_file = os.path.join(temp_dir, 'kernel.ipynb')
shutil.copy(local_notebook_path, kernel_file)
# Verify kernel spec (for debugging)
with open(kernel_file, 'r') as f:
metadata_content = json.load(f)['metadata']
code_file = "kernel.ipynb"
kernel_type = "notebook"
metadata = {
"id": notebook_slug,
"title": "video-to-bvh-converter",
"code_file": code_file,
"language": "python",
"kernel_type": kernel_type,
"enable_gpu": True,
"enable_internet": True,
"is_private": True,
"accelerator": "gpu",
"gpu_product": "T4x2",
"competition_sources": [],
"dataset_sources": ["amanu1234/pipeline"], # Add your dataset here
"kernel_sources": []
}
metadata_file = os.path.join(temp_dir, 'kernel-metadata.json')
with open(metadata_file, 'w') as f:
json.dump(metadata, f)
cmd = f"kaggle kernels push -p {temp_dir}"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode != 0:
st.error(f"Kernel push failed: {result.stderr}")
raise Exception(f"Push failed: {result.stderr}")
os.remove(local_notebook_path)
os.remove(kernel_file)
kernel_url = f"https://www.kaggle.com/code/{notebook_slug}"
return kernel_url
except Exception as e:
st.error(f"Failed to push kernel: {str(e)}")
raise
def check_kernel_exists(api, notebook_slug):
try:
kernels = api.kernels_list(mine=True, search=notebook_slug)
for kernel in kernels:
if kernel.ref == notebook_slug:
return True
return False
except Exception as e:
st.error(f"Kernel check failed: {e}")
return False
def download_and_save_bvh(bvh_url, filename):
try:
response = requests.get(bvh_url)
if response.status_code == 200:
temp_dir = tempfile.mkdtemp()
bvh_path = os.path.join(temp_dir, filename)
with open(bvh_path, 'wb') as f:
f.write(response.content)
return bvh_path, response.content
else:
st.error(f"Failed to download BVH: Status code {response.status_code}")
return None, None
except Exception as e:
st.error(f"Error downloading BVH: {e}")
return None, None
def process_video(api, drive_service, video_file):
video_file_id = None
bvh_file_id = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(video_file.read())
video_path = tmp_file.name
video_file_id, video_shareable_link = upload_to_drive(drive_service, video_path, "input_video.mp4")
st.success(f"Video uploaded to Drive: {video_shareable_link}")
username = os.environ.get("KAGGLE_USERNAME")
notebook_slug = f"{username}/video-to-bvh-converter"
kernel_exists = check_kernel_exists(api, notebook_slug)
if not kernel_exists:
pass
temp_dir = tempfile.mkdtemp()
with st.spinner("Triggering..."):
kernel_url = push_kaggle_kernel(api, temp_dir, notebook_slug)
progress_bar = st.progress(0.0)
progress_text = st.empty()
with st.spinner("Waiting for video processing..."):
start_time = time.time()
execution_started = False
retry_count = 0
max_retries = 3
overall_timeout = 1800
while time.time() - start_time < overall_timeout:
try:
status_response = api.kernels_status(notebook_slug)
current_status = status_response.status if hasattr(status_response, 'status') else 'unknown'
if current_status in ['queued', 'running']:
execution_started = True
if current_status == 'queued':
progress_bar.progress(0.2)
progress_text.text("Queued - Waiting for GPU...")
elif current_status == 'running':
progress_bar.progress(0.4)
progress_text.text("Processing video...")
elif current_status == 'complete' and not execution_started:
push_kaggle_kernel(api, temp_dir, notebook_slug)
time.sleep(10)
continue
elif current_status == 'error' and not execution_started:
if retry_count < max_retries:
time.sleep(10)
push_kaggle_kernel(api, temp_dir, notebook_slug)
retry_count += 1
start_time = time.time()
continue
else:
if video_file_id:
delete_from_drive(drive_service, video_file_id)
return None
bvh_file_id, bvh_url, bvh_filename = get_bvh_from_folder(drive_service)
if bvh_url and bvh_filename:
# Immediately download the BVH file
bvh_path, bvh_data = download_and_save_bvh(bvh_url, bvh_filename or "motion_capture.bvh")
if bvh_path and bvh_data:
# Generate a timestamp for unique filenames
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
# Create a filename with timestamp
if bvh_filename:
# Extract the base name without extension
base_name, ext = os.path.splitext(bvh_filename)
timestamped_filename = f"{base_name}_{timestamp_str}{ext}"
else:
timestamped_filename = f"motion_capture_{timestamp_str}.bvh"
progress_bar.progress(1.0)
progress_text.text("Complete!")
# Immediately delete files from Google Drive
if video_file_id:
delete_from_drive(drive_service, video_file_id)
video_file_id = None # Set to None to prevent double deletion in cleanup
if bvh_file_id:
delete_from_drive(drive_service, bvh_file_id)
bvh_file_id = None # Set to None to prevent double deletion in cleanup
st.success("Motion capture complete! BVH file ready for download.")
# Save the BVH data to session state for download
st.session_state['bvh_data'] = bvh_data
st.session_state['bvh_filename'] = timestamped_filename
st.session_state['bvh_path'] = bvh_path # Save path for deletion later
# Generate a unique timestamp for this result
st.session_state['bvh_timestamp'] = int(time.time())
return {
'bvh_data': bvh_data,
'bvh_path': bvh_path,
'bvh_filename': timestamped_filename,
'timestamp': st.session_state['bvh_timestamp']
}
if execution_started and current_status in ['complete', 'error']:
progress_bar.progress(0.8 if current_status == 'complete' else 0.6)
progress_text.text("Finalizing..." if current_status == 'complete' else "Error occurred...")
if current_status == 'error':
if video_file_id:
delete_from_drive(drive_service, video_file_id)
return None
time.sleep(10)
except Exception as e:
st.error(f"Status check failed: {str(e)}")
time.sleep(10)
if video_file_id:
delete_from_drive(drive_service, video_file_id)
return None
except Exception as e:
st.error(f"Processing error: {e}")
if video_file_id:
delete_from_drive(drive_service, video_file_id)
if bvh_file_id:
delete_from_drive(drive_service, bvh_file_id)
return None
finally:
if 'video_path' in locals():
os.unlink(video_path)
if 'temp_dir' in locals() and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def main():
st.set_page_config(
page_title="Motion Capture Studio | Video to BVH Converter",
page_icon="🎬",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
:root {
--bg-color: #1a1a1a;
--card-bg: #252525;
--primary-color: #bb86fc;
--secondary-color: #03dac6;
--error-color: #cf6679;
--text-color: #e0e0e0;
--text-secondary: #a0a0a0;
}
.stApp { background-color: var(--bg-color); }
h1, h2, h3, p, div { color: var(--text-color) !important; }
.card { background-color: var(--card-bg); border-radius: 20px; padding: 2rem; margin: 1rem auto; max-width: 1200px; }
.main-title { font-size: 3.5rem; font-weight: 900; background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); -webkit-background-clip: text; -webkit-text-fill-color: transparent; text-align: center; }
.subtitle { font-size: 1.3rem; color: var(--text-secondary); text-align: center; }
.section-title { font-size: 1.5rem; font-weight: 700; color: var(--primary-color) !important; }
.stButton > button { background: linear-gradient(135deg, var(--primary-color), #9b59f5); color: #fff !important; border-radius: 12px; padding: 0.8rem 2.5rem; font-weight: 600; font-size: 1.2rem; border: none; width: 100%; }
.stDownloadButton > button { background: linear-gradient(135deg, var(--secondary-color), #02b3a3); color: #fff !important; border-radius: 12px; padding: 0.8rem 2.5rem; font-weight: 600; font-size: 1.2rem; border: none; width: 100%; }
</style>
""", unsafe_allow_html=True)
st.markdown('<h1 class="main-title">Motion Capture Studio</h1>', unsafe_allow_html=True)
st.markdown('<p class="subtitle">Convert videos to BVH with AI</p>', unsafe_allow_html=True)
st.markdown('<p class="section-title">(Note: Every Body part should be visible correctly and face should not be covered)</p>', unsafe_allow_html=True)
api = setup_kaggle_api()
drive_service = setup_drive_service()
st.markdown('<div class="card">', unsafe_allow_html=True)
st.markdown('<h3 class="section-title"></h3>', unsafe_allow_html=True)
status_col1, status_col2 = st.columns(2)
with status_col1:
try:
username = os.environ.get("KAGGLE_USERNAME")
notebook_slug = f"{username}/video-to-bvh-converter"
kernel_exists = check_kernel_exists(api, notebook_slug)
if kernel_exists:
st.success(f"✅kernel found")
else:
st.error(f"❌kernel not found")
except Exception as e:
st.error(f"❌failed: {e}")
with status_col2:
try:
drive_about = drive_service.about().get(fields="user,storageQuota").execute()
storage_used = int(drive_about.get('storageQuota', {}).get('usage', 0)) / (1024 * 1024)
# st.success(f"✅ Google Drive: {storage_used:.2f} MB used")
except Exception as e:
st.error(f"❌Drive check failed: {e}")
col1, col2 = st.columns(2)
with col1:
st.markdown('<h3 class="section-title">Upload Video</h3>', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload a video", type=['mp4', 'avi', 'mov'])
if uploaded_file:
st.session_state['uploaded_file'] = uploaded_file
with col2:
st.markdown('<h3 class="section-title">Preview</h3>', unsafe_allow_html=True)
if uploaded_file := st.session_state.get('uploaded_file'):
st.video(uploaded_file)
if st.session_state.get('uploaded_file'):
# Create containers for progress indicators and buttons
progress_container = st.container()
button_container = st.container()
download_container = st.container()
# Place these progress indicators above the button but only show when processing
with progress_container:
# This section will be populated during processing
pass
# Place the button below the progress indicators
with button_container:
st.markdown('<h3 class="section-title">Processing Options</h3>', unsafe_allow_html=True)
if st.button("Start Motion Capture", key="start_capture_button"):
# Clear previous BVH data when starting a new process
if 'bvh_data' in st.session_state:
del st.session_state['bvh_data']
if 'bvh_filename' in st.session_state:
del st.session_state['bvh_filename']
if 'bvh_timestamp' in st.session_state:
del st.session_state['bvh_timestamp']
if 'bvh_path' in st.session_state and os.path.exists(st.session_state['bvh_path']):
try:
os.remove(st.session_state['bvh_path'])
except Exception as e:
st.warning(f"Could not delete previous local file: {e}")
# Use the progress container for progress indicators
with progress_container:
result = process_video(api, drive_service, st.session_state['uploaded_file'])
if not result or 'bvh_data' not in result:
st.error("Failed to generate BVH file.")
# If BVH data is in session state (from a previous run), offer it for download
with download_container:
if 'bvh_data' in st.session_state and 'bvh_filename' in st.session_state:
timestamp = st.session_state.get('bvh_timestamp', int(time.time()))
# Create a callback for when download completes
def on_download_complete():
if 'bvh_path' in st.session_state and os.path.exists(st.session_state['bvh_path']):
try:
os.remove(st.session_state['bvh_path'])
st.session_state['bvh_path_deleted'] = True
except Exception as e:
st.warning(f"Failed to delete local BVH file: {e}")
# This doesn't actually work as Streamlit doesn't support download callbacks
# Instead, we'll clean up at the start of a new process
download_button = st.download_button(
label="Download BVH",
data=st.session_state['bvh_data'],
file_name=st.session_state['bvh_filename'],
mime="application/octet-stream",
key=f"download_saved_{timestamp}",
on_click=on_download_complete
)
st.markdown('</div>', unsafe_allow_html=True)
if __name__ == "__main__":
main()