diff --git a/.gitattributes b/.gitattributes index 7fe70d7f07c494ee23600b490a1167607f3a08ca..7ebbda59fc572adcf72d0ba8bf6de67a0e57cc30 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,3 @@ *.json filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.db filter=lfs diff=lfs merge=lfs -text diff --git a/.github/README.md b/.github/README.md index cd764debf7f948d469d69b7dafd3abb5e9886f4c..ba8d603ea1af032ac91fa32ae3456797fab4f5c5 100644 --- a/.github/README.md +++ b/.github/README.md @@ -18,12 +18,12 @@ MLIP Arena leverages modern pythonic workflow orchestrator [Prefect](https://www ## Announcement -- **[April 8, 2025]** [๐ŸŽ‰ **MLIP Arena accepted as an ICLR AI4Mat Spotlight!** ๐ŸŽ‰](https://openreview.net/forum?id=ysKfIavYQE#discussion) Huge thanks to all co-authors for their contributions! +- **[April 8, 2025]** [๐ŸŽ‰ **MLIP Arena is accepted as an ICLR AI4Mat Spotlight!** ๐ŸŽ‰](https://openreview.net/forum?id=ysKfIavYQE#discussion) Huge thanks to all co-authors for their contributions! ## Installation -### From PyPI (without model running capability) +### From PyPI (prefect workflow only, without pretrained models) ```bash pip install mlip-arena diff --git a/examples/eos_bulk/CHGNet.parquet b/examples/eos_bulk/CHGNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..43dad7c081ddbc284dc6e6b2c2852465debfa320 --- /dev/null +++ b/examples/eos_bulk/CHGNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68871d694e93a3c3e7e272b9cbd87d3757e3bc689f30f3189db232d76e629c07 +size 429910 diff --git a/examples/eos_bulk/CHGNet_processed.parquet b/examples/eos_bulk/CHGNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..68784499609a2c80692fa0f654172600134c7782 --- /dev/null +++ b/examples/eos_bulk/CHGNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6fbea63f9035e376bb5ac7db38175102ab3f96a0f8758cc3e9931424f829ac0 +size 357919 diff --git a/examples/eos_bulk/M3GNet.parquet b/examples/eos_bulk/M3GNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..0ce0cb0cd28f9e7203404427fe68d6b04edb859c --- /dev/null +++ b/examples/eos_bulk/M3GNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53dde465b5e10edd677f131f8a531e3dfc36303dd7ec7b9df0060c19847494d9 +size 427419 diff --git a/examples/eos_bulk/M3GNet_processed.parquet b/examples/eos_bulk/M3GNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..1f4c816367ce534b6cb07da03b6bec28035e125c --- /dev/null +++ b/examples/eos_bulk/M3GNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18ea51bf19c5e011e170a3229bbd63ce675a725c364e0cffb71de95459f8629e +size 379859 diff --git a/examples/eos_bulk/MACE-MP(M).parquet b/examples/eos_bulk/MACE-MP(M).parquet new file mode 100644 index 0000000000000000000000000000000000000000..7287426ae925a64ec1400a9f641848be301ada49 --- /dev/null +++ b/examples/eos_bulk/MACE-MP(M).parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff9769eeb83042129767aeff975eb04dee8efae12e96fbd46cd3039eeda26705 +size 427896 diff --git a/examples/eos_bulk/MACE-MP(M)_processed.parquet b/examples/eos_bulk/MACE-MP(M)_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..bef6529f23c192c7741131de38aee08a72817434 --- /dev/null +++ b/examples/eos_bulk/MACE-MP(M)_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f492125c87400fee013d32904ee97df49a07d321bd393f9335c7bf4258fe159 +size 371004 diff --git a/examples/eos_bulk/MACE-MPA.parquet b/examples/eos_bulk/MACE-MPA.parquet new file mode 100644 index 0000000000000000000000000000000000000000..ae18b9d947a17edb51ce0fecd2e6cc066d74d0a5 --- /dev/null +++ b/examples/eos_bulk/MACE-MPA.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53fcd188baddd4d5e797c5aa3de1b4368db711ebd29b7877cfe224856ba9d171 +size 428888 diff --git a/examples/eos_bulk/MACE-MPA_processed.parquet b/examples/eos_bulk/MACE-MPA_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..4be6db0ab599fd1e4a49d4bc77c0418240166cca --- /dev/null +++ b/examples/eos_bulk/MACE-MPA_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34bd9ae08656e374263820774f49e91a964aa8b2aeade4150cf62cfc08bb37f6 +size 365289 diff --git a/examples/eos_bulk/MatterSim.parquet b/examples/eos_bulk/MatterSim.parquet new file mode 100644 index 0000000000000000000000000000000000000000..1ac19b614e60ddafbd3a5e0ed2599ac3e7daf401 --- /dev/null +++ b/examples/eos_bulk/MatterSim.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6717650b97782de6f90e4473075410fe4540279eb39338d2234d3c9399079b3 +size 389586 diff --git a/examples/eos_bulk/MatterSim_processed.parquet b/examples/eos_bulk/MatterSim_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..85ee3a7982f54b27ec9696b3ac1018450d120918 --- /dev/null +++ b/examples/eos_bulk/MatterSim_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cca22b5db67dae59602adfb8c42a80da90b39c4e95af89ef918813351b422119 +size 320962 diff --git a/examples/eos_bulk/ORBv2.parquet b/examples/eos_bulk/ORBv2.parquet new file mode 100644 index 0000000000000000000000000000000000000000..2cd571ac07630e9883eb0c9e0d854ed646a18d9b --- /dev/null +++ b/examples/eos_bulk/ORBv2.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae13c9af1ae7fafe2a42ed4c47e2ba0f036abfa64a87ca517b92d89c62fcbfd9 +size 427105 diff --git a/examples/eos_bulk/ORBv2_processed.parquet b/examples/eos_bulk/ORBv2_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c8775daf947ddb809288bcbc648b18530e0c9936 --- /dev/null +++ b/examples/eos_bulk/ORBv2_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c90e61645c83f2452bcd90ef9132c46266896217fe1dea9cca8e0d124d73821a +size 227929 diff --git a/examples/eos_bulk/SevenNet.parquet b/examples/eos_bulk/SevenNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..5a3bb7cc25bd7752f35f3fbad0154d411b270b97 --- /dev/null +++ b/examples/eos_bulk/SevenNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64be88ec2632cdabf79daa01acb2cf2ef19fef0557813df5502c4f71ec566f4e +size 428341 diff --git a/examples/eos_bulk/SevenNet_processed.parquet b/examples/eos_bulk/SevenNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c32745504ba858eba7902004ec8eec54e4a159c0 --- /dev/null +++ b/examples/eos_bulk/SevenNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cc03d9af93c001f3fa441b50f058506e13c0aa7cb3d329275e68f5ed80dc3e6 +size 364846 diff --git a/examples/eos_bulk/preprocessing.py b/examples/eos_bulk/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..08789185c32c110dd37ce3576bf8779073615069 --- /dev/null +++ b/examples/eos_bulk/preprocessing.py @@ -0,0 +1,12 @@ +import json + +from ase.db import connect +from pymatgen.core import Structure + +with open("wbm_structures.json") as f: + structs = json.load(f) + +with connect("wbm_structures.db") as db: + for id, s in structs.items(): + atoms = Structure.from_dict(s).to_ase_atoms(msonable=False) + db.write(atoms, wbm_id=id) diff --git a/examples/eos_bulk/run.py b/examples/eos_bulk/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5e0b9dabbeaff7feef12381b851a7f1fe1423286 --- /dev/null +++ b/examples/eos_bulk/run.py @@ -0,0 +1,135 @@ +import functools +from pathlib import Path + +import pandas as pd +from ase.db import connect +from dask.distributed import Client +from dask_jobqueue import SLURMCluster +from prefect import Task, flow, task +from prefect.client.schemas.objects import TaskRun +from prefect.states import State +from prefect_dask import DaskTaskRunner + +from mlip_arena.models import REGISTRY, MLIPEnum +from mlip_arena.tasks.eos import run as EOS +from mlip_arena.tasks.optimize import run as OPT +from mlip_arena.tasks.utils import get_calculator + + +@task +def load_wbm_structures(): + """ + Load the WBM structures from a ASE DB file. + """ + with connect("../wbm_structures.db") as db: + for row in db.select(): + yield row.toatoms(add_additional_information=True) + + +def save_result( + tsk: Task, + run: TaskRun, + state: State, + model_name: str, + id: str, +): + result = run.state.result() + + assert isinstance(result, dict) + + result["method"] = model_name + result["id"] = id + result.pop("atoms", None) + + fpath = Path(f"{model_name}") + fpath.mkdir(exist_ok=True) + + fpath = fpath / f"{result['id']}.pkl" + + df = pd.DataFrame([result]) + df.to_pickle(fpath) + + +@task +def eos_bulk(atoms, model): + + calculator = get_calculator( + model + ) # avoid sending entire model over prefect and select freer GPU + + result = OPT.with_options( + refresh_cache=True, + )( + atoms, + calculator, + optimizer="FIRE", + criterion=dict( + fmax=0.1, + ), + ) + + return EOS.with_options( + refresh_cache=True, + on_completion=[functools.partial( + save_result, + model_name=model.name, + id=atoms.info["key_value_pairs"]["wbm_id"], + )], + )( + atoms=result["atoms"], + calculator=calculator, + optimizer="FIRE", + npoints=21, + max_abs_strain=0.2, + concurrent=False + ) + + +@flow +def run_all(): + futures = [] + for atoms in load_wbm_structures(): + for model in MLIPEnum: + if "eos_bulk" not in REGISTRY[model.name].get("gpu-tasks", []): + continue + result = eos_bulk.submit(atoms, model) + futures.append(result) + return [f.result(raise_on_failure=False) for f in futures] + + +nodes_per_alloc = 1 +gpus_per_alloc = 1 +ntasks = 1 + +cluster_kwargs = dict( + cores=4, + memory="64 GB", + shebang="#!/bin/bash", + account="m3828", + walltime="00:50:00", + job_mem="0", + job_script_prologue=[ + "source ~/.bashrc", + "module load python", + "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena", + ], + job_directives_skip=["-n", "--cpus-per-task", "-J"], + job_extra_directives=[ + "-J eos_bulk", + "-q regular", + f"-N {nodes_per_alloc}", + "-C gpu", + f"-G {gpus_per_alloc}", + "--exclusive", + ], +) + +cluster = SLURMCluster(**cluster_kwargs) +print(cluster.job_script()) +cluster.adapt(minimum_jobs=20, maximum_jobs=40) +client = Client(cluster) + +run_all.with_options( + task_runner=DaskTaskRunner(address=client.scheduler.address), + log_prints=True, +)() diff --git a/examples/eos_bulk/summary.csv b/examples/eos_bulk/summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..2a3759045a4d1812bf07abe3dca5280112023da8 --- /dev/null +++ b/examples/eos_bulk/summary.csv @@ -0,0 +1,8 @@ +model,rank,rank-aggregation,energy-diff-flip-times,tortuosity,spearman-compression-energy,spearman-compression-derivative,spearman-tension-energy,missing +MACE-MPA,1,6,1.0370741482965933,1.005455197941088,-0.9993684338373716,0.9963320580555048,0.993186372745491,2 +MACE-MP(M),2,16,1.042211055276382,1.008986842539345,-0.999329983249581,0.9941160347190496,0.9915857612939804,5 +MatterSim,3,18,1.045135406218656,1.0060900449752808,-0.99734962463147,0.9927904926901917,0.9880977115916667,3 +CHGNet,4,22,1.1053159478435306,1.014753469076796,-0.9964985866690981,0.9929971733381963,0.9866417434120545,3 +SevenNet,5,27,1.1093279839518555,1.0186969977862483,-0.9981277164827815,0.9889121911188109,0.9859580417030127,3 +M3GNet,6,33,1.1748743718592964,1.0175007963267957,-0.9963209989340641,0.9897426526572255,0.9801690217498693,5 +ORBv2,7,42,1.3162134944612287,1.0374718753890275,-0.9918459519667977,0.9701425127407,0.9637462235649547,7 diff --git a/examples/eos_bulk/summary.tex b/examples/eos_bulk/summary.tex new file mode 100644 index 0000000000000000000000000000000000000000..a27cbc037c58d00d65372882dd35956573888da5 --- /dev/null +++ b/examples/eos_bulk/summary.tex @@ -0,0 +1,13 @@ +\begin{tabular}{lrrrrrrrl} +\toprule +model & rank & rank-aggregation & energy-diff-flip-times & tortuosity & spearman-compression-energy & spearman-compression-derivative & spearman-tension-energy & missing \\ +\midrule +MACE-MPA & 1 & 6 & 1.037074 & 1.005455 & -0.999368 & 0.996332 & 0.993186 & 2 \\ +MACE-MP(M) & 2 & 16 & 1.042211 & 1.008987 & -0.999330 & 0.994116 & 0.991586 & 5 \\ +MatterSim & 3 & 18 & 1.045135 & 1.006090 & -0.997350 & 0.992790 & 0.988098 & 3 \\ +CHGNet & 4 & 22 & 1.105316 & 1.014753 & -0.996499 & 0.992997 & 0.986642 & 3 \\ +SevenNet & 5 & 27 & 1.109328 & 1.018697 & -0.998128 & 0.988912 & 0.985958 & 3 \\ +M3GNet & 6 & 33 & 1.174874 & 1.017501 & -0.996321 & 0.989743 & 0.980169 & 5 \\ +ORBv2 & 7 & 42 & 1.316213 & 1.037472 & -0.991846 & 0.970143 & 0.963746 & 7 \\ +\bottomrule +\end{tabular} diff --git a/examples/wbm_ev/ALIGNN.parquet b/examples/wbm_ev/ALIGNN.parquet new file mode 100644 index 0000000000000000000000000000000000000000..34cbb20eed3a2d47dd00221e4490019419962768 --- /dev/null +++ b/examples/wbm_ev/ALIGNN.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b84592b56c667f49e510f382c07f2dd4105df71468c2198c3958b2d0066202b +size 425244 diff --git a/examples/wbm_ev/ALIGNN_processed.parquet b/examples/wbm_ev/ALIGNN_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..fba6b6e80cfe1164bc81dbf660b09628d1fe6834 --- /dev/null +++ b/examples/wbm_ev/ALIGNN_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1d3f3d2992c02464fdc5a4d58ca05de553dffab355b08e99a46d3b6d2495d11 +size 368547 diff --git a/examples/wbm_ev/CHGNet.parquet b/examples/wbm_ev/CHGNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..9437015c0dc08d2b508481c8ec170a6759dc9f1c --- /dev/null +++ b/examples/wbm_ev/CHGNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06342370be572819441a9c706f3e70555c6ac0bf75d0fdaa35f2f574c9f600cd +size 424462 diff --git a/examples/wbm_ev/CHGNet_processed.parquet b/examples/wbm_ev/CHGNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..f2862d06a12547350abb505c1b222422e7c831eb --- /dev/null +++ b/examples/wbm_ev/CHGNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50bbb0ed4ba3e8af06c2a90e927182132be1f988643dd7b6844d39fe4dd1084c +size 357683 diff --git a/examples/wbm_ev/M3GNet.parquet b/examples/wbm_ev/M3GNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..27f064dc4d9be72439dc73e462c7a7d56f59d6f1 --- /dev/null +++ b/examples/wbm_ev/M3GNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:841aaa082db265939b3a3ada6f0d6901e65cb942277b074ca55cbdd7730dde75 +size 411741 diff --git a/examples/wbm_ev/M3GNet_processed.parquet b/examples/wbm_ev/M3GNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..d717d2d19b8fbfb4d60b09aa4727aef9f2e7c790 --- /dev/null +++ b/examples/wbm_ev/M3GNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:411b2f619314dafb34349a65d985b7b881cb21687ecb9a853da29fc21d6fa714 +size 357786 diff --git a/examples/wbm_ev/MACE-MP(M).parquet b/examples/wbm_ev/MACE-MP(M).parquet new file mode 100644 index 0000000000000000000000000000000000000000..64273d7e076572a111355f3b18b5a0b28486bcb4 --- /dev/null +++ b/examples/wbm_ev/MACE-MP(M).parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1551823daef888914de951f2610ee6ffdfd2b0d6f33e1e293614534cdd217196 +size 409083 diff --git a/examples/wbm_ev/MACE-MP(M)_processed.parquet b/examples/wbm_ev/MACE-MP(M)_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..6eb4f2c917674a8dfbb46690c1e84a28fc00e826 --- /dev/null +++ b/examples/wbm_ev/MACE-MP(M)_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:877cd9e9407f9402fc01510d051ad19db815e9794c87ee8402f306e2afafd45e +size 359765 diff --git a/examples/wbm_ev/MACE-MPA.parquet b/examples/wbm_ev/MACE-MPA.parquet new file mode 100644 index 0000000000000000000000000000000000000000..59e4bf3153fc1f86534d7f999f6921b6ae4b571a --- /dev/null +++ b/examples/wbm_ev/MACE-MPA.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eedf48b2478811a1dca46eb50d607004dca99f54c81909f331c550317f14cd19 +size 407912 diff --git a/examples/wbm_ev/MACE-MPA_processed.parquet b/examples/wbm_ev/MACE-MPA_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..0901c8ac5757f02d83e33ea00616dd40822e08c2 --- /dev/null +++ b/examples/wbm_ev/MACE-MPA_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee5167ef38acf130548c5f5cbf78fedd14b7de7df47c72d771997f4e13302c0b +size 356642 diff --git a/examples/wbm_ev/MatterSim.parquet b/examples/wbm_ev/MatterSim.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c8852d9e741a98f59f3bef04a4f6e020d90ebd08 --- /dev/null +++ b/examples/wbm_ev/MatterSim.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3d587fd71a817968a513b727a174c353c97552bc7674e5d3a4108e2b6233556 +size 408998 diff --git a/examples/wbm_ev/MatterSim_processed.parquet b/examples/wbm_ev/MatterSim_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..a02f2d8041f349d80fe820ce4e4bbe6eae116edd --- /dev/null +++ b/examples/wbm_ev/MatterSim_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:107a30bb5541a861e4141be0db3de210ca2229a22d975f732f107f0e6afdeb0f +size 356292 diff --git a/examples/wbm_ev/ORBv2.parquet b/examples/wbm_ev/ORBv2.parquet new file mode 100644 index 0000000000000000000000000000000000000000..a7c753891453cb6a880ebce6d267c0d2d5c14984 --- /dev/null +++ b/examples/wbm_ev/ORBv2.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ee9f2322096fbeb103a85b0735ed0d547f93e34f1c113af3619baf35c7acbc3 +size 415496 diff --git a/examples/wbm_ev/ORBv2_processed.parquet b/examples/wbm_ev/ORBv2_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..98fbe425b1b5fea8f50aea98e65ba80c847b30b7 --- /dev/null +++ b/examples/wbm_ev/ORBv2_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01525c5b2e37b8d930c94ee2dc257b8d7c36b24111206ba2efc499a8bc172fcd +size 357949 diff --git a/examples/wbm_ev/SevenNet.parquet b/examples/wbm_ev/SevenNet.parquet new file mode 100644 index 0000000000000000000000000000000000000000..b053a0e5f2cc8d69480b94d6a47762f252029f2e --- /dev/null +++ b/examples/wbm_ev/SevenNet.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6850c333c0b754b942efcfac11739ae199a9f3c816da4e0e6b26bc9a037a0524 +size 410197 diff --git a/examples/wbm_ev/SevenNet_processed.parquet b/examples/wbm_ev/SevenNet_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..25af1c02248907c47d28c055ef3f985b839df296 --- /dev/null +++ b/examples/wbm_ev/SevenNet_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83d9c45216940262575bc7386c62f8b79bb3ac5e63b8ab33ca4307976bb9796f +size 358345 diff --git a/examples/wbm_ev/eqV2(OMat).parquet b/examples/wbm_ev/eqV2(OMat).parquet new file mode 100644 index 0000000000000000000000000000000000000000..9c77ac538e8078edd73e6deca2498da8e41e236c --- /dev/null +++ b/examples/wbm_ev/eqV2(OMat).parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7521597cd3189c0a3cea18ae98d9310bfc6becd50fca5e6f2c97af5e69b2596d +size 414251 diff --git a/examples/wbm_ev/eqV2(OMat)_processed.parquet b/examples/wbm_ev/eqV2(OMat)_processed.parquet new file mode 100644 index 0000000000000000000000000000000000000000..8dc514281d099eafda5cd18770caa45ff573c43d --- /dev/null +++ b/examples/wbm_ev/eqV2(OMat)_processed.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50f318dce825e939a5c24aa8789acb2952865e0c682462f99cfed204ea9fba64 +size 356693 diff --git a/examples/wbm_ev/run.py b/examples/wbm_ev/run.py new file mode 100644 index 0000000000000000000000000000000000000000..90d78ebeae7db1927fff689e7bcab7a6d403c2ab --- /dev/null +++ b/examples/wbm_ev/run.py @@ -0,0 +1,163 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from ase.db import connect +from dask.distributed import Client +from dask_jobqueue import SLURMCluster +from prefect import flow, task +from prefect.runtime import task_run +from prefect_dask import DaskTaskRunner +from prefect.cache_policies import INPUTS, TASK_SOURCE + +from mlip_arena.models import REGISTRY, MLIPEnum +from mlip_arena.tasks.utils import get_calculator + + +@task +def load_wbm_structures(): + """ + Load the WBM structures from an ASE database file. + + Reads structures from 'wbm_structures.db' and yields them as ASE Atoms objects + with additional metadata preserved from the database. + + Yields: + ase.Atoms: Individual atomic structures from the WBM database with preserved + metadata in the .info dictionary. + """ + with connect("../wbm_structures.db") as db: + for row in db.select(): + yield row.toatoms(add_additional_information=True) + +@task( + name="E-V Scan", + task_run_name=lambda: f"{task_run.task_name}: {task_run.parameters['atoms'].get_chemical_formula()} - {task_run.parameters['model'].name}", + cache_policy=TASK_SOURCE + INPUTS, +) +def ev_scan(atoms, model): + """ + Perform an energy-volume scan for a given model and atomic structure. + + This function applies uniaxial strain to the structure in all three dimensions, + maintaining the fractional coordinates of atoms, and computes the energy at each + deformation point using the specified model. + + Args: + atoms: ASE atoms object containing the structure to analyze. + model: MLIPEnum model to use for the energy calculations. + + Returns: + dict: Results dictionary containing: + - method (str): The name of the model used + - id (str): The WBM ID of the structure + - eos (dict): Energy of state data with: + - volumes (list): Volume of the unit cell at each strain point + - energies (list): Computed potential energy at each strain point + + Note: + The strain range is fixed at ยฑ20% with 21 evenly spaced points. + Results are also saved as a JSON file in a directory named after the model. + """ + calculator = get_calculator( + model + ) # avoid sending entire model over prefect and select freer GPU + + wbm_id = atoms.info["key_value_pairs"]["wbm_id"] + + c0 = atoms.get_cell() + max_abs_strain = 0.2 + npoints = 21 + volumes = [] + energies = [] + for uniaxial_strain in np.linspace(-max_abs_strain, max_abs_strain, npoints): + cloned = atoms.copy() + scale_factor = uniaxial_strain + 1 + cloned.set_cell(c0 * scale_factor, scale_atoms=True) + cloned.calc = calculator + volumes.append(cloned.get_volume()) + energies.append(cloned.get_potential_energy()) + + data = { + "method": model.name, + "id": wbm_id, + "eos": { + "volumes": volumes, "energies": energies + } + } + + fpath = Path(f"{model.name}") / f"{wbm_id}.json" + fpath.parent.mkdir(exist_ok=True) + + df = pd.DataFrame([data]) + df.to_json(fpath) + + return df + + +@flow +def submit_tasks(): + """ + Create and submit energy-volume scan tasks for subsampled WBM structures and applicable models. + + This flow function: + 1. Loads all structures from the WBM database + 2. Iterates through available models in MLIPEnum + 3. Filters models based on their capability to handle the 'wbm_ev' GPU task + 4. Submits parallel ev_scan tasks for all valid (structure, model) combinations + 5. Collects and returns results from all tasks + + Returns: + list: Results from all executed tasks (successful or failed) + """ + futures = [] + for atoms in load_wbm_structures(): + for model in MLIPEnum: + if "wbm_ev" not in REGISTRY[model.name].get("gpu-tasks", []): + continue + try: + result = ev_scan.submit(atoms, model) + except Exception as e: + print(f"Failed to submit task for {model.name}: {e}") + continue + futures.append(result) + return [f.result(raise_on_failure=False) for f in futures] + + +nodes_per_alloc = 1 +gpus_per_alloc = 1 +ntasks = 1 + +cluster_kwargs = dict( + cores=1, + memory="64 GB", + processes=1, + shebang="#!/bin/bash", + account="m3828", + walltime="00:30:00", + # job_mem="0", + job_script_prologue=[ + "source ~/.bashrc", + "module load python", + "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena", + ], + job_directives_skip=["-n", "--cpus-per-task", "-J"], + job_extra_directives=[ + "-J wbm_ev", + "-q debug", + f"-N {nodes_per_alloc}", + "-C gpu", + f"-G {gpus_per_alloc}", + "--exclusive", + ], +) + +cluster = SLURMCluster(**cluster_kwargs) +print(cluster.job_script()) +cluster.adapt(minimum_jobs=2, maximum_jobs=2) +client = Client(cluster) + +submit_tasks.with_options( + task_runner=DaskTaskRunner(address=client.scheduler.address), + log_prints=True, +)() diff --git a/examples/wbm_ev/summary.csv b/examples/wbm_ev/summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..f3ebaa0c0bb4f23ba0b43aaa62351e72d7eaba05 --- /dev/null +++ b/examples/wbm_ev/summary.csv @@ -0,0 +1,10 @@ +model,rank,rank-aggregation,energy-diff-flip-times,tortuosity,spearman-compression-energy,spearman-compression-derivative,spearman-tension-energy,missing +MACE-MPA,1,11,1.0,1.000675741122765,-0.9983393939393939,0.9993090909090908,0.9987181818181818,0 +CHGNet,2,14,1.0,1.0006287770651048,-0.9982787878787878,0.9439636363636364,0.999090909090909,0 +MatterSim,3,19,1.009,1.000567338639546,-0.9980969696969696,0.9997090909090908,0.9937541835359507,0 +eqV2(OMat),4,22,1.035,1.0008346292192054,-0.9982060606060604,0.9972242424242423,0.9986454545454545,0 +M3GNet,5,24,1.002,1.0020010929112253,-0.9975878787878787,0.997442424242424,0.9964676571137886,0 +ORBv2,6,29,1.058,1.004064906459821,-0.9977696969696969,0.970751515151515,0.9976,0 +SevenNet,7,33,1.034,1.0100246177550205,-0.9951636363636364,0.9465575757575757,0.9947048195608054,0 +MACE-MP(M),8,35,1.121,1.0807128149289842,-0.9438060606060605,0.9011878787878788,0.9987454545454546,0 +ALIGNN,9,46,3.909,1.3756517739089669,-0.8892069391323368,0.7602706775644651,0.862085379002138,0 diff --git a/examples/wbm_ev/summary.tex b/examples/wbm_ev/summary.tex new file mode 100644 index 0000000000000000000000000000000000000000..4d95a41c99be575d8445ad9c3d740dad6290007b --- /dev/null +++ b/examples/wbm_ev/summary.tex @@ -0,0 +1,15 @@ +\begin{tabular}{lrrrrrrrl} +\toprule +model & rank & rank-aggregation & energy-diff-flip-times & tortuosity & spearman-compression-energy & spearman-compression-derivative & spearman-tension-energy & missing \\ +\midrule +MACE-MPA & 1 & 11 & 1.000000 & 1.000676 & -0.998339 & 0.999309 & 0.998718 & 0 \\ +CHGNet & 2 & 14 & 1.000000 & 1.000629 & -0.998279 & 0.943964 & 0.999091 & 0 \\ +MatterSim & 3 & 19 & 1.009000 & 1.000567 & -0.998097 & 0.999709 & 0.993754 & 0 \\ +eqV2(OMat) & 4 & 22 & 1.035000 & 1.000835 & -0.998206 & 0.997224 & 0.998645 & 0 \\ +M3GNet & 5 & 24 & 1.002000 & 1.002001 & -0.997588 & 0.997442 & 0.996468 & 0 \\ +ORBv2 & 6 & 29 & 1.058000 & 1.004065 & -0.997770 & 0.970752 & 0.997600 & 0 \\ +SevenNet & 7 & 33 & 1.034000 & 1.010025 & -0.995164 & 0.946558 & 0.994705 & 0 \\ +MACE-MP(M) & 8 & 35 & 1.121000 & 1.080713 & -0.943806 & 0.901188 & 0.998745 & 0 \\ +ALIGNN & 9 & 46 & 3.909000 & 1.375652 & -0.889207 & 0.760271 & 0.862085 & 0 \\ +\bottomrule +\end{tabular} diff --git a/examples/wbm_structures.db b/examples/wbm_structures.db new file mode 100644 index 0000000000000000000000000000000000000000..e1f15c8aa5b8d8fd27b2a9b9ce833deed609f476 --- /dev/null +++ b/examples/wbm_structures.db @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc387c7787c21e7ff2ab80d5428c60b9e817c9b37f53e03c0f5e5e72dc44fe88 +size 782336 diff --git a/mlip_arena/models/externals/mattersim.py b/mlip_arena/models/externals/mattersim.py index dab0660a8ecf39683cd5b8639a6e6e6d83017121..28b78a23164720dc149b3b5a10fec854b9f47bc5 100644 --- a/mlip_arena/models/externals/mattersim.py +++ b/mlip_arena/models/externals/mattersim.py @@ -5,9 +5,7 @@ from pathlib import Path import yaml from mattersim.forcefield import MatterSimCalculator -from ase import Atoms from mlip_arena.models.utils import get_freer_device -# from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f: REGISTRY = yaml.safe_load(f) @@ -24,6 +22,14 @@ class MatterSim(MatterSimCalculator): load_path=checkpoint, device=str(device or get_freer_device()), **kwargs ) + def get_potential_energy(self, atoms=None, force_consistent=False): + return float( + super().get_potential_energy( + atoms=atoms, + force_consistent=force_consistent, + ) + ) # mattersim return numpy float instead of python float + def __getstate__(self): state = self.__dict__.copy() diff --git a/mlip_arena/models/registry.yaml b/mlip_arena/models/registry.yaml index c2a9ca2cfbf535022499df264fb2a5721ffb0b08..e656c16f7c2065324252c4f77d37553700b23d85 100644 --- a/mlip_arena/models/registry.yaml +++ b/mlip_arena/models/registry.yaml @@ -15,6 +15,8 @@ MACE-MP(M): - homonuclear-diatomics - stability - combustion + - eos_bulk + - wbm_ev github: https://github.com/ACEsuit/mace doi: https://arxiv.org/abs/2401.00096 date: 2023-12-29 @@ -38,6 +40,8 @@ CHGNet: - homonuclear-diatomics - stability - combustion + - eos_bulk + - wbm_ev github: https://github.com/CederGroupHub/chgnet doi: https://doi.org/10.1038/s42256-023-00716-3 date: 2023-02-28 @@ -61,6 +65,8 @@ M3GNet: - homonuclear-diatomics - combustion - stability + - eos_bulk + - wbm_ev github: https://github.com/materialsvirtuallab/matgl doi: https://doi.org/10.1038/s43588-022-00349-3 date: 2022-02-05 @@ -87,6 +93,8 @@ MatterSim: - homonuclear-diatomics - stability - combustion + - eos_bulk + - wbm_ev github: https://github.com/microsoft/mattersim doi: https://arxiv.org/abs/2405.04967 date: 2024-12-05 @@ -113,6 +121,8 @@ ORBv2: - homonuclear-diatomics - combustion - stability + - eos_bulk + - wbm_ev github: https://github.com/orbital-materials/orb-models doi: https://arxiv.org/abs/2410.22570 date: 2024-10-15 @@ -136,6 +146,8 @@ SevenNet: - homonuclear-diatomics - stability - combustion + - eos_bulk + - wbm_ev github: https://github.com/MDIL-SNU/SevenNet doi: https://doi.org/10.1021/acs.jctc.4c00190 date: 2024-07-11 @@ -161,6 +173,7 @@ eqV2(OMat): - eos_alloy gpu-tasks: - homonuclear-diatomics + - wbm_ev prediction: EFS nvt: true npt: false # https://github.com/FAIR-Chem/fairchem/issues/888, https://github.com/atomind-ai/mlip-arena/issues/17 @@ -184,6 +197,8 @@ MACE-MPA: gpu-tasks: - homonuclear-diatomics - stability + - eos_bulk + - wbm_ev github: https://github.com/ACEsuit/mace doi: https://arxiv.org/abs/2401.00096 date: 2024-12-09 @@ -336,6 +351,7 @@ ALIGNN: gpu-tasks: - homonuclear-diatomics - stability + - wbm_ev # - combustion prediction: EFS nvt: true diff --git a/mlip_arena/tasks/optimize.py b/mlip_arena/tasks/optimize.py index 0c6ccd4a29816368c334df43137eb0791d8f870d..c0d5f478b0b055e3ae8598460000e24ad52ce127 100644 --- a/mlip_arena/tasks/optimize.py +++ b/mlip_arena/tasks/optimize.py @@ -78,7 +78,7 @@ def run( filter_kwargs = filter_kwargs or {} optimizer_kwargs = optimizer_kwargs or {} - criterion = criterion or {} + criterion = criterion or dict(steps=1000) if symmetry: atoms.set_constraint(FixSymmetry(atoms)) diff --git a/mlip_arena/tasks/registry.yaml b/mlip_arena/tasks/registry.yaml index f51e5ca707a05bed2bcd08beb790d1059d13bc73..d35035f9abfdac42070cc766c9db73608eb2a343 100644 --- a/mlip_arena/tasks/registry.yaml +++ b/mlip_arena/tasks/registry.yaml @@ -4,6 +4,18 @@ Homonuclear diatomics: task-layout: wide rank-page: homonuclear-diatomics last-update: 2024-09-19 +Energy-volume scans: + category: Fundamentals + task-page: wbm_ev + task-layout: wide + rank-page: wbm_ev + last-update: 2025-04-29 +Equation of state: + category: Fundamentals + task-page: eos_bulk + task-layout: wide + rank-page: eos_bulk + last-update: 2025-04-29 Combustion: category: Molecular Dynamics task-page: combustion diff --git a/serve/app.py b/serve/app.py index 24c0b68b0618fb32505dfbe2408fd1e0bf5baeb7..9eb20febd584b37f164cd58f2fbea568c8b7ea6d 100644 --- a/serve/app.py +++ b/serve/app.py @@ -15,6 +15,8 @@ nav[""].append(leaderboard) wide_pages, centered_pages = [], [] for task in TASKS: + if TASKS[task]['task-page'] is None: + continue page = st.Page( f"tasks/{TASKS[task]['task-page']}.py", title=task, icon=":material/target:" ) @@ -50,10 +52,10 @@ else: ) -st.toast( - "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care. Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.", - icon="๐Ÿž", -) +# st.toast( +# "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care. Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.", +# icon="๐Ÿž", +# ) st.sidebar.page_link( "https://github.com/atomind-ai/mlip-arena", label="GitHub Repository", icon=":material/code:" diff --git a/serve/leaderboard.py b/serve/leaderboard.py index 335d93b6ddb2fd6ebd107fe2cf86fbf9d3e9ffd2..da46f1da3dd5731ab02d946e7f3a4d06d1527a46 100644 --- a/serve/leaderboard.py +++ b/serve/leaderboard.py @@ -59,10 +59,10 @@ s = table.style.background_gradient( cmap="PuRd", subset=["Element Coverage"], vmin=0, vmax=120 ) -st.warning( - "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care.", - icon="โš ๏ธ", -) +# st.warning( +# "MLIP Arena is currently in **pre-alpha**. The results are not stable. Please interpret them with care.", +# icon="โš ๏ธ", +# ) st.info( "Contributions are welcome. For more information, visit https://github.com/atomind-ai/mlip-arena.", icon="๐Ÿค—", @@ -117,11 +117,12 @@ for task in TASKS: task_module = importlib.import_module(f"ranks.{TASKS[task]['rank-page']}") - st.page_link( - f"tasks/{TASKS[task]['task-page']}.py", - label="Go to the associated task page", - icon=":material/link:", - ) + if TASKS[task]['task-page'] is not None: + st.page_link( + f"tasks/{TASKS[task]['task-page']}.py", + label="Go to the associated task page", + icon=":material/link:", + ) # Call the function from the imported module if hasattr(task_module, "render"): diff --git a/serve/ranks/eos_bulk.py b/serve/ranks/eos_bulk.py new file mode 100644 index 0000000000000000000000000000000000000000..962ce83e1be52606c40efe3ed0e5788b07b6519e --- /dev/null +++ b/serve/ranks/eos_bulk.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import pandas as pd +import streamlit as st + +DATA_DIR = Path("examples/eos_bulk") + + +table = pd.read_csv(DATA_DIR / "summary.csv") + + + +table = table.rename( + columns={ + "model": "Model", + "rank": "Rank", + "rank-aggregation": "Rank aggr.", + "energy-diff-flip-times": "Derivative flips", + "tortuosity": "Tortuosity", + "spearman-compression-energy": "Spearman's coeff. (compression)", + "spearman-tension-energy": "Spearman's coeff. (tension)", + "spearman-compression-derivative": "Spearman's coeff. (compression derivative)", + "missing": "Missing", + }, +) + +table.set_index("Model", inplace=True) + +s = ( + table.style.background_gradient( + cmap="Blues", + subset=["Rank", "Rank aggr."], + ).background_gradient( + cmap="Reds", + subset=[ + "Spearman's coeff. (compression)", + ], + ).background_gradient( + cmap="Reds_r", + subset=[ + "Spearman's coeff. (tension)", + "Spearman's coeff. (compression derivative)", + ], + ).background_gradient( + cmap="RdPu", + subset=["Tortuosity", "Derivative flips"], + ).format( + "{:.5f}", + subset=[ + "Spearman's coeff. (compression)", + "Spearman's coeff. (tension)", + "Spearman's coeff. (compression derivative)", + "Tortuosity", + "Derivative flips", + ], + ) +) + +def render(): + st.dataframe( + s, + use_container_width=True, + ) diff --git a/serve/ranks/wbm_ev.py b/serve/ranks/wbm_ev.py new file mode 100644 index 0000000000000000000000000000000000000000..95e4359bde727ca2d456edd11582f3149632060d --- /dev/null +++ b/serve/ranks/wbm_ev.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import pandas as pd +import streamlit as st + +DATA_DIR = Path("examples/wbm_ev") + + +table = pd.read_csv(DATA_DIR / "summary.csv") + + + +table = table.rename( + columns={ + "model": "Model", + "rank": "Rank", + "rank-aggregation": "Rank aggr.", + "energy-diff-flip-times": "Derivative flips", + "tortuosity": "Tortuosity", + "spearman-compression-energy": "Spearman's coeff. (compression)", + "spearman-tension-energy": "Spearman's coeff. (tension)", + "spearman-compression-derivative": "Spearman's coeff. (compression derivative)", + "missing": "Missing", + }, +) + +table.set_index("Model", inplace=True) + +s = ( + table.style.background_gradient( + cmap="Blues", + subset=["Rank", "Rank aggr."], + ).background_gradient( + cmap="Reds", + subset=[ + "Spearman's coeff. (compression)", + ], + ).background_gradient( + cmap="Reds_r", + subset=[ + "Spearman's coeff. (tension)", + "Spearman's coeff. (compression derivative)", + ], + ).background_gradient( + cmap="RdPu", + subset=["Tortuosity", "Derivative flips"], + ).format( + "{:.5f}", + subset=[ + "Spearman's coeff. (compression)", + "Spearman's coeff. (tension)", + "Spearman's coeff. (compression derivative)", + "Tortuosity", + "Derivative flips", + ], + ) +) + +def render(): + st.dataframe( + s, + use_container_width=True, + ) diff --git a/serve/tasks/eos_bulk.py b/serve/tasks/eos_bulk.py new file mode 100644 index 0000000000000000000000000000000000000000..39c126f0103214f865d79aae1ee846fcf5f54668 --- /dev/null +++ b/serve/tasks/eos_bulk.py @@ -0,0 +1,242 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.colors as pcolors +import plotly.graph_objects as go +import streamlit as st +from ase.db import connect +from scipy import stats + +from mlip_arena.models import REGISTRY as MODELS + +DATA_DIR = Path("examples/eos_bulk") + +st.markdown(""" +# Equation of state (EOS) +""") + +# Control panels at the top +st.markdown("### Methods") +methods_container = st.container(border=True) + +valid_models = [ + model + for model, metadata in MODELS.items() + if Path(__file__).stem in metadata.get("gpu-tasks", []) +] + +# Model selection +selected_models = methods_container.multiselect( + "Select Models", + options=valid_models, + default=valid_models +) + +# Visualization settings +st.markdown("### Visualization Settings") +vis = st.container(border=True) + +# Column settings +ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=2) + +# Color palette selection +all_attributes = dir(pcolors.qualitative) +color_palettes = { + attr: getattr(pcolors.qualitative, attr) + for attr in all_attributes + if isinstance(getattr(pcolors.qualitative, attr), list) +} +color_palettes.pop("__all__", None) + +palette_names = list(color_palettes.keys()) +palette_name = vis.selectbox("Color sequence", options=palette_names, index=22) +color_sequence = color_palettes[palette_name] + +# Stop execution if no models selected +if not selected_models: + st.warning("Please select at least one model to visualize.") + st.stop() + + +def load_wbm_structures(): + """ + Load the WBM structures from a ASE DB file. + """ + with connect(DATA_DIR.parent / "wbm_structures.db") as db: + for row in db.select(): + yield row.toatoms(add_additional_information=True) + + +@st.cache_data +def generate_dataframe(model_name): + fpath = DATA_DIR / f"{model_name}.parquet" + if not fpath.exists(): + return pd.DataFrame() # Return empty dataframe instead of using continue + + df_raw_results = pd.read_parquet(fpath) + + df_analyzed = pd.DataFrame( + columns=[ + "model", + "structure", + "formula", + "volume-ratio", + "energy-delta-per-atom", + "energy-diff-flip-times", + "tortuosity", + "spearman-compression-energy", + "spearman-compression-derivative", + "spearman-tension-energy", + "missing", + ] + ) + + for wbm_struct in load_wbm_structures(): + structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"] + + try: + results = df_raw_results.loc[df_raw_results["id"] == structure_id] + results = results["eos"].values[0] + es = np.array(results["energies"]) + vols = np.array(results["volumes"]) + vol0 = wbm_struct.get_volume() + + indices = np.argsort(vols) + vols = vols[indices] + es = es[indices] + + imine = len(es) // 2 + # min_center_val = np.min(es[imid - 1 : imid + 2]) + # imine = np.where(es == min_center_val)[0][0] + emin = es[imine] + + interpolated_volumes = [ + (vols[i] + vols[i + 1]) / 2 for i in range(len(vols) - 1) + ] + ediff = np.diff(es) + ediff_sign = np.sign(ediff) + mask = ediff_sign != 0 + ediff = ediff[mask] + ediff_sign = ediff_sign[mask] + ediff_flip = np.diff(ediff_sign) != 0 + + etv = np.sum(np.abs(np.diff(es))) + + data = { + "model": model_name, + "structure": structure_id, + "formula": wbm_struct.get_chemical_formula(), + "missing": False, + "volume-ratio": vols / vol0, + "energy-delta-per-atom": (es - emin) / len(wbm_struct), + "energy-diff-flip-times": np.sum(ediff_flip).astype(int), + "tortuosity": etv / (abs(es[0] - emin) + abs(es[-1] - emin)), + "spearman-compression-energy": stats.spearmanr( + vols[:imine], es[:imine] + ).statistic, + "spearman-compression-derivative": stats.spearmanr( + interpolated_volumes[:imine], ediff[:imine] + ).statistic, + "spearman-tension-energy": stats.spearmanr( + vols[imine:], es[imine:] + ).statistic, + } + + except Exception: + data = { + "model": model_name, + "structure": structure_id, + "formula": wbm_struct.get_chemical_formula(), + "missing": True, + "volume-ratio": None, + "energy-delta-per-atom": None, + "energy-diff-flip-times": None, + "tortuosity": None, + "spearman-compression-energy": None, + "spearman-compression-derivative": None, + "spearman-tension-energy": None, + } + + df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True) + + return df_analyzed + + +@st.cache_data +def get_plots(selected_models): + """Generate one plot per model with all structures (legend disabled for each structure).""" + figs = [] + + for model_name in selected_models: + + fpath = DATA_DIR / f"{model_name}_processed.parquet" + if not fpath.exists(): + df = generate_dataframe(model_name) + else: + df = pd.read_parquet(fpath) + + if len(df) == 0: + continue + + fig = go.Figure() + valid_structures = [] + for i, (_, row) in enumerate(df.iterrows()): + structure_id = row["structure"] + formula = row.get("formula", "") + if isinstance(row["volume-ratio"], list | np.ndarray) and isinstance( + row["energy-delta-per-atom"], list | np.ndarray + ): + vol_strain = row["volume-ratio"] + energy_delta = row["energy-delta-per-atom"] + color = color_sequence[i % len(color_sequence)] + fig.add_trace( + go.Scatter( + x=vol_strain, + y=energy_delta, + mode="lines", + name=f"{structure_id}", + showlegend=False, + line=dict(color=color), + hoverlabel=dict(bgcolor=color, font=dict(color="black")), + hovertemplate=( + structure_id + "
" + "Formula: " + str(formula) + "
" + "Volume ratio V/Vโ‚€: %{x:.3f}
" + "ฮ”Energy: %{y:.3f} eV/atom
" + "" + ), + + ) + ) + valid_structures.append(structure_id) + + # if valid_structures: + fig.update_layout( + title=f"{model_name} ({len(valid_structures)} / {len(df)} structures)", + xaxis_title="Volume ratio V/Vโ‚€", + yaxis_title="Relative energy ฮ”E (eV/atom)", + height=500, + showlegend=False, # Disable legend for the whole plot + yaxis=dict(range=[-0.1, 1]), # Set y-axis limits + ) + fig.add_vline(x=1, line_dash="dash", line_color="gray", opacity=0.7) + figs.append((model_name, fig, valid_structures)) + + return figs + + +# Generate all plots +all_plots = get_plots(selected_models) + +# Display plots in the specified column layout +if all_plots: + for i, (model_name, fig, structures) in enumerate(all_plots): + if i % ncols == 0: + cols = st.columns(ncols) + cols[i % ncols].plotly_chart(fig, use_container_width=True) + + # Display number of structures in this plot + # cols[i % ncols].caption(f"{len(structures)} / 1000 structures") +else: + st.warning("No data available for the selected models.") diff --git a/serve/tasks/wbm_ev.py b/serve/tasks/wbm_ev.py new file mode 100644 index 0000000000000000000000000000000000000000..a292ba1893ed8a48df2e2e1397fad40936f8ece7 --- /dev/null +++ b/serve/tasks/wbm_ev.py @@ -0,0 +1,243 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.colors as pcolors +import plotly.graph_objects as go +import streamlit as st +from ase.db import connect +from scipy import stats + +from mlip_arena.models import REGISTRY as MODELS + +DATA_DIR = Path("examples/wbm_ev") + +st.markdown(""" +# Energy-volume scans +""") + +# Control panels at the top +st.markdown("### Methods") +methods_container = st.container(border=True) + +# Get valid models that support wbm_ev +valid_models = [ + model + for model, metadata in MODELS.items() + if Path(__file__).stem in metadata.get("gpu-tasks", []) +] + +# Model selection +selected_models = methods_container.multiselect( + "Select Models", + options=valid_models, + default=valid_models +) + +# Visualization settings +st.markdown("### Visualization Settings") +vis = st.container(border=True) + +# Column settings +ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=2) + +# Color palette selection +all_attributes = dir(pcolors.qualitative) +color_palettes = { + attr: getattr(pcolors.qualitative, attr) + for attr in all_attributes + if isinstance(getattr(pcolors.qualitative, attr), list) +} +color_palettes.pop("__all__", None) + +palette_names = list(color_palettes.keys()) +palette_name = vis.selectbox("Color sequence", options=palette_names, index=22) +color_sequence = color_palettes[palette_name] + +# Stop execution if no models selected +if not selected_models: + st.warning("Please select at least one model to visualize.") + st.stop() + + +def load_wbm_structures(): + """ + Load the WBM structures from a ASE DB file. + """ + with connect(DATA_DIR.parent / "wbm_structures.db") as db: + for row in db.select(): + yield row.toatoms(add_additional_information=True) + + +@st.cache_data +def generate_dataframe(model_name): + fpath = DATA_DIR / f"{model_name}.parquet" + if not fpath.exists(): + return pd.DataFrame() # Return empty dataframe instead of using continue + + df_raw_results = pd.read_parquet(fpath) + + df_analyzed = pd.DataFrame( + columns=[ + "model", + "structure", + "formula", + "volume-ratio", + "energy-delta-per-atom", + "energy-diff-flip-times", + "tortuosity", + "spearman-compression-energy", + "spearman-compression-derivative", + "spearman-tension-energy", + "missing", + ] + ) + + for wbm_struct in load_wbm_structures(): + structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"] + + try: + results = df_raw_results.loc[df_raw_results["id"] == structure_id] + results = results["eos"].values[0] + es = np.array(results["energies"]) + vols = np.array(results["volumes"]) + vol0 = wbm_struct.get_volume() + + indices = np.argsort(vols) + vols = vols[indices] + es = es[indices] + + imine = len(es) // 2 + # min_center_val = np.min(es[imid - 1 : imid + 2]) + # imine = np.where(es == min_center_val)[0][0] + emin = es[imine] + + interpolated_volumes = [ + (vols[i] + vols[i + 1]) / 2 for i in range(len(vols) - 1) + ] + ediff = np.diff(es) + ediff_sign = np.sign(ediff) + mask = ediff_sign != 0 + ediff = ediff[mask] + ediff_sign = ediff_sign[mask] + ediff_flip = np.diff(ediff_sign) != 0 + + etv = np.sum(np.abs(np.diff(es))) + + data = { + "model": model_name, + "structure": structure_id, + "formula": wbm_struct.get_chemical_formula(), + "missing": False, + "volume-ratio": vols / vol0, + "energy-delta-per-atom": (es - emin) / len(wbm_struct), + "energy-diff-flip-times": np.sum(ediff_flip).astype(int), + "tortuosity": etv / (abs(es[0] - emin) + abs(es[-1] - emin)), + "spearman-compression-energy": stats.spearmanr( + vols[:imine], es[:imine] + ).statistic, + "spearman-compression-derivative": stats.spearmanr( + interpolated_volumes[:imine], ediff[:imine] + ).statistic, + "spearman-tension-energy": stats.spearmanr( + vols[imine:], es[imine:] + ).statistic, + } + + except Exception: + data = { + "model": model_name, + "structure": structure_id, + "formula": wbm_struct.get_chemical_formula(), + "missing": True, + "volume-ratio": None, + "energy-delta-per-atom": None, + "energy-diff-flip-times": None, + "tortuosity": None, + "spearman-compression-energy": None, + "spearman-compression-derivative": None, + "spearman-tension-energy": None, + } + + df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True) + + return df_analyzed + + +@st.cache_data +def get_plots(selected_models): + """Generate one plot per model with all structures (legend disabled for each structure).""" + figs = [] + + for model_name in selected_models: + + fpath = DATA_DIR / f"{model_name}_processed.parquet" + if not fpath.exists(): + df = generate_dataframe(model_name) + else: + df = pd.read_parquet(fpath) + + if len(df) == 0: + continue + + fig = go.Figure() + valid_structures = [] + for i, (_, row) in enumerate(df.iterrows()): + structure_id = row["structure"] + formula = row.get("formula", "") + if isinstance(row["volume-ratio"], list | np.ndarray) and isinstance( + row["energy-delta-per-atom"], list | np.ndarray + ): + vol_strain = row["volume-ratio"] + energy_delta = row["energy-delta-per-atom"] + color = color_sequence[i % len(color_sequence)] + fig.add_trace( + go.Scatter( + x=vol_strain, + y=energy_delta, + mode="lines", + name=f"{structure_id}", + showlegend=False, + line=dict(color=color), + hoverlabel=dict(bgcolor=color, font=dict(color="black")), + hovertemplate=( + structure_id + "
" + "Formula: " + str(formula) + "
" + "Volume ratio V/Vโ‚€: %{x:.3f}
" + "ฮ”Energy: %{y:.3f} eV/atom
" + "" + ), + + ) + ) + valid_structures.append(structure_id) + + # if valid_structures: + fig.update_layout( + title=f"{model_name} ({len(valid_structures)} / {len(df)} structures)", + xaxis_title="Volume ratio V/Vโ‚€", + yaxis_title="Relative energy E - Eโ‚€ (eV/atom)", + height=500, + showlegend=False, # Disable legend for the whole plot + yaxis=dict(range=[-1, 15]), # Set y-axis limits + ) + fig.add_vline(x=1, line_dash="dash", line_color="gray", opacity=0.7) + figs.append((model_name, fig, valid_structures)) + + return figs + + +# Generate all plots +all_plots = get_plots(selected_models) + +# Display plots in the specified column layout +if all_plots: + for i, (model_name, fig, structures) in enumerate(all_plots): + if i % ncols == 0: + cols = st.columns(ncols) + cols[i % ncols].plotly_chart(fig, use_container_width=True) + + # Display number of structures in this plot + # cols[i % ncols].caption(f"{len(structures)} / 1000 structures") +else: + st.warning("No data available for the selected models.")