Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Build pretrain scheme for self-supervised learning #235

Merged
merged 8 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cogdl/tasks/unsupervised_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, args, dataset=None, model=None):
super(UnsupervisedNodeClassification, self).__init__(args)
dataset = build_dataset(args) if dataset is None else dataset

self.dataset = dataset
self.data = dataset[0]

self.num_nodes = self.data.y.shape[0]
Expand Down Expand Up @@ -103,7 +104,7 @@ def save_emb(self, embs):

def train(self):
if self.trainer is not None:
return self.trainer.fit(self.model, self.data)
return self.trainer.fit(self.model, self.dataset)
if self.load_emb_path is None:
if "gcc" in self.model_name:
features_matrix = self.model.train(self.data)
Expand Down
1 change: 1 addition & 0 deletions cogdl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def build_trainer(args):
"neighborsampler": "cogdl.trainers.sampled_trainer",
"clustergcn": "cogdl.trainers.sampled_trainer",
"random_cluster": "cogdl.trainers.sampled_trainer",
"self_supervised": "cogdl.trainers.self_supervised_trainer",
"self_auxiliary_task_pretrain": "cogdl.trainers.self_auxiliary_task_trainer",
"self_auxiliary_task_joint": "cogdl.trainers.self_auxiliary_task_trainer",
"m3s": "cogdl.trainers.m3s_trainer",
Expand Down
176 changes: 95 additions & 81 deletions cogdl/trainers/self_auxiliary_task_trainer.py

Large diffs are not rendered by default.

74 changes: 49 additions & 25 deletions cogdl/trainers/self_supervised_trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from .base_trainer import BaseTrainer
from . import register_trainer


@register_trainer("self_supervised")
class SelfSupervisedTrainer(BaseTrainer):
def __init__(self, args):
super(SelfSupervisedTrainer, self).__init__()
Expand All @@ -18,21 +21,31 @@ def __init__(self, args):
self.save_dir = args.save_dir
self.load_emb_path = args.load_emb_path
self.lr = args.lr
self.sampling = args.sampling
self.sample_size = args.sample_size

@staticmethod
def add_args(parser: argparse.ArgumentParser):
"""Add trainer-specific arguments to the parser."""
# fmt: off
parser.add_argument("--sampling", action="store_true")
parser.add_argument("--sample-size", type=int, default=20000)
# fmt: on

@classmethod
def build_trainer_from_args(cls, args):
return cls(args)

def fit(self, model, data):
def fit(self, model, dataset):
data = dataset.data
data.add_remaining_self_loops()
self.data = data

if self.load_emb_path is not None:
embeds = np.load(self.load_emb_path)
embeds = torch.from_numpy(embeds).to(self.device)
return self.evaluate(embeds)

self.data.to(self.device)

best = 1e9
cnt_wait = 0
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.0)
Expand All @@ -42,39 +55,47 @@ def fit(self, model, data):

model.train()
for epoch in epoch_iter:
optimizer.zero_grad()
with self.data.local_graph():
if self.sampling:
idx = np.random.choice(np.arange(self.data.num_nodes), self.sample_size, replace=False)
self.data = data.subgraph(idx)

loss = model.node_classification_loss(data)
epoch_iter.set_description(f"Epoch: {epoch:03d}, Loss: {loss.item(): .4f}")
self.data.to(self.device)
optimizer.zero_grad()

if loss < best:
best = loss
cnt_wait = 0
else:
cnt_wait += 1
loss = model.node_classification_loss(self.data)
epoch_iter.set_description(f"Epoch: {epoch:03d}, Loss: {loss.item(): .4f}")

if cnt_wait == self.patience:
print("Early stopping!")
break
if loss < best:
best = loss
cnt_wait = 0
else:
cnt_wait += 1

loss.backward()
optimizer.step()
if cnt_wait == self.patience:
print("Early stopping!")
break

loss.backward()
optimizer.step()

self.data = data
self.data.to(self.device)
with torch.no_grad():
embeds = model.embed(data)
embeds = model.embed(self.data)
self.save_embed(embeds)

return self.evaluate(embeds)
return self.evaluate(embeds, dataset.get_loss_fn(), dataset.get_evaluator())

