Skip to content

Commit

Permalink
Index trampoline payments by hash and secret
Browse files Browse the repository at this point in the history
We need to group incoming HTLCs together by payment_hash and payment_secret,
otherwise we will reject valid payments that are split into multiple distinct
trampoline parts (same payment_hash but different payment_secret).

Fixes #1723
  • Loading branch information
t-bast committed Apr 15, 2021
1 parent eb834e2 commit 6adc24a
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC}
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.payment.IncomingPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment.OutgoingPacket.Upstream
import fr.acinq.eclair.payment._
Expand Down Expand Up @@ -75,13 +76,29 @@ object NodeRelay {
}
}

def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
def apply(nodeParams: NodeParams,
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
register: ActorRef,
relayId: UUID,
nodeRelayPacket: NodeRelayPacket,
paymentSecret: ByteVector32,
outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
Behaviors.setup { context =>
val paymentHash = nodeRelayPacket.add.paymentHash
val totalAmountIn = nodeRelayPacket.outerPayload.totalAmount
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
paymentHash_opt = Some(paymentHash))) {
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
context.log.info("relaying payment relayId={}", relayId)
val mppFsmAdapters = {
context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded)
}.toClassic
val incomingPaymentHandler = context.actorOf(MultiPartPaymentFSM.props(nodeParams, paymentHash, totalAmountIn, mppFsmAdapters))
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, paymentSecret, context, outgoingPaymentFactory)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nodeRelayPacket.nextPacket, incomingPaymentHandler)
}
}

Expand Down Expand Up @@ -144,66 +161,37 @@ class NodeRelay private(nodeParams: NodeParams,
register: ActorRef,
relayId: UUID,
paymentHash: ByteVector32,
paymentSecret: ByteVector32,
context: ActorContext[NodeRelay.Command],
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) {

import NodeRelay._

private val mppFsmAdapters = {
context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded)
}.toClassic
private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic

def apply(): Behavior[Command] =
Behaviors.receiveMessagePartial {
// We make sure we receive all payment parts before forwarding to the next trampoline node.
case Relay(IncomingPacket.NodeRelayPacket(add, outer, inner, next)) => outer.paymentSecret match {
case None =>
// TODO: @pm: maybe those checks should be done later in the flow (by the mpp FSM?)
context.log.warn("rejecting htlcId={}: missing payment secret", add.id)
rejectHtlc(add.id, add.channelId, add.amountMsat)
stopping()
case Some(secret) =>
import akka.actor.typed.scaladsl.adapter._
context.log.info("relaying payment relayId={}", relayId)
val mppFsm = context.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, outer.totalAmount, mppFsmAdapters))
context.log.debug("forwarding incoming htlc to the payment FSM")
mppFsm ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add)
receiving(Queue(add), secret, inner, next, mppFsm)
}
}

/**
* We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next
* trampoline node and forward the payment.
*
* @param htlcs received incoming HTLCs for this set.
* @param secret all incoming HTLCs in this set must have the same secret to protect against probing / fee theft.
* @param nextPayload relay instructions (should be identical across HTLCs in this set).
* @param nextPacket trampoline onion to relay to the next trampoline node.
* @param handler actor handling the aggregation of the incoming HTLC set.
*/
private def receiving(htlcs: Queue[UpdateAddHtlc], secret: ByteVector32, nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] =
private def receiving(htlcs: Queue[UpdateAddHtlc], nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] =
Behaviors.receiveMessagePartial {
case Relay(IncomingPacket.NodeRelayPacket(add, outer, _, _)) => outer.paymentSecret match {
// TODO: @pm: maybe those checks should be done by the mpp FSM?
case None =>
context.log.warn("rejecting htlcId={}: missing payment secret", add.id)
context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
Behaviors.same
case Some(incomingSecret) if incomingSecret != secret =>
context.log.warn("rejecting htlcId={}: payment secret doesn't match other HTLCs in the set", add.id)
case Some(incomingSecret) if incomingSecret != paymentSecret =>
context.log.warn("rejecting htlc #{} from channel {}: payment secret doesn't match other HTLCs in the set", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
Behaviors.same
case Some(incomingSecret) if incomingSecret == secret =>
context.log.debug("forwarding incoming htlc to the payment FSM")
case Some(incomingSecret) if incomingSecret == paymentSecret =>
context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", add.id, add.channelId)
handler ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add)
receiving(htlcs :+ add, secret, nextPayload, nextPacket, handler)
receiving(htlcs :+ add, nextPayload, nextPacket, handler)
}
case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) =>
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
Expand Down Expand Up @@ -267,14 +255,20 @@ class NodeRelay private(nodeParams: NodeParams,
* Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us.
*/
private def stopping(): Behavior[Command] = {
parent ! NodeRelayer.RelayComplete(context.self, paymentHash)
parent ! NodeRelayer.RelayComplete(context.self, paymentHash, paymentSecret)
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case Stop => Behaviors.stopped
}
}
}

