File size: 5,288 Bytes
060ac52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
This file contains the code to plot a 3d tree
"""
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
"""
Generates a 3D surface plot showing the relationship between detectability, distortion,
and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.
The function takes three sets of values: detectability, distortion, and Euclidean distance,
normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted
as a red marker on the 3D surface plot.
The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
for the Euclidean distance over the detectability and distortion values. The result is a surface plot
where the contours represent different Euclidean distances.
Args:
detectability_val (list or array): A list or array of detectability scores.
distortion_val (list or array): A list or array of distortion scores.
euclidean_val (list or array): A list or array of Euclidean distances.
Returns:
plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot,
with contour lines and a marker for the sweet spot.
Raises:
ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the
input data does not allow for a proper interpolation.
Example:
# Example of usage:
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
fig.show() # Displays the plot in a web browser
Notes:
- The composite score is calculated as:
`composite_score = norm_detectability - (norm_distortion + norm_euclidean)`,
where the goal is to maximize detectability and minimize distortion and Euclidean distance.
- The `griddata` function uses linear interpolation to create a smooth surface for the plot.
- The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
"""
detectability = np.array(detectability_val)
distortion = np.array(distortion_val)
euclidean = np.array(euclidean_val)
# Normalize the values to range [0, 1]
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
# Composite score: maximize detectability, minimize distortion and Euclidean distance
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
# Find the index of the maximum score (sweet spot)
sweet_spot_index = np.argmax(composite_score)
# Sweet spot values
sweet_spot_detectability = detectability[sweet_spot_index]
sweet_spot_distortion = distortion[sweet_spot_index]
sweet_spot_euclidean = euclidean[sweet_spot_index]
# Create a meshgrid from the data
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
np.linspace(min(distortion), max(distortion), 30))
# Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')
if z_grid is None:
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# Create the 3D contour plot with the Plasma color scale
fig = go.Figure(data=go.Surface(
z=z_grid,
x=x_grid,
y=y_grid,
contours={
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
},
colorscale='Plasma'
))
# Add a marker for the sweet spot
fig.add_trace(go.Scatter3d(
x=[sweet_spot_detectability],
y=[sweet_spot_distortion],
z=[sweet_spot_euclidean],
mode='markers+text',
marker=dict(size=10, color='red', symbol='circle'),
text=["Sweet Spot"],
textposition="top center"
))
# Set axis labels
fig.update_layout(
scene=dict(
xaxis_title='Detectability Score',
yaxis_title='Distortion Score',
zaxis_title='Euclidean Distance'
),
margin=dict(l=0, r=0, b=0, t=0)
)
return fig
if __name__ == "__main__":
# Example input data
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
# Call the function with example data
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
# Show the plot
fig.show() |