LPX55 commited on
Commit
a7eb958
·
verified ·
1 Parent(s): d944e21

Update mini.py

Browse files
Files changed (1) hide show
  1. mini.py +23 -17
mini.py CHANGED
@@ -3,10 +3,12 @@ import torch
3
  import spaces
4
  from PIL import Image
5
  import os
6
- from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
7
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
8
  from flux.transformer_flux_simple import FluxTransformer2DModel
9
  from flux.pipeline_flux_chameleon_og import FluxPipeline
 
 
10
  import torch.nn as nn
11
  import math
12
  import logging
@@ -29,6 +31,9 @@ MODEL_CACHE_DIR = "model_cache"
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  DTYPE = torch.bfloat16
31
 
 
 
 
32
  # Aspect ratio options
33
  ASPECT_RATIOS = {
34
  "1:1": (1024, 1024),
@@ -81,12 +86,13 @@ tokenizer_two = T5TokenizerFast.from_pretrained(
81
 
82
  # Load larger models to CPU
83
  vae = AutoencoderKL.from_pretrained(
84
- os.path.join(MODEL_CACHE_DIR, "flux/vae")
85
- ).to(DTYPE).to(DEVICE)
86
 
87
  transformer = FluxTransformer2DModel.from_pretrained(
88
- os.path.join(MODEL_CACHE_DIR, "flux/transformer")
89
- ).to(DTYPE).to(DEVICE)
 
90
 
91
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
92
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
@@ -95,7 +101,8 @@ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
95
 
96
  # Load Qwen2VL to CPU
97
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
98
- os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
 
99
  ).to(DTYPE).cpu()
100
 
101
  # Load connector and embedder
@@ -134,16 +141,6 @@ pipeline = FluxPipeline(
134
  tokenizer=tokenizer,
135
  )
136
 
137
-
138
- # # Move Transformer and VAE to GPU
139
- # logger.info("Moving Transformer and VAE to GPU...")
140
- # transformer.to(DEVICE)
141
- # vae.to(DEVICE)
142
-
143
- # # Update pipeline models
144
- # pipeline.transformer = transformer
145
- # pipeline.vae = vae
146
- # logger.info("Models moved to GPU")
147
  def process_image(image):
148
  """Process image with Qwen2VL model"""
149
  try:
@@ -267,7 +264,16 @@ def generate(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28,
267
  pooled_prompt_embeds = compute_text_embeddings(prompt)
268
  t5_prompt_embeds = compute_t5_text_embeddings(prompt)
269
  logger.info("Text embeddings computed")
270
-
 
 
 
 
 
 
 
 
 
271
 
272
  # Get dimensions
273
  width, height = ASPECT_RATIOS[aspect_ratio]
 
3
  import spaces
4
  from PIL import Image
5
  import os
6
+ from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast, BitsAndBytesConfig
7
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
8
  from flux.transformer_flux_simple import FluxTransformer2DModel
9
  from flux.pipeline_flux_chameleon_og import FluxPipeline
10
+ from flux.pipeline_flux_img2img import FluxImg2ImgPipeline
11
+
12
  import torch.nn as nn
13
  import math
14
  import logging
 
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
  DTYPE = torch.bfloat16
33
 
34
+ quant_config = BitsAndBytesConfig(load_in_8bit=True,)
35
+
36
+
37
  # Aspect ratio options
38
  ASPECT_RATIOS = {
39
  "1:1": (1024, 1024),
 
86
 
87
  # Load larger models to CPU
88
  vae = AutoencoderKL.from_pretrained(
89
+ os.path.join(MODEL_CACHE_DIR, "flux/vae"),
90
+ ).to(DTYPE).cpu()
91
 
92
  transformer = FluxTransformer2DModel.from_pretrained(
93
+ os.path.join(MODEL_CACHE_DIR, "flux/transformer"),
94
+ quantization_config=quant_config,
95
+ ).to(DTYPE).cpu()
96
 
97
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
98
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
 
101
 
102
  # Load Qwen2VL to CPU
103
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
104
+ os.path.join(MODEL_CACHE_DIR, "qwen2-vl"),
105
+ quantization_config=quant_config,
106
  ).to(DTYPE).cpu()
107
 
108
  # Load connector and embedder
 
141
  tokenizer=tokenizer,
142
  )
143
 
 
 
 
 
 
 
 
 
 
 
144
  def process_image(image):
145
  """Process image with Qwen2VL model"""
146
  try:
 
264
  pooled_prompt_embeds = compute_text_embeddings(prompt)
265
  t5_prompt_embeds = compute_t5_text_embeddings(prompt)
266
  logger.info("Text embeddings computed")
267
+
268
+ # Move Transformer and VAE to GPU
269
+ logger.info("Moving Transformer and VAE to GPU...")
270
+ transformer.to(DEVICE)
271
+ vae.to(DEVICE)
272
+
273
+ # Update pipeline models
274
+ pipeline.transformer = transformer
275
+ pipeline.vae = vae
276
+ logger.info("Models moved to GPU")
277
 
278
  # Get dimensions
279
  width, height = ASPECT_RATIOS[aspect_ratio]