huathedev commited on
Commit
8becd50
·
1 Parent(s): 7b3d409

Update pages/01_🦷 Segment.py

Browse files
Files changed (1) hide show
  1. pages/01_🦷 Segment.py +42 -7
pages/01_🦷 Segment.py CHANGED
@@ -1,27 +1,25 @@
 
1
  import shutil
2
- import os
3
 
 
4
  import numpy as np
5
  from sklearn import neighbors
6
  from scipy.spatial import distance_matrix
7
  from pygco import cut_from_graph
 
8
  import open3d as o3d
9
  import matplotlib.pyplot as plt
10
  import matplotlib.colors as mcolors
11
  from stqdm import stqdm
12
  import json
13
-
14
- import pyvista as pv
15
  from stpyvista import stpyvista
16
-
17
  import torch
18
  import torch.nn as nn
19
- import torch.nn.functional as F
20
  from torch.autograd import Variable
21
-
22
  import streamlit as st
 
23
 
24
- from streamlit import session_state as session
25
  from PIL import Image
26
 
27
  class TeethApp:
@@ -896,6 +894,43 @@ class Segment(TeethApp):
896
  if segment:
897
  segmentation_main("ZOUIF2W4_upper.obj")
898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
 
900
 
901
  elif inputs == "Upload Scan":
 
1
+ from streamlit import session_state as session
2
  import shutil
 
3
 
4
+ import os
5
  import numpy as np
6
  from sklearn import neighbors
7
  from scipy.spatial import distance_matrix
8
  from pygco import cut_from_graph
9
+ import streamlit_ext as ste
10
  import open3d as o3d
11
  import matplotlib.pyplot as plt
12
  import matplotlib.colors as mcolors
13
  from stqdm import stqdm
14
  import json
 
 
15
  from stpyvista import stpyvista
 
16
  import torch
17
  import torch.nn as nn
 
18
  from torch.autograd import Variable
19
+ import torch.nn.functional as F
20
  import streamlit as st
21
+ import pyvista as pv
22
 
 
23
  from PIL import Image
24
 
25
  class TeethApp:
 
894
  if segment:
895
  segmentation_main("ZOUIF2W4_upper.obj")
896
 
897
+ # Load the JSON file
898
+ with open('ZOUIF2W4_upper.json', 'r') as file:
899
+ labels_data = json.load(file)
900
+
901
+ # Assuming labels_data['labels'] is a list of labels
902
+ labels = labels_data['labels']
903
+
904
+ # Make sure the number of labels matches the number of vertices or faces
905
+ assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
906
+
907
+ # If labels correspond to vertices
908
+ if len(labels) == mesh.n_points:
909
+ mesh.point_data['Labels'] = labels
910
+ # If labels correspond to faces
911
+ elif len(labels) == mesh.n_cells:
912
+ mesh.cell_data['Labels'] = labels
913
+
914
+ # Create a pyvista plotter
915
+ plotter = pv.Plotter()
916
+
917
+ cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
918
+
919
+ colors = cmap(np.linspace(0, 1, 27)) # Generate colors
920
+
921
+ # Convert colors to a format acceptable by PyVista
922
+ colormap = mcolors.ListedColormap(colors)
923
+
924
+ # Add the mesh to the plotter with labels as a scalar field
925
+ #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
926
+ plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
927
+
928
+ # Show the plot
929
+ #plotter.show()
930
+ ## Send to streamlit
931
+ with st.expander("Ground Truth - scroll for zoom", expanded=False):
932
+ stpyvista(plotter)
933
+
934
 
935
 
936
  elif inputs == "Upload Scan":