yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
raw
history blame contribute delete
6.93 kB
from typing import List, Tuple, Dict
import torch
from torch import Tensor, nn
from torchtext.vocab import Vocab
import tokenizers as tk
from ..utils import pred_token_within_range, subsequent_mask
from ..vocab import (
HTML_TOKENS,
TASK_TOKENS,
RESERVED_TOKENS,
BBOX_TOKENS,
)
VALID_HTML_TOKEN = ["<eos>"] + HTML_TOKENS
INVALID_CELL_TOKEN = (
["<sos>", "<pad>", "<empty>", "<sep>"] + TASK_TOKENS + RESERVED_TOKENS
)
VALID_BBOX_TOKEN = [
"<eos>"
] + BBOX_TOKENS # image size will be addressed after instantiation
class Batch:
"""Wrap up a batch of training samples with different training targets.
The input is not torch tensor
Shape of the image (src): B, S, E
Shape of the text (tgt): B, N, S, E (M includes 1 table detection, 1 structure, 1 cell, and multiple bbox)
Reshape text to (B * N, S, E) and inflate the image to match the shape of the text
Args:
----
device: gpu id
"""
def __init__(
self,
device: torch.device,
target: str,
vocab: Vocab,
obj: List,
) -> None:
self.device = device
self.image = obj[0].to(device)
self.name = obj[1]["filename"]
self.target = target
self.vocab = vocab
self.image_size = self.image.shape[-1]
if "table" in target:
raise NotImplementedError
if "html" in target:
self.valid_html_token = [vocab.token_to_id(i) for i in VALID_HTML_TOKEN]
(
self.html_src,
self.html_tgt,
self.html_casual_mask,
self.html_padding_mask,
) = self._prepare_transformer_input(obj[1]["html"])
if "cell" in target:
self.invalid_cell_token = [vocab.token_to_id(i) for i in INVALID_CELL_TOKEN]
(
self.cell_src,
self.cell_tgt,
self.cell_casual_mask,
self.cell_padding_mask,
) = self._prepare_transformer_input(obj[1]["cell"])
if "bbox" in target:
(
self.bbox_src,
self.bbox_tgt,
self.bbox_casual_mask,
self.bbox_padding_mask,
) = self._prepare_transformer_input(obj[1]["bbox"])
def _prepare_transformer_input(
self, seq: List[tk.Encoding]
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
tmp = [i.ids for i in seq]
tmp = torch.tensor(tmp, dtype=torch.int32)
src = tmp[:, :-1].to(self.device)
tgt = tmp[:, 1:].type(torch.LongTensor).to(self.device)
casual_mask = subsequent_mask(src.shape[-1]).to(self.device)
tmp = [i.attention_mask[:-1] for i in seq] # padding mask
tmp = torch.tensor(tmp, dtype=torch.bool)
padding_mask = (~tmp).to(self.device)
return src, tgt, casual_mask, padding_mask
def _inference_one_task(
self, model, memory, src, casual_mask, padding_mask, use_ddp
):
if use_ddp:
out = model.module.decode(memory, src, casual_mask, padding_mask)
out = model.module.generator(out)
else:
out = model.decode(memory, src, casual_mask, padding_mask)
out = model.generator(out)
return out
def inference(
self,
model: nn.Module,
criterion: nn.Module,
criterion_bbox: nn.Module = None,
loss_weights: dict = None,
use_ddp: bool = True,
) -> Tuple[Dict, Dict]:
pred = dict()
loss = dict(table=0, html=0, cell=0, bbox=0)
if use_ddp:
memory = model.module.encode(self.image)
else:
memory = model.encode(self.image)
# inference + suppress invalid logits + compute loss
if "html" in self.target:
out_html = self._inference_one_task(
model,
memory,
self.html_src,
self.html_casual_mask,
self.html_padding_mask,
use_ddp,
)
pred["html"] = pred_token_within_range(
out_html, white_list=self.valid_html_token
).permute(0, 2, 1)
loss["html"] = criterion(pred["html"], self.html_tgt)
if "cell" in self.target:
out_cell = self._inference_one_task(
model,
memory,
self.cell_src,
self.cell_casual_mask,
self.cell_padding_mask,
use_ddp,
)
pred["cell"] = pred_token_within_range(
out_cell, black_list=self.invalid_cell_token
).permute(0, 2, 1)
loss["cell"] = criterion(pred["cell"], self.cell_tgt)
if "bbox" in self.target:
assert criterion_bbox is not None
out_bbox = self._inference_one_task(
model,
memory,
self.bbox_src,
self.bbox_casual_mask,
self.bbox_padding_mask,
use_ddp,
)
pred["bbox"] = out_bbox.permute(0, 2, 1)
loss["bbox"] = criterion_bbox(pred["bbox"], self.bbox_tgt)
total = 0.0
for k, v in loss_weights.items():
total += loss[k] * v
loss["total"] = total
return loss, pred
def configure_optimizer_weight_decay(
model: nn.Module, weight_decay: float
) -> List[Dict]:
weight_decay_blacklist = (nn.LayerNorm, nn.BatchNorm2d, nn.Embedding)
if hasattr(model, "no_weight_decay"):
skip_list = model.no_weight_decay()
decay = set()
no_decay = set()
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, weight_decay_blacklist):
no_decay.add(fpn)
elif pn in skip_list:
no_decay.add(fpn)
param_dict = {pn: p for pn, p in model.named_parameters()}
decay = param_dict.keys() - no_decay
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
return optim_groups
def turn_off_beit_grad(model: nn.Module):
"Freeze BEiT pretrained weights."
for param in model.encoder.parameters():
param.requires_grad = False
for param in model.backbone.parameters():
param.requires_grad = False
for param in model.pos_embed.parameters():
param.requires_grad = False
def turn_on_beit_grad(model: nn.Module):
for param in model.parameters():
param.requires_grad = True