diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 5ae6e3a538..95802f0231 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -24,6 +24,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC} import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.payment.IncomingPacket.NodeRelayPacket import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.OutgoingPacket.Upstream import fr.acinq.eclair.payment._ @@ -75,13 +76,29 @@ object NodeRelay { } } - def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] = + def apply(nodeParams: NodeParams, + parent: akka.actor.typed.ActorRef[NodeRelayer.Command], + register: ActorRef, + relayId: UUID, + nodeRelayPacket: NodeRelayPacket, + paymentSecret: ByteVector32, + outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] = Behaviors.setup { context => + val paymentHash = nodeRelayPacket.add.paymentHash + val totalAmountIn = nodeRelayPacket.outerPayload.totalAmount Behaviors.withMdc(Logs.mdc( category_opt = Some(Logs.LogCategory.PAYMENT), parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment paymentHash_opt = Some(paymentHash))) { - new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)() + context.log.info("relaying payment relayId={}", relayId) + val mppFsmAdapters = { + context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived) + context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed) + context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded) + }.toClassic + val incomingPaymentHandler = context.actorOf(MultiPartPaymentFSM.props(nodeParams, paymentHash, totalAmountIn, mppFsmAdapters)) + new NodeRelay(nodeParams, parent, register, relayId, paymentHash, paymentSecret, context, outgoingPaymentFactory) + .receiving(Queue.empty, nodeRelayPacket.innerPayload, nodeRelayPacket.nextPacket, incomingPaymentHandler) } } @@ -144,66 +161,37 @@ class NodeRelay private(nodeParams: NodeParams, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, + paymentSecret: ByteVector32, context: ActorContext[NodeRelay.Command], outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) { import NodeRelay._ - private val mppFsmAdapters = { - context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived) - context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed) - context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded) - }.toClassic - private val payFsmAdapters = { - context.messageAdapter[PreimageReceived](WrappedPreimageReceived) - context.messageAdapter[PaymentSent](WrappedPaymentSent) - context.messageAdapter[PaymentFailed](WrappedPaymentFailed) - }.toClassic - - def apply(): Behavior[Command] = - Behaviors.receiveMessagePartial { - // We make sure we receive all payment parts before forwarding to the next trampoline node. - case Relay(IncomingPacket.NodeRelayPacket(add, outer, inner, next)) => outer.paymentSecret match { - case None => - // TODO: @pm: maybe those checks should be done later in the flow (by the mpp FSM?) - context.log.warn("rejecting htlcId={}: missing payment secret", add.id) - rejectHtlc(add.id, add.channelId, add.amountMsat) - stopping() - case Some(secret) => - import akka.actor.typed.scaladsl.adapter._ - context.log.info("relaying payment relayId={}", relayId) - val mppFsm = context.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, outer.totalAmount, mppFsmAdapters)) - context.log.debug("forwarding incoming htlc to the payment FSM") - mppFsm ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add) - receiving(Queue(add), secret, inner, next, mppFsm) - } - } - /** * We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next * trampoline node and forward the payment. * * @param htlcs received incoming HTLCs for this set. - * @param secret all incoming HTLCs in this set must have the same secret to protect against probing / fee theft. * @param nextPayload relay instructions (should be identical across HTLCs in this set). * @param nextPacket trampoline onion to relay to the next trampoline node. * @param handler actor handling the aggregation of the incoming HTLC set. */ - private def receiving(htlcs: Queue[UpdateAddHtlc], secret: ByteVector32, nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] = + private def receiving(htlcs: Queue[UpdateAddHtlc], nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] = Behaviors.receiveMessagePartial { case Relay(IncomingPacket.NodeRelayPacket(add, outer, _, _)) => outer.paymentSecret match { + // TODO: @pm: maybe those checks should be done by the mpp FSM? case None => - context.log.warn("rejecting htlcId={}: missing payment secret", add.id) + context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", add.id, add.channelId) rejectHtlc(add.id, add.channelId, add.amountMsat) Behaviors.same - case Some(incomingSecret) if incomingSecret != secret => - context.log.warn("rejecting htlcId={}: payment secret doesn't match other HTLCs in the set", add.id) + case Some(incomingSecret) if incomingSecret != paymentSecret => + context.log.warn("rejecting htlc #{} from channel {}: payment secret doesn't match other HTLCs in the set", add.id, add.channelId) rejectHtlc(add.id, add.channelId, add.amountMsat) Behaviors.same - case Some(incomingSecret) if incomingSecret == secret => - context.log.debug("forwarding incoming htlc to the payment FSM") + case Some(incomingSecret) if incomingSecret == paymentSecret => + context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", add.id, add.channelId) handler ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add) - receiving(htlcs :+ add, secret, nextPayload, nextPacket, handler) + receiving(htlcs :+ add, nextPayload, nextPacket, handler) } case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) => context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure) @@ -267,7 +255,7 @@ class NodeRelay private(nodeParams: NodeParams, * Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us. */ private def stopping(): Behavior[Command] = { - parent ! NodeRelayer.RelayComplete(context.self, paymentHash) + parent ! NodeRelayer.RelayComplete(context.self, paymentHash, paymentSecret) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { case Stop => Behaviors.stopped @@ -275,6 +263,12 @@ class NodeRelay private(nodeParams: NodeParams, } } + private val payFsmAdapters = { + context.messageAdapter[PreimageReceived](WrappedPreimageReceived) + context.messageAdapter[PaymentSent](WrappedPaymentSent) + context.messageAdapter[PaymentFailed](WrappedPaymentFailed) + }.toClassic + private def relay(upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload, packetOut: OnionRoutingPacket): ActorRef = { val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, Nil) val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) @@ -322,7 +316,7 @@ class NodeRelay private(nodeParams: NodeParams, } private def rejectExtraHtlc(add: UpdateAddHtlc): Unit = { - context.log.warn("rejecting extra htlcId={}", add.id) + context.log.warn("rejecting extra htlc #{} from channel {}", add.id, add.channelId) rejectHtlc(add.id, add.channelId, add.amountMsat) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 737e08f1a6..05a8425c58 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -19,7 +19,10 @@ package fr.acinq.eclair.payment.relay import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.eclair.channel.CMD_FAIL_HTLC +import fr.acinq.eclair.db.PendingRelayDb import fr.acinq.eclair.payment._ +import fr.acinq.eclair.wire.protocol.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{Logs, NodeParams} import java.util.UUID @@ -29,16 +32,16 @@ import java.util.UUID */ /** - * The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer). It - * doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a child - * actor of type [[NodeRelay]]. + * The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer). + * It doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a + * child actor of type [[NodeRelay]]. */ object NodeRelayer { // @formatter:off sealed trait Command case class Relay(nodeRelayPacket: IncomingPacket.NodeRelayPacket) extends Command - case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32) extends Command + case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32, paymentSecret: ByteVector32) extends Command private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command // @formatter:on @@ -48,34 +51,47 @@ object NodeRelayer { case _: GetPendingPayments => Logs.mdc() } + case class PaymentKey(paymentHash: ByteVector32, paymentSecret: ByteVector32) + /** - * @param children a map of current in-process payments, indexed by payment hash and purposefully *not* by payment id, - * because that is how we aggregate payment parts (when the incoming payment uses MPP). + * @param children a map of pending payments. We must index by both payment hash and payment secret because we may + * need to independently relay multiple parts of the same payment using distinct payment secrets. + * NB: the payment secret used here is different from the invoice's payment secret and ensures we can + * group together HTLCs that the previous trampoline node sent in the same MPP. */ - def apply(nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, children: Map[ByteVector32, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = + def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = Behaviors.setup { context => Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) { Behaviors.receiveMessage { case Relay(nodeRelayPacket) => - import nodeRelayPacket.add.paymentHash - children.get(paymentHash) match { - case Some(handler) => - context.log.debug("forwarding incoming htlc to existing handler") - handler ! NodeRelay.Relay(nodeRelayPacket) - Behaviors.same + val htlcIn = nodeRelayPacket.add + nodeRelayPacket.outerPayload.paymentSecret match { + case Some(paymentSecret) => + val childKey = PaymentKey(htlcIn.paymentHash, paymentSecret) + children.get(childKey) match { + case Some(handler) => + context.log.debug("forwarding incoming htlc #{} from channel {} to existing handler", htlcIn.id, htlcIn.channelId) + handler ! NodeRelay.Relay(nodeRelayPacket) + Behaviors.same + case None => + val relayId = UUID.randomUUID() + context.log.debug(s"spawning a new handler with relayId=$relayId") + val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, childKey.paymentSecret, outgoingPaymentFactory), relayId.toString) + context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId) + handler ! NodeRelay.Relay(nodeRelayPacket) + apply(nodeParams, register, outgoingPaymentFactory, children + (childKey -> handler)) + } case None => - val relayId = UUID.randomUUID() - context.log.debug(s"spawning a new handler with relayId=$relayId") - val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register) - val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString) - context.log.debug("forwarding incoming htlc to new handler") - handler ! NodeRelay.Relay(nodeRelayPacket) - apply(nodeParams, router, register, children + (paymentHash -> handler)) + context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", htlcIn.id, htlcIn.channelId) + val failureMessage = IncorrectOrUnknownPaymentDetails(htlcIn.amountMsat, nodeParams.currentBlockHeight) + val cmd = CMD_FAIL_HTLC(htlcIn.id, Right(failureMessage), commit = true) + PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, htlcIn.channelId, cmd) + Behaviors.same } - case RelayComplete(childHandler, paymentHash) => + case RelayComplete(childHandler, paymentHash, paymentSecret) => // we do a back-and-forth between parent and child before stopping the child to prevent a race condition childHandler ! NodeRelay.Stop - apply(nodeParams, router, register, children - paymentHash) + apply(nodeParams, register, outgoingPaymentFactory, children - PaymentKey(paymentHash, paymentSecret)) case GetPendingPayments(replyTo) => replyTo ! children Behaviors.same diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index 9527bcc2d8..c834fdae83 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -55,7 +55,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner") private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer") - private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, router, register)).onFailure(SupervisorStrategy.resume), name = "node-relayer") + private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register))).onFailure(SupervisorStrategy.resume), name = "node-relayer") def receive: Receive = { case RelayForward(add) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 4ccd867039..6d9619d8ab 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -30,6 +30,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.OutgoingPacket.Upstream import fr.acinq.eclair.payment.PaymentRequest.{ExtraHop, PaymentRequestFeatures} import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment @@ -37,8 +38,8 @@ import fr.acinq.eclair.router.Router.RouteRequest import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TestConstants, nodeFee, randomBytes, randomBytes32, randomKey} +import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike -import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax import java.util.UUID @@ -53,77 +54,138 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ - case class FixtureParam(nodeParams: NodeParams, nodeRelayer: ActorRef[NodeRelay.Command], parent: TestProbe[NodeRelayer.Command], router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) + case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { + def createNodeRelay(packetIn: IncomingPacket.NodeRelayPacket, paymentSecret: ByteVector32 = incomingSecret, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { + val parent = TestProbe[NodeRelayer.Command]("parent-relayer") + val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this) + val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, paymentSecret, outgoingPaymentFactory)) + (nodeRelay, parent) + } + } + + case class FakeOutgoingPaymentFactory(f: FixtureParam) extends NodeRelay.OutgoingPaymentFactory { + override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { + f.mockPayFSM.ref ! cfg + f.mockPayFSM.ref.toClassic + } + } + + case class RealOutgoingPaymentFactory(f: FixtureParam) extends NodeRelay.OutgoingPaymentFactory { + override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { + val outgoingPayFSM = NodeRelay.SimpleOutgoingPaymentFactory(f.nodeParams, f.router.ref.toClassic, f.register.ref.toClassic).spawnOutgoingPayFSM(context, cfg, multiPart) + f.mockPayFSM.ref ! outgoingPayFSM + outgoingPayFSM + } + } override def withFixture(test: OneArgTest): Outcome = { val nodeParams = TestConstants.Bob.nodeParams.copy(multiPartPaymentExpiry = 5 seconds) - val parent = TestProbe[NodeRelayer.Command]("parent-relayer") val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") system.eventStream ! EventStream.Subscribe(eventListener.ref) val mockPayFSM = TestProbe[Any]("pay-fsm") - val outgoingPaymentFactory = if (test.tags.contains("mock-fsm")) { - new NodeRelay.OutgoingPaymentFactory { - override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { - mockPayFSM.ref ! cfg - mockPayFSM.ref.toClassic - } - } - } else { - new NodeRelay.OutgoingPaymentFactory { - override def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): akka.actor.ActorRef = { - val outgoingPayFSM = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router.ref.toClassic, register.ref.toClassic).spawnOutgoingPayFSM(context, cfg, multiPart) - mockPayFSM.ref ! outgoingPayFSM - outgoingPayFSM - } - } - } - val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, paymentHash, outgoingPaymentFactory)) - withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelay, parent, router, register, mockPayFSM, eventListener))) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener))) + } + + test("create child handlers for new payments") { f => + import f._ + val probe = TestProbe[Any] + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f))) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + probe.expectMessage(Map.empty) + + val paymentNoSecret = createPartialIncomingPacket(randomBytes32, randomBytes32).copy(outerPayload = Onion.createSinglePartPayload(incomingAmount, CltvExpiry(500000))) + parentRelayer ! NodeRelayer.Relay(paymentNoSecret) + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId === paymentNoSecret.add.channelId) + assert(fwd.message === CMD_FAIL_HTLC(paymentNoSecret.add.id, Right(IncorrectOrUnknownPaymentDetails(paymentNoSecret.add.amountMsat, nodeParams.currentBlockHeight)), commit = true)) + + val (paymentHash1, paymentSecret1) = (randomBytes32, randomBytes32) + val payment1 = createPartialIncomingPacket(paymentHash1, paymentSecret1) + parentRelayer ! NodeRelayer.Relay(payment1) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] + assert(pending1.keySet === Set(PaymentKey(paymentHash1, paymentSecret1))) + + val (paymentHash2, paymentSecret2) = (randomBytes32, randomBytes32) + val payment2 = createPartialIncomingPacket(paymentHash2, paymentSecret2) + parentRelayer ! NodeRelayer.Relay(payment2) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + val pending2 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] + assert(pending2.keySet === Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2))) + + val payment3a = createPartialIncomingPacket(paymentHash1, paymentSecret2) + parentRelayer ! NodeRelayer.Relay(payment3a) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + val pending3 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] + assert(pending3.keySet === Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2), PaymentKey(paymentHash1, paymentSecret2))) + + val payment3b = createPartialIncomingPacket(paymentHash1, paymentSecret2) + parentRelayer ! NodeRelayer.Relay(payment3b) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + val pending4 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] + assert(pending4.keySet === Set(PaymentKey(paymentHash1, paymentSecret1), PaymentKey(paymentHash2, paymentSecret2), PaymentKey(paymentHash1, paymentSecret2))) + + register.expectNoMessage(100 millis) } - test("stop child handler when relay is complete") { f => + test("stop child handlers when relay is complete") { f => import f._ val probe = TestProbe[Any] + val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f) { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) } { - val (paymentHash1, child1) = (randomBytes32, TestProbe[NodeRelay.Command]) - val (paymentHash2, child2) = (randomBytes32, TestProbe[NodeRelay.Command]) - val children = Map(paymentHash1 -> child1.ref, paymentHash2 -> child2.ref) - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic, children)) + val (paymentHash1, paymentSecret1, child1) = (randomBytes32, randomBytes32, TestProbe[NodeRelay.Command]) + val (paymentHash2, paymentSecret2, child2) = (randomBytes32, randomBytes32, TestProbe[NodeRelay.Command]) + val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, children)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(children) - parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1) + parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1, paymentSecret1) child1.expectMessage(NodeRelay.Stop) - parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1) + parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1, paymentSecret1) + child1.expectMessage(NodeRelay.Stop) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + probe.expectMessage(Map(PaymentKey(paymentHash2, paymentSecret2) -> child2.ref)) + } + { + val paymentHash = randomBytes32 + val (paymentSecret1, child1) = (randomBytes32, TestProbe[NodeRelay.Command]) + val (paymentSecret2, child2) = (randomBytes32, TestProbe[NodeRelay.Command]) + val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, children)) + parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) + probe.expectMessage(children) + + parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash, paymentSecret1) child1.expectMessage(NodeRelay.Stop) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) - probe.expectMessage(children - paymentHash1) + probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref)) } { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory)) parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) - val pending1 = probe.expectMessageType[Map[ByteVector32, ActorRef[NodeRelay.Command]]] + val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending1.size === 1) - assert(pending1.head._1 === paymentHash) + assert(pending1.head._1 === PaymentKey(paymentHash, incomingSecret)) - parentRelayer ! NodeRelayer.RelayComplete(pending1.head._2, paymentHash) + parentRelayer ! NodeRelayer.RelayComplete(pending1.head._2, paymentHash, incomingSecret) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) - val pending2 = probe.expectMessageType[Map[ByteVector32, ActorRef[NodeRelay.Command]]] + val pending2 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] assert(pending2.size === 1) - assert(pending2.head._1 === paymentHash) + assert(pending2.head._1 === pending1.head._1) assert(pending2.head._2 !== pending1.head._2) } } @@ -131,6 +193,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail to relay when incoming multi-part payment times out") { f => import f._ + val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) // Receive a partial upstream multi-part payment. incomingMultiPart.dropRight(1).foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) // after a while the payment times out @@ -148,6 +211,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail all extraneous multi-part incoming HTLCs") { f => import f._ + val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) // We send all the parts of a mpp incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) // and then one extra @@ -167,9 +231,10 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("fail all additional incoming HTLCs once already relayed out", Tag("mock-fsm")) { f => + test("fail all additional incoming HTLCs once already relayed out") { f => import f._ + val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) // Receive a complete upstream multi-part payment, which we relay out. incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) @@ -194,7 +259,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl // Receive new HTLC with different details, but for the same payment hash. val i2 = IncomingPacket.NodeRelayPacket( UpdateAddHtlc(randomBytes32, Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket), - Onion.createSinglePartPayload(1500 msat, CltvExpiry(499990), Some(randomBytes32)), + Onion.createSinglePartPayload(1500 msat, CltvExpiry(499990), Some(incomingSecret)), Onion.createNodeRelayPayload(1250 msat, outgoingExpiry, outgoingNodeId), nextTrampolinePacket) nodeRelayer ! NodeRelay.Relay(i2) @@ -213,6 +278,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val p = createValidIncomingPacket(2000000 msat, 2000000 msat, CltvExpiry(500000), outgoingAmount, outgoingExpiry).copy( outerPayload = Onion.createSinglePartPayload(2000000 msat, CltvExpiry(500000)) // missing outer payment secret ) + val (nodeRelayer, _) = f.createNodeRelay(p) nodeRelayer ! NodeRelay.Relay(p) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -230,6 +296,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val p2 = createValidIncomingPacket(1000000 msat, 3000000 msat, CltvExpiry(500000), 2500000 msat, outgoingExpiry).copy( outerPayload = Onion.createMultiPartPayload(1000000 msat, 3000000 msat, CltvExpiry(500000), randomBytes32) ) + val (nodeRelayer, _) = f.createNodeRelay(p1) nodeRelayer ! NodeRelay.Relay(p1) nodeRelayer ! NodeRelay.Relay(p2) @@ -247,6 +314,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val expiryIn = CltvExpiry(500000) // not ok (delta = 100) val expiryOut = CltvExpiry(499900) val p = createValidIncomingPacket(2000000 msat, 2000000 msat, expiryIn, 1000000 msat, expiryOut) + val (nodeRelayer, _) = f.createNodeRelay(p) nodeRelayer ! NodeRelay.Relay(p) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -262,6 +330,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val expiryIn = CltvExpiry(500000) val expiryOut = CltvExpiry(300000) // not ok (chain heigh = 400000) val p = createValidIncomingPacket(2000000 msat, 2000000 msat, expiryIn, 1000000 msat, expiryOut) + val (nodeRelayer, _) = f.createNodeRelay(p) nodeRelayer ! NodeRelay.Relay(p) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -281,6 +350,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(2000000 msat, 3000000 msat, expiryIn1, 2100000 msat, expiryOut), createValidIncomingPacket(1000000 msat, 3000000 msat, expiryIn2, 2100000 msat, expiryOut) ) + val (nodeRelayer, _) = f.createNodeRelay(p.head) p.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) p.foreach { p => @@ -296,6 +366,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ val p = createValidIncomingPacket(2000000 msat, 2000000 msat, CltvExpiry(500000), 1999000 msat, CltvExpiry(490000)) + val (nodeRelayer, _) = f.createNodeRelay(p) nodeRelayer ! NodeRelay.Relay(p) val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -312,6 +383,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(2000000 msat, 3000000 msat, CltvExpiry(500000), 2999000 msat, CltvExpiry(400000)), createValidIncomingPacket(1000000 msat, 3000000 msat, CltvExpiry(500000), 2999000 msat, CltvExpiry(400000)) ) + val (nodeRelayer, _) = f.createNodeRelay(p.head) p.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) p.foreach { p => @@ -323,10 +395,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("fail to relay because outgoing balance isn't sufficient (low fees)", Tag("mock-fsm")) { f => + test("fail to relay because outgoing balance isn't sufficient (low fees)") { f => import f._ // Receive an upstream multi-part payment. + val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head) incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) mockPayFSM.expectMessageType[SendPaymentConfig] @@ -353,6 +426,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl createValidIncomingPacket(outgoingAmount, outgoingAmount * 2, CltvExpiry(500000), outgoingAmount, outgoingExpiry), createValidIncomingPacket(outgoingAmount, outgoingAmount * 2, CltvExpiry(500000), outgoingAmount, outgoingExpiry), ) + val (nodeRelayer, _) = f.createNodeRelay(incoming.head, useRealPaymentFactory = true) incoming.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] @@ -372,6 +446,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ // Receive an upstream multi-part payment. + val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head, useRealPaymentFactory = true) incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] @@ -394,6 +469,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import f._ // Receive an upstream multi-part payment. + val (nodeRelayer, _) = f.createNodeRelay(incomingMultiPart.head, useRealPaymentFactory = true) incomingMultiPart.foreach(p => nodeRelayer ! NodeRelay.Relay(p)) val payFSM = mockPayFSM.expectMessageType[akka.actor.ActorRef] @@ -414,11 +490,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("compute route params") { f => import f._ - // Receive an upstream multi-part payment. + // Receive an upstream payment. + val (nodeRelayer, _) = f.createNodeRelay(incomingSinglePart, useRealPaymentFactory = true) nodeRelayer ! NodeRelay.Relay(incomingSinglePart) val routeRequest = router.expectMessageType[RouteRequest] - val routeParams = routeRequest.routeParams.get val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, outgoingAmount) assert(routeParams.maxFeePct === 0) // should be disabled @@ -426,10 +502,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl assert(routeParams.routeMaxCltv === incomingSinglePart.add.cltvExpiry - outgoingExpiry - nodeParams.expiryDelta) // we apply our cltv delta } - test("relay incoming multi-part payment", Tag("mock-fsm")) { f => + test("relay incoming multi-part payment") { f => import f._ // Receive an upstream multi-part payment. + val (nodeRelayer, parent) = f.createNodeRelay(incomingMultiPart.head) incomingMultiPart.dropRight(1).foreach(p => nodeRelayer ! NodeRelay.Relay(p)) mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment @@ -464,10 +541,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("relay incoming single-part payment", Tag("mock-fsm")) { f => + test("relay incoming single-part payment") { f => import f._ // Receive an upstream single-part payment. + val (nodeRelayer, parent) = f.createNodeRelay(incomingSinglePart) nodeRelayer ! NodeRelay.Relay(incomingSinglePart) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] @@ -492,16 +570,18 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("relay to non-trampoline recipient supporting multi-part", Tag("mock-fsm")) { f => + test("relay to non-trampoline recipient supporting multi-part") { f => import f._ // Receive an upstream multi-part payment. val hints = List(List(ExtraHop(outgoingNodeId, ShortChannelId(42), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val features = PaymentRequestFeatures(VariableLengthOnion.optional, PaymentSecret.mandatory, BasicMultiPartPayment.optional) val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(outgoingAmount * 3), paymentHash, randomKey, "Some invoice", CltvExpiryDelta(18), extraHops = hints, features = Some(features)) - incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming.copy(innerPayload = Onion.createNodeRelayToNonTrampolinePayload( + val incomingPayments = incomingMultiPart.map(incoming => incoming.copy(innerPayload = Onion.createNodeRelayToNonTrampolinePayload( incoming.innerPayload.amountToForward, outgoingAmount * 3, outgoingExpiry, outgoingNodeId, pr - )))) + ))) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) @@ -532,15 +612,17 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("relay to non-trampoline recipient without multi-part", Tag("mock-fsm")) { f => + test("relay to non-trampoline recipient without multi-part") { f => import f._ // Receive an upstream multi-part payment. val hints = List(List(ExtraHop(outgoingNodeId, ShortChannelId(42), feeBase = 10 msat, feeProportionalMillionths = 1, cltvExpiryDelta = CltvExpiryDelta(12)))) val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(outgoingAmount), paymentHash, randomKey, "Some invoice", CltvExpiryDelta(18), extraHops = hints, features = Some(PaymentRequestFeatures())) - incomingMultiPart.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming.copy(innerPayload = Onion.createNodeRelayToNonTrampolinePayload( + val incomingPayments = incomingMultiPart.map(incoming => incoming.copy(innerPayload = Onion.createNodeRelayToNonTrampolinePayload( incoming.innerPayload.amountToForward, incoming.innerPayload.amountToForward, outgoingExpiry, outgoingNodeId, pr - )))) + ))) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(_.add))) @@ -638,4 +720,14 @@ object NodeRelayerSpec { nextTrampolinePacket) } + def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): IncomingPacket.NodeRelayPacket = { + val (expiryIn, expiryOut) = (CltvExpiry(500000), CltvExpiry(490000)) + val amountIn = incomingAmount / 2 + IncomingPacket.NodeRelayPacket( + UpdateAddHtlc(randomBytes32, Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket), + Onion.createMultiPartPayload(amountIn, incomingAmount, expiryIn, paymentSecret), + Onion.createNodeRelayPayload(outgoingAmount, expiryOut, outgoingNodeId), + nextTrampolinePacket) + } + } \ No newline at end of file