Spaces:
Build error
Build error
update GPU duration
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ def load_models():
|
|
68 |
|
69 |
pipe_sd35, pipe_sdxl = load_models()
|
70 |
|
71 |
-
@spaces.GPU
|
72 |
def generate_image(
|
73 |
model_name,
|
74 |
seed,
|
@@ -97,7 +97,6 @@ def generate_image(
|
|
97 |
|
98 |
pipe.to(device)
|
99 |
pipe.enable_model_cpu_offload()
|
100 |
-
# os.makedirs('./weights', exist_ok=True)
|
101 |
os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
|
102 |
# TODO: load noise model
|
103 |
if method == 'core' or method == 'z-core':
|
@@ -105,7 +104,7 @@ def generate_image(
|
|
105 |
from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
|
106 |
|
107 |
if model_name == 'sd35':
|
108 |
-
refine_model = PromptSD35Net()
|
109 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
110 |
lora_true(refine_model, lora_idx=0)
|
111 |
checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
|
@@ -117,10 +116,9 @@ def generate_image(
|
|
117 |
checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
|
118 |
refine_model.load_state_dict(checkpoint)
|
119 |
|
120 |
-
print("Load Lora Success")
|
121 |
-
refine_model = refine_model.to(device)
|
122 |
refine_model = refine_model.to(torch.bfloat16)
|
123 |
-
|
|
|
124 |
# 根据模型类型设置形状
|
125 |
if model_name == 'sdxl':
|
126 |
shape = (1, 4, size // 8, size // 8)
|
|
|
68 |
|
69 |
pipe_sd35, pipe_sdxl = load_models()
|
70 |
|
71 |
+
@spaces.GPU(duration=360)
|
72 |
def generate_image(
|
73 |
model_name,
|
74 |
seed,
|
|
|
97 |
|
98 |
pipe.to(device)
|
99 |
pipe.enable_model_cpu_offload()
|
|
|
100 |
os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
|
101 |
# TODO: load noise model
|
102 |
if method == 'core' or method == 'z-core':
|
|
|
104 |
from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
|
105 |
|
106 |
if model_name == 'sd35':
|
107 |
+
refine_model = PromptSD35Net()
|
108 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
109 |
lora_true(refine_model, lora_idx=0)
|
110 |
checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
|
|
|
116 |
checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
|
117 |
refine_model.load_state_dict(checkpoint)
|
118 |
|
|
|
|
|
119 |
refine_model = refine_model.to(torch.bfloat16)
|
120 |
+
refine_model = refine_model.to(device)
|
121 |
+
print("Load Lora Success")
|
122 |
# 根据模型类型设置形状
|
123 |
if model_name == 'sdxl':
|
124 |
shape = (1, 4, size // 8, size // 8)
|