Skip to content

Commit

Permalink
[Feature] Build pretrain scheme for self-supervised learning (#235)
Browse files Browse the repository at this point in the history
* [Feature] Add pretraining-finetune setting for self auxiliary task

* [Test] modifications for ssl tests

* Fix edge_index

* Fix oagbert-v2 test

Co-authored-by: Yukuo Cen <cenyk1230@qq.com>
  • Loading branch information
icycookies and cenyk1230 authored May 30, 2021
1 parent a93247d commit e738c9f
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 121 deletions.
3 changes: 2 additions & 1 deletion cogdl/tasks/unsupervised_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, args, dataset=None, model=None):
super(UnsupervisedNodeClassification, self).__init__(args)
dataset = build_dataset(args) if dataset is None else dataset

self.dataset = dataset
self.data = dataset[0]

self.num_nodes = self.data.y.shape[0]
Expand Down Expand Up @@ -103,7 +104,7 @@ def save_emb(self, embs):

def train(self):
if self.trainer is not None:
return self.trainer.fit(self.model, self.data)
return self.trainer.fit(self.model, self.dataset)
if self.load_emb_path is None:
if "gcc" in self.model_name:
features_matrix = self.model.train(self.data)
Expand Down
1 change: 1 addition & 0 deletions cogdl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def build_trainer(args):
"neighborsampler": "cogdl.trainers.sampled_trainer",
"clustergcn": "cogdl.trainers.sampled_trainer",
"random_cluster": "cogdl.trainers.sampled_trainer",
"self_supervised": "cogdl.trainers.self_supervised_trainer",
"self_auxiliary_task_pretrain": "cogdl.trainers.self_auxiliary_task_trainer",
"self_auxiliary_task_joint": "cogdl.trainers.self_auxiliary_task_trainer",
"m3s": "cogdl.trainers.m3s_trainer",
Expand Down
176 changes: 95 additions & 81 deletions cogdl/trainers/self_auxiliary_task_trainer.py

Large diffs are not rendered by default.

74 changes: 49 additions & 25 deletions cogdl/trainers/self_supervised_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from .base_trainer import BaseTrainer
from . import register_trainer


@register_trainer("self_supervised")
class SelfSupervisedTrainer(BaseTrainer):
def __init__(self, args):
super(SelfSupervisedTrainer, self).__init__()
Expand All @@ -18,21 +21,31 @@ def __init__(self, args):
self.save_dir = args.save_dir
self.load_emb_path = args.load_emb_path
self.lr = args.lr
self.sampling = args.sampling
self.sample_size = args.sample_size

@staticmethod
def add_args(parser: argparse.ArgumentParser):
"""Add trainer-specific arguments to the parser."""
# fmt: off
parser.add_argument("--sampling", action="store_true")
parser.add_argument("--sample-size", type=int, default=20000)
# fmt: on

@classmethod
def build_trainer_from_args(cls, args):
return cls(args)

def fit(self, model, data):
def fit(self, model, dataset):
data = dataset.data
data.add_remaining_self_loops()
self.data = data

if self.load_emb_path is not None:
embeds = np.load(self.load_emb_path)
embeds = torch.from_numpy(embeds).to(self.device)
return self.evaluate(embeds)

self.data.to(self.device)

best = 1e9
cnt_wait = 0
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.0)
Expand All @@ -42,39 +55,47 @@ def fit(self, model, data):

model.train()
for epoch in epoch_iter:
optimizer.zero_grad()
with self.data.local_graph():
if self.sampling:
idx = np.random.choice(np.arange(self.data.num_nodes), self.sample_size, replace=False)
self.data = data.subgraph(idx)

loss = model.node_classification_loss(data)
epoch_iter.set_description(f"Epoch: {epoch:03d}, Loss: {loss.item(): .4f}")
self.data.to(self.device)
optimizer.zero_grad()

if loss < best:
best = loss
cnt_wait = 0
else:
cnt_wait += 1
loss = model.node_classification_loss(self.data)
epoch_iter.set_description(f"Epoch: {epoch:03d}, Loss: {loss.item(): .4f}")

if cnt_wait == self.patience:
print("Early stopping!")
break
if loss < best:
best = loss
cnt_wait = 0
else:
cnt_wait += 1

loss.backward()
optimizer.step()
if cnt_wait == self.patience:
print("Early stopping!")
break

loss.backward()
optimizer.step()

self.data = data
self.data.to(self.device)
with torch.no_grad():
embeds = model.embed(data)
embeds = model.embed(self.data)
self.save_embed(embeds)

