diff --git a/.gitattributes b/.gitattributes
index 7fe70d7f07c494ee23600b490a1167607f3a08ca..7ebbda59fc572adcf72d0ba8bf6de67a0e57cc30 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1 +1,3 @@
*.json filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.db filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/README.md b/.github/README.md
index cd764debf7f948d469d69b7dafd3abb5e9886f4c..ba8d603ea1af032ac91fa32ae3456797fab4f5c5 100644
--- a/.github/README.md
+++ b/.github/README.md
@@ -18,12 +18,12 @@ MLIP Arena leverages modern pythonic workflow orchestrator [Prefect](https://www
## Announcement
-- **[April 8, 2025]** [๐ **MLIP Arena accepted as an ICLR AI4Mat Spotlight!** ๐](https://openreview.net/forum?id=ysKfIavYQE#discussion) Huge thanks to all co-authors for their contributions!
+- **[April 8, 2025]** [๐ **MLIP Arena is accepted as an ICLR AI4Mat Spotlight!** ๐](https://openreview.net/forum?id=ysKfIavYQE#discussion) Huge thanks to all co-authors for their contributions!
## Installation
-### From PyPI (without model running capability)
+### From PyPI (prefect workflow only, without pretrained models)
```bash
pip install mlip-arena
diff --git a/examples/eos_bulk/CHGNet.parquet b/examples/eos_bulk/CHGNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..43dad7c081ddbc284dc6e6b2c2852465debfa320
--- /dev/null
+++ b/examples/eos_bulk/CHGNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68871d694e93a3c3e7e272b9cbd87d3757e3bc689f30f3189db232d76e629c07
+size 429910
diff --git a/examples/eos_bulk/CHGNet_processed.parquet b/examples/eos_bulk/CHGNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..68784499609a2c80692fa0f654172600134c7782
--- /dev/null
+++ b/examples/eos_bulk/CHGNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6fbea63f9035e376bb5ac7db38175102ab3f96a0f8758cc3e9931424f829ac0
+size 357919
diff --git a/examples/eos_bulk/M3GNet.parquet b/examples/eos_bulk/M3GNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..0ce0cb0cd28f9e7203404427fe68d6b04edb859c
--- /dev/null
+++ b/examples/eos_bulk/M3GNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53dde465b5e10edd677f131f8a531e3dfc36303dd7ec7b9df0060c19847494d9
+size 427419
diff --git a/examples/eos_bulk/M3GNet_processed.parquet b/examples/eos_bulk/M3GNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..1f4c816367ce534b6cb07da03b6bec28035e125c
--- /dev/null
+++ b/examples/eos_bulk/M3GNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18ea51bf19c5e011e170a3229bbd63ce675a725c364e0cffb71de95459f8629e
+size 379859
diff --git a/examples/eos_bulk/MACE-MP(M).parquet b/examples/eos_bulk/MACE-MP(M).parquet
new file mode 100644
index 0000000000000000000000000000000000000000..7287426ae925a64ec1400a9f641848be301ada49
--- /dev/null
+++ b/examples/eos_bulk/MACE-MP(M).parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff9769eeb83042129767aeff975eb04dee8efae12e96fbd46cd3039eeda26705
+size 427896
diff --git a/examples/eos_bulk/MACE-MP(M)_processed.parquet b/examples/eos_bulk/MACE-MP(M)_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..bef6529f23c192c7741131de38aee08a72817434
--- /dev/null
+++ b/examples/eos_bulk/MACE-MP(M)_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f492125c87400fee013d32904ee97df49a07d321bd393f9335c7bf4258fe159
+size 371004
diff --git a/examples/eos_bulk/MACE-MPA.parquet b/examples/eos_bulk/MACE-MPA.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..ae18b9d947a17edb51ce0fecd2e6cc066d74d0a5
--- /dev/null
+++ b/examples/eos_bulk/MACE-MPA.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53fcd188baddd4d5e797c5aa3de1b4368db711ebd29b7877cfe224856ba9d171
+size 428888
diff --git a/examples/eos_bulk/MACE-MPA_processed.parquet b/examples/eos_bulk/MACE-MPA_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..4be6db0ab599fd1e4a49d4bc77c0418240166cca
--- /dev/null
+++ b/examples/eos_bulk/MACE-MPA_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34bd9ae08656e374263820774f49e91a964aa8b2aeade4150cf62cfc08bb37f6
+size 365289
diff --git a/examples/eos_bulk/MatterSim.parquet b/examples/eos_bulk/MatterSim.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..1ac19b614e60ddafbd3a5e0ed2599ac3e7daf401
--- /dev/null
+++ b/examples/eos_bulk/MatterSim.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6717650b97782de6f90e4473075410fe4540279eb39338d2234d3c9399079b3
+size 389586
diff --git a/examples/eos_bulk/MatterSim_processed.parquet b/examples/eos_bulk/MatterSim_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..85ee3a7982f54b27ec9696b3ac1018450d120918
--- /dev/null
+++ b/examples/eos_bulk/MatterSim_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cca22b5db67dae59602adfb8c42a80da90b39c4e95af89ef918813351b422119
+size 320962
diff --git a/examples/eos_bulk/ORBv2.parquet b/examples/eos_bulk/ORBv2.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..2cd571ac07630e9883eb0c9e0d854ed646a18d9b
--- /dev/null
+++ b/examples/eos_bulk/ORBv2.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae13c9af1ae7fafe2a42ed4c47e2ba0f036abfa64a87ca517b92d89c62fcbfd9
+size 427105
diff --git a/examples/eos_bulk/ORBv2_processed.parquet b/examples/eos_bulk/ORBv2_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..c8775daf947ddb809288bcbc648b18530e0c9936
--- /dev/null
+++ b/examples/eos_bulk/ORBv2_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c90e61645c83f2452bcd90ef9132c46266896217fe1dea9cca8e0d124d73821a
+size 227929
diff --git a/examples/eos_bulk/SevenNet.parquet b/examples/eos_bulk/SevenNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..5a3bb7cc25bd7752f35f3fbad0154d411b270b97
--- /dev/null
+++ b/examples/eos_bulk/SevenNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64be88ec2632cdabf79daa01acb2cf2ef19fef0557813df5502c4f71ec566f4e
+size 428341
diff --git a/examples/eos_bulk/SevenNet_processed.parquet b/examples/eos_bulk/SevenNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..c32745504ba858eba7902004ec8eec54e4a159c0
--- /dev/null
+++ b/examples/eos_bulk/SevenNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cc03d9af93c001f3fa441b50f058506e13c0aa7cb3d329275e68f5ed80dc3e6
+size 364846
diff --git a/examples/eos_bulk/preprocessing.py b/examples/eos_bulk/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..08789185c32c110dd37ce3576bf8779073615069
--- /dev/null
+++ b/examples/eos_bulk/preprocessing.py
@@ -0,0 +1,12 @@
+import json
+
+from ase.db import connect
+from pymatgen.core import Structure
+
+with open("wbm_structures.json") as f:
+ structs = json.load(f)
+
+with connect("wbm_structures.db") as db:
+ for id, s in structs.items():
+ atoms = Structure.from_dict(s).to_ase_atoms(msonable=False)
+ db.write(atoms, wbm_id=id)
diff --git a/examples/eos_bulk/run.py b/examples/eos_bulk/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e0b9dabbeaff7feef12381b851a7f1fe1423286
--- /dev/null
+++ b/examples/eos_bulk/run.py
@@ -0,0 +1,135 @@
+import functools
+from pathlib import Path
+
+import pandas as pd
+from ase.db import connect
+from dask.distributed import Client
+from dask_jobqueue import SLURMCluster
+from prefect import Task, flow, task
+from prefect.client.schemas.objects import TaskRun
+from prefect.states import State
+from prefect_dask import DaskTaskRunner
+
+from mlip_arena.models import REGISTRY, MLIPEnum
+from mlip_arena.tasks.eos import run as EOS
+from mlip_arena.tasks.optimize import run as OPT
+from mlip_arena.tasks.utils import get_calculator
+
+
+@task
+def load_wbm_structures():
+ """
+ Load the WBM structures from a ASE DB file.
+ """
+ with connect("../wbm_structures.db") as db:
+ for row in db.select():
+ yield row.toatoms(add_additional_information=True)
+
+
+def save_result(
+ tsk: Task,
+ run: TaskRun,
+ state: State,
+ model_name: str,
+ id: str,
+):
+ result = run.state.result()
+
+ assert isinstance(result, dict)
+
+ result["method"] = model_name
+ result["id"] = id
+ result.pop("atoms", None)
+
+ fpath = Path(f"{model_name}")
+ fpath.mkdir(exist_ok=True)
+
+ fpath = fpath / f"{result['id']}.pkl"
+
+ df = pd.DataFrame([result])
+ df.to_pickle(fpath)
+
+
+@task
+def eos_bulk(atoms, model):
+
+ calculator = get_calculator(
+ model
+ ) # avoid sending entire model over prefect and select freer GPU
+
+ result = OPT.with_options(
+ refresh_cache=True,
+ )(
+ atoms,
+ calculator,
+ optimizer="FIRE",
+ criterion=dict(
+ fmax=0.1,
+ ),
+ )
+
+ return EOS.with_options(
+ refresh_cache=True,
+ on_completion=[functools.partial(
+ save_result,
+ model_name=model.name,
+ id=atoms.info["key_value_pairs"]["wbm_id"],
+ )],
+ )(
+ atoms=result["atoms"],
+ calculator=calculator,
+ optimizer="FIRE",
+ npoints=21,
+ max_abs_strain=0.2,
+ concurrent=False
+ )
+
+
+@flow
+def run_all():
+ futures = []
+ for atoms in load_wbm_structures():
+ for model in MLIPEnum:
+ if "eos_bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
+ continue
+ result = eos_bulk.submit(atoms, model)
+ futures.append(result)
+ return [f.result(raise_on_failure=False) for f in futures]
+
+
+nodes_per_alloc = 1
+gpus_per_alloc = 1
+ntasks = 1
+
+cluster_kwargs = dict(
+ cores=4,
+ memory="64 GB",
+ shebang="#!/bin/bash",
+ account="m3828",
+ walltime="00:50:00",
+ job_mem="0",
+ job_script_prologue=[
+ "source ~/.bashrc",
+ "module load python",
+ "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
+ ],
+ job_directives_skip=["-n", "--cpus-per-task", "-J"],
+ job_extra_directives=[
+ "-J eos_bulk",
+ "-q regular",
+ f"-N {nodes_per_alloc}",
+ "-C gpu",
+ f"-G {gpus_per_alloc}",
+ "--exclusive",
+ ],
+)
+
+cluster = SLURMCluster(**cluster_kwargs)
+print(cluster.job_script())
+cluster.adapt(minimum_jobs=20, maximum_jobs=40)
+client = Client(cluster)
+
+run_all.with_options(
+ task_runner=DaskTaskRunner(address=client.scheduler.address),
+ log_prints=True,
+)()
diff --git a/examples/eos_bulk/summary.csv b/examples/eos_bulk/summary.csv
new file mode 100644
index 0000000000000000000000000000000000000000..2a3759045a4d1812bf07abe3dca5280112023da8
--- /dev/null
+++ b/examples/eos_bulk/summary.csv
@@ -0,0 +1,8 @@
+model,rank,rank-aggregation,energy-diff-flip-times,tortuosity,spearman-compression-energy,spearman-compression-derivative,spearman-tension-energy,missing
+MACE-MPA,1,6,1.0370741482965933,1.005455197941088,-0.9993684338373716,0.9963320580555048,0.993186372745491,2
+MACE-MP(M),2,16,1.042211055276382,1.008986842539345,-0.999329983249581,0.9941160347190496,0.9915857612939804,5
+MatterSim,3,18,1.045135406218656,1.0060900449752808,-0.99734962463147,0.9927904926901917,0.9880977115916667,3
+CHGNet,4,22,1.1053159478435306,1.014753469076796,-0.9964985866690981,0.9929971733381963,0.9866417434120545,3
+SevenNet,5,27,1.1093279839518555,1.0186969977862483,-0.9981277164827815,0.9889121911188109,0.9859580417030127,3
+M3GNet,6,33,1.1748743718592964,1.0175007963267957,-0.9963209989340641,0.9897426526572255,0.9801690217498693,5
+ORBv2,7,42,1.3162134944612287,1.0374718753890275,-0.9918459519667977,0.9701425127407,0.9637462235649547,7
diff --git a/examples/eos_bulk/summary.tex b/examples/eos_bulk/summary.tex
new file mode 100644
index 0000000000000000000000000000000000000000..a27cbc037c58d00d65372882dd35956573888da5
--- /dev/null
+++ b/examples/eos_bulk/summary.tex
@@ -0,0 +1,13 @@
+\begin{tabular}{lrrrrrrrl}
+\toprule
+model & rank & rank-aggregation & energy-diff-flip-times & tortuosity & spearman-compression-energy & spearman-compression-derivative & spearman-tension-energy & missing \\
+\midrule
+MACE-MPA & 1 & 6 & 1.037074 & 1.005455 & -0.999368 & 0.996332 & 0.993186 & 2 \\
+MACE-MP(M) & 2 & 16 & 1.042211 & 1.008987 & -0.999330 & 0.994116 & 0.991586 & 5 \\
+MatterSim & 3 & 18 & 1.045135 & 1.006090 & -0.997350 & 0.992790 & 0.988098 & 3 \\
+CHGNet & 4 & 22 & 1.105316 & 1.014753 & -0.996499 & 0.992997 & 0.986642 & 3 \\
+SevenNet & 5 & 27 & 1.109328 & 1.018697 & -0.998128 & 0.988912 & 0.985958 & 3 \\
+M3GNet & 6 & 33 & 1.174874 & 1.017501 & -0.996321 & 0.989743 & 0.980169 & 5 \\
+ORBv2 & 7 & 42 & 1.316213 & 1.037472 & -0.991846 & 0.970143 & 0.963746 & 7 \\
+\bottomrule
+\end{tabular}
diff --git a/examples/wbm_ev/ALIGNN.parquet b/examples/wbm_ev/ALIGNN.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..34cbb20eed3a2d47dd00221e4490019419962768
--- /dev/null
+++ b/examples/wbm_ev/ALIGNN.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b84592b56c667f49e510f382c07f2dd4105df71468c2198c3958b2d0066202b
+size 425244
diff --git a/examples/wbm_ev/ALIGNN_processed.parquet b/examples/wbm_ev/ALIGNN_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..fba6b6e80cfe1164bc81dbf660b09628d1fe6834
--- /dev/null
+++ b/examples/wbm_ev/ALIGNN_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1d3f3d2992c02464fdc5a4d58ca05de553dffab355b08e99a46d3b6d2495d11
+size 368547
diff --git a/examples/wbm_ev/CHGNet.parquet b/examples/wbm_ev/CHGNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..9437015c0dc08d2b508481c8ec170a6759dc9f1c
--- /dev/null
+++ b/examples/wbm_ev/CHGNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06342370be572819441a9c706f3e70555c6ac0bf75d0fdaa35f2f574c9f600cd
+size 424462
diff --git a/examples/wbm_ev/CHGNet_processed.parquet b/examples/wbm_ev/CHGNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..f2862d06a12547350abb505c1b222422e7c831eb
--- /dev/null
+++ b/examples/wbm_ev/CHGNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50bbb0ed4ba3e8af06c2a90e927182132be1f988643dd7b6844d39fe4dd1084c
+size 357683
diff --git a/examples/wbm_ev/M3GNet.parquet b/examples/wbm_ev/M3GNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..27f064dc4d9be72439dc73e462c7a7d56f59d6f1
--- /dev/null
+++ b/examples/wbm_ev/M3GNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:841aaa082db265939b3a3ada6f0d6901e65cb942277b074ca55cbdd7730dde75
+size 411741
diff --git a/examples/wbm_ev/M3GNet_processed.parquet b/examples/wbm_ev/M3GNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..d717d2d19b8fbfb4d60b09aa4727aef9f2e7c790
--- /dev/null
+++ b/examples/wbm_ev/M3GNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:411b2f619314dafb34349a65d985b7b881cb21687ecb9a853da29fc21d6fa714
+size 357786
diff --git a/examples/wbm_ev/MACE-MP(M).parquet b/examples/wbm_ev/MACE-MP(M).parquet
new file mode 100644
index 0000000000000000000000000000000000000000..64273d7e076572a111355f3b18b5a0b28486bcb4
--- /dev/null
+++ b/examples/wbm_ev/MACE-MP(M).parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1551823daef888914de951f2610ee6ffdfd2b0d6f33e1e293614534cdd217196
+size 409083
diff --git a/examples/wbm_ev/MACE-MP(M)_processed.parquet b/examples/wbm_ev/MACE-MP(M)_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..6eb4f2c917674a8dfbb46690c1e84a28fc00e826
--- /dev/null
+++ b/examples/wbm_ev/MACE-MP(M)_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:877cd9e9407f9402fc01510d051ad19db815e9794c87ee8402f306e2afafd45e
+size 359765
diff --git a/examples/wbm_ev/MACE-MPA.parquet b/examples/wbm_ev/MACE-MPA.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..59e4bf3153fc1f86534d7f999f6921b6ae4b571a
--- /dev/null
+++ b/examples/wbm_ev/MACE-MPA.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eedf48b2478811a1dca46eb50d607004dca99f54c81909f331c550317f14cd19
+size 407912
diff --git a/examples/wbm_ev/MACE-MPA_processed.parquet b/examples/wbm_ev/MACE-MPA_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..0901c8ac5757f02d83e33ea00616dd40822e08c2
--- /dev/null
+++ b/examples/wbm_ev/MACE-MPA_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee5167ef38acf130548c5f5cbf78fedd14b7de7df47c72d771997f4e13302c0b
+size 356642
diff --git a/examples/wbm_ev/MatterSim.parquet b/examples/wbm_ev/MatterSim.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..c8852d9e741a98f59f3bef04a4f6e020d90ebd08
--- /dev/null
+++ b/examples/wbm_ev/MatterSim.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3d587fd71a817968a513b727a174c353c97552bc7674e5d3a4108e2b6233556
+size 408998
diff --git a/examples/wbm_ev/MatterSim_processed.parquet b/examples/wbm_ev/MatterSim_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..a02f2d8041f349d80fe820ce4e4bbe6eae116edd
--- /dev/null
+++ b/examples/wbm_ev/MatterSim_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:107a30bb5541a861e4141be0db3de210ca2229a22d975f732f107f0e6afdeb0f
+size 356292
diff --git a/examples/wbm_ev/ORBv2.parquet b/examples/wbm_ev/ORBv2.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..a7c753891453cb6a880ebce6d267c0d2d5c14984
--- /dev/null
+++ b/examples/wbm_ev/ORBv2.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ee9f2322096fbeb103a85b0735ed0d547f93e34f1c113af3619baf35c7acbc3
+size 415496
diff --git a/examples/wbm_ev/ORBv2_processed.parquet b/examples/wbm_ev/ORBv2_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..98fbe425b1b5fea8f50aea98e65ba80c847b30b7
--- /dev/null
+++ b/examples/wbm_ev/ORBv2_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01525c5b2e37b8d930c94ee2dc257b8d7c36b24111206ba2efc499a8bc172fcd
+size 357949
diff --git a/examples/wbm_ev/SevenNet.parquet b/examples/wbm_ev/SevenNet.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..b053a0e5f2cc8d69480b94d6a47762f252029f2e
--- /dev/null
+++ b/examples/wbm_ev/SevenNet.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6850c333c0b754b942efcfac11739ae199a9f3c816da4e0e6b26bc9a037a0524
+size 410197
diff --git a/examples/wbm_ev/SevenNet_processed.parquet b/examples/wbm_ev/SevenNet_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..25af1c02248907c47d28c055ef3f985b839df296
--- /dev/null
+++ b/examples/wbm_ev/SevenNet_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83d9c45216940262575bc7386c62f8b79bb3ac5e63b8ab33ca4307976bb9796f
+size 358345
diff --git a/examples/wbm_ev/eqV2(OMat).parquet b/examples/wbm_ev/eqV2(OMat).parquet
new file mode 100644
index 0000000000000000000000000000000000000000..9c77ac538e8078edd73e6deca2498da8e41e236c
--- /dev/null
+++ b/examples/wbm_ev/eqV2(OMat).parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7521597cd3189c0a3cea18ae98d9310bfc6becd50fca5e6f2c97af5e69b2596d
+size 414251
diff --git a/examples/wbm_ev/eqV2(OMat)_processed.parquet b/examples/wbm_ev/eqV2(OMat)_processed.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..8dc514281d099eafda5cd18770caa45ff573c43d
--- /dev/null
+++ b/examples/wbm_ev/eqV2(OMat)_processed.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50f318dce825e939a5c24aa8789acb2952865e0c682462f99cfed204ea9fba64
+size 356693
diff --git a/examples/wbm_ev/run.py b/examples/wbm_ev/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d78ebeae7db1927fff689e7bcab7a6d403c2ab
--- /dev/null
+++ b/examples/wbm_ev/run.py
@@ -0,0 +1,163 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from ase.db import connect
+from dask.distributed import Client
+from dask_jobqueue import SLURMCluster
+from prefect import flow, task
+from prefect.runtime import task_run
+from prefect_dask import DaskTaskRunner
+from prefect.cache_policies import INPUTS, TASK_SOURCE
+
+from mlip_arena.models import REGISTRY, MLIPEnum
+from mlip_arena.tasks.utils import get_calculator
+
+
+@task
+def load_wbm_structures():
+ """
+ Load the WBM structures from an ASE database file.
+
+ Reads structures from 'wbm_structures.db' and yields them as ASE Atoms objects
+ with additional metadata preserved from the database.
+
+ Yields:
+ ase.Atoms: Individual atomic structures from the WBM database with preserved
+ metadata in the .info dictionary.
+ """
+ with connect("../wbm_structures.db") as db:
+ for row in db.select():
+ yield row.toatoms(add_additional_information=True)
+
+@task(
+ name="E-V Scan",
+ task_run_name=lambda: f"{task_run.task_name}: {task_run.parameters['atoms'].get_chemical_formula()} - {task_run.parameters['model'].name}",
+ cache_policy=TASK_SOURCE + INPUTS,
+)
+def ev_scan(atoms, model):
+ """
+ Perform an energy-volume scan for a given model and atomic structure.
+
+ This function applies uniaxial strain to the structure in all three dimensions,
+ maintaining the fractional coordinates of atoms, and computes the energy at each
+ deformation point using the specified model.
+
+ Args:
+ atoms: ASE atoms object containing the structure to analyze.
+ model: MLIPEnum model to use for the energy calculations.
+
+ Returns:
+ dict: Results dictionary containing:
+ - method (str): The name of the model used
+ - id (str): The WBM ID of the structure
+ - eos (dict): Energy of state data with:
+ - volumes (list): Volume of the unit cell at each strain point
+ - energies (list): Computed potential energy at each strain point
+
+ Note:
+ The strain range is fixed at ยฑ20% with 21 evenly spaced points.
+ Results are also saved as a JSON file in a directory named after the model.
+ """
+ calculator = get_calculator(
+ model
+ ) # avoid sending entire model over prefect and select freer GPU
+
+ wbm_id = atoms.info["key_value_pairs"]["wbm_id"]
+
+ c0 = atoms.get_cell()
+ max_abs_strain = 0.2
+ npoints = 21
+ volumes = []
+ energies = []
+ for uniaxial_strain in np.linspace(-max_abs_strain, max_abs_strain, npoints):
+ cloned = atoms.copy()
+ scale_factor = uniaxial_strain + 1
+ cloned.set_cell(c0 * scale_factor, scale_atoms=True)
+ cloned.calc = calculator
+ volumes.append(cloned.get_volume())
+ energies.append(cloned.get_potential_energy())
+
+ data = {
+ "method": model.name,
+ "id": wbm_id,
+ "eos": {
+ "volumes": volumes, "energies": energies
+ }
+ }
+
+ fpath = Path(f"{model.name}") / f"{wbm_id}.json"
+ fpath.parent.mkdir(exist_ok=True)
+
+ df = pd.DataFrame([data])
+ df.to_json(fpath)
+
+ return df
+
+
+@flow
+def submit_tasks():
+ """
+ Create and submit energy-volume scan tasks for subsampled WBM structures and applicable models.
+
+ This flow function:
+ 1. Loads all structures from the WBM database
+ 2. Iterates through available models in MLIPEnum
+ 3. Filters models based on their capability to handle the 'wbm_ev' GPU task
+ 4. Submits parallel ev_scan tasks for all valid (structure, model) combinations
+ 5. Collects and returns results from all tasks
+
+ Returns:
+ list: Results from all executed tasks (successful or failed)
+ """
+ futures = []
+ for atoms in load_wbm_structures():
+ for model in MLIPEnum:
+ if "wbm_ev" not in REGISTRY[model.name].get("gpu-tasks", []):
+ continue
+ try:
+ result = ev_scan.submit(atoms, model)
+ except Exception as e:
+ print(f"Failed to submit task for {model.name}: {e}")
+ continue
+ futures.append(result)
+ return [f.result(raise_on_failure=False) for f in futures]
+
+
+nodes_per_alloc = 1
+gpus_per_alloc = 1
+ntasks = 1
+
+cluster_kwargs = dict(
+ cores=1,
+ memory="64 GB",
+ processes=1,
+ shebang="#!/bin/bash",
+ account="m3828",
+ walltime="00:30:00",
+ # job_mem="0",
+ job_script_prologue=[
+ "source ~/.bashrc",
+ "module load python",
+ "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
+ ],
+ job_directives_skip=["-n", "--cpus-per-task", "-J"],
+ job_extra_directives=[
+ "-J wbm_ev",
+ "-q debug",
+ f"-N {nodes_per_alloc}",
+ "-C gpu",
+ f"-G {gpus_per_alloc}",
+ "--exclusive",
+ ],
+)
+
+cluster = SLURMCluster(**cluster_kwargs)
+print(cluster.job_script())
+cluster.adapt(minimum_jobs=2, maximum_jobs=2)
+client = Client(cluster)
+
+submit_tasks.with_options(
+ task_runner=DaskTaskRunner(address=client.scheduler.address),
+ log_prints=True,
+)()
diff --git a/examples/wbm_ev/summary.csv b/examples/wbm_ev/summary.csv
new file mode 100644
index 0000000000000000000000000000000000000000..f3ebaa0c0bb4f23ba0b43aaa62351e72d7eaba05
--- /dev/null
+++ b/examples/wbm_ev/summary.csv
@@ -0,0 +1,10 @@
+model,rank,rank-aggregation,energy-diff-flip-times,tortuosity,spearman-compression-energy,spearman-compression-derivative,spearman-tension-energy,missing
+MACE-MPA,1,11,1.0,1.000675741122765,-0.9983393939393939,0.9993090909090908,0.9987181818181818,0
+CHGNet,2,14,1.0,1.0006287770651048,-0.9982787878787878,0.9439636363636364,0.999090909090909,0
+MatterSim,3,19,1.009,1.000567338639546,-0.9980969696969696,0.9997090909090908,0.9937541835359507,0
+eqV2(OMat),4,22,1.035,1.0008346292192054,-0.9982060606060604,0.9972242424242423,0.9986454545454545,0
+M3GNet,5,24,1.002,1.0020010929112253,-0.9975878787878787,0.997442424242424,0.9964676571137886,0
+ORBv2,6,29,1.058,1.004064906459821,-0.9977696969696969,0.970751515151515,0.9976,0
+SevenNet,7,33,1.034,1.0100246177550205,-0.9951636363636364,0.9465575757575757,0.9947048195608054,0
+MACE-MP(M),8,35,1.121,1.0807128149289842,-0.9438060606060605,0.9011878787878788,0.9987454545454546,0
+ALIGNN,9,46,3.909,1.3756517739089669,-0.8892069391323368,0.7602706775644651,0.862085379002138,0
diff --git a/examples/wbm_ev/summary.tex b/examples/wbm_ev/summary.tex
new file mode 100644
index 0000000000000000000000000000000000000000..4d95a41c99be575d8445ad9c3d740dad6290007b
--- /dev/null
+++ b/examples/wbm_ev/summary.tex
@@ -0,0 +1,15 @@
+\begin{tabular}{lrrrrrrrl}
+\toprule
+model & rank & rank-aggregation & energy-diff-flip-times & tortuosity & spearman-compression-energy & spearman-compression-derivative & spearman-tension-energy & missing \\
+\midrule
+MACE-MPA & 1 & 11 & 1.000000 & 1.000676 & -0.998339 & 0.999309 & 0.998718 & 0 \\
+CHGNet & 2 & 14 & 1.000000 & 1.000629 & -0.998279 & 0.943964 & 0.999091 & 0 \\
+MatterSim & 3 & 19 & 1.009000 & 1.000567 & -0.998097 & 0.999709 & 0.993754 & 0 \\
+eqV2(OMat) & 4 & 22 & 1.035000 & 1.000835 & -0.998206 & 0.997224 & 0.998645 & 0 \\
+M3GNet & 5 & 24 & 1.002000 & 1.002001 & -0.997588 & 0.997442 & 0.996468 & 0 \\
+ORBv2 & 6 & 29 & 1.058000 & 1.004065 & -0.997770 & 0.970752 & 0.997600 & 0 \\
+SevenNet & 7 & 33 & 1.034000 & 1.010025 & -0.995164 & 0.946558 & 0.994705 & 0 \\
+MACE-MP(M) & 8 & 35 & 1.121000 & 1.080713 & -0.943806 & 0.901188 & 0.998745 & 0 \\
+ALIGNN & 9 & 46 & 3.909000 & 1.375652 & -0.889207 & 0.760271 & 0.862085 & 0 \\
+\bottomrule
+\end{tabular}
diff --git a/examples/wbm_structures.db b/examples/wbm_structures.db
new file mode 100644
index 0000000000000000000000000000000000000000..e1f15c8aa5b8d8fd27b2a9b9ce833deed609f476
--- /dev/null
+++ b/examples/wbm_structures.db
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc387c7787c21e7ff2ab80d5428c60b9e817c9b37f53e03c0f5e5e72dc44fe88
+size 782336
diff --git a/mlip_arena/models/externals/mattersim.py b/mlip_arena/models/externals/mattersim.py
index dab0660a8ecf39683cd5b8639a6e6e6d83017121..28b78a23164720dc149b3b5a10fec854b9f47bc5 100644
--- a/mlip_arena/models/externals/mattersim.py
+++ b/mlip_arena/models/externals/mattersim.py
@@ -5,9 +5,7 @@ from pathlib import Path
import yaml
from mattersim.forcefield import MatterSimCalculator
-from ase import Atoms
from mlip_arena.models.utils import get_freer_device
-# from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
@@ -24,6 +22,14 @@ class MatterSim(MatterSimCalculator):
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
)
+ def get_potential_energy(self, atoms=None, force_consistent=False):
+ return float(
+ super().get_potential_energy(
+ atoms=atoms,
+ force_consistent=force_consistent,
+ )
+ ) # mattersim return numpy float instead of python float
+
def __getstate__(self):
state = self.__dict__.copy()
diff --git a/mlip_arena/models/registry.yaml b/mlip_arena/models/registry.yaml
index c2a9ca2cfbf535022499df264fb2a5721ffb0b08..e656c16f7c2065324252c4f77d37553700b23d85 100644
--- a/mlip_arena/models/registry.yaml
+++ b/mlip_arena/models/registry.yaml
@@ -15,6 +15,8 @@ MACE-MP(M):
- homonuclear-diatomics
- stability
- combustion
+ - eos_bulk
+ - wbm_ev
github: https://github.com/ACEsuit/mace
doi: https://arxiv.org/abs/2401.00096
date: 2023-12-29
@@ -38,6 +40,8 @@ CHGNet:
- homonuclear-diatomics
- stability
- combustion
+ - eos_bulk
+ - wbm_ev
github: https://github.com/CederGroupHub/chgnet
doi: https://doi.org/10.1038/s42256-023-00716-3
date: 2023-02-28
@@ -61,6 +65,8 @@ M3GNet:
- homonuclear-diatomics
- combustion
- stability
+ - eos_bulk
+ - wbm_ev
github: https://github.com/materialsvirtuallab/matgl
doi: https://doi.org/10.1038/s43588-022-00349-3
date: 2022-02-05
@@ -87,6 +93,8 @@ MatterSim:
- homonuclear-diatomics
- stability
- combustion
+ - eos_bulk
+ - wbm_ev
github: https://github.com/microsoft/mattersim
doi: https://arxiv.org/abs/2405.04967
date: 2024-12-05
@@ -113,6 +121,8 @@ ORBv2:
- homonuclear-diatomics
- combustion
- stability
+ - eos_bulk
+ - wbm_ev
github: https://github.com/orbital-materials/orb-models
doi: https://arxiv.org/abs/2410.22570
date: 2024-10-15
@@ -136,6 +146,8 @@ SevenNet:
- homonuclear-diatomics
- stability
- combustion
+ - eos_bulk
+ - wbm_ev
github: https://github.com/MDIL-SNU/SevenNet
doi: https://doi.org/10.1021/acs.jctc.4c00190
date: 2024-07-11
@@ -161,6 +173,7 @@ eqV2(OMat):
- eos_alloy
gpu-tasks:
- homonuclear-diatomics
+ - wbm_ev
prediction: EFS
nvt: true
npt: false # https://github.com/FAIR-Chem/fairchem/issues/888, https://github.com/atomind-ai/mlip-arena/issues/17
@@ -184,6 +197,8 @@ MACE-MPA:
gpu-tasks:
- homonuclear-diatomics
- stability
+ - eos_bulk
+ - wbm_ev
github: https://github.com/ACEsuit/mace
doi: https://arxiv.org/abs/2401.00096
date: 2024-12-09
@@ -336,6 +351,7 @@ ALIGNN:
gpu-tasks:
- homonuclear-diatomics
- stability
+ - wbm_ev
# - combustion
prediction: EFS
nvt: true
diff --git a/mlip_arena/tasks/optimize.py b/mlip_arena/tasks/optimize.py
index 0c6ccd4a29816368c334df43137eb0791d8f870d..c0d5f478b0b055e3ae8598460000e24ad52ce127 100644
--- a/mlip_arena/tasks/optimize.py
+++ b/mlip_arena/tasks/optimize.py
@@ -78,7 +78,7 @@ def run(
filter_kwargs = filter_kwargs or {}
optimizer_kwargs = optimizer_kwargs or {}
- criterion = criterion or {}
+ criterion = criterion or dict(steps=1000)
if symmetry:
atoms.set_constraint(FixSymmetry(atoms))
diff --git a/mlip_arena/tasks/registry.yaml b/mlip_arena/tasks/registry.yaml
index f51e5ca707a05bed2bcd08beb790d1059d13bc73..d35035f9abfdac42070cc766c9db73608eb2a343 100644
--- a/mlip_arena/tasks/registry.yaml
+++ b/mlip_arena/tasks/registry.yaml
@@ -4,6 +4,18 @@ Homonuclear diatomics:
task-layout: wide
rank-page: homonuclear-diatomics
last-update: 2024-09-19
+Energy-volume scans:
+ category: Fundamentals
+ task-page: wbm_ev
+ task-layout: wide
+ rank-page: wbm_ev
+ last-update: 2025-04-29
+Equation of state:
+ category: Fundamentals
+ task-page: eos_bulk
+ task-layout: wide
+ rank-page: eos_bulk
+ last-update: 2025-04-29
Combustion:
category: Molecular Dynamics
task-page: combustion
diff --git a/serve/app.py b/serve/app.py
index 24c0b68b0618fb32505dfbe2408fd1e0bf5baeb7..9eb20febd584b37f164cd58f2fbea568c8b7ea6d 100644
--- a/serve/app.py
+++ b/serve/app.py
@@ -15,6 +15,8 @@ nav[""].append(leaderboard)
wide_pages, centered_pages = [], []
for task in TASKS:
+ if TASKS[task]['task-page'] is None:
+ continue
page = st.Page(
f"tasks/{TASKS[task]['task-page']}.py", title=task, icon=":material/target:"
)
@@ -50,10 +52,10 @@ else:
)
-st.toast(
- "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care. Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.",
- icon="๐",
-)
+# st.toast(
+# "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care. Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.",
+# icon="๐",
+# )
st.sidebar.page_link(
"https://github.com/atomind-ai/mlip-arena", label="GitHub Repository", icon=":material/code:"
diff --git a/serve/leaderboard.py b/serve/leaderboard.py
index 335d93b6ddb2fd6ebd107fe2cf86fbf9d3e9ffd2..da46f1da3dd5731ab02d946e7f3a4d06d1527a46 100644
--- a/serve/leaderboard.py
+++ b/serve/leaderboard.py
@@ -59,10 +59,10 @@ s = table.style.background_gradient(
cmap="PuRd", subset=["Element Coverage"], vmin=0, vmax=120
)
-st.warning(
- "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care.",
- icon="โ ๏ธ",
-)
+# st.warning(
+# "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care.",
+# icon="โ ๏ธ",
+# )
st.info(
"Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.",
icon="๐ค",
@@ -117,11 +117,12 @@ for task in TASKS:
task_module = importlib.import_module(f"ranks.{TASKS[task]['rank-page']}")
- st.page_link(
- f"tasks/{TASKS[task]['task-page']}.py",
- label="Go to the associated task page",
- icon=":material/link:",
- )
+ if TASKS[task]['task-page'] is not None:
+ st.page_link(
+ f"tasks/{TASKS[task]['task-page']}.py",
+ label="Go to the associated task page",
+ icon=":material/link:",
+ )
# Call the function from the imported module
if hasattr(task_module, "render"):
diff --git a/serve/ranks/eos_bulk.py b/serve/ranks/eos_bulk.py
new file mode 100644
index 0000000000000000000000000000000000000000..962ce83e1be52606c40efe3ed0e5788b07b6519e
--- /dev/null
+++ b/serve/ranks/eos_bulk.py
@@ -0,0 +1,63 @@
+from pathlib import Path
+
+import pandas as pd
+import streamlit as st
+
+DATA_DIR = Path("examples/eos_bulk")
+
+
+table = pd.read_csv(DATA_DIR / "summary.csv")
+
+
+
+table = table.rename(
+ columns={
+ "model": "Model",
+ "rank": "Rank",
+ "rank-aggregation": "Rank aggr.",
+ "energy-diff-flip-times": "Derivative flips",
+ "tortuosity": "Tortuosity",
+ "spearman-compression-energy": "Spearman's coeff. (compression)",
+ "spearman-tension-energy": "Spearman's coeff. (tension)",
+ "spearman-compression-derivative": "Spearman's coeff. (compression derivative)",
+ "missing": "Missing",
+ },
+)
+
+table.set_index("Model", inplace=True)
+
+s = (
+ table.style.background_gradient(
+ cmap="Blues",
+ subset=["Rank", "Rank aggr."],
+ ).background_gradient(
+ cmap="Reds",
+ subset=[
+ "Spearman's coeff. (compression)",
+ ],
+ ).background_gradient(
+ cmap="Reds_r",
+ subset=[
+ "Spearman's coeff. (tension)",
+ "Spearman's coeff. (compression derivative)",
+ ],
+ ).background_gradient(
+ cmap="RdPu",
+ subset=["Tortuosity", "Derivative flips"],
+ ).format(
+ "{:.5f}",
+ subset=[
+ "Spearman's coeff. (compression)",
+ "Spearman's coeff. (tension)",
+ "Spearman's coeff. (compression derivative)",
+ "Tortuosity",
+ "Derivative flips",
+ ],
+ )
+)
+
+def render():
+ st.dataframe(
+ s,
+ use_container_width=True,
+ )
diff --git a/serve/ranks/wbm_ev.py b/serve/ranks/wbm_ev.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e4359bde727ca2d456edd11582f3149632060d
--- /dev/null
+++ b/serve/ranks/wbm_ev.py
@@ -0,0 +1,63 @@
+from pathlib import Path
+
+import pandas as pd
+import streamlit as st
+
+DATA_DIR = Path("examples/wbm_ev")
+
+
+table = pd.read_csv(DATA_DIR / "summary.csv")
+
+
+
+table = table.rename(
+ columns={
+ "model": "Model",
+ "rank": "Rank",
+ "rank-aggregation": "Rank aggr.",
+ "energy-diff-flip-times": "Derivative flips",
+ "tortuosity": "Tortuosity",
+ "spearman-compression-energy": "Spearman's coeff. (compression)",
+ "spearman-tension-energy": "Spearman's coeff. (tension)",
+ "spearman-compression-derivative": "Spearman's coeff. (compression derivative)",
+ "missing": "Missing",
+ },
+)
+
+table.set_index("Model", inplace=True)
+
+s = (
+ table.style.background_gradient(
+ cmap="Blues",
+ subset=["Rank", "Rank aggr."],
+ ).background_gradient(
+ cmap="Reds",
+ subset=[
+ "Spearman's coeff. (compression)",
+ ],
+ ).background_gradient(
+ cmap="Reds_r",
+ subset=[
+ "Spearman's coeff. (tension)",
+ "Spearman's coeff. (compression derivative)",
+ ],
+ ).background_gradient(
+ cmap="RdPu",
+ subset=["Tortuosity", "Derivative flips"],
+ ).format(
+ "{:.5f}",
+ subset=[
+ "Spearman's coeff. (compression)",
+ "Spearman's coeff. (tension)",
+ "Spearman's coeff. (compression derivative)",
+ "Tortuosity",
+ "Derivative flips",
+ ],
+ )
+)
+
+def render():
+ st.dataframe(
+ s,
+ use_container_width=True,
+ )
diff --git a/serve/tasks/eos_bulk.py b/serve/tasks/eos_bulk.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c126f0103214f865d79aae1ee846fcf5f54668
--- /dev/null
+++ b/serve/tasks/eos_bulk.py
@@ -0,0 +1,242 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import plotly.colors as pcolors
+import plotly.graph_objects as go
+import streamlit as st
+from ase.db import connect
+from scipy import stats
+
+from mlip_arena.models import REGISTRY as MODELS
+
+DATA_DIR = Path("examples/eos_bulk")
+
+st.markdown("""
+# Equation of state (EOS)
+""")
+
+# Control panels at the top
+st.markdown("### Methods")
+methods_container = st.container(border=True)
+
+valid_models = [
+ model
+ for model, metadata in MODELS.items()
+ if Path(__file__).stem in metadata.get("gpu-tasks", [])
+]
+
+# Model selection
+selected_models = methods_container.multiselect(
+ "Select Models",
+ options=valid_models,
+ default=valid_models
+)
+
+# Visualization settings
+st.markdown("### Visualization Settings")
+vis = st.container(border=True)
+
+# Column settings
+ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=2)
+
+# Color palette selection
+all_attributes = dir(pcolors.qualitative)
+color_palettes = {
+ attr: getattr(pcolors.qualitative, attr)
+ for attr in all_attributes
+ if isinstance(getattr(pcolors.qualitative, attr), list)
+}
+color_palettes.pop("__all__", None)
+
+palette_names = list(color_palettes.keys())
+palette_name = vis.selectbox("Color sequence", options=palette_names, index=22)
+color_sequence = color_palettes[palette_name]
+
+# Stop execution if no models selected
+if not selected_models:
+ st.warning("Please select at least one model to visualize.")
+ st.stop()
+
+
+def load_wbm_structures():
+ """
+ Load the WBM structures from a ASE DB file.
+ """
+ with connect(DATA_DIR.parent / "wbm_structures.db") as db:
+ for row in db.select():
+ yield row.toatoms(add_additional_information=True)
+
+
+@st.cache_data
+def generate_dataframe(model_name):
+ fpath = DATA_DIR / f"{model_name}.parquet"
+ if not fpath.exists():
+ return pd.DataFrame() # Return empty dataframe instead of using continue
+
+ df_raw_results = pd.read_parquet(fpath)
+
+ df_analyzed = pd.DataFrame(
+ columns=[
+ "model",
+ "structure",
+ "formula",
+ "volume-ratio",
+ "energy-delta-per-atom",
+ "energy-diff-flip-times",
+ "tortuosity",
+ "spearman-compression-energy",
+ "spearman-compression-derivative",
+ "spearman-tension-energy",
+ "missing",
+ ]
+ )
+
+ for wbm_struct in load_wbm_structures():
+ structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"]
+
+ try:
+ results = df_raw_results.loc[df_raw_results["id"] == structure_id]
+ results = results["eos"].values[0]
+ es = np.array(results["energies"])
+ vols = np.array(results["volumes"])
+ vol0 = wbm_struct.get_volume()
+
+ indices = np.argsort(vols)
+ vols = vols[indices]
+ es = es[indices]
+
+ imine = len(es) // 2
+ # min_center_val = np.min(es[imid - 1 : imid + 2])
+ # imine = np.where(es == min_center_val)[0][0]
+ emin = es[imine]
+
+ interpolated_volumes = [
+ (vols[i] + vols[i + 1]) / 2 for i in range(len(vols) - 1)
+ ]
+ ediff = np.diff(es)
+ ediff_sign = np.sign(ediff)
+ mask = ediff_sign != 0
+ ediff = ediff[mask]
+ ediff_sign = ediff_sign[mask]
+ ediff_flip = np.diff(ediff_sign) != 0
+
+ etv = np.sum(np.abs(np.diff(es)))
+
+ data = {
+ "model": model_name,
+ "structure": structure_id,
+ "formula": wbm_struct.get_chemical_formula(),
+ "missing": False,
+ "volume-ratio": vols / vol0,
+ "energy-delta-per-atom": (es - emin) / len(wbm_struct),
+ "energy-diff-flip-times": np.sum(ediff_flip).astype(int),
+ "tortuosity": etv / (abs(es[0] - emin) + abs(es[-1] - emin)),
+ "spearman-compression-energy": stats.spearmanr(
+ vols[:imine], es[:imine]
+ ).statistic,
+ "spearman-compression-derivative": stats.spearmanr(
+ interpolated_volumes[:imine], ediff[:imine]
+ ).statistic,
+ "spearman-tension-energy": stats.spearmanr(
+ vols[imine:], es[imine:]
+ ).statistic,
+ }
+
+ except Exception:
+ data = {
+ "model": model_name,
+ "structure": structure_id,
+ "formula": wbm_struct.get_chemical_formula(),
+ "missing": True,
+ "volume-ratio": None,
+ "energy-delta-per-atom": None,
+ "energy-diff-flip-times": None,
+ "tortuosity": None,
+ "spearman-compression-energy": None,
+ "spearman-compression-derivative": None,
+ "spearman-tension-energy": None,
+ }
+
+ df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
+
+ return df_analyzed
+
+
+@st.cache_data
+def get_plots(selected_models):
+ """Generate one plot per model with all structures (legend disabled for each structure)."""
+ figs = []
+
+ for model_name in selected_models:
+
+ fpath = DATA_DIR / f"{model_name}_processed.parquet"
+ if not fpath.exists():
+ df = generate_dataframe(model_name)
+ else:
+ df = pd.read_parquet(fpath)
+
+ if len(df) == 0:
+ continue
+
+ fig = go.Figure()
+ valid_structures = []
+ for i, (_, row) in enumerate(df.iterrows()):
+ structure_id = row["structure"]
+ formula = row.get("formula", "")
+ if isinstance(row["volume-ratio"], list | np.ndarray) and isinstance(
+ row["energy-delta-per-atom"], list | np.ndarray
+ ):
+ vol_strain = row["volume-ratio"]
+ energy_delta = row["energy-delta-per-atom"]
+ color = color_sequence[i % len(color_sequence)]
+ fig.add_trace(
+ go.Scatter(
+ x=vol_strain,
+ y=energy_delta,
+ mode="lines",
+ name=f"{structure_id}",
+ showlegend=False,
+ line=dict(color=color),
+ hoverlabel=dict(bgcolor=color, font=dict(color="black")),
+ hovertemplate=(
+ structure_id + "
"
+ "Formula: " + str(formula) + "
"
+ "Volume ratio V/Vโ: %{x:.3f}
"
+ "ฮEnergy: %{y:.3f} eV/atom
"
+ ""
+ ),
+
+ )
+ )
+ valid_structures.append(structure_id)
+
+ # if valid_structures:
+ fig.update_layout(
+ title=f"{model_name} ({len(valid_structures)} / {len(df)} structures)",
+ xaxis_title="Volume ratio V/Vโ",
+ yaxis_title="Relative energy ฮE (eV/atom)",
+ height=500,
+ showlegend=False, # Disable legend for the whole plot
+ yaxis=dict(range=[-0.1, 1]), # Set y-axis limits
+ )
+ fig.add_vline(x=1, line_dash="dash", line_color="gray", opacity=0.7)
+ figs.append((model_name, fig, valid_structures))
+
+ return figs
+
+
+# Generate all plots
+all_plots = get_plots(selected_models)
+
+# Display plots in the specified column layout
+if all_plots:
+ for i, (model_name, fig, structures) in enumerate(all_plots):
+ if i % ncols == 0:
+ cols = st.columns(ncols)
+ cols[i % ncols].plotly_chart(fig, use_container_width=True)
+
+ # Display number of structures in this plot
+ # cols[i % ncols].caption(f"{len(structures)} / 1000 structures")
+else:
+ st.warning("No data available for the selected models.")
diff --git a/serve/tasks/wbm_ev.py b/serve/tasks/wbm_ev.py
new file mode 100644
index 0000000000000000000000000000000000000000..a292ba1893ed8a48df2e2e1397fad40936f8ece7
--- /dev/null
+++ b/serve/tasks/wbm_ev.py
@@ -0,0 +1,243 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import plotly.colors as pcolors
+import plotly.graph_objects as go
+import streamlit as st
+from ase.db import connect
+from scipy import stats
+
+from mlip_arena.models import REGISTRY as MODELS
+
+DATA_DIR = Path("examples/wbm_ev")
+
+st.markdown("""
+# Energy-volume scans
+""")
+
+# Control panels at the top
+st.markdown("### Methods")
+methods_container = st.container(border=True)
+
+# Get valid models that support wbm_ev
+valid_models = [
+ model
+ for model, metadata in MODELS.items()
+ if Path(__file__).stem in metadata.get("gpu-tasks", [])
+]
+
+# Model selection
+selected_models = methods_container.multiselect(
+ "Select Models",
+ options=valid_models,
+ default=valid_models
+)
+
+# Visualization settings
+st.markdown("### Visualization Settings")
+vis = st.container(border=True)
+
+# Column settings
+ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=2)
+
+# Color palette selection
+all_attributes = dir(pcolors.qualitative)
+color_palettes = {
+ attr: getattr(pcolors.qualitative, attr)
+ for attr in all_attributes
+ if isinstance(getattr(pcolors.qualitative, attr), list)
+}
+color_palettes.pop("__all__", None)
+
+palette_names = list(color_palettes.keys())
+palette_name = vis.selectbox("Color sequence", options=palette_names, index=22)
+color_sequence = color_palettes[palette_name]
+
+# Stop execution if no models selected
+if not selected_models:
+ st.warning("Please select at least one model to visualize.")
+ st.stop()
+
+
+def load_wbm_structures():
+ """
+ Load the WBM structures from a ASE DB file.
+ """
+ with connect(DATA_DIR.parent / "wbm_structures.db") as db:
+ for row in db.select():
+ yield row.toatoms(add_additional_information=True)
+
+
+@st.cache_data
+def generate_dataframe(model_name):
+ fpath = DATA_DIR / f"{model_name}.parquet"
+ if not fpath.exists():
+ return pd.DataFrame() # Return empty dataframe instead of using continue
+
+ df_raw_results = pd.read_parquet(fpath)
+
+ df_analyzed = pd.DataFrame(
+ columns=[
+ "model",
+ "structure",
+ "formula",
+ "volume-ratio",
+ "energy-delta-per-atom",
+ "energy-diff-flip-times",
+ "tortuosity",
+ "spearman-compression-energy",
+ "spearman-compression-derivative",
+ "spearman-tension-energy",
+ "missing",
+ ]
+ )
+
+ for wbm_struct in load_wbm_structures():
+ structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"]
+
+ try:
+ results = df_raw_results.loc[df_raw_results["id"] == structure_id]
+ results = results["eos"].values[0]
+ es = np.array(results["energies"])
+ vols = np.array(results["volumes"])
+ vol0 = wbm_struct.get_volume()
+
+ indices = np.argsort(vols)
+ vols = vols[indices]
+ es = es[indices]
+
+ imine = len(es) // 2
+ # min_center_val = np.min(es[imid - 1 : imid + 2])
+ # imine = np.where(es == min_center_val)[0][0]
+ emin = es[imine]
+
+ interpolated_volumes = [
+ (vols[i] + vols[i + 1]) / 2 for i in range(len(vols) - 1)
+ ]
+ ediff = np.diff(es)
+ ediff_sign = np.sign(ediff)
+ mask = ediff_sign != 0
+ ediff = ediff[mask]
+ ediff_sign = ediff_sign[mask]
+ ediff_flip = np.diff(ediff_sign) != 0
+
+ etv = np.sum(np.abs(np.diff(es)))
+
+ data = {
+ "model": model_name,
+ "structure": structure_id,
+ "formula": wbm_struct.get_chemical_formula(),
+ "missing": False,
+ "volume-ratio": vols / vol0,
+ "energy-delta-per-atom": (es - emin) / len(wbm_struct),
+ "energy-diff-flip-times": np.sum(ediff_flip).astype(int),
+ "tortuosity": etv / (abs(es[0] - emin) + abs(es[-1] - emin)),
+ "spearman-compression-energy": stats.spearmanr(
+ vols[:imine], es[:imine]
+ ).statistic,
+ "spearman-compression-derivative": stats.spearmanr(
+ interpolated_volumes[:imine], ediff[:imine]
+ ).statistic,
+ "spearman-tension-energy": stats.spearmanr(
+ vols[imine:], es[imine:]
+ ).statistic,
+ }
+
+ except Exception:
+ data = {
+ "model": model_name,
+ "structure": structure_id,
+ "formula": wbm_struct.get_chemical_formula(),
+ "missing": True,
+ "volume-ratio": None,
+ "energy-delta-per-atom": None,
+ "energy-diff-flip-times": None,
+ "tortuosity": None,
+ "spearman-compression-energy": None,
+ "spearman-compression-derivative": None,
+ "spearman-tension-energy": None,
+ }
+
+ df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True)
+
+ return df_analyzed
+
+
+@st.cache_data
+def get_plots(selected_models):
+ """Generate one plot per model with all structures (legend disabled for each structure)."""
+ figs = []
+
+ for model_name in selected_models:
+
+ fpath = DATA_DIR / f"{model_name}_processed.parquet"
+ if not fpath.exists():
+ df = generate_dataframe(model_name)
+ else:
+ df = pd.read_parquet(fpath)
+
+ if len(df) == 0:
+ continue
+
+ fig = go.Figure()
+ valid_structures = []
+ for i, (_, row) in enumerate(df.iterrows()):
+ structure_id = row["structure"]
+ formula = row.get("formula", "")
+ if isinstance(row["volume-ratio"], list | np.ndarray) and isinstance(
+ row["energy-delta-per-atom"], list | np.ndarray
+ ):
+ vol_strain = row["volume-ratio"]
+ energy_delta = row["energy-delta-per-atom"]
+ color = color_sequence[i % len(color_sequence)]
+ fig.add_trace(
+ go.Scatter(
+ x=vol_strain,
+ y=energy_delta,
+ mode="lines",
+ name=f"{structure_id}",
+ showlegend=False,
+ line=dict(color=color),
+ hoverlabel=dict(bgcolor=color, font=dict(color="black")),
+ hovertemplate=(
+ structure_id + "
"
+ "Formula: " + str(formula) + "
"
+ "Volume ratio V/Vโ: %{x:.3f}
"
+ "ฮEnergy: %{y:.3f} eV/atom
"
+ ""
+ ),
+
+ )
+ )
+ valid_structures.append(structure_id)
+
+ # if valid_structures:
+ fig.update_layout(
+ title=f"{model_name} ({len(valid_structures)} / {len(df)} structures)",
+ xaxis_title="Volume ratio V/Vโ",
+ yaxis_title="Relative energy E - Eโ (eV/atom)",
+ height=500,
+ showlegend=False, # Disable legend for the whole plot
+ yaxis=dict(range=[-1, 15]), # Set y-axis limits
+ )
+ fig.add_vline(x=1, line_dash="dash", line_color="gray", opacity=0.7)
+ figs.append((model_name, fig, valid_structures))
+
+ return figs
+
+
+# Generate all plots
+all_plots = get_plots(selected_models)
+
+# Display plots in the specified column layout
+if all_plots:
+ for i, (model_name, fig, structures) in enumerate(all_plots):
+ if i % ncols == 0:
+ cols = st.columns(ncols)
+ cols[i % ncols].plotly_chart(fig, use_container_width=True)
+
+ # Display number of structures in this plot
+ # cols[i % ncols].caption(f"{len(structures)} / 1000 structures")
+else:
+ st.warning("No data available for the selected models.")