Skip to content

Commit

Permalink
Add graph apis (#40809)
Browse files Browse the repository at this point in the history
* Add graph_reindex API

* add graph_sample_neighbors api

* Add buffer

* delete VLOG

* delete thrust::copy for output

* add ShareDataWith

* delete graph_reindex hashtable output

* add graph_reindex dispensable

* add reindex unittest, move memset to cuda kernel, change api

* fix conflict

* add reindex buffer for gpu version note

* fix conflicts for op_func_generator

* Add fisher_yates sampling, add dispensable, change infermeta

* add dtype for edge_id

* fix rocm ci and static check ci

* add unittest

* fix unittest

* fix unittest

* fix bug
  • Loading branch information
DesmonDay authored Apr 2, 2022
1 parent 36f97cd commit b0398c8
Show file tree
Hide file tree
Showing 20 changed files with 2,210 additions and 0 deletions.
77 changes: 77 additions & 0 deletions paddle/fluid/operators/graph_reindex_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

class GraphReindexOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};

class GraphReindexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The destination nodes of the input graph.");
AddInput("Neighbors", "The neighbor nodes of the destination nodes `X`.");
AddInput("Count", "The number of neighbor nodes of each destination node.");
// Note(daisiming): If using buffer hashtable, we must ensure the number of
// nodes of the input graph should be no larger than maximum(int32).
AddInput("HashTable_Value",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddInput("HashTable_Index",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddAttr<bool>("flag_buffer_hashtable",
"Define whether using the buffer hashtable.")
.SetDefault(false);
AddOutput("Reindex_Src",
"The source node index of graph edges after reindex.");
AddOutput("Reindex_Dst",
"The destination node index of graph edges after reindex.");
AddOutput("Out_Nodes", "The original index of graph nodes before reindex");

AddComment(R"DOC(
Graph Reindex operator.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(graph_reindex, GraphReindexInferShapeFunctor,
PD_INFER_META(phi::GraphReindexInferMeta));

REGISTER_OPERATOR(
graph_reindex, ops::GraphReindexOP, ops::GraphReindexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GraphReindexInferShapeFunctor);
82 changes: 82 additions & 0 deletions paddle/fluid/operators/graph_sample_neighbors_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

class GraphSampleNeighborsOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Row"),
ctx.device_context());
}
};

class GraphSampleNeighborsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Row",
"One of the components of the CSC format of the input graph.");
AddInput("Col_Ptr",
"One of the components of the CSC format of the input graph.");
AddInput("X", "The input center nodes index tensor.");
AddInput("Eids", "The edge ids of the input graph.").AsDispensable();
AddInput("Perm_Buffer", "Permutation buffer for fisher-yates sampling.")
.AsDispensable();
AddOutput("Out", "The neighbors of input nodes X after sampling.");
AddOutput("Out_Count",
"The number of sample neighbors of input nodes respectively.");
AddOutput("Out_Eids", "The eids of the sample edges");
AddAttr<int>(
"sample_size", "The sample size of graph sample neighbors method. ",
"Set default value as -1, means return all neighbors of nodes.")
.SetDefault(-1);
AddAttr<bool>("return_eids",
"Whether to return the eid of the sample edges.")
.SetDefault(false);
AddAttr<bool>("flag_perm_buffer",
"Using the permutation for fisher-yates sampling in GPU"
"Set default value as false, means not using it.")
.SetDefault(false);
AddComment(R"DOC(
Graph Learning Sampling Neighbors operator, for graphsage sampling method.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(graph_sample_neighbors,
GraphSampleNeighborsInferShapeFunctor,
PD_INFER_META(phi::GraphSampleNeighborsInferMeta));

REGISTER_OPERATOR(
graph_sample_neighbors, ops::GraphSampleNeighborsOP,
ops::GraphSampleNeighborsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GraphSampleNeighborsInferShapeFunctor);
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/op_function_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}},
{"crf_decoding", {"Emission", "Transition", "Label", "Length"}},
{"chunk_eval", {"Inference", "Label", "SeqLength"}},
{"graph_reindex",
{"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}},
{"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}},
};

// NOTE(zhiqiu): Like op_ins_map.
Expand Down
97 changes: 97 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,103 @@ void WhereInferMeta(const MetaTensor& condition,
out->share_meta(x);
}

void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
paddle::optional<const MetaTensor&> hashtable_value,
paddle::optional<const MetaTensor&> hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes) {
auto GraphReindexShapeCheck = [](const phi::DDim& dims,
std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};

GraphReindexShapeCheck(x.dims(), "X");
GraphReindexShapeCheck(neighbors.dims(), "Neighbors");
GraphReindexShapeCheck(count.dims(), "Count");
if (flag_buffer_hashtable) {
GraphReindexShapeCheck(hashtable_value->dims(), "HashTable_Value");
GraphReindexShapeCheck(hashtable_index->dims(), "HashTable_Index");
}

reindex_src->set_dims({-1});
reindex_src->set_dtype(neighbors.dtype());
reindex_dst->set_dims({-1});
reindex_dst->set_dtype(neighbors.dtype());
out_nodes->set_dims({-1});
out_nodes->set_dtype(x.dtype());
}

void GraphSampleNeighborsInferMeta(
const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
paddle::optional<const MetaTensor&> eids,
paddle::optional<const MetaTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};

GSNShapeCheck(row.dims(), "Row");
GSNShapeCheck(col_ptr.dims(), "Col_Ptr");
GSNShapeCheck(x.dims(), "X");
if (return_eids) {
GSNShapeCheck(eids->dims(), "Eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
if (flag_perm_buffer) {
GSNShapeCheck(perm_buffer->dims(), "Perm_Buffer");
}

out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}

void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box,
const MetaTensor& gt_label,
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& y,
MetaTensor* out);

void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
paddle::optional<const MetaTensor&> hashtable_value,
paddle::optional<const MetaTensor&> hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes);

void GraphSampleNeighborsInferMeta(
const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
paddle::optional<const MetaTensor&> eids,
paddle::optional<const MetaTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);

void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box,
const MetaTensor& gt_label,
Expand Down
84 changes: 84 additions & 0 deletions paddle/phi/kernels/cpu/graph_reindex_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <unordered_map>
#include <vector>

#include "paddle/phi/kernels/graph_reindex_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& neighbors,
const DenseTensor& count,
paddle::optional<const DenseTensor&> hashtable_value,
paddle::optional<const DenseTensor&> hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes) {
const T* x_data = x.data<T>();
const T* neighbors_data = neighbors.data<T>();
const int* count_data = count.data<int>();
const int bs = x.dims()[0];
const int num_edges = neighbors.dims()[0];

std::unordered_map<T, T> node_map;
std::vector<T> unique_nodes;
int reindex_id = 0;
for (int i = 0; i < bs; i++) {
T node = x_data[i];
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
// Reindex Src
std::vector<T> src(num_edges);
std::vector<T> dst(num_edges);
for (int i = 0; i < num_edges; i++) {
T node = neighbors_data[i];
if (node_map.find(node) == node_map.end()) {
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
src[i] = node_map[node];
}
// Reindex Dst
int cnt = 0;
for (int i = 0; i < bs; i++) {
for (int j = 0; j < count_data[i]; j++) {
T node = x_data[i];
dst[cnt++] = node_map[node];
}
}

reindex_src->Resize({num_edges});
T* reindex_src_data = dev_ctx.template Alloc<T>(reindex_src);
std::copy(src.begin(), src.end(), reindex_src_data);
reindex_dst->Resize({num_edges});
T* reindex_dst_data = dev_ctx.template Alloc<T>(reindex_dst);
std::copy(dst.begin(), dst.end(), reindex_dst_data);
out_nodes->Resize({static_cast<int>(unique_nodes.size())});
T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes);
std::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data);
}

} // namespace phi

PD_REGISTER_KERNEL(
graph_reindex, CPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {}
Loading

0 comments on commit b0398c8

Please sign in to comment.