File size: 5,288 Bytes
060ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
This file contains the code to plot a 3d tree
"""
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata

def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
    """
    Generates a 3D surface plot showing the relationship between detectability, distortion, 
    and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.

    The function takes three sets of values: detectability, distortion, and Euclidean distance,
    normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
    The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted 
    as a red marker on the 3D surface plot.

    The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
    for the Euclidean distance over the detectability and distortion values. The result is a surface plot
    where the contours represent different Euclidean distances.

    Args:
        detectability_val (list or array): A list or array of detectability scores.
        distortion_val (list or array): A list or array of distortion scores.
        euclidean_val (list or array): A list or array of Euclidean distances.

    Returns:
        plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot, 
                                     with contour lines and a marker for the sweet spot.

    Raises:
        ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the 
                    input data does not allow for a proper interpolation.

    Example:
        # Example of usage:
        detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
        distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
        euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
        
        fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
        fig.show()  # Displays the plot in a web browser

    Notes:
        - The composite score is calculated as: 
          `composite_score = norm_detectability - (norm_distortion + norm_euclidean)`, 
          where the goal is to maximize detectability and minimize distortion and Euclidean distance.
        - The `griddata` function uses linear interpolation to create a smooth surface for the plot.
        - The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
    """
    
    detectability = np.array(detectability_val)
    distortion = np.array(distortion_val)
    euclidean = np.array(euclidean_val)

    # Normalize the values to range [0, 1]
    norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
    norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
    norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))

    # Composite score: maximize detectability, minimize distortion and Euclidean distance
    composite_score = norm_detectability - (norm_distortion + norm_euclidean)

    # Find the index of the maximum score (sweet spot)
    sweet_spot_index = np.argmax(composite_score)

    # Sweet spot values
    sweet_spot_detectability = detectability[sweet_spot_index]
    sweet_spot_distortion = distortion[sweet_spot_index]
    sweet_spot_euclidean = euclidean[sweet_spot_index]

    # Create a meshgrid from the data
    x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
                                 np.linspace(min(distortion), max(distortion), 30))

    # Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
    z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')

    if z_grid is None:
        raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

    # Create the 3D contour plot with the Plasma color scale
    fig = go.Figure(data=go.Surface(
        z=z_grid, 
        x=x_grid, 
        y=y_grid, 
        contours={
            "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
        },
        colorscale='Plasma'
    ))

    # Add a marker for the sweet spot
    fig.add_trace(go.Scatter3d(
        x=[sweet_spot_detectability],
        y=[sweet_spot_distortion],
        z=[sweet_spot_euclidean],
        mode='markers+text',
        marker=dict(size=10, color='red', symbol='circle'),
        text=["Sweet Spot"],
        textposition="top center"
    ))

    # Set axis labels
    fig.update_layout(
        scene=dict(
            xaxis_title='Detectability Score',
            yaxis_title='Distortion Score',
            zaxis_title='Euclidean Distance'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig

if __name__ == "__main__":
    # Example input data
    detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
    distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
    euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
    
    # Call the function with example data
    fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
    
    # Show the plot
    fig.show()