blanchon commited on
Commit
3c3ffca
·
1 Parent(s): 31bf3ec
Files changed (1) hide show
  1. app-fast.py +6 -4
app-fast.py CHANGED
@@ -42,18 +42,20 @@ RESOLUTION_OPTIONS: list[str] = [
42
 
43
  # Using AOBaseConfig instance (torchao >= 0.10.0)
44
  quant_config = Int4WeightOnlyConfig(group_size=128)
45
- quantization_config = TransformersTorchAoConfig(quant_type=quant_config)
 
 
46
 
47
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
48
  text_encoder = AutoModelForCausalLM.from_pretrained(
49
  LLAMA_MODEL_NAME,
50
- torch_dtype=torch.float16,
51
  low_cpu_mem_usage=True,
52
  device_map="auto",
53
  output_hidden_states=True,
54
  output_attentions=True,
55
  quantization_config=quantization_config,
56
- ).to("cuda")
57
 
58
  quantization_config = DiffusersTorchAoConfig("int8wo")
59
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
@@ -61,7 +63,7 @@ transformer = HiDreamImageTransformer2DModel.from_pretrained(
61
  subfolder="transformer",
62
  quantization_config=quantization_config,
63
  torch_dtype=torch.bfloat16,
64
- ).to("cuda")
65
 
66
  scheduler = MODEL_CONFIGS["scheduler"](
67
  num_train_timesteps=1000,
 
42
 
43
  # Using AOBaseConfig instance (torchao >= 0.10.0)
44
  quant_config = Int4WeightOnlyConfig(group_size=128)
45
+ quantization_config = TransformersTorchAoConfig(
46
+ quant_type=quant_config, dtype=torch.bfloat16
47
+ )
48
 
49
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
50
  text_encoder = AutoModelForCausalLM.from_pretrained(
51
  LLAMA_MODEL_NAME,
52
+ torch_dtype=torch.bfloat16,
53
  low_cpu_mem_usage=True,
54
  device_map="auto",
55
  output_hidden_states=True,
56
  output_attentions=True,
57
  quantization_config=quantization_config,
58
+ ).to("cuda", torch.bfloat16)
59
 
60
  quantization_config = DiffusersTorchAoConfig("int8wo")
61
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
 
63
  subfolder="transformer",
64
  quantization_config=quantization_config,
65
  torch_dtype=torch.bfloat16,
66
+ ).to("cuda", dtype=torch.float16)
67
 
68
  scheduler = MODEL_CONFIGS["scheduler"](
69
  num_train_timesteps=1000,