Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add GraphMAE2 #429

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,17 @@ def __init__(self, x=None, y=None, **kwargs):
self.x = x
self.y = y
self.grb_adj = None
num_nodes = x.shape[0] if x is not None else None

for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
num_nodes = item
elif key == "grb_adj":
self.grb_adj = item
elif not is_read_adj_key(key):
self[key] = item

num_nodes = x.shape[0] if x is not None else None
if "edge_index_train" in kwargs:
self._adj_train = Adjacency(num_nodes=num_nodes)
for key, item in kwargs.items():
Expand Down Expand Up @@ -534,14 +535,17 @@ def add_remaining_self_loops(self):
self._adj_full.add_remaining_self_loops()
if self._adj_train is not None:
self._adj_train.add_remaining_self_loops()
return self

def padding_self_loops(self):
self._adj.padding_self_loops()
return self

def remove_self_loops(self):
self._adj_full.remove_self_loops()
if self._adj_train is not None:
self._adj_train.remove_self_loops()
return self

def row_norm(self):
self._adj.row_norm()
Expand Down Expand Up @@ -790,7 +794,7 @@ def sample_adj(self, batch, size=-1, replace=True):
if not torch.is_tensor(batch):
batch = torch.tensor(batch, dtype=torch.long)
(row_ptr, col_indices, nodes, edges) = sample_adj_c(
self._adj.row_indptr, self.col_indices, batch, size, replace
self.row_indptr, self.col_indices, batch, size, replace
)
else:
if torch.is_tensor(batch):
Expand Down Expand Up @@ -891,13 +895,18 @@ def subgraph(self, node_idx, keep_order=False):
val = self.edge_weight.numpy()
N = self.num_nodes
self[key] = sp.csr_matrix((val, (row, col)), shape=(N, N))
sub_adj = self[key][node_idx, :][:, node_idx]
sub_adj = self[key][node_idx, :][:, node_idx].tocoo()
sub_g = Graph()
sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long()
sub_g.col_indices = torch.from_numpy(sub_adj.indices).long()
# sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long()
# sub_g.col_indices = torch.from_numpy(sub_adj.indices).long()
row = torch.from_numpy(sub_adj.row).long()
col = torch.from_numpy(sub_adj.col).long()
sub_g.edge_index = (row, col)
sub_g.edge_weight = torch.from_numpy(sub_adj.data)
sub_g.num_nodes = len(node_idx)
for key in self.__keys__():
sub_g[key] = self[key][node_idx]
sub_g._adj._to_csr()
return sub_g.to(self._adj.device)

def edge_subgraph(self, edge_idx, require_idx=True):
Expand Down
Binary file added examples/graphmae/imgs/compare.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/graphmae/imgs/fig.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 87 additions & 0 deletions examples/graphmae2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
<h1> GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
Graph Learner </h1>

