blanchon commited on
Commit
2cece51
·
1 Parent(s): fe2f104
Files changed (1) hide show
  1. app-fast.py +6 -3
app-fast.py CHANGED
@@ -50,7 +50,8 @@ text_encoder = AutoModelForCausalLM.from_pretrained(
50
  output_attentions=True,
51
  low_cpu_mem_usage=True,
52
  quantization_config=quant_config,
53
- torch_dtype=torch.float16,
 
54
  )
55
 
56
  quant_config = DiffusersBitsAndBytesConfig(
@@ -60,7 +61,8 @@ transformer = HiDreamImageTransformer2DModel.from_pretrained(
60
  MODEL_PATH,
61
  subfolder="transformer",
62
  quantization_config=quant_config,
63
- torch_dtype=torch.float16,
 
64
  )
65
 
66
  scheduler = MODEL_CONFIGS["scheduler"](
@@ -74,7 +76,8 @@ pipe = HiDreamImagePipeline.from_pretrained(
74
  scheduler=scheduler,
75
  tokenizer_4=tokenizer,
76
  text_encoder_4=text_encoder,
77
- torch_dtype=torch.float16,
 
78
  ).to(device)
79
 
80
  pipe.transformer = transformer
 
50
  output_attentions=True,
51
  low_cpu_mem_usage=True,
52
  quantization_config=quant_config,
53
+ device_map="auto",
54
+ torch_dtype=torch.bfloat16,
55
  )
56
 
57
  quant_config = DiffusersBitsAndBytesConfig(
 
61
  MODEL_PATH,
62
  subfolder="transformer",
63
  quantization_config=quant_config,
64
+ device_map="auto",
65
+ torch_dtype=torch.bfloat16,
66
  )
67
 
68
  scheduler = MODEL_CONFIGS["scheduler"](
 
76
  scheduler=scheduler,
77
  tokenizer_4=tokenizer,
78
  text_encoder_4=text_encoder,
79
+ device_map="auto",
80
+ torch_dtype=torch.bfloat16,
81
  ).to(device)
82
 
83
  pipe.transformer = transformer