mhamilton723 commited on
Commit
af372ca
·
verified ·
1 Parent(s): aa938bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -17,10 +17,14 @@ from os.path import join
17
 
18
  if __name__ == "__main__":
19
 
20
- os.environ['TORCH_HOME'] = '/tmp/.cache'
21
- os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
22
- sample_images_dir = "/tmp/samples"
23
- # sample_videos_dir = "samples"
 
 
 
 
24
 
25
 
26
  def download_video(url, save_path):
@@ -63,7 +67,6 @@ if __name__ == "__main__":
63
  video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
64
  height=480)
65
  video_output3 = gr.Video(label="Visual Features", height=480)
66
- video_output4 = gr.Video(label="Audio Features", height=480)
67
 
68
  models = {o: torch.hub.load("mhamilton723/DenseAV", o) for o in options}
69
 
@@ -146,7 +149,7 @@ if __name__ == "__main__":
146
  temp_video_path_3,
147
  temp_video_path_4,
148
  )
149
- return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
150
 
151
  return temp_video_path_1, temp_video_path_2, temp_video_path_3
152
 
@@ -180,9 +183,10 @@ if __name__ == "__main__":
180
  video_output3.render()
181
 
182
  submit_button.click(fn=process_video, inputs=[video_input, model_option],
183
- outputs=[video_output1, video_output2])
184
 
185
- # demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
186
 
187
- # demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
188
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
 
 
17
 
18
  if __name__ == "__main__":
19
 
20
+ mode = "hf"
21
+
22
+ if mode == "local":
23
+ sample_videos_dir = "samples"
24
+ else:
25
+ os.environ['TORCH_HOME'] = '/tmp/.cache'
26
+ os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
27
+ sample_images_dir = "/tmp/samples"
28
 
29
 
30
  def download_video(url, save_path):
 
67
  video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
68
  height=480)
69
  video_output3 = gr.Video(label="Visual Features", height=480)
 
70
 
71
  models = {o: torch.hub.load("mhamilton723/DenseAV", o) for o in options}
72
 
 
149
  temp_video_path_3,
150
  temp_video_path_4,
151
  )
152
+ # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
153
 
154
  return temp_video_path_1, temp_video_path_2, temp_video_path_3
155
 
 
183
  video_output3.render()
184
 
185
  submit_button.click(fn=process_video, inputs=[video_input, model_option],
186
+ outputs=[video_output1, video_output2, video_output3])
187
 
 
188
 
189
+ if mode == "local":
190
+ demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
191
+ else:
192
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)