samarth-ht commited on
Commit
45378e6
·
1 Parent(s): ac802d5

dropdown added

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -18,11 +18,15 @@ def process_video(
18
  guidance_scale,
19
  inference_steps,
20
  seed,
 
21
  ):
22
  # Create the temp directory if it doesn't exist
23
  output_dir = Path("./temp")
24
  output_dir.mkdir(parents=True, exist_ok=True)
25
 
 
 
 
26
  # Convert paths to absolute Path objects and normalize them
27
  video_file_path = Path(video_path)
28
  video_path = video_file_path.absolute().as_posix()
@@ -44,7 +48,7 @@ def process_video(
44
  )
45
 
46
  # Parse the arguments
47
- args = create_args(video_path, audio_path, output_path, guidance_scale, seed)
48
 
49
  try:
50
  result = main(
@@ -59,7 +63,7 @@ def process_video(
59
 
60
 
61
  def create_args(
62
- video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int
63
  ) -> argparse.Namespace:
64
  parser = argparse.ArgumentParser()
65
  parser.add_argument("--inference_ckpt_path", type=str, required=True)
@@ -72,7 +76,7 @@ def create_args(
72
  return parser.parse_args(
73
  [
74
  "--inference_ckpt_path",
75
- CHECKPOINT_PATH.absolute().as_posix(),
76
  "--video_path",
77
  video_path,
78
  "--audio_path",
@@ -86,6 +90,12 @@ def create_args(
86
  ]
87
  )
88
 
 
 
 
 
 
 
89
 
90
  # Create Gradio interface
91
  with gr.Blocks(title="SoundImage") as demo:
@@ -99,6 +109,12 @@ with gr.Blocks(title="SoundImage") as demo:
99
 
100
  with gr.Row():
101
  with gr.Column():
 
 
 
 
 
 
102
  video_input = gr.Video(label="Input Video")
103
  audio_input = gr.Audio(label="Input Audio", type="filepath")
104
 
@@ -139,6 +155,7 @@ with gr.Blocks(title="SoundImage") as demo:
139
  guidance_scale,
140
  inference_steps,
141
  seed,
 
142
  ],
143
  outputs=video_output,
144
  )
 
18
  guidance_scale,
19
  inference_steps,
20
  seed,
21
+ checkpoint_file,
22
  ):
23
  # Create the temp directory if it doesn't exist
24
  output_dir = Path("./temp")
25
  output_dir.mkdir(parents=True, exist_ok=True)
26
 
27
+ # Use selected checkpoint or fall back to default
28
+ checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
29
+
30
  # Convert paths to absolute Path objects and normalize them
31
  video_file_path = Path(video_path)
32
  video_path = video_file_path.absolute().as_posix()
 
48
  )
49
 
50
  # Parse the arguments
51
+ args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path)
52
 
53
  try:
54
  result = main(
 
63
 
64
 
65
  def create_args(
66
+ video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int, checkpoint_path: Path
67
  ) -> argparse.Namespace:
68
  parser = argparse.ArgumentParser()
69
  parser.add_argument("--inference_ckpt_path", type=str, required=True)
 
76
  return parser.parse_args(
77
  [
78
  "--inference_ckpt_path",
79
+ checkpoint_path.absolute().as_posix(),
80
  "--video_path",
81
  video_path,
82
  "--audio_path",
 
90
  ]
91
  )
92
 
93
+ # Add this function to get checkpoint files
94
+ def get_checkpoint_files():
95
+ unet_files_dir = Path("unetFiles")
96
+ if not unet_files_dir.exists():
97
+ return []
98
+ return [f.name for f in unet_files_dir.glob("*.pt")]
99
 
100
  # Create Gradio interface
101
  with gr.Blocks(title="SoundImage") as demo:
 
109
 
110
  with gr.Row():
111
  with gr.Column():
112
+ # Add checkpoint selector dropdown
113
+ checkpoint_dropdown = gr.Dropdown(
114
+ choices=get_checkpoint_files(),
115
+ label="Select Checkpoint",
116
+ value=get_checkpoint_files()[0] if get_checkpoint_files() else None
117
+ )
118
  video_input = gr.Video(label="Input Video")
119
  audio_input = gr.Audio(label="Input Audio", type="filepath")
120
 
 
155
  guidance_scale,
156
  inference_steps,
157
  seed,
158
+ checkpoint_dropdown,
159
  ],
160
  outputs=video_output,
161
  )