Damian Stewart commited on
Commit
0002379
·
1 Parent(s): ac5ee04

support for different base models

Browse files
Files changed (3) hide show
  1. StableDiffuser.py +29 -52
  2. app.py +56 -15
  3. train.py +7 -6
StableDiffuser.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
 
3
  import torch
4
  from baukit import TraceDict
@@ -36,71 +37,68 @@ def default_parser():
36
  class StableDiffuser(torch.nn.Module):
37
 
38
  def __init__(self,
39
- scheduler='LMS'
 
40
  ):
41
 
42
  super().__init__()
43
 
44
  # Load the autoencoder model which will be used to decode the latents into image space.
45
  self.vae = AutoencoderKL.from_pretrained(
46
- "CompVis/stable-diffusion-v1-4", subfolder="vae")
47
 
48
  # Load the tokenizer and text encoder to tokenize and encode the text.
49
  self.tokenizer = CLIPTokenizer.from_pretrained(
50
- "openai/clip-vit-large-patch14")
51
  self.text_encoder = CLIPTextModel.from_pretrained(
52
- "openai/clip-vit-large-patch14")
53
 
54
  # The UNet model for generating the latents.
55
  self.unet = UNet2DConditionModel.from_pretrained(
56
- "CompVis/stable-diffusion-v1-4", subfolder="unet")
57
-
58
- self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
59
- self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
 
 
 
 
 
60
 
61
  if scheduler == 'LMS':
62
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
63
  elif scheduler == 'DDIM':
64
- self.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
65
  elif scheduler == 'DDPM':
66
- self.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
67
 
68
  self.eval()
69
 
70
  def get_noise(self, batch_size, img_size, generator=None):
71
-
72
  param = list(self.parameters())[0]
