From fb972ee03d55a8df98dd63ba8c22ae2381bdba23 Mon Sep 17 00:00:00 2001 From: Shon Feder Date: Wed, 5 Oct 2022 01:07:43 -0400 Subject: [PATCH] Return failure data from pass chain executions (#2186) Co-authored-by: Igor Konnov --- .unreleased/features/2186-rpc-data.md | 2 + build.sbt | 4 +- .../forsyte/apalache/infra/passes/Pass.scala | 52 +++++++++++++++- .../infra/passes/PassChainExecutor.scala | 3 +- .../infra/passes/TestPassChainExecutor.scala | 2 +- .../apalache/tla/tooling/opt/CheckCmd.scala | 4 +- .../apalache/tla/tooling/opt/ParseCmd.scala | 2 +- .../apalache/tla/tooling/opt/TestCmd.scala | 4 +- .../tla/tooling/opt/TranspileCmd.scala | 4 +- .../tla/tooling/opt/TypeCheckCmd.scala | 4 +- project/Dependencies.scala | 1 + .../apalache/shai/v1/CmdExecutorService.scala | 60 +++++++++---------- .../shai/v1/TestCmdExecutorService.scala | 20 +++++-- .../forsyte/apalache/tla/bmcmt/Checker.scala | 21 ++++++- ...tCounterexamplesModelCheckerListener.scala | 12 ++-- .../apalache/tla/bmcmt/Counterexample.scala | 47 +++++++++++++++ .../bmcmt/DumpFilesModelCheckerListener.scala | 30 +++++----- .../tla/bmcmt/ModelCheckerListener.scala | 17 ++---- .../apalache/tla/bmcmt/SeqModelChecker.scala | 51 +++++++--------- .../bmcmt/passes/BoundedCheckerPassImpl.scala | 14 +++-- .../tla/bmcmt/search/SearchState.scala | 17 +++--- .../tla/bmcmt/CrossTestEncodings.scala | 18 +++--- ...unterexamplesSeqModelCheckerListener.scala | 30 +++++----- .../tla/bmcmt/TestSeqModelCheckerTrait.scala | 49 ++++++++------- .../apalache/io/json/JsonDecoder.scala | 15 ++++- .../apalache/io/json/JsonRepresentation.scala | 2 +- .../forsyte/apalache/io/json/JsonToTla.scala | 22 +++++-- .../apalache/io/json/impl/TlaToUJson.scala | 8 +++ .../io/lir/ItfCounterexampleWriter.scala | 10 ++-- .../tla/imp/passes/SanyParserPassImpl.scala | 31 +++++----- .../io/lir/TestItfCounterexampleWriter.scala | 4 +- .../tla/pp/passes/PreproPassPartial.scala | 3 +- .../SourceAwareTypeCheckerListener.scala | 19 ++++++ .../tla/typecheck/TypeCheckerTool.scala | 12 ++-- .../RecordingTypeCheckerListener.scala | 13 +++- .../passes/EtcTypeCheckerPassImpl.scala | 33 +++++----- .../passes/LoggingTypeCheckerListener.scala | 17 ++---- 37 files changed, 413 insertions(+), 244 deletions(-) create mode 100644 .unreleased/features/2186-rpc-data.md create mode 100644 tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Counterexample.scala create mode 100644 tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/SourceAwareTypeCheckerListener.scala diff --git a/.unreleased/features/2186-rpc-data.md b/.unreleased/features/2186-rpc-data.md new file mode 100644 index 0000000000..543098ec31 --- /dev/null +++ b/.unreleased/features/2186-rpc-data.md @@ -0,0 +1,2 @@ +Return JSON with success or failure data from RPC calls to the CmdExecutor +service (see #2186). diff --git a/build.sbt b/build.sbt index e725eefef9..2b4842e643 100644 --- a/build.sbt +++ b/build.sbt @@ -136,7 +136,9 @@ lazy val infra = (project in file("mod-infra")) .settings( testSettings, libraryDependencies ++= Seq( - Deps.commonsIo + Deps.commonsIo, + Deps.ujson, + Deps.upickle, ), ) diff --git a/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/Pass.scala b/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/Pass.scala index 5638b40cd1..d9d5b22f03 100644 --- a/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/Pass.scala +++ b/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/Pass.scala @@ -2,7 +2,10 @@ package at.forsyte.apalache.infra.passes import at.forsyte.apalache.infra.ExitCodes.TExitCode import at.forsyte.apalache.infra.passes.Pass.PassResult -import at.forsyte.apalache.tla.lir.{ModuleProperty, TlaModule} +import at.forsyte.apalache.tla.lir.ModuleProperty +import at.forsyte.apalache.tla.lir.TlaModule +import upickle.default.Writer +import upickle.default.writeJs /** *

An analysis or transformation pass. Instead of explicitly setting a pass' input and output, we interconnect passes @@ -16,6 +19,8 @@ import at.forsyte.apalache.tla.lir.{ModuleProperty, TlaModule} * * @author * Igor Konnov + * @author + * Shon Feder */ trait Pass { @@ -69,8 +74,51 @@ trait Pass { */ def transformations: Set[ModuleProperty.Value] + /** + * Construct a failing pass result + * + * To be called to construct a failing `PassResult` in the event that a pass fails. + * + * @param errorData + * Data providing insights into the reasons for the failure. + * @param exitCode + * The exit code to be used when terminating the program. + * @param f + * An implicit upickle writer than can convert the `errorData` into json. You can import `upickle.default._` to get + * implicits for common datatypes. For an example of defining a custom writer, see + * `at.forsyte.apalache.tla.bmcmt.Counterexample`. + */ + def passFailure[E](errorData: E, exitCode: TExitCode)(implicit f: Writer[E]): PassResult = + Left(Pass.PassFailure(name, writeJs(errorData), exitCode)) } object Pass { - type PassResult = Either[TExitCode, TlaModule] + + import upickle.implicits.key + + /** + * Represents a failing pass + * + * @param passName + * The name of the pass which has failed. + * @param errorData + * Data providing insights into the reasons for the failure. + * @param exitCode + * The exit code to be used when terminating the program. + */ + case class PassFailure( + @key("pass_name") passName: String, + @key("error_data") errorData: ujson.Value, + @key("exit_code") exitCode: TExitCode) {} + + /** Implicit conversions for [[PassFailure]] */ + object PassFailure { + import upickle.default.{macroRW, writeJs, ReadWriter} + + implicit val upickleReadWriter: ReadWriter[PassFailure] = macroRW + + implicit val ujsonView: PassFailure => ujson.Value = writeJs + } + + type PassResult = Either[PassFailure, TlaModule] } diff --git a/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/PassChainExecutor.scala b/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/PassChainExecutor.scala index 630a9b617f..f38d753409 100644 --- a/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/PassChainExecutor.scala +++ b/mod-infra/src/main/scala/at/forsyte/apalache/infra/passes/PassChainExecutor.scala @@ -1,6 +1,5 @@ package at.forsyte.apalache.infra.passes -import at.forsyte.apalache.infra.ExitCodes.TExitCode import at.forsyte.apalache.infra.passes.Pass.PassResult import com.typesafe.scalalogging.LazyLogging import at.forsyte.apalache.tla.lir.{MissingTransformationError, TlaModule, TlaModuleProperties} @@ -22,7 +21,7 @@ import at.forsyte.apalache.infra.AdaptedException */ object PassChainExecutor extends LazyLogging { - type PassResultModule = Either[TExitCode, TlaModule with TlaModuleProperties] + type PassResultModule = Either[Pass.PassFailure, TlaModule with TlaModuleProperties] def run[O <: OptionGroup](toolModule: ToolModule[O]): PassResult = { diff --git a/mod-infra/src/test/scala/at/forsyte/apalache/infra/passes/TestPassChainExecutor.scala b/mod-infra/src/test/scala/at/forsyte/apalache/infra/passes/TestPassChainExecutor.scala index 27a9b62a6c..1ce5dd428c 100644 --- a/mod-infra/src/test/scala/at/forsyte/apalache/infra/passes/TestPassChainExecutor.scala +++ b/mod-infra/src/test/scala/at/forsyte/apalache/infra/passes/TestPassChainExecutor.scala @@ -17,7 +17,7 @@ class TestPassChainExecutor extends AnyFunSuite { if (result) { Right(TlaModule("TestModule", Seq())) } else { - Left(ExitCodes.ERROR) + passFailure(None, ExitCodes.ERROR) } } override def dependencies = deps diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala index b6bf9bb9b6..0650f81274 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala @@ -124,8 +124,8 @@ class CheckCmd(name: String = "check", description: String = "Check a TLA+ speci logger.info("Tuning: " + tuning.toList.map { case (k, v) => s"$k=$v" }.mkString(":")) PassChainExecutor.run(new CheckerModule(options)) match { - case Right(_) => Right(s"Checker reports no error up to computation length ${options.checker.length}") - case Left(code) => Left(code, "Checker has found an error") + case Right(_) => Right(s"Checker reports no error up to computation length ${options.checker.length}") + case Left(failure) => Left(failure.exitCode, "Checker has found an error") } } diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/ParseCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/ParseCmd.scala index c27677ed2b..a4a60a990c 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/ParseCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/ParseCmd.scala @@ -36,7 +36,7 @@ class ParseCmd PassChainExecutor.run(new ParserModule(options)) match { case Right(m) => Right(s"Parsed successfully\nRoot module: ${m.name} with ${m.declarations.length} declarations.") - case Left(code) => Left(code, "Parser has failed") + case Left(failure) => Left(failure.exitCode, "Parser has failed") } } } diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TestCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TestCmd.scala index bab3a37a52..a04e2e8ac3 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TestCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TestCmd.scala @@ -73,8 +73,8 @@ class TestCmd logger.info("Tuning: " + tuning.toList.map { case (k, v) => s"$k=$v" }.mkString(":")) PassChainExecutor.run(new CheckerModule(options)) match { - case Right(_) => Right("No example found") - case Left(code) => Left(code, "Found a violation of the postcondition. Check violation.tla.") + case Right(_) => Right("No example found") + case Left(failure) => Left(failure.exitCode, "Found a violation of the postcondition. Check violation.tla.") } } diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TranspileCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TranspileCmd.scala index 4a94da599b..d027636630 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TranspileCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TranspileCmd.scala @@ -26,8 +26,8 @@ class TranspileCmd extends AbstractCheckerCmd(name = "transpile", description = .getOrElse(TlaExToVMTWriter.outFileName) PassChainExecutor.run(new ReTLAToVMTModule(options)) match { - case Right(_) => Right(s"VMT constraints successfully generated at\n$outFilePath") - case Left(code) => Left(code, "Failed to generate constraints") + case Right(_) => Right(s"VMT constraints successfully generated at\n$outFilePath") + case Left(failure) => Left(failure.exitCode, "Failed to generate constraints") } } } diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TypeCheckCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TypeCheckCmd.scala index 1fb94496fb..9722e20062 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TypeCheckCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/TypeCheckCmd.scala @@ -40,8 +40,8 @@ class TypeCheckCmd logger.info("Type checking " + file) PassChainExecutor.run(new TypeCheckerModule(options)) match { - case Right(_) => Right("Type checker [OK]") - case Left(code) => Left(code, "Type checker [FAILED]") + case Right(_) => Right("Type checker [OK]") + case Left(failure) => Left(failure.exitCode, "Type checker [FAILED]") } } } diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 12d221c0bd..175bf0d9cc 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -35,6 +35,7 @@ object Dependencies { val shapeless = "com.chuusai" %% "shapeless" % "2.3.10" val tla2tools = "org.lamport" % "tla2tools" % "1.7.0-SNAPSHOT" val ujson = "com.lihaoyi" %% "ujson" % "2.0.0" + val upickle = "com.lihaoyi" %% "upickle" % "2.0.0" val z3 = "tools.aqua" % "z3-turnkey" % "4.11.2" val zio = "dev.zio" %% "zio" % zioVersion // Keep up to sync with version in plugins.sbt diff --git a/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala b/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala index f9f5f76347..3fa0b86df2 100644 --- a/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala +++ b/shai/src/main/scala/at/forsyte/apalache/shai/v1/CmdExecutorService.scala @@ -1,25 +1,20 @@ package at.forsyte.apalache.shai.v1 +import scala.util.Try + import com.typesafe.scalalogging.Logger import io.grpc.Status -import scala.util.Failure -import scala.util.Success -import zio.ZIO import zio.ZEnv +import zio.ZIO -import at.forsyte.apalache.infra.passes.PassChainExecutor import at.forsyte.apalache.infra.passes.options.OptionGroup +import at.forsyte.apalache.infra.passes.{Pass, PassChainExecutor} import at.forsyte.apalache.io.ConfigManager import at.forsyte.apalache.io.json.impl.TlaToUJson -import at.forsyte.apalache.io.lir.TlaType1PrinterPredefs.printer // Required as implicit parameter to JsonTlaWRiter -import at.forsyte.apalache.shai.v1.cmdExecutor.Cmd -import at.forsyte.apalache.shai.v1.cmdExecutor.{CmdRequest, CmdResponse, ZioCmdExecutor} +import at.forsyte.apalache.shai.v1.cmdExecutor.{Cmd, CmdRequest, CmdResponse, ZioCmdExecutor} import at.forsyte.apalache.tla.bmcmt.config.CheckerModule import at.forsyte.apalache.tla.imp.passes.ParserModule -import at.forsyte.apalache.tla.lir.TlaModule import at.forsyte.apalache.tla.typecheck.passes.TypeCheckerModule -import scala.util.Try -import at.forsyte.apalache.infra.passes.options.Config /** * Provides the [[CmdExecutorService]] @@ -43,17 +38,31 @@ class CmdExecutorService(logger: Logger) extends ZioCmdExecutor.ZCmdExecutor[ZEn def run(req: CmdRequest): Result[CmdResponse] = for { cmd <- validateCmd(req.cmd) resp <- executeCmd(cmd, req.config) match { - case Right(r) => ZIO.succeed(CmdResponse.Result.Success(r)) - case Left(err) => ZIO.succeed(CmdResponse.Result.Failure(err)) + case Right(r) => ZIO.succeed(CmdResponse.Result.Success(r.toString())) + case Left(err) => ZIO.succeed(CmdResponse.Result.Failure(err.toString())) } } yield CmdResponse(resp) - private def executeCmd(cmd: Cmd, cfgStr: String): Either[String, String] = { - // Convert a Try into an `Either` with `Left` the message from a possible `Failure`. - def convErr[O](v: Try[O]) = v.toEither.left.map(e => e.getMessage()) + // Convert pass error results into the JSON representation + private object Converters { + import ujson._ + + def passErr(err: Pass.PassFailure): ujson.Value = { + Obj("error_type" -> "pass_failure", "data" -> err) + } + + def throwableErr(err: Throwable): ujson.Value = + Obj("error_type" -> "unexpected", + "data" -> Obj("msg" -> err.getMessage(), "stack_trace" -> err.getStackTrace().map(_.toString()).toList)) + + def convErr[O](v: Try[O]): Either[ujson.Value, O] = v.toEither.left.map(throwableErr) + } + + import Converters._ + private def executeCmd(cmd: Cmd, cfgStr: String): Either[ujson.Value, ujson.Value] = { for { - cfg <- parseConfig(cfgStr) + cfg <- convErr(ConfigManager(cfgStr)).map(cfg => cfg.copy(common = cfg.common.copy(command = Some("server")))) toolModule <- { import OptionGroup._ @@ -67,13 +76,11 @@ class CmdExecutorService(logger: Logger) extends ZioCmdExecutor.ZCmdExecutor[ZEn } tlaModule <- - try { PassChainExecutor.run(toolModule).left.map(errCode => s"Command failed with error code: ${errCode}") } + try { PassChainExecutor.run(toolModule).left.map(passErr) } catch { - case e: Throwable => Left(s"Command failed with exception: ${e.getMessage()}") + case err: Throwable => Left(throwableErr(err)) } - - json = jsonOfModule(tlaModule) - } yield s"Command succeeded ${json}" + } yield TlaToUJson(tlaModule) } // Allows us to handle invalid protobuf messages on the ZIO level, before @@ -84,15 +91,4 @@ class CmdExecutorService(logger: Logger) extends ZioCmdExecutor.ZCmdExecutor[ZEn ZIO.fail(Status.INVALID_ARGUMENT.withDescription(msg)) case cmd => ZIO.succeed(cmd) } - - private def parseConfig(data: String): Either[String, Config.ApalacheConfig] = { - ConfigManager(data) match { - case Success(cfg) => Right(cfg.copy(common = cfg.common.copy(command = Some("server")))) - case Failure(err) => Left(s"Invalid configuration data given to command: ${err.getMessage()}") - } - } - - private def jsonOfModule(module: TlaModule): String = { - new TlaToUJson(None).makeRoot(Seq(module)).toString - } } diff --git a/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala b/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala index 0fdcb8afef..d842fdc4bb 100644 --- a/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala +++ b/shai/src/test/scala/at/forsyte/apalache/shai/v1/TestCmdExecutorService.scala @@ -80,27 +80,35 @@ object TestCmdExecutorService extends DefaultRunnableSpec { resp <- s.run(runCmd(Cmd.CHECK, checkableSpec)) } yield assert(resp.result.isSuccess)(isTrue) }, - testM("running check on spec with vioalted invariant fails") { + testM("running check on spec with violated invariant fails") { for { s <- ZIO.service[CmdExecutorService] config = Config.ApalacheConfig(checker = Config.Checker(inv = Some(List("Inv")))) resp <- s.run(runCmd(Cmd.CHECK, checkableSpec, cfg = config)) - // error code 12 indicates counterexamples found - } yield assert(resp.result.failure.get)(containsString("Command failed with error code: 12")) + json = ujson.read(resp.result.failure.get) + } yield { + assert(json("error_type").str)(equalTo("pass_failure")) + assert(json("data")("pass_name").str)(equalTo("BoundedChecker")) + assert(json("data")("error_data")("checking_result").str)(equalTo("violation")) + assert(json("data")("error_data")("counterexamples").arr)(isNonEmpty) + } }, testM("typechecking well-typed spec succeeds") { for { s <- ZIO.service[CmdExecutorService] resp <- s.run(runCmd(Cmd.TYPECHECK, trivialSpec)) - // error code 12 indicates counterexamples found } yield assert(resp.result.isSuccess)(isTrue) }, testM("typechecking ill-typed spec returns an error") { for { s <- ZIO.service[CmdExecutorService] resp <- s.run(runCmd(Cmd.TYPECHECK, illTypedSpec)) - // error code 120 indicates a typechecking error - } yield assert(resp.result.failure.get)(containsString("Command failed with error code: 120")) + json = ujson.read(resp.result.failure.get) + } yield { + assert(json("error_type").str)(equalTo("pass_failure")) + assert(json("data")("pass_name").str)(equalTo("TypeCheckerSnowcat")) + assert(json("data")("error_data").arr)(isNonEmpty) + } }, ) // Create the single shared service for use in our tests, allowing us to run diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Checker.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Checker.scala index 90d240437e..c86ec71a94 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Checker.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Checker.scala @@ -13,13 +13,30 @@ object Checker { val isOk: Boolean } + object CheckerResult { + + import upickle.default.{writer, Writer} + import ujson._ + + implicit val ujsonView: CheckerResult => ujson.Value = { + case Error(nerrors, counterexamples) => + Obj("checking_result" -> "Error", "counterexamples" -> counterexamples, "nerrors" -> nerrors) + case Deadlock(counterexample) => + Obj("checking_result" -> "Deadlock", "counterexample" -> counterexample) + case other => + Obj("checking_result" -> other.toString()) + } + + implicit val upickleWriter: Writer[CheckerResult] = writer[ujson.Value].comap(ujsonView) + } + case class NoError() extends CheckerResult { override def toString: String = "NoError" override val isOk: Boolean = true } - case class Error(nerrors: Int) extends CheckerResult { + case class Error(nerrors: Int, counterexamples: Seq[Counterexample]) extends CheckerResult { override def toString: String = s"Error" override val isOk: Boolean = false @@ -28,7 +45,7 @@ object Checker { /** * An execution cannot be extended. We interpret it as a deadlock. */ - case class Deadlock() extends CheckerResult { + case class Deadlock(counterexample: Option[Counterexample]) extends CheckerResult { override def toString: String = "Deadlock" override val isOk: Boolean = false diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/CollectCounterexamplesModelCheckerListener.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/CollectCounterexamplesModelCheckerListener.scala index 7858d77cd5..8eb9c96162 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/CollectCounterexamplesModelCheckerListener.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/CollectCounterexamplesModelCheckerListener.scala @@ -1,7 +1,7 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.trex.DecodedExecution -import at.forsyte.apalache.tla.lir.{TlaEx, TlaModule} +import at.forsyte.apalache.tla.lir.TlaModule import com.typesafe.scalalogging.LazyLogging import scala.collection.mutable.ListBuffer @@ -12,18 +12,16 @@ import scala.collection.mutable.ListBuffer class CollectCounterexamplesModelCheckerListener extends ModelCheckerListener with LazyLogging { override def onCounterexample( - rootModule: TlaModule, - trace: DecodedExecution, - invViolated: TlaEx, + counterexample: Counterexample, errorIndex: Int): Unit = { - _counterExamples += trace + _counterExamples += counterexample } override def onExample(rootModule: TlaModule, trace: DecodedExecution, exampleIndex: Int): Unit = { // ignore the examples } - private val _counterExamples = ListBuffer.empty[DecodedExecution] + private val _counterExamples = ListBuffer.empty[Counterexample] - def counterExamples: Seq[DecodedExecution] = _counterExamples.toSeq + def counterExamples: Seq[Counterexample] = _counterExamples.toSeq } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Counterexample.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Counterexample.scala new file mode 100644 index 0000000000..e23a78302e --- /dev/null +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/Counterexample.scala @@ -0,0 +1,47 @@ +package at.forsyte.apalache.tla.bmcmt + +import at.forsyte.apalache.tla.lir.TlaEx +import at.forsyte.apalache.tla.lir.TlaModule +import at.forsyte.apalache.io.lir.ItfCounterexampleWriter +import at.forsyte.apalache.tla.bmcmt.trex.DecodedExecution + +/** + * Representation of a counterexample found while model checking + * + * @param rootModule + * The checked TLA+ module. + * @param states + * The states leading up to the invariant violation. + * @param invViolated + * The invariant violation to record in the counterexample. Pass + * - for invariant violations: the negated invariant, + * - for deadlocks: `ValEx(TlaBool(true))`, + * - for trace invariants: the applied, negated trace invariant + * + * @author + * Shon Feder + */ +case class Counterexample(module: TlaModule, states: Counterexample.States, invViolated: TlaEx) + +object Counterexample { + type States = List[(String, Map[String, TlaEx])] + + import upickle.default.{writer, Writer} + + // Defines an implicit view for converting to UJSON + implicit val ujsonView: Counterexample => ujson.Value = { case Counterexample(module, states, _) => + ItfCounterexampleWriter.mkJson(module, states) + } + + // Defines an implicit converter for writing with upickle + implicit val upickleWriter: Writer[Counterexample] = + writer[ujson.Value].comap(ujsonView) + + /** Produce a `Counterexample` from a `trace` (rather than from `states`) */ + def apply(module: TlaModule, trace: DecodedExecution, invViolated: TlaEx): Counterexample = { + // TODO(shonfeder): This conversion seems kind of senseless: we just swap the tuple and convert the transition index to + // a string. Lots depends on this particular format, but it seems like a pretty pointless intermediary structure? + val states = trace.path.map(p => (p._2.toString, p._1)) + Counterexample(module, states, invViolated) + } +} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/DumpFilesModelCheckerListener.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/DumpFilesModelCheckerListener.scala index c8e8fae2e1..02e69675ba 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/DumpFilesModelCheckerListener.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/DumpFilesModelCheckerListener.scala @@ -4,7 +4,7 @@ import at.forsyte.apalache.io.lir.CounterexampleWriter import at.forsyte.apalache.tla.bmcmt.trex.DecodedExecution import at.forsyte.apalache.tla.lir.TypedPredefs.BuilderExAsTyped import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.{BoolT1, TlaEx, TlaModule} +import at.forsyte.apalache.tla.lir.{BoolT1, TlaModule} import com.typesafe.scalalogging.LazyLogging /** @@ -20,28 +20,30 @@ import com.typesafe.scalalogging.LazyLogging */ object DumpFilesModelCheckerListener extends ModelCheckerListener with LazyLogging { - override def onCounterexample( - rootModule: TlaModule, - trace: DecodedExecution, - invViolated: TlaEx, - errorIndex: Int): Unit = { - dump(rootModule, trace, invViolated, errorIndex, "violation") + override def onCounterexample(counterexample: Counterexample, errorIndex: Int): Unit = { + dump(counterexample, errorIndex, "violation") } override def onExample(rootModule: TlaModule, trace: DecodedExecution, exampleIndex: Int): Unit = { - dump(rootModule, trace, tla.bool(true).as(BoolT1), exampleIndex, "example") + val counterexample = Counterexample(rootModule, trace, tla.bool(true).as(BoolT1)) + dump(counterexample, exampleIndex, "example") } private def dump( - rootModule: TlaModule, - trace: DecodedExecution, - invViolated: TlaEx, + counterexample: Counterexample, index: Int, prefix: String): Unit = { - val states = trace.path.map(p => (p._2.toString, p._1)) - def dump(suffix: String): List[String] = { - CounterexampleWriter.writeAllFormats(prefix, suffix, rootModule, invViolated, states) + // TODO(shonfeder): Should the CounterexampleWriter take a Counterexample? + // Would require fixing inter-package dependencies, since it would require + // exposing the Counterexample class to the tla-io project. + CounterexampleWriter.writeAllFormats( + prefix, + suffix, + counterexample.module, + counterexample.invViolated, + counterexample.states, + ) } // for a human user, write the latest (counter)example into ${prefix}.{tla,json,...} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelCheckerListener.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelCheckerListener.scala index 795a6c6573..82e4fd77bb 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelCheckerListener.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelCheckerListener.scala @@ -1,7 +1,7 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.trex.DecodedExecution -import at.forsyte.apalache.tla.lir.{TlaEx, TlaModule} +import at.forsyte.apalache.tla.lir.TlaModule /** * Observe [[SeqModelChecker]]. State changes in model checker state are reported via callbacks. @@ -11,23 +11,14 @@ trait ModelCheckerListener { /** * Call when the model checker encounters a counterexample. * - * @param rootModule - * The checked TLA+ module. - * @param trace - * The counterexample trace. - * @param invViolated - * The invariant violation to record in the counterexample. Pass - * - for invariant violations: the negated invariant, - * - for deadlocks: `ValEx(TlaBool(true))`, - * - for trace invariants: the applied, negated trace invariant + * @param counterexample + * The counterexample to record * @param errorIndex * Number of found error (likely [[search.SearchState.nFoundErrors]]). */ // For more on possible trace invariant violations, see the private method `SeqModelChecker.applyTraceInv` def onCounterexample( - rootModule: TlaModule, - trace: DecodedExecution, - invViolated: TlaEx, + counterexample: Counterexample, errorIndex: Int): Unit /** diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SeqModelChecker.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SeqModelChecker.scala index 5d41cda445..e494075dd8 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SeqModelChecker.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SeqModelChecker.scala @@ -3,9 +3,7 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.Checker._ import at.forsyte.apalache.tla.bmcmt.search.ModelCheckerParams.InvariantMode import at.forsyte.apalache.tla.bmcmt.search.{ModelCheckerParams, SearchState} -import at.forsyte.apalache.tla.bmcmt.trex.{ - ConstrainedTransitionExecutor, DecodedExecution, ExecutionSnapshot, TransitionExecutor, -} +import at.forsyte.apalache.tla.bmcmt.trex.{ConstrainedTransitionExecutor, ExecutionSnapshot, TransitionExecutor} import at.forsyte.apalache.tla.lir.TypedPredefs.TypeTagAsTlaType1 import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir._ @@ -108,24 +106,15 @@ class SeqModelChecker[ExecutorContextT]( /** * Notify all listeners that a counterexample has been found. * - * @param rootModule - * The checked TLA+ module. - * @param trace - * The counterexample trace. - * @param invViolated - * The invariant violation to record in the counterexample. Pass - * - for invariant violations: the negated invariant, - * - for deadlocks: `ValEx(TlaBool(true))`, - * - for trace invariants: the applied, negated trace invariant (see [[SeqModelChecker.applyTraceInv]]). + * @param counterexample + * The counterexample to record * @param errorIndex * Number of found error (likely [[SearchState.nFoundErrors]]). */ private def notifyOnError( - rootModule: TlaModule, - trace: DecodedExecution, - invViolated: TlaEx, + counterexample: Counterexample, errorIndex: Int): Unit = { - listeners.foreach(_.onCounterexample(rootModule, trace, invViolated, errorIndex)) + listeners.foreach(_.onCounterexample(counterexample, errorIndex)) } private def makeStep(isNext: Boolean, transitions: Seq[TlaEx]): Unit = { @@ -143,14 +132,16 @@ class SeqModelChecker[ExecutorContextT]( case Some(true) => () // OK case Some(false) => - if (trex.sat(0).contains(true)) { - notifyOnError(checkerInput.rootModule, trex.decodedExecution(), ValEx(TlaBool(true)), - searchState.nFoundErrors) + val counterexample = if (trex.sat(0).contains(true)) { + val cx = Counterexample(checkerInput.rootModule, trex.decodedExecution(), ValEx(TlaBool(true))) + notifyOnError(cx, searchState.nFoundErrors) logger.error("Found a deadlock.") + Some(cx) } else { logger.error(s"Found a deadlock. No SMT model.") + None } - searchState.onResult(Deadlock()) + searchState.onResult(Deadlock(counterexample)) case None => searchState.onResult(RuntimeError()) @@ -289,14 +280,16 @@ class SeqModelChecker[ExecutorContextT]( if (trex.preparedTransitionNumbers.isEmpty) { if (params.checkForDeadlocks) { - if (trex.sat(0).contains(true)) { - notifyOnError(checkerInput.rootModule, trex.decodedExecution(), ValEx(TlaBool(true)), - searchState.nFoundErrors) + val counterexample = if (trex.sat(0).contains(true)) { + val cx = Counterexample(checkerInput.rootModule, trex.decodedExecution(), ValEx(TlaBool(true))) + notifyOnError(cx, searchState.nFoundErrors) logger.error("Found a deadlock.") + Some(cx) } else { logger.error(s"Found a deadlock. No SMT model.") + None } - searchState.onResult(Deadlock()) + searchState.onResult(Deadlock(counterexample)) } else { val msg = "All executions are shorter than the provided bound." logger.warn(msg) @@ -359,8 +352,9 @@ class SeqModelChecker[ExecutorContextT]( trex.sat(params.smtTimeoutSec) match { case Some(true) => - searchState.onResult(Error(1)) - notifyOnError(checkerInput.rootModule, trex.decodedExecution(), notInv, searchState.nFoundErrors) + val counterexample = Counterexample(checkerInput.rootModule, trex.decodedExecution(), notInv) + searchState.onResult(Error(1, Seq(counterexample))) + notifyOnError(counterexample, searchState.nFoundErrors) logger.info(f"State ${stateNo}: ${kind} invariant ${invNo} violated.") excludePathView() @@ -404,8 +398,9 @@ class SeqModelChecker[ExecutorContextT]( trex.sat(params.smtTimeoutSec) match { case Some(true) => - searchState.onResult(Error(1)) - notifyOnError(checkerInput.rootModule, trex.decodedExecution(), traceInvApp, searchState.nFoundErrors) + val counterexample = Counterexample(checkerInput.rootModule, trex.decodedExecution(), traceInvApp) + searchState.onResult(Error(1, Seq(counterexample))) + notifyOnError(counterexample, searchState.nFoundErrors) val msg = "State %d: trace invariant %s violated.".format(stateNo, invNo) logger.error(msg) excludePathView() diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala index 6e28af0f85..b47dd0baea 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala @@ -89,14 +89,18 @@ class BoundedCheckerPassImpl @Inject() ( case Algorithm.Offline => runOfflineChecker(params, input, tuning, solverConfig) } - if (result) Right(module) else Left(ExitCodes.ERROR_COUNTEREXAMPLE) + if (result.isOk) { + Right(module) + } else { + passFailure(result, ExitCodes.ERROR_COUNTEREXAMPLE) + } } private def runIncrementalChecker( params: ModelCheckerParams, input: CheckerInput, tuning: Map[String, String], - solverConfig: SolverConfig): Boolean = { + solverConfig: SolverConfig): Checker.CheckerResult = { val solverContext: RecordingSolverContext = RecordingSolverContext.createZ3(None, solverConfig) val metricProfilerListener = @@ -131,14 +135,14 @@ class BoundedCheckerPassImpl @Inject() ( val outcome = checker.run() rewriter.dispose() logger.info(s"The outcome is: " + outcome) - outcome.isOk + outcome } private def runOfflineChecker( params: ModelCheckerParams, input: CheckerInput, tuning: Map[String, String], - solverConfig: SolverConfig): Boolean = { + solverConfig: SolverConfig): Checker.CheckerResult = { val solverContext: RecordingSolverContext = RecordingSolverContext.createZ3(None, solverConfig) if (solverConfig.profile) { @@ -166,7 +170,7 @@ class BoundedCheckerPassImpl @Inject() ( val outcome = checker.run() rewriter.dispose() logger.info(s"The outcome is: " + outcome) - outcome.isOk + outcome } /* diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/search/SearchState.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/search/SearchState.scala index ecc96508e9..0981dbf004 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/search/SearchState.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/search/SearchState.scala @@ -1,6 +1,7 @@ package at.forsyte.apalache.tla.bmcmt.search -import at.forsyte.apalache.tla.bmcmt.Checker +import at.forsyte.apalache.tla.bmcmt.{Checker, Counterexample} +import scala.collection.mutable.ListBuffer /** * The search state machine that is implemented by SeqModelChecker. This machine is simple when the model checker fails @@ -17,6 +18,7 @@ class SearchState(params: ModelCheckerParams) { private var _result: CheckerResult = NoError() private var _nFoundErrors: Int = 0 + private val _counterexamples: ListBuffer[Counterexample] = ListBuffer.empty private var _nRunsLeft: Int = if (params.isRandomSimulation) params.nSimulationRuns else 1 @@ -46,7 +48,7 @@ class SearchState(params: ModelCheckerParams) { _result match { case NoError() => if (_nFoundErrors > 0) { - Error(_nFoundErrors) + Error(_nFoundErrors, _counterexamples.toList) } else { NoError() } @@ -77,22 +79,23 @@ class SearchState(params: ModelCheckerParams) { */ def onResult(result: CheckerResult): Unit = { result match { - case Error(_) => + case Error(_, counterexamples) => _nFoundErrors += 1 + _counterexamples.appendAll(counterexamples) if (_nFoundErrors >= params.nMaxErrors) { // go to an error state, as the maximum number of errors has been reached - _result = Error(_nFoundErrors) + _result = Error(_nFoundErrors, _counterexamples.toList) } else { // the search may continue, to discover more errors _result = NoError() } - case Deadlock() => + case Deadlock(counterexample) => if (_nFoundErrors > 0) { // this deadlock is probably caused by exclusion of previous counterexamples, so it may be a false positive - _result = Error(_nFoundErrors) + _result = Error(_nFoundErrors, _counterexamples.toList) } else { - _result = Deadlock() + _result = Deadlock(counterexample) } case _ => diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/CrossTestEncodings.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/CrossTestEncodings.scala index faa009ab8b..f55afdc52b 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/CrossTestEncodings.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/CrossTestEncodings.scala @@ -198,15 +198,15 @@ trait CrossTestEncodings extends AnyFunSuite with Checkers { val checker = new SeqModelChecker(checkerParams, checkerInput, trex, Seq(listener)) // check the outcome - val outcome = checker.run() - if (outcome != Error(1)) { - Left(outcome) - } else { - // extract witness expression from the counterexample - assert(listener.counterExamples.length == 1) // () --(init transition)--> initial state - val cex = listener.counterExamples.head.path - val (binding, _) = cex.last // initial state binding - Right(binding) + checker.run() match { + case Error(1, _) => + // extract witness expression from the counterexample + assert(listener.counterExamples.length == 1) // () --(init transition)--> initial state + val cex = listener.counterExamples.head.states + val (_, binding) = cex.last // initial state binding + Right(binding) + + case outcome => Left(outcome) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCollectCounterexamplesSeqModelCheckerListener.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCollectCounterexamplesSeqModelCheckerListener.scala index 1c6ca6c415..daff197e46 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCollectCounterexamplesSeqModelCheckerListener.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCollectCounterexamplesSeqModelCheckerListener.scala @@ -48,6 +48,11 @@ class TestCollectCounterexamplesModelCheckerListener extends AnyFunSuite { (listener, checker) } + private def assertResultHasNErrors(n: Int, result: Checker.CheckerResult) = assert(result match { + case Error(m, _) if m == n => true + case _ => false + }) + test("finds cex for invariant violation at initial state") { // construct TLA+ module val initTrans = List(mkAssign("x", 2)) @@ -58,16 +63,15 @@ class TestCollectCounterexamplesModelCheckerListener extends AnyFunSuite { // check the outcome val (listener, checker) = getChecker(module, initTrans, nextTrans, inv, 1) - val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, checker.run()) // check the counterexample assert(listener.counterExamples.length == 1) - val cex = listener.counterExamples.head.path + val cex = listener.counterExamples.head.states assert(cex.length == 2) // () --(init transition)--> initial state - assert(cex.forall(_._2 == 0)) // state number - assert(cex.head._1.isEmpty) // empty binding on 0th state - val (binding, _) = cex.last + assert(cex.forall(_._1 == "0")) // state number + assert(cex.head._2.isEmpty) // empty binding on 0th state + val (_, binding) = cex.last assert(binding.contains("x")) val valOfX = binding("x") assert(valOfX.isInstanceOf[ValEx]) @@ -84,16 +88,15 @@ class TestCollectCounterexamplesModelCheckerListener extends AnyFunSuite { // check the outcome val (listener, checker) = getChecker(module, initTrans, nextTrans, inv, 1) - val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, checker.run()) // check the counterexample assert(listener.counterExamples.length == 1) - val cex = listener.counterExamples.head.path + val cex = listener.counterExamples.head.states assert(cex.length == 3) // () --(init transition)--> initial state - assert(cex.forall(_._2 == 0)) // state number - assert(cex.head._1.isEmpty) // empty binding on 0th state - val (binding, _) = cex.last + assert(cex.forall(_._1 == "0")) // state number + assert(cex.head._2.isEmpty) // empty binding on 0th state + val (_, binding) = cex.last assert(binding.contains("x")) val valOfX = binding("x") assert(valOfX.isInstanceOf[ValEx]) @@ -110,8 +113,7 @@ class TestCollectCounterexamplesModelCheckerListener extends AnyFunSuite { // check the outcome val (listener, checker) = getChecker(module, initTrans, nextTrans, inv, 3) - val outcome = checker.run() - assert(Error(3) == outcome) + assertResultHasNErrors(3, checker.run()) // check the counterexample assert(listener.counterExamples.length == 3) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelCheckerTrait.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelCheckerTrait.scala index 81029dd4dd..c66cfe9d8c 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelCheckerTrait.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelCheckerTrait.scala @@ -34,6 +34,11 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { TlaModule("root", List(TlaVarDecl("x")(intTag), TlaVarDecl("y")(intTag))) } + private def assertResultHasNErrors(n: Int, result: Checker.CheckerResult) = assert(result match { + case Error(m, _) if m == n => true + case _ => false + }) + test("Init + Inv => OK") { rewriter: SymbStateRewriter => // x' <- 2 val initTrans = List(mkAssign("x", 2)) @@ -65,7 +70,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("ConstInit + Init => OK") { rewriter: SymbStateRewriter => @@ -105,7 +110,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assert(outcome match { case Error(1, _) => true; case _ => false }) } test("Init, deadlock") { rewriter: SymbStateRewriter => @@ -119,7 +124,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Deadlock() == outcome) + assert(outcome match { case Deadlock(_) => true; case _ => false }) } test("Init, 2 options, OK") { rewriter: SymbStateRewriter => @@ -170,7 +175,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + Inv (before + all-enabled) => ERR") { rewriter: SymbStateRewriter => @@ -192,7 +197,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (before + all-enabled) => ERR") { rewriter: SymbStateRewriter => @@ -220,7 +225,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (before + all-enabled) => OK") { rewriter: SymbStateRewriter => @@ -270,7 +275,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + Inv (after + no-all-enabled) => ERR") { rewriter: SymbStateRewriter => @@ -292,7 +297,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (after + all-enabled) => ERR") { rewriter: SymbStateRewriter => @@ -320,7 +325,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (after + all-enabled) => OK") { rewriter: SymbStateRewriter => @@ -376,7 +381,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (before) => OK") { rewriter: SymbStateRewriter => @@ -432,7 +437,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 10 + ActionInv (after) => OK") { rewriter: SymbStateRewriter => @@ -522,7 +527,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next x 2 (LET-IN) + Inv => ERR") { rewriter: SymbStateRewriter => @@ -549,7 +554,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("determinstic Init + 2 steps (regression)") { rewriter: SymbStateRewriter => @@ -591,7 +596,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Deadlock() == outcome) + assert(outcome match { case Deadlock(_) => true; case _ => false }) } test("Init + Next, 10 steps, OK") { rewriter: SymbStateRewriter => @@ -623,7 +628,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Deadlock() == outcome) + assert(outcome match { case Deadlock(_) => true; case _ => false }) } test("Init + Next + Inv x 10 => OK") { rewriter: SymbStateRewriter => @@ -667,7 +672,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next + Inv x 2 => OK, edge case") { rewriter: SymbStateRewriter => @@ -755,7 +760,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("cInit + Init + Next, 10 steps") { rewriter: SymbStateRewriter => @@ -789,7 +794,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - assert(Error(1) == outcome) + assertResultHasNErrors(1, outcome) } test("Init + Next, 10 steps and filter") { rewriter: SymbStateRewriter => @@ -850,13 +855,7 @@ trait TestSeqModelCheckerTrait extends FixtureAnyFunSuite { val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) val checker = new SeqModelChecker(params, checkerInput, trex) val outcome = checker.run() - outcome match { - case Error(nerrors) => - assert(4 == nerrors) - - case _ => - fail("Expected 4 errors") - } + assertResultHasNErrors(4, outcome) } private def mkAssign(varName: String, value: Int): TlaEx = { diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonDecoder.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonDecoder.scala index 37ce18dcfc..7f058178b9 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonDecoder.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonDecoder.scala @@ -1,6 +1,7 @@ package at.forsyte.apalache.io.json import at.forsyte.apalache.tla.lir.{TlaDecl, TlaEx, TlaModule} +import scala.util.Try /** * A JsonDecoder defines a conversion from a json (as represented by T) to a TLA+ expression/declaration/module @@ -12,5 +13,17 @@ trait JsonDecoder[T <: JsonRepresentation] { def asTlaModule(moduleJson: T): TlaModule def asTlaDecl(declJson: T): TlaDecl def asTlaEx(exJson: T): TlaEx - def fromRoot(rootJson: T): Iterable[TlaModule] + def fromRoot(rootJson: T): Seq[TlaModule] + + /** + * Parse a json representation which holds only a single TLA module. This is our typical, and currently only supported + * use case. + * + * @param json + * A JSON encoding of a TLA Module + * @return + * `Success(m)` if the `json` can be parsed, correctly and it described a single module. Otherwise `Failure(t)` + * where `t` is a `Throwable` describing the error. + */ + def fromSingleModule(json: T): Try[TlaModule] } diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonRepresentation.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonRepresentation.scala index ad11a03645..e0eb66f89f 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonRepresentation.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonRepresentation.scala @@ -6,7 +6,7 @@ package at.forsyte.apalache.io.json */ trait JsonRepresentation { - /** The type of tused to represent JSON */ + /** The type used to represent JSON */ type Value def toString: String diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonToTla.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonToTla.scala index 9d3910beaa..8b0956ece9 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonToTla.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/json/JsonToTla.scala @@ -7,6 +7,9 @@ import at.forsyte.apalache.tla.lir.values._ import convenience.tla import UntypedPredefs._ import at.forsyte.apalache.io.lir.TypeTagReader +import scala.util.Try +import scala.util.Success +import scala.util.Failure /** * A semi-abstraction of a json decoder. It is independent of the concrete JsonRepresentation, resp. ScalaFactory @@ -206,12 +209,19 @@ class JsonToTla[T <: JsonRepresentation]( val versionField = getOrThrow(rootJson, TlaToJson.versionFieldName) val version = scalaFactory.asStr(versionField) val current = JsonVersion.current - if (version != current) + if (version != current) { throw new JsonDeserializationError(s"JSON version is $version, expected $current.") - - val modulesField = getOrThrow(rootJson, "modules") - val modulesObjSeq = scalaFactory.asSeq(modulesField) - - modulesObjSeq.map(asTlaModule) + } else { + val modulesField = getOrThrow(rootJson, "modules") + scalaFactory.asSeq(modulesField).map(asTlaModule) + } } + + override def fromSingleModule(json: T): Try[TlaModule] = for { + modules <- Try(fromRoot(json)) + module <- modules match { + case m +: Nil => Success(m) + case _ => Failure(new JsonDeserializationError(s"JSON included more than one module")) + } + } yield module } diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/json/impl/TlaToUJson.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/json/impl/TlaToUJson.scala index 88aa818c78..5ec7beb6ff 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/json/impl/TlaToUJson.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/json/impl/TlaToUJson.scala @@ -3,6 +3,8 @@ package at.forsyte.apalache.io.json.impl import at.forsyte.apalache.io.json.TlaToJson import at.forsyte.apalache.io.lir.TypeTagPrinter import at.forsyte.apalache.tla.lir.storage.SourceLocator +import at.forsyte.apalache.tla.lir.TlaModule +import at.forsyte.apalache.io.lir.TlaType1PrinterPredefs.printer // Required as implicit parameter to JsonTlaWRiter /** * A json encoder, using the UJson representation @@ -11,3 +13,9 @@ class TlaToUJson( locatorOpt: Option[SourceLocator] = None )(implicit typeTagPrinter: TypeTagPrinter) extends TlaToJson[UJsonRep](UJsonFactory, locatorOpt)(typeTagPrinter) + +object TlaToUJson { + def apply(module: TlaModule): ujson.Value = (new TlaToUJson()).makeRoot(Seq(module)).value + + implicit val ujsonView: TlaModule => ujson.Value = TlaToUJson(_) +} diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/lir/ItfCounterexampleWriter.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/lir/ItfCounterexampleWriter.scala index 9f31391fe0..12e822ee9a 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/lir/ItfCounterexampleWriter.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/lir/ItfCounterexampleWriter.scala @@ -19,6 +19,12 @@ import scala.collection.mutable * Igor Konnov */ class ItfCounterexampleWriter(writer: PrintWriter) extends CounterexampleWriter { + override def write(rootModule: TlaModule, notInvariant: NotInvariant, states: List[NextState]): Unit = { + writer.write(ujson.write(ItfCounterexampleWriter.mkJson(rootModule, states), indent = 2)) + } +} + +object ItfCounterexampleWriter { /** * The minimal value that can be reliably represented with Double in JavaScript. @@ -75,10 +81,6 @@ class ItfCounterexampleWriter(writer: PrintWriter) extends CounterexampleWriter ujson.Obj(rootMap) } - override def write(rootModule: TlaModule, notInvariant: NotInvariant, states: List[NextState]): Unit = { - writer.write(ujson.write(mkJson(rootModule, states), indent = 2)) - } - private def varsToJson(root: TlaModule): ujson.Value = { val names = root.declarations.collect { case TlaVarDecl(name) => ujson.Str(name) diff --git a/tla-io/src/main/scala/at/forsyte/apalache/tla/imp/passes/SanyParserPassImpl.scala b/tla-io/src/main/scala/at/forsyte/apalache/tla/imp/passes/SanyParserPassImpl.scala index 2b129f2886..aa9d8bb341 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/tla/imp/passes/SanyParserPassImpl.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/tla/imp/passes/SanyParserPassImpl.scala @@ -1,7 +1,7 @@ package at.forsyte.apalache.tla.imp.passes import at.forsyte.apalache.infra.ExitCodes -import at.forsyte.apalache.infra.passes.Pass.PassResult +import at.forsyte.apalache.infra.passes.Pass.{PassFailure, PassResult} import at.forsyte.apalache.io.annotations.store._ import at.forsyte.apalache.io.json.impl.{DefaultTagReader, UJsonRep, UJsonToTla} import at.forsyte.apalache.tla.imp.src.SourceStore @@ -16,6 +16,9 @@ import java.io.File import at.forsyte.apalache.infra.passes.options.SourceOption import scala.io.Source import at.forsyte.apalache.infra.passes.options.OptionGroup +import scala.util.Try +import scala.util.Failure +import scala.util.Success /** * Parsing TLA+ code with SANY. @@ -40,21 +43,17 @@ class SanyParserPassImpl @Inject() ( case _ => throw new IllegalArgumentException("loadFromJsonSource called with non Json SourceOption") } - try { - val moduleJson = UJsonRep(ujson.read(readable)) - val modules = new UJsonToTla(Some(sourceStore))(DefaultTagReader).fromRoot(moduleJson) - modules match { - case rMod +: Nil => Right(rMod) - case _ => { - logger.error(s" > Error parsing file ${source}") - Left(ExitCodes.ERROR_SPEC_PARSE) - } - } - } catch { - case e: Exception => + val result = for { + moduleJson <- Try(UJsonRep(ujson.read(readable))) + module <- new UJsonToTla(Some(sourceStore))(DefaultTagReader).fromSingleModule(moduleJson) + } yield module + + result match { + case Success(mod) => Right(mod) + case Failure(err) => logger.error(s" > Error parsing file ${source}") - logger.error(" > " + e.getMessage) - Left(ExitCodes.ERROR_SPEC_PARSE) + logger.error(" > " + err.getMessage) + passFailure(err.getMessage(), ExitCodes.ERROR_SPEC_PARSE) } } @@ -72,7 +71,7 @@ class SanyParserPassImpl @Inject() ( Right(modules.get(rootName).get) } - private def saveLoadedModule(module: TlaModule): Either[ExitCodes.TExitCode, Unit] = { + private def saveLoadedModule(module: TlaModule): Either[PassFailure, Unit] = { // save the output writeOut(writerFactory, module) // write parser output to specified destination, if requested diff --git a/tla-io/src/test/scala/at/forsyte/apalache/io/lir/TestItfCounterexampleWriter.scala b/tla-io/src/test/scala/at/forsyte/apalache/io/lir/TestItfCounterexampleWriter.scala index 4a330f916c..f14be831dc 100644 --- a/tla-io/src/test/scala/at/forsyte/apalache/io/lir/TestItfCounterexampleWriter.scala +++ b/tla-io/src/test/scala/at/forsyte/apalache/io/lir/TestItfCounterexampleWriter.scala @@ -8,7 +8,6 @@ import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner -import java.io.{PrintWriter, StringWriter} import scala.collection.immutable.SortedMap @RunWith(classOf[JUnitRunner]) @@ -25,8 +24,7 @@ class TestItfCounterexampleWriter extends AnyFunSuite { * the expected output as a string */ def compareJson(rootModule: TlaModule, states: List[NextState], expected: String): Unit = { - val writer = new ItfCounterexampleWriter(new PrintWriter(new StringWriter())) - val actualJson = writer.mkJson(rootModule, states) + val actualJson = ItfCounterexampleWriter.mkJson(rootModule, states) // erase the date from the description as it is time dependent actualJson("#meta")("description") = "Created by Apalache" val expectedJson = ujson.read(expected) diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassPartial.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassPartial.scala index 84f8ef014c..a7f8a002ad 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassPartial.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassPartial.scala @@ -71,7 +71,8 @@ abstract class PreproPassPartial( val message = "%s: unsupported expression: %s".format(findLoc(id), errorMessage) logger.error(message) } - Left(ExitCodes.FAILURE_SPEC_EVAL) + val errData = failedIds.map { case (uid, s) => (uid.toString(), s) } + passFailure(errData, ExitCodes.FAILURE_SPEC_EVAL) } protected def executeWithParams( diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/SourceAwareTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/SourceAwareTypeCheckerListener.scala new file mode 100644 index 0000000000..19d3400514 --- /dev/null +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/SourceAwareTypeCheckerListener.scala @@ -0,0 +1,19 @@ +package at.forsyte.apalache.tla.typecheck + +import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} +import at.forsyte.apalache.tla.imp.src.SourceStore +import at.forsyte.apalache.tla.lir.UID + +/** A [[TypeCheckerListener]] that has a source store and and a change listener */ +abstract class SourceAwareTypeCheckerListener(sourceStore: SourceStore, changeListener: ChangeListener) + extends TypeCheckerListener { + + protected def findLoc(id: UID): String = { + val sourceLocator: SourceLocator = SourceLocator(sourceStore.makeSourceMap, changeListener) + + sourceLocator.sourceOf(id) match { + case Some(loc) => loc.toString + case None => "unknown location" + } + } +} diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala index 95f3bcffaa..84b9506483 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala @@ -5,10 +5,11 @@ import at.forsyte.apalache.io.annotations.{Annotation, AnnotationStr, StandardAn import at.forsyte.apalache.io.typecheck.parser.{DefaultType1Parser, Type1ParseError} import at.forsyte.apalache.tla.lir import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.typecheck.etc._ -import at.forsyte.apalache.tla.typecheck.integration.{RecordingTypeCheckerListener, TypeRewriter} import com.typesafe.scalalogging.LazyLogging +import at.forsyte.apalache.tla.typecheck.integration.RecordingTypeCheckerListener +import at.forsyte.apalache.tla.typecheck.integration.TypeRewriter +import at.forsyte.apalache.tla.imp.src.SourceStore /** * The API to the type checker. It first translates a TLA+ module into EtcExpr and then does the type checking. @@ -62,6 +63,8 @@ class TypeCheckerTool(annotationStore: AnnotationStore, inferPoly: Boolean, useR * Check the types in a module and, if the module is well-typed, produce a new module that attaches a type tag to * every expression and declaration in the module. * + * Only used in tests. + * * @param tracker * a transformation tracker that is applied when expressions and declarations are tagged * @param listener @@ -74,11 +77,12 @@ class TypeCheckerTool(annotationStore: AnnotationStore, inferPoly: Boolean, useR * Some(newModule) if module is well-typed; None, otherwise */ def checkAndTag( - tracker: TransformationTracker, + tracker: lir.transformations.TransformationTracker, listener: TypeCheckerListener, defaultTag: UID => TypeTag, module: TlaModule): Option[TlaModule] = { - val recorder = new RecordingTypeCheckerListener() + // The source stores and ChangeListeners for this aren't needed, since it's only run in tests + val recorder = new RecordingTypeCheckerListener(new SourceStore(), new lir.storage.ChangeListener()) if (!check(new MultiTypeCheckerListener(listener, recorder), module)) { None } else { diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala index d6ec753c41..6e89f37c0a 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala @@ -1,8 +1,10 @@ package at.forsyte.apalache.tla.typecheck.integration +import at.forsyte.apalache.tla.lir.storage.ChangeListener +import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.{TlaType1, UID} import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} -import at.forsyte.apalache.tla.typecheck.TypeCheckerListener +import at.forsyte.apalache.tla.typecheck.SourceAwareTypeCheckerListener import scala.collection.mutable @@ -13,13 +15,18 @@ import scala.collection.mutable * @author * Igor Konnov */ -class RecordingTypeCheckerListener extends TypeCheckerListener { +class RecordingTypeCheckerListener(sourceStore: SourceStore, changeListener: ChangeListener) + extends SourceAwareTypeCheckerListener(sourceStore, changeListener) { private val uidToType: mutable.Map[UID, TlaType1] = mutable.Map[UID, TlaType1]() def toMap: Map[UID, TlaType1] = { uidToType.toMap } + private val _errors: mutable.ListBuffer[(String, String)] = mutable.ListBuffer.empty + + def getErrors(): List[(String, String)] = _errors.toList + override def onTypeFound(sourceRef: ExactRef, monotype: TlaType1): Unit = { uidToType += sourceRef.tlaId -> monotype } @@ -33,6 +40,6 @@ class RecordingTypeCheckerListener extends TypeCheckerListener { * the error description */ override def onTypeError(sourceRef: EtcRef, message: String): Unit = { - // ignore + _errors += (findLoc(sourceRef.tlaId) -> message) } } diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala index a39906cc39..6fd90b719b 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala @@ -13,6 +13,9 @@ import at.forsyte.apalache.tla.typecheck.TypeCheckerTool import com.google.inject.Inject import com.typesafe.scalalogging.LazyLogging import at.forsyte.apalache.infra.passes.options.OptionGroup +import at.forsyte.apalache.tla.typecheck.integration.RecordingTypeCheckerListener +import at.forsyte.apalache.tla.typecheck.MultiTypeCheckerListener +import at.forsyte.apalache.tla.typecheck.integration.TypeRewriter class EtcTypeCheckerPassImpl @Inject() ( val options: OptionGroup.HasTypechecker, @@ -47,21 +50,21 @@ class EtcTypeCheckerPassImpl @Inject() ( Untyped } - val listener = new LoggingTypeCheckerListener(sourceStore, changeListener, inferPoly) - val taggedModule = tool.checkAndTag(tracker, listener, defaultTag, tlaModule) - - taggedModule match { - case Some(newModule) => - logger.info(" > Your types are purrfect!") - logger.info(if (isTypeCoverageComplete) " > All expressions are typed" else " > Some expressions are untyped") - writeOut(writerFactory, newModule) - - utils.writeToOutput(newModule, options, writerFactory, logger, sourceStore) - - Right(newModule) - case None => - logger.info(" > Snowcat asks you to fix the types. Meow.") - Left(ExitCodes.ERROR_TYPECHECK) + val loggingListener = new LoggingTypeCheckerListener(sourceStore, changeListener, inferPoly) + val recordingListener = new RecordingTypeCheckerListener(sourceStore, changeListener) + val listener = new MultiTypeCheckerListener(loggingListener, recordingListener) + if (tool.check(listener, tlaModule)) { + val transformer = new TypeRewriter(tracker, defaultTag)(recordingListener.toMap) + val taggedDecls = tlaModule.declarations.map(transformer(_)) + val newModule = TlaModule(tlaModule.name, taggedDecls) + logger.info(" > Your types are purrfect!") + logger.info(if (isTypeCoverageComplete) " > All expressions are typed" else " > Some expressions are untyped") + writeOut(writerFactory, newModule) + utils.writeToOutput(newModule, options, writerFactory, logger, sourceStore) + Right(newModule) + } else { + logger.info(" > Snowcat asks you to fix the types. Meow.") + passFailure(recordingListener.getErrors(), ExitCodes.ERROR_TYPECHECK) } } diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala index 7b503a93be..2e546b678c 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala @@ -1,17 +1,17 @@ package at.forsyte.apalache.tla.typecheck.passes import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} -import at.forsyte.apalache.tla.lir.{TlaType1, UID} +import at.forsyte.apalache.tla.lir.storage.ChangeListener +import at.forsyte.apalache.tla.lir.TlaType1 import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} -import at.forsyte.apalache.tla.typecheck.{TypeCheckerListener, TypingInputException} +import at.forsyte.apalache.tla.typecheck.{SourceAwareTypeCheckerListener, TypingInputException} import com.typesafe.scalalogging.LazyLogging class LoggingTypeCheckerListener( sourceStore: SourceStore, changeListener: ChangeListener, isPolymorphismEnabled: Boolean) - extends TypeCheckerListener with LazyLogging { + extends SourceAwareTypeCheckerListener(sourceStore, changeListener) with LazyLogging { /** * This method is called when the type checker finds the type of an expression. @@ -41,13 +41,4 @@ class LoggingTypeCheckerListener( override def onTypeError(sourceRef: EtcRef, message: String): Unit = { logger.error("[%s]: %s".format(findLoc(sourceRef.tlaId), message)) } - - private def findLoc(id: UID): String = { - val sourceLocator: SourceLocator = SourceLocator(sourceStore.makeSourceMap, changeListener) - - sourceLocator.sourceOf(id) match { - case Some(loc) => loc.toString - case None => "unknown location" - } - } }