awacke1 commited on
Commit
b0d4a94
·
verified ·
1 Parent(s): 67a63df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -175
app.py CHANGED
@@ -36,9 +36,7 @@ import tempfile
36
  from PIL import Image
37
  import io
38
  import requests
39
-
40
-
41
-
42
 
43
 
44
 
@@ -814,136 +812,164 @@ def get_video_html(video_path, width="100%"):
814
 
815
  # *********
816
 
817
- def resize_image_for_video(image_data, max_size=(1024, 1024)):
818
- """Resize image to be compatible with video generation"""
819
  try:
 
 
820
  # Convert bytes to PIL Image if needed
821
  if isinstance(image_data, bytes):
822
  img = Image.open(io.BytesIO(image_data))
823
  elif isinstance(image_data, Image.Image):
824
  img = image_data
825
  else:
826
- raise ValueError("Unsupported image data type")
827
-
 
 
828
  # Convert to RGB if necessary
829
  if img.mode != 'RGB':
 
830
  img = img.convert('RGB')
831
-
832
- # Calculate new size maintaining aspect ratio
833
- ratio = min(max_size[0] / img.size[0], max_size[1] / img.size[1])
834
- new_size = tuple(int(dim * ratio) for dim in img.size)
835
 
836
- # Resize image
837
- resized_img = img.resize(new_size, Image.Resampling.LANCZOS)
838
- return resized_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  except Exception as e:
840
- st.error(f"Error resizing image: {str(e)}")
841
  return None
842
 
843
  def generate_video_from_image(image_data, seed=None, motion_bucket_id=127, fps_id=6, max_retries=3):
844
- """Generate video from image using Stable Video Diffusion with improved error handling"""
845
- temp_files = [] # Keep track of temporary files
846
  try:
847
- # Create progress bar
848
  progress_bar = st.progress(0)
849
  status_text = st.empty()
850
 
851
- status_text.text("Preparing image...")
 
852
  progress_bar.progress(10)
853
-
854
- # Resize image
855
- resized_img = resize_image_for_video(image_data)
856
- if resized_img is None:
857
  return None, None
858
-
859
- # Save resized image to temporary file
 
 
 
 
860
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
861
  temp_files.append(temp_img.name)
862
- resized_img.save(temp_img.name, format='PNG')
863
-
 
 
 
 
 
864
  status_text.text("Connecting to video generation service...")
865
- progress_bar.progress(20)
866
-
867
- # Initialize the Gradio client with error handling
868
  client = Client(
869
  "awacke1/stable-video-diffusion",
870
- hf_token=os.environ.get("HUGGINGFACE_TOKEN") # Add your token if needed
871
  )
872
-
873
- # Get random seed if none provided
874
  if seed is None:
875
- try:
876
- seed = client.predict(api_name="/get_random_value")
877
- except Exception as e:
878
- st.warning(f"Could not get random seed, using default. Error: {str(e)}")
879
- seed = int(time.time()) # Use timestamp as fallback
880
-
881
- status_text.text("Generating video...")
882
  progress_bar.progress(40)
883
-
884
- # Attempt video generation with retries
885
- error = None
886
  for attempt in range(max_retries):
887
  try:
888
  status_text.text(f"Generating video (attempt {attempt + 1}/{max_retries})...")
889
  progress_bar.progress(40 + (attempt * 20))
890
-
891
- # First try to resize the image using the API
892
- try:
893
- resized_result = client.predict(
894
- image=temp_img.name,
895
- api_name="/resize_image"
896
- )
897
- if resized_result:
898
- temp_files.append(resized_result)
899
- input_image = resized_result
900
- else:
901
- input_image = temp_img.name
902
- except Exception as e:
903
- st.warning(f"Image resize API failed, using original image. Error: {str(e)}")
904
- input_image = temp_img.name
905
-
906
- # Generate video
907
  result = client.predict(
908
- image=input_image,
909
  seed=seed,
910
- randomize_seed=True,
911
  motion_bucket_id=motion_bucket_id,
912
  fps_id=fps_id,
913
  api_name="/video"
914
  )
915
-
 
916
  if result and isinstance(result, tuple) and len(result) >= 1:
