Model Description

This model is a DiT (diffusion transformer) trained on Wikiart dataset https://huggingface.co/datasets/Artificio/WikiArt from scratch. It is designed to generate art images given art genre and art style.

Model Architecture

The model largely mirrors classic DiT architecture described in the paper Scalable Diffusion Models with Transformers with slight modifications:

  1. Replaced ImageNet classes embeddings with Wikiart genres and styles embeddings;
  2. Used post-norm instead of pre-norm;
  3. Omitted final linear layer;
  4. Replaced sin-cos-2d positional embedding with learned positional embedding;
  5. Models only predict noise and don't learn sigma;
  6. Setting patch_size=2 for all model variants;
  7. Models have different size settings. Please check modeling_dit_wikiart.py in this repository for more details if you are interested.

The model has three variants:

  • S: small, num_blocks=8, hidden_size=384, num_heads=6, total_params=20M;
  • B: base, num_blocks=12, hidden_size=640, num_heads=10, total_params=90M;
  • L: large, num_blocks=16, hidden_size=896, num_heads=14, total_params=234M.

Training Procedure

  • dataset: all model variants were trained on 103K Wikiart dataset with data augmentation by horizontal flipping.
  • optimizer: AdamW with default settings.
  • learning rate: linear warmup for first 1% steps where learning rate reached a maximum of 3e-4, then cosine decay to zero in following steps.
  • epochs and batch size:
    • S: 96 epochs with batch size of 176,
    • B: 120 epochs with batch size of 192,
    • L: 144 epochs with batch size of 192
  • device:
    • S: single RTX 4060ti 16G for 24 hrs,
    • B: single RTX 4060ti 16G for 90 hrs,
    • L: single RTX 4090D 24G for 48 hrs, followed by single RTX 4060ti 16G for 100 hrs.
  • loss curve: all variants witnessed a dramatic loss in the first epoch from above 1.0000 to around 0.2000, followed by a much slower decrease to finally reach loss=0.1600 at 20th epoch. DiT-S finally reached 0.1590; DiT-B finally reached 0.1525; DiT-L finally reached 0.1510. Training is stable without loss spike.

Performance and Limitations

  • The models demonstrates basic abilities to understand genres and styles and produce visually-appealing paintings (at first glance).
  • Limitations include:
    • Failure to understand complex structures like human faces, buildings, etc.
    • Occassional modal collapse when asked to generate genres or styles rarely seen in the dataset. style like minimalism and genre like uroshi-e for example.
    • Resolution limited to 256x256
    • Trained on Wikiart dataset, therefore unable to generate out-of-scope images

How to use it

To use the model, install the "huggingface_hub" library and download modeling_dit_wikiart.py in "Files and versions" for model definition. After that you can use the model using the following code:

from modeling_dit_wikiart import DiTWikiartModel

model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)

The model is paired with stabilityai/sd-vae-ft-ema.

Downloads last month
38
Safetensors
Model size
235M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train kaupane/DiT-Wikiart-Large

Space using kaupane/DiT-Wikiart-Large 1

Collection including kaupane/DiT-Wikiart-Large