import os import torch import collections import logging from tqdm import tqdm, trange import json import bs4 from os import path as osp from bs4 import BeautifulSoup as bs # from transformers.models.bert.tokenization_bert import BasicTokenizer, whitespace_tokenize from torch.utils.data import Dataset import networkx as nx from lxml import etree import pickle # from transformers.tokenization_bert import BertTokenizer from transformers import BertTokenizer import argparse tags_dict = {'a': 0, 'abbr': 1, 'acronym': 2, 'address': 3, 'altGlyph': 4, 'altGlyphDef': 5, 'altGlyphItem': 6, 'animate': 7, 'animateColor': 8, 'animateMotion': 9, 'animateTransform': 10, 'applet': 11, 'area': 12, 'article': 13, 'aside': 14, 'audio': 15, 'b': 16, 'base': 17, 'basefont': 18, 'bdi': 19, 'bdo': 20, 'bgsound': 21, 'big': 22, 'blink': 23, 'blockquote': 24, 'body': 25, 'br': 26, 'button': 27, 'canvas': 28, 'caption': 29, 'center': 30, 'circle': 31, 'cite': 32, 'clipPath': 33, 'code': 34, 'col': 35, 'colgroup': 36, 'color-profile': 37, 'content': 38, 'cursor': 39, 'data': 40, 'datalist': 41, 'dd': 42, 'defs': 43, 'del': 44, 'desc': 45, 'details': 46, 'dfn': 47, 'dialog': 48, 'dir': 49, 'div': 50, 'dl': 51, 'dt': 52, 'ellipse': 53, 'em': 54, 'embed': 55, 'feBlend': 56, 'feColorMatrix': 57, 'feComponentTransfer': 58, 'feComposite': 59, 'feConvolveMatrix': 60, 'feDiffuseLighting': 61, 'feDisplacementMap': 62, 'feDistantLight': 63, 'feFlood': 64, 'feFuncA': 65, 'feFuncB': 66, 'feFuncG': 67, 'feFuncR': 68, 'feGaussianBlur': 69, 'feImage': 70, 'feMerge': 71, 'feMergeNode': 72, 'feMorphology': 73, 'feOffset': 74, 'fePointLight': 75, 'feSpecularLighting': 76, 'feSpotLight': 77, 'feTile': 78, 'feTurbulence': 79, 'fieldset': 80, 'figcaption': 81, 'figure': 82, 'filter': 83, 'font-face-format': 84, 'font-face-name': 85, 'font-face-src': 86, 'font-face-uri': 87, 'font-face': 88, 'font': 89, 'footer': 90, 'foreignObject': 91, 'form': 92, 'frame': 93, 'frameset': 94, 'g': 95, 'glyph': 96, 'glyphRef': 97, 'h1': 98, 'h2': 99, 'h3': 100, 'h4': 101, 'h5': 102, 'h6': 103, 'head': 104, 'header': 105, 'hgroup': 106, 'hkern': 107, 'hr': 108, 'html': 109, 'i': 110, 'iframe': 111, 'image': 112, 'img': 113, 'input': 114, 'ins': 115, 'kbd': 116, 'keygen': 117, 'label': 118, 'legend': 119, 'li': 120, 'line': 121, 'linearGradient': 122, 'link': 123, 'main': 124, 'map': 125, 'mark': 126, 'marker': 127, 'marquee': 128, 'mask': 129, 'math': 130, 'menu': 131, 'menuitem': 132, 'meta': 133, 'metadata': 134, 'meter': 135, 'missing-glyph': 136, 'mpath': 137, 'nav': 138, 'nobr': 139, 'noembed': 140, 'noframes': 141, 'noscript': 142, 'object': 143, 'ol': 144, 'optgroup': 145, 'option': 146, 'output': 147, 'p': 148, 'param': 149, 'path': 150, 'pattern': 151, 'picture': 152, 'plaintext': 153, 'polygon': 154, 'polyline': 155, 'portal': 156, 'pre': 157, 'progress': 158, 'q': 159, 'radialGradient': 160, 'rb': 161, 'rect': 162, 'rp': 163, 'rt': 164, 'rtc': 165, 'ruby': 166, 's': 167, 'samp': 168, 'script': 169, 'section': 170, 'select': 171, 'set': 172, 'shadow': 173, 'slot': 174, 'small': 175, 'source': 176, 'spacer': 177, 'span': 178, 'stop': 179, 'strike': 180, 'strong': 181, 'style': 182, 'sub': 183, 'summary': 184, 'sup': 185, 'svg': 186, 'switch': 187, 'symbol': 188, 'table': 189, 'tbody': 190, 'td': 191, 'template': 192, 'text': 193, 'textPath': 194, 'textarea': 195, 'tfoot': 196, 'th': 197, 'thead': 198, 'time': 199, 'title': 200, 'tr': 201, 'track': 202, 'tref': 203, 'tspan': 204, 'tt': 205, 'u': 206, 'ul': 207, 'use': 208, 'var': 209, 'video': 210, 'view': 211, 'vkern': 212, 'wbr': 213, 'xmp': 214} def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens # ---------- copied ! -------------- def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) for new_start in range(input_start, input_end + 1): for new_end in range(input_end, new_start - 1, -1): text_span = " ".join([w for w in doc_tokens[new_start:(new_end + 1)] if w[0] != '<' or w[-1] != '>']) if text_span == tok_answer_text: return new_start, new_end return input_start, input_end class StrucDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (*torch.Tensor): tensors that have the same size of the first dimension. page_ids (list): the corresponding page ids of the input features. cnn_feature_dir (str): the direction where the cnn features are stored. token_to_tag (torch.Tensor): the mapping from each token to its corresponding tag id. """ def __init__(self, *tensors, pad_id=0, all_expended_attention_mask=None, all_graph_names=None, all_token_to_tag=None, page_ids=None, attention_width=None, has_tree_attention_bias = False): tensors = tuple(tensor for tensor in tensors) assert all(len(tensors[0]) == len(tensor) for tensor in tensors) if all_expended_attention_mask is not None: assert len(tensors[0]) == len(all_expended_attention_mask) tensors += (all_expended_attention_mask,) self.tensors = tensors self.page_ids = page_ids self.all_graph_names = all_graph_names self.all_token_to_tag = all_token_to_tag self.pad_id = pad_id self.attention_width = attention_width self.has_tree_attention_bias = has_tree_attention_bias def __getitem__(self, index): output = [tensor[index] for tensor in self.tensors] input_id = output[0] attention_mask = output[1] if not self.attention_width is None or self.has_tree_attention_bias: assert self.all_graph_names is not None , ("For non-empty attention_width / tree rel pos," "Graph names must be sent in!") if self.all_graph_names is not None: assert self.all_token_to_tag is not None graph_name = self.all_graph_names[index] token_to_tag = self.all_token_to_tag[index] with open(graph_name,"rb") as f: node_pairs_lengths = pickle.load(f) # node_pairs_lengths = dict(nx.all_pairs_shortest_path_length(graph)) seq_len = len(token_to_tag) if self.has_tree_attention_bias: mat = [[0]*seq_len]*seq_len else: mat = None if self.attention_width is not None: emask = attention_mask.expand(seq_len,seq_len) else: emask = None for nid in range(seq_len): if input_id[nid]==self.pad_id: break for anid in range(nid+1,seq_len): if input_id[anid]==self.pad_id: break x_tid4nid = token_to_tag[nid] x_tid4anid = token_to_tag[anid] if x_tid4nid==x_tid4anid: continue try: xx = node_pairs_lengths[x_tid4nid] # x_tid4nid in valid tid list, or == -1 except: # x_tid4nid out of bound, like `question`, `sep` or `cls` xx = node_pairs_lengths[-1] x_tid4nid=-1 try: dis = xx[x_tid4anid] # x_tid4anid in valid tid list, or == -1 except: # x_tid4nid out of bound, like `question`, `sep` or `cls` dis = xx[-1] x_tid4anid = -1 # xx = node_pairs_lengths.get(tid4nid,node_pairs_lengths[-1]) # dis = xx.get(tid4anid,xx[-1]) if self.has_tree_attention_bias: if x_tid4nidself.attention_width: emask[nid][anid]=0 emask[anid][nid]=0 if self.attention_width is not None: output.append(emask) if self.has_tree_attention_bias: t_mat = torch.tensor(mat,dtype=torch.long) output.append(t_mat) return tuple(item for item in output) def __len__(self): return len(self.tensors[0]) def get_xpath4tokens(html_fn: str, unique_tids: set): xpath_map = {} tree = etree.parse(html_fn, etree.HTMLParser()) nodes = tree.xpath('//*') for node in nodes: tid = node.attrib.get("tid") if int(tid) in unique_tids: xpath_map[int(tid)] = tree.getpath(node) xpath_map[len(nodes)] = "/html" xpath_map[len(nodes) + 1] = "/html" return xpath_map def get_xpath_and_treeid4tokens(html_code, unique_tids, max_depth): unknown_tag_id = len(tags_dict) pad_tag_id = unknown_tag_id + 1 max_width = 1000 width_pad_id = 1001 pad_x_tag_seq = [pad_tag_id] * max_depth pad_x_subs_seq = [width_pad_id] * max_depth pad_x_box = [0,0,0,0] pad_tree_id_seq = [width_pad_id] * max_depth def xpath_soup(element): xpath_tags = [] xpath_subscripts = [] tree_index = [] child = element if element.name else element.parent for parent in child.parents: # type: bs4.element.Tag siblings = parent.find_all(child.name, recursive=False) para_siblings = parent.find_all(True, recursive=False) xpath_tags.append(child.name) xpath_subscripts.append( 0 if 1 == len(siblings) else next(i for i, s in enumerate(siblings, 1) if s is child)) tree_index.append(next(i for i, s in enumerate(para_siblings, 0) if s is child)) child = parent xpath_tags.reverse() xpath_subscripts.reverse() tree_index.reverse() return xpath_tags, xpath_subscripts, tree_index xpath_tag_map = {} xpath_subs_map = {} tree_id_map = {} for tid in unique_tids: element = html_code.find(attrs={'tid': tid}) if element is None: xpath_tags = pad_x_tag_seq xpath_subscripts = pad_x_subs_seq tree_index = pad_tree_id_seq xpath_tag_map[tid] = xpath_tags xpath_subs_map[tid] = xpath_subscripts tree_id_map[tid] = tree_index continue xpath_tags, xpath_subscripts, tree_index = xpath_soup(element) assert len(xpath_tags) == len(xpath_subscripts) assert len(xpath_tags) == len(tree_index) if len(xpath_tags) > max_depth: xpath_tags = xpath_tags[-max_depth:] xpath_subscripts = xpath_subscripts[-max_depth:] # tree_index = tree_index[-max_depth:] xpath_tags = [tags_dict.get(name, unknown_tag_id) for name in xpath_tags] xpath_subscripts = [min(i, max_width) for i in xpath_subscripts] tree_index = [min(i, max_width) for i in tree_index] # we do not append them to max depth here xpath_tags += [pad_tag_id] * (max_depth - len(xpath_tags)) xpath_subscripts += [width_pad_id] * (max_depth - len(xpath_subscripts)) # tree_index += [width_pad_id] * (max_depth - len(tree_index)) xpath_tag_map[tid] = xpath_tags xpath_subs_map[tid] = xpath_subscripts tree_id_map[tid] = tree_index return xpath_tag_map, xpath_subs_map, tree_id_map # ---------- copied ! -------------- def _check_is_max_context(doc_spans, cur_span_index, position): best_score = None best_span_index = None for (span_index, doc_span) in enumerate(doc_spans): end = doc_span.start + doc_span.length - 1 if position < doc_span.start: continue if position > end: continue num_left_context = position - doc_span.start num_right_context = end - position score = min(num_left_context, num_right_context) + 0.01 * doc_span.length if best_score is None or score > best_score: best_score = score best_span_index = span_index return cur_span_index == best_span_index class SRCExample(object): r""" The Containers for SRC Examples. Arguments: doc_tokens (list[str]): the original tokens of the HTML file before dividing into sub-tokens. qas_id (str): the id of the corresponding question. tag_num (int): the total tag number in the corresponding HTML file, including the additional 'yes' and 'no'. question_text (str): the text of the corresponding question. orig_answer_text (str): the answer text provided by the dataset. all_doc_tokens (list[str]): the sub-tokens of the corresponding HTML file. start_position (int): the position where the answer starts in the all_doc_tokens. end_position (int): the position where the answer ends in the all_doc_tokens; NOTE that the answer tokens include the token at end_position. tok_to_orig_index (list[int]): the mapping from sub-tokens (all_doc_tokens) to origin tokens (doc_tokens). orig_to_tok_index (list[int]): the mapping from origin tokens (doc_tokens) to sub-tokens (all_doc_tokens). tok_to_tags_index (list[int]): the mapping from sub-tokens (all_doc_tokens) to the id of the deepest tag it belongs to. """ # the difference between T-PLM and H-PLM is just add and into the # original tokens and further-tokenized tokens def __init__(self, doc_tokens, qas_id, tag_num, # ?? is counted as one tag question_text=None, html_code=None, orig_answer_text=None, start_position=None, # in all_doc_tokens end_position=None, # in all_doc_tokens tok_to_orig_index=None, orig_to_tok_index=None, all_doc_tokens=None, tok_to_tags_index=None, xpath_tag_map=None, xpath_subs_map=None, xpath_box=None, tree_id_map=None, visible_matrix=None, ): self.doc_tokens = doc_tokens self.qas_id = qas_id self.tag_num = tag_num self.question_text = question_text self.html_code = html_code self.orig_answer_text = orig_answer_text self.start_position = start_position self.end_position = end_position self.tok_to_orig_index = tok_to_orig_index self.orig_to_tok_index = orig_to_tok_index self.all_doc_tokens = all_doc_tokens self.tok_to_tags_index = tok_to_tags_index self.xpath_tag_map = xpath_tag_map self.xpath_subs_map = xpath_subs_map self.xpath_box = xpath_box self.tree_id_map = tree_id_map self.visible_matrix = visible_matrix def __str__(self): return self.__repr__() def __repr__(self): """ s = "" s += "qas_id: %s" % self.qas_id s += ", question_text: %s" % ( self.question_text) s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) if self.start_position: s += ", start_position: %d" % self.start_position if self.end_position: s += ", end_position: %d" % self.end_position """ s = "[INFO]\n" s += f"qas_id ({type(self.qas_id)}): {self.qas_id}\n" s += f"tag_num ({type(self.tag_num)}): {self.tag_num}\n" s += f"question_text ({type(self.question_text)}): {self.question_text}\n" s += f"html_code ({type(self.html_code)}): {self.html_code}\n" s += f"orig_answer_text ({type(self.orig_answer_text)}): {self.orig_answer_text}\n" s += f"start_position ({type(self.start_position)}): {self.start_position}\n" s += f"end_position ({type(self.end_position)}): {self.end_position}\n" s += f"tok_to_orig_index ({type(self.tok_to_orig_index)}): {self.tok_to_orig_index}\n" s += f"orig_to_tok_index ({type(self.orig_to_tok_index)}): {self.orig_to_tok_index}\n" s += f"all_doc_tokens ({type(self.all_doc_tokens)}): {self.all_doc_tokens}\n" s += f"tok_to_tags_index ({type(self.tok_to_tags_index)}): {self.tok_to_tags_index}\n" s += f"xpath_tag_map ({type(self.xpath_tag_map)}): {self.xpath_tag_map}\n" s += f"xpath_subs_map ({type(self.xpath_subs_map)}): {self.xpath_subs_map}\n" s += f"tree_id_map ({type(self.tree_id_map)}): {self.tree_id_map}\n" return s class InputFeatures(object): r""" The Container for the Features of Input Doc Spans. Arguments: unique_id (int): the unique id of the input doc span. example_index (int): the index of the corresponding SRC Example of the input doc span. page_id (str): the id of the corresponding web page of the question. doc_span_index (int): the index of the doc span among all the doc spans which corresponding to the same SRC Example. tokens (list[str]): the sub-tokens of the input sequence, including cls token, sep tokens, and the sub-tokens of the question and HTML file. token_to_orig_map (dict[int, int]): the mapping from the HTML file's sub-tokens in the sequence tokens (tokens) to the origin tokens (all_tokens in the corresponding SRC Example). token_is_max_context (dict[int, bool]): whether the current doc span contains the max pre- and post-context for each HTML file's sub-tokens. input_ids (list[int]): the ids of the sub-tokens in the input sequence (tokens). input_mask (list[int]): use 0/1 to distinguish the input sequence from paddings. segment_ids (list[int]): use 0/1 to distinguish the question and the HTML files. paragraph_len (int): the length of the HTML file's sub-tokens. start_position (int): the position where the answer starts in the input sequence (0 if the answer is not fully in the input sequence). end_position (int): the position where the answer ends in the input sequence; NOTE that the answer tokens include the token at end_position (0 if the answer is not fully in the input sequence). token_to_tag_index (list[int]): the mapping from sub-tokens of the input sequence to the id of the deepest tag it belongs to. is_impossible (bool): whether the answer is fully in the doc span. """ def __init__(self, unique_id, example_index, page_id, doc_span_index, tokens, token_to_orig_map, token_is_max_context, input_ids, input_mask, segment_ids, paragraph_len, start_position=None, end_position=None, token_to_tag_index=None, is_impossible=None, xpath_tags_seq=None, xpath_subs_seq=None, xpath_box_seq=None, extended_attention_mask=None): self.unique_id = unique_id self.example_index = example_index self.page_id = page_id self.doc_span_index = doc_span_index self.tokens = tokens self.token_to_orig_map = token_to_orig_map self.token_is_max_context = token_is_max_context self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.paragraph_len = paragraph_len self.start_position = start_position self.end_position = end_position self.token_to_tag_index = token_to_tag_index self.is_impossible = is_impossible self.xpath_tags_seq = xpath_tags_seq self.xpath_subs_seq = xpath_subs_seq self.xpath_box_seq = xpath_box_seq self.extended_attention_mask = extended_attention_mask def html_escape(html): r""" replace the special expressions in the html file for specific punctuation. """ html = html.replace('"', '"') html = html.replace('&', '&') html = html.replace('<', '<') html = html.replace('>', '>') html = html.replace(' ', ' ') return html def read_squad_examples(args, input_file, root_dir, is_training, tokenizer, simplify=False, max_depth=50, split_flag="n-eon", attention_width=None): r""" pre-process the data in json format into SRC Examples. Arguments: split_flag: attention_width: input_file (str): the inputting data file in json format. root_dir (str): the root directory of the raw WebSRC dataset, which contains the HTML files. is_training (bool): True if processing the training set, else False. tokenizer (Tokenizer): the tokenizer for PLM in use. method (str): the name of the method in use, choice: ['T-PLM', 'H-PLM', 'V-PLM']. simplify (bool): when setting to Ture, the returned Example will only contain document tokens, the id of the question-answers, and the total tag number in the corresponding html files. Returns: list[SRCExamples]: the resulting SRC Examples, contained all the needed information for the feature generation process, except when the argument simplify is setting to True; set[str]: all the tag names appeared in the processed dataset, e.g.
, ,

