fossil_app / app.py
andy-wyx's picture
show more xai output
dd58475
raw
history blame
12.8 kB
import os
import sys
from env import config_env
config_env()
import gradio as gr
from huggingface_hub import snapshot_download
import cv2
import dotenv
dotenv.load_dotenv()
import numpy as np
import gradio as gr
import glob
from inference_sam import segmentation_sam
from explanations import explain
from inference_resnet import get_triplet_model
from inference_beit import get_triplet_model_beit
import pathlib
import tensorflow as tf
from closest_sample import get_images
if not os.path.exists('images'):
REPO_ID='Serrelab/image_examples_gradio'
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
def get_model(model_name):
if model_name=='Mummified 170':
n_classes = 170
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/mummified-170.h5')
elif model_name=='Rock 170':
n_classes = 171
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/rock-170.h5')
elif model_name == 'Fossils 142':
n_classes = 142
model = get_triplet_model_beit(input_shape = (384, 384, 3),
embedding_units = 256,
embedding_depth = 2,
n_classes = n_classes)
model.load_weights('model_classification/fossil-142.h5')
else:
raise ValueError(f"Model name '{model_name}' is not recognized")
return model,n_classes
def segment_image(input_image):
img = segmentation_sam(input_image)
return img
def classify_image(input_image, model_name):
#segmented_image = segment_image(input_image)
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_finer
model,n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_finer
model, n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
if 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_finer_beit
model,n_classes = get_model(model_name)
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
return result
return None
def get_embeddings(input_image,model_name):
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model,n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model, n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
if 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_embedding_beit
model,n_classes = get_model(model_name)
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
return result
return None
def find_closest(input_image,model_name):
embedding = get_embeddings(input_image,model_name)
classes, paths = get_images(embedding)
#outputs = classes+paths
return classes,paths
def explain_image(input_image,model_name):
model,n_classes= get_model(model_name)
if model_name=='Fossils 142':
size = 384
else:
size = 600
#saliency, integrated, smoothgrad,
exp_list = explain(model,input_image,size = size, n_classes=n_classes)
#original = saliency + integrated + smoothgrad
print('done')
sobol1,sobol2,sobol3,sobol4,sobol5 = exp_list[0],exp_list[1],exp_list[2],exp_list[3],exp_list[4]
rise1,rise2,rise3,rise4,rise5 = exp_list[5],exp_list[6],exp_list[7],exp_list[8],exp_list[9]
hsic1,hsic2,hsic3,hsic4,hsic5 = exp_list[10],exp_list[11],exp_list[12],exp_list[13],exp_list[14]
saliency1,saliency2,saliency3,saliency4,saliency5 = exp_list[15],exp_list[16],exp_list[17],exp_list[18],exp_list[19]
return sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5
#minimalist theme
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
with gr.Tab(" Florrissant Fossils"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input")
classify_image_button = gr.Button("Classify Image")
# with gr.Column():
# #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
# segmented_image=gr.Image(label="Segmented Image", type='numpy')
# segment_button = gr.Button("Segment Image")
# #classify_segmented_button = gr.Button("Classify Segmented Image")
with gr.Column():
model_name = gr.Dropdown(
["Mummified 170", "Rock 170","Fossils 142"],
multiselect=False,
value="Fossils 142", # default option
label="Model",
interactive=True,
)
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
with gr.Row():
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
# with gr.Accordion("Using Diffuser"):
# with gr.Column():
# prompt = gr.Textbox(lines=1, label="Prompt")
# output_image = gr.Image(label="Output")
# generate_button = gr.Button("Generate Leave")
# with gr.Column():
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
# classify_button = gr.Button("Classify Image")
with gr.Accordion("Explanations "):
gr.Markdown("Computing Explanations from the model")
with gr.Column():
with gr.Row():
#original_input = gr.Image(label="Original Frame")
#saliency = gr.Image(label="saliency")
#gradcam = gr.Image(label='integraged gradients')
#guided_gradcam = gr.Image(label='gradcam')
#guided_backprop = gr.Image(label='guided backprop')
sobol1 = gr.Image(label = 'Sobol1')
sobol2= gr.Image(label = 'Sobol2')
sobol3= gr.Image(label = 'Sobol3')
sobol4= gr.Image(label = 'Sobol4')
sobol5= gr.Image(label = 'Sobol5')
with gr.Row():
rise1 = gr.Image(label = 'Rise1')
rise2 = gr.Image(label = 'Rise2')
rise3 = gr.Image(label = 'Rise3')
rise4 = gr.Image(label = 'Rise4')
rise5 = gr.Image(label = 'Rise5')
with gr.Row():
hsic1 = gr.Image(label = 'HSIC1')
hsic2 = gr.Image(label = 'HSIC2')
hsic3 = gr.Image(label = 'HSIC3')
hsic4 = gr.Image(label = 'HSIC4')
hsic5 = gr.Image(label = 'HSIC5')
with gr.Row():
saliency1 = gr.Image(label = 'Saliency1')
saliency2 = gr.Image(label = 'Saliency2')
saliency3 = gr.Image(label = 'Saliency3')
saliency4 = gr.Image(label = 'Saliency4')
saliency5 = gr.Image(label = 'Saliency5')
generate_explanations = gr.Button("Generate Explanations")
# with gr.Accordion('Closest Images'):
# gr.Markdown("Finding the closest images in the dataset")
# with gr.Row():
# with gr.Column():
# label_closest_image_0 = gr.Markdown('')
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_1 = gr.Markdown('')
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_2 = gr.Markdown('')
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_3 = gr.Markdown('')
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
# with gr.Column():
# label_closest_image_4 = gr.Markdown('')
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
# find_closest_btn = gr.Button("Find Closest Images")
with gr.Accordion('Closest Images'):
gr.Markdown("Finding the closest images in the dataset")
with gr.Row():
gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
#.style(grid=[1, 5], height=200, width=200)
find_closest_btn = gr.Button("Find Closest Images")
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5]) #
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
def update_outputs(input_image,model_name):
labels, images = find_closest(input_image,model_name)
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
image_caption=[]
for i in range(5):
image_caption.append((images[i],labels[i]))
return image_caption
find_closest_btn.click(fn=update_outputs, inputs=[input_image,model_name], outputs=[gallery])
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
demo.queue() # manage multiple incoming requests
if os.getenv('SYSTEM') == 'spaces':
demo.launch(width='40%',auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD')))
else:
demo.launch()