From b8514e64de4cd89e8c5396cf0af0ce57c9c231ed Mon Sep 17 00:00:00 2001 From: Gleb Belov Date: Wed, 13 Mar 2024 23:43:49 +1100 Subject: [PATCH] Export NL model: item names #232 Generate names when needed and use for the flat model --- include/mp/flat/problem_flattener.h | 43 ++++++----- include/mp/model-mgr-with-pb.h | 14 ++-- include/mp/nl-reader.h | 16 +++-- include/mp/problem.h | 63 +++++++++++++++- src/expr-writer.h | 107 ++++++++++++++++------------ src/nl-reader.cc | 18 +++-- src/problem.cc | 13 ++-- test/expr-writer-test.cc | 10 +-- 8 files changed, 191 insertions(+), 93 deletions(-) diff --git a/include/mp/flat/problem_flattener.h b/include/mp/flat/problem_flattener.h index 1be73bb51..c2f309730 100644 --- a/include/mp/flat/problem_flattener.h +++ b/include/mp/flat/problem_flattener.h @@ -29,17 +29,19 @@ LinTerms ToLinTerms(const LinearExpr& e) { } /// Write algebraic expression (linear + non-linear.) -template +template void WriteExpr(fmt::Writer &w, const LinearExpr &linear, - NumericExpr nonlinear); + NumericExpr nonlinear, Namer); /// Write logical expression -template -void WriteExpr(fmt::Writer &w, LogicalExpr expr); +template +void WriteExpr(fmt::Writer &w, LogicalExpr expr, VN); /// Write algebraic constraint. -template -void WriteAlgCon(fmt::Writer &w, const AlgCon &con); +template +void WriteAlgCon(fmt::Writer &w, const AlgCon &con, VN); /// ProblemFlattener: it walks and "flattens" most expressions @@ -135,8 +137,6 @@ class ProblemFlattener : protected: /// Convert problem items void ConvertStandardItems() { - CopyItemNames(); - ////////////////////////// Variables ConvertVars(); @@ -176,6 +176,10 @@ class ProblemFlattener : MP_DISPATCH( ConvertLogicalCon( i ) ); } + /// We could have produced variable names + /// when exporting NL model info + CopyItemNames(); + /// Signal we are not flattening anything ifFltCon_ = -1; } @@ -188,11 +192,12 @@ class ProblemFlattener : MiniJSONWriter jw(wrt); jw["NL_COMMON_EXPR_index"] = i; // We don't receive defvar names from AMPL - jw["name"] = "ce" + std::to_string(i+1); + auto vn = GetModel().GetVarNamer(); + jw["name"] = vn.dvname(i); auto ce = GetModel().common_expr(i); fmt::MemoryWriter w2; WriteExpr( - w2, ce.linear_expr(), ce.nonlinear_expr()); + w2, ce.linear_expr(), ce.nonlinear_expr(), vn); jw["printed"] = w2.c_str(); } wrt.write("\n"); // EOL @@ -207,13 +212,13 @@ class ProblemFlattener : { MiniJSONWriter jw(wrt); jw["NL_OBJECTIVE_index"] = i; - if (GetModel().obj_names().size()>i) - jw["name"] = GetModel().obj_names()[i]; + jw["name"] = GetModel().obj_name(i); auto obj = GetModel().obj(i); jw["sense"] = (int)obj.type(); fmt::MemoryWriter w2; WriteExpr( - w2, obj.linear_expr(), obj.nonlinear_expr()); + w2, obj.linear_expr(), obj.nonlinear_expr(), + GetModel().GetVarNamer()); jw["printed"] = w2.c_str(); } wrt.write("\n"); // EOL @@ -230,10 +235,10 @@ class ProblemFlattener : auto con = GetModel().algebraic_con(i); jw["NL_CON_TYPE"] = (con.nonlinear_expr() ? "nonlin" : "lin"); jw["index"] = i; - if (GetModel().con_names().size()>i) - jw["name"] = GetModel().con_names()[i]; + jw["name"] = GetModel().con_name(i); fmt::MemoryWriter w2; - WriteAlgCon(w2, con); + WriteAlgCon( + w2, con, GetModel().GetVarNamer()); jw["printed"] = w2.c_str(); } wrt.write("\n"); // EOL @@ -251,10 +256,10 @@ class ProblemFlattener : jw["NL_CON_TYPE"] = "logical"; int i_actual = GetModel().num_algebraic_cons() + i; jw["index"] = i_actual; - if (GetModel().con_names().size()>i_actual) - jw["name"] = GetModel().con_names()[i_actual]; + jw["name"] = GetModel().con_name(i_actual); fmt::MemoryWriter w2; - WriteExpr(w2, con.expr()); + WriteExpr( + w2, con.expr(), GetModel().GetVarNamer()); jw["printed"] = w2.c_str(); } wrt.write("\n"); // EOL diff --git a/include/mp/model-mgr-with-pb.h b/include/mp/model-mgr-with-pb.h index a60d08f88..542245096 100644 --- a/include/mp/model-mgr-with-pb.h +++ b/include/mp/model-mgr-with-pb.h @@ -123,11 +123,12 @@ class ModelManagerWithProblemBuilder : /// The .row file has cons + objs. void ReadNames(const std::string& namebase) { if (WantNames()) { - NameProvider npv("_svar"); - NameProvider npc("_scon"); + NameProvider npv("_svar", "_sdvar"); + NameProvider npc("_scon", "_slogcon"); if (WantNames()<=2) { npv.ReadNames(namebase + ".col", - GetModel().num_vars()); + GetModel().num_vars() + + GetModel().num_common_exprs()); npc.ReadNames(namebase + ".row", GetModel().num_cons() + GetModel().num_objs()); @@ -135,9 +136,12 @@ class ModelManagerWithProblemBuilder : if (WantNames()>=2 || npv.number_read()+npc.number_read()) { GetModel().SetVarNames( - npv.get_names(GetModel().num_vars())); + npv.get_names( + GetModel().num_vars() + GetModel().num_common_exprs(), + GetModel().num_vars())); GetModel().SetConNames( - npc.get_names(GetModel().num_cons()) ); + npc.get_names(GetModel().num_cons(), + GetModel().num_algebraic_cons()) ); SetObjNames(npc); } } diff --git a/include/mp/nl-reader.h b/include/mp/nl-reader.h index 98bbee778..a27e04904 100644 --- a/include/mp/nl-reader.h +++ b/include/mp/nl-reader.h @@ -2505,8 +2505,10 @@ class NameProvider { fmt::CStringRef gen_name, std::size_t num_items); - /// Construct without reading (generic names can be provided) - NameProvider(fmt::CStringRef gen_name); + /// Construct without reading (generic names can be provided). + /// @param n2: second generic name, + /// used from the 2nd index of name(). + NameProvider(fmt::CStringRef gen_name, fmt::CStringRef n2=""); /// Read names void ReadNames(fmt::CStringRef filename, @@ -2516,15 +2518,19 @@ class NameProvider { size_t number_read() const; /// Returns the name of the item at specified index. - fmt::StringRef name(std::size_t index); + /// @param i2: if >=0, + /// from this index, generic name 2 is used + fmt::StringRef name(std::size_t index, std::size_t i2=-1); /// Return vector of names, length n. /// If number_read() < n, generic names are filled. - std::vector get_names(size_t n); + /// @param i2: if >=0, + /// from this index, generic name 2 is used + std::vector get_names(size_t n, size_t i2=-1); private: std::vector names_; - std::string gen_name_; + std::string gen_name_, gen_name_2_; internal::NameReader reader_; fmt::MemoryWriter writer_; }; diff --git a/include/mp/problem.h b/include/mp/problem.h index 9ff511e41..0ee36ee37 100644 --- a/include/mp/problem.h +++ b/include/mp/problem.h @@ -475,14 +475,73 @@ class BasicProblem : public ExprFactory, public SuffixManager { /** Returns the number of variables. */ int num_vars() const { return static_cast(vars_.size()); } + /// Normal variable name + const std::string& var_name(int i) { + assert(0<=i && i& var_names() { return var_names_; } /** Returns the constraint names (if present). */ const std::vector& con_names() { return con_names_; } /** Returns the objective names (if present). */ const std::vector& obj_names() { return obj_names_; } + /// Variable namer + class VarNamer { + public: + /// Construct + VarNamer(BasicProblem& p) : p_(p) { } + /// Normal var name + const std::string& vname(int i) const + { return p_.var_name(i); } + /// Defined var name + const std::string& dvname(int i) const + { return p_.dvar_name(i); } + private: + BasicProblem& p_; + }; + + /// Obtain variable namer + VarNamer GetVarNamer() { return VarNamer(*this); } + /** Returns the number of objectives. */ int num_objs() const { return static_cast(linear_objs_.size()); } @@ -669,7 +728,7 @@ class BasicProblem : public ExprFactory, public SuffixManager { /// Set name vectors void SetVarNames(std::vector names) { - assert((size_t)num_vars() == names.size()); + assert((size_t)(num_vars() + num_common_exprs()) == names.size()); var_names_ = std::move( names ); } void SetConNames(std::vector names) { diff --git a/src/expr-writer.h b/src/expr-writer.h index f29b9b084..7587d5b7d 100644 --- a/src/expr-writer.h +++ b/src/expr-writer.h @@ -50,6 +50,16 @@ enum Precedence { }; } +/// Default var namer +struct GenericVarNamer { + /// Normal variable name + static std::string vname(int i) + { return "x" + std::to_string(i+1); } + /// Defined variable name + static std::string dvname(int i) + { return "_sdvar[" + std::to_string(i+1) + "]"; } +}; + namespace expr { /// Returns operator precedence for the specified expression kind assuming the /// notation used by ExprWriter. @@ -76,16 +86,19 @@ inline bool IsZero(NumericExpr expr) { /// to fmt::Writer. It takes into account precedence and associativity /// of operators avoiding unnecessary parentheses except for potentially /// confusing cases such as "!x = y" which is written as "!(x = y) instead. -template +template class ExprWriter : - public BasicExprVisitor, void, ExprTypes> { + public BasicExprVisitor< + ExprWriter, void, ExprTypes> { private: fmt::Writer &writer_; int precedence_; + VarNamer vnam_; MP_DEFINE_EXPR_TYPES(ExprTypes); - typedef BasicExprVisitor, void, ExprTypes> Base; + typedef BasicExprVisitor< + ExprWriter, void, ExprTypes> Base; static int precedence(Expr e) { return expr::precedence(e.kind()); } @@ -125,8 +138,8 @@ class ExprWriter : public: /// Construct - explicit ExprWriter(fmt::Writer &w) - : writer_(w), precedence_(prec::UNKNOWN) {} + explicit ExprWriter(fmt::Writer &w, VarNamer vn={}) + : writer_(w), precedence_(prec::UNKNOWN), vnam_(vn) {} /// Visit numeric expr void Visit(NumericExpr e, int precedence = -1) { @@ -141,7 +154,7 @@ class ExprWriter : } void VisitCommonExpr(CommonExpr e) - { writer_ << "ce" << (e.index() + 1); } + { writer_ << vnam_.dvname(e.index()); } void VisitNumericConstant(NumericConstant c) { writer_ << c.value(); } @@ -171,7 +184,7 @@ class ExprWriter : void VisitPLTerm(PLTerm e); void VisitCall(CallExpr e); void VisitVariable(Variable v) - { writer_ << 'x' << (v.index() + 1); } + { writer_ << vnam_.vname(v.index()); } void VisitNot(NotExpr e) { writer_ << '!'; @@ -191,9 +204,9 @@ class ExprWriter : void VisitLogicalConstant(LogicalConstant c) { writer_ << c.value(); } }; -template -ExprWriter::Parenthesizer::Parenthesizer( - ExprWriter &w, Expr e, int prec) +template +ExprWriter::Parenthesizer::Parenthesizer( + ExprWriter &w, Expr e, int prec) : writer_(w), write_paren_(false) { saved_precedence_ = w.precedence_; if (prec == -1) @@ -204,16 +217,16 @@ ExprWriter::Parenthesizer::Parenthesizer( w.precedence_ = precedence(e); } -template -ExprWriter::Parenthesizer::~Parenthesizer() { +template +ExprWriter::Parenthesizer::~Parenthesizer() { writer_.precedence_ = saved_precedence_; if (write_paren_) writer_.writer_ << ')'; } -template +template template -void ExprWriter::WriteArgs( +void ExprWriter::WriteArgs( Iter begin, Iter end, const char *sep, int precedence) { writer_ << '('; if (begin != end) { @@ -226,9 +239,9 @@ void ExprWriter::WriteArgs( writer_ << ')'; } -template +template template -void ExprWriter::WriteBinary(ExprType e) { +void ExprWriter::WriteBinary(ExprType e) { int prec = precedence(e); bool right_associative = prec == prec::EXPONENTIATION; Visit(e.lhs(), prec + (right_associative ? 1 : 0)); @@ -236,8 +249,8 @@ void ExprWriter::WriteBinary(ExprType e) { Visit(e.rhs(), prec + (right_associative ? 0 : 1)); } -template -void ExprWriter::WriteCallArg(Expr arg) { +template +void ExprWriter::WriteCallArg(Expr arg) { if (NumericExpr e = ExprTypes::template Cast(arg)) { Visit(e, prec::UNKNOWN); return; @@ -262,8 +275,8 @@ void ExprWriter::WriteCallArg(Expr arg) { writer_ << "'"; } -template -void ExprWriter::VisitBinaryFunc(BinaryExpr e) { +template +void ExprWriter::VisitBinaryFunc(BinaryExpr e) { writer_ << str(e.kind()) << '('; Visit(e.lhs(), prec::UNKNOWN); writer_ << ", "; @@ -271,8 +284,8 @@ void ExprWriter::VisitBinaryFunc(BinaryExpr e) { writer_ << ')'; } -template -void ExprWriter::VisitIf(IfExpr e) { +template +void ExprWriter::VisitIf(IfExpr e) { writer_ << "if "; Visit(e.condition(), prec::UNKNOWN); writer_ << " then "; @@ -285,8 +298,8 @@ void ExprWriter::VisitIf(IfExpr e) { } } -template -void ExprWriter::VisitSum(SumExpr e) { +template +void ExprWriter::VisitSum(SumExpr e) { writer_ << "("; typename SumExpr::iterator i = e.begin(), end = e.end(); if (i != end) { @@ -299,8 +312,8 @@ void ExprWriter::VisitSum(SumExpr e) { writer_ << ')'; } -template -void ExprWriter::VisitNumberOf(NumberOfExpr e) { +template +void ExprWriter::VisitNumberOf(NumberOfExpr e) { writer_ << "numberof "; typename NumberOfExpr::iterator i = e.begin(); Visit(*i++, prec::UNKNOWN); @@ -308,8 +321,8 @@ void ExprWriter::VisitNumberOf(NumberOfExpr e) { WriteArgs(i, e.end()); } -template -void ExprWriter::VisitPLTerm(PLTerm e) { +template +void ExprWriter::VisitPLTerm(PLTerm e) { writer_ << "<<" << e.breakpoint(0); for (int i = 1, n = e.num_breakpoints(); i < n; ++i) writer_ << ", " << e.breakpoint(i); @@ -324,8 +337,8 @@ void ExprWriter::VisitPLTerm(PLTerm e) { writer_ << "e" << ((ExprTypes::template Cast(arg)).index() + 1); } -template -void ExprWriter::VisitCall(CallExpr e) { +template +void ExprWriter::VisitCall(CallExpr e) { writer_ << e.function().name() << '('; typename CallExpr::iterator i = e.begin(), end = e.end(); if (i != end) { @@ -338,16 +351,16 @@ void ExprWriter::VisitCall(CallExpr e) { writer_ << ')'; } -template -void ExprWriter::VisitLogicalCount(LogicalCountExpr e) { +template +void ExprWriter::VisitLogicalCount(LogicalCountExpr e) { writer_ << str(e.kind()) << ' '; Visit(e.lhs()); writer_ << ' '; WriteArgs(e.rhs()); } -template -void ExprWriter::VisitIteratedLogical(IteratedLogicalExpr e) { +template +void ExprWriter::VisitIteratedLogical(IteratedLogicalExpr e) { // There is no way to produce an AMPL forall/exists expression because // its indexing is not available any more. So we write a count expression // instead with a comment about the original expression. @@ -361,8 +374,8 @@ void ExprWriter::VisitIteratedLogical(IteratedLogicalExpr e) { WriteArgs(e, op, prec); } -template -void ExprWriter::VisitImplication(ImplicationExpr e) { +template +void ExprWriter::VisitImplication(ImplicationExpr e) { Visit(e.condition()); writer_ << " ==> "; Visit(e.then_expr(), prec::IMPLICATION + 1); @@ -375,9 +388,11 @@ void ExprWriter::VisitImplication(ImplicationExpr e) { } /// Write algebraic expression (linear + non-linear.) -template +template void WriteExpr(fmt::Writer &w, const LinearExpr &linear, - NumericExpr nonlinear) { + NumericExpr nonlinear, Namer vnam={}) { bool have_terms = false; for (auto i = linear.begin(), e = linear.end(); i != e; ++i) { double coef = i->coef(); @@ -388,7 +403,7 @@ void WriteExpr(fmt::Writer &w, const LinearExpr &linear, have_terms = true; if (coef != 1) w << coef << " * "; - w << "x" << (i->var_index() + 1); + w << vnam.vname(i->var_index()); } } if (!nonlinear || IsZero(nonlinear)) { @@ -398,25 +413,25 @@ void WriteExpr(fmt::Writer &w, const LinearExpr &linear, } if (have_terms) w << " + "; - ExprWriter(w).Visit(nonlinear); + ExprWriter(w, vnam).Visit(nonlinear); } /// Write logical expression -template -void WriteExpr(fmt::Writer &w, LogicalExpr expr) { - ExprWriter(w).Visit(expr); +template +void WriteExpr(fmt::Writer &w, LogicalExpr expr, VN vnam={}) { + ExprWriter(w, vnam).Visit(expr); } /// Write algebraic constraint. -template +template void WriteAlgCon(fmt::Writer &w, - const AlgCon &con) { + const AlgCon &con, VN vnam) { double inf = INFINITY; double lb = con.lb(), ub = con.ub(); if (lb != ub && lb != -inf && ub != inf) w << lb << " <= "; WriteExpr( - w, con.linear_expr(), con.nonlinear_expr()); + w, con.linear_expr(), con.nonlinear_expr(), vnam); if (lb == ub) w << " = " << lb; else if (ub != inf) diff --git a/src/nl-reader.cc b/src/nl-reader.cc index 68203421c..5172e7940 100644 --- a/src/nl-reader.cc +++ b/src/nl-reader.cc @@ -336,8 +336,9 @@ mp::NameProvider::NameProvider( ReadNames(filename, num_items); } -mp::NameProvider::NameProvider(fmt::CStringRef gen_name) - : gen_name_(gen_name.c_str()) { } +mp::NameProvider::NameProvider( + fmt::CStringRef gen_name, fmt::CStringRef n2) + : gen_name_(gen_name.c_str()), gen_name_2_(n2.c_str()) { } void mp::NameProvider::ReadNames( fmt::CStringRef filename, std::size_t num_items) { @@ -352,7 +353,8 @@ void mp::NameProvider::ReadNames( names_.push_back(last_name.data() + last_name.size() + 1); } -fmt::StringRef mp::NameProvider::name(std::size_t index) { +fmt::StringRef mp::NameProvider::name( + std::size_t index, std::size_t i2) { if (index + 1 < names_.size()) { const char *name = names_[index]; const auto* pos1past = names_[index + 1] - 1; @@ -362,7 +364,10 @@ fmt::StringRef mp::NameProvider::name(std::size_t index) { return fmt::StringRef(name, pos1past - name); } writer_.clear(); - writer_ << gen_name_ << '[' << (index + 1) << ']'; + if (i2>=0 && index>=i2) + writer_ << gen_name_2_ << '[' << (index - i2 + 1) << ']'; + else + writer_ << gen_name_ << '[' << (index + 1) << ']'; return fmt::StringRef(writer_.c_str(), writer_.size()); } @@ -370,11 +375,12 @@ size_t mp::NameProvider::number_read() const { return (names_.size() ? names_.size()-1 : 0); } -std::vector mp::NameProvider::get_names(size_t n) { +std::vector mp::NameProvider::get_names( + size_t n, size_t i2) { std::vector result; result.reserve(n); for (size_t i=0; i -(fmt::Writer &w, const LinearExpr &linear, NumericExpr nonlinear); +void WriteExpr( +fmt::Writer &w, +const LinearExpr &linear, NumericExpr nonlinear, +Problem::VarNamer); template void WriteExpr -(fmt::Writer &w, LogicalExpr expr); +(fmt::Writer &w, LogicalExpr expr, Problem::VarNamer); /// Write algebraic constraint. template -void WriteAlgCon -(fmt::Writer &w, const typename Problem::MutAlgebraicCon &con); +void WriteAlgCon( +fmt::Writer &w, const typename Problem::MutAlgebraicCon &con, +Problem::VarNamer); } // namespace mp diff --git a/test/expr-writer-test.cc b/test/expr-writer-test.cc index 20bc2f5c1..3340fd79c 100644 --- a/test/expr-writer-test.cc +++ b/test/expr-writer-test.cc @@ -321,11 +321,11 @@ TEST_F(ExprWriterTest, WriteVarArgExpr) { TEST_F(ExprWriterTest, WriteSumExpr) { NumericExpr args[] = {MakeVariable(0), MakeVariable(1), MakeConst(42)}; - CHECK_WRITE("/* sum */ (x1 + x2 + 42)", MakeIterated(ex::SUM, args)); + CHECK_WRITE("(x1 + x2 + 42)", MakeIterated(ex::SUM, args)); NumericExpr args2[] = { MakeBinary(ex::ADD, MakeVariable(0), MakeVariable(1)), MakeConst(42) }; - CHECK_WRITE("/* sum */ ((x1 + x2) + 42)", MakeIterated(ex::SUM, args2)); + CHECK_WRITE("((x1 + x2) + 42)", MakeIterated(ex::SUM, args2)); } TEST_F(ExprWriterTest, WriteCountExpr) { @@ -579,12 +579,12 @@ TEST_F(ExprWriterTest, SumExprPrecedence) { auto x1 = MakeVariable(0), x2 = MakeVariable(1), x3 = MakeVariable(2); NumericExpr args1[] = {x2, x3}; NumericExpr args2[] = {x1, MakeIterated(ex::SUM, args1)}; - CHECK_WRITE("/* sum */ (x1 + /* sum */ (x2 + x3))", + CHECK_WRITE("(x1 + (x2 + x3))", MakeIterated(ex::SUM, args2)); NumericExpr args3[] = {x1, MakeBinary(ex::MUL, x2, x3)}; - CHECK_WRITE("/* sum */ (x1 + x2 * x3)", MakeIterated(ex::SUM, args3)); + CHECK_WRITE("(x1 + x2 * x3)", MakeIterated(ex::SUM, args3)); NumericExpr args4[] = {MakeBinary(ex::ADD, x1, x2), x3}; - CHECK_WRITE("/* sum */ ((x1 + x2) + x3)", MakeIterated(ex::SUM, args4)); + CHECK_WRITE("((x1 + x2) + x3)", MakeIterated(ex::SUM, args4)); } TEST_F(ExprWriterTest, CountExprPrecedence) {