Tzktz's picture
Upload 7664 files
6fc683c verified
import copy
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
class BiaffineAttention(torch.nn.Module):
"""Implements a biaffine attention operator for binary relation classification.
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
as a classifier for binary relation classification.
Args:
in_features (int): The size of the feature dimension of the inputs.
out_features (int): The size of the feature dimension of the output.
Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
additional dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
additional dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.
Examples:
>>> batch_size, in_features, out_features = 32, 100, 4
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
>>> x_1 = torch.randn(batch_size, in_features)
>>> x_2 = torch.randn(batch_size, in_features)
>>> output = biaffine_attention(x_1, x_2)
>>> print(output.size())
torch.Size([32, 4])
"""
def __init__(self, in_features, out_features):
super(BiaffineAttention, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
self.reset_parameters()
def forward(self, x_1, x_2):
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
def reset_parameters(self):
self.bilinear.reset_parameters()
self.linear.reset_parameters()
class REDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.entity_emb = nn.Embedding(3, config.hidden_size, scale_grad_by_freq=True)
projection = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
)
self.ffnn_head = copy.deepcopy(projection)
self.ffnn_tail = copy.deepcopy(projection)
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
self.loss_fct = CrossEntropyLoss()
def build_relation(self, relations, entities):
batch_size = len(relations)
new_relations = []
for b in range(batch_size):
if len(entities[b]["start"]) <= 2:
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
all_possible_relations = set(
[
(i, j)
for i in range(len(entities[b]["label"]))
for j in range(len(entities[b]["label"]))
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
]
)
if len(all_possible_relations) == 0:
all_possible_relations = set([(0, 1)])
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
negative_relations = all_possible_relations - positive_relations
positive_relations = set([i for i in positive_relations if i in all_possible_relations])
reordered_relations = list(positive_relations) + list(negative_relations)
relation_per_doc = {"head": [], "tail": [], "label": []}
relation_per_doc["head"] = [i[0] for i in reordered_relations]
relation_per_doc["tail"] = [i[1] for i in reordered_relations]
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
len(reordered_relations) - len(positive_relations)
)
assert len(relation_per_doc["head"]) != 0
new_relations.append(relation_per_doc)
return new_relations, entities
def get_predicted_relations(self, logits, relations, entities):
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = {}
rel["head_id"] = relations["head"][i]
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
rel["head_type"] = entities["label"][rel["head_id"]]
rel["tail_id"] = relations["tail"][i]
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
rel["tail_type"] = entities["label"][rel["tail_id"]]
rel["type"] = 1
pred_relations.append(rel)
return pred_relations
def forward(self, hidden_states, entities, relations):
batch_size, max_n_words, context_dim = hidden_states.size()
device = hidden_states.device
relations, entities = self.build_relation(relations, entities)
loss = 0
all_pred_relations = []
for b in range(batch_size):
head_entities = torch.tensor(relations[b]["head"], device=device)
tail_entities = torch.tensor(relations[b]["tail"], device=device)
relation_labels = torch.tensor(relations[b]["label"], device=device)
entities_start_index = torch.tensor(entities[b]["start"], device=device)
entities_labels = torch.tensor(entities[b]["label"], device=device)
head_index = entities_start_index[head_entities]
head_label = entities_labels[head_entities]
head_label_repr = self.entity_emb(head_label)
tail_index = entities_start_index[tail_entities]
tail_label = entities_labels[tail_entities]
tail_label_repr = self.entity_emb(tail_label)
head_repr = torch.cat(
(hidden_states[b][head_index], head_label_repr),
dim=-1,
)
tail_repr = torch.cat(
(hidden_states[b][tail_index], tail_label_repr),
dim=-1,
)
heads = self.ffnn_head(head_repr)
tails = self.ffnn_tail(tail_repr)
logits = self.rel_classifier(heads, tails)
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
all_pred_relations.append(pred_relations)
return loss, all_pred_relations