[**CogDL**](/~https://github.com/THUDM/cogdl) Implementation for WWW'23 paper: [GraphMAE2: A Decoding-Enhanced Masked Self-Supervised
Graph Learner](https://arxiv.org/abs/2304.04779).
<img src="assets/../asserts/overview.png">

[GraphMAE] The predecessor of this work: [GraphMAE: Self-Supervised Masked Graph Autoencoders](https://arxiv.org/abs/2205.10803) can be found [here](/~https://github.com/THUDM/cogdl/tree/master/examples/graphmae).

<h2>Dependencies </h2>

* Python >= 3.7
* [Pytorch](https://pytorch.org/) >= 1.9.0
* [cogdl](/~https://github.com/THUDM/cogdl) >= 0.5.3
* pyyaml == 5.4.1


<h2>Quick Start </h2>

For quick start, you could run the scripts:

**Node classification**

```bash
sh run_minibatch.sh <dataset_name> <gpu_id> # for mini batch node classification
# example: sh run_minibatch.sh ogbn-arxiv 0
sh run_fullbatch.sh <dataset_name> <gpu_id> # for full batch node classification
# example: sh run_fullbatch.sh cora 0

# Or you could run the code manually:
# for mini batch node classification
python main_large.py --dataset ogbn-arxiv --encoder gat --decoder gat --seed 0 --device 0
# for full batch node classification
python main_full_batch.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0
```

Supported datasets:

* mini batch node classification: `ogbn-arxiv`, `ogbn-products`, `mag-scholar-f`, `ogbn-papers100M`
* full batch node classification: `cora`, `citeseer`, `pubmed`

Run the scripts provided or add `--use_cfg` in command to reproduce the reported results.

**For Large scale graphs**
Before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. By default, the program will load dataset from `./data` and save the generated local clusters to `./lc_ego_graphs`. To generate a local cluster, you should first install [localclustering](/~https://github.com/kfoynt/LocalGraphClustering) and then run the following command:

```
python ./datasets/localclustering.py --dataset <your_dataset> --data_dir <path_to_data>
```
And we also provide the pre-generated local clusters which can be downloaded [here](https://cloud.tsinghua.edu.cn/d/64f859f389ca43eda472/) and then put into `lc_ego_graphs` for usage.



<h2> Datasets </h2>

During the code's execution, the OGB and small-scale datasets (Cora, Citeseer, and PubMed) will be downloaded automatically.

<h2> Experimental Results </h2>

Experimental results of node classification on large-scale datasets (Accuracy, %):

| | Ogbn-arxiv | Ogbn-products | Mag-Scholar-F | Ogbn-papers100M |
| ------------------ | ------------ | ------------ | ------------ | -------------- |
| MLP | 55.50±0.23 | 61.06±0.08 | 39.11±0.21 | 47.24±0.31 |
| SGC | 66.92±0.08 | 74.87±0.25 | 54.68±0.23 | 63.29±0.19 |
| Random-Init | 68.14±0.02 | 74.04±0.06 | 56.57±0.03 | 61.55±0.12 |
| CCA-SSG | 68.57±0.02 | 75.27±0.05 | 51.55±0.03 | 55.67±0.15 |
| GRACE | 69.34±0.01 | 79.47±0.59 | 57.39±0.02 | 61.21±0.12 |
| BGRL | 70.51±0.03 | 78.59±0.02 | 57.57±0.01 | 62.18±0.15 |
| GGD | - | 75.70±0.40 | - | 63.50±0.50 |
| GraphMAE | 71.03±0.02 | 78.89±0.01 | 58.75±0.03 | 62.54±0.09 |
| **GraphMAE2** | **71.89±0.03** | **81.59±0.02** | **59.24±0.01** | **64.89±0.04** |



<h1> Citing </h1>

If you find this work is helpful to your research, please consider citing our paper:

```
@inproceedings{hou2023graphmae2,
title={GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner},
author={Zhenyu Hou, Yufei He, Yukuo Cen, Xiao Liu, Yuxiao Dong, Evgeny Kharlamov, Jie Tang},
booktitle={Proceedings of the ACM Web Conference 2023 (WWW’23)},
year={2023}
}
```
Binary file added examples/graphmae2/asserts/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions examples/graphmae2/configs/citeseer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
lr: 0.0005 # 0.0005
lr_f: 0.025
num_hidden: 1024
num_heads: 4
num_out_heads: 1
num_layers: 2
weight_decay: 1e-4
weight_decay_f: 1e-2
max_epoch: 500
max_epoch_f: 500
mask_rate: 0.5
num_layers: 2
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.1
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.0
optimizer: adam
replace_rate: 0.0
alpha_l: 1
scheduler: True
remask_method: fixed
momentum: 1
lam: 0.1
27 changes: 27 additions & 0 deletions examples/graphmae2/configs/cora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
lr: 0.001
lr_f: 0.025
num_hidden: 1024
num_heads: 8
num_out_heads: 1
num_layers: 2
weight_decay: 2e-4
weight_decay_f: 1e-4
max_epoch: 2000
max_epoch_f: 300
mask_rate: 0.5
num_layers: 2
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.1
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.0
optimizer: adam
replace_rate: 0.1
alpha_l: 4
scheduler: True
remask_method: fixed
momentum: 0
lam: 0.1
30 changes: 30 additions & 0 deletions examples/graphmae2/configs/mag-scholar-f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
lr: 0.001
lr_f: 0.001
num_hidden: 1024
num_heads: 8
num_out_heads: 1
num_layers: 4
weight_decay: 0.04
weight_decay_f: 0
max_epoch: 10
max_epoch_f: 1000
batch_size: 512
batch_size_f: 256
mask_rate: 0.5
num_layers: 4
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.2
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.5
optimizer: adamw
alpha_l: 2
scheduler: True
remask_method: random
momentum: 0.996
lam: 0.1
delayed_ema_epoch: 0
num_remasking: 3
30 changes: 30 additions & 0 deletions examples/graphmae2/configs/ogbn-arxiv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
lr: 0.0025
lr_f: 0.005
num_hidden: 1024
num_heads: 8
num_out_heads: 1
num_layers: 4
weight_decay: 0.06
weight_decay_f: 1e-4
max_epoch: 60
max_epoch_f: 1000
batch_size: 512
batch_size_f: 256
mask_rate: 0.5
num_layers: 4
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.1
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.5
optimizer: adamw
alpha_l: 6
scheduler: True
remask_method: random
momentum: 0.996
lam: 10.0
delayed_ema_epoch: 40
num_remasking: 3
30 changes: 30 additions & 0 deletions examples/graphmae2/configs/ogbn-papers100M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
lr: 0.001
lr_f: 0.001
num_hidden: 1024
num_heads: 4
num_out_heads: 1
num_layers: 4
weight_decay: 0.05
weight_decay_f: 0
max_epoch: 10
max_epoch_f: 1000
batch_size: 512
batch_size_f: 256
mask_rate: 0.5
num_layers: 4
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.2
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.5
optimizer: adamw
alpha_l: 2
scheduler: True
remask_method: random
momentum: 0.996
lam: 10.0
delayed_ema_epoch: 0
num_remasking: 3
30 changes: 30 additions & 0 deletions examples/graphmae2/configs/ogbn-products.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
lr: 0.002
lr_f: 0.001
num_hidden: 1024
num_heads: 4
num_out_heads: 1
num_layers: 4
weight_decay: 0.04
weight_decay_f: 0
max_epoch: 20
max_epoch_f: 1000
batch_size: 512
batch_size_f: 256
mask_rate: 0.5
num_layers: 4
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.2
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.5
optimizer: adamw
alpha_l: 3
scheduler: True
remask_method: random
momentum: 0.996
lam: 5.0
delayed_ema_epoch: 0
num_remasking: 3
27 changes: 27 additions & 0 deletions examples/graphmae2/configs/pubmed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
lr: 0.005
lr_f: 0.025
num_hidden: 512
num_heads: 2
num_out_heads: 1
num_layers: 2
weight_decay: 1e-5
weight_decay_f: 5e-4
max_epoch: 2000
max_epoch_f: 500
mask_rate: 0.9
num_layers: 2
encoder: gat
decoder: gat
activation: prelu
attn_drop: 0.1
linear_prob: True
in_drop: 0.2
loss_fn: sce
drop_edge_rate: 0.0
optimizer: adam
replace_rate: 0.0
alpha_l: 4
scheduler: True
remask_method: fixed
momentum: 0.995
lam: 1
Empty file.
Loading