73
-
74
  return torch.randn(
75
  (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
76
  generator=generator).type(param.dtype).to(param.device)
77
 
78
  def add_noise(self, latents, noise, step):
79
-
80
  return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
81
 
82
  def text_tokenize(self, prompts):
83
-
84
  return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
85
 
86
  def text_detokenize(self, tokens):
87
-
88
  return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
89
 
90
  def text_encode(self, tokens):
91
-
92
  return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
93
 
94
  def decode(self, latents):
95
-
96
  return self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample
97
 
98
  def encode(self, tensors):
99
-
100
  return self.vae.encode(tensors).latent_dist.mode() * 0.18215
101
 
102
  def to_image(self, image):
103
-
104
  image = (image / 2 + 0.5).clamp(0, 1)
105
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
106
  images = (image * 255).round().astype("uint8")
@@ -112,25 +110,16 @@ class StableDiffuser(torch.nn.Module):
112
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
113
 
114
  def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
115
-
116
  noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
117
-
118
  latents = noise * self.scheduler.init_noise_sigma
119
-
120
  return latents
121
 
122
- def get_text_embeddings(self, prompts, n_imgs):
123
-
124
  text_tokens = self.text_tokenize(prompts)
125
-
126
  text_embeddings = self.text_encode(text_tokens)
127
-
128
- unconditional_tokens = self.text_tokenize([""] * len(prompts))
129
-
130
  unconditional_embeddings = self.text_encode(unconditional_tokens)
131
-
132
  text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
133
-
134
  return text_embeddings
135
 
136
  def predict_noise(self,
@@ -174,9 +163,7 @@ class StableDiffuser(torch.nn.Module):
174
  trace = None
175
 
176
  for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
177
-
178
  if trace_args:
179
-
180
  trace = TraceDict(self, **trace_args)
181
 
182
  noise_pred = self.predict_noise(
@@ -189,17 +176,13 @@ class StableDiffuser(torch.nn.Module):
189
  output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
190
 
191
  if trace_args:
192
-
193
  trace.close()
194
-
195
  trace_steps.append(trace)
196
 
197
  latents = output.prev_sample
198
 
199
  if return_steps or iteration == end_iteration - 1:
200
-
201
  output = output.pred_original_sample if pred_x0 else latents
202
-
203
  if return_steps:
204
  latents_steps.append(output.cpu())
205
  else:
@@ -210,6 +193,7 @@ class StableDiffuser(torch.nn.Module):
210
  @torch.no_grad()
211
  def __call__(self,
212
  prompts,
 
213
  img_size=512,
214
  n_steps=50,
215
  n_imgs=1,
@@ -221,17 +205,12 @@ class StableDiffuser(torch.nn.Module):
221
  assert 0 <= n_steps <= 1000
222
 
223
  if not isinstance(prompts, list):
224
-
225
  prompts = [prompts]
226
 
227
  self.set_scheduler_timesteps(n_steps)
228
-
229
  latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
230
-
231
- text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
232
-
233
  end_iteration = end_iteration or n_steps
234
-
235
  latents_steps, trace_steps = self.diffusion(
236
  latents,
237
  text_embeddings,
@@ -242,19 +221,18 @@ class StableDiffuser(torch.nn.Module):
242
  latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
243
  images_steps = [self.to_image(latents) for latents in latents_steps]
244
 
245
- for i in range(len(images_steps)):
246
- self.safety_checker = self.safety_checker.float()
247
- safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
248
- image, has_nsfw_concept = self.safety_checker(
249
- images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
250
- )
251
-
252
- images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
253
 
254
  images_steps = list(zip(*images_steps))
255
 
256
  if trace_steps:
257
-
258
  return images_steps, trace_steps
259
 
260
  return images_steps
@@ -263,7 +241,6 @@ class StableDiffuser(torch.nn.Module):
263
  if __name__ == '__main__':
264
 
265
  parser = default_parser()
266
-
267
  args = parser.parse_args()
268
 
269
  diffuser = StableDiffuser(seed=args.seed, scheduler='DDIM').to(torch.device(args.device)).half()
 
1
  import argparse
2
+ import traceback
3
 
4
  import torch
5
  from baukit import TraceDict
 
37
  class StableDiffuser(torch.nn.Module):
38
 
39
  def __init__(self,
40
+ scheduler='LMS',
41
+ repo_id_or_path="CompVis/stable-diffusion-v1-4",
42
  ):
43
 
44
  super().__init__()
45
 
46
  # Load the autoencoder model which will be used to decode the latents into image space.
47
  self.vae = AutoencoderKL.from_pretrained(
48
+ repo_id_or_path, subfolder="vae")
49
 
50
  # Load the tokenizer and text encoder to tokenize and encode the text.
51
  self.tokenizer = CLIPTokenizer.from_pretrained(
52
+ repo_id_or_path, subfolder="tokenizer")
53
  self.text_encoder = CLIPTextModel.from_pretrained(
54
+ repo_id_or_path, subfolder="text_encoder")
55
 
56
  # The UNet model for generating the latents.
57
  self.unet = UNet2DConditionModel.from_pretrained(
58
+ repo_id_or_path, subfolder="unet")
59
+
60
+ try:
61
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(repo_id_or_path, subfolder="feature_extractor")
62
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id_or_path, subfolder="safety_checker")
63
+ except Exception as error:
64
+ print(f"caught exception {error} making feature extractor / safety checker")
65
+ self.feature_extractor = None
66
+ self.safety_checker = None
67
 
68
  if scheduler == 'LMS':
69
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
70
  elif scheduler == 'DDIM':
71
+ self.scheduler = DDIMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
72
  elif scheduler == 'DDPM':
73
+ self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
74
 
75
  self.eval()
76
 
77
  def get_noise(self, batch_size, img_size, generator=None):
 
78
  param = list(self.parameters())[0]
 
79
  return torch.randn(
80
  (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
81
  generator=generator).type(param.dtype).to(param.device)
82
 
83
  def add_noise(self, latents, noise, step):
 
84
  return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
85
 
86
  def text_tokenize(self, prompts):
 
87
  return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
88
 
89
  def text_detokenize(self, tokens):
 
90
  return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
91
 
92
  def text_encode(self, tokens):
 
93
  return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
94
 
95
  def decode(self, latents):
 
96
  return self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample
97
 
98
  def encode(self, tensors):
 
99
  return self.vae.encode(tensors).latent_dist.mode() * 0.18215
100
 
101
  def to_image(self, image):
 
102
  image = (image / 2 + 0.5).clamp(0, 1)
103
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
104
  images = (image * 255).round().astype("uint8")
 
110
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
111
 
112
  def get_initial_latents(self, n_imgs, img_size, n_prompts, generator=None):
 
113
  noise = self.get_noise(n_imgs, img_size, generator=generator).repeat(n_prompts, 1, 1, 1)
 
114
  latents = noise * self.scheduler.init_noise_sigma
 
115
  return latents
116
 
117
+ def get_text_embeddings(self, prompts, negative_prompts, n_imgs):
 
118
  text_tokens = self.text_tokenize(prompts)
 
119
  text_embeddings = self.text_encode(text_tokens)
120
+ unconditional_tokens = self.text_tokenize(negative_prompts)
 
 
121
  unconditional_embeddings = self.text_encode(unconditional_tokens)
 
122
  text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
 
123
  return text_embeddings
124
 
125
  def predict_noise(self,
 
163
  trace = None
164
 
165
  for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
 
166
  if trace_args:
 
167
  trace = TraceDict(self, **trace_args)
168
 
169
  noise_pred = self.predict_noise(
 
176
  output = self.scheduler.step(noise_pred, self.scheduler.timesteps[iteration], latents)
177
 
178
  if trace_args:
 
179
  trace.close()
 
180
  trace_steps.append(trace)
181
 
182
  latents = output.prev_sample
183
 
184
  if return_steps or iteration == end_iteration - 1:
 
185
  output = output.pred_original_sample if pred_x0 else latents
 
186
  if return_steps:
187
  latents_steps.append(output.cpu())
188
  else:
 
193
  @torch.no_grad()
194
  def __call__(self,
195
  prompts,
196
+ negative_prompts,
197
  img_size=512,
198
  n_steps=50,
199
  n_imgs=1,
 
205
  assert 0 <= n_steps <= 1000
206
 
207
  if not isinstance(prompts, list):
 
208
  prompts = [prompts]
209
 
210
  self.set_scheduler_timesteps(n_steps)
 
211
  latents = self.get_initial_latents(n_imgs, img_size, len(prompts), generator=generator)
212
+ text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
 
 
213
  end_iteration = end_iteration or n_steps
 
214
  latents_steps, trace_steps = self.diffusion(
215
  latents,
216
  text_embeddings,
 
221
  latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
222
  images_steps = [self.to_image(latents) for latents in latents_steps]
223
 
224
+ if self.safety_checker is not None:
225
+ for i in range(len(images_steps)):
226
+ self.safety_checker = self.safety_checker.float()
227
+ safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
228
+ image, has_nsfw_concept = self.safety_checker(
229
+ images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
230
+ )
231
+ images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
232
 
233
  images_steps = list(zip(*images_steps))
234
 
235
  if trace_steps:
 
236
  return images_steps, trace_steps
237
 
238
  return images_steps
 
241
  if __name__ == '__main__':
242
 
243
  parser = default_parser()
 
244
  args = parser.parse_args()
245
 
246
  diffuser = StableDiffuser(seed=args.seed, scheduler='DDIM').to(torch.device(args.device)).half()
app.py CHANGED
@@ -1,20 +1,27 @@
1
  import gradio as gr
2
  import torch
 
3
  from finetuning import FineTunedModel
4
  from StableDiffuser import StableDiffuser
5
  from train import train
6
 
7
  import os
8
- model_map = {'Van Gogh' : 'models/vangogh.pt',
9
  'Pablo Picasso': 'models/pablopicasso.pt',
10
- 'Car' : 'models/car.pt',
11
  'Garbage Truck': 'models/garbagetruck.pt',
12
  'French Horn': 'models/frenchhorn.pt',
13
- 'Kilian Eng' : 'models/kilianeng.pt',
14
- 'Thomas Kinkade' : 'models/thomaskinkade.pt',
15
- 'Tyler Edlin' : 'models/tyleredlin.pt',
16
  'Kelly McKernan': 'models/kellymckernan.pt',
17
  'Rembrandt': 'models/rembrandt.pt' }
 
 
 
 
 
 
18
 
19
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
20
  SPACE_ID = os.getenv('SPACE_ID')
@@ -31,8 +38,6 @@ class Demo:
31
  self.training = False
32
  self.generating = False
33
 
34
- self.diffuser = StableDiffuser(scheduler='DDIM').to('cuda').eval().half()
35
-
36
  with gr.Blocks() as demo:
37
  self.layout()
38
  demo.queue(concurrency_count=5).launch()
@@ -64,6 +69,9 @@ class Demo:
64
  label="Prompt",
65
  info="Prompt to generate"
66
  )
 
 
 
67
 
68
  with gr.Row():
69
 
@@ -78,6 +86,19 @@ class Demo:
78
  label="Seed",
79
  value=42
80
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  with gr.Column(scale=2):
83
 
@@ -108,6 +129,21 @@ class Demo:
108
 
109
  with gr.Column(scale=3):
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  self.prompt_input = gr.Text(
112
  placeholder="Enter prompt...",
113
  label="Prompt to Erase",
@@ -156,8 +192,11 @@ class Demo:
156
 
157
  self.infr_button.click(self.inference, inputs = [
158
  self.prompt_input_infr,
 
159
  self.seed_infr,
160
- self.model_dropdown
 
 
161
  ],
162
  outputs=[
163
  self.image_new,
@@ -165,6 +204,8 @@ class Demo:
165
  ]
166
  )
167
  self.train_button.click(self.train, inputs = [
 
 
168
  self.prompt_input,
169
  self.train_method_input,
170
  self.neg_guidance_input,
@@ -174,7 +215,7 @@ class Demo:
174
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
175
  )
176
 
177
- def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
178
 
179
  if self.training:
180
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
@@ -200,7 +241,7 @@ class Demo:
200
 
201
  self.training = True
202
 
203
- train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
204
 
205
  self.training = False
206
 
@@ -211,22 +252,21 @@ class Demo:
211
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
212
 
213
 
214
- def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
215
 
216
  seed = seed or 42
217
-
218
  generator = torch.manual_seed(seed)
219
-
220
  model_path = model_map[model_name]
221
-
222
  checkpoint = torch.load(model_path)
223
 
 
224
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
225
-
226
  torch.cuda.empty_cache()
227
 
228
  images = self.diffuser(
229
  prompt,
 
 
230
  n_steps=50,
231
  generator=generator
232
  )
@@ -242,6 +282,7 @@ class Demo:
242
 
243
  images = self.diffuser(
244
  prompt,
 
245
  n_steps=50,
246
  generator=generator
247
  )
 
1
  import gradio as gr
2
  import torch
3
+ import os
4
  from finetuning import FineTunedModel
5
  from StableDiffuser import StableDiffuser
6
  from train import train
7
 
8
  import os
9
+ model_map = {'Van Gogh': 'models/vangogh.pt',
10
  'Pablo Picasso': 'models/pablopicasso.pt',
11
+ 'Car': 'models/car.pt',
12
  'Garbage Truck': 'models/garbagetruck.pt',
13
  'French Horn': 'models/frenchhorn.pt',
14
+ 'Kilian Eng': 'models/kilianeng.pt',
15
+ 'Thomas Kinkade': 'models/thomaskinkade.pt',
16
+ 'Tyler Edlin': 'models/tyleredlin.pt',
17
  'Kelly McKernan': 'models/kellymckernan.pt',
18
  'Rembrandt': 'models/rembrandt.pt' }
19
+ for model_file in os.listdir('models'):
20
+ path = 'models/' + model_file
21
+ if any([existing_path == path for existing_path in model_map.values()]):
22
+ continue
23
+ model_map[model_file] = path
24
+
25
 
26
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
27
  SPACE_ID = os.getenv('SPACE_ID')
 
38
  self.training = False
39
  self.generating = False
40
 
 
 
41
  with gr.Blocks() as demo:
42
  self.layout()
43
  demo.queue(concurrency_count=5).launch()
 
69
  label="Prompt",
70
  info="Prompt to generate"
71
  )
72
+ self.negative_prompt_input_infr = gr.Text(
73
+ label="Negative prompt"
74
+ )
75
 
76
  with gr.Row():
77
 
 
86
  label="Seed",
87
  value=42
88
  )
89
+ self.img_size_infr = gr.Slider(
90
+ label="Image size",
91
+ minimum=256,
92
+ maximum=1024,
93
+ value=512,
94
+ step=64
95
+ )
96
+
97
+ self.base_repo_id_or_path_input_infr = gr.Text(
98
+ label="Base model",
99
+ value="CompVis/stable-diffusion-v1-4",
100
+ info="Path or huggingface repo id of the base model that this edit was done against"
101
+ )
102
 
103
  with gr.Column(scale=2):
104
 
 
129
 
130
  with gr.Column(scale=3):
131
 
132
+ self.train_model_input = gr.Text(
133
+ label="Model to Edit",
134
+ value="CompVis/stable-diffusion-v1-4",
135
+ info="Path or huggingface repo id of the model to edit"
136
+ )
137
+
138
+ self.train_img_size_input = gr.Slider(
139
+ value=512,
140
+ step=64,
141
+ minimum=256,
142
+ maximum=1024,
143
+ label="Image Size",
144
+ info="Image size for training, should match the model's native image size"
145
+ )
146
+
147
  self.prompt_input = gr.Text(
148
  placeholder="Enter prompt...",
149
  label="Prompt to Erase",
 
192
 
193
  self.infr_button.click(self.inference, inputs = [
194
  self.prompt_input_infr,
195
+ self.negative_prompt_input_infr,
196
  self.seed_infr,
197
+ self.img_size_infr,
198
+ self.model_dropdown,
199
+ self.base_repo_id_or_path_input_infr
200
  ],
201
  outputs=[
202
  self.image_new,
 
204
  ]
205
  )
206
  self.train_button.click(self.train, inputs = [
207
+ self.train_model_input,
208
+ self.train_img_size_input,
209
  self.prompt_input,
210
  self.train_method_input,
211
  self.neg_guidance_input,
 
215
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
216
  )
217
 
218
+ def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
219
 
220
  if self.training:
221
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
 
241
 
242
  self.training = True
243
 
244
+ train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
245
 
246
  self.training = False
247
 
 
252
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
253
 
254
 
255
+ def inference(self, prompt, negative_prompt, seed, img_size, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
256
 
257
  seed = seed or 42
 
258
  generator = torch.manual_seed(seed)
 
259
  model_path = model_map[model_name]
 
260
  checkpoint = torch.load(model_path)
261
 
262
+ self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
263
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
 
264
  torch.cuda.empty_cache()
265
 
266
  images = self.diffuser(
267
  prompt,
268
+ negative_prompt,
269
+ img_size=img_size,
270
  n_steps=50,
271
  generator=generator
272
  )
 
282
 
283
  images = self.diffuser(
284
  prompt,
285
+ negative_prompt,
286
  n_steps=50,
287
  generator=generator
288
  )
train.py CHANGED
@@ -3,11 +3,11 @@ from finetuning import FineTunedModel
3
  import torch
4
  from tqdm import tqdm
5
 
6
- def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
7
 
8
  nsteps = 50
9
 
10
- diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
13
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
@@ -28,17 +28,16 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
28
 
29
  torch.cuda.empty_cache()
30
 
 
 
31
  for i in pbar:
32
-
33
  with torch.no_grad():
34
-
35
  diffuser.set_scheduler_timesteps(nsteps)
36
-
37
  optimizer.zero_grad()
38
 
39
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
40
 
41
- latents = diffuser.get_initial_latents(1, 512, 1)
42
 
43
  with finetuner:
44
 
@@ -80,6 +79,8 @@ if __name__ == '__main__':
80
 
81
  parser = argparse.ArgumentParser()
82
 
 
 
83
  parser.add_argument('--prompt', required=True)
84
  parser.add_argument('--modules', required=True)
85
  parser.add_argument('--freeze_modules', nargs='+', required=True)
 
3
  import torch
4
  from tqdm import tqdm
5
 
6
+ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
7
 
8
  nsteps = 50
9
 
10
+ diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
11
  diffuser.train()
12
 
13
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
 
28
 
29
  torch.cuda.empty_cache()
30
 
31
+ print(f"using img_size of {img_size}")
32
+
33
  for i in pbar:
 
34
  with torch.no_grad():
 
35
  diffuser.set_scheduler_timesteps(nsteps)
 
36
  optimizer.zero_grad()
37
 
38
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
39
 
40
+ latents = diffuser.get_initial_latents(1, img_size, 1)
41
 
42
  with finetuner:
43
 
 
79
 
80
  parser = argparse.ArgumentParser()
81
 
82
+ parser.add_argument("--repo_id_or_path", required=True)
83
+ parser.add_argument("--img_size", type=int, required=False, default=512)
84
  parser.add_argument('--prompt', required=True)
85
  parser.add_argument('--modules', required=True)
86
  parser.add_argument('--freeze_modules', nargs='+', required=True)