lzyhha commited on
Commit
6dd0ec6
·
1 Parent(s): 7bf1b5d
Files changed (1) hide show
  1. visualcloze.py +16 -16
visualcloze.py CHANGED
@@ -90,26 +90,26 @@ class VisualClozeModel:
90
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
92
 
93
- # Initialize model
94
- print("Initializing model...")
95
- self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
96
 
97
- # Initialize VAE
98
- print("Initializing VAE...")
99
- self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
100
- self.ae.requires_grad_(False)
101
 
102
- # Initialize text encoders
103
- print("Initializing text encoders...")
104
- self.t5 = load_t5(self.device, max_length=self.max_length)
105
- self.clip = load_clip(self.device)
106
 
107
- self.model.eval().to(self.device, dtype=self.dtype)
108
 
109
- # Load model weights
110
- ckpt = torch.load(model_path)
111
- self.model.load_state_dict(ckpt, strict=False)
112
- del ckpt
113
 
114
  # Initialize sampler
115
  transport = create_transport(
 
90
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
92
 
93
+ # # Initialize model
94
+ # print("Initializing model...")
95
+ # self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
96
 
97
+ # # Initialize VAE
98
+ # print("Initializing VAE...")
99
+ # self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
100
+ # self.ae.requires_grad_(False)
101
 
102
+ # # Initialize text encoders
103
+ # print("Initializing text encoders...")
104
+ # self.t5 = load_t5(self.device, max_length=self.max_length)
105
+ # self.clip = load_clip(self.device)
106
 
107
+ # self.model.eval().to(self.device, dtype=self.dtype)
108
 
109
+ # # Load model weights
110
+ # ckpt = torch.load(model_path)
111
+ # self.model.load_state_dict(ckpt, strict=False)
112
+ # del ckpt
113
 
114
  # Initialize sampler
115
  transport = create_transport(