def evaluate(self, embeds):
def evaluate(self, embeds, loss_fn=None, evaluator=None):
nclass = int(torch.max(self.data.y) + 1)
opt = {
"idx_train": self.data.train_mask.to(self.device),
"idx_val": self.data.val_mask.to(self.device),
"idx_test": self.data.test_mask.to(self.device),
"num_classes": nclass,
}
result = LogRegTrainer().train(embeds, self.data.y.to(self.device), opt)
result = LogRegTrainer().train(embeds, self.data.y.to(self.device), opt, loss_fn, evaluator)
print(f"TestAcc: {result: .4f}")
return dict(Acc=result)

Expand Down Expand Up @@ -105,7 +126,7 @@ def forward(self, seq):


class LogRegTrainer(object):
def train(self, data, labels, opt):
def train(self, data, labels, opt, loss_fn=None, evaluator=None):
device = data.device
idx_train = opt["idx_train"].to(device)
idx_test = opt["idx_test"].to(device)
Expand All @@ -120,7 +141,7 @@ def train(self, data, labels, opt):
test_lbls = labels[idx_test]
tot = 0

xent = nn.CrossEntropyLoss()
xent = nn.CrossEntropyLoss() if loss_fn is None else loss_fn

for _ in range(50):
log = LogReg(nhid, nclass).to(device)
Expand All @@ -138,7 +159,10 @@ def train(self, data, labels, opt):
optimizer.step()

logits = log(test_embs)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
tot += acc.item()
if evaluator is None:
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
else:
acc = evaluator(logits, test_lbls)
tot += acc
return tot / 50
4 changes: 4 additions & 0 deletions tests/models/ssl/test_contrastive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_default_args():
"training_percents": [0.1],
"activation": "relu",
"residual": False,
"sampling": False,
"sample_size": 20,
"norm": None,
}
return build_args_from_dict(default_dict)
Expand All @@ -43,6 +45,8 @@ def get_unsupervised_nn_args():
"task": "unsupervised_node_classification",
"checkpoint": False,
"load_emb_path": None,
"sampling": False,
"sample_size": 20,
"training_percents": [0.1],
}
return build_args_from_dict(default_dict)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/ssl/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_default_args():
"checkpoint": False,
"num_layers": 2,
"activation": "relu",
"dropedge_rate": 0,
"agc_eval": False,
"residual": False,
"norm": None,
}
Expand Down Expand Up @@ -188,7 +190,7 @@ def test_m3s():


if __name__ == "__main__":
test_supergat()
# test_supergat()
test_m3s()
test_edgemask()
test_edgemask_pt_ft()
Expand Down
32 changes: 19 additions & 13 deletions tests/tasks/test_encode_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@


def test_encode_paper():
tokenizer, model = oagbert("oagbert-v2")
title = 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding'
abstract = 'We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation...'
authors = ['Jacob Devlin', 'Ming-Wei Chang', 'Kenton Lee', 'Kristina Toutanova']
venue = 'north american chapter of the association for computational linguistics'
affiliations = ['Google']
concepts = ['language model', 'natural language inference', 'question answering']
tokenizer, model = oagbert("oagbert-v2-test")
title = "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
abstract = "We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation..."
authors = ["Jacob Devlin", "Ming-Wei Chang", "Kenton Lee", "Kristina Toutanova"]
venue = "north american chapter of the association for computational linguistics"
affiliations = ["Google"]
concepts = ["language model", "natural language inference", "question answering"]
# encode paper
paper_info = model.encode_paper(
title=title, abstract=abstract, venue=venue, authors=authors, concepts=concepts, affiliations=affiliations, reduction="max"
title=title,
abstract=abstract,
venue=venue,
authors=authors,
concepts=concepts,
affiliations=affiliations,
reduction="max",
)

assert len(paper_info) == 5
assert paper_info['text'][0]['type'] == 'TEXT'
assert len(paper_info['authors']) == 4
assert len(paper_info['venue'][0]['token_ids']) == 9
assert tuple(paper_info['text'][0]['sequence_output'].shape) == (43, 768)
assert len(paper_info['text'][0]['pooled_output']) == 768
assert paper_info["text"][0]["type"] == "TEXT"
assert len(paper_info["authors"]) == 4
assert len(paper_info["venue"][0]["token_ids"]) == 9
assert tuple(paper_info["text"][0]["sequence_output"].shape) == (43, 768)
assert len(paper_info["text"][0]["pooled_output"]) == 768


if __name__ == "__main__":
Expand Down