From 5c7e2e9cc44d4db2975c86205a3d5a352441504e Mon Sep 17 00:00:00 2001 From: Igor Konnov Date: Mon, 25 Jan 2021 15:56:05 +0100 Subject: [PATCH] Bugfixes in Desugarer and propagation of primes (#483) Co-authored-by: Shon Feder --- UNRELEASED.md | 10 ++- docs/src/apalache/features.md | 16 ++-- test/tla/ExistTuple476.tla | 17 ++++ test/tla/UnchangedExpr471.tla | 22 +++++ test/tla/cli-integration-tests.md | 18 ++++ .../forsyte/apalache/tla/pp/Desugarer.scala | 50 +++++++++-- .../tla/pp/passes/PreproPassImpl.scala | 55 ++++++++---- .../apalache/tla/pp/TestDesugarer.scala | 89 ++++++++++++++++++- .../standard/ModuleByExTransformer.scala | 29 +++++- .../standard/PrimePropagation.scala | 67 ++++++++++---- .../standard/TestPrimePropagation.scala | 68 ++++++++++++++ 11 files changed, 383 insertions(+), 58 deletions(-) create mode 100644 test/tla/ExistTuple476.tla create mode 100644 test/tla/UnchangedExpr471.tla create mode 100644 tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestPrimePropagation.scala diff --git a/UNRELEASED.md b/UNRELEASED.md index 6e16af5d5f..c7dba3c8f6 100644 --- a/UNRELEASED.md +++ b/UNRELEASED.md @@ -12,8 +12,14 @@ DO NOT LEAVE A BLANK LINE BELOW THIS PREAMBLE --> ### Bug fixes - * handling big integers, #450 + * handling big integers, see #450 + * expanding tuples in quantifiers, see #476 + * unfolding UNCHANGED for arbitrary expressions, see #471 + * unfolding UNCHANGED <<>>, see #475 ### Features - * constant simplification over strings, #197 \ No newline at end of file + * constant simplification over strings, see #197 + * propagation of primes inside expressions, + e.g., (f[i])' becomes f'[i'] if both f and i are state variables + diff --git a/docs/src/apalache/features.md b/docs/src/apalache/features.md index d504169c55..e935248d28 100644 --- a/docs/src/apalache/features.md +++ b/docs/src/apalache/features.md @@ -13,9 +13,9 @@ Construct | Supported? | Milestone | Comment ``F(x1, ..., x_n) == exp`` | ✔ / ✖ | - | Every application of `F` is replaced with its body. Recursive operators need [unrolling annotations](./principles.md#recursive-operators). ``f[x ∈ S] == exp`` | ✔ / ✖ | - | Only recursive functions that return integers or Booleans are supported. ``INSTANCE M WITH ...`` | ✔ / ✖ | - | No special treatment for ``~>``, ``\cdot``, ``ENABLED`` -``N(x1, ..., x_n) == INSTANCE M WITH...`` | ✔ / ✖ | - | Parameterized instances are not supported yet, LOCAL operator definitions inside instances may fail to work +``N(x1, ..., x_n) == INSTANCE M WITH...`` | ✔ / ✖ | - | Parameterized instances are not supported ``THEOREM P`` | ✔ / ✖ | - | Parsed but not used -``LOCAL def`` | ✔ | - | Handled by SANY +``LOCAL def`` | ✔ | - | Replaced with local LET-IN definitions ### The constant operators @@ -53,7 +53,7 @@ Operator | Supported? | Milestone | Comment ------------------------|:------------------:|:---------------:|-------------- `f[e]` | ✔ | - | `DOMAIN f` | ✔ | - | -`[ x \in S ↦ e]` | ✔ / ✖ | - | +`[ x \in S ↦ e]` | ✔ | - | `[ S -> T ]` | ✔ | - | Sometimes, the functions sets are expanded `[ f EXCEPT ![e1] = e2 ]` | ✔ | - | @@ -89,7 +89,7 @@ Construct | Supported? | Milestone | Comment `"c1...c_n"` | ✔ | - | A string is always mapped to a unique uninterpreted constant `STRING` | ✖ | - | It is an infinite set. We cannot handle infinite sets. `d1...d_n` | ✔ | - | As long as the SMT solver (Z3) accepts that large number -`d1...d_n.d_n+1...d_m` | ✖ | - | Technical issue. We will implemented upon a user request. +`d1...d_n.d_n+1...d_m` | ✖ | - | Technical issue. We will implement it upon a user request. #### Miscellaneous Constructs @@ -98,14 +98,14 @@ Construct | Supported? | Milestone | Comment `IF p THEN e1 ELSE e2` | ✔ | - | Provided that both e1 and e2 have the same type `CASE p1 -> e1 [] ... [] p_n -> e_n [] OTHER -> e` | ✔ | - | See the comment above `CASE p1 -> e1 [] ... [] p_n -> e_n` | ✖ | - | Introduce the default arm with `OTHER`. -``LET d1 == e1 ... d_n == e_n IN e`` | ✔ / ✖ | `0.7-dev-calls` | All applications of `d1`, ..., `d_n` are replaced with the expressions `e1`, ... `e_n` respectively. LET-definitions without arguments are kept in place. +``LET d1 == e1 ... d_n == e_n IN e`` | ✔ | | All applications of `d1`, ..., `d_n` are replaced with the expressions `e1`, ... `e_n` respectively. LET-definitions without arguments are kept in place. multi-line `/\` and `\/` | ✔ | - | ### The Action Operators Construct | Supported? | Milestone | Comment ------------------------|:------------------:|:---------------:|-------------- -``e'`` | ✔ / ✖ | - | Provided that e is a variable +``e'`` | ✔ | - | ``[A]_e`` | ✖ | - | It does not matter for safety ``< A >_e`` | ✖ | - | ``ENABLED A`` | ✖ | - | @@ -141,8 +141,8 @@ Operator | Supported? | Milestone | Comment Operator | Supported? | Milestone | Comment ------------------------|:------------------:|:---------------:|-------------- ``<<...>>``, ``Head``, ``Tail``, ``Len``, ``SubSeq``, `Append`, `\o`, `f[e]` | ✔ | - | The sequence constructor ``<<...>>`` needs a [type annotation](types-and-annotations.md). -``EXCEPT`` | ✖ | `0.9` | this operator do not seem to be often used -``Seq(S)`` | ✖ | - | We need an upper bound on the length of the sequences. +``EXCEPT`` | ✖ | | If you need it, let us know, issue #324 +``Seq(S)`` | ✖ | - | If you need it, let us know, issue #314 ``SelectSeq`` | ✖ | - | will not be supported in the near future ### FiniteSets diff --git a/test/tla/ExistTuple476.tla b/test/tla/ExistTuple476.tla new file mode 100644 index 0000000000..b68ef11bb2 --- /dev/null +++ b/test/tla/ExistTuple476.tla @@ -0,0 +1,17 @@ +------------------ MODULE ExistTuple476 ---------------------- +(* A regression test for the issue: + /~https://github.com/informalsystems/apalache/issues/476 + *) +EXTENDS Integers + +VARIABLES x, y + +Init == + x = 0 /\ y = 0 + +Next == + \E <> \in (1..2) \X (3..4): + /\ x' = i + /\ y' = j + +============================================================== diff --git a/test/tla/UnchangedExpr471.tla b/test/tla/UnchangedExpr471.tla new file mode 100644 index 0000000000..6aac930196 --- /dev/null +++ b/test/tla/UnchangedExpr471.tla @@ -0,0 +1,22 @@ +------------------------ MODULE UnchangedExpr471 ---------------------------- +(* A regression test for UNCHANGED e: + see issue: /~https://github.com/informalsystems/apalache/issues/471 + *) +EXTENDS Integers + +CONSTANT N + +VARIABLES f, i + +ConstInit == + N' = 3 + +Init == + /\ f = [ j \in 1..3 |-> j * 2] + /\ i = 2 + +Next == + UNCHANGED <> + +============================================================================== + diff --git a/test/tla/cli-integration-tests.md b/test/tla/cli-integration-tests.md index d57be06a1f..fbf2be1c89 100644 --- a/test/tla/cli-integration-tests.md +++ b/test/tla/cli-integration-tests.md @@ -110,6 +110,24 @@ EXITCODE: OK ## running the check command +### check UnchangedExpr471.tla reports no error: regression for issue 471 + +```sh +$ apalache-mc check --cinit=ConstInit --length=1 UnchangedExpr471.tla | sed 's/I@.*//' +... +The outcome is: NoError +... +``` + +### check ExistTuple476.tla reports no error: regression for issue 476 + +```sh +$ apalache-mc check --length=1 ExistTuple476.tla | sed 's/I@.*//' +... +The outcome is: NoError +... +``` + ### check InvSub for SafeMath reports no error: regression for issue 450 ```sh diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala index 366d24b95c..ef548196dd 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala @@ -2,9 +2,10 @@ package at.forsyte.apalache.tla.pp import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience._ -import at.forsyte.apalache.tla.lir.oper.{TlaActionOper, TlaFunOper, TlaOper, TlaSetOper} +import at.forsyte.apalache.tla.lir.oper.{TlaActionOper, TlaBoolOper, TlaFunOper, TlaOper, TlaSetOper} import at.forsyte.apalache.tla.lir.transformations.standard.FlatLanguagePred import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker} + import javax.inject.Singleton /** @@ -45,15 +46,40 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { case OperEx(TlaActionOper.unchanged, args @ _*) => // flatten all tuples, e.g., convert <> >> to [x, y, z] - val flatArgs = flattenTuples(tla.tuple(args.map(transform) :_*)) - // and map every x to x' = x + val flatArgs = flattenTuplesInUnchanged(tla.tuple(args.map(transform) :_*)) + // map every x to x' = x val eqs = flatArgs map { x => tla.eql(tla.prime(x), x) } - tla.and(eqs :_*) + // x' = x /\ y' = y /\ z' = z + eqs match { + case Seq() => + // results from UNCHANGED <<>>, UNCHANGED << << >> >>, etc. + tla.bool(true) + + case Seq(one) => + one + + case _ => + tla.and(eqs: _*) + } case OperEx(TlaSetOper.filter, boundEx, setEx, predEx) => + // rewrite { <> >> \in XYZ: P(x, y, z) } + // to { x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) } OperEx(TlaSetOper.filter, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*) + case OperEx(TlaBoolOper.exists, boundEx, setEx, predEx) => + // rewrite \E <> >> \in XYZ: P(x, y, z) + // to \E x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) + OperEx(TlaBoolOper.exists, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*) + + case OperEx(TlaBoolOper.forall, boundEx, setEx, predEx) => + // rewrite \A <> >> \in XYZ: P(x, y, z) + // to \A x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) + OperEx(TlaBoolOper.forall, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*) + case OperEx(TlaSetOper.map, args @ _*) => + // rewrite { <> >> \in XYZ |-> e(x, y, z) } + // to { x_y_z \in XYZ |-> e(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) val trArgs = args map transform OperEx(TlaSetOper.map, collapseTuplesInMap(trArgs.head, trArgs.tail) :_*) @@ -81,15 +107,21 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { LetInEx( transform( body ), defs map { d => d.copy( body = transform( d.body ) ) } : _* ) } - private def flattenTuples(ex: TlaEx): Seq[TlaEx] = ex match { + private def flattenTuplesInUnchanged(ex: TlaEx): Seq[TlaEx] = ex match { case OperEx(TlaFunOper.tuple, args @ _*) => - args.map(flattenTuples).reduce(_ ++ _) + if (args.isEmpty) { + // Surprisingly, somebody has written UNCHANGED << >>, see issue #475. + Seq() + } else { + args.map(flattenTuplesInUnchanged).reduce(_ ++ _) + } - case NameEx(_) => - Seq(ex) + case ValEx(_) => + Seq() // no point in priming literals case _ => - throw new IllegalArgumentException("Expected a variable or a tuple of variables, found: " + ex) + // in general, UNCHANGED e becomes e' = e + Seq(ex) } private def expandExcept(topFun: TlaEx, accessors: Seq[TlaEx], newValues: Seq[TlaEx]): TlaEx = { diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala index 0ece9e5203..c148e96d05 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala @@ -2,15 +2,22 @@ package at.forsyte.apalache.tla.pp.passes import java.io.File import java.nio.file.Path - import at.forsyte.apalache.infra.passes.{Pass, PassOptions, TlaModuleMixin} import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.lir.TlaModule +import at.forsyte.apalache.tla.lir.{TlaDecl, TlaModule, TlaOperDecl} import at.forsyte.apalache.tla.lir.io.PrettyWriter import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} -import at.forsyte.apalache.tla.lir.transformations.{TlaModuleTransformation, TransformationTracker} +import at.forsyte.apalache.tla.lir.transformations.{ + TlaModuleTransformation, + TransformationTracker +} import at.forsyte.apalache.tla.lir.transformations.standard._ -import at.forsyte.apalache.tla.pp.{Desugarer, Keramelizer, Normalizer, UniqueNameGenerator} +import at.forsyte.apalache.tla.pp.{ + Desugarer, + Keramelizer, + Normalizer, + UniqueNameGenerator +} import com.google.inject.Inject import com.google.inject.name.Named import com.typesafe.scalalogging.LazyLogging @@ -22,14 +29,16 @@ import com.typesafe.scalalogging.LazyLogging * @param tracker transformation tracker * @param nextPass next pass to call */ -class PreproPassImpl @Inject()( val options: PassOptions, - gen: UniqueNameGenerator, - renaming: IncrementalRenaming, - tracker: TransformationTracker, - sourceStore: SourceStore, - changeListener: ChangeListener, - @Named("AfterPrepro") nextPass: Pass with TlaModuleMixin) - extends PreproPass with LazyLogging { +class PreproPassImpl @Inject() ( + val options: PassOptions, + gen: UniqueNameGenerator, + renaming: IncrementalRenaming, + tracker: TransformationTracker, + sourceStore: SourceStore, + changeListener: ChangeListener, + @Named("AfterPrepro") nextPass: Pass with TlaModuleMixin +) extends PreproPass + with LazyLogging { private var outputTlaModule: Option[TlaModule] = None @@ -48,10 +57,12 @@ class PreproPassImpl @Inject()( val options: PassOptions, override def execute(): Boolean = { logger.info(" > Before preprocessing: unique renaming") val input = tlaModule.get + val varSet = input.varDeclarations.map(_.name).toSet val transformationSequence: List[(String, TlaModuleTransformation)] = List( ("Desugarer", ModuleByExTransformer(Desugarer(tracker))), + ("PrimePropagation", createModuleTransformerForPrimePropagation(varSet)), ("UniqueRenamer", renaming.renameInModule), ("Normalizer", ModuleByExTransformer(Normalizer(tracker))), ("Keramelizer", ModuleByExTransformer(Keramelizer(gen, tracker))) @@ -65,7 +76,10 @@ class PreproPassImpl @Inject()( val options: PassOptions, logger.info(s" > $name") val transfomed = xformer(m) // dump the result of preprocessing after every transformation, in case the next one fails - PrettyWriter.write(transfomed, new File(outdir.toFile, s"out-prepro-$name.tla")) + PrettyWriter.write( + transfomed, + new File(outdir.toFile, s"out-prepro-$name.tla") + ) transfomed } @@ -78,13 +92,24 @@ class PreproPassImpl @Inject()( val options: PassOptions, outputTlaModule = Some(afterModule) if (options.getOrElse("general", "debug", false)) { - val sourceLocator = SourceLocator(sourceStore.makeSourceMap, changeListener) + val sourceLocator = + SourceLocator(sourceStore.makeSourceMap, changeListener) outputTlaModule.get.operDeclarations foreach sourceLocator.checkConsistency } true } + private def createModuleTransformerForPrimePropagation(varSet: Set[String]) + : ModuleByExTransformer = { + val cinitName = options.getOrElse("checker", "cinit", "CInit") + "Primed" + val includeAllButConstInit: TlaDecl => Boolean = { + case TlaOperDecl(name, _, _) => cinitName != name + case _ => true + } + ModuleByExTransformer(new PrimePropagation(tracker, varSet), includeAllButConstInit) + } + /** * Get the next pass in the chain. What is the next pass is up * to the module configuration and the pass outcome. @@ -93,7 +118,7 @@ class PreproPassImpl @Inject()( val options: PassOptions, */ override def next(): Option[Pass] = { outputTlaModule map { m => - nextPass.setModule( m ) + nextPass.setModule(m) nextPass } } diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala index 3d3dd729a2..8bbd917480 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala @@ -58,6 +58,15 @@ class TestDesugarer extends FunSuite with BeforeAndAfterEach { assert(expected == sugarFree) } + test("""rewrite UNCHANGED x to x' = x""") { + // input: x + val unchanged = tla.unchanged(tla.name("x")) + val sugarFree = desugarer.transform(unchanged) + // output: x' = x + val expected = tla.eql(tla.prime(tla.name("x")), tla.name("x")) + assert(expected == sugarFree) + } + test("""rewrite UNCHANGED <> >> to x' = x /\ y' = y""") { // input: <> >> val unchanged = tla.unchangedTup(tla.name("x"), tla.tuple(tla.name("y"))) @@ -74,7 +83,7 @@ class TestDesugarer extends FunSuite with BeforeAndAfterEach { test("unfold UNCHANGED <> >> to UNCHANGED <>") { // This is an idiom that was probably introduced by Diego Ongaro in Raft. // There is no added value in this construct, so it is just sugar. - // So, we do the transformation right here. + // We do the transformation right here. val unchanged = tla.unchangedTup(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))) val sugarFree = desugarer.transform(unchanged) val expected = @@ -86,6 +95,36 @@ class TestDesugarer extends FunSuite with BeforeAndAfterEach { assert(expected == sugarFree) } + test("""rewrite UNCHANGED <<>> to TRUE""") { + // this is a regression for issue #375 + // input: << >> + val unchanged = tla.unchangedTup() + val sugarFree = desugarer.transform(unchanged) + // output: TRUE + val expected = tla.bool(true) + assert(expected == sugarFree) + } + + test("""rewrite UNCHANGED << <<>>, <<>> >> to TRUE""") { + // this is a regression for issue #375 + // input: << <<>>, <<>> >> + val unchanged = tla.unchangedTup(tla.unchangedTup(), tla.unchangedTup()) + val sugarFree = desugarer.transform(unchanged) + // output: TRUE + val expected = tla.bool(true) + assert(expected == sugarFree) + } + + test("""rewrite UNCHANGED f[i] to (f[i])' = f[i]""") { + // this is a regression for issue #471 + // input: UNCHANGED f[i] + val app = tla.appFun(tla.name("f"), tla.name("i")) + val sugarFree = desugarer.transform(tla.unchangedTup(app)) + // output: (f[i])' = f[i] + val expected = tla.eql(tla.prime(app), app) + assert(expected == sugarFree) + } + test("simplify tuples in filters") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: { <> >> \in XYZ: x = 3 /\ y = 4 } @@ -132,6 +171,54 @@ class TestDesugarer extends FunSuite with BeforeAndAfterEach { assert(expected == sugarFree) } + test("simplify tuples in existentials") { + // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. + // input: \E <> >> \in XYZ: x = 3 /\ y = 4 } + val filter = + tla.exists( + tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), + tla.name("XYZ"), + tla.and(tla.eql(tla.name("x"), tla.int(3)), + tla.eql(tla.name("y"), tla.int(4)))) + val sugarFree = desugarer.transform(filter) + // output: \E x_y_z \in XYZ: x_y_z[1] = 3 /\ x_y_z[2][1] = 4 } + val expected = + tla.exists( + tla.name("x_y_z"), + tla.name("XYZ"), + tla.and( + tla.eql(tla.appFun(tla.name("x_y_z"), tla.int(1)), tla.int(3)), + tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), + tla.int(1)), + tla.int(4)) + )) //// + assert(expected == sugarFree) + } + + test("simplify tuples in universals") { + // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. + // input: \A <> >> \in XYZ: x = 3 /\ y = 4 } + val filter = + tla.forall( + tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), + tla.name("XYZ"), + tla.and(tla.eql(tla.name("x"), tla.int(3)), + tla.eql(tla.name("y"), tla.int(4)))) + val sugarFree = desugarer.transform(filter) + // output: \A x_y_z \in XYZ: x_y_z[1] = 3 /\ x_y_z[2][1] = 4 } + val expected = + tla.forall( + tla.name("x_y_z"), + tla.name("XYZ"), + tla.and( + tla.eql(tla.appFun(tla.name("x_y_z"), tla.int(1)), tla.int(3)), + tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), + tla.int(1)), + tla.int(4)) + )) //// + assert(expected == sugarFree) + } + test("simplify tuples in functions") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: [<> >> \in XYZ |-> x + y] diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/ModuleByExTransformer.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/ModuleByExTransformer.scala index 74382d27ae..5d95646e45 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/ModuleByExTransformer.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/ModuleByExTransformer.scala @@ -1,14 +1,20 @@ package at.forsyte.apalache.tla.lir.transformations.standard import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TlaModuleTransformation} +import at.forsyte.apalache.tla.lir.transformations.{ + TlaExTransformation, + TlaModuleTransformation +} /** * This transformer uses a TlaExTransformer to modify the bodies of operator declarations inside a module. * * @author Igor Konnov */ -class ModuleByExTransformer(exTrans: TlaExTransformation) extends TlaModuleTransformation { +class ModuleByExTransformer( + exTrans: TlaExTransformation, + applyTo: (TlaDecl => Boolean) = (_ => true) +) extends TlaModuleTransformation { override def apply(mod: TlaModule): TlaModule = { def mapOneDeclaration: TlaDecl => TlaDecl = { case TlaOperDecl(name, params, body) => @@ -20,10 +26,25 @@ class ModuleByExTransformer(exTrans: TlaExTransformation) extends TlaModuleTrans case d => d } - new TlaModule(mod.name, mod.declarations map mapOneDeclaration) + def mapIfIncluded(decl: TlaDecl): TlaDecl = { + if (applyTo(decl)) { + mapOneDeclaration(decl) + } else { + decl + } + } + + new TlaModule(mod.name, mod.declarations map mapIfIncluded) } } object ModuleByExTransformer { - def apply(exTrans: TlaExTransformation): ModuleByExTransformer = new ModuleByExTransformer(exTrans) + def apply(exTrans: TlaExTransformation): ModuleByExTransformer = + new ModuleByExTransformer(exTrans) + + def apply( + exTrans: TlaExTransformation, + include: TlaDecl => Boolean + ): ModuleByExTransformer = + new ModuleByExTransformer(exTrans, include) } diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/PrimePropagation.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/PrimePropagation.scala index c5e5ecccec..52622bae40 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/PrimePropagation.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/PrimePropagation.scala @@ -1,31 +1,60 @@ package at.forsyte.apalache.tla.lir.transformations.standard import at.forsyte.apalache.tla.lir.oper.TlaActionOper -import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.transformations.{ + TlaExTransformation, + TransformationTracker +} +import at.forsyte.apalache.tla.lir.{LetInEx, NameEx, OperEx, TlaEx} /** - * This is a simple reference implementation of an expression transformer. It expands the prime operator, + * A reference implementation of an expression transformer. It expands the prime operator, * that is, when it meets an expression e', it propagates primes inside e. * + * @param stateVars state variables * @param tracker a transformation tracker */ -class PrimePropagation(tracker: TransformationTracker) extends TlaExTransformation { - override def apply(e: TlaEx): TlaEx = transform(primed = false)(e) - - private def transform(primed: Boolean): TlaEx => TlaEx = tracker.trackEx { - case OperEx(TlaActionOper.prime, e) => - transform(primed)(e) - - case OperEx(oper, args @ _*) => - OperEx(oper, args map transform(primed) :_*) - - // TODO: ENABLED and module instances need a special treatment - - case ne @ NameEx(name) => - if (primed) OperEx(TlaActionOper.prime, ne) else ne - - case e => e +class PrimePropagation(tracker: TransformationTracker, stateVars: Set[String]) + extends TlaExTransformation { + + /** + * Propagate primes in the expression to the state variables. + * All names that are different from state variables should subsume prime. + * + * @param expr an expression to propagate primes + * @return the expression where primes are propagated to the level of state variables + */ + override def apply(expr: TlaEx): TlaEx = { + def transform(primeToAdd: Boolean): TlaEx => TlaEx = + tracker.trackEx { + case OperEx(TlaActionOper.prime, e) => + transform(true)(e) + + case OperEx(oper, args @ _*) => + OperEx(oper, args map transform(primeToAdd): _*) + + // TODO: ENABLED and module instances need a special treatment + + case nameEx @ NameEx(name) => + if (primeToAdd && stateVars.contains(name)) { + // add prime to a variable name + OperEx(TlaActionOper.prime, nameEx) + } else { + nameEx + } + + case ex @ LetInEx(body, defs @ _*) => + val newDefs = defs.map(tracker.trackOperDecl { x => + x.copy(body = transform(primeToAdd)(x.body)) + }) + val newBody = transform(primeToAdd)(body) + if (defs == newDefs && body == newBody) ex + else LetInEx(newBody, newDefs: _*) + + case e => e + } + + transform(primeToAdd = false)(expr) } } diff --git a/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestPrimePropagation.scala b/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestPrimePropagation.scala new file mode 100644 index 0000000000..0772287e02 --- /dev/null +++ b/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestPrimePropagation.scala @@ -0,0 +1,68 @@ +package at.forsyte.apalache.tla.lir.transformations.standard + +import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker +import at.forsyte.apalache.tla.lir.{LetInEx, TlaOperDecl} +import org.scalatest.{BeforeAndAfter, FunSuite} + +/** + * Tests of PrimePropagation. + */ +class TestPrimePropagation extends FunSuite with BeforeAndAfter { + import tla._ + + private var transformer: PrimePropagation = _ + + before { + transformer = new PrimePropagation(new IdleTracker, Set("x", "y")) + } + + test("a name should not be primed") { + val input = name("x") + val output = transformer(input) + assert(output == input) + } + + test("a constant should not be primed") { + val input = name("N") + val output = transformer(input) + assert(output == input) + } + + test("a primed variable stays primed") { + val input = prime(name("x")) + val output = transformer(input) + assert(output == input) + } + + test("a primed literal is de-primed") { + val intEx = int(2021) + val input = prime(intEx) + val output = transformer(input) + assert(intEx == output) + } + + test("a primed constant is de-primed") { + val const = name("N") + val input = prime(const) + val output = transformer(input) + assert(const == output) + } + + test("prime is propagated in operator") { + val input = prime(appFun(name("x"), name("y"))) + val output = transformer(input) + val expected = appFun(prime(name("x")), prime(name("y"))) + assert(expected == output) + } + + test("prime is propagated in LET-IN") { + val fooDecl = TlaOperDecl("Foo", List.empty, appFun(name("x"), name("y"))) + val letIn = LetInEx(appOp("Foo"), fooDecl) + val input = prime(letIn) + val output = transformer(input) + val expectedDecl = TlaOperDecl("Foo", List.empty, appFun(prime(name("x")), prime(name("y")))) + val expectedLetIn = LetInEx(appOp("Foo"), expectedDecl) + assert(expectedLetIn == output) + } +}