917
- video_path = result[0].get('video') if isinstance(result[0], dict) else None
918
- if video_path:
919
- status_text.text("Video generated successfully!")
920
- progress_bar.progress(100)
921
- return video_path, seed
922
-
923
- error = f"Invalid result format on attempt {attempt + 1}"
 
924
  time.sleep(2 ** attempt) # Exponential backoff
 
925
  except Exception as e:
926
- error = str(e)
927
- st.warning(f"Attempt {attempt + 1} failed: {error}")
928
- time.sleep(2 ** attempt) # Exponential backoff
929
-
930
- raise Exception(f"Failed after {max_retries} attempts. Last error: {error}")
931
-
932
  except Exception as e:
933
  st.error(f"Error in video generation: {str(e)}")
934
  return None, None
 
935
  finally:
936
- # Cleanup temporary files
937
  for temp_file in temp_files:
938
  try:
939
  if os.path.exists(temp_file):
940
  os.unlink(temp_file)
 
941
  except Exception as e:
942
- st.warning(f"Error cleaning up temporary file {temp_file}: {str(e)}")
943
 
944
- # Add this to your main Streamlit interface, in the appropriate section:
945
  def add_video_generation_ui(container):
946
- """Add video generation UI components"""
947
  st.markdown("### 🎥 Video Generation")
948
 
949
  col1, col2 = st.columns([2, 1])
@@ -952,120 +978,99 @@ def add_video_generation_ui(container):
952
  uploaded_image = st.file_uploader(
953
  "Upload Image for Video Generation 🖼️",
954
  type=['png', 'jpg', 'jpeg'],
955
- help="Upload an image to generate a video from"
956
  )
957
 
958
  with col2:
959
- st.markdown("#### Parameters")
960
  motion_bucket_id = st.slider(
961
  "Motion Intensity 🌊",
962
  min_value=1,
963
  max_value=255,
964
  value=127,
965
- help="Controls the amount of motion in the generated video"
966
  )
967
  fps_id = st.slider(
968
  "Frames per Second 🎬",
969
  min_value=1,
970
  max_value=30,
971
  value=6,
972
- help="Controls the smoothness of the generated video"
973
  )
 
 
 
 
 
 
 
 
974
 
975
  if uploaded_image:
976
- st.image(uploaded_image, caption="Preview of uploaded image", use_column_width=True)
977
-
978
- if st.button("🎥 Generate Video", help="Click to start video generation"):
979
- with st.spinner("Processing your video... This may take a few minutes 🎬"):
980
- image_bytes = uploaded_image.read()
981
- video_path, used_seed = generate_video_from_image(
982
- image_bytes,
983
- motion_bucket_id=motion_bucket_id,
984
- fps_id=fps_id
985
- )
986
-
987
- if video_path:
988
- # Save video locally
989
- video_filename = f"generated_video_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
990
- try:
991
- shutil.copy(video_path, video_filename)
992
-
993
- # Display the generated video
994
- st.success(f"Video generated successfully! Seed: {used_seed}")
995
- st.video(video_filename)
996
-
997
- # Save to Cosmos DB
998
- if container:
999
- video_record = {
1000
- "id": generate_unique_id(),
1001
- "type": "generated_video",
1002
- "filename": video_filename,
1003
- "seed": used_seed,
1004
- "motion_bucket_id": motion_bucket_id,
1005
- "fps_id": fps_id,
1006
- "timestamp": datetime.now().isoformat()
1007
- }
1008
- success, message = insert_record(container, video_record)
1009
- if success:
1010
- st.success("Video record saved to database!")
1011
- else:
1012
- st.error(f"Error saving video record: {message}")
1013
- except Exception as e:
1014
- st.error(f"Error saving video: {str(e)}")
1015
- else:
1016
- st.error("Failed to generate video. Please try again with different parameters.")
1017
-
1018
-
1019
- # Add this to the 'Show as Run AI' section in your main function,
1020
- # right after the "🤖 Run AI" button:
1021
-
1022
- # Add image upload and video generation
1023
- st.image_uploader = st.file_uploader("Upload Image for Video Generation 🖼️", type=['png', 'jpg', 'jpeg'])
1024
- st.video_gen_params = {
1025
- 'motion_bucket_id': st.slider("Motion Intensity 🌊", 1, 255, 127),
1026
- 'fps_id': st.slider("Frames per Second 🎬", 1, 30, 6)
1027
- }
1028
-
1029
- if st.image_uploader is not None:
1030
- if st.button("🎥 Generate Video"):
1031
- with st.spinner("Generating video... 🎬"):
1032
- # Read uploaded image
1033
- image_bytes = st.image_uploader.read()
1034
-
1035
- # Generate video
1036
- video_path, used_seed = generate_video_from_image(
1037
- image_bytes,
1038
- motion_bucket_id=st.video_gen_params['motion_bucket_id'],
1039
- fps_id=st.video_gen_params['fps_id']
1040
- )
1041
 
