Spaces:
Running
on
Zero
Running
on
Zero
lzyhha
commited on
Commit
·
6dd0ec6
1
Parent(s):
7bf1b5d
flashattn
Browse files- visualcloze.py +16 -16
visualcloze.py
CHANGED
@@ -90,26 +90,26 @@ class VisualClozeModel:
|
|
90 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
91 |
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
|
92 |
|
93 |
-
# Initialize model
|
94 |
-
print("Initializing model...")
|
95 |
-
self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
|
96 |
|
97 |
-
# Initialize VAE
|
98 |
-
print("Initializing VAE...")
|
99 |
-
self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
|
100 |
-
self.ae.requires_grad_(False)
|
101 |
|
102 |
-
# Initialize text encoders
|
103 |
-
print("Initializing text encoders...")
|
104 |
-
self.t5 = load_t5(self.device, max_length=self.max_length)
|
105 |
-
self.clip = load_clip(self.device)
|
106 |
|
107 |
-
self.model.eval().to(self.device, dtype=self.dtype)
|
108 |
|
109 |
-
# Load model weights
|
110 |
-
ckpt = torch.load(model_path)
|
111 |
-
self.model.load_state_dict(ckpt, strict=False)
|
112 |
-
del ckpt
|
113 |
|
114 |
# Initialize sampler
|
115 |
transport = create_transport(
|
|
|
90 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
91 |
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
|
92 |
|
93 |
+
# # Initialize model
|
94 |
+
# print("Initializing model...")
|
95 |
+
# self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
|
96 |
|
97 |
+
# # Initialize VAE
|
98 |
+
# print("Initializing VAE...")
|
99 |
+
# self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
|
100 |
+
# self.ae.requires_grad_(False)
|
101 |
|
102 |
+
# # Initialize text encoders
|
103 |
+
# print("Initializing text encoders...")
|
104 |
+
# self.t5 = load_t5(self.device, max_length=self.max_length)
|
105 |
+
# self.clip = load_clip(self.device)
|
106 |
|
107 |
+
# self.model.eval().to(self.device, dtype=self.dtype)
|
108 |
|
109 |
+
# # Load model weights
|
110 |
+
# ckpt = torch.load(model_path)
|
111 |
+
# self.model.load_state_dict(ckpt, strict=False)
|
112 |
+
# del ckpt
|
113 |
|
114 |
# Initialize sampler
|
115 |
transport = create_transport(
|