diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc index 2925a86a127f..98597bd473dc 100644 --- a/src/operator/contrib/dgl_graph.cc +++ b/src/operator/contrib/dgl_graph.cc @@ -1393,5 +1393,196 @@ Example:: .set_attr("FComputeEx", DGLAdjacencyForwardEx) .add_argument("data", "NDArray-or-Symbol", "Input ndarray"); +///////////////////////// Compact subgraphs /////////////////////////// + +struct SubgraphCompactParam : public dmlc::Parameter { + int num_args; + bool return_mapping; + nnvm::Tuple graph_sizes; + DMLC_DECLARE_PARAMETER(SubgraphCompactParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments."); + DMLC_DECLARE_FIELD(return_mapping) + .describe("Return mapping of vid and eid between the subgraph and the parent graph."); + DMLC_DECLARE_FIELD(graph_sizes) + .describe("the number of vertices in each graph."); + } +}; // struct SubgraphCompactParam + +DMLC_REGISTER_PARAMETER(SubgraphCompactParam); + +static inline size_t get_num_graphs(const SubgraphCompactParam ¶ms) { + // Each CSR needs a 1D array to store the original vertex Id for each row. + return params.num_args / 2; +} + +static void CompactSubgraph(const NDArray &csr, const NDArray &vids, + const NDArray &out_csr, size_t graph_size) { + TBlob in_idx_data = csr.aux_data(csr::kIdx); + TBlob in_ptr_data = csr.aux_data(csr::kIndPtr); + const dgl_id_t *indices_in = in_idx_data.dptr(); + const dgl_id_t *indptr_in = in_ptr_data.dptr(); + const dgl_id_t *row_ids = vids.data().dptr(); + size_t num_elems = csr.aux_data(csr::kIdx).shape_.Size(); + // The last element in vids is the actual number of vertices in the subgraph. + CHECK_EQ(vids.shape()[0], in_ptr_data.shape_[0]); + CHECK_EQ((size_t) row_ids[vids.shape()[0] - 1], graph_size); + + // Prepare the Id map from the original graph to the subgraph. + std::unordered_map id_map; + id_map.reserve(graph_size); + for (size_t i = 0; i < graph_size; i++) { + id_map.insert(std::pair(row_ids[i], i)); + CHECK_NE(row_ids[i], -1); + } + + TShape nz_shape(1); + nz_shape[0] = num_elems; + TShape indptr_shape(1); + CHECK_EQ(out_csr.shape()[0], graph_size); + indptr_shape[0] = graph_size + 1; + CHECK_GE(in_ptr_data.shape_[0], indptr_shape[0]); + + out_csr.CheckAndAllocData(nz_shape); + out_csr.CheckAndAllocAuxData(csr::kIdx, nz_shape); + out_csr.CheckAndAllocAuxData(csr::kIndPtr, indptr_shape); + + dgl_id_t *indices_out = out_csr.aux_data(csr::kIdx).dptr(); + dgl_id_t *indptr_out = out_csr.aux_data(csr::kIndPtr).dptr(); + dgl_id_t *sub_eids = out_csr.data().dptr(); + std::copy(indptr_in, indptr_in + indptr_shape[0], indptr_out); + for (int64_t i = 0; i < nz_shape[0]; i++) { + dgl_id_t old_id = indices_in[i]; + auto it = id_map.find(old_id); + CHECK(it != id_map.end()); + indices_out[i] = it->second; + sub_eids[i] = i; + } +} + +static void SubgraphCompactComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + int num_g = get_num_graphs(params); +#pragma omp parallel for + for (int i = 0; i < num_g; i++) { + CompactSubgraph(inputs[i], inputs[i + num_g], outputs[i], params.graph_sizes[i]); + } +} + +static bool SubgraphCompactStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + size_t num_g = get_num_graphs(params); + CHECK_EQ(num_g * 2, in_attrs->size()); + // These are the input subgraphs. + for (size_t i = 0; i < num_g; i++) + CHECK_EQ(in_attrs->at(i), kCSRStorage); + // These are the vertex Ids in the original graph. + for (size_t i = 0; i < num_g; i++) + CHECK_EQ(in_attrs->at(i + num_g), kDefaultStorage); + + bool success = true; + *dispatch_mode = DispatchMode::kFComputeEx; + for (size_t i = 0; i < out_attrs->size(); i++) { + if (!type_assign(&(*out_attrs)[i], mxnet::kCSRStorage)) + success = false; + } + return success; +} + +static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + size_t num_g = get_num_graphs(params); + CHECK_EQ(num_g * 2, in_attrs->size()); + // These are the input subgraphs. + for (size_t i = 0; i < num_g; i++) { + CHECK_EQ(in_attrs->at(i).ndim(), 2U); + CHECK_GE(in_attrs->at(i)[0], params.graph_sizes[i]); + CHECK_GE(in_attrs->at(i)[1], params.graph_sizes[i]); + } + // These are the vertex Ids in the original graph. + for (size_t i = 0; i < num_g; i++) { + CHECK_EQ(in_attrs->at(i + num_g).ndim(), 1U); + CHECK_GE(in_attrs->at(i + num_g)[0], params.graph_sizes[i]); + } + + for (size_t i = 0; i < num_g; i++) { + TShape gshape(2); + gshape[0] = params.graph_sizes[i]; + gshape[1] = params.graph_sizes[i]; + out_attrs->at(i) = gshape; + if (params.return_mapping) + out_attrs->at(i + num_g) = gshape; + } + return true; +} + +static bool SubgraphCompactType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + for (size_t i = 0; i < in_attrs->size(); i++) { + CHECK_EQ(in_attrs->at(i), mshadow::kInt64); + } + for (size_t i = 0; i < out_attrs->size(); i++) { + out_attrs->at(i) = mshadow::kInt64; + } + return true; +} + +NNVM_REGISTER_OP(_contrib_dgl_graph_compact) +.describe(R"code(This operator compacts a CSR matrix generated by +csr_neighbor_uniform_sample and csr_neighbor_non_uniform_sample. +The CSR matrices generated by these two operators may have many empty +rows at the end. This operator removes these empty rows and empty columns. +Example:: + subgs = mx.nd.contrib.csr_neighbor_uniform_sample(csr, seed, num_hops=1, + num_neighbor=2, max_num_vertices=5) + subg_v = subgs[0] + subg = subgs[1] + compacts = mx.nd.contrib.dgl_graph_compact(subg, subg_v, + graph_sizes=(subg_v[-1].asnumpy()[0])) +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs([](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + int num_varray = get_num_graphs(params); + if (params.return_mapping) + return num_varray * 2; + else + return num_varray; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + size_t num_graphs = get_num_graphs(params); + for (size_t i = 0; i < num_graphs; i++) + names.push_back("graph" + std::to_string(i)); + for (size_t i = 0; i < num_graphs; ++i) + names.push_back("varray" + std::to_string(i)); + return names; +}) +.set_attr("FInferStorageType", SubgraphCompactStorageType) +.set_attr("FInferShape", SubgraphCompactShape) +.set_attr("FInferType", SubgraphCompactType) +.set_attr("FComputeEx", SubgraphCompactComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("graph_data", "NDArray-or-Symbol[]", "Input graphs and input vertex Ids.") +.add_arguments(SubgraphCompactParam::__FIELDS__()); + } // namespace op } // namespace mxnet