MohamedRashad commited on
Commit
9c8b547
·
verified ·
1 Parent(s): ad5c840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -44
app.py CHANGED
@@ -14,7 +14,7 @@ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode
14
  from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
15
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
16
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
17
- from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
18
  from diffusers_helper.thread_utils import AsyncStream, async_run
19
  from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
20
  from transformers import SiglipImageProcessor, SiglipVisionModel
@@ -22,12 +22,6 @@ from diffusers_helper.clip_vision import hf_clip_vision_encode
22
  from diffusers_helper.bucket_tools import find_nearest_bucket
23
 
24
 
25
- free_mem_gb = get_cuda_free_memory_gb(gpu)
26
- high_vram = free_mem_gb > 60
27
-
28
- print(f'Free VRAM {free_mem_gb} GB')
29
- print(f'High-VRAM Mode: {high_vram}')
30
-
31
  text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
32
  text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
33
  tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
@@ -45,9 +39,8 @@ text_encoder_2.eval()
45
  image_encoder.eval()
46
  transformer.eval()
47
 
48
- if not high_vram:
49
- vae.enable_slicing()
50
- vae.enable_tiling()
51
 
52
  transformer.high_quality_fp32_output_for_inference = True
53
  print('transformer.high_quality_fp32_output_for_inference = True')
@@ -64,16 +57,8 @@ text_encoder_2.requires_grad_(False)
64
  image_encoder.requires_grad_(False)
65
  transformer.requires_grad_(False)
66
 
67
- if not high_vram:
68
- # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
69
- DynamicSwapInstaller.install_model(transformer, device=gpu)
70
- DynamicSwapInstaller.install_model(text_encoder, device=gpu)
71
- else:
72
- text_encoder.to(gpu)
73
- text_encoder_2.to(gpu)
74
- image_encoder.to(gpu)
75
- vae.to(gpu)
76
- transformer.to(gpu)
77
 
78
  stream = AsyncStream()
79
 
@@ -91,19 +76,16 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
91
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
92
 
93
  try:
94
- # Clean GPU
95
- if not high_vram:
96
- unload_complete_models(
97
- text_encoder, text_encoder_2, image_encoder, vae, transformer
98
- )
99
 
100
  # Text encoding
101
 
102
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
103
 
104
- if not high_vram:
105
- fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
106
- load_model_as_complete(text_encoder_2, target_device=gpu)
107
 
108
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
109
 
@@ -132,8 +114,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
132
 
133
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
134
 
135
- if not high_vram:
136
- load_model_as_complete(vae, target_device=gpu)
137
 
138
  start_latent = vae_encode(input_image_pt, vae)
139
 
@@ -141,8 +122,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
141
 
142
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
143
 
144
- if not high_vram:
145
- load_model_as_complete(image_encoder, target_device=gpu)
146
 
147
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
148
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
@@ -193,9 +173,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
193
  clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
194
  clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
195
 
196
- if not high_vram:
197
- unload_complete_models()
198
- move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
199
 
200
  if use_teacache:
201
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
@@ -257,9 +236,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
257
  total_generated_latent_frames += int(generated_latents.shape[2])
258
  history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
259
 
260
- if not high_vram:
261
- offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
262
- load_model_as_complete(vae, target_device=gpu)
263
 
264
  real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
265
 
@@ -272,8 +250,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
272
  current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
273
  history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
274
 
275
- if not high_vram:
276
- unload_complete_models()
277
 
278
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
279
 
@@ -288,10 +265,9 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
288
  except:
289
  traceback.print_exc()
290
 
291
- if not high_vram:
292
- unload_complete_models(
293
- text_encoder, text_encoder_2, image_encoder, vae, transformer
294
- )
295
 
296
  stream.output_queue.push(('end', None))
297
  return
 
14
  from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
15
  from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
16
  from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
17
+ from diffusers_helper.memory import cpu, gpu, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
18
  from diffusers_helper.thread_utils import AsyncStream, async_run
19
  from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
20
  from transformers import SiglipImageProcessor, SiglipVisionModel
 
22
  from diffusers_helper.bucket_tools import find_nearest_bucket
23
 
24
 
 
 
 
 
 
 
25
  text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
26
  text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
27
  tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
 
39
  image_encoder.eval()
40
  transformer.eval()
41
 
42
+ vae.enable_slicing()
43
+ vae.enable_tiling()
 
44
 
45
  transformer.high_quality_fp32_output_for_inference = True
46
  print('transformer.high_quality_fp32_output_for_inference = True')
 
57
  image_encoder.requires_grad_(False)
58
  transformer.requires_grad_(False)
59
 
60
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
61
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
 
 
 
 
 
 
 
 
62
 
63
  stream = AsyncStream()
64
 
 
76
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
77
 
78
  try:
79
+ unload_complete_models(
80
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
81
+ )
 
 
82
 
83
  # Text encoding
84
 
85
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
86
 
87
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
88
+ load_model_as_complete(text_encoder_2, target_device=gpu)
 
89
 
90
  llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
91
 
 
114
 
115
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
116
 
117
+ load_model_as_complete(vae, target_device=gpu)
 
118
 
119
  start_latent = vae_encode(input_image_pt, vae)
120
 
 
122
 
123
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
124
 
125
+ load_model_as_complete(image_encoder, target_device=gpu)
 
126
 
127
  image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
128
  image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
 
173
  clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
174
  clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
175
 
176
+ unload_complete_models()
177
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
 
178
 
179
  if use_teacache:
180
  transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
 
236
  total_generated_latent_frames += int(generated_latents.shape[2])
237
  history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
238
 
239
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
240
+ load_model_as_complete(vae, target_device=gpu)
 
241
 
242
  real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
243
 
 
250
  current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
251
  history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
252
 
253
+ unload_complete_models()
 
254
 
255
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
256
 
 
265
  except:
266
  traceback.print_exc()
267
 
268
+ unload_complete_models(
269
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
270
+ )
 
271
 
272
  stream.output_queue.push(('end', None))
273
  return