Damian Stewart commited on
Commit
6dc9635
·
1 Parent(s): 5329ade

batching sample generation and cancellation support

Browse files
Files changed (2) hide show
  1. app.py +23 -10
  2. train.py +36 -19
app.py CHANGED
@@ -76,7 +76,7 @@ class Demo:
76
  label="Seed",
77
  value=42
78
  )
79
- with gr.Row(scale=1):
80
  self.img_width_infr = gr.Slider(
81
  label="Image width",
82
  minimum=256,
@@ -92,7 +92,7 @@ class Demo:
92
  step=64
93
  )
94
 
95
- with gr.Row(scale=1):
96
  self.model_dropdown = gr.Dropdown(
97
  label="ESD Model",
98
  choices= list(model_map.keys()),
@@ -152,6 +152,15 @@ class Demo:
152
  info="Image size for training, should match the model's native image size"
153
  )
154
 
 
 
 
 
 
 
 
 
 
155
  self.prompt_input = gr.Text(
156
  placeholder="Enter prompt...",
157
  label="Prompt to Erase",
@@ -313,6 +322,7 @@ class Demo:
313
  self.train_use_gradient_checkpointing_input,
314
  self.train_seed_input,
315
  self.train_save_every_input,
 
316
  self.train_validation_prompts,
317
  self.train_sample_positive_prompts,
318
  self.train_sample_negative_prompts,
@@ -322,7 +332,8 @@ class Demo:
322
  )
323
  self.train_cancel_button.click(self.cancel_training,
324
  inputs=[],
325
- outputs=[self.train_cancel_button])
 
326
 
