awacke1 commited on
Commit
67a63df
·
verified ·
1 Parent(s): 43d9435

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -79
app.py CHANGED
@@ -35,6 +35,8 @@ from gradio_client import Client, handle_file
35
  import tempfile
36
  from PIL import Image
37
  import io
 
 
38
 
39
 
40
 
@@ -812,48 +814,207 @@ def get_video_html(video_path, width="100%"):
812
 
813
  # *********
814
 
 
 
 
 
 
 
 
 
 
 
815
 
 
 
 
816
 
 
 
 
 
 
 
 
 
 
 
817
 
818
-
819
- def generate_video_from_image(image_data, seed=None, motion_bucket_id=127, fps_id=6):
820
- """Generate video from image using Stable Video Diffusion"""
821
  try:
822
- # Create a temporary file to save the uploaded image
 
 
 
 
 
 
 
 
 
 
 
 
823
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
824
- # If image_data is a PIL Image, convert to bytes
825
- if isinstance(image_data, Image.Image):
826
- img_byte_arr = io.BytesIO()
827
- image_data.save(img_byte_arr, format='PNG')
828
- temp_img.write(img_byte_arr.getvalue())
829
- else:
830
- temp_img.write(image_data)
831
- temp_img_path = temp_img.name
832
 
833
- # Initialize the Gradio client
834
- client = Client("awacke1/stable-video-diffusion")
 
 
 
 
 
 
835
 
836
  # Get random seed if none provided
837
  if seed is None:
838
- seed = client.predict(api_name="/get_random_value")
839
-
840
- # Generate video
841
- result = client.predict(
842
- image=temp_img_path,
843
- seed=seed,
844
- randomize_seed=True,
845
- motion_bucket_id=motion_bucket_id,
846
- fps_id=fps_id,
847
- api_name="/video"
848
- )
849
 
850
- # result[0] contains Dict with video path
851
- video_path = result[0]['video']
852
- return video_path, result[1] # Return video path and used seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
 
854
  except Exception as e:
855
- st.error(f"Error generating video: {str(e)}")
856
  return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
  # Add this to the 'Show as Run AI' section in your main function,
859
  # right after the "🤖 Run AI" button:
@@ -1180,56 +1341,6 @@ def main():
1180
  except Exception as e:
1181
  st.error(f"Error processing deletion: {str(e)}")
1182
 
1183
- elif selected_view == 'Show as Code Editor - Old':
1184
- Label = '#### 💻 Code editor view'
1185
- st.markdown(Label)
1186
- total_docs = len(documents)
1187
- doc = documents[st.session_state.current_index]
1188
- # st.markdown(f"#### Document ID: {doc.get('id', '')}")
1189
- doc_str = st.text_area("Edit Document",
1190
- value=json.dumps(doc, indent=2),
1191
- height=300,
1192
- key=f'code_editor_{st.session_state.current_index}')
1193
-
1194
- col_prev, col_next = st.columns([1, 1])
1195
- with col_prev:
1196
- if st.button("⬅️ Previous", key='prev_code'):
1197
- if st.session_state.current_index > 0:
1198
- st.session_state.current_index -= 1
1199
- st.rerun()
1200
- with col_next:
1201
- if st.button("➡️ Next", key='next_code'):
1202
- if st.session_state.current_index < total_docs - 1:
1203
- st.session_state.current_index += 1
1204
- st.rerun()
1205
-
1206
- col_save, col_delete = st.columns([1, 1])
1207
- with col_save:
1208
- if st.button("💾 Save Changes", key=f'save_button_{st.session_state.current_index}'):
1209
- try:
1210
- updated_doc = json.loads(doc_str)
1211
- response = container.upsert_item(body=updated_doc)
1212
- if response:
1213
- st.success(f"Document {updated_doc['id']} saved successfully.")
1214
- st.session_state.selected_document_id = updated_doc['id']
1215
- st.rerun()
1216
- except Exception as e:
1217
- st.error(f"Error saving document: {str(e)}")
1218
-
1219
- with col_delete:
1220
- if st.button("🗑️ Delete", key=f'delete_button_{st.session_state.current_index}'):
1221
- try:
1222
- current_doc = json.loads(doc_str)
1223
- # Direct deletion using container method with id and partition key
1224
- delete = container.delete_item(current_doc["id"], current_doc["id"])
1225
- if delete:
1226
- st.success(f"Document {current_doc['id']} deleted successfully.")
1227
- if st.session_state.current_index > 0:
1228
- st.session_state.current_index -= 1
1229
- st.rerun()
1230
- except Exception as e:
1231
- st.error(f"Error deleting document: {str(e)}")
1232
-
1233
 
