Pretrained Vision Transformer Neural Quantum State on the J1J_1 - J2J_2 Heinseberg model on a 10Γ—1010\times10 square lattice. The frustration ratio is set to J2/J1=0.5J_2/J_1=0.5.

Revision Variational energy Time per sweep Description
main -0.497505103 41s Plain ViT with translation invariance among patches
symm_t -0.49760546 166s ViT with translational symmetry
symm_trxy_ising -0.497676335 3317s ViT with translational, point group and sz inversion symmetries

The time per sweep is evaluated on a single A100-40GB GPU.

The architecture has been trained by distributing the computation over 40 A100-64GB GPUs for about four days.

Citation

https://www.nature.com/articles/s42005-024-01732-4

How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the model using NetKet.

import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
# Load the model from HuggingFace
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert, 
                                    graph=lattice, 
                                    J=[1.0, 0.5], 
                                    sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
                                        graph=lattice,
                                        d_max=2,
                                        n_chains=16384,
                                        sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=wf.__call__, 
                        sampler_seed=subkey,
                        n_samples=16384, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=16384)
# Overwrite samples with already thermalized ones
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Οƒ = samples)
# Sample the model
for _ in range(10):
    E = vstate.expect(hamiltonian)
    print("Mean: ", E.mean.real / lattice.n_nodes / 4)
    vstate.sample()

The expected output is:

Number of parameters = 434760
Mean: -0.4975034481394982
Mean: -0.4975697817150899
Mean: -0.49753878662981793
Mean: -0.49749150331671876
Mean: -0.4975093308123018
Mean: -0.49755810175173776
Mean: -0.49753726455462444
Mean: -0.49748956161946795
Mean: -0.497479875901942
Mean: -0.49752966071413424

The fully translational invariant wavefunction can be also be downloaded using:

wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, revision="symm_t")

Use revision="symm_trxy_ising" for a wavefunction including also the point group and the sz inversion symmetries.

Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples)

Starting from the vector zz, a fully connected network can be trained to fine-tune the model on a different value of the ratio J2/J1J_2/J_1. See https://doi.org/10.1103/PhysRevResearch.6.023057 for more informations.

Note: the hidden representation is well defined only for the non symmetrized model.

Training Hyperparameters

Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12

Total number of parameters: 434760

Model Card Contact

Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])

Downloads last month
28
Safetensors
Model size
435k params
Tensor type
F64
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support