Spaces:
Running
Running
File size: 3,665 Bytes
a0dc447 62977e4 a0dc447 f0ab20b a0dc447 f0ab20b a0dc447 b20df59 a0dc447 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import os
# Check GPU availability
use_cuda = torch.cuda.is_available()
gpu_ids = [0] if use_cuda else []
device = torch.device('cuda' if use_cuda else 'cpu')
dataset_name = 'all' # DRIVE
#dataset_name = 'LES' # LES
# dataset_name = 'hrf' # HRF
# dataset_name = 'ukbb' # UKBB
# dataset_name = 'all'
dataset = dataset_name
max_step = 30000 # 30000 for ukbb
batch_size = 8 # default: 4
print_iter = 100 # default: 100
display_iter = 100 # default: 100
save_iter = 5000 # default: 5000
first_display_metric_iter = max_step - save_iter # default: 25000
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
step_size = 7000 # 7000 for DRIVE
lr_decay_gamma = 0.5 # default: 0.5
use_SGD = False # default:False
input_nc = 3
ndf = 32
netD_type = 'basic'
n_layers_D = 5
norm = 'instance'
no_lsgan = False
init_type = 'normal'
init_gain = 0.02
use_sigmoid = no_lsgan
use_noise_input_D = False
use_dropout_D = False
# torch.cuda.set_device(gpu_ids[0])
use_GAN = True # default: True
# adam
beta1 = 0.5
# settings for GAN loss
num_classes_D = 1
lambda_GAN_D = 0.01
lambda_GAN_G = 0.01
lambda_GAN_gp = 100
lambda_BCE = 5
lambda_DICE = 5
input_nc_D = input_nc + 3
# settings for centerness
use_centerness = True # default: True
lambda_centerness = 1
center_loss_type = 'centerness'
centerness_map_size = [128, 128]
# pretrained model
use_pretrained_G = True
use_pretrained_D = False
# model_path_pretrained_G = './log/patch_pretrain'
model_path_pretrained_G = ''
model_step_pretrained_G = 0
stride_height = 0
stride_width = 0
patch_size_list=[]
use_CAM = False
#use resize
use_resize = False
resize_w_h = (1920,512)
def set_dataset(name):
global dataset_name, model_path_pretrained_G, model_step_pretrained_G
global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h
dataset_name = name
dataset = name
if dataset_name == 'DRIVE':
model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
model_step_pretrained_G = 6500
elif dataset_name == 'LES':
model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
model_step_pretrained_G = 0
elif dataset_name == 'hrf':
model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
model_step_pretrained_G = 1500
elif dataset_name == 'ukbb':
model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
model_step_pretrained_G = 5000
else:
model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
model_step_pretrained_G = 9000
if dataset_name == 'DRIVE':
patch_size_list = [64, 128, 256]
elif dataset_name == 'LES':
patch_size_list = [96, 384, 256]
elif dataset_name == 'hrf':
patch_size_list = [64, 384, 256]
elif dataset_name == 'ukbb':
patch_size_list = [96, 384, 256]
else:
patch_size_list = [96, 384, 512]
patch_size = patch_size_list[2]
# path for dataset
if dataset_name == 'DRIVE' or dataset_name == 'LES':
stride_height = 50
stride_width = 50
elif dataset_name == 'ukbb' or dataset_name == 'hrf':
use_CAM=False
use_resize = True
stride_height = 150
stride_width = 150
else:
use_CAM=True
use_resize = True
stride_height = 150
stride_width = 150
n_classes = 3
model_step = 0
# use av_cross
use_av_cross = False
use_high_semantic = False
lambda_high = 1 # A,V,Vessel
# use global semantic in local, huggingface set false
use_global_semantic = False
global_warmup_step = 0 if use_pretrained_G else 5000
|