TrajectoryCrafter commited on
Commit
12e9d98
·
1 Parent(s): 0f56e8b
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -8,6 +8,8 @@ import random
8
  from inference import get_parser
9
  from datetime import datetime
10
  import argparse
 
 
11
 
12
  # 解析命令行参数
13
 
@@ -28,6 +30,15 @@ img_examples = [
28
 
29
  max_seed = 2 ** 31
30
 
 
 
 
 
 
 
 
 
 
31
  parser = get_parser() # infer_config.py
32
  opts = parser.parse_args() # default device: 'cuda:0'
33
  opts.weight_dtype = torch.bfloat16
@@ -37,6 +48,8 @@ os.makedirs(opts.save_dir,exist_ok=True)
37
  test_tensor = torch.Tensor([0]).cuda()
38
  opts.device = str(test_tensor.device)
39
 
 
 
40
  CAMERA_MOTION_MODE = ["Basic Camera Trajectory", "Custom Camera Trajectory"]
41
 
42
  def show_traj(mode):
@@ -84,14 +97,13 @@ def trajcrafter_demo(opts):
84
  }
85
  """
86
  image2video = TrajCrafter(opts,gradio=True)
87
- # image2video.run_both = spaces.GPU(image2video.run_both, duration=290) # fixme
88
  with gr.Blocks(analytics_enabled=False, css=css) as trajcrafter_iface:
89
- gr.Markdown("<div align='center'> <h1> TrajectoryCrafter: Redirecting View Trajectory for Monocular Videos via Diffusion Models </span> </h1>")
90
- # # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
91
- # # <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02048'> [ArXiv] </a>\
92
- # # <a style='font-size:18px;color: #000000' href='https://drexubery.github.io/ViewCrafter/'> [Project Page] </a>\
93
- # # <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Drexubery/ViewCrafter'> [Github] </a>\
94
- # # <a style='font-size:18px;color: #000000' href='https://www.youtube.com/watch?v=WGIEmu9eXmU'> [Video] </a> </div>")
95
 
96
 
97
  with gr.Row(equal_height=True):
@@ -277,8 +289,8 @@ def trajcrafter_demo(opts):
277
 
278
  trajcrafter_iface = trajcrafter_demo(opts)
279
  trajcrafter_iface.queue(max_size=10)
280
- # trajcrafter_iface.launch(server_name=args.server_name, max_threads=10, debug=True)
281
- trajcrafter_iface.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False, max_threads=10)
282
 
283
 
284
 
 
8
  from inference import get_parser
9
  from datetime import datetime
10
  import argparse
11
+ import spaces #fixme
12
+ from huggingface_hub import snapshot_download
13
 
14
  # 解析命令行参数
15
 
 
30
 
31
  max_seed = 2 ** 31
32
 
33
+ os.makedirs('./checkpoints/',exist_ok=True)
34
+ def download_model():
35
+ snapshot_download(repo_id="TrajectoryCrafter/TrajectoryCrafter", local_dir="checkpoints/TrajectoryCrafter", local_dir_use_symlinks=False)
36
+ snapshot_download(repo_id="tencent/DepthCrafter", local_dir="checkpoints/DepthCrafter", local_dir_use_symlinks=False)
37
+ snapshot_download(repo_id="stabilityai/stable-video-diffusion-img2vid", local_dir="checkpoints/stable-video-diffusion-img2vid", local_dir_use_symlinks=False)
38
+ snapshot_download(repo_id="alibaba-pai/CogVideoX-Fun-V1.1-5b-InP", local_dir="checkpoints/CogVideoX-Fun-V1.1-5b-InP", local_dir_use_symlinks=False)
39
+ snapshot_download(repo_id="Salesforce/blip2-opt-2.7b", local_dir="checkpoints/blip2-opt-2.7b", local_dir_use_symlinks=False)
40
+ download_model() #fixme
41
+
42
  parser = get_parser() # infer_config.py
43
  opts = parser.parse_args() # default device: 'cuda:0'
44
  opts.weight_dtype = torch.bfloat16
 
48
  test_tensor = torch.Tensor([0]).cuda()
49
  opts.device = str(test_tensor.device)
50
 
51
+
52
+
53
  CAMERA_MOTION_MODE = ["Basic Camera Trajectory", "Custom Camera Trajectory"]
54
 
55
  def show_traj(mode):
 
97
  }
98
  """
99
  image2video = TrajCrafter(opts,gradio=True)
100
+ image2video.run_gradio = spaces.GPU(image2video.run_gradio, duration=300) # fixme
101
  with gr.Blocks(analytics_enabled=False, css=css) as trajcrafter_iface:
102
+ gr.Markdown("<div align='center'> <h1> TrajectoryCrafter: Redirecting View Trajectory for Monocular Videos via Diffusion Models </span> </h1>
103
+ <a style='font-size:18px;color: #FF5DB0' href='https://github.com/TrajectoryCrafter/TrajectoryCrafter'> [Github] </a>\
104
+ # <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02048'> [ArXiv] </a>\
105
+ <a style='font-size:18px;color: #000000' href='https://trajectorycrafter.github.io/'> [Project Page] </a>\
106
+ <a style='font-size:18px;color: #000000' href='https://www.youtube.com/watch?v=dQtHFgyrids'> [Video] </a> </div>")
 
107
 
108
 
109
  with gr.Row(equal_height=True):
 
289
 
290
  trajcrafter_iface = trajcrafter_demo(opts)
291
  trajcrafter_iface.queue(max_size=10)
292
+ trajcrafter_iface.launch()
293
+ # trajcrafter_iface.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False, max_threads=10)
294
 
295
 
296