1042
- if video_path:
1043
- # Save video to local storage
1044
- video_filename = f"generated_video_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
1045
- shutil.copy(video_path, video_filename)
1046
-
1047
- st.success(f"Video generated successfully! Seed used: {used_seed}")
1048
-
1049
- # Display the generated video
1050
- st.video(video_filename)
1051
-
1052
- # Save to Cosmos DB if needed
1053
- if container:
1054
- video_record = {
1055
- "id": generate_unique_id(),
1056
- "type": "generated_video",
1057
- "filename": video_filename,
1058
- "seed": used_seed,
1059
- "motion_bucket_id": st.video_gen_params['motion_bucket_id'],
1060
- "fps_id": st.video_gen_params['fps_id'],
1061
- "timestamp": datetime.now().isoformat()
1062
- }
1063
- success, message = insert_record(container, video_record)
1064
- if success:
1065
- st.success("Video record saved to database")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
  else:
1067
- st.error(f"Error saving video record: {message}")
1068
-
 
1069
 
1070
  # ******************************************
1071
 
 
36
  from PIL import Image
37
  import io
38
  import requests
39
+ import numpy as np
 
 
40
 
41
 
42
 
 
812
 
813
  # *********
814
 
815
+ def validate_and_preprocess_image(image_data, target_size=(576, 1024)):
816
+ """Validate and preprocess image for video generation with detailed logging"""
817
  try:
818
+ st.write("Starting image preprocessing...")
819
+
820
  # Convert bytes to PIL Image if needed
821
  if isinstance(image_data, bytes):
822
  img = Image.open(io.BytesIO(image_data))
823
  elif isinstance(image_data, Image.Image):
824
  img = image_data
825
  else:
826
+ raise ValueError(f"Unsupported image data type: {type(image_data)}")
827
+
828
+ st.write(f"Original image size: {img.size}, mode: {img.mode}")
829
+
830
  # Convert to RGB if necessary
831
  if img.mode != 'RGB':
832
+ st.write(f"Converting image from {img.mode} to RGB")
833
  img = img.convert('RGB')
 
 
 
 
834
 
