fossil_app / inference_beit.py
Yuxiang Wang
explanations,closest sample
c5343e6
raw
history blame
6.46 kB
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if gpu_devices:
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
print(f"TensorFlow device: {gpu_devices}")
import os
import numpy as np
import keras
from PIL import Image
import keras_cv
from keras_cv_attention_models import beit
import matplotlib.pyplot as plt
#preprocessing
#TODO
num_classes = len(class_names)
AUTO = tf.data.AUTOTUNE
rand_augment = keras_cv.layers.RandAugment(value_range = (-1, 1), augmentations_per_image = 3, magnitude=0.5)
SIZE = 384
debug = None
def augmentations(x, crop_size=22, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
x = tf.cast(x, tf.float32)
x = tf.image.random_crop(x, (tf.shape(x)[0], 100, 100, 3))
x = tf.image.random_brightness(x, max_delta=brightness)
x = tf.image.random_contrast(x, lower=1.0-contrast, upper=1+contrast)
x = tf.image.random_saturation(x, lower=1.0-saturation, upper=1.0+saturation)
x = tf.image.random_hue(x, max_delta=hue)
x = tf.image.resize(x, (128, 128))
x = tf.clip_by_value(x, 0.0, 255.0)
x = tf.keras.applications.resnet_v2.preprocess_input(x)
return x
def pad_gt(x):
h, w = x.shape[-2:]
padh = sam.image_encoder.img_size - h
padw = sam.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def preprocess(img):
img = np.array(img).astype(np.uint8)
#assert img.max() > 127.0
img_preprocess = predictor.transform.apply_image(img)
intermediate_shape = img_preprocess.shape
img_preprocess = torch.as_tensor(img_preprocess).cuda()
img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
img_preprocess = sam.preprocess(img_preprocess)
if len(intermediate_shape) == 3:
intermediate_shape = intermediate_shape[:2]
elif len(intermediate_shape) == 4:
intermediate_shape = intermediate_shape[1:3]
return img_preprocess, intermediate_shape
def normalize(img):
img = img - tf.math.reduce_min(img)
img = img / tf.math.reduce_max(img)
img = img * 2.0 - 1.0
return img
def smooth_mask(mask, ds=20):
shape = tf.shape(mask)
w, h = shape[0], shape[1]
return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
def resize(img):
# default resize function for all pi outputs
return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
def pi(img, mask):
img = tf.cast(img, tf.float32)
shape = tf.shape(img)
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
mask = smooth_mask(mask)
mask = tf.reduce_mean(mask, -1)
img = img * tf.cast(mask > 0.1, tf.float32)[:, :, None]
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
# building 2 anchors
anchors = tf.where(mask > 0.15)
anchor_xmin = tf.math.reduce_min(anchors[:, 0])
anchor_xmax = tf.math.reduce_max(anchors[:, 0])
anchor_ymin = tf.math.reduce_min(anchors[:, 1])
anchor_ymax = tf.math.reduce_max(anchors[:, 1])
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
delta_x = (anchor_xmax - anchor_xmin) // 4
delta_y = (anchor_ymax - anchor_ymin) // 4
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
anchor_ymin+delta_y:anchor_ymax-delta_y]
img_anchor_2 = resize(img_anchor_2)
else:
img_anchor_1 = img_resize
img_anchor_2 = img_pad
# building the anchors max
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
img_max_zoom1 = resize(img_max_zoom1)
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
#tf.print(img_max_zoom2.shape)
#img_max_zoom2 = resize(img_max_zoom2)
return tf.cast(img_resize, tf.float32)
def parse_img(element, split, randaugment,maskaugment=True):
#global debug
path, class_id = element[0], element[1]
data = tf.io.read_file(path)
img = tf.io.decode_jpeg(data)
img = tf.cast(img, tf.uint8)
img = normalize(img)
shape = tf.shape(img)
# data_mask = tf.io.read_file(path_mask)
# mask = tf.io.decode_jpeg(data_mask)
class_id = tf.strings.to_number(class_id)
class_id = tf.cast(class_id, tf.int32)
label = tf.one_hot(class_id, num_classes)
# img = pi(img, mask)
img = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
return tf.cast(img, tf.float32), tf.cast(label, tf.int32)
SIZE = 384
wsize=hsize=SIZE
def resize_images(batch_x, width=224, height=224):
return tf.image.resize(batch_x, (width, height))
def load_img(image_path,gray=False):
img = tf.io.read_file(image_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
if gray:
img = tf.image.rgb_to_grayscale(img)
img = tf.image.grayscale_to_rgb(img)
img = tf.image.resize(img,(wsize,hsize))
return img
LR = 1e-3
optimizer = tf.keras.optimizers.Adam(LR)
cce = tf.keras.losses.categorical_crossentropy
model_path = '/content/drive/MyDrive/Gg_Fossils_data_shared_copy/Fossils/models/model-13.h5'
model = keras.models.load_model(model_path, custom_objects = {'cce': cce})
outputs = model.predict(images)
predictions = tf.math.top_k(outputs[1], k = 5)
cid = 1
dataset = np.array(dataset)
final_predictions = []
for ele in predictions[1]:
if cid in ele:
final_predictions.append(cid)
else:
final_predictions.append(cid+10)
final_predictions = np.array(final_predictions)
images2 = images[final_predictions == cid]
image2_paths = dataset[final_predictions == cid][:,0]
print(images2.shape)
def get_beit_model(input_shape, num_labels, load_weights=False, ...):
pass
def inference_dino(input_image, model_name):
pass
def inference_beit_embedding(input_image, model, size=600):
pass