|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq import distributed_utils, options, tasks, utils |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.logging import progress_bar |
|
from fairseq.utils import reset_logging |
|
from omegaconf import DictConfig |
|
|
|
from utils import checkpoint_utils |
|
from utils.eval_utils import eval_step, merge_results |
|
from utils.zero_shot_utils import zero_shot_step |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
stream=sys.stdout, |
|
) |
|
logger = logging.getLogger("ofa.evaluate") |
|
|
|
from utils.utils import print_trainable_params_percentage, setup_for_distributed |
|
|
|
def apply_half(t): |
|
if t.dtype is torch.float32: |
|
return t.to(dtype=torch.half) |
|
return t |
|
|
|
|
|
def main(cfg: DictConfig, **kwargs): |
|
utils.import_user_module(cfg.common) |
|
|
|
setup_for_distributed(distributed_utils.is_master(cfg.distributed_training)) |
|
|
|
reset_logging() |
|
|
|
|
|
assert ( |
|
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None |
|
), "Must specify batch size either with --max-tokens or --batch-size" |
|
|
|
|
|
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: |
|
np.random.seed(cfg.common.seed) |
|
utils.set_torch_seed(cfg.common.seed) |
|
|
|
use_fp16 = cfg.common.fp16 |
|
use_cuda = torch.cuda.is_available() and not cfg.common.cpu |
|
|
|
if use_cuda: |
|
torch.cuda.set_device(cfg.distributed_training.device_id) |
|
|
|
|
|
overrides = eval(cfg.common_eval.model_overrides) |
|
|
|
if cfg.task._name == "vqa_gen": |
|
overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand" |
|
|
|
logger.info("loading model(s) from {}".format(cfg.common_eval.path)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
strict = kwargs['strict'] |
|
logger.info('load checkpoint, strict:{}'.format(strict)) |
|
|
|
if kwargs["zero_shot"]: |
|
for arg_name, arg_val in overrides.items(): |
|
cfg.task[arg_name] = arg_val |
|
|
|
|
|
if hasattr(cfg.task, "add_caption"): |
|
cfg.task.add_caption = False |
|
print("cfg.task", cfg.task) |
|
task = tasks.setup_task(cfg.task) |
|
|
|
|
|
|
|
models, saved_cfg = checkpoint_utils.load_model_ensemble( |
|
utils.split_paths(cfg.common_eval.path), |
|
arg_overrides=overrides, |
|
task=task, |
|
suffix=cfg.checkpoint.checkpoint_suffix, |
|
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict), |
|
num_shards=cfg.checkpoint.checkpoint_shard_count, |
|
) |
|
for m in models: |
|
m.encoder.sample_patch_num = 776 |
|
saved_cfg.task = cfg.task |
|
|
|
else: |
|
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( |
|
utils.split_paths(cfg.common_eval.path), |
|
arg_overrides=overrides, |
|
suffix=cfg.checkpoint.checkpoint_suffix, |
|
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict), |
|
num_shards=cfg.checkpoint.checkpoint_shard_count, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
kwargs['evaluate_cfg'] = cfg.task |
|
|
|
|
|
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) |
|
|
|
|
|
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)): |
|
if kwargs['ema_eval']: |
|
logger.info("loading EMA weights from {}".format(ckpt_path)) |
|
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) |
|
model.eval() |
|
print("use fp16", use_fp16) |
|
if use_fp16: |
|
|
|
model.half() |
|
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
|
model.cuda() |
|
model.prepare_for_inference_(cfg) |
|
|
|
|
|
itr = task.get_batch_iterator( |
|
dataset=task.dataset(cfg.dataset.gen_subset), |
|
max_tokens=cfg.dataset.max_tokens, |
|
max_sentences=cfg.dataset.batch_size, |
|
max_positions=utils.resolve_max_positions( |
|
task.max_positions(), *[m.max_positions() for m in models] |
|
), |
|
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, |
|
seed=cfg.common.seed, |
|
num_shards=cfg.distributed_training.distributed_world_size, |
|
shard_id=cfg.distributed_training.distributed_rank, |
|
num_workers=cfg.dataset.num_workers, |
|
data_buffer_size=cfg.dataset.data_buffer_size, |
|
).next_epoch_itr(shuffle=False) |
|
progress = progress_bar.progress_bar( |
|
itr, |
|
log_format=cfg.common.log_format, |
|
log_interval=cfg.common.log_interval, |
|
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), |
|
) |
|
|
|
|
|
generator = task.build_generator(models, cfg.generation) |
|
|
|
results = [] |
|
score_sum = torch.FloatTensor([0]).cuda() |
|
score_cnt = torch.FloatTensor([0]).cuda() |
|
|
|
score_sum_list = [] |
|
score_cnt_list = [] |
|
for sample in progress: |
|
if "net_input" not in sample: |
|
continue |
|
sample = utils.move_to_cuda(sample) if use_cuda else sample |
|
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample |
|
with torch.no_grad(): |
|
if kwargs["zero_shot"] and kwargs['noconstraints']: |
|
result, scores = zero_shot_step(task, generator, models, sample) |
|
else: |
|
result, scores = eval_step(task, generator, models, sample, **kwargs) |
|
|
|
|
|
|
|
scalar = False |
|
if isinstance(scores, list): |
|
if not isinstance(scores[0], list): |
|
try: |
|
tmp = sum(scores[0]) |
|
scalar=False |
|
except: |
|
scalar=True |
|
|
|
|
|
if isinstance(scores, list) and not scalar: |
|
names = result[0] |
|
result = result[1] |
|
if len(score_sum_list) == 0: |
|
score_sum_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))] |
|
score_cnt_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))] |
|
|
|
for i in range(len(scores)): |
|
|
|
|
|
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0 |
|
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0 |
|
else: |
|
for i in range(len(scores)): |
|
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0 |
|
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0 |
|
else: |
|
score_sum += sum(scores) if scores is not None else 0 |
|
score_cnt += len(scores) if scores is not None else 0 |
|
results += result |
|
progress.log({"sentences": sample["nsentences"]}) |
|
|
|
|
|
|
|
if len(score_sum_list) > 0: |
|
print(names, len(score_sum_list)) |
|
for i in range(len(score_sum_list)): |
|
print(names[i]) |
|
merge_results(task, cfg, logger, score_cnt_list[i], score_sum_list[i], results) |
|
else: |
|
merge_results(task, cfg, logger, score_cnt, score_sum, results) |
|
|
|
|
|
def cli_main(): |
|
parser = options.get_generation_parser() |
|
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.") |
|
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.") |
|
parser.add_argument("--zero-shot", action='store_true') |
|
parser.add_argument("--strict", action='store_false') |
|
parser.add_argument("--noconstraints", action='store_true') |
|
args = options.parse_args_and_arch(parser) |
|
cfg = convert_namespace_to_omegaconf(args) |
|
distributed_utils.call_main( |
|
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, |
|
zero_shot=args.zero_shot, strict=args.strict, noconstraints=args.noconstraints |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|