835
+ # Calculate aspect ratio
836
+ aspect_ratio = img.size[0] / img.size[1]
837
+ st.write(f"Original aspect ratio: {aspect_ratio:.2f}")
838
+
839
+ # Determine target dimensions maintaining aspect ratio
840
+ if aspect_ratio > target_size[0]/target_size[1]: # Wider than target
841
+ new_width = target_size[0]
842
+ new_height = int(new_width / aspect_ratio)
843
+ else: # Taller than target
844
+ new_height = target_size[1]
845
+ new_width = int(new_height * aspect_ratio)
846
+
847
+ # Ensure dimensions are even numbers
848
+ new_width = (new_width // 2) * 2
849
+ new_height = (new_height // 2) * 2
850
+
851
+ st.write(f"Resizing to: {new_width}x{new_height}")
852
+
853
+ # Resize image using high-quality downsampling
854
+ resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
855
+
856
+ # Create white background image of target size
857
+ final_img = Image.new('RGB', target_size, (255, 255, 255))
858
+
859
+ # Calculate position to paste resized image (center)
860
+ paste_x = (target_size[0] - new_width) // 2
861
+ paste_y = (target_size[1] - new_height) // 2
862
+
863
+ # Paste resized image onto white background
864
+ final_img.paste(resized_img, (paste_x, paste_y))
865
+
866
+ st.write(f"Final image size: {final_img.size}")
867
+
868
+ # Validate final image
869
+ if final_img.size != target_size:
870
+ raise ValueError(f"Final image size {final_img.size} doesn't match target size {target_size}")
871
+
872
+ return final_img
873
+
874
  except Exception as e:
875
+ st.error(f"Error in image preprocessing: {str(e)}")
876
  return None
877
 
878
  def generate_video_from_image(image_data, seed=None, motion_bucket_id=127, fps_id=6, max_retries=3):
879
+ """Generate video from image with improved preprocessing and error handling"""
880
+ temp_files = []
881
  try:
882
+ # Set up progress tracking
883
  progress_bar = st.progress(0)
884
  status_text = st.empty()
885
 
886
+ # Preprocess image
887
+ status_text.text("Preprocessing image...")
888
  progress_bar.progress(10)
889
+
890
+ processed_img = validate_and_preprocess_image(image_data)
891
+ if processed_img is None:
892
+ st.error("Image preprocessing failed")
893
  return None, None
894
+
895
+ # Show preprocessed image
896
+ st.write("Preprocessed image preview:")
897
+ st.image(processed_img, caption="Preprocessed image", use_column_width=True)
898
+
899
+ # Save processed image
900
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
901
  temp_files.append(temp_img.name)
902
+ processed_img.save(temp_img.name, format='PNG', optimize=True)
903
+ st.write(f"Saved preprocessed image to: {temp_img.name}")
904
+
905
+ # Verify file size
906
+ file_size = os.path.getsize(temp_img.name)
907
+ st.write(f"Preprocessed image file size: {file_size/1024:.2f}KB")
908
+
909
  status_text.text("Connecting to video generation service...")
910
+ progress_bar.progress(30)
911
+
912
+ # Initialize client with debug flags
913
  client = Client(
914
  "awacke1/stable-video-diffusion",
915
+ hf_token=os.environ.get("HUGGINGFACE_TOKEN"),
916
  )
917
+
 
918
  if seed is None:
919
+ seed = int(time.time() * 1000) # Use millisecond timestamp as seed
920
+
921
+ status_text.text("Starting video generation...")
 
 
 
 
922
  progress_bar.progress(40)
923
+
 
 
924
  for attempt in range(max_retries):
925
  try:
926
  status_text.text(f"Generating video (attempt {attempt + 1}/{max_retries})...")
927
  progress_bar.progress(40 + (attempt * 20))
928
+
929
+ # Call video generation API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  result = client.predict(
931
+ image=temp_img.name,
932
  seed=seed,
933
+ randomize_seed=False, # Set to False for reproducibility
934
  motion_bucket_id=motion_bucket_id,
935
  fps_id=fps_id,
936
  api_name="/video"
937
  )
938
+
939
+ # Validate result
940
  if result and isinstance(result, tuple) and len(result) >= 1:
941
+ if isinstance(result[0], dict) and 'video' in result[0]:
942
+ video_path = result[0]['video']
943
+ if os.path.exists(video_path):
944
+ status_text.text("Video generated successfully!")
945
+ progress_bar.progress(100)
946
+ return video_path, seed
947
+
948
+ st.warning(f"Invalid result format on attempt {attempt + 1}: {result}")
949
  time.sleep(2 ** attempt) # Exponential backoff
950
+
951
  except Exception as e:
952
+ st.warning(f"Attempt {attempt + 1} failed: {str(e)}")
953
+ time.sleep(2 ** attempt)
954
+
955
+ raise Exception(f"Failed to generate video after {max_retries} attempts")
956
+
 
957
  except Exception as e:
958
  st.error(f"Error in video generation: {str(e)}")
959
  return None, None
960
+
961
  finally:
962
+ # Cleanup
963
  for temp_file in temp_files:
964
  try:
965
  if os.path.exists(temp_file):
966
  os.unlink(temp_file)
967
+ st.write(f"Cleaned up temporary file: {temp_file}")
968
  except Exception as e:
969
+ st.warning(f"Error cleaning up {temp_file}: {str(e)}")
970
 
 
971
  def add_video_generation_ui(container):
972
+ """Enhanced video generation UI with better error handling and feedback"""
973
  st.markdown("### 🎥 Video Generation")
974
 
975
  col1, col2 = st.columns([2, 1])
 
978
  uploaded_image = st.file_uploader(
979
  "Upload Image for Video Generation 🖼️",
980
  type=['png', 'jpg', 'jpeg'],
981
+ help="Upload a clear, well-lit image. Recommended size: 576x1024 pixels."
982
  )
983
 
984
  with col2:
985
+ st.markdown("#### Generation Parameters")
986
  motion_bucket_id = st.slider(
987
  "Motion Intensity 🌊",
988
  min_value=1,
989
  max_value=255,
990
  value=127,
991
+ help="Lower values create subtle movement, higher values create more dramatic motion"
992
  )
993
  fps_id = st.slider(
994
  "Frames per Second 🎬",
995
  min_value=1,
996
  max_value=30,
997
  value=6,
998
+ help="Higher values create smoother but potentially less stable videos"
999
  )
1000
+
1001
+ # Add advanced options in an expander
1002
+ with st.expander("Advanced Options"):
1003
+ use_custom_seed = st.checkbox("Use Custom Seed")
1004
+ if use_custom_seed:
1005
+ seed = st.number_input("Seed Value", value=int(time.time() * 1000))
1006
+ else:
1007
+ seed = None
1008
 
1009
  if uploaded_image:
1010
+ try:
1011
+ # Preview original image
1012
+ preview_col1, preview_col2 = st.columns(2)
1013
+ with preview_col1:
1014
+ st.write("Original Image:")
1015
+ st.image(uploaded_image, caption="Original", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
 
1017
+ # Preview preprocessed image
1018
+ with preview_col2:
1019
+ preprocessed = validate_and_preprocess_image(uploaded_image.read())
1020
+ if preprocessed:
1021
+ st.write("Preprocessed Image:")
1022
+ st.image(preprocessed, caption="Preprocessed", use_column_width=True)
1023
+ except Exception as e:
1024
+ st.error(f"Error previewing image: {str(e)}")
1025
+
1026
+ if st.button("🎥 Generate Video", help="Start video generation process"):
1027
+ try:
1028
+ with st.spinner("Processing your video... This may take a few minutes 🎬"):
1029
+ video_path, used_seed = generate_video_from_image(
1030
+ uploaded_image.read(),
1031
+ seed=seed,
1032
+ motion_bucket_id=motion_bucket_id,
1033
+ fps_id=fps_id
1034
+ )
1035
+
1036
+ if video_path and os.path.exists(video_path):
1037
+ # Save video locally
1038
+ video_filename = f"generated_video_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
1039
+ try:
1040
+ shutil.copy(video_path, video_filename)
1041
+
1042
+ # Display success and video
1043
+ st.success(f"""
1044
+ Video generated successfully! 🎉
1045
+ - Seed: {used_seed}
1046
+ - Motion Intensity: {motion_bucket_id}
1047
+ - FPS: {fps_id}
1048
+ """)
1049
+
1050
+ st.video(video_filename)
1051
+
1052
+ # Save to Cosmos DB
1053
+ if container:
1054
+ video_record = {
1055
+ "id": generate_unique_id(),
1056
+ "type": "generated_video",
1057
+ "filename": video_filename,
1058
+ "seed": used_seed,
1059
+ "motion_bucket_id": motion_bucket_id,
1060
+ "fps_id": fps_id,
1061
+ "timestamp": datetime.now().isoformat()
1062
+ }
1063
+ success, message = insert_record(container, video_record)
1064
+ if success:
1065
+ st.success("Video record saved to database!")
1066
+ else:
1067
+ st.error(f"Error saving video record: {message}")
1068
+ except Exception as e:
1069
+ st.error(f"Error saving video: {str(e)}")
1070
  else:
1071
+ st.error("Video generation failed. Please try again with different parameters.")
1072
+ except Exception as e:
1073
+ st.error(f"Error during video generation process: {str(e)}")
1074
 
1075
  # ******************************************
1076