diff --git a/flair/models/dependency_parser_model.py b/flair/models/dependency_parser_model.py index ae31619726..e89dad9f13 100644 --- a/flair/models/dependency_parser_model.py +++ b/flair/models/dependency_parser_model.py @@ -1,38 +1,36 @@ from pathlib import Path -from typing import List, Optional, Union, Dict, Tuple +from typing import List, Optional, Union, Tuple + +import sklearn import torch import torch.nn - -from torch.utils.data import Dataset from torch.nn.modules.rnn import apply_permutation from torch.nn.utils.rnn import PackedSequence -import sklearn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from torch.utils.data import Dataset import flair.nn from flair.data import Dictionary, Sentence, Token, Label, DataPoint from flair.datasets import DataLoader, SentenceDataset from flair.embeddings import TokenEmbeddings -from flair.training_utils import Result, store_embeddings from flair.nn.dropout import LockedDropout, WordDropout +from flair.training_utils import Result, store_embeddings from flair.visual.tree_printer import tree_printer -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence - class DependencyParser(flair.nn.Model): - def __init__( - self, - token_embeddings: TokenEmbeddings, - relations_dictionary: Dictionary, - tag_type: str = 'dependency', - lstm_hidden_size: int = 400, - mlp_arc_units: int = 500, - mlp_rel_units: int = 100, - lstm_layers: int = 3, - mlp_dropout: float = 0.33, - lstm_dropout: float = 0.33, - word_dropout: float = 0.05, - ): + def __init__(self, + token_embeddings: TokenEmbeddings, + relations_dictionary: Dictionary, + tag_type: str = 'dependency', + lstm_hidden_size: int = 400, + mlp_arc_units: int = 500, + mlp_rel_units: int = 100, + lstm_layers: int = 3, + mlp_dropout: float = 0.33, + lstm_dropout: float = 0.33, + word_dropout: float = 0.05, + ): """ Initializes a DependencyParser The model is based on biaffine dependency parser :cite: "Dozat T. & Manning C. Deep biaffine attention for neural dependency parsing." @@ -48,8 +46,9 @@ def __init__( :param lstm_dropout: dropout probability in LSTM :param word_dropout: word dropout probability """ - + super(DependencyParser, self).__init__() + self.token_embeddings = token_embeddings self.relations_dictionary: Dictionary = relations_dictionary self.lstm_hidden_size = lstm_hidden_size @@ -58,27 +57,29 @@ def __init__( self.lstm_layers = lstm_layers self.lstm_dropout = lstm_dropout self.mlp_dropout = mlp_dropout + self.use_word_dropout: bool = word_dropout > 0 if self.use_word_dropout: self.word_dropout = WordDropout(dropout_rate=word_dropout) + self.tag_type = tag_type self.lstm_input_dim: int = self.token_embeddings.embedding_length - + self.lstm = BiLSTM(input_size=self.lstm_input_dim, hidden_size=self.lstm_hidden_size, num_layers=self.lstm_layers, dropout=self.lstm_dropout) - self.mlp_arc_h = MLP(n_in=self.lstm_hidden_size*2, + self.mlp_arc_h = MLP(n_in=self.lstm_hidden_size * 2, n_hidden=self.mlp_arc_units, dropout=self.mlp_dropout) - self.mlp_arc_d = MLP(n_in=self.lstm_hidden_size*2, + self.mlp_arc_d = MLP(n_in=self.lstm_hidden_size * 2, n_hidden=self.mlp_arc_units, dropout=self.mlp_dropout) - self.mlp_rel_h = MLP(n_in=self.lstm_hidden_size*2, + self.mlp_rel_h = MLP(n_in=self.lstm_hidden_size * 2, n_hidden=self.mlp_rel_units, dropout=self.mlp_dropout) - self.mlp_rel_d = MLP(n_in=self.lstm_hidden_size*2, + self.mlp_rel_d = MLP(n_in=self.lstm_hidden_size * 2, n_hidden=self.mlp_rel_units, dropout=self.mlp_dropout) @@ -115,12 +116,12 @@ def forward(self, sentences: List[Sentence]): all_embs.append(t) sentence_tensor = torch.cat(all_embs).view([batch_size, seq_len, - self.token_embeddings.embedding_length,]) - + self.token_embeddings.embedding_length, ]) + # Main model implementation drops words and tags (independently), instead, we use word dropout! if self.use_word_dropout: sentence_tensor = self.word_dropout(sentence_tensor) - + x = pack_padded_sequence(sentence_tensor, lengths, True, False) x, _ = self.lstm(x) @@ -139,19 +140,17 @@ def forward(self, sentences: List[Sentence]): score_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) return score_arc, score_rel - - def forward_loss(self, - data_points: List[Sentence]) -> torch.tensor: - + def forward_loss(self, data_points: List[Sentence]) -> torch.tensor: + score_arc, score_rel = self.forward(data_points) loss_arc, loss_rel = self._calculate_loss(score_arc, score_rel, data_points) main_loss = loss_arc + loss_rel return main_loss - - def _calculate_loss(self, score_arc: torch.tensor, + def _calculate_loss(self, + score_arc: torch.tensor, score_relation: torch.tensor, data_points: List[Sentence]) -> Tuple[float, float]: @@ -159,11 +158,11 @@ def _calculate_loss(self, score_arc: torch.tensor, arc_loss = 0.0 rel_loss = 0.0 - + for sen_id, sen in enumerate(data_points): sen_len = lengths[sen_id] - - arc_labels = [token.head_id - 1 if token.head_id != 0 else token.idx - 1 + + arc_labels = [token.head_id - 1 if token.head_id != 0 else token.idx - 1 for token in sen.tokens] arc_labels = torch.tensor(arc_labels, dtype=torch.int64, device=flair.device) arc_loss += torch.nn.functional.cross_entropy( @@ -171,13 +170,13 @@ def _calculate_loss(self, score_arc: torch.tensor, rel_labels = [self.relations_dictionary.get_idx_for_item(token.get_tag(self.tag_type).value) for token in sen.tokens] - + rel_labels = torch.tensor(rel_labels, dtype=torch.int64, device=flair.device) score_rel = score_relation[sen_id][torch.arange(len(arc_labels)), arc_labels] rel_loss += torch.nn.functional.cross_entropy(score_rel, rel_labels) return arc_loss, rel_loss - + def predict(self, sentences: Union[List[Sentence], Sentence], mini_batch_size: int = 32, @@ -205,19 +204,19 @@ def predict(self, score_arc, score_rel = self.forward(batch) arc_prediction, relation_prediction = self._obtain_labels_(score_arc, score_rel) - for sentnce_index, (sentence, sent_tags, sent_arcs) in enumerate(zip(batch, relation_prediction, arc_prediction)): + for sentnce_index, (sentence, sent_tags, sent_arcs) in enumerate( + zip(batch, relation_prediction, arc_prediction)): + for token_index, (token, tag, head_id) in enumerate(zip(sentence.tokens, sent_tags, sent_arcs)): - token.add_tag(self.tag_type, - tag, - score_rel[sentnce_index][token_index]) - + token.add_tag(self.tag_type, tag, score_rel[sentnce_index][token_index]) + token.head_id = int(head_id) if print_tree: tree_printer(sentence, self.tag_type) print("-" * 50) store_embeddings(batch, storage_mode=embedding_storage_mode) - + def evaluate( self, data_points: Union[List[DataPoint], Dataset], @@ -228,13 +227,13 @@ def evaluate( num_workers: int = 8, main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), gold_label_dictionary: Optional[Dictionary] = None, - ) -> Result: - + **kwargs, + ) -> Result: + if not isinstance(data_points, Dataset): data_points = SentenceDataset(data_points) - data_loader = DataLoader(data_points, - batch_size=mini_batch_size, - num_workers=num_workers) + + data_loader = DataLoader(data_points, batch_size=mini_batch_size, num_workers=num_workers) lines: List[str] = ["token gold_tag gold_arc predicted_tag predicted_arc\n"] @@ -253,9 +252,9 @@ def evaluate( score_arc, score_rel = self.forward(batch) loss_arc, loss_rel = self._calculate_loss(score_arc, score_rel, batch) arc_prediction, relation_prediction = self._obtain_labels_(score_arc, score_rel) - + parsing_metric(arc_prediction, relation_prediction, batch, gold_label_type) - + eval_loss_arc += loss_arc eval_loss_rel += loss_rel @@ -263,8 +262,7 @@ def evaluate( for (token, arc, tag) in zip(sentence.tokens, arcs, sent_tags): token: Token = token token.add_tag_label("predicted", Label(tag)) - token.add_tag_label("predicted_head_id", - Label(str(int(arc)))) + token.add_tag_label("predicted_head_id", Label(str(int(arc)))) # append both to file for evaluation eval_line = "{} {} {} {} {}\n".format(token.text, @@ -276,14 +274,11 @@ def evaluate( lines.append("\n") for sentence in batch: - gold_tags = [token.get_tag(gold_label_type).value for token in sentence.tokens] predicted_tags = [tag.tag for tag in sentence.get_spans("predicted")] - y_pred += [self.relations_dictionary.get_idx_for_item(tag) - for tag in predicted_tags] - y_true += [self.relations_dictionary.get_idx_for_item(tag) - for tag in gold_tags] + y_pred += [self.relations_dictionary.get_idx_for_item(tag) for tag in predicted_tags] + y_true += [self.relations_dictionary.get_idx_for_item(tag) for tag in gold_tags] store_embeddings(batch, embedding_storage_mode) @@ -325,26 +320,24 @@ def evaluate( log_header=log_header, detailed_results=detailed_result, classification_report=classification_report_dict, - loss=eval_loss_rel+eval_loss_arc + loss=eval_loss_rel + eval_loss_arc ) return result - def _obtain_labels_(self, - score_arc: torch.tensor, - score_rel: torch.tensor) -> Tuple[List[List[int]], - List[List[str]]]: - + def _obtain_labels_(self, score_arc: torch.tensor, score_rel: torch.tensor) -> Tuple[List[List[int]], + List[List[str]]]: + arc_prediction: torch.tensor = score_arc.argmax(-1) relation_prediction: torch.tensor = score_rel.argmax(-1) relation_prediction = relation_prediction.gather(-1, arc_prediction.unsqueeze(-1)).squeeze(-1) - arc_prediction = [[arc+1 if token_index != arc else 0 for token_index, arc in enumerate(batch)] + arc_prediction = [[arc + 1 if token_index != arc else 0 for token_index, arc in enumerate(batch)] for batch in arc_prediction] relation_prediction = [[self.relations_dictionary.get_item_for_index(rel_tag_idx) for rel_tag_idx in batch] for batch in relation_prediction] return arc_prediction, relation_prediction - + def _get_state_dict(self): model_state = { "state_dict": self.state_dict(), @@ -362,16 +355,15 @@ def _get_state_dict(self): @staticmethod def _init_model_with_state_dict(state): - model = DependencyParser( - token_embeddings=state["token_embeddings"], - relations_dictionary=state["relations_dictionary"], - lstm_hidden_size=state["lstm_hidden_size"], - mlp_arc_units=state["mlp_arc_units"], - mlp_rel_units=state["mlp_rel_units"], - lstm_layers=state["lstm_layers"], - mlp_dropout=state["mlp_dropout"], - lstm_dropout=state["lstm_dropout"], - ) + model = DependencyParser(token_embeddings=state["token_embeddings"], + relations_dictionary=state["relations_dictionary"], + lstm_hidden_size=state["lstm_hidden_size"], + mlp_arc_units=state["mlp_arc_units"], + mlp_rel_units=state["mlp_rel_units"], + lstm_layers=state["lstm_layers"], + mlp_dropout=state["mlp_dropout"], + lstm_dropout=state["lstm_dropout"], + ) model.load_state_dict(state["state_dict"]) return model @@ -383,12 +375,12 @@ def label_type(self): class BiLSTM(torch.nn.Module): def __init__( - self, - input_size: int, - hidden_size: int, - num_layers: int= 1, - dropout: float= 0.0 - ): + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + dropout: float = 0.0 + ): """ Initializes a VariationalBiLSTM @@ -418,9 +410,9 @@ def __init__( def __repr__(self): st = "input:{} , hidden_size:{}, num_of_layers:{}, dropout_rate:{}".format(self.input_size, - self.hidden_size, - self.num_layers, - self.dropout) + self.hidden_size, + self.num_layers, + self.dropout) return f"{self.__class__.__name__}({st})" def reset_parameters(self): @@ -517,11 +509,11 @@ def forward(self, sequence, hx=None): class Biaffine(torch.nn.Module): def __init__( - self, - n_in, - n_out=1, - bias_x=True, - bias_y=True): + self, + n_in, + n_out=1, + bias_x=True, + bias_y=True): """ :param n_in: size of input :param n_out: number of channels @@ -542,9 +534,9 @@ def __init__( def extra_repr(self): st = "n_in:{}, n_out:{}, bias_x:{}, bias_x:{}".format(self.n_in, - self.n_out, - self.bias_x, - self.bias_y) + self.n_out, + self.bias_x, + self.bias_y) return st def reset_parameters(self): @@ -563,12 +555,11 @@ def forward(self, x, y): class MLP(torch.nn.Module): - def __init__( - self, - n_in, - n_hidden, - dropout=0 - ): + def __init__(self, + n_in: int, + n_hidden: int, + dropout: float = 0.0, + ): super(MLP, self).__init__() self.linear = torch.nn.Linear(n_in, n_hidden) @@ -592,7 +583,7 @@ def forward(self, x): class ParsingMetric: def __init__(self, epsilon=1e-8): - + self.eps = epsilon self.total = 0.0 self.correct_arcs = 0.0 @@ -607,10 +598,13 @@ def __call__(self, for batch_indx, batch in enumerate(sentences): self.total += len(batch.tokens) for token_indx, token in enumerate(batch.tokens): + if arc_prediction[batch_indx][token_indx] == token.head_id: self.correct_arcs += 1 - if relation_prediction[batch_indx][token_indx] == token.get_tag(tag_type).value: - self.correct_rels += 1 + + # if head AND deprel correct, augment correct_rels score + if relation_prediction[batch_indx][token_indx] == token.get_tag(tag_type).value: + self.correct_rels += 1 def get_las(self) -> float: return self.correct_rels / (self.total + self.eps)