return self.evaluate(embeds)
return self.evaluate(embeds, dataset.get_loss_fn(), dataset.get_evaluator())

def evaluate(self, embeds):
def evaluate(self, embeds, loss_fn=None, evaluator=None):
nclass = int(torch.max(self.data.y) + 1)
opt = {
"idx_train": self.data.train_mask.to(self.device),
"idx_val": self.data.val_mask.to(self.device),
"idx_test": self.data.test_mask.to(self.device),
"num_classes": nclass,
}
result = LogRegTrainer().train(embeds, self.data.y.to(self.device), opt)
result = LogRegTrainer().train(embeds, self.data.y.to(self.device), opt, loss_fn, evaluator)
print(f"TestAcc: {result: .4f}")
return dict(Acc=result)

Expand Down Expand Up @@ -105,7 +126,7 @@ def forward(self, seq):


class LogRegTrainer(object):
def train(self, data, labels, opt):
def train(self, data, labels, opt, loss_fn=None, evaluator=None):
device = data.device
idx_train = opt["idx_train"].to(device)
idx_test = opt["idx_test"].to(device)
Expand All @@ -120,7 +141,7 @@ def train(self, data, labels, opt):
test_lbls = labels[idx_test]
tot = 0

xent = nn.CrossEntropyLoss()
xent = nn.CrossEntropyLoss() if loss_fn is None else loss_fn

for _ in range(50):
log = LogReg(nhid, nclass).to(device)
Expand All @@ -138,7 +159,10 @@ def train(self, data, labels, opt):
optimizer.step()

logits = log(test_embs)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
tot += acc.item()
if evaluator is None:
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
else:
acc = evaluator(logits, test_lbls)
tot += acc
return tot / 50
4 changes: 4 additions & 0 deletions tests/models/ssl/test_contrastive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_default_args():
"training_percents": [0.1],
"activation": "relu",
"residual": False,
"sampling": False,
"sample_size": 20,
"norm": None,
}
return build_args_from_dict(default_dict)
Expand All @@ -43,6 +45,8 @@ def get_unsupervised_nn_args():
"task": "unsupervised_node_classification",
"checkpoint": False,
"load_emb_path": None,
"sampling": False,
"sample_size": 20,
"training_percents": [0.1],
}
return build_args_from_dict(default_dict)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/ssl/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_default_args():
"checkpoint": False,
"num_layers": 2,
"activation": "relu",
"dropedge_rate": 0,
"agc_eval": False,
"residual": False,
"norm": None,
}
Expand Down Expand Up @@ -188,7 +190,7 @@ def test_m3s():


if __name__ == "__main__":
test_supergat()
# test_supergat()
test_m3s()
test_edgemask()
test_edgemask_pt_ft()
Expand Down
32 changes: 19 additions & 13 deletions tests/tasks/test_encode_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@


def test_encode_paper():
tokenizer, model = oagbert("oagbert-v2")
title = 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding'
abstract = 'We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation...'
authors = ['Jacob Devlin', 'Ming-Wei Chang', 'Kenton Lee', 'Kristina Toutanova']
venue = 'north american chapter of the association for computational linguistics'
affiliations = ['Google']
concepts = ['language model', 'natural language inference', 'question answering']
tokenizer, model = oagbert("oagbert-v2-test")
title = "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
abstract = "We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation..."
authors = ["Jacob Devlin", "Ming-Wei Chang", "Kenton Lee", "Kristina Toutanova"]
venue = "north american chapter of the association for computational linguistics"
affiliations = ["Google"]
concepts = ["language model", "natural language inference", "question answering"]
# encode paper
paper_info = model.encode_paper(
title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, reduction="max"
title=title,
abstract=abstract,
venue=venue,
authors=authors,
concepts=concepts,
affiliations=affiliations,
reduction="max",
)

assert len(paper_info) == 5
assert paper_info['text'][0]['type'] == 'TEXT'
assert len(paper_info['authors']) == 4
assert len(paper_info['venue'][0]['token_ids']) == 9
assert tuple(paper_info['text'][0]['sequence_output'].shape) == (43, 768)
assert len(paper_info['text'][0]['pooled_output']) == 768
assert paper_info["text"][0]["type"] == "TEXT"
assert len(paper_info["authors"]) == 4
assert len(paper_info["venue"][0]["token_ids"]) == 9
assert tuple(paper_info["text"][0]["sequence_output"].shape) == (43, 768)
assert len(paper_info["text"][0]["pooled_output"]) == 768


if __name__ == "__main__":
Expand Down

0 comments on commit e738c9f

Please sign in to comment.