-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add graph apis #40809
Add graph apis #40809
Conversation
// Note(daisiming): If using buffer hashtable, we must ensure the number of | ||
// nodes of | ||
// the input graph should be no larger than maximum(int32). | ||
AddInput("HashTable_Value", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名中间建议不加下划线
"X": x, | ||
"Neighbors": neighbors, | ||
"Count": count, | ||
"HashTable_Value": None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API都不支持传这两个值进来的话,这两个Dispensable的输入在什么情况下会用到
c53da45
to
ab4f462
Compare
… add_graph_apis
… add_graph_apis
… add_graph_apis
should be the same with `x`. | ||
count (Tensor): The neighbor count of the input nodes `x`. And the | ||
data type should be int32. | ||
value_buffer (Tensor|None): Value buffer for hashtable. The data type should |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value buffer 和 index buffer这块和name的统一一下描述吧,要不为None,要么是optional
""" | ||
if flag_buffer_hashtable: | ||
if value_buffer is None or index_buffer is None: | ||
raise ValueError(f"`value_buffer` and `index_buffer` should not" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块buffer的设计具体有测试增加的显存量吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在实验中有具体看过,增加的显存量和图节点数量相关,而且采用int32的范围填充,所以实际不会耗费很多显存。比较担心的是图节点数量超过int32最大值,buffer方法就可能不太适用了,所以用户也可以采用非buffer的方式来reindex。
thrust::transform( | ||
output_count, output_count + bs, output_count, MaxFunctor(sample_size)); | ||
} | ||
int total_sample_num = thrust::reduce(output_count, output_count + bs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块有疑问,如果sample size这里< 0, 看起来sample size这个变量不太可控,会有随机性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果sample_size < 0的话,就默认采样所有邻居了,确实有一定随机性。主要是这个设计可以满足PGL那边的一些直接返回邻居的API。
constexpr int TILE_SIZE = BLOCK_WARPS * 16; | ||
const dim3 block(WARP_SIZE, BLOCK_WARPS); | ||
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); | ||
SampleKernel<T, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来普通sampler是没有一个hash table版本,不是一个buffer版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,区分了两个采样版本。因为fisher_yates采样依赖于一个和边的数量相同的buffer,占的显存会比较多一些,所以也保留原来的采样方式。
} | ||
|
||
template <typename T> | ||
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在UVA模式下这种访存效率看起来不是特别高, 后续改成一个warp访问
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,下个PR修改看看。
ReindexSrcOutput<T><<<grid, block, 0, dev_ctx.stream()>>>( | ||
thrust::raw_pointer_cast(src_outputs), num_edges, hashtable_value); | ||
|
||
ResetBufferHashTable<T, Context>(dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
比较好奇的是 为啥要写一个kernel进行reset,为啥不能对一块显存直接reset
row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs); | ||
out->Resize({static_cast<int>(output.size())}); | ||
T* out_data = dev_ctx.template Alloc<T>(out); | ||
std::copy(output.begin(), output.end(), out_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来这里也不用直接拷贝,直接使用ResetHolder,或者ShareDataWith,看std::shared_ptr<phi::Allocation> holder_
是一个shared_ptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一下个PR修改。
out_count->Resize({bs}); | ||
int* out_count_data = dev_ctx.template Alloc<int>(out_count); | ||
std::copy(output_count.begin(), output_count.end(), out_count_data); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
std::copy(dst.begin(), dst.end(), reindex_dst_data); | ||
out_nodes->Resize({static_cast<int>(unique_nodes.size())}); | ||
T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes); | ||
std::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看看这些拷贝是否需要
out->set_dtype(row.dtype()); | ||
out_count->set_dims({-1}); | ||
out_count->set_dtype(DataType::INT32); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图测试过了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测里有静态图的测试,应该是ok的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
Add graph_sample_neighbors API and graph_reindex API.