import torch import gradio as gr from PIL import Image import cv2 from AV.models.network import PGNet from AV.Tools.AVclassifiation import AVclassifiation from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \ kill_border from AV.config import config_test_general as cfg import torch.autograd as autograd import numpy as np import os from datetime import datetime from huggingface_hub import hf_hub_download hf_token = os.environ.get("HF_token") def creatMask(Image, threshold=5): ##This program try to creat the mask for the filed-of-view ##Input original image (RGB or green channel), threshold (user set parameter, default 10) ##Output: the filed-of-view mask if len(Image.shape) == 3: ##RGB image gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY) Mask0 = gray >= threshold else: # for green channel image Mask0 = Image >= threshold # ######get the largest blob, this takes 0.18s cvVersion = int(cv2.__version__.split('.')[0]) Mask0 = np.uint8(Mask0) contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) areas = [cv2.contourArea(c) for c in contours] max_index = np.argmax(areas) Mask = np.zeros(Image.shape[:2], dtype=np.uint8) cv2.drawContours(Mask, contours, max_index, 1, -1) ResultImg = Image.copy() if len(Image.shape) == 3: ResultImg[Mask == 0] = (255, 255, 255) else: ResultImg[Mask == 0] = 255 Mask[Mask > 0] = 255 kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3) return ResultImg, Mask def shift_rgb(img, *args): result_img = np.empty_like(img) shifts = args max_value = 255 # print(shifts) for i, shift in enumerate(shifts): lut = np.arange(0, max_value + 1).astype("float32") lut += shift lut = np.clip(lut, 0, max_value).astype(img.dtype) if len(img.shape) == 2: print(f'=========grey image=======') result_img = cv2.LUT(img, lut) else: result_img[..., i] = cv2.LUT(img[..., i], lut) return result_img def CAM(x, img_path, rate=0.8, ind=0): """ :param dataset_path: 计算整个训练数据集的平均RGB通道值 :param image: array, 单张图片的array 形式 :return: array形式的cam后的结果 """ # 每次使用新数据集时都需要重新计算前面的RBG平均值 # RGB-->Rshift-->CLAHE x = np.uint8(x) _, Mask0 = creatMask(x, threshold=10) Mask = np.zeros((x.shape[0], x.shape[1]), np.float32) Mask[Mask0 > 0] = 1 resize = False R_mea_num, G_mea_num, B_mea_num = [], [], [] dataset_path = img_path image = np.array(Image.open(dataset_path)) R_mea_num.append(np.mean(image[:, :, 0])) G_mea_num.append(np.mean(image[:, :, 1])) B_mea_num.append(np.mean(image[:, :, 2])) mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate) mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate) mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate) y = shift_rgb(x, mea2stand, mea2standg, mea2standb) y[Mask == 0, :] = 0 return y def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3, config=None, output_dir='', evaluate_metrics=False): # path for images to save n_classes = 3 Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch, num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness, centerness_map_size=config.centerness_map_size) msg = Net.load_state_dict(net, strict=False) if use_cuda: Net.cuda() Net.eval() image_basename = dataset # if not os.path.exists(output_dir): # os.makedirs(output_dir) step = 1 # every step of between star and end for loop until len(image_basename) # for start_end in start_end_list: image0 = cv2.imread(image_basename) test_image_height = image0.shape[0] test_image_width = image0.shape[1] if config.use_resize: if min(test_image_height, test_image_width) <= 256: scaling = 512 / min(test_image_height, test_image_width) new_width = int(test_image_width * scaling) new_height = int(test_image_height * scaling) test_image_width, test_image_height = new_width, new_height # 大尺寸处理:确保最长边≤1536 elif max(test_image_height, test_image_width) >= 2048: scaling = 2048 / max(test_image_height, test_image_width) new_width = int(test_image_width * scaling) new_height = int(test_image_height * scaling) test_image_width, test_image_height = new_width, new_height ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32) MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0, use_cuda=use_cuda, dataset=image_basename, is_kill_border=is_kill_border, config=config, resize_w_h=( test_image_width, test_image_height) ) ArteryPredAll[0 % step, :, :, :] = ArteryPred VeinPredAll[0 % step, :, :, :] = VeinPred VesselPredAll[0 % step, :, :, :] = VesselPred MaskAll[0 % step, :, :, :] = Mask image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename) return image_color def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None, resize_w_h=None): ImgName = dataset Img0 = cv2.imread(ImgName) _, Mask0 = creatMask(Img0, threshold=-1) Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32) Mask[Mask0 > 0] = 1 if config.use_resize: Img0 = cv2.resize(Img0, resize_w_h) Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST) Img = Img0 height, width = Img.shape[:2] n_classes = 3 patch_height = config.patch_size patch_width = config.patch_size stride_height = config.stride_height stride_width = config.stride_width Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) if cfg.dataset == 'all': # # # 将图像转换为 LAB 颜色空间 lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB) # 拆分 LAB 通道 l, a, b = cv2.split(lab) # 创建 CLAHE 对象并应用到 L 通道 clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8)) l_clahe = clahe.apply(l) # 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并 lab_clahe = cv2.merge((l_clahe, a, b)) # 将图像转换回 BGR 颜色空间 Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) if cfg.use_CAM: Img = CAM(Img, dataset) Img = np.float32(Img / 255.) Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width) patch_size = config.patch_size batch_size = 2 patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width, stride_height, stride_width) patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2)) patches_imgs = Normalize(patches_imgs) global_images = np.transpose(global_images, (0, 3, 1, 2)) global_images = Normalize(global_images) patchNum = patches_imgs.shape[0] max_iter = int(np.ceil(patchNum / float(batch_size))) pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32) for i in range(max_iter): begin_index = i * batch_size end_index = (i + 1) * batch_size patches_temp1 = patches_imgs[begin_index:end_index, :, :, :] patches_input_temp1 = torch.FloatTensor(patches_temp1) global_input_temp1 = patches_input_temp1 if config.use_global_semantic: global_temp1 = global_images[begin_index:end_index, :, :, :] global_input_temp1 = torch.FloatTensor(global_temp1) if use_cuda: patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda()) if config.use_global_semantic: global_input_temp1 = autograd.Variable(global_input_temp1.cuda()) else: patches_input_temp1 = autograd.Variable(patches_input_temp1) if config.use_global_semantic: global_input_temp1 = autograd.Variable(global_input_temp1) output_temp, _1, = Net(patches_input_temp1, global_input_temp1) pred_patches_temp = np.float32(output_temp.data.cpu().numpy()) pred_patches_temp_sigmoid = sigmoid(pred_patches_temp) pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size] del patches_input_temp1 del pred_patches_temp del patches_temp1 del output_temp del pred_patches_temp_sigmoid new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1] pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions pred_img = pred_img[:, 0:height, 0:width] if is_kill_border: pred_img = kill_border(pred_img, Mask) ArteryPred = np.float32(pred_img[0, :, :]) VeinPred = np.float32(pred_img[2, :, :]) VesselPred = np.float32(pred_img[1, :, :]) ArteryPred = ArteryPred[np.newaxis, :, :] VeinPred = VeinPred[np.newaxis, :, :] VesselPred = VesselPred[np.newaxis, :, :] Mask = Mask[np.newaxis, :, :] return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred, def out_test(cfg,model_path='', output_dir='', evaluate_metrics=False, img_name='out_test'): device = torch.device("cuda" if cfg.use_cuda else "cpu") model_path = model_path net = torch.load(model_path, map_location=device) image_color = modelEvalution_out_big(net, use_cuda=cfg.use_cuda, dataset=img_name, input_ch=cfg.input_nc, config=cfg, output_dir=output_dir, evaluate_metrics=evaluate_metrics) return image_color def segment_by_out_test(image,model_name): print("✅ 传到后端的模型名:", model_name) model_path = hf_hub_download( repo_id="weidai00/RIP-AV-sulab", # 模型库的名字 filename=f"G_{model_name}.pkl", # 文件名 repo_type="model", # 模型库必须写 repo_type token=hf_token ) cfg.set_dataset(model_name) if image is None: raise gr.Error("请上传一张图像(upload a fundus image)。") os.makedirs("./examples", exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_path = f"./examples/tmp_upload_{timestamp}.png" image.save(temp_path) image_color = out_test(cfg,model_path=model_path, output_dir='', evaluate_metrics=False, img_name=temp_path) return Image.fromarray(image_color) def gradio_interface(): model_info_md = """ ### 📘 模型说明 | 模型(model name) | 数据集(dataset) | patch size |running time | |------|--------|------------|--------| | DRIVE | 小分辨率血管图像 | 256 |30s以内| | HRF | 高分辨率图像(健康、青光眼等)| 256 | 2min以内| | LES | 视盘中心图像适配 | 256 |2min以内| | UKBB | UKBB图像 | 256 |2min以内 | | 通用模型(512) | 超清图像,适配性强 | 512 |2min以内| """ model_choices = [ ("1: DRIVE专用模型", "DRIVE"), ("2: HRF专用模型", "hrf"), ("3: LES专用模型","LES"), ("4: UKBB专用模型", "ukbb"), ("5: 通用模型(general)", "all"), ] with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 👁️ 眼底图像动静脉血管分割(Retinal image artery and vein segmentation)") gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。(Upload the retinal image, select a model to start processing, and the results will be generated automatically.)") with gr.Row(): image_input = gr.Image(type="pil", label="📤 上传图像(upload)",height=300) with gr.Row(): with gr.Column(): model_select = gr.Radio( choices=model_choices, label="🎯 选择模型", value="DRIVE", interactive = True ) submit_btn = gr.Button("🚀 开始分割(RUN)") with gr.Column(): output_image = gr.Image(label="🖼️ 分割结果(Result)") gr.Markdown("### 📁 示例图像examples(点击自动加载)") gr.Examples( examples=[ ["examples/DRIVE.tif", "DRIVE"], ["examples/LES.png", "LES"], ["examples/hrf.png", "hrf"], ["examples/ukbb.png", "ukbb"], ["examples/all.jpg", "all"] ], inputs=[image_input, model_select], label="示例图像", examples_per_page=5 ) with gr.Accordion("📖 模型说明-Description(点击展开)", open=False): gr.Markdown(model_info_md) # 功能连接 submit_btn.click( fn=segment_by_out_test, inputs=[image_input, model_select], outputs=[output_image] ) gr.Markdown("📚 **专用模型引用cite**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation") gr.Markdown("📚 **通用模型引用cite**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.") demo.queue() demo.launch() if __name__ == '__main__': # cfg.set_dataset('all') # image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif') # Image.fromarray(image_color).save('image_color.png') #print(cfg.patch_size) gradio_interface()