LSM / submodules /lang_seg /modules /lseg_module_zs.py
kairunwen's picture
Update Code
57746f1
import re
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from argparse import ArgumentParser
import pytorch_lightning as pl
from .lsegmentation_module_zs import LSegmentationModuleZS
from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS
from encoding.models.sseg.base import up_kwargs
import os
import clip
import numpy as np
from scipy import signal
import glob
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
class LSegModuleZS(LSegmentationModuleZS):
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
super(LSegModuleZS, self).__init__(
data_path, dataset, batch_size, base_lr, max_epochs, **kwargs
)
label_list = self.get_labels(dataset)
self.len_dataloader = len(label_list)
# print(kwargs)
if kwargs["use_pretrained"] in ['False', False]:
use_pretrained = False
elif kwargs["use_pretrained"] in ['True', True]:
use_pretrained = True
if kwargs["backbone"] in ["clip_resnet101"]:
self.net = LSegRNNetZS(
label_list=label_list,
backbone=kwargs["backbone"],
features=kwargs["num_features"],
aux=kwargs["aux"],
use_pretrained=use_pretrained,
arch_option=kwargs["arch_option"],
block_depth=kwargs["block_depth"],
activation=kwargs["activation"],
)
else:
self.net = LSegNetZS(
label_list=label_list,
backbone=kwargs["backbone"],
features=kwargs["num_features"],
aux=kwargs["aux"],
use_pretrained=use_pretrained,
arch_option=kwargs["arch_option"],
block_depth=kwargs["block_depth"],
activation=kwargs["activation"],
)
def get_labels(self, dataset):
labels = []
path = 'label_files/fewshot_{}.txt'.format(dataset)
assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path)
f = open(path, 'r')
lines = f.readlines()
for line in lines:
label = line.strip()
labels.append(label)
f.close()
print(labels)
return labels
@staticmethod
def add_model_specific_args(parent_parser):
parser = LSegmentationModuleZS.add_model_specific_args(parent_parser)
parser = ArgumentParser(parents=[parser])
parser.add_argument(
"--backbone",
type=str,
default="vitb16_384",
help="backbone network",
)
parser.add_argument(
"--num_features",
type=int,
default=256,
help="number of featurs that go from encoder to decoder",
)
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
parser.add_argument(
"--finetune_weights", type=str, help="load weights to finetune from"
)
parser.add_argument(
"--no-scaleinv",
default=True,
action="store_false",
help="turn off scaleinv layers",
)
parser.add_argument(
"--no-batchnorm",
default=False,
action="store_true",
help="turn off batchnorm",
)
parser.add_argument(
"--widehead", default=False, action="store_true", help="wider output head"
)
parser.add_argument(
"--widehead_hr",
default=False,
action="store_true",
help="wider output head",
)
parser.add_argument(
"--use_pretrained",
type=str,
default="True",
help="whether use the default model to intialize the model",
)
parser.add_argument(
"--arch_option",
type=int,
default=0,
help="which kind of architecture to be used",
)
parser.add_argument(
"--block_depth",
type=int,
default=0,
help="how many blocks should be used",
)
parser.add_argument(
"--activation",
choices=['relu', 'lrelu', 'tanh'],
default="relu",
help="use which activation to activate the block",
)
return parser