Skip to content

Commit

Permalink
fix sequence_pool_cvm show click bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xymyeah committed Jul 4, 2024
1 parent 9632d59 commit 7c011dd
Showing 1 changed file with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,36 @@ static __device__ inline void memset_value_float(float* lm, int size, float valu
mfence_lm();
}

// cvm_offset == 0
template <typename T, bool cvm_is_zero, typename T2>
struct showclick_quant {
static __device__ inline void copy(T* sum, T2* sum_show_clk) {
// do nothing
}
};

// cvm_offset > 0
template <typename T, typename T2>
struct showclick_quant<T, true, T2> {
static __device__ inline void copy(T* sum, T2* sum_show_clk) {
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
}
};

#define KERNEL_SHOW_CLICK_COPY(kernel, cvm_is_zero, T, T2, args...) \
if (cvm_is_zero) { \
kernel<T, true, T2>::copy(args); \
} else { \
kernel<T, false, T2>::copy(args); \
}

// normal
// need_filter:false && quant_ratio_valid:false
template <typename T, bool need_filter, bool embed_threshold_filter,
bool embedx_concate_filter, bool quant_ratio_valid, typename T2>
struct pooling_engine {
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -66,7 +90,7 @@ struct pooling_engine {
template <typename T, bool need_filter, bool embed_threshold_filter,
bool embedx_concate_filter, bool quant_ratio_valid, typename T2>
struct pooling_engine_with_large_dim {
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -104,7 +128,7 @@ struct pooling_engine_with_large_dim {
// need_filter:true && embed_threshold_filter:true && embedx_concate_filter:false
template <typename T, bool quant_ratio_valid, typename T2>
struct pooling_engine<T, true, true, false, quant_ratio_valid, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -177,17 +201,17 @@ struct pooling_engine<T, true, true, false, quant_ratio_valid, T2>{
sum_show_clk[1] += local_x[1];
}

// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

// embed quant filter
// need_filter:true && embed_threshold_filter:true && embedx_concate_filter:false
template <typename T, bool quant_ratio_valid, typename T2>
struct pooling_engine_with_large_dim<T, true, true, false, quant_ratio_valid, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -262,17 +286,17 @@ struct pooling_engine_with_large_dim<T, true, true, false, quant_ratio_valid, T2
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];
}
// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

// quant need filter
// need_filter:true && embed_threshold_filter:false
template <typename T, bool embedx_concate_filter, bool quant_ratio_valid, typename T2>
struct pooling_engine<T, true, false, embedx_concate_filter, quant_ratio_valid, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -327,17 +351,17 @@ struct pooling_engine<T, true, false, embedx_concate_filter, quant_ratio_valid,
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];
}
// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

// quant need filter
// need_filter:true && embed_threshold_filter:false
template <typename T, bool embedx_concate_filter, bool quant_ratio_valid, typename T2>
struct pooling_engine_with_large_dim<T, true, false, embedx_concate_filter, quant_ratio_valid, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -393,17 +417,17 @@ struct pooling_engine_with_large_dim<T, true, false, embedx_concate_filter, quan
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];
}
// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

// quant not filter
// need_filter:false && quant_ratio_valid:true
template <typename T, bool embed_threshold_filter, bool embedx_concate_filter, typename T2>
struct pooling_engine<T, false, embed_threshold_filter, embedx_concate_filter, true, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -452,17 +476,17 @@ struct pooling_engine<T, false, embed_threshold_filter, embedx_concate_filter, t
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];
}
// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

// quant not filter
// need_filter:false && quant_ratio_valid:true
template <typename T, bool embed_threshold_filter, bool embedx_concate_filter, typename T2>
struct pooling_engine_with_large_dim<T, false, embed_threshold_filter, embedx_concate_filter, true, T2>{
static __device__ void sum_pooling(T* local_x,
static __device__ inline void sum_pooling(T* local_x,
T* sum,
T2* sum_show_clk,
int len,
Expand Down Expand Up @@ -512,9 +536,9 @@ struct pooling_engine_with_large_dim<T, false, embed_threshold_filter, embedx_co
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];
}
// only for cvm_offset > 0
mfence_lm();
sum[0] = (float)sum_show_clk[0];
sum[1] = (float)sum_show_clk[1];
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, sum, sum_show_clk);
}
};

