Skip to content

Commit

Permalink
[Framework] Update set_model_from_buffer api (PaddlePaddle#10026)
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Feb 28, 2023
1 parent 5004867 commit 8934039
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 40 deletions.
26 changes: 18 additions & 8 deletions lite/api/light_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,25 @@
namespace paddle {
namespace lite {

void LightPredictor::Build(const std::string& lite_model_file,
bool model_from_memory) {
if (model_from_memory) {
LoadModelNaiveFromMemory(
lite_model_file, scope_.get(), program_desc_.get());
} else {
LoadModelNaiveFromFile(lite_model_file, scope_.get(), program_desc_.get());
}
void LightPredictor::Build(const std::string& lite_model_file) {
LoadModelNaiveFromFile(lite_model_file, scope_.get(), program_desc_.get());
// For weight quantization of post training, load the int8/16 weights
// for optimized model, and dequant it to fp32.
DequantizeWeight();
#ifdef ENABLE_ARM_FP16
// fp16 Weight convert
WeightFP32ToFP16();
#endif
BuildRuntimeProgram(program_desc_, use_low_precision_);
PrepareFeedFetch();
}

void LightPredictor::Build(const char* lite_model_buffer_ptr,
size_t lite_model_buffer_size) {
LoadModelNaiveFromMemory(lite_model_buffer_ptr,
lite_model_buffer_size,
scope_.get(),
program_desc_.get());
// For weight quantization of post training, load the int8/16 weights
// for optimized model, and dequant it to fp32.
DequantizeWeight();
Expand Down
16 changes: 12 additions & 4 deletions lite/api/light_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,20 @@ class LITE_API LightPredictor {
// model file or buffer,`model_from_memory` refers to whther to load model
// from memory.
LightPredictor(const std::string& lite_model_file,
bool model_from_memory = false,
bool use_low_precision = false) {
use_low_precision_ = use_low_precision;
scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(lite_model_file, model_from_memory);
Build(lite_model_file);
}

LightPredictor(const char* lite_model_buffer_ptr,
size_t lite_model_buffer_size,
bool use_low_precision = false) {
use_low_precision_ = use_low_precision;
scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
Build(lite_model_buffer_ptr, lite_model_buffer_size);
}

// NOTE: This is a deprecated API and will be removed in latter release.
Expand Down Expand Up @@ -118,8 +126,8 @@ class LITE_API LightPredictor {
// would be called in Run().
void CheckInputValid();

void Build(const std::string& lite_model_file,
bool model_from_memory = false);
void Build(const std::string& lite_model_file);
void Build(const char* lite_model_buffer_ptr, size_t lite_model_buffer_size);

// NOTE: This is a deprecated API and will be removed in latter release.
void Build(
Expand Down
36 changes: 22 additions & 14 deletions lite/api/light_api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,29 @@ namespace lite {

void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
// LightPredictor Only support NaiveBuffer backend in publish lib
if (config.lite_model_file().empty()) {
raw_predictor_.reset(new LightPredictor(
config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.is_model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer,
(config.precision_mode() == lite_api::LITE_PRECISION_LOW) ? true
: false));
auto use_low_precision =
config.precision_mode() == lite_api::LITE_PRECISION_LOW ? true : false;
if (config.lite_model_file().empty() && !config.lite_model_buffer_ptr()) {
raw_predictor_.reset(
new LightPredictor(config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.is_model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer,
use_low_precision));
} else if (!config.lite_model_file().empty() &&
!config.is_model_from_memory()) {
raw_predictor_.reset(
new LightPredictor(config.lite_model_file(), use_low_precision));
} else if (!config.lite_model_file().empty() &&
config.is_model_from_memory()) {
raw_predictor_.reset(new LightPredictor(config.lite_model_file().c_str(),
config.lite_model_file().length(),
use_low_precision));
} else {
raw_predictor_.reset(new LightPredictor(
config.lite_model_file(),
config.is_model_from_memory(),
(config.precision_mode() == lite_api::LITE_PRECISION_LOW) ? true
: false));
raw_predictor_.reset(new LightPredictor(config.lite_model_buffer_ptr(),
config.lite_model_buffer_size(),
use_low_precision));
}

mode_ = config.power_mode();
Expand Down
3 changes: 2 additions & 1 deletion lite/api/paddle_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ void MobileConfig::set_model_from_buffer(std::string &&x) {
}

void MobileConfig::set_model_from_buffer(const char *buffer, size_t length) {
lite_model_file_.assign(buffer, length);
lite_model_buffer_ptr_ = buffer;
lite_model_buffer_size_ = length;
model_from_memory_ = true;
}

Expand Down
14 changes: 8 additions & 6 deletions lite/api/paddle_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,11 @@ class LITE_API MobileConfig : public ConfigBase {
bool model_from_memory_{false};
PrecisionMode precision_mode_{LITE_PRECISION_NORMAL};

// model data readed from file or memory buffer in combined format.
// model data readed from file in combined format.
std::string lite_model_file_;
// model data readed from memory buffer in combined format.
const char* lite_model_buffer_ptr_ = nullptr;
size_t lite_model_buffer_size_{0};

// NOTE: This is a deprecated variable and will be removed in latter release.
std::string model_buffer_;
Expand All @@ -551,16 +554,15 @@ class LITE_API MobileConfig : public ConfigBase {
void set_model_from_buffer(const char* buffer, size_t length);
void set_precision_mode(PrecisionMode mode) { precision_mode_ = mode; }
PrecisionMode precision_mode() const { return precision_mode_; }
// return model data in lite_model_file_, which is in combined format.
// return model file path.
const std::string& lite_model_file() const { return lite_model_file_; }
// return model buffer data, which is in combined format.
const char* lite_model_buffer_ptr() const { return lite_model_buffer_ptr_; }
size_t lite_model_buffer_size() const { return lite_model_buffer_size_; }

// return model_from_memory_, which indicates whether to load model from
// memory buffer.
bool is_model_from_memory() const { return model_from_memory_; }
// note: `model_from_memory` has the same effect as `is_model_from_memory`,
// but is_model_from_memory is recommended and `model_from_memory` will be
// abandoned in v3.0.
bool model_from_memory() const { return model_from_memory_; }

// NOTE: This is a deprecated API and will be removed in latter release.
void set_model_buffer(const char* model_buffer,
Expand Down
6 changes: 6 additions & 0 deletions lite/core/model/base/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ void StringBufferReader::Read(void* dst, size_t size) const {
cur_ += size;
}

void CharBufferReader::Read(void* dst, size_t size) const {
CHECK(dst);
lite::TargetCopy(TargetType::kHost, dst, buf_ + cur_, size);
cur_ += size;
}

} // namespace model_parser
} // namespace lite
} // namespace paddle
18 changes: 18 additions & 0 deletions lite/core/model/base/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,24 @@ class StringBufferReader : public ByteReader {
mutable size_t cur_{0};
};

class CharBufferReader : public ByteReader {
public:
explicit CharBufferReader(const char* buffer, size_t length)
: buf_(buffer), length_(length) {
CHECK(buf_);
}
~CharBufferReader() = default;
void Read(void* dst, size_t size) const override;
bool ReachEnd() const override { return cur_ >= length_; }
size_t length() const override { return length_; }
size_t current() const override { return cur_; }

private:
const char* buf_;
size_t length_;
mutable size_t cur_{0};
};

} // namespace model_parser
} // namespace lite
} // namespace paddle
12 changes: 8 additions & 4 deletions lite/model_parser/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "lite/model_parser/model_parser.h"

#include <algorithm>
#include <fstream>
#include <limits>
Expand All @@ -29,6 +30,7 @@
#include "lite/model_parser/pb/tensor_io.h"
#ifndef LITE_ON_TINY_PUBLISH
#include <cstdio>

#include "lite/model_parser/naive_buffer/combined_params_desc.h"
#include "lite/model_parser/naive_buffer/param_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
Expand Down Expand Up @@ -737,7 +739,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
* topo_size: length of `topo_data`.
* topo_data: contains model's topology data.
* param_data: contains model's params data.
*/
*/

void LoadModelNaiveFromFile(const std::string &filename,
Scope *scope,
Expand Down Expand Up @@ -898,7 +900,8 @@ void LoadModelFbsFromFile(model_parser::BinaryFileReader *reader,
}
}

void LoadModelNaiveFromMemory(const std::string &model_buffer,
void LoadModelNaiveFromMemory(const char *model_buffer,
size_t model_buffer_size,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
Expand All @@ -907,7 +910,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,

// (1)get meta version
uint16_t meta_version;
model_parser::StringBufferReader reader(model_buffer);
model_parser::CharBufferReader reader(model_buffer, model_buffer_size);
reader.Read(&meta_version, sizeof(uint16_t));
VLOG(4) << "Meta_version:" << meta_version;

Expand All @@ -933,6 +936,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
break;
}
}

#ifndef LITE_ON_TINY_PUBLISH
void LoadModelNaiveV0FromMemory(const std::string &model_buffer,
Scope *scope,
Expand Down Expand Up @@ -973,7 +977,7 @@ void LoadModelNaiveV0FromMemory(const std::string &model_buffer,
///////////////////////////////////////////////////////////////////
// Meta_version=1,2
///////////////////////////////////////////////////////////////////
void LoadModelFbsFromMemory(model_parser::StringBufferReader *reader,
void LoadModelFbsFromMemory(model_parser::CharBufferReader *reader,
Scope *scope,
cpp::ProgramDesc *cpp_prog,
uint16_t meta_version) {
Expand Down
5 changes: 3 additions & 2 deletions lite/model_parser/model_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ void LoadModelNaiveFromFile(const std::string& filename,
lite::Scope* scope,
cpp::ProgramDesc* prog);

void LoadModelNaiveFromMemory(const std::string& model_buffer,
void LoadModelNaiveFromMemory(const char* model_buffer,
size_t model_buffer_size,
lite::Scope* scope,
cpp::ProgramDesc* cpp_prog);
void LoadModelFbsFromMemory(model_parser::StringBufferReader* reader,
void LoadModelFbsFromMemory(model_parser::CharBufferReader* reader,
Scope* scope,
cpp::ProgramDesc* cpp_prog,
uint16_t meta_version);
Expand Down
3 changes: 2 additions & 1 deletion lite/model_parser/model_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ TEST(ModelParser, LoadModelNaiveFromMemory) {

auto model_path = std::string(FLAGS_model_dir) + ".saved.nb";
std::string model_buffer = lite::ReadFile(model_path);
LoadModelNaiveFromMemory(model_buffer, &scope, &prog);
LoadModelNaiveFromMemory(
model_buffer.c_str(), model_buffer.length(), &scope, &prog);
}

} // namespace lite
Expand Down

0 comments on commit 8934039

Please sign in to comment.