Update pages/01_🦷 Segment.py
Browse files- pages/01_🦷 Segment.py +42 -7
pages/01_🦷 Segment.py
CHANGED
@@ -1,27 +1,25 @@
|
|
|
|
1 |
import shutil
|
2 |
-
import os
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
from sklearn import neighbors
|
6 |
from scipy.spatial import distance_matrix
|
7 |
from pygco import cut_from_graph
|
|
|
8 |
import open3d as o3d
|
9 |
import matplotlib.pyplot as plt
|
10 |
import matplotlib.colors as mcolors
|
11 |
from stqdm import stqdm
|
12 |
import json
|
13 |
-
|
14 |
-
import pyvista as pv
|
15 |
from stpyvista import stpyvista
|
16 |
-
|
17 |
import torch
|
18 |
import torch.nn as nn
|
19 |
-
import torch.nn.functional as F
|
20 |
from torch.autograd import Variable
|
21 |
-
|
22 |
import streamlit as st
|
|
|
23 |
|
24 |
-
from streamlit import session_state as session
|
25 |
from PIL import Image
|
26 |
|
27 |
class TeethApp:
|
@@ -896,6 +894,43 @@ class Segment(TeethApp):
|
|
896 |
if segment:
|
897 |
segmentation_main("ZOUIF2W4_upper.obj")
|
898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
|
900 |
|
901 |
elif inputs == "Upload Scan":
|
|
|
1 |
+
from streamlit import session_state as session
|
2 |
import shutil
|
|
|
3 |
|
4 |
+
import os
|
5 |
import numpy as np
|
6 |
from sklearn import neighbors
|
7 |
from scipy.spatial import distance_matrix
|
8 |
from pygco import cut_from_graph
|
9 |
+
import streamlit_ext as ste
|
10 |
import open3d as o3d
|
11 |
import matplotlib.pyplot as plt
|
12 |
import matplotlib.colors as mcolors
|
13 |
from stqdm import stqdm
|
14 |
import json
|
|
|
|
|
15 |
from stpyvista import stpyvista
|
|
|
16 |
import torch
|
17 |
import torch.nn as nn
|
|
|
18 |
from torch.autograd import Variable
|
19 |
+
import torch.nn.functional as F
|
20 |
import streamlit as st
|
21 |
+
import pyvista as pv
|
22 |
|
|
|
23 |
from PIL import Image
|
24 |
|
25 |
class TeethApp:
|
|
|
894 |
if segment:
|
895 |
segmentation_main("ZOUIF2W4_upper.obj")
|
896 |
|
897 |
+
# Load the JSON file
|
898 |
+
with open('ZOUIF2W4_upper.json', 'r') as file:
|
899 |
+
labels_data = json.load(file)
|
900 |
+
|
901 |
+
# Assuming labels_data['labels'] is a list of labels
|
902 |
+
labels = labels_data['labels']
|
903 |
+
|
904 |
+
# Make sure the number of labels matches the number of vertices or faces
|
905 |
+
assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
|
906 |
+
|
907 |
+
# If labels correspond to vertices
|
908 |
+
if len(labels) == mesh.n_points:
|
909 |
+
mesh.point_data['Labels'] = labels
|
910 |
+
# If labels correspond to faces
|
911 |
+
elif len(labels) == mesh.n_cells:
|
912 |
+
mesh.cell_data['Labels'] = labels
|
913 |
+
|
914 |
+
# Create a pyvista plotter
|
915 |
+
plotter = pv.Plotter()
|
916 |
+
|
917 |
+
cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
|
918 |
+
|
919 |
+
colors = cmap(np.linspace(0, 1, 27)) # Generate colors
|
920 |
+
|
921 |
+
# Convert colors to a format acceptable by PyVista
|
922 |
+
colormap = mcolors.ListedColormap(colors)
|
923 |
+
|
924 |
+
# Add the mesh to the plotter with labels as a scalar field
|
925 |
+
#plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
|
926 |
+
plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
|
927 |
+
|
928 |
+
# Show the plot
|
929 |
+
#plotter.show()
|
930 |
+
## Send to streamlit
|
931 |
+
with st.expander("Ground Truth - scroll for zoom", expanded=False):
|
932 |
+
stpyvista(plotter)
|
933 |
+
|
934 |
|
935 |
|
936 |
elif inputs == "Upload Scan":
|