327
  self.export_button.click(self.export, inputs = [
328
  self.model_dropdown_export,
@@ -340,12 +351,14 @@ class Demo:
340
  return [self.model_dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
341
 
342
  def cancel_training(self):
343
- train.training_should_cancel = True
344
- return [gr.update(value="Cancelling...", interactive=False)]
 
 
345
 
346
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
347
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
348
- seed=-1, save_every=-1,
349
  validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
350
  pbar=gr.Progress(track_tqdm=True)):
351
  """
@@ -373,8 +386,6 @@ class Demo:
373
  if self.training:
374
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
375
 
376
- train.training_should_cancel = False
377
-
378
  print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
379
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
380
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
@@ -403,8 +414,8 @@ class Demo:
403
  break
404
  # repeat until a not-in-use path is found
405
 
406
- validation_prompts = [] if validation_prompts is None else validation_prompts.split('\n')
407
- sample_positive_prompts = [] if sample_positive_prompts is None else sample_positive_prompts.split('\n')
408
  sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
409
  print(f"validation prompts: {validation_prompts}")
410
  print(f"sample positive prompts: {sample_positive_prompts}")
@@ -413,9 +424,11 @@ class Demo:
413
  try:
414
  self.training = True
415
  self.train_cancel_button.update(interactive=True)
 
416
  save_path = train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
417
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
418
  seed=int(seed), save_every_n_steps=int(save_every),
 
419
  validate_every_n_steps=validate_every_n_steps, validation_prompts=validation_prompts,
420
  sample_positive_prompts=sample_positive_prompts, sample_negative_prompts=sample_negative_prompts)
421
  if save_path is None:
 
76
  label="Seed",
77
  value=42
78
  )
79
+ with gr.Row():
80
  self.img_width_infr = gr.Slider(
81
  label="Image width",
82
  minimum=256,
 
92
  step=64
93
  )
94
 
95
+ with gr.Row():
96
  self.model_dropdown = gr.Dropdown(
97
  label="ESD Model",
98
  choices= list(model_map.keys()),
 
152
  info="Image size for training, should match the model's native image size"
153
  )
154
 
155
+ self.train_sample_batch_size_input = gr.Slider(
156
+ value=1,
157
+ step=1,
158
+ minimum=1,
159
+ maximum=32,
160
+ label="Sample generation batch size",
161
+ info="Batch size for sample generation, larger needs more VRAM"
162
+ )
163
+
164
  self.prompt_input = gr.Text(
165
  placeholder="Enter prompt...",
166
  label="Prompt to Erase",
 
322
  self.train_use_gradient_checkpointing_input,
323
  self.train_seed_input,
324
  self.train_save_every_input,
325
+ self.train_sample_batch_size_input,
326
  self.train_validation_prompts,
327
  self.train_sample_positive_prompts,
328
  self.train_sample_negative_prompts,
 
332
  )
333
  self.train_cancel_button.click(self.cancel_training,
334
  inputs=[],
335
+ outputs=[self.train_cancel_button],
336
+ cancels=[train_event])
337
 
338
  self.export_button.click(self.export, inputs = [
339
  self.model_dropdown_export,
 
351
  return [self.model_dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
352
 
353
  def cancel_training(self):
354
+ if self.training:
355
+ training_should_cancel.release()
356
+ print("cancellation requested...")
357
+ return [gr.update(value="Cancelling...", interactive=True)]
358
 
359
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
360
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
361
+ seed=-1, save_every=-1, sample_batch_size=1,
362
  validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
363
  pbar=gr.Progress(track_tqdm=True)):
364
  """
 
386
  if self.training:
387
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
388
 
 
 
389
  print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
390
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
391
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
 
414
  break
415
  # repeat until a not-in-use path is found
416
 
417
+ validation_prompts = [] if validation_prompts is None else [p for p in validation_prompts.split('\n') if len(p)>0]
418
+ sample_positive_prompts = [] if sample_positive_prompts is None else [p for p in sample_positive_prompts.split('\n') if len(p)>0]
419
  sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
420
  print(f"validation prompts: {validation_prompts}")
421
  print(f"sample positive prompts: {sample_positive_prompts}")
 
424
  try:
425
  self.training = True
426
  self.train_cancel_button.update(interactive=True)
427
+ batch_size = 1 # other batch sizes are non-functional
428
  save_path = train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
429
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
430
  seed=int(seed), save_every_n_steps=int(save_every),
431
+ batch_size=int(batch_size), sample_batch_size=int(sample_batch_size),
432
  validate_every_n_steps=validate_every_n_steps, validation_prompts=validation_prompts,
433
  sample_positive_prompts=sample_positive_prompts, sample_negative_prompts=sample_negative_prompts)
434
  if save_path is None:
train.py CHANGED
@@ -1,5 +1,6 @@
1
  import os.path
2
  import random
 
3
 
4
  from accelerate.utils import set_seed
5
  from diffusers import StableDiffusionPipeline
@@ -15,7 +16,7 @@ from isolate_rng import isolate_rng
15
  from memory_efficiency import MemoryEfficiencyWrapper
16
  from torch.utils.tensorboard import SummaryWriter
17
 
18
- training_should_cancel = False
19
 
20
  def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
21
  validation_embeddings: torch.FloatTensor,
@@ -24,8 +25,11 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
24
  logger: SummaryWriter, use_amp: bool,
25
  global_step: int,
26
  validation_seed: int = 555,
 
 
27
  ):
28
  print("validating...")
 
29
  with isolate_rng(include_cuda=True), torch.no_grad():
30
  set_seed(validation_seed)
31
  criteria = torch.nn.MSELoss()
@@ -33,14 +37,14 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
33
  val_count = 5
34
 
35
  nsteps=50
36
- num_validation_prompts = validation_embeddings.shape[0] // 2
37
 
38
- for i in tqdm(range(num_validation_prompts)):
39
- if training_should_cancel:
40
  print("cancel requested, bailing")
41
  return
42
  accumulated_loss = None
43
- this_validation_embeddings = validation_embeddings[i*2:i*2+2]
44
  for j in range(val_count):
45
  iteration = random.randint(1, nsteps)
46
  diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp)
@@ -55,12 +59,11 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
55
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
56
  accumulated_loss = (accumulated_loss or 0) + loss.item()
