Klayand commited on
Commit
7168bc5
·
1 Parent(s): 82d04a4

update GPU duration

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -68,7 +68,7 @@ def load_models():
68
 
69
  pipe_sd35, pipe_sdxl = load_models()
70
 
71
- @spaces.GPU
72
  def generate_image(
73
  model_name,
74
  seed,
@@ -97,7 +97,6 @@ def generate_image(
97
 
98
  pipe.to(device)
99
  pipe.enable_model_cpu_offload()
100
- # os.makedirs('./weights', exist_ok=True)
101
  os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
102
  # TODO: load noise model
103
  if method == 'core' or method == 'z-core':
@@ -105,7 +104,7 @@ def generate_image(
105
  from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
106
 
107
  if model_name == 'sd35':
108
- refine_model = PromptSD35Net().to(device)
109
  replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
110
  lora_true(refine_model, lora_idx=0)
111
  checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
@@ -117,10 +116,9 @@ def generate_image(
117
  checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
118
  refine_model.load_state_dict(checkpoint)
119
 
120
- print("Load Lora Success")
121
- refine_model = refine_model.to(device)
122
  refine_model = refine_model.to(torch.bfloat16)
123
-
 
124
  # 根据模型类型设置形状
125
  if model_name == 'sdxl':
126
  shape = (1, 4, size // 8, size // 8)
 
68
 
69
  pipe_sd35, pipe_sdxl = load_models()
70
 
71
+ @spaces.GPU(duration=360)
72
  def generate_image(
73
  model_name,
74
  seed,
 
97
 
98
  pipe.to(device)
99
  pipe.enable_model_cpu_offload()
 
100
  os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
101
  # TODO: load noise model
102
  if method == 'core' or method == 'z-core':
 
104
  from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
105
 
106
  if model_name == 'sd35':
107
+ refine_model = PromptSD35Net()
108
  replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
109
  lora_true(refine_model, lora_idx=0)
110
  checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
 
116
  checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
117
  refine_model.load_state_dict(checkpoint)
118
 
 
 
119
  refine_model = refine_model.to(torch.bfloat16)
120
+ refine_model = refine_model.to(device)
121
+ print("Load Lora Success")
122
  # 根据模型类型设置形状
123
  if model_name == 'sdxl':
124
  shape = (1, 4, size // 8, size // 8)