Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add error checking for cpp examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jan 11, 2019
1 parent 5282cdd commit 8b7ea17
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 18 deletions.
8 changes: 6 additions & 2 deletions cpp-package/example/alexnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,14 @@ int main(int argc, char const *argv[]) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/googlenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,14 @@ int main(int argc, char const *argv[]) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/inception_bn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ int main(int argc, char const *argv[]) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

// initialize parameters
Xavier xavier = Xavier(Xavier::gaussian, Xavier::in, 2);
Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/lenet_with_mxdataiter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ int main(int argc, char const *argv[]) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/mlp_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ int main(int argc, char** argv) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

auto net = mlp(layers);

Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/mlp_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ int main(int argc, char** argv) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

auto net = mlp(layers);

Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,14 @@ int main(int argc, char const *argv[]) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

// initialize parameters
Xavier xavier = Xavier(Xavier::gaussian, Xavier::in, 2);
Expand Down
8 changes: 6 additions & 2 deletions cpp-package/example/test_score.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ int main(int argc, char** argv) {
};

auto train_iter = MXDataIter("MNISTIter");
setDataIter(&train_iter, "Train", data_files, batch_size);
if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
return 1;
}

auto val_iter = MXDataIter("MNISTIter");
setDataIter(&val_iter, "Label", data_files, batch_size);
if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
return 1;
}

auto net = mlp(layers);

Expand Down
5 changes: 3 additions & 2 deletions cpp-package/example/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ bool check_datafiles(const std::vector<std::string> &data_files) {
}
}
return true;
}
}

bool setDataIter(MXDataIter *iter , const std::string &useType,
const std::vector<std::string> &data_files, int batch_size) {
if (!check_datafiles(data_files))
if (!check_datafiles(data_files)) {
return false;
}

iter->SetParam("batch_size", batch_size);
iter->SetParam("shuffle", 1);
Expand Down

0 comments on commit 8b7ea17

Please sign in to comment.