1234
 
1235
 
@@ -1257,6 +1368,12 @@ def main():
1257
 
1258
  # Save and AI operations columns
1259
 
 
 
 
 
 
 
1260
  if st.button("🤖 Run AI", key=f'run_with_ai_button_{idx}'):
1261
  # Your existing AI processing code here
1262
  values_with_space = []
 
35
  import tempfile
36
  from PIL import Image
37
  import io
38
+ import requests
39
+
40
 
41
 
42
 
 
814
 
815
  # *********
816
 
817
+ def resize_image_for_video(image_data, max_size=(1024, 1024)):
818
+ """Resize image to be compatible with video generation"""
819
+ try:
820
+ # Convert bytes to PIL Image if needed
821
+ if isinstance(image_data, bytes):
822
+ img = Image.open(io.BytesIO(image_data))
823
+ elif isinstance(image_data, Image.Image):
824
+ img = image_data
825
+ else:
826
+ raise ValueError("Unsupported image data type")
827
 
828
+ # Convert to RGB if necessary
829
+ if img.mode != 'RGB':
830
+ img = img.convert('RGB')
831
 
832
+ # Calculate new size maintaining aspect ratio
833
+ ratio = min(max_size[0] / img.size[0], max_size[1] / img.size[1])
834
+ new_size = tuple(int(dim * ratio) for dim in img.size)
835
+
836
+ # Resize image
837
+ resized_img = img.resize(new_size, Image.Resampling.LANCZOS)
838
+ return resized_img
839
+ except Exception as e:
840
+ st.error(f"Error resizing image: {str(e)}")
841
+ return None
842
 
843
+ def generate_video_from_image(image_data, seed=None, motion_bucket_id=127, fps_id=6, max_retries=3):
844
+ """Generate video from image using Stable Video Diffusion with improved error handling"""
845
+ temp_files = [] # Keep track of temporary files
846
  try:
847
+ # Create progress bar
848
+ progress_bar = st.progress(0)
849
+ status_text = st.empty()
850
+
851
+ status_text.text("Preparing image...")
852
+ progress_bar.progress(10)
853
+
854
+ # Resize image
855
+ resized_img = resize_image_for_video(image_data)
856
+ if resized_img is None:
857
+ return None, None
858
+
859
+ # Save resized image to temporary file
860
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
861
+ temp_files.append(temp_img.name)
862
+ resized_img.save(temp_img.name, format='PNG')
 
 
 
 
 
 
863
 
864
+ status_text.text("Connecting to video generation service...")
865
+ progress_bar.progress(20)
866
+
867
+ # Initialize the Gradio client with error handling
868
+ client = Client(
869
+ "awacke1/stable-video-diffusion",
870
+ hf_token=os.environ.get("HUGGINGFACE_TOKEN") # Add your token if needed
871
+ )
872
 
873
  # Get random seed if none provided
874
  if seed is None:
875
+ try:
876
+ seed = client.predict(api_name="/get_random_value")
877
+ except Exception as e:
878
+ st.warning(f"Could not get random seed, using default. Error: {str(e)}")
879
+ seed = int(time.time()) # Use timestamp as fallback
880
+
881
+ status_text.text("Generating video...")
882
+ progress_bar.progress(40)
 
 
 
883
 
884
+ # Attempt video generation with retries
885
+ error = None
886
+ for attempt in range(max_retries):
887
+ try:
888
+ status_text.text(f"Generating video (attempt {attempt + 1}/{max_retries})...")
889
+ progress_bar.progress(40 + (attempt * 20))
890
+
891
+ # First try to resize the image using the API
892
+ try:
893
+ resized_result = client.predict(
894
+ image=temp_img.name,
895
+ api_name="/resize_image"
896
+ )
897
+ if resized_result:
898
+ temp_files.append(resized_result)
899
+ input_image = resized_result
900
+ else:
901
+ input_image = temp_img.name
902
+ except Exception as e:
903
+ st.warning(f"Image resize API failed, using original image. Error: {str(e)}")
904
+ input_image = temp_img.name
905
+
906
+ # Generate video
907
+ result = client.predict(
908
+ image=input_image,
909
+ seed=seed,
910
+ randomize_seed=True,
911
+ motion_bucket_id=motion_bucket_id,
912
+ fps_id=fps_id,
913
+ api_name="/video"
914
+ )
915
+
916
+ if result and isinstance(result, tuple) and len(result) >= 1:
917
+ video_path = result[0].get('video') if isinstance(result[0], dict) else None
918
+ if video_path:
919
+ status_text.text("Video generated successfully!")
920
+ progress_bar.progress(100)
921
+ return video_path, seed
922
+
923
+ error = f"Invalid result format on attempt {attempt + 1}"
924
+ time.sleep(2 ** attempt) # Exponential backoff
925
+ except Exception as e:
926
+ error = str(e)
927
+ st.warning(f"Attempt {attempt + 1} failed: {error}")
928
+ time.sleep(2 ** attempt) # Exponential backoff
929
+
930
+ raise Exception(f"Failed after {max_retries} attempts. Last error: {error}")
931
 