private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic

private def relay(upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload, packetOut: OnionRoutingPacket): ActorRef = {
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, Nil)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
Expand Down Expand Up @@ -322,7 +316,7 @@ class NodeRelay private(nodeParams: NodeParams,
}

private def rejectExtraHtlc(add: UpdateAddHtlc): Unit = {
context.log.warn("rejecting extra htlcId={}", add.id)
context.log.warn("rejecting extra htlc #{} from channel {}", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package fr.acinq.eclair.payment.relay
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.CMD_FAIL_HTLC
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.payment._
import fr.acinq.eclair.wire.protocol.IncorrectOrUnknownPaymentDetails
import fr.acinq.eclair.{Logs, NodeParams}

import java.util.UUID
Expand All @@ -29,16 +32,16 @@ import java.util.UUID
*/

/**
* The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer). It
* doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a child
* actor of type [[NodeRelay]].
* The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer).
* It doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a
* child actor of type [[NodeRelay]].
*/
object NodeRelayer {

// @formatter:off
sealed trait Command
case class Relay(nodeRelayPacket: IncomingPacket.NodeRelayPacket) extends Command
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32) extends Command
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32, paymentSecret: ByteVector32) extends Command
private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command
// @formatter:on

Expand All @@ -48,34 +51,47 @@ object NodeRelayer {
case _: GetPendingPayments => Logs.mdc()
}

case class PaymentKey(paymentHash: ByteVector32, paymentSecret: ByteVector32)

/**
* @param children a map of current in-process payments, indexed by payment hash and purposefully *not* by payment id,
* because that is how we aggregate payment parts (when the incoming payment uses MPP).
* @param children a map of pending payments. We must index by both payment hash and payment secret because we may
* need to independently relay multiple parts of the same payment using distinct payment secrets.
* NB: the payment secret used here is different from the invoice's payment secret and ensures we can
* group together HTLCs that the previous trampoline node sent in the same MPP.
*/
def apply(nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, children: Map[ByteVector32, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
Behaviors.receiveMessage {
case Relay(nodeRelayPacket) =>
import nodeRelayPacket.add.paymentHash
children.get(paymentHash) match {
case Some(handler) =>
context.log.debug("forwarding incoming htlc to existing handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
Behaviors.same
val htlcIn = nodeRelayPacket.add
nodeRelayPacket.outerPayload.paymentSecret match {
case Some(paymentSecret) =>
val childKey = PaymentKey(htlcIn.paymentHash, paymentSecret)
children.get(childKey) match {
case Some(handler) =>
context.log.debug("forwarding incoming htlc #{} from channel {} to existing handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
Behaviors.same
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, childKey.paymentSecret, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, register, outgoingPaymentFactory, children + (childKey -> handler))
}
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register)
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc to new handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, router, register, children + (paymentHash -> handler))
context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", htlcIn.id, htlcIn.channelId)
val failureMessage = IncorrectOrUnknownPaymentDetails(htlcIn.amountMsat, nodeParams.currentBlockHeight)
val cmd = CMD_FAIL_HTLC(htlcIn.id, Right(failureMessage), commit = true)
PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, htlcIn.channelId, cmd)
Behaviors.same
}
case RelayComplete(childHandler, paymentHash) =>
case RelayComplete(childHandler, paymentHash, paymentSecret) =>
// we do a back-and-forth between parent and child before stopping the child to prevent a race condition
childHandler ! NodeRelay.Stop
apply(nodeParams, router, register, children - paymentHash)
apply(nodeParams, register, outgoingPaymentFactory, children - PaymentKey(paymentHash, paymentSecret))
case GetPendingPayments(replyTo) =>
replyTo ! children
Behaviors.same
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym

private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner")
private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, router, register)).onFailure(SupervisorStrategy.resume), name = "node-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register))).onFailure(SupervisorStrategy.resume), name = "node-relayer")

def receive: Receive = {
case RelayForward(add) =>
Expand Down
Loading

0 comments on commit 6adc24a

Please sign in to comment.