|
""" |
|
This file contains the code to plot a 3d tree |
|
""" |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
from scipy.interpolate import griddata |
|
|
|
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): |
|
""" |
|
Generates a 3D surface plot showing the relationship between detectability, distortion, |
|
and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score. |
|
|
|
The function takes three sets of values: detectability, distortion, and Euclidean distance, |
|
normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics. |
|
The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted |
|
as a red marker on the 3D surface plot. |
|
|
|
The function then uses a grid interpolation method (`griddata`) to generate a smooth surface |
|
for the Euclidean distance over the detectability and distortion values. The result is a surface plot |
|
where the contours represent different Euclidean distances. |
|
|
|
Args: |
|
detectability_val (list or array): A list or array of detectability scores. |
|
distortion_val (list or array): A list or array of distortion scores. |
|
euclidean_val (list or array): A list or array of Euclidean distances. |
|
|
|
Returns: |
|
plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot, |
|
with contour lines and a marker for the sweet spot. |
|
|
|
Raises: |
|
ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the |
|
input data does not allow for a proper interpolation. |
|
|
|
Example: |
|
# Example of usage: |
|
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9] |
|
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0] |
|
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6] |
|
|
|
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals) |
|
fig.show() # Displays the plot in a web browser |
|
|
|
Notes: |
|
- The composite score is calculated as: |
|
`composite_score = norm_detectability - (norm_distortion + norm_euclidean)`, |
|
where the goal is to maximize detectability and minimize distortion and Euclidean distance. |
|
- The `griddata` function uses linear interpolation to create a smooth surface for the plot. |
|
- The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme. |
|
""" |
|
|
|
detectability = np.array(detectability_val) |
|
distortion = np.array(distortion_val) |
|
euclidean = np.array(euclidean_val) |
|
|
|
|
|
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability)) |
|
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion)) |
|
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean)) |
|
|
|
|
|
composite_score = norm_detectability - (norm_distortion + norm_euclidean) |
|
|
|
|
|
sweet_spot_index = np.argmax(composite_score) |
|
|
|
|
|
sweet_spot_detectability = detectability[sweet_spot_index] |
|
sweet_spot_distortion = distortion[sweet_spot_index] |
|
sweet_spot_euclidean = euclidean[sweet_spot_index] |
|
|
|
|
|
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), |
|
np.linspace(min(distortion), max(distortion), 30)) |
|
|
|
|
|
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest') |
|
|
|
if z_grid is None: |
|
raise ValueError("griddata could not generate a valid interpolation. Check your input data.") |
|
|
|
|
|
fig = go.Figure(data=go.Surface( |
|
z=z_grid, |
|
x=x_grid, |
|
y=y_grid, |
|
contours={ |
|
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} |
|
}, |
|
colorscale='Plasma' |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter3d( |
|
x=[sweet_spot_detectability], |
|
y=[sweet_spot_distortion], |
|
z=[sweet_spot_euclidean], |
|
mode='markers+text', |
|
marker=dict(size=10, color='red', symbol='circle'), |
|
text=["Sweet Spot"], |
|
textposition="top center" |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
scene=dict( |
|
xaxis_title='Detectability Score', |
|
yaxis_title='Distortion Score', |
|
zaxis_title='Euclidean Distance' |
|
), |
|
margin=dict(l=0, r=0, b=0, t=0) |
|
) |
|
|
|
return fig |
|
|
|
if __name__ == "__main__": |
|
|
|
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9] |
|
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0] |
|
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6] |
|
|
|
|
|
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals) |
|
|
|
|
|
fig.show() |