Expand Down Expand Up @@ -794,18 +818,17 @@ struct do_sum_pooling_and_cvm<T, use_cvm, clk_filter, need_filter, quant_ratio_v
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];

// only for cvm_offset > 0
mfence_lm();
local_result[0] = (float)sum_show_clk[0];
local_result[1] = (float)sum_show_clk[1];

mfence_lm();
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, local_result, sum_show_clk);

// second: cvm
int cur_y_index = seqid * embedx_concate_size * out_dim_size + (embedx_concate_size - 1) * out_dim_size;
cvm_engine<T, true, use_cvm, clk_filter, T2>::concat_cvm(local_result,
out_dim_size, dim_start_offset,
cur_y_index,
cur_y);
mfence();
} else {
// first: sum pool
// copy
Expand All @@ -831,18 +854,17 @@ struct do_sum_pooling_and_cvm<T, use_cvm, clk_filter, need_filter, quant_ratio_v
vstore_lm_float32x16(local_result, v_temp1);
vstore_lm_float32x16(local_result + 16, v_temp2);

// only for cvm_offset > 0
mfence_lm();
local_result[0] = local_x[0];
local_result[1] = local_x[1];

mfence_lm();
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T, local_result, local_x);

// second: cvm
int cur_y_index = seqid * embedx_concate_size * out_dim_size + concate_index * out_dim_size;
cvm_engine<T, true, use_cvm, clk_filter, T2>::concat_cvm(local_result,
out_dim_size, dim_start_offset,
cur_y_index,
cur_y);
mfence();
concate_index += 1;
}
}
Expand All @@ -851,8 +873,8 @@ struct do_sum_pooling_and_cvm<T, use_cvm, clk_filter, need_filter, quant_ratio_v
for (int i = concate_index; i < embedx_concate_size; i++) {
memset_value_float(local_result, local_result_len, padding_value);
int cur_y_index = seqid * embedx_concate_size * out_dim_size + i * out_dim_size;
mfence();
LM2GM_ASYNC(local_result, cur_y + cur_y_index, out_dim_size * sizeof(T));
mfence();
}
}
};
Expand Down Expand Up @@ -943,18 +965,17 @@ struct do_sum_pooling_and_cvm_with_large_dim<T, use_cvm, clk_filter, need_filter
sum_show_clk[0] += local_x[0];
sum_show_clk[1] += local_x[1];

// only for cvm_offset > 0
mfence_lm();
local_result[0] = (float)sum_show_clk[0];
local_result[1] = (float)sum_show_clk[1];

mfence_lm();
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T2, local_result, sum_show_clk);

// second: cvm
int cur_y_index = seqid * embedx_concate_size * out_dim_size + (embedx_concate_size - 1) * out_dim_size;
cvm_engine<T, true, use_cvm, clk_filter, T2>::concat_cvm(local_result,
out_dim_size, dim_start_offset,
cur_y_index,
cur_y);
mfence();
} else {
// first: sum pool

Expand Down Expand Up @@ -984,18 +1005,17 @@ struct do_sum_pooling_and_cvm_with_large_dim<T, use_cvm, clk_filter, need_filter
vstore_lm_float32x16(local_result + k + 16, v_temp2);
}

// only for cvm_offset > 0
mfence_lm();
local_result[0] = local_x[0];
local_result[1] = local_x[1];

mfence_lm();
KERNEL_SHOW_CLICK_COPY(showclick_quant, cvm_offset > 0, T, T, local_result, local_x);

// second: cvm
int cur_y_index = seqid * embedx_concate_size * out_dim_size + concate_index * out_dim_size;
cvm_engine<T, true, use_cvm, clk_filter, T2>::concat_cvm(local_result,
out_dim_size, dim_start_offset,
cur_y_index,
cur_y);
mfence();
concate_index += 1;
}
}
Expand Down

0 comments on commit 7c011dd

Please sign in to comment.