Skip to content

Commit

Permalink
Merge branch 'master' into cpu_docker_tcmalloc
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-sizov committed Jul 7, 2023
2 parents c33c131 + 67da04e commit 022b843
Show file tree
Hide file tree
Showing 25 changed files with 1,050 additions and 45 deletions.
1 change: 1 addition & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ pipeline {
steps {
unit_distributed_linux('pytorch', 'cpu')
}
when { expression { false } }
}
}
post {
Expand Down
26 changes: 26 additions & 0 deletions graphbolt/include/graphbolt/csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const;

/**
* @brief Sample negative edges by randomly choosing negative
* source-destination pairs according to a uniform distribution. For each edge
* ``(u, v)``, it is supposed to generate `negative_ratio` pairs of negative
* edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
* the graph.
*
* @param node_pairs A tuple of two 1D tensors that represent the source and
* destination of positive edges, with 'positive' indicating that these edges
* are present in the graph. It's important to note that within the context of
* a heterogeneous graph, the ids in these tensors signify heterogeneous ids.
* @param negative_ratio The ratio of the number of negative samples to
* positive samples.
* @param max_node_id The maximum ID of the node to be selected. It
* should correspond to the number of nodes of a specific type.
*
* @return A tuple consisting of two 1D tensors represents the source and
* destination of negative edges. In the context of a heterogeneous
* graph, both the input nodes and the selected nodes are represented
* by heterogeneous IDs. Note that negative refers to false negatives,
* which means the edge could be present or not present in the graph.
*/
std::tuple<torch::Tensor, torch::Tensor> SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const;

/**
* @brief Copy the graph to shared memory.
* @param shared_memory_name The name of the shared memory.
Expand Down
34 changes: 24 additions & 10 deletions graphbolt/src/csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,29 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value())
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
}
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}

std::tuple<torch::Tensor, torch::Tensor>
CSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const {
torch::Tensor pos_src;
std::tie(pos_src, std::ignore) = node_pairs;
auto neg_len = pos_src.size(0) * negative_ratio;
auto neg_src = pos_src.repeat(negative_ratio);
auto neg_dst = torch::randint(0, max_node_id, {neg_len}, pos_src.options());
return std::make_tuple(neg_src, neg_dst);
}

c10::intrusive_ptr<CSCSamplingGraph>
CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
std::tuple<
Expand Down Expand Up @@ -238,7 +251,7 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
Expand All @@ -251,14 +264,12 @@ inline torch::Tensor UniformPick(
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
} else if (replace) {
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
if (replace) {
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
picked_neighbors = torch::randperm(num_neighbors, options);
picked_neighbors = picked_neighbors.slice(0, 0, fanout) + offset;
}
picked_neighbors = torch::randperm(num_neighbors, options);
picked_neighbors = picked_neighbors.slice(0, 0, fanout) + offset;
}
return picked_neighbors;
}
Expand Down Expand Up @@ -286,7 +297,7 @@ inline torch::Tensor UniformPick(
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
Expand Down Expand Up @@ -346,6 +357,9 @@ torch::Tensor PickByEtype(
const auto end = offset + num_neighbors;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < fanouts.size(),
"Etype values exceed the number of fanouts.");
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
Expand Down
3 changes: 3 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ TORCH_LIBRARY(graphbolt, m) {
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
.def(
"sample_negative_edges_uniform",
&CSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory);
m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .feature_fetcher import *
from .copy_to import *
from .dataset import *
from .impl import *
from .dataloader import *
from .subgraph_sampler import *


Expand Down
23 changes: 23 additions & 0 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
"""Graph Bolt DataLoaders"""

import torch.utils.data


class SingleProcessDataLoader(torch.utils.data.DataLoader):
"""Single process DataLoader.
Iterates over the data pipeline in the main process.
Parameters
----------
datapipe : DataPipe
The data pipeline.
"""

# In the single process dataloader case, we don't need to do any
# modifications to the datapipe, and we just PyTorch's native
# dataloader as-is.
#
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in MinibatchSampler.
def __init__(self, datapipe):
super().__init__(datapipe, batch_size=None, num_workers=0)
14 changes: 8 additions & 6 deletions python/dgl/graphbolt/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""GraphBolt Dataset."""

from typing import List

from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict

Expand Down Expand Up @@ -29,16 +31,16 @@ class Dataset:
generate a subgraph.
"""

def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training sets."""
raise NotImplementedError

def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation sets."""
raise NotImplementedError

def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test sets."""
raise NotImplementedError

def graph(self) -> object:
Expand Down
96 changes: 96 additions & 0 deletions python/dgl/graphbolt/feature_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Feature store for GraphBolt."""
from typing import List, Optional

import numpy as np
import pydantic
import pydantic_yaml

import torch


Expand Down Expand Up @@ -134,3 +140,93 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
f"but got {ids.shape[0]} and {value.shape[0]}."
)
self._tensor[ids] = value


# [TODO] Move code to 'impl/' and separate OnDisk-related code to another file.
class FeatureDataFormatEnum(pydantic_yaml.YamlStrEnum):
"""Enum of feature data format."""

TORCH = "torch"
NUMPY = "numpy"


class FeatureDataDomainEnum(pydantic_yaml.YamlStrEnum):
"""Enum of feature data domain."""

NODE = "node"
EDGE = "edge"
GRAPH = "graph"


class OnDiskFeatureData(pydantic.BaseModel):
r"""The description of an on-disk feature."""
domain: FeatureDataDomainEnum
type: Optional[str]
name: str
format: FeatureDataFormatEnum
path: str
in_memory: Optional[bool] = True


def load_feature_stores(feat_data: List[OnDiskFeatureData]):
r"""Load feature stores from disk.
The feature stores are described by the `feat_data`. The `feat_data` is a
list of `OnDiskFeatureData`.
For a feature store, its format must be either "pt" or "npy" for Pytorch or
Numpy formats. If the format is "pt", the feature store must be loaded in
memory. If the format is "npy", the feature store can be loaded in memory or
on disk.
Parameters
----------
feat_data : List[OnDiskFeatureData]
The description of the feature stores.
Returns
-------
dict
The loaded feature stores. The keys are the names of the feature stores,
and the values are the feature stores.
Examples
--------
>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> edge_label = torch.tensor([1, 2, 3])
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> np.save("/tmp/node_feat.npy", node_feat.numpy())
>>> feat_data = [
... gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
... name="label", format="torch", path="/tmp/edge_label.pt",
... in_memory=True),
... gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
... format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> gb.load_feature_stores(feat_data)
... {("edge", "author:writes:paper", "label"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4df0>, ("node", "paper", "feat"):
... <dgl.graphbolt.feature_store.TorchBasedFeatureStore object at
... 0x7ff093cb4dc0>}
"""
feat_stores = {}
for spec in feat_data:
key = (spec.domain, spec.type, spec.name)
if spec.format == "torch":
assert spec.in_memory, (
f"Pytorch tensor can only be loaded in memory, "
f"but the feature {key} is loaded on disk."
)
feat_stores[key] = TorchBasedFeatureStore(torch.load(spec.path))
elif spec.format == "numpy":
mmap_mode = "r+" if not spec.in_memory else None
feat_stores[key] = TorchBasedFeatureStore(
torch.as_tensor(np.load(spec.path, mmap_mode=mmap_mode))
)
else:
raise ValueError(f"Unknown feature format {spec.format}")
return feat_stores
66 changes: 62 additions & 4 deletions python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def sample_neighbors(
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
], "Fanouts should have the same number of elements as etypes or \
should have a length of 1."
if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
Expand All @@ -279,10 +287,6 @@ def sample_neighbors(
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if self.metadata and self.metadata.edge_type_to_id:
assert len(self.metadata.edge_type_to_id) == fanouts.size(
0
), "Fanouts should have the same number of elements as etypes."
if probs_or_mask is not None:
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
Expand All @@ -300,6 +304,60 @@ def sample_neighbors(
nodes, fanouts.tolist(), replace, return_eids, probs_or_mask
)

def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio
):
"""
Sample negative edges by randomly choosing negative source-destination
pairs according to a uniform distribution. For each edge ``(u, v)``,
it is supposed to generate `negative_ratio` pairs of negative edges
``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
the graph.
Parameters
----------
edge_type: Tuple[str]
The type of edges in the provided node_pairs. Any negative edges
sampled will also have the same type. If set to None, it will be
considered as a homogeneous graph.
node_pairs : Tuple[Tensor]
A tuple of two 1D tensors that represent the source and destination
of positive edges, with 'positive' indicating that these edges are
present in the graph. It's important to note that within the
context of a heterogeneous graph, the ids in these tensors signify
heterogeneous ids.
negative_ratio: int
The ratio of the number of negative samples to positive samples.
Returns
-------
Tuple[Tensor]
A tuple consisting of two 1D tensors represents the source and
destination of negative edges. In the context of a heterogeneous
graph, both the input nodes and the selected nodes are represented
by heterogeneous IDs, and the formed edges are of the input type
`edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph.
"""
if edge_type:
assert (
self.node_type_offset is not None
), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type."
_, _, dst_node_type = edge_type
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id]
)
else:
max_node_id = self.num_nodes
return self._c_csc_graph.sample_negative_edges_uniform(
node_pairs,
negative_ratio,
max_node_id,
)

def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory.
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Implementation of GraphBolt."""
from .ondisk_dataset import *
Loading

0 comments on commit 022b843

Please sign in to comment.