lzyhha commited on
Commit
24d1968
·
1 Parent(s): e2a1c5c
Files changed (1) hide show
  1. visualcloze.py +16 -16
visualcloze.py CHANGED
@@ -91,26 +91,26 @@ class VisualClozeModel:
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
93
 
94
- # # Initialize model
95
- # print("Initializing model...")
96
- # self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
97
 
98
- # # Initialize VAE
99
- # print("Initializing VAE...")
100
- # self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
101
- # self.ae.requires_grad_(False)
102
 
103
- # # Initialize text encoders
104
- # print("Initializing text encoders...")
105
- # self.t5 = load_t5(self.device, max_length=self.max_length)
106
- # self.clip = load_clip(self.device)
107
 
108
- # self.model.eval().to(self.device, dtype=self.dtype)
109
 
110
- # # Load model weights
111
- # ckpt = torch.load(model_path)
112
- # self.model.load_state_dict(ckpt, strict=False)
113
- # del ckpt
114
 
115
  # Initialize sampler
116
  transport = create_transport(
 
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
93
 
94
+ # Initialize model
95
+ print("Initializing model...")
96
+ self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
97
 
98
+ # Initialize VAE
99
+ print("Initializing VAE...")
100
+ self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
101
+ self.ae.requires_grad_(False)
102
 
103
+ # Initialize text encoders
104
+ print("Initializing text encoders...")
105
+ self.t5 = load_t5(self.device, max_length=self.max_length)
106
+ self.clip = load_clip(self.device)
107
 
108
+ self.model.eval().to(self.device, dtype=self.dtype)
109
 
110
+ # Load model weights
111
+ ckpt = torch.load(model_path)
112
+ self.model.load_state_dict(ckpt, strict=False)
113
+ del ckpt
114
 
115
  # Initialize sampler
116
  transport = create_transport(