From 0ae3befec027b5239b2f39dfb1dbae543d3e765b Mon Sep 17 00:00:00 2001
From: Hans Moog <hm@mkjc.net>
Date: Fri, 12 Jun 2020 19:48:46 +0200
Subject: [PATCH] Fix: fixes a race condition in solidification

---
 .../valuetransfers/packages/tangle/tangle.go  | 27 ++++++++++---------
 .../tangle/tangle_concurrency_test.go         | 12 ++++-----
 2 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/dapps/valuetransfers/packages/tangle/tangle.go b/dapps/valuetransfers/packages/tangle/tangle.go
index dae0647a..9657b8e8 100644
--- a/dapps/valuetransfers/packages/tangle/tangle.go
+++ b/dapps/valuetransfers/packages/tangle/tangle.go
@@ -1164,7 +1164,7 @@ func (tangle *Tangle) processSolidificationStackEntry(solidificationStack *list.
 	}
 
 	// book the solid entities
-	transactionBooked, _, decisionPending, bookingErr := tangle.book(solidificationStackEntry.Retain())
+	transactionBooked, payloadBooked, decisionPending, bookingErr := tangle.book(solidificationStackEntry.Retain())
 	if bookingErr != nil {
 		tangle.Events.Error.Trigger(bookingErr)
 
@@ -1177,9 +1177,12 @@ func (tangle *Tangle) processSolidificationStackEntry(solidificationStack *list.
 	// trigger events and schedule check of approvers / consumers
 	if transactionBooked {
 		tangle.Events.TransactionBooked.Trigger(solidificationStackEntry.CachedTransaction, solidificationStackEntry.CachedTransactionMetadata, decisionPending)
+
+		tangle.ForEachConsumers(currentTransaction, tangle.createValuePayloadFutureConeIterator(solidificationStack, processedPayloads))
+	}
+	if payloadBooked {
+		tangle.ForeachApprovers(currentPayload.ID(), tangle.createValuePayloadFutureConeIterator(solidificationStack, processedPayloads))
 	}
-	tangle.ForEachConsumers(currentTransaction, tangle.createValuePayloadFutureConeIterator(solidificationStack, processedPayloads))
-	tangle.ForeachApprovers(currentPayload.ID(), tangle.createValuePayloadFutureConeIterator(solidificationStack, processedPayloads))
 }
 
 func (tangle *Tangle) book(entitiesToBook *valuePayloadPropagationStackEntry) (transactionBooked bool, payloadBooked bool, decisionPending bool, err error) {
@@ -1333,6 +1336,14 @@ func (tangle *Tangle) bookPayload(cachedPayload *payload.CachedPayload, cachedPa
 		return
 	}
 
+	branchBranchID := tangle.payloadBranchID(valueObject.BranchID())
+	trunkBranchID := tangle.payloadBranchID(valueObject.TrunkID())
+	transactionBranchID := transactionMetadata.BranchID()
+
+	if branchBranchID == branchmanager.UndefinedBranchID || trunkBranchID == branchmanager.UndefinedBranchID || transactionBranchID == branchmanager.UndefinedBranchID {
+		return
+	}
+
 	// abort if the payload has been marked as solid before
 	if !valueObjectMetadata.setSolid(true) {
 		return
@@ -1341,16 +1352,6 @@ func (tangle *Tangle) bookPayload(cachedPayload *payload.CachedPayload, cachedPa
 	// trigger event if payload became solid
 	tangle.Events.PayloadSolid.Trigger(cachedPayload, cachedPayloadMetadata)
 
-	branchBranchID := tangle.payloadBranchID(valueObject.BranchID())
-	trunkBranchID := tangle.payloadBranchID(valueObject.TrunkID())
-	transactionBranchID := transactionMetadata.BranchID()
-
-	if branchBranchID == branchmanager.UndefinedBranchID ||
-		trunkBranchID == branchmanager.UndefinedBranchID ||
-		transactionBranchID == branchmanager.UndefinedBranchID {
-		return
-	}
-
 	cachedAggregatedBranch, err := tangle.BranchManager().AggregateBranches([]branchmanager.BranchID{branchBranchID, trunkBranchID, transactionBranchID}...)
 	if err != nil {
 		return
diff --git a/dapps/valuetransfers/packages/tangle/tangle_concurrency_test.go b/dapps/valuetransfers/packages/tangle/tangle_concurrency_test.go
index 3155f447..34ec76e2 100644
--- a/dapps/valuetransfers/packages/tangle/tangle_concurrency_test.go
+++ b/dapps/valuetransfers/packages/tangle/tangle_concurrency_test.go
@@ -341,16 +341,16 @@ func TestReverseTransactionSolidification(t *testing.T) {
 			// check if outputs are found in database
 			transactions[i].Outputs().ForEach(func(address address.Address, balances []*balance.Balance) bool {
 				cachedOutput := tangle.TransactionOutput(transaction.NewOutputID(address, transactions[i].ID()))
-				assert.Truef(t, cachedOutput.Consume(func(output *Output) {
+				require.Truef(t, cachedOutput.Consume(func(output *Output) {
 					// only the last outputs in chain should not be spent
 					if i+txChains >= countTotal {
-						assert.Equalf(t, 0, output.ConsumerCount(), "the output should not be spent")
+						require.Equalf(t, 0, output.ConsumerCount(), "the output should not be spent")
 					} else {
-						assert.Equalf(t, 1, output.ConsumerCount(), "the output should be spent")
+						require.Equalf(t, 1, output.ConsumerCount(), "the output should be spent")
 					}
-					assert.Equal(t, []*balance.Balance{balance.New(balance.ColorIOTA, 1)}, output.Balances())
-					assert.Equalf(t, branchmanager.MasterBranchID, output.BranchID(), "the output was booked into the wrong branch")
-					assert.Truef(t, output.Solid(), "the output is not solid")
+					require.Equal(t, []*balance.Balance{balance.New(balance.ColorIOTA, 1)}, output.Balances())
+					require.Equalf(t, branchmanager.MasterBranchID, output.BranchID(), "the output was booked into the wrong branch")
+					require.Truef(t, output.Solid(), "the output is not solid")
 				}), "output not found in database for tx %s", transactions[i])
 				return true
 			})
-- 
GitLab