932
  except Exception as e:
933
+ st.error(f"Error in video generation: {str(e)}")
934
  return None, None
935
+ finally:
936
+ # Cleanup temporary files
937
+ for temp_file in temp_files:
938
+ try:
939
+ if os.path.exists(temp_file):
940
+ os.unlink(temp_file)
941
+ except Exception as e:
942
+ st.warning(f"Error cleaning up temporary file {temp_file}: {str(e)}")
943
+
944
+ # Add this to your main Streamlit interface, in the appropriate section:
945
+ def add_video_generation_ui(container):
946
+ """Add video generation UI components"""
947
+ st.markdown("### 🎥 Video Generation")
948
+
949
+ col1, col2 = st.columns([2, 1])
950
+
951
+ with col1:
952
+ uploaded_image = st.file_uploader(
953
+ "Upload Image for Video Generation 🖼️",
954
+ type=['png', 'jpg', 'jpeg'],
955
+ help="Upload an image to generate a video from"
956
+ )
957
+
958
+ with col2:
959
+ st.markdown("#### Parameters")
960
+ motion_bucket_id = st.slider(
961
+ "Motion Intensity 🌊",
962
+ min_value=1,
963
+ max_value=255,
964
+ value=127,
965
+ help="Controls the amount of motion in the generated video"
966
+ )
967
+ fps_id = st.slider(
968
+ "Frames per Second 🎬",
969
+ min_value=1,
970
+ max_value=30,
971
+ value=6,
972
+ help="Controls the smoothness of the generated video"
973
+ )
974
+
975
+ if uploaded_image:
976
+ st.image(uploaded_image, caption="Preview of uploaded image", use_column_width=True)
977
+
978
+ if st.button("🎥 Generate Video", help="Click to start video generation"):
979
+ with st.spinner("Processing your video... This may take a few minutes 🎬"):
980
+ image_bytes = uploaded_image.read()
981
+ video_path, used_seed = generate_video_from_image(
982
+ image_bytes,
983
+ motion_bucket_id=motion_bucket_id,
984
+ fps_id=fps_id
985
+ )
986
+
987
+ if video_path:
988
+ # Save video locally
989
+ video_filename = f"generated_video_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
990
+ try:
991
+ shutil.copy(video_path, video_filename)
992
+
993
+ # Display the generated video
994
+ st.success(f"Video generated successfully! Seed: {used_seed}")
995
+ st.video(video_filename)
996
+
997
+ # Save to Cosmos DB
998
+ if container:
999
+ video_record = {
1000
+ "id": generate_unique_id(),
1001
+ "type": "generated_video",
1002
+ "filename": video_filename,
1003
+ "seed": used_seed,
1004
+ "motion_bucket_id": motion_bucket_id,
1005
+ "fps_id": fps_id,
1006
+ "timestamp": datetime.now().isoformat()
1007
+ }
1008
+ success, message = insert_record(container, video_record)
1009
+ if success:
1010
+ st.success("Video record saved to database!")
1011
+ else:
1012
+ st.error(f"Error saving video record: {message}")
1013
+ except Exception as e:
1014
+ st.error(f"Error saving video: {str(e)}")
1015
+ else:
1016
+ st.error("Failed to generate video. Please try again with different parameters.")
1017
+
1018
 
1019
  # Add this to the 'Show as Run AI' section in your main function,
1020
  # right after the "🤖 Run AI" button:
 
1341
  except Exception as e:
1342
  st.error(f"Error processing deletion: {str(e)}")
1343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1344
 
1345
 
1346
 
 
1368
 
1369
  # Save and AI operations columns
1370
 
1371
+
1372
+ # Video Generator call - the video generation UI for container:
1373
+ add_video_generation_ui(container)
1374
+
1375
+
1376
+
1377
  if st.button("🤖 Run AI", key=f'run_with_ai_button_{idx}'):
1378
  # Your existing AI processing code here
1379
  values_with_space = []