listen2you003 commited on
Commit
d2c3bde
·
1 Parent(s): 6b3d669

force cuda set

Browse files
Files changed (1) hide show
  1. app.py +33 -28
app.py CHANGED
@@ -122,6 +122,11 @@ class ImageGenerator:
122
  self.ae = self.ae.to(device=self.device, dtype=torch.float32)
123
  self.dit = self.dit.to(device=self.device, dtype=dtype)
124
  self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
 
 
 
 
 
125
 
126
  def prepare(self, prompt, img, ref_image, ref_image_raw):
127
  bs, _, h, w = img.shape
@@ -377,10 +382,32 @@ class ImageGenerator:
377
  return images_list
378
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
 
382
  @spaces.GPU(duration=240)
383
- def inference(prompt, ref_images, seed, size_level, infer_func=None):
384
  start_time = time.time()
385
 
386
  if seed == -1:
@@ -389,7 +416,11 @@ def inference(prompt, ref_images, seed, size_level, infer_func=None):
389
  else:
390
  random_seed = seed
391
 
392
- image = infer_func(
 
 
 
 
393
  prompt,
394
  negative_prompt="",
395
  ref_images=ref_images.convert('RGB'),
@@ -404,32 +435,6 @@ def inference(prompt, ref_images, seed, size_level, infer_func=None):
404
  print(f"Time taken: {time.time() - start_time:.2f} seconds")
405
  return image, random_seed
406
 
407
-
408
- # 模型仓库ID(如:"bert-base-uncased")
409
- model_repo = "stepfun-ai/Step1X-Edit"
410
- # 本地保存路径
411
- model_path = "./model_weights"
412
- os.makedirs(model_path, exist_ok=True)
413
-
414
-
415
- # 下载模型(包括所有文件)
416
- snapshot_download(
417
- repo_id=model_repo,
418
- local_dir=model_path,
419
- local_dir_use_symlinks=False # 避免使用符号链接
420
- )
421
-
422
-
423
- image_edit = ImageGenerator(
424
- ae_path=os.path.join(model_path, 'vae.safetensors'),
425
- dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
426
- qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
427
- max_length=640,
428
- )
429
-
430
- inference_func = image_edit.generate_image
431
-
432
- # inference_func = prepare_infer_func()
433
  with gr.Blocks() as demo:
434
  gr.Markdown(
435
  """
 
122
  self.ae = self.ae.to(device=self.device, dtype=torch.float32)
123
  self.dit = self.dit.to(device=self.device, dtype=dtype)
124
  self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
125
+
126
+ def to_cuda(self):
127
+ self.ae.to(device='cuda', dtype=torch.float32)
128
+ self.dit.to(device='cuda', dtype=torch.bfloat16)
129
+ self.llm_encoder.to(device='cuda', dtype=torch.bfloat16)
130
 
131
  def prepare(self, prompt, img, ref_image, ref_image_raw):
132
  bs, _, h, w = img.shape
 
382
  return images_list
383
 
384
 
385
+ # 模型仓库ID(如:"bert-base-uncased")
386
+ model_repo = "stepfun-ai/Step1X-Edit"
387
+ # 本地保存路径
388
+ model_path = "./model_weights"
389
+ os.makedirs(model_path, exist_ok=True)
390
+
391
+
392
+ # 下载模型(包括所有文件)
393
+ snapshot_download(
394
+ repo_id=model_repo,
395
+ local_dir=model_path,
396
+ local_dir_use_symlinks=False # 避免使用符号链接
397
+ )
398
+
399
+
400
+ image_edit = ImageGenerator(
401
+ ae_path=os.path.join(model_path, 'vae.safetensors'),
402
+ dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
403
+ qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
404
+ max_length=640,
405
+ )
406
+
407
 
408
 
409
  @spaces.GPU(duration=240)
410
+ def inference(prompt, ref_images, seed, size_level):
411
  start_time = time.time()
412
 
413
  if seed == -1:
 
416
  else:
417
  random_seed = seed
418
 
419
+ image_edit.to_cuda()
420
+
421
+ inference_func = image_edit.generate_image
422
+
423
+ image = inference_func(
424
  prompt,
425
  negative_prompt="",
426
  ref_images=ref_images.convert('RGB'),
 
435
  print(f"Time taken: {time.time() - start_time:.2f} seconds")
436
  return image, random_seed
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  with gr.Blocks() as demo:
439
  gr.Markdown(
440
  """