57
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
58
- pbar.step()
59
 
60
- num_samples = sample_embeddings.shape[0] // 2
61
- for i in tqdm(range(0, num_samples)):
62
- print(f'making sample {i}...')
63
- if training_should_cancel:
64
  print("cancel requested, bailing")
65
  return
66
  with finetuner:
@@ -72,10 +75,16 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
72
  safety_checker=None,
73
  feature_extractor=None,
74
  requires_safety_checker=False)
75
- images = pipeline(prompt_embeds=sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=sample_embeddings[i*2:i*2+1],
 
 
 
 
 
76
  num_inference_steps=50)
77
- image_tensor = transforms.ToTensor()(images.images[0])
78
- logger.add_image(f"samples/{i}", img_tensor=image_tensor, global_step=global_step)
 
79
 
80
  """
81
  with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
@@ -90,6 +99,7 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
90
 
91
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
92
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
 
93
  save_every_n_steps=-1, validate_every_n_steps=-1,
94
  validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
95
 
@@ -101,8 +111,6 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
101
  neutral_latents = None
102
  positive_latents = None
103
 
104
- global training_should_cancel
105
-
106
  nsteps = 50
107
  print(f"using img_size of {img_size}")
108
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
@@ -135,6 +143,13 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
135
  validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
136
  sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
137
 
 
 
 
 
 
 
 
138
  #if use_amp:
139
  # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16)
140
 
@@ -151,14 +166,15 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
151
  validation_embeddings=validation_embeddings,
152
  sample_embeddings=sample_embeddings,
153
  neutral_embeddings=neutral_text_embeddings,
154
- logger=logger, use_amp=False, global_step=0)
 
155
 
156
  prev_losses = []
157
  start_loss = None
158
  max_prev_loss_count = 10
159
  try:
160
  for i in pbar:
161
- if training_should_cancel:
162
  print("cancel requested, bailing")
163
  return None
164
 
@@ -210,7 +226,8 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
210
  validation_embeddings=validation_embeddings,
211
  sample_embeddings=sample_embeddings,
212
  neutral_embeddings=neutral_text_embeddings,
213
- logger=logger, use_amp=False, global_step=i)
 
214
  torch.save(finetuner.state_dict(), save_path)
215
  return save_path
216
  finally:
@@ -220,7 +237,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
220
 
221
  def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp):
222
  diffuser.set_scheduler_timesteps(nsteps)
223
- latents = diffuser.get_initial_latents(1, n_prompts=1)
224
  latents_steps, _ = diffuser.diffusion(
225
  latents,
226
  text_embeddings,
 
1
  import os.path
2
  import random
3
+ import multiprocessing
4
 
5
  from accelerate.utils import set_seed
6
  from diffusers import StableDiffusionPipeline
 
16
  from memory_efficiency import MemoryEfficiencyWrapper
17
  from torch.utils.tensorboard import SummaryWriter
18
 
19
+ training_should_cancel = multiprocessing.Semaphore(0)
20
 
21
  def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
22
  validation_embeddings: torch.FloatTensor,
 
25
  logger: SummaryWriter, use_amp: bool,
26
  global_step: int,
27
  validation_seed: int = 555,
28
+ batch_size: int = 1,
29
+ sample_batch_size: int = 1 # might need to be smaller than batch_size
30
  ):
31
  print("validating...")
32
+ assert batch_size==1, "batch_size != 1 not implemented work"
33
  with isolate_rng(include_cuda=True), torch.no_grad():
34
  set_seed(validation_seed)
35
  criteria = torch.nn.MSELoss()
 
37
  val_count = 5
38
 
39
  nsteps=50
40
+ num_validation_batches = validation_embeddings.shape[0] // (batch_size*2)
41
 
42
+ for i in tqdm(range(num_validation_batches)):
43
+ if training_should_cancel.acquire(block=False):
44
  print("cancel requested, bailing")
45
  return
46
  accumulated_loss = None
47
+ this_validation_embeddings = validation_embeddings[i*batch_size*2:(i+1)*batch_size*2]
48
  for j in range(val_count):
49
  iteration = random.randint(1, nsteps)
50
  diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp)
 
59
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
60
  accumulated_loss = (accumulated_loss or 0) + loss.item()
61
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
 
62
 
63
+ num_sample_batches = sample_embeddings.shape[0] // (sample_batch_size*2)
64
+ for i in tqdm(range(0, num_sample_batches)):
65
+ print(f'making sample batch {i}...')
66
+ if training_should_cancel.acquire(block=False):
67
  print("cancel requested, bailing")
68
  return
69
  with finetuner:
 
75
  safety_checker=None,
76
  feature_extractor=None,
77
  requires_safety_checker=False)
78
+ batch_start = (i * sample_batch_size)*2
79
+ next_batch_start = batch_start + sample_batch_size*2 + 1
80
+ batch_negative_prompt_embeds = torch.cat([sample_embeddings[i+0:i+1] for i in range(batch_start, next_batch_start, 2)])
81
+ batch_prompt_embeds = torch.cat([sample_embeddings[i+1:i+2] for i in range(batch_start, next_batch_start, 2)])
82
+ images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2],
83
+ negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1],
84
  num_inference_steps=50)
85
+ for j in range(sample_batch_size):
86
+ image_tensor = transforms.ToTensor()(images.images[j])
87
+ logger.add_image(f"samples/{i*sample_batch_size+j}", img_tensor=image_tensor, global_step=global_step)
88
 
89
  """
