Skip to content

Commit

Permalink
[XPU] support fc per channel quant (#9323)
Browse files Browse the repository at this point in the history
  • Loading branch information
newway authored Aug 23, 2022
1 parent 17b6abd commit 261822c
Show file tree
Hide file tree
Showing 17 changed files with 412 additions and 184 deletions.
16 changes: 10 additions & 6 deletions lite/backends/xpu/target_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ void TargetWrapperXPU::MemcpySync(void* dst,

template <typename Tcpu, typename Txpu>
XPUQuantData TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight(
const Tcpu* cpu_data, const DDimLite& dims, bool data_transpose) {
const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t max_ptr_len) {
CHECK(quantizer_.get());
return quantizer_->quant<Tcpu, Txpu>(cpu_data, dims, data_transpose);
return quantizer_->quant<Tcpu, Txpu>(
cpu_data, dims, data_transpose, max_ptr_len);
}

void TargetWrapperXPU::ScatterL3Cache(
Expand Down Expand Up @@ -145,16 +149,16 @@ void TargetWrapperXPU::FreeL3Cache() {

template XPUQuantData
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<float, float>(
const float*, const DDimLite&, bool);
const float*, const DDimLite&, bool, size_t);
template XPUQuantData
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<float, int16_t>(
const float*, const DDimLite&, bool);
const float*, const DDimLite&, bool, size_t);
template XPUQuantData
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<float, int8_t>(
const float*, const DDimLite&, bool);
const float*, const DDimLite&, bool, size_t);
template XPUQuantData
TargetWrapperXPU::ConvertCPUWeightToXPUQuantWeight<int8_t, int8_t>(
const int8_t*, const DDimLite&, bool);
const int8_t*, const DDimLite&, bool, size_t);

// xpu context
LITE_THREAD_LOCAL std::shared_ptr<xdnn::Context> TargetWrapperXPU::tls_raw_ctx_{
Expand Down
3 changes: 2 additions & 1 deletion lite/backends/xpu/target_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class TargetWrapper<TARGET(kXPU)> {
template <typename Tcpu, typename Txpu>
static XPUQuantData ConvertCPUWeightToXPUQuantWeight(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose);
bool data_transpose,
size_t max_ptr_len);

static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_.get() == nullptr) {
Expand Down
34 changes: 22 additions & 12 deletions lite/backends/xpu/xpu_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ template <
void XPUQuantizer::ConvertWithQuant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key) {
size_t hashed_key,
size_t max_ptr_len) {
LOG(FATAL) << "Not support for Tcpu is " << CppTypeToString<Tcpu>();
}

Expand All @@ -123,7 +124,8 @@ template <
void XPUQuantizer::ConvertWithQuant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key) {
size_t hashed_key,
size_t max_ptr_len) {
// transpose
const Tcpu* cpu_ptr = nullptr;
int numel = dims.production();
Expand All @@ -140,7 +142,7 @@ void XPUQuantizer::ConvertWithQuant(const Tcpu* cpu_data,
XPUScratchPadGuard weight_max_guard;
XPUScratchPadGuard quant_weight_guard;
float max_val = paddle::lite::xpu::math::FindMaxAbs(cpu_ptr, numel);
int max_ptr_size = XPUMemory::get_max_ptr_size();
size_t max_ptr_size = max_ptr_len;
std::vector<float> max_vec(max_ptr_size, max_val);
weight_max_guard =
std::move(XPUMemory::MallocScratchPad(max_ptr_size * sizeof(float)));
Expand All @@ -162,11 +164,12 @@ template <typename T>
void XPUQuantizer::ConvertWithoutQuant(const T* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key) {
size_t hashed_key,
size_t max_ptr_len) {
// transpose
const T* cpu_ptr = nullptr;
int numel = dims.production();
int max_ptr_size = XPUMemory::get_max_ptr_size();
size_t max_ptr_size = max_ptr_len;
std::vector<T> transpose_data(numel, 0);
if (data_transpose) {
CHECK(dims.size() == 2) << "Not support: dims.size = " << dims.size();
Expand All @@ -178,8 +181,9 @@ void XPUQuantizer::ConvertWithoutQuant(const T* cpu_data,
}
// copy to XPU
XPUScratchPadGuard weight_max_guard(new XPUScratchPad(nullptr, 0));
if (std::is_same<T, int8_t>::value) {
if (std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value) {
// prepare max_w space for slim int8 quant
// just allocate buffer, set max value in kernel
weight_max_guard =
std::move(XPUMemory::MallocScratchPad(max_ptr_size * sizeof(float)));
}
Expand All @@ -196,7 +200,8 @@ void XPUQuantizer::ConvertWithoutQuant(const T* cpu_data,
template <typename Tcpu, typename Txpu>
XPUQuantData XPUQuantizer::quant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose) {
bool data_transpose,
size_t max_ptr_len) {
int numel = dims.production();
const std::string cpu_dtype = CppTypeToString<Tcpu>();
const std::string xpu_dtype = CppTypeToString<Txpu>();
Expand All @@ -206,7 +211,8 @@ XPUQuantData XPUQuantizer::quant(const Tcpu* cpu_data,
<< ", precision=" << precision << ", transpose=" << data_transpose
<< ", hashed_key=" << hashed_key;
if (weight_cache_.find(hashed_key) == weight_cache_.end()) {
ConvertWrapper<Tcpu, Txpu>(cpu_data, dims, data_transpose, hashed_key);
ConvertWrapper<Tcpu, Txpu>(
cpu_data, dims, data_transpose, hashed_key, max_ptr_len);
}

float* max_ptr =
Expand All @@ -218,15 +224,19 @@ XPUQuantData XPUQuantizer::quant(const Tcpu* cpu_data,

template XPUQuantData XPUQuantizer::quant<float, float>(const float*,
const DDimLite&,
bool);
bool,
size_t);
template XPUQuantData XPUQuantizer::quant<float, int16_t>(const float*,
const DDimLite&,
bool);
bool,
size_t);
template XPUQuantData XPUQuantizer::quant<float, int8_t>(const float*,
const DDimLite&,
bool);
bool,
size_t);
template XPUQuantData XPUQuantizer::quant<int8_t, int8_t>(const int8_t*,
const DDimLite&,
bool);
bool,
size_t);
} // namespace lite
} // namespace paddle
24 changes: 16 additions & 8 deletions lite/backends/xpu/xpu_quantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class XPUQuantizer {
template <typename Tcpu, typename Txpu>
XPUQuantData quant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose);
bool data_transpose,
size_t max_ptr_len);

private:
template <typename T>
Expand All @@ -49,7 +50,8 @@ class XPUQuantizer {
void ConvertWithoutQuant(const T* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key);
size_t hashed_key,
size_t max_ptr_len);

template <typename Tcpu,
typename Txpu,
Expand All @@ -58,7 +60,8 @@ class XPUQuantizer {
void ConvertWithQuant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key);
size_t hashed_key,
size_t max_ptr_len);

