Foundation Neural-Network Quantum State trained on the Ising in disordered transverse field model on a chain with LL sites. The Hamiltonian (assuming periodic boundary conditions) is given by: H^=βˆ’Jβˆ‘i=1NS^izS^i+1zβˆ’βˆ‘i=1NhiS^ix , \hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - \sum_{i=1}^N h_i \hat{S}_i^x \ ,

where hih_i is the on-site transverse magnetic field at the ii-th site. In the disordered case, hih_i varies randomly along the chain, drawn independently and identically from the uniform distribution on the interval [0,h0][0, h_0].

Several values of the external field intensity h0h_0 are available (check the different revisions).

The architecture has been trained on R=2000R=2000 different disorder realization for a fixed value of h0h_0, using a total batch size of M=10000M=10000 samples.

The computation has been distributed over 4 A100-64GB GPUs for about two hours.

How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the architecture for a fixed disordered realization using NetKet.

from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import netket as nk
import math
import flax
from flax.training import checkpoints

from netket.operator.spin import sigmax,sigmaz

flax.config.update('flax_use_orbax_checkpointing', False)

h0 = 1.0 #* fix the value of the external field
L = 32
revision = f"L={L}_h={h0}" #check the revisions for the available values of h0 and L

from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_disorder_fnqs", 
                                   trust_remote_code=True, 
                                   revision=revision, 
                                   )
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)

lattice = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)

J = -1.0/math.e
key = jax.random.key(0)
h = jax.random.uniform(key, shape=(L,))

hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = sum([(-h[i]*h0)*sigmax(hilbert,i) for i in range(L)])
hamiltonian += sum([J*sigmaz(hilbert,i)*sigmaz(hilbert,(i+1)%L) for i in range(L)])

action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert, 
                                       rule=action, 
                                       n_chains=10000, 
                                       n_sweeps=lattice.n_nodes)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=partial(wf.__call__, coups=h), 
                        sampler_seed=subkey,
                        n_samples=10000, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=10000)

from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_disorder_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8') # some netket versions require floats
vstate.sampler_state = vstate.sampler_state.replace(Οƒ = samples)

import time
# Sample the model
for _ in range(10):
    start = time.time()
    E = vstate.expect(hamiltonian)
    vstate.sample()

    print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)

The time per sweep is 0.5s, evaluated on a single A100-40GB GPU.

Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

wf = FlaxAutoModel.from_pretrained("nqs-models/ising_disorder_fnqs", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples, h)

Training Hyperparameters

Number of layers: 6
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Patch size: 4

Total number of parameters: 326124

Contacts

Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])

Downloads last month
22
Safetensors
Model size
326k params
Tensor type
F64
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support