blanchon commited on
Commit
18bbde3
·
1 Parent(s): f5ce168
Files changed (1) hide show
  1. app-fast.py +11 -12
app-fast.py CHANGED
@@ -39,31 +39,30 @@ RESOLUTION_OPTIONS: list[str] = [
39
  "832 x 1248 (Portrait)",
40
  ]
41
 
 
42
 
43
  quant_config = Int4WeightOnlyConfig(group_size=128)
44
- quantization_config = TransformersTorchAoConfig(
45
- quant_type=quant_config, dtype=torch.bfloat16
46
- )
47
 
48
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
49
  text_encoder = AutoModelForCausalLM.from_pretrained(
50
  LLAMA_MODEL_NAME,
51
- torch_dtype="auto",
52
- low_cpu_mem_usage=True,
53
- device_map="auto",
54
  output_hidden_states=True,
55
  output_attentions=True,
 
56
  quantization_config=quantization_config,
57
- )
 
 
58
 
59
  quantization_config = DiffusersTorchAoConfig("int8wo")
60
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
61
  MODEL_PATH,
62
  subfolder="transformer",
63
- device_map="auto",
64
  quantization_config=quantization_config,
65
- torch_dtype="auto",
66
- )
 
67
 
68
  scheduler = MODEL_CONFIGS["scheduler"](
69
  num_train_timesteps=1000,
@@ -76,8 +75,8 @@ pipe = HiDreamImagePipeline.from_pretrained(
76
  scheduler=scheduler,
77
  tokenizer_4=tokenizer,
78
  text_encoder_4=text_encoder,
79
- torch_dtype="auto",
80
- )
81
 
82
  pipe.transformer = transformer
83
 
 
39
  "832 x 1248 (Portrait)",
40
  ]
41
 
42
+ device = torch.device("cuda")
43
 
44
  quant_config = Int4WeightOnlyConfig(group_size=128)
45
+ quantization_config = TransformersTorchAoConfig(quant_type=quant_config)
 
 
46
 
47
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
48
  text_encoder = AutoModelForCausalLM.from_pretrained(
49
  LLAMA_MODEL_NAME,
 
 
 
50
  output_hidden_states=True,
51
  output_attentions=True,
52
+ low_cpu_mem_usage=True,
53
  quantization_config=quantization_config,
54
+ torch_dtype=torch.bfloat16, # Explicitly set dtype
55
+ device_map="auto", # Still use auto, but ensure device consistency
56
+ ).to(device) # Move model to the correct device after loading
57
 
58
  quantization_config = DiffusersTorchAoConfig("int8wo")
59
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
60
  MODEL_PATH,
61
  subfolder="transformer",
 
62
  quantization_config=quantization_config,
63
+ device_map="auto",
64
+ torch_dtype=torch.bfloat16,
65
+ ).to(device)
66
 
67
  scheduler = MODEL_CONFIGS["scheduler"](
68
  num_train_timesteps=1000,
 
75
  scheduler=scheduler,
76
  tokenizer_4=tokenizer,
77
  text_encoder_4=text_encoder,
78
+ torch_dtype=torch.bfloat16,
79
+ ).to(device)
80
 
81
  pipe.transformer = transformer
82