90
  with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
 
99
 
100
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
101
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
102
+ batch_size=1, sample_batch_size=1,
103
  save_every_n_steps=-1, validate_every_n_steps=-1,
104
  validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
105
 
 
111
  neutral_latents = None
112
  positive_latents = None
113
 
 
 
114
  nsteps = 50
115
  print(f"using img_size of {img_size}")
116
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
 
143
  validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
144
  sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
145
 
146
+ for i, validation_prompt in enumerate(validation_prompts):
147
+ logger.add_text(f"val/{i}", f"validation prompt: \"{validation_prompt}\"")
148
+ for i in range(len(sample_positive_prompts)):
149
+ positive_prompt = sample_positive_prompts[i]
150
+ negative_prompt = "" if i >= len(sample_negative_prompts) else sample_negative_prompts[i]
151
+ logger.add_text(f"sample/{i}", f"sample prompt: \"{positive_prompt}\", negative: \"{negative_prompt}\"")
152
+
153
  #if use_amp:
154
  # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16)
155
 
 
166
  validation_embeddings=validation_embeddings,
167
  sample_embeddings=sample_embeddings,
168
  neutral_embeddings=neutral_text_embeddings,
169
+ logger=logger, use_amp=False, global_step=0,
170
+ batch_size=batch_size, sample_batch_size=sample_batch_size)
171
 
172
  prev_losses = []
173
  start_loss = None
174
  max_prev_loss_count = 10
175
  try:
176
  for i in pbar:
177
+ if training_should_cancel.acquire(block=False):
178
  print("cancel requested, bailing")
179
  return None
180
 
 
226
  validation_embeddings=validation_embeddings,
227
  sample_embeddings=sample_embeddings,
228
  neutral_embeddings=neutral_text_embeddings,
229
+ logger=logger, use_amp=False, global_step=i,
230
+ batch_size=batch_size, sample_batch_size=sample_batch_size)
231
  torch.save(finetuner.state_dict(), save_path)
232
  return save_path
233
  finally:
 
237
 
238
  def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp):
239
  diffuser.set_scheduler_timesteps(nsteps)
240
+ latents = diffuser.get_initial_latents(len(text_embeddings)//2, n_prompts=1)
241
  latents_steps, _ = diffuser.diffusion(
242
  latents,
243
  text_embeddings,