, etc.. """ with open(input_file, "r", encoding='utf-8') as reader: input_data = json.load(reader)["data"] pad_tree_id_seq = [1001] * max_depth def is_whitespace(c): if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: return True return False def html_to_text_list(h): tag_num, text_list = 0, [] for element in h.descendants: if (type(element) == bs4.element.NavigableString) and (element.strip()): text_list.append(element.strip()) if type(element) == bs4.element.Tag: tag_num += 1 return text_list, tag_num + 2 # + 2 because we treat the additional 'yes' and 'no' as two special tags. def html_to_text(h): tag_list = set() for element in h.descendants: if type(element) == bs4.element.Tag: element.attrs = {} temp = str(element).split() tag_list.add(temp[0]) tag_list.add(temp[-1]) return html_escape(str(h)), tag_list def adjust_offset(offset, text): text_list = text.split() cnt, adjustment = 0, [] for t in text_list: if not t: continue if t[0] == '<' and t[-1] == '>': adjustment.append(offset.index(cnt)) else: cnt += 1 add = 0 adjustment.append(len(offset)) for i in range(len(offset)): while i >= adjustment[add]: add += 1 offset[i] += add return offset def e_id_to_t_id(e_id, html): t_id = 0 for element in html.descendants: if type(element) == bs4.element.NavigableString and element.strip(): t_id += 1 if type(element) == bs4.element.Tag: if int(element.attrs['tid']) == e_id: break return t_id def calc_num_from_raw_text_list(t_id, l): n_char = 0 for i in range(t_id): n_char += len(l[i]) + 1 return n_char def word_to_tag_from_text(tokens, h): cnt, w_t, path = -1, [], [] unique_tids = set() for t in tokens[0:-2]: if len(t) < 2: w_t.append(path[-1]) unique_tids.add(path[-1]) continue if t[0] == '<' and t[-2] == '/': cnt += 1 w_t.append(cnt) unique_tids.add(cnt) continue if t[0] == '<' and t[1] != '/': cnt += 1 path.append(cnt) w_t.append(path[-1]) unique_tids.add(path[-1]) if t[0] == '<' and t[1] == '/': del path[-1] w_t.append(cnt + 1) unique_tids.add(cnt + 1) w_t.append(cnt + 2) unique_tids.add(cnt + 2) assert len(w_t) == len(tokens) assert len(path) == 0, print(h) return w_t, unique_tids def word_tag_offset(html): cnt, w_t, t_w, tags, tags_tids = 0, [], [], [], [] for element in html.descendants: if type(element) == bs4.element.Tag: content = ' '.join(list(element.strings)).split() t_w.append({'start': cnt, 'len': len(content)}) tags.append('<' + element.name + '>') tags_tids.append(element['tid']) elif type(element) == bs4.element.NavigableString and element.strip(): text = element.split() tid = element.parent['tid'] ind = tags_tids.index(tid) for _ in text: w_t.append(ind) cnt += 1 assert cnt == len(w_t) w_t.append(len(t_w)) w_t.append(len(t_w) + 1) return w_t def subtoken_tag_offset(html, s_tok): w_t = word_tag_offset(html) s_t = [] unique_tids = set() for i in range(len(s_tok)): s_t.append(w_t[s_tok[i]]) unique_tids.add(w_t[s_tok[i]]) return s_t, unique_tids def subtoken_tag_offset_plus_eon(html, s_tok, all_doc_tokens): w_t = word_tag_offset(html) s_t = [] unique_tids = set() offset = 0 for i in range(len(s_tok)): if all_doc_tokens[i] not in ('', tokenizer.sep_token, tokenizer.cls_token): s_t.append(w_t[s_tok[i] - offset]) unique_tids.add(w_t[s_tok[i] - offset]) else: prev_tid = s_t[-1] s_t.append(prev_tid) offset += 1 return s_t, unique_tids def check_visible(path1, path2, attention_width): i = 0 j = 0 dis = 0 lp1 = len(path1) lp2 = len(path2) while i < lp1 and j < lp2 and path1[i] == path2[j]: i += 1 j += 1 if i < lp1 and j < lp2: dis += lp1 - i + lp2 - j else: if i == lp1: dis += lp2 - j else: dis += lp1 - i if dis <= attention_width: return True return False def from_tids_to_box(html_fn, unique_tids, json_fn): sorted_ids = sorted(unique_tids) f = open(json_fn, 'r') data = json.load(f) orig_width, orig_height = data['2']['rect']['width'], data['2']['rect']['height'] orig_x, orig_y = data['2']['rect']['x'], data['2']['rect']['y'] return_dict = {} for id in sorted_ids: if str(id) in data: x, y, width, height = data[str(id)]['rect']['x'], data[str(id)]['rect']['y'], data[str(id)]['rect']['width'], data[str(id)]['rect']['height'] resize_x = (x - orig_x) * 1000 // orig_width resize_y = (y - orig_y) * 1000 // orig_height resize_width = width * 1000 // orig_width resize_height = height * 1000 // orig_height # if not (resize_x <= 1000 and resize_y <= 1000): # print('before', x, y, width, height) # print('after', resize_x, resize_y, resize_width, resize_height) # print('file name ', html_fn) # # exit(0) if resize_x < 0 or resize_y < 0 or resize_width < 0 or resize_height < 0: # meaningless return_dict[id] = [0, 0, 0, 0] else: return_dict[id] = [int(resize_x), int(resize_y), int(resize_x+resize_width), int(resize_y+resize_height)] else: return_dict[id] = [0,0,0,0] return return_dict def get_visible_matrix(unique_tids, tree_id_map, attention_width): if attention_width is None: return None unique_tids_list = list(unique_tids) visible_matrix = collections.defaultdict(list) for i in range(len(unique_tids_list)): if tree_id_map[unique_tids_list[i]] == pad_tree_id_seq: visible_matrix[unique_tids_list[i]] = list() continue visible_matrix[unique_tids_list[i]].append(unique_tids_list[i]) for j in range(i + 1, len(unique_tids_list)): if check_visible(tree_id_map[unique_tids_list[i]], tree_id_map[unique_tids_list[j]], attention_width): visible_matrix[unique_tids_list[i]].append(unique_tids_list[j]) visible_matrix[unique_tids_list[j]].append(unique_tids_list[i]) return visible_matrix examples = [] all_tag_list = set() total_num = sum([len(entry["websites"]) for entry in input_data]) with tqdm(total=total_num, desc="Converting websites to examples") as t: for entry in input_data: # print('entry', entry) domain = entry["domain"] for website in entry["websites"]: # print('website', website) # Generate Doc Tokens page_id = website["page_id"] # print('page_id', page_id) curr_dir = osp.join(root_dir, domain, page_id[0:2], 'processed_data') html_fn = osp.join(curr_dir, page_id + '.html') json_fn = osp.join(curr_dir, page_id + '.json') # print('html', html_fn) html_file = open(html_fn).read() html_code = bs(html_file, "html.parser") raw_text_list, tag_num = html_to_text_list(html_code) # print(raw_text_list) # print(tag_num) # exit(0) doc_tokens = [] char_to_word_offset = [] # print(split_flag) # n-eon # exit(0) if split_flag in ["y-eon", "y-sep", "y-cls"]: prev_is_whitespace = True for i, doc_string in enumerate(raw_text_list): for c in doc_string: if is_whitespace(c): prev_is_whitespace = True else: if prev_is_whitespace: doc_tokens.append(c) else: doc_tokens[-1] += c prev_is_whitespace = False char_to_word_offset.append(len(doc_tokens) - 1) if i < len(raw_text_list) - 1: prev_is_whitespace = True char_to_word_offset.append(len(doc_tokens) - 1) if split_flag == "y-eon": doc_tokens.append('') elif split_flag == "y-sep": doc_tokens.append(tokenizer.sep_token) elif split_flag == "y-cls": doc_tokens.append(tokenizer.cls_token) else: raise ValueError("Split flag should be `y-eon` or `y-sep` or `y-cls`") prev_is_whitespace = True elif split_flag =="n-eon" or split_flag == "y-hplm": page_text = ' '.join(raw_text_list) prev_is_whitespace = True for c in page_text: if is_whitespace(c): prev_is_whitespace = True else: if prev_is_whitespace: doc_tokens.append(c) else: doc_tokens[-1] += c prev_is_whitespace = False char_to_word_offset.append(len(doc_tokens) - 1) doc_tokens.append('no') char_to_word_offset.append(len(doc_tokens) - 1) doc_tokens.append('yes') char_to_word_offset.append(len(doc_tokens) - 1) if split_flag == "y-hplm": real_text, tag_list = html_to_text(bs(html_file)) all_tag_list = all_tag_list | tag_list char_to_word_offset = adjust_offset(char_to_word_offset, real_text) doc_tokens = real_text.split() doc_tokens.append('no') doc_tokens.append('yes') doc_tokens = [i for i in doc_tokens if i] else: tag_list = [] assert len(doc_tokens) == char_to_word_offset[-1] + 1, (len(doc_tokens), char_to_word_offset[-1]) if simplify: for qa in website["qas"]: qas_id = qa["id"] example = SRCExample(doc_tokens=doc_tokens, qas_id=qas_id, tag_num=tag_num) examples.append(example) t.update(1) else: # Tokenize all doc tokens # tokenize sth like < / > tok_to_orig_index = [] orig_to_tok_index = [] all_doc_tokens = [] for (i, token) in enumerate(doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) if token in tag_list: sub_tokens = [token] else: sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) # Generate extra information for features if split_flag in ["y-eon", "y-sep", "y-cls"]: tok_to_tags_index, unique_tids = subtoken_tag_offset_plus_eon(html_code, tok_to_orig_index, all_doc_tokens) elif split_flag == "n-eon": tok_to_tags_index, unique_tids = subtoken_tag_offset(html_code, tok_to_orig_index) elif split_flag == "y-hplm": tok_to_tags_index, unique_tids = word_to_tag_from_text(all_doc_tokens, html_code) else: raise ValueError("Unsupported split_flag!") xpath_tag_map, xpath_subs_map, tree_id_map = get_xpath_and_treeid4tokens(html_code, unique_tids, max_depth=max_depth) # tree_id_map : neither truncated nor padded xpath_box = from_tids_to_box(html_fn, unique_tids, json_fn) assert tok_to_tags_index[-1] == tag_num - 1, (tok_to_tags_index[-1], tag_num - 1) # we get attention_mask here visible_matrix = get_visible_matrix(unique_tids, tree_id_map, attention_width=attention_width) # Process each qas, which is mainly calculate the answer position for qa in website["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position = None end_position = None orig_answer_text = None if is_training: if len(qa["answers"]) != 1: raise ValueError( "For training, each question should have exactly 1 answer.") answer = qa["answers"][0] orig_answer_text = answer["text"] if answer["element_id"] == -1: num_char = len(char_to_word_offset) - 2 else: num_char = calc_num_from_raw_text_list(e_id_to_t_id(answer["element_id"], html_code), raw_text_list) answer_offset = num_char + answer["answer_start"] answer_length = len(orig_answer_text) if answer["element_id"] != -1 else 1 start_position = char_to_word_offset[answer_offset] end_position = char_to_word_offset[answer_offset + answer_length - 1] # Only add answers where the text can be exactly recovered from the # document. If this CAN'T happen it's likely due to weird Unicode # stuff so we will just skip the example. # # Note that this means for training mode, every example is NOT # guaranteed to be preserved. actual_text = " ".join([w for w in doc_tokens[start_position:(end_position + 1)] if (w[0] != '<' or w[-1] != '>') and w != "" and w != tokenizer.sep_token and w != tokenizer.cls_token]) cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text)) if actual_text.find(cleaned_answer_text) == -1: logging.warning("Could not find answer of question %s: '%s' vs. '%s'", qa['id'], actual_text, cleaned_answer_text) continue example = SRCExample( doc_tokens=doc_tokens, qas_id=qas_id, tag_num=tag_num, question_text=question_text, html_code=html_code, orig_answer_text=orig_answer_text, start_position=start_position, end_position=end_position, tok_to_orig_index=tok_to_orig_index, orig_to_tok_index=orig_to_tok_index, all_doc_tokens=all_doc_tokens, tok_to_tags_index=tok_to_tags_index, xpath_tag_map=xpath_tag_map, xpath_subs_map=xpath_subs_map, xpath_box=xpath_box, tree_id_map=tree_id_map, visible_matrix=visible_matrix ) examples.append(example) if args.web_num_features != 0: if len(examples) >= args.web_num_features: return examples, all_tag_list t.update(1) return examples, all_tag_list def load_and_cache_examples(args, tokenizer, max_depth=50, evaluate=False, output_examples=False): r""" Load and process the raw data. """ if args.local_rank not in [-1, 0] and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, # and the others will use the cache # Load data features from cache or dataset file input_file = args.web_eval_file if evaluate else args.web_train_file cached_features_file = os.path.join(args.cache_dir, 'cached_{}_{}_{}_{}_{}_{}'.format( 'dev' if evaluate else 'train', "markuplm", str(args.max_seq_length), str(max_depth), args.web_num_features, args.model_type )) if not os.path.exists(os.path.dirname(cached_features_file)): os.makedirs(os.path.dirname(cached_features_file)) if os.path.exists(cached_features_file) and not args.overwrite_cache: print("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) if output_examples: examples, tag_list = read_squad_examples(args, input_file=input_file, root_dir=args.web_root_dir, is_training=not evaluate, tokenizer=tokenizer, simplify=True, max_depth=max_depth ) else: examples = None else: print("Creating features from dataset file at %s", input_file) examples, _ = read_squad_examples(args, input_file=input_file, root_dir=args.web_root_dir, is_training=not evaluate, tokenizer=tokenizer, simplify=False, max_depth=max_depth) features = convert_examples_to_features(examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, pad_token=tokenizer.pad_token_id, sequence_a_segment_id=0, sequence_b_segment_id=0, max_depth=max_depth) if args.local_rank in [-1, 0] and args.web_save_features: print("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) if args.local_rank == 0 and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, # and the others will use the cache # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_xpath_tags_seq = torch.tensor([f.xpath_tags_seq for f in features], dtype=torch.long) all_xpath_subs_seq = torch.tensor([f.xpath_subs_seq for f in features], dtype=torch.long) all_xpath_box_seq = torch.tensor([f.xpath_box_seq for f in features], dtype=torch.long) if evaluate: all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) dataset = StrucDataset(all_input_ids, all_input_mask, all_segment_ids, all_feature_index, all_xpath_tags_seq, all_xpath_subs_seq, all_xpath_box_seq) else: all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) dataset = StrucDataset(all_input_ids, all_input_mask, all_segment_ids, all_xpath_tags_seq, all_xpath_subs_seq, all_start_positions, all_end_positions, all_xpath_box_seq) if output_examples: dataset = (dataset, examples, features) return dataset def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, cls_token='[CLS]', sep_token='[SEP]', pad_token=0, sequence_a_segment_id=0, sequence_b_segment_id=1, cls_token_segment_id=0, pad_token_segment_id=0, mask_padding_with_zero=True, max_depth=50): r""" Converting the SRC Examples further into the features for all the input doc spans. Arguments: examples (list[SRCExample]): the list of SRC Examples to process. tokenizer (Tokenizer): the tokenizer for PLM in use. max_seq_length (int): the max length of the total sub-token sequence, including the question, cls token, sep tokens, and documents; if the length of the input is bigger than max_seq_length, the input will be cut into several doc spans. doc_stride (int): the stride length when the input is cut into several doc spans. max_query_length (int): the max length of the sub-token sequence of the questions; the question will be truncate if it is longer than max_query_length. is_training (bool): True if processing the training set, else False. cls_token (str): the cls token in use, default is '[CLS]'. sep_token (str): the sep token in use, default is '[SEP]'. pad_token (int): the id of the padding token in use when the total sub-token length is smaller that max_seq_length, default is 0 which corresponding to the '[PAD]' token. sequence_a_segment_id: the segment id for the first sequence (the question), default is 0. sequence_b_segment_id: the segment id for the second sequence (the html file), default is 1. cls_token_segment_id: the segment id for the cls token, default is 0. pad_token_segment_id: the segment id for the padding tokens, default is 0. mask_padding_with_zero: determine the pattern of the returned input mask; 0 for padding tokens and 1 for others when True, and vice versa. Returns: list[InputFeatures]: the resulting input features for all the input doc spans """ pad_x_tag_seq = [216] * max_depth pad_x_subs_seq = [1001] * max_depth pad_x_box = [0,0,0,0] pad_tree_id_seq = [1001] * max_depth unique_id = 1000000000 features = [] for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")): xpath_tag_map = example.xpath_tag_map xpath_subs_map = example.xpath_subs_map xpath_box = example.xpath_box tree_id_map = example.tree_id_map visible_matrix = example.visible_matrix query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: query_tokens = query_tokens[0:max_query_length] tok_start_position = None tok_end_position = None if is_training: tok_start_position = example.orig_to_tok_index[example.start_position] if example.end_position < len(example.doc_tokens) - 1: tok_end_position = example.orig_to_tok_index[example.end_position + 1] - 1 else: tok_end_position = len(example.all_doc_tokens) - 1 (tok_start_position, tok_end_position) = _improve_answer_span( example.all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.orig_answer_text) # The -3 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 # We can have documents that are longer than the maximum sequence length. # To deal with this we do a sliding window approach, where we take chunks # of the up to our max length with a stride of `doc_stride`. _DocSpan = collections.namedtuple( # pylint: disable=invalid-name "DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(example.all_doc_tokens): length = len(example.all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(example.all_doc_tokens): break start_offset += min(length, doc_stride) for (doc_span_index, doc_span) in enumerate(doc_spans): tokens = [] token_to_orig_map = {} token_is_max_context = {} segment_ids = [] token_to_tag_index = [] # CLS token at the beginning tokens.append(cls_token) segment_ids.append(cls_token_segment_id) token_to_tag_index.append(example.tag_num) # Query tokens += query_tokens segment_ids += [sequence_a_segment_id] * len(query_tokens) token_to_tag_index += [example.tag_num] * len(query_tokens) # SEP token tokens.append(sep_token) segment_ids.append(sequence_a_segment_id) token_to_tag_index.append(example.tag_num) # Paragraph for i in range(doc_span.length): split_token_index = doc_span.start + i token_to_orig_map[len(tokens)] = example.tok_to_orig_index[split_token_index] token_to_tag_index.append(example.tok_to_tags_index[split_token_index]) is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(example.all_doc_tokens[split_token_index]) segment_ids.append(sequence_b_segment_id) paragraph_len = doc_span.length # SEP token tokens.append(sep_token) segment_ids.append(sequence_b_segment_id) token_to_tag_index.append(example.tag_num) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(pad_token) input_mask.append(0 if mask_padding_with_zero else 1) segment_ids.append(pad_token_segment_id) token_to_tag_index.append(example.tag_num) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(token_to_tag_index) == max_seq_length span_is_impossible = False start_position = None end_position = None if is_training: # For training, if our document chunk does not contain an annotation # we throw it out, since there is nothing to predict. doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 out_of_span = False if not (tok_start_position >= doc_start and tok_end_position <= doc_end): out_of_span = True if out_of_span: span_is_impossible = True start_position = 0 end_position = 0 else: doc_offset = len(query_tokens) + 2 start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset ''' if 10 < example_index < 20: print("*** Example ***") #print("page_id: %s" % (example.qas_id[:-5])) #print("token_to_tag_index :%s" % token_to_tag_index) #print(len(token_to_tag_index)) #print("unique_id: %s" % (unique_id)) #print("example_index: %s" % (example_index)) #print("doc_span_index: %s" % (doc_span_index)) # print("tokens: %s" % " ".join(tokens)) print("tokens: %s" % " ".join([ "%d:%s" % (x, y) for (x, y) in enumerate(tokens) ])) #print("token_to_orig_map: %s" % " ".join([ # "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) #print(len(token_to_orig_map)) # print("token_is_max_context: %s" % " ".join([ # "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() # ])) #print(len(token_is_max_context)) #print("input_ids: %s" % " ".join([str(x) for x in input_ids])) #print(len(input_ids)) #print( # "input_mask: %s" % " ".join([str(x) for x in input_mask])) #print(len(input_mask)) #print( # "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) #print(len(segment_ids)) print(f"original answer: {example.orig_answer_text}") if is_training and span_is_impossible: print("impossible example") if is_training and not span_is_impossible: answer_text = " ".join(tokens[start_position:(end_position + 1)]) print("start_position: %d" % (start_position)) print("end_position: %d" % (end_position)) print( "answer: %s" % (answer_text)) ''' # print('token_to_tag_index', token_to_tag_index) # print('xpath_tag_map', xpath_tag_map) # exit(0) xpath_tags_seq = [xpath_tag_map.get(tid, pad_x_tag_seq) for tid in token_to_tag_index] # ok xpath_subs_seq = [xpath_subs_map.get(tid, pad_x_subs_seq) for tid in token_to_tag_index] # ok xpath_box_seq = [xpath_box.get(tid, pad_x_box) for tid in token_to_tag_index] # print(xpath_box_seq) # exit(0) # we need to get extended_attention_mask if visible_matrix is not None: extended_attention_mask = [] for tid in token_to_tag_index: if tid == example.tag_num: extended_attention_mask.append(input_mask) else: visible_tids = visible_matrix[tid] if len(visible_tids) == 0: extended_attention_mask.append(input_mask) continue visible_per_token = [] for i, tid in enumerate(token_to_tag_index): if tid == example.tag_num and input_mask[i] == (1 if mask_padding_with_zero else 0): visible_per_token.append(1 if mask_padding_with_zero else 0) elif tid in visible_tids: visible_per_token.append(1 if mask_padding_with_zero else 0) else: visible_per_token.append(0 if mask_padding_with_zero else 1) extended_attention_mask.append(visible_per_token) # should be (max_seq_len*max_seq_len) else: extended_attention_mask = None features.append( InputFeatures( unique_id=unique_id, example_index=example_index, page_id=example.qas_id[:-5], doc_span_index=doc_span_index, tokens=tokens, token_to_orig_map=token_to_orig_map, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, paragraph_len=paragraph_len, start_position=start_position, end_position=end_position, token_to_tag_index=token_to_tag_index, is_impossible=span_is_impossible, xpath_tags_seq=xpath_tags_seq, xpath_subs_seq=xpath_subs_seq, xpath_box_seq=xpath_box_seq, extended_attention_mask=extended_attention_mask, )) unique_id += 1 return features def get_websrc_dataset(args, tokenizer, evaluate=False, output_examples=False): if not evaluate: websrc_dataset = load_and_cache_examples(args, tokenizer, evaluate=evaluate, output_examples=False) return websrc_dataset else: dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=evaluate, output_examples=True) return dataset, examples, features