|
""" |
|
Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. |
|
|
|
Copyright 2024. UChicago Argonne, LLC. This software was produced |
|
under U.S. Government contract DE-AC02-06CH11357 for Argonne National |
|
Laboratory (ANL), which is operated by UChicago Argonne, LLC for the |
|
U.S. Department of Energy. The U.S. Government has rights to use, |
|
reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR |
|
UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR |
|
ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is |
|
modified to produce derivative works, such modified software should |
|
be clearly marked, so as not to confuse it with the version available |
|
from ANL. |
|
|
|
Additionally, redistribution and use in source and binary forms, with |
|
or without modification, are permitted provided that the following |
|
conditions are met: |
|
|
|
* Redistributions of source code must retain the above copyright |
|
notice, this list of conditions and the following disclaimer. |
|
|
|
* Redistributions in binary form must reproduce the above copyright |
|
notice, this list of conditions and the following disclaimer in |
|
the documentation and/or other materials provided with the |
|
distribution. |
|
|
|
* Neither the name of UChicago Argonne, LLC, Argonne National |
|
Laboratory, ANL, the U.S. Government, nor the names of its |
|
contributors may be used to endorse or promote products derived |
|
from this software without specific prior written permission. |
|
|
|
THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS |
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
|
FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago |
|
Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
|
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
|
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT |
|
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN |
|
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
|
POSSIBILITY OF SUCH DAMAGE. |
|
""" |
|
|
|
|
|
|
|
import marimo |
|
|
|
__generated_with = "0.10.2" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell |
|
def _(__file__): |
|
import sys |
|
from math import floor, ceil, acos, pi |
|
import numpy as np |
|
import plotly.express as px |
|
import marimo as mo |
|
|
|
from mapstorch.io import read_dataset |
|
from mapstorch.util import PeriodicTableWidget |
|
from mapstorch.default import ( |
|
default_fitting_elems, |
|
unsupported_elements, |
|
supported_elements_mapping, |
|
) |
|
|
|
return ( |
|
PeriodicTableWidget, |
|
acos, |
|
ceil, |
|
default_fitting_elems, |
|
floor, |
|
mo, |
|
np, |
|
pi, |
|
px, |
|
read_dataset, |
|
supported_elements_mapping, |
|
sys, |
|
unsupported_elements, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(mo): |
|
dataset = mo.ui.file_browser( |
|
filetypes=[".h5", ".h50", ".h51", ".h52", ".h53", ".h54", ".h55"], |
|
multiple=False, |
|
restrict_navigation=True |
|
) |
|
mo.md(f"Please select the dataset file (h5 file) \n{dataset}") |
|
return (dataset,) |
|
|
|
|
|
@app.cell |
|
def _(mo): |
|
int_spec_path = mo.ui.dropdown( |
|
["MAPS/int_spec"], |
|
value="MAPS/int_spec", |
|
label="Integrated spectrum location", |
|
) |
|
elem_path = mo.ui.dropdown( |
|
["MAPS/channel_names"], |
|
value="MAPS/channel_names", |
|
label="Energy channel names location", |
|
) |
|
dataset_button = mo.ui.run_button(label="Load") |
|
mo.hstack( |
|
[int_spec_path, elem_path, dataset_button], justify="start", gap=1 |
|
).right() |
|
return dataset_button, elem_path, int_spec_path |
|
|
|
|
|
@app.cell |
|
def _(int_spec_og, mo): |
|
energy_range = mo.ui.range_slider( |
|
start=0, |
|
stop=int_spec_og.shape[-1] - 1, |
|
step=1, |
|
label="Energy range", |
|
value=[50, 1450], |
|
full_width=True, |
|
) |
|
return (energy_range,) |
|
|
|
|
|
@app.cell |
|
def _(energy_range, int_spec_og, mo, peaks): |
|
incident_energy_slider = mo.ui.slider( |
|
start=6, |
|
stop=18, |
|
step=0.01, |
|
value=12, |
|
label="Incident Energy (keV)", |
|
full_width=True, |
|
) |
|
compton_peak_value = ( |
|
(int_spec_og.shape[-1] - 1) // 2 if len(peaks) < 8 else peaks[-2] |
|
) |
|
compton_peak_slider = mo.ui.slider( |
|
start=0, |
|
stop=int_spec_og.shape[-1] - 1, |
|
step=1, |
|
value=compton_peak_value, |
|
label="Compton Peak Position", |
|
full_width=True, |
|
) |
|
elastic_peak_value = ( |
|
(int_spec_og.shape[-1] - 1) // 1.9 if len(peaks) < 8 else peaks[-1] |
|
) |
|
elastic_peak_slider = mo.ui.slider( |
|
start=0, |
|
stop=int_spec_og.shape[-1] - 1, |
|
step=1, |
|
value=elastic_peak_value, |
|
label="Elastic Peak Position", |
|
full_width=True, |
|
) |
|
mo.vstack( |
|
[incident_energy_slider, energy_range, compton_peak_slider, elastic_peak_slider] |
|
) |
|
return ( |
|
compton_peak_slider, |
|
compton_peak_value, |
|
elastic_peak_slider, |
|
elastic_peak_value, |
|
incident_energy_slider, |
|
) |
|
|
|
|
|
@app.cell |
|
def _( |
|
compton_peak_slider, |
|
elastic_peak_slider, |
|
int_spec, |
|
int_spec_log, |
|
mo, |
|
np, |
|
peaks, |
|
): |
|
from plotly.subplots import make_subplots |
|
import plotly.graph_objects as go |
|
|
|
int_spec_fig = make_subplots(rows=2, cols=1) |
|
|
|
|
|
int_spec_fig.append_trace( |
|
go.Scatter( |
|
x=np.arange(len(int_spec)), y=int_spec, mode="lines", name="Photon counts" |
|
), |
|
row=1, |
|
col=1, |
|
) |
|
int_spec_fig.append_trace( |
|
go.Scatter( |
|
x=np.arange(len(int_spec)), y=int_spec_log, mode="lines", name="Log scale" |
|
), |
|
row=2, |
|
col=1, |
|
) |
|
int_spec_fig.append_trace( |
|
go.Scatter( |
|
x=peaks, |
|
y=int_spec[peaks], |
|
mode="markers", |
|
marker_color="#00cc96", |
|
showlegend=False, |
|
), |
|
row=1, |
|
col=1, |
|
) |
|
int_spec_fig.append_trace( |
|
go.Scatter( |
|
x=peaks, |
|
y=int_spec_log[peaks], |
|
mode="markers", |
|
name="Peaks", |
|
marker_color="#00cc96", |
|
), |
|
row=2, |
|
col=1, |
|
) |
|
|
|
|
|
int_spec_fig.add_vline( |
|
x=compton_peak_slider.value, line_width=1, line_color="#ab63fa" |
|
) |
|
int_spec_fig.add_vline( |
|
x=elastic_peak_slider.value, line_width=1, line_color="#ffa15a" |
|
) |
|
|
|
int_spec_fig.update_layout(showlegend=False) |
|
|
|
mo.ui.plotly(int_spec_fig) |
|
return go, int_spec_fig, make_subplots |
|
|
|
|
|
@app.cell |
|
def _(configs, elem_selection, mo, param_selection): |
|
control_panel = mo.accordion( |
|
{ |
|
"Elements": elem_selection.center(), |
|
"Parameters": param_selection, |
|
"Configs": configs, |
|
}, |
|
multiple=True, |
|
) |
|
control_panel_shown = True |
|
control_panel |
|
return control_panel, control_panel_shown |
|
|
|
|
|
@app.cell |
|
def _(control_panel_shown, mo): |
|
run_button = mo.ui.run_button() |
|
run_button.right() if control_panel_shown else None |
|
return (run_button,) |
|
|
|
|
|
@app.cell |
|
def _( |
|
device_selection, |
|
elem_checkboxes, |
|
energy_range, |
|
fit_indices_list, |
|
init_amp_checkbox, |
|
int_spec_og, |
|
iter_slider, |
|
loss_selection, |
|
mo, |
|
optimizer_selection, |
|
param_checkboxes, |
|
run_button, |
|
use_snip_checkbox, |
|
use_step_checkbox, |
|
): |
|
mo.stop(not run_button.value) |
|
from mapstorch.opt import fit_spec |
|
|
|
n_iter = iter_slider.value |
|
with mo.status.progress_bar(total=n_iter) as bar: |
|
fitted_tensors, fitted_spec, fitted_bkg, loss_trace = fit_spec( |
|
int_spec_og, |
|
energy_range.value, |
|
elements_to_fit=[k for k, v in elem_checkboxes.items() if v.value], |
|
fitting_params=[k for k, v in param_checkboxes.items() if v[0].value], |
|
init_param_vals={ |
|
k: float(v[2].value) for k, v in param_checkboxes.items() if v[0].value |
|
}, |
|
fixed_param_vals={ |
|
k: float(v[2].value) |
|
for k, v in param_checkboxes.items() |
|
if v[0].value and v[1].value |
|
}, |
|
indices=fit_indices_list[-1], |
|
tune_params=True, |
|
init_amp=init_amp_checkbox.value, |
|
use_snip=use_snip_checkbox.value, |
|
use_step=use_step_checkbox.value, |
|
use_tail=False, |
|
loss=loss_selection.value, |
|
optimizer=optimizer_selection.value, |
|
n_iter=n_iter, |
|
progress_bar=False, |
|
device=device_selection.value, |
|
status_updator=bar, |
|
) |
|
return ( |
|
bar, |
|
fit_spec, |
|
fitted_bkg, |
|
fitted_spec, |
|
fitted_tensors, |
|
loss_trace, |
|
n_iter, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(fitted_bkg, fitted_spec, go, int_spec, make_subplots, mo, np, px): |
|
fit_labels = ["experiment", "background", "fitted"] |
|
fit_fig = make_subplots(rows=2, cols=1) |
|
spec_x = np.linspace(0, int_spec.size - 1, int_spec.size) |
|
|
|
for i, spec in enumerate([int_spec, fitted_bkg, fitted_spec + fitted_bkg]): |
|
fit_fig.add_trace( |
|
go.Scatter( |
|
x=spec_x, |
|
y=spec, |
|
mode="lines", |
|
name=fit_labels[i], |
|
line=dict(color=px.colors.qualitative.Plotly[i]), |
|
), |
|
row=1, |
|
col=1, |
|
) |
|
spec_log = np.log10(np.clip(spec, 0, None) + 1) |
|
fit_fig.add_trace( |
|
go.Scatter( |
|
x=spec_x, |
|
y=spec_log, |
|
mode="lines", |
|
showlegend=False, |
|
line=dict(color=px.colors.qualitative.Plotly[i]), |
|
), |
|
row=2, |
|
col=1, |
|
) |
|
|
|
mo.ui.plotly(fit_fig) |
|
return fit_fig, fit_labels, i, spec, spec_log, spec_x |
|
|
|
|
|
@app.cell |
|
def _(elem_checkboxes, fitted_tensors, go, make_subplots, mo): |
|
target_elems = [k for k, v in elem_checkboxes.items() if v.value] |
|
amps = {p: fitted_tensors[p].item() for p in target_elems} |
|
amps = dict(sorted(amps.items(), key=lambda item: item[1])) |
|
|
|
amp_fig = make_subplots(rows=1, cols=2) |
|
|
|
amp_fig.add_trace( |
|
go.Bar( |
|
x=[10**v for v in amps.values()], |
|
y=list(amps.keys()), |
|
orientation="h", |
|
name="Photon counts", |
|
), |
|
row=1, |
|
col=1, |
|
) |
|
amp_fig.add_trace(go.Bar(x=list(amps.values()), name="Log scale"), row=1, col=2) |
|
amp_fig.update_yaxes(showticklabels=False, row=1, col=2) |
|
amp_fig.update_layout(showlegend=False) |
|
|
|
results_shown = True |
|
|
|
mo.ui.plotly(amp_fig) |
|
return amp_fig, amps, results_shown, target_elems |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.cell |
|
def _(confirm_range_button, elem_peak_indices, fit_indices_list, mo): |
|
mo.stop(not confirm_range_button.value) |
|
|
|
fit_indices_list[-1] = elem_peak_indices |
|
mo.callout( |
|
"Target ranges have been updated. Please select parameters to re-run the fitting process.", |
|
kind="success", |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def _(dataset, elem_checkboxes, fitted_tensors, mo, params_record): |
|
import pandas as pd |
|
import datetime |
|
|
|
for par, l in params_record.items(): |
|
if par == "elements": |
|
l.append(",".join([k for k, v in elem_checkboxes.items() if v.value])) |
|
else: |
|
l.append(fitted_tensors[par].item()) |
|
today = datetime.date.today() |
|
today_string = today.strftime("%Y-%m-%d") |
|
table_label = dataset.value[0].name + " parameter tuning record " + today_string |
|
params_table = mo.ui.table( |
|
pd.DataFrame(params_record), |
|
selection="single", |
|
label=table_label, |
|
show_download=False, |
|
) |
|
params_table |
|
return ( |
|
datetime, |
|
l, |
|
par, |
|
params_table, |
|
pd, |
|
table_label, |
|
today, |
|
today_string, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(load_params_button, mo, params_table, save_params_button): |
|
( |
|
mo.hstack([load_params_button]).right() |
|
if len(params_table.value) > 0 |
|
else None |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def _(mo): |
|
load_params_button = mo.ui.button(label="Load selected parameters and re-run") |
|
save_params_button = mo.ui.run_button(label="Generate override params file") |
|
return load_params_button, save_params_button |
|
|
|
|
|
@app.cell |
|
def _(mo, params_table, save_params_button): |
|
mo.stop(not save_params_button.value) |
|
from mapstorch.io import write_override_params_file |
|
|
|
write_override_params_file( |
|
"maps_fit_parameters_override.txt", |
|
param_values={ |
|
"COHERENT_SCT_ENERGY": params_table.value.iloc[0][ |
|
"COHERENT_SCT_ENERGY" |
|
].item(), |
|
"ENERGY_OFFSET": params_table.value.iloc[0]["ENERGY_OFFSET"].item(), |
|
"ENERGY_SLOPE": params_table.value.iloc[0]["ENERGY_SLOPE"].item(), |
|
"ENERGY_QUADRATIC": params_table.value.iloc[0]["ENERGY_QUADRATIC"].item(), |
|
"COMPTON_ANGLE": params_table.value.iloc[0]["COMPTON_ANGLE"].item(), |
|
"COMPTON_FWHM_CORR": params_table.value.iloc[0]["COMPTON_FWHM_CORR"].item(), |
|
"COMPTON_HI_F_TAIL": params_table.value.iloc[0]["COMPTON_HI_F_TAIL"].item(), |
|
"COMPTON_F_TAIL": params_table.value.iloc[0]["COMPTON_F_TAIL"].item(), |
|
"FWHM_FANOPRIME": params_table.value.iloc[0]["FWHM_FANOPRIME"].item(), |
|
"FWHM_OFFSET": params_table.value.iloc[0]["FWHM_OFFSET"].item(), |
|
"F_TAIL_OFFSET": params_table.value.iloc[0]["F_TAIL_OFFSET"].item(), |
|
"KB_F_TAIL_OFFSET": params_table.value.iloc[0]["KB_F_TAIL_OFFSET"].item(), |
|
}, |
|
elements=params_table.value.iloc[0]["elements"].split(","), |
|
) |
|
return (write_override_params_file,) |
|
|
|
|
|
@app.cell |
|
def _( |
|
elem_peak_shapes, |
|
fit_labels, |
|
fitted_bkg, |
|
fitted_spec, |
|
focus_target_switch, |
|
go, |
|
int_spec, |
|
make_subplots, |
|
np, |
|
px, |
|
spec_x, |
|
): |
|
range_fig = make_subplots(rows=2, cols=1) |
|
|
|
for iii, specc in enumerate([int_spec, fitted_bkg, fitted_spec + fitted_bkg]): |
|
range_fig.add_trace( |
|
go.Scatter( |
|
x=spec_x, |
|
y=specc, |
|
mode="lines", |
|
name=fit_labels[iii], |
|
line=dict(color=px.colors.qualitative.Plotly[iii]), |
|
), |
|
row=1, |
|
col=1, |
|
) |
|
specc_log = np.log10(np.clip(specc, 0, None) + 1) |
|
range_fig.add_trace( |
|
go.Scatter( |
|
x=spec_x, |
|
y=specc_log, |
|
mode="lines", |
|
showlegend=False, |
|
line=dict(color=px.colors.qualitative.Plotly[iii]), |
|
), |
|
row=2, |
|
col=1, |
|
) |
|
|
|
if focus_target_switch.value: |
|
range_fig.update_layout(shapes=elem_peak_shapes, overwrite=True) |
|
return iii, range_fig, specc, specc_log |
|
|
|
|
|
@app.cell |
|
def _(mo): |
|
focus_target_switch = mo.ui.switch(label="Focus on target elements", value=False) |
|
energy_level_slider = mo.ui.slider( |
|
start=1, stop=6, step=1, value=1, label="Energy levels" |
|
) |
|
confirm_range_button = mo.ui.run_button(label="Load target ranges") |
|
return confirm_range_button, energy_level_slider, focus_target_switch |
|
|
|
|
|
@app.cell |
|
def _(fit_indices_list, focus_target_switch): |
|
if not focus_target_switch.value: |
|
fit_indices_list[-1] = None |
|
return |
|
|
|
|
|
@app.cell |
|
def _(elem_checkboxes, energy_level_slider, energy_range, fitted_tensors): |
|
from mapstorch.util import get_peak_ranges |
|
|
|
plot_elems = [k for k, v in elem_checkboxes.items() if v.value] |
|
elem_peak_indices = [] |
|
elem_peak_shapes = [] |
|
for ii, ee in enumerate(plot_elems): |
|
peak_rg = get_peak_ranges( |
|
[ee], |
|
fitted_tensors["COHERENT_SCT_ENERGY"].item(), |
|
fitted_tensors["COMPTON_ANGLE"].item(), |
|
fitted_tensors["ENERGY_OFFSET"].item(), |
|
fitted_tensors["ENERGY_SLOPE"].item(), |
|
fitted_tensors["ENERGY_QUADRATIC"].item(), |
|
energy_range.value, |
|
) |
|
alpha = 0.2 |
|
for p_n, r in peak_rg.items(): |
|
if ( |
|
p_n in ["COMPTON_AMPLITUDE", "COHERENT_SCT_AMPLITUDE"] |
|
or int(p_n[-1]) < energy_level_slider.value |
|
): |
|
elem_peak_indices += list(range(*r)) |
|
elem_peak_shapes.append( |
|
dict( |
|
type="rect", |
|
x0=r[0], |
|
x1=r[1], |
|
y0=0, |
|
y1=1, |
|
xref="x", |
|
yref="paper", |
|
fillcolor="yellow", |
|
opacity=alpha, |
|
layer="below", |
|
line_width=0, |
|
) |
|
) |
|
return ( |
|
alpha, |
|
ee, |
|
elem_peak_indices, |
|
elem_peak_shapes, |
|
get_peak_ranges, |
|
ii, |
|
p_n, |
|
peak_rg, |
|
plot_elems, |
|
r, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(param_checkbox_vals, param_default_vals, params_table): |
|
if len(params_table.value) > 0: |
|
for pp in params_table.value: |
|
if pp in param_checkbox_vals: |
|
param_checkbox_vals[pp] = float(params_table.value[pp].item()) |
|
else: |
|
for pp in param_checkbox_vals: |
|
param_checkbox_vals[pp] = float(param_default_vals[pp]) |
|
return (pp,) |
|
|
|
|
|
@app.cell |
|
def _( |
|
PeriodicTableWidget, |
|
default_fitting_elems, |
|
elems, |
|
mo, |
|
unsupported_elements, |
|
): |
|
initial_selected_elems = set() |
|
for e in default_fitting_elems: |
|
if e in elems and not e in ["COHERENT_SCT_AMPLITUDE", "COMPTON_AMPLITUDE"]: |
|
initial_selected_elems.add(e.split("_")[0]) |
|
|
|
elem_selection = mo.ui.anywidget( |
|
PeriodicTableWidget( |
|
states=1, |
|
initial_selected={se: 0 for se in initial_selected_elems}, |
|
initial_disabled=unsupported_elements, |
|
) |
|
) |
|
return e, elem_selection, initial_selected_elems |
|
|
|
|
|
@app.cell |
|
def _( |
|
default_fitting_elems, |
|
elem_selection, |
|
mo, |
|
supported_elements_mapping, |
|
): |
|
selected_lines = ["COHERENT_SCT_AMPLITUDE", "COMPTON_AMPLITUDE", "Si_Si"] |
|
for select_e in elem_selection.selected_elements: |
|
for l_ in supported_elements_mapping[select_e]: |
|
if l_ == "K": |
|
selected_lines.append(select_e) |
|
else: |
|
selected_lines.append(select_e + "_" + l_) |
|
elem_checkboxes = {} |
|
for edf in default_fitting_elems: |
|
if edf in selected_lines: |
|
elem_checkboxes[edf] = mo.ui.checkbox(label=edf, value=True) |
|
else: |
|
elem_checkboxes[edf] = mo.ui.checkbox(label=edf, value=False) |
|
return edf, elem_checkboxes, l_, select_e, selected_lines |
|
|
|
|
|
@app.cell |
|
def _(default_fitting_params, load_params_button, mo, param_checkbox_vals): |
|
load_params_button |
|
|
|
param_checkboxes = {} |
|
for p in default_fitting_params: |
|
param_checkboxes[p] = [ |
|
mo.ui.checkbox(label=p, value=True), |
|
mo.ui.checkbox(label="Fix"), |
|
mo.ui.text(value=str(param_checkbox_vals[p])), |
|
] |
|
|
|
param_selection = mo.vstack( |
|
[ |
|
mo.hstack(param_checkboxes[p], justify="start", gap=0) |
|
for p in default_fitting_params |
|
] |
|
) |
|
return p, param_checkboxes, param_selection |
|
|
|
|
|
@app.cell |
|
def _(device_list, mo): |
|
init_amp_checkbox = mo.ui.checkbox(label="Initialize amplitudes", value=True) |
|
use_snip_checkbox = mo.ui.checkbox(label="Use SNIP background", value=True) |
|
use_step_checkbox = mo.ui.checkbox(label="Modify peaks with step", value=True) |
|
model_options = mo.hstack( |
|
[init_amp_checkbox, use_snip_checkbox, use_step_checkbox], |
|
justify="start", |
|
gap=5, |
|
) |
|
iter_slider = mo.ui.slider( |
|
value=500, start=100, stop=3000, step=50, label="number of iterations" |
|
) |
|
loss_selection = mo.ui.dropdown(["mse", "l1"], value="mse", label="loss") |
|
optimizer_selection = mo.ui.dropdown( |
|
["adam", "adamw"], value="adam", label="optimizer" |
|
) |
|
device_selection = mo.ui.dropdown(device_list, value="cpu", label="device") |
|
opt_options = mo.hstack( |
|
[device_selection, loss_selection, optimizer_selection, iter_slider], |
|
justify="start", |
|
gap=2, |
|
) |
|
configs = mo.vstack([model_options, opt_options]) |
|
return ( |
|
configs, |
|
device_selection, |
|
init_amp_checkbox, |
|
iter_slider, |
|
loss_selection, |
|
model_options, |
|
opt_options, |
|
optimizer_selection, |
|
use_snip_checkbox, |
|
use_step_checkbox, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(): |
|
import torch |
|
|
|
device_list = ["cpu"] |
|
if torch.cuda.is_available(): |
|
device_list.append("cuda") |
|
return device_list, torch |
|
|
|
|
|
@app.cell |
|
def _(): |
|
fit_indices_list = [None] |
|
return (fit_indices_list,) |
|
|
|
|
|
@app.cell |
|
def _( |
|
acos, |
|
compton_peak_slider, |
|
elastic_peak_slider, |
|
incident_energy_slider, |
|
pi, |
|
): |
|
from mapstorch.default import default_param_vals, default_fitting_params |
|
from copy import copy |
|
|
|
coherent_sct_energy = incident_energy_slider.value |
|
energy_slope = coherent_sct_energy / elastic_peak_slider.value |
|
compton_energy = energy_slope * compton_peak_slider.value |
|
try: |
|
compton_angle = ( |
|
acos(1 - 511 * (1 / compton_energy - 1 / coherent_sct_energy)) * 180 / pi |
|
) |
|
except: |
|
compton_angle = default_param_vals["COMPTON_ANGLE"] |
|
param_default_vals = copy(default_param_vals) |
|
param_default_vals["COHERENT_SCT_ENERGY"] = coherent_sct_energy |
|
param_default_vals["ENERGY_SLOPE"] = energy_slope |
|
param_default_vals["COMPTON_ANGLE"] = compton_angle |
|
param_checkbox_vals = copy(param_default_vals) |
|
params_record = {p: [] for p in default_fitting_params + ["elements"]} |
|
return ( |
|
coherent_sct_energy, |
|
compton_angle, |
|
compton_energy, |
|
copy, |
|
default_fitting_params, |
|
default_param_vals, |
|
energy_slope, |
|
param_checkbox_vals, |
|
param_default_vals, |
|
params_record, |
|
) |
|
|
|
|
|
@app.cell |
|
def _(int_spec): |
|
from scipy.signal import find_peaks |
|
|
|
peaks, _ = find_peaks(int_spec, prominence=int_spec.max() / 200) |
|
return find_peaks, peaks |
|
|
|
|
|
@app.cell |
|
def _(energy_range, int_spec_og, np): |
|
int_spec = int_spec_og[energy_range.value[0] : energy_range.value[1] + 1] |
|
int_spec_log = np.log10(np.clip(int_spec, 0, None) + 1) |
|
return int_spec, int_spec_log |
|
|
|
|
|
@app.cell |
|
def _(dataset, dataset_button, elem_path, int_spec_path, mo, read_dataset): |
|
mo.stop(not dataset_button.value) |
|
dataset_dict = read_dataset( |
|
dataset.value[0].path, |
|
fit_elem_key=elem_path.value, |
|
int_spec_key=int_spec_path.value, |
|
) |
|
int_spec_og = dataset_dict["int_spec"] |
|
elems = dataset_dict["elems"] |
|
return dataset_dict, elems, int_spec_og |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|