template <typename Tcpu,
typename Txpu,
Expand All @@ -67,7 +70,8 @@ class XPUQuantizer {
void ConvertWithQuant(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key);
size_t hashed_key,
size_t max_ptr_len);

template <typename Tcpu,
typename Txpu,
Expand All @@ -76,8 +80,10 @@ class XPUQuantizer {
void ConvertWrapper(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key) {
ConvertWithQuant<Tcpu, Txpu>(cpu_data, dims, data_transpose, hashed_key);
size_t hashed_key,
size_t max_ptr_len) {
ConvertWithQuant<Tcpu, Txpu>(
cpu_data, dims, data_transpose, hashed_key, max_ptr_len);
}

template <typename Tcpu,
Expand All @@ -87,8 +93,10 @@ class XPUQuantizer {
void ConvertWrapper(const Tcpu* cpu_data,
const DDimLite& dims,
bool data_transpose,
size_t hashed_key) {
ConvertWithoutQuant<Tcpu>(cpu_data, dims, data_transpose, hashed_key);
size_t hashed_key,
size_t max_ptr_len) {
ConvertWithoutQuant<Tcpu>(
cpu_data, dims, data_transpose, hashed_key, max_ptr_len);
}

// cpu data to xpu quant data
Expand Down
53 changes: 53 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ class XPUFcFuser : public FuseBase {
output_name = matched.at("mul_out")->arg()->name;
output_node_name = "mul_out";
}
bool per_channel = false;
int weight_scale_size = 1;
auto* op_info = matched.at("mul")->stmt()->op_info();
auto mul_input_y_name = op_info->Input("Y").front();
auto mul_y_shape = scope->FindMutableTensor(mul_input_y_name)->dims();
CHECK_EQ(mul_y_shape.size(), 2) << "mul_y_shape.size: "
<< mul_y_shape.size();
const bool quant = op_info->HasAttr("enable_int8") &&
op_info->GetAttr<bool>("enable_int8");
op_desc.SetAttr<bool>("enable_int8", quant);
// X0_scale is already in op_desc when copy from mul
if (quant) {
CHECK(op_info->HasAttr("Y0_scale")) << "quant model no Y0_scale";
weight_scale_size =
op_info->GetAttr<std::vector<float>>("Y0_scale").size();
CHECK_EQ(weight_scale_size, mul_y_shape[1])
<< "weight_scale_size: " << weight_scale_size
<< ", mul_y_shape:" << mul_y_shape;
CHECK_GE(weight_scale_size, 1) << weight_scale_size;
std::vector<float> weight_max;
if (is_per_tensor(op_info->GetAttr<std::vector<float>>("Y0_scale"))) {
per_channel = false;
VLOG(3) << "xpu fc per tensor";
weight_max.push_back(
op_info->GetAttr<std::vector<float>>("Y0_scale")[0] * 127);
} else {
per_channel = true;
VLOG(3) << "xpu fc per channel, first channel max:"
<< op_info->GetAttr<std::vector<float>>("Y0_scale")[0] * 127
<< ", last channel max: "
<< op_info->GetAttr<std::vector<float>>(
"Y0_scale")[weight_scale_size - 1] *
127;
for (auto wm : op_info->GetAttr<std::vector<float>>("Y0_scale")) {
weight_max.push_back(wm * 127);
}
}
VLOG(3) << "weight_max size:" << weight_max.size();
op_desc.SetAttr<std::vector<float>>("Y0_max", weight_max);
op_desc.SetAttr<bool>("per_channel", per_channel);
}
op_desc.SetOutput("Output", {output_name});
std::map<std::string, int> act_map{{"linear", 0},
{"relu", 1},
Expand Down Expand Up @@ -171,6 +212,18 @@ class XPUFcFuser : public FuseBase {
bool with_bias_;
std::string act_type_;
std::string mul_type_;
bool is_per_tensor(const std::vector<float>& weight_max) {
bool per_tensor = true;
CHECK_GT(weight_max.size(), 0) << "fc channel size: " << weight_max.size();
auto first = weight_max[0];
for (int i = 1; i < weight_max.size(); ++i) {
if (std::abs(first - weight_max[i]) > 1e-6) {
per_tensor = false;
break;
}
}
return per_tensor;
}
};

} // namespace fusion
Expand Down
Loading

0 comments on commit 261822c

Please sign in to comment.