Skip to content
Snippets Groups Projects
Unverified Commit f4711320 authored by Wolfgang Welz's avatar Wolfgang Welz Committed by GitHub
Browse files

Feat: Add PoW for messages (#511)


* add pow and nonce to the message

* add pow plugin

* fix pow tests

* fix linter warnings

* fix typo

* fix panic when messagelayer is disabled

* improve logging

* fix typo

* make the pow check a byte filter

* decrease the test PoW difficulty even further

* expose the pow filter and use it in the pow plugin

* Apply suggestions from code review

* make length of nonce a const

Co-authored-by: default avatarLuca Moser <moser.luca@gmail.com>
parent 14fce79a
No related branches found
No related tags found
No related merge requests found
Showing
with 490 additions and 145 deletions
...@@ -60,6 +60,11 @@ ...@@ -60,6 +60,11 @@
"disablePlugins": [], "disablePlugins": [],
"enablePlugins": [] "enablePlugins": []
}, },
"pow": {
"difficulty": 22,
"numThreads": 1,
"timeout": "1m"
},
"webapi": { "webapi": {
"auth": { "auth": {
"password": "goshimmer", "password": "goshimmer",
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/iotaledger/hive.go/crypto/ed25519" "github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/identity"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/iotaledger/goshimmer/dapps/valuetransfers/packages/address" "github.com/iotaledger/goshimmer/dapps/valuetransfers/packages/address"
...@@ -46,7 +45,6 @@ func ExamplePayload() { ...@@ -46,7 +45,6 @@ func ExamplePayload() {
) )
// 3. build actual transaction (the base layer creates this and wraps the ontology provided payload) // 3. build actual transaction (the base layer creates this and wraps the ontology provided payload)
localIdentity := identity.GenerateLocalIdentity()
tx := message.New( tx := message.New(
// trunk in "network tangle" ontology (filled by tipSelector) // trunk in "network tangle" ontology (filled by tipSelector)
message.EmptyId, message.EmptyId,
...@@ -54,17 +52,23 @@ func ExamplePayload() { ...@@ -54,17 +52,23 @@ func ExamplePayload() {
// branch in "network tangle" ontology (filled by tipSelector) // branch in "network tangle" ontology (filled by tipSelector)
message.EmptyId, message.EmptyId,
// issuer of the transaction (signs automatically)
localIdentity,
// the time when the transaction was created // the time when the transaction was created
time.Now(), time.Now(),
// public key of the issuer
ed25519.PublicKey{},
// the ever increasing sequence number of this transaction // the ever increasing sequence number of this transaction
0, 0,
// payload // payload
valuePayload, valuePayload,
// nonce to check PoW
0,
// signature
ed25519.Signature{},
) )
fmt.Println(tx) fmt.Println(tx)
......
...@@ -29,7 +29,7 @@ func TestSignatureFilter(t *testing.T) { ...@@ -29,7 +29,7 @@ func TestSignatureFilter(t *testing.T) {
// create helper instances // create helper instances
seed := wallet.NewSeed() seed := wallet.NewSeed()
messageFactory := messagefactory.New(mapdb.NewMapDB(), identity.GenerateLocalIdentity(), tipselector.New(), []byte("sequenceKey")) messageFactory := messagefactory.New(mapdb.NewMapDB(), []byte("sequenceKey"), identity.GenerateLocalIdentity(), tipselector.New())
// 1. test value message without signatures // 1. test value message without signatures
{ {
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/iotaledger/hive.go/crypto/ed25519" "github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/marshalutil" "github.com/iotaledger/hive.go/marshalutil"
"github.com/iotaledger/hive.go/objectstorage" "github.com/iotaledger/hive.go/objectstorage"
"github.com/iotaledger/hive.go/stringify" "github.com/iotaledger/hive.go/stringify"
...@@ -20,38 +19,35 @@ type Message struct { ...@@ -20,38 +19,35 @@ type Message struct {
objectstorage.StorableObjectFlags objectstorage.StorableObjectFlags
// core properties (get sent over the wire) // core properties (get sent over the wire)
trunkId Id trunkID Id
branchId Id branchID Id
issuerPublicKey ed25519.PublicKey issuerPublicKey ed25519.PublicKey
issuingTime time.Time issuingTime time.Time
sequenceNumber uint64 sequenceNumber uint64
payload payload.Payload payload payload.Payload
bytes []byte nonce uint64
bytesMutex sync.RWMutex
signature ed25519.Signature signature ed25519.Signature
signatureMutex sync.RWMutex
// derived properties // derived properties
id *Id id *Id
idMutex sync.RWMutex idMutex sync.RWMutex
contentId *ContentId contentId *ContentId
contentIdMutex sync.RWMutex contentIdMutex sync.RWMutex
bytes []byte
// only stored on the machine of the signer bytesMutex sync.RWMutex
issuerLocalIdentity *identity.LocalIdentity
} }
// New creates a new message with the details provided by the issuer. // New creates a new message with the details provided by the issuer.
func New(trunkMessageId Id, branchMessageId Id, localIdentity *identity.LocalIdentity, issuingTime time.Time, sequenceNumber uint64, payload payload.Payload) (result *Message) { func New(trunkID Id, branchID Id, issuingTime time.Time, issuerPublicKey ed25519.PublicKey, sequenceNumber uint64, payload payload.Payload, nonce uint64, signature ed25519.Signature) (result *Message) {
return &Message{ return &Message{
trunkId: trunkMessageId, trunkID: trunkID,
branchId: branchMessageId, branchID: branchID,
issuerPublicKey: localIdentity.PublicKey(), issuerPublicKey: issuerPublicKey,
issuingTime: issuingTime, issuingTime: issuingTime,
sequenceNumber: sequenceNumber, sequenceNumber: sequenceNumber,
payload: payload, payload: payload,
nonce: nonce,
issuerLocalIdentity: localIdentity, signature: signature,
} }
} }
...@@ -113,10 +109,12 @@ func StorableObjectFromKey(key []byte, optionalTargetObject ...*Message) (result ...@@ -113,10 +109,12 @@ func StorableObjectFromKey(key []byte, optionalTargetObject ...*Message) (result
// VerifySignature verifies the signature of the message. // VerifySignature verifies the signature of the message.
func (message *Message) VerifySignature() bool { func (message *Message) VerifySignature() bool {
msgBytes := message.Bytes() msgBytes := message.Bytes()
message.signatureMutex.RLock() signature := message.Signature()
valid := message.issuerPublicKey.VerifySignature(msgBytes[:len(msgBytes)-ed25519.SignatureSize], message.Signature())
message.signatureMutex.RUnlock() contentLength := len(msgBytes) - len(signature)
return valid content := msgBytes[:contentLength]
return message.issuerPublicKey.VerifySignature(content, signature)
} }
// ID returns the id of the message which is made up of the content id and trunk/branch ids. // ID returns the id of the message which is made up of the content id and trunk/branch ids.
...@@ -145,12 +143,12 @@ func (message *Message) Id() (result Id) { ...@@ -145,12 +143,12 @@ func (message *Message) Id() (result Id) {
// TrunkID returns the id of the trunk message. // TrunkID returns the id of the trunk message.
func (message *Message) TrunkId() Id { func (message *Message) TrunkId() Id {
return message.trunkId return message.trunkID
} }
// BranchID returns the id of the branch message. // BranchID returns the id of the branch message.
func (message *Message) BranchId() Id { func (message *Message) BranchId() Id {
return message.branchId return message.branchID
} }
// IssuerPublicKey returns the public key of the message issuer. // IssuerPublicKey returns the public key of the message issuer.
...@@ -168,26 +166,21 @@ func (message *Message) SequenceNumber() uint64 { ...@@ -168,26 +166,21 @@ func (message *Message) SequenceNumber() uint64 {
return message.sequenceNumber return message.sequenceNumber
} }
// Signature returns the signature of the message.
func (message *Message) Signature() ed25519.Signature {
message.signatureMutex.RLock()
defer message.signatureMutex.RUnlock()
if message.signature == ed25519.EmptySignature {
// unlock the signatureMutex so Bytes() can write the Signature
message.signatureMutex.RUnlock()
message.Bytes()
message.signatureMutex.RLock()
}
return message.signature
}
// Payload returns the payload of the message. // Payload returns the payload of the message.
func (message *Message) Payload() payload.Payload { func (message *Message) Payload() payload.Payload {
return message.payload return message.payload
} }
// Payload returns the payload of the message.
func (message *Message) Nonce() uint64 {
return message.nonce
}
// Signature returns the signature of the message.
func (message *Message) Signature() ed25519.Signature {
return message.signature
}
// ContentId returns the content id of the message which is made up of all the // ContentId returns the content id of the message which is made up of all the
// parts of the message minus the trunk and branch ids. // parts of the message minus the trunk and branch ids.
func (message *Message) ContentId() (result ContentId) { func (message *Message) ContentId() (result ContentId) {
...@@ -215,8 +208,8 @@ func (message *Message) ContentId() (result ContentId) { ...@@ -215,8 +208,8 @@ func (message *Message) ContentId() (result ContentId) {
func (message *Message) calculateId() Id { func (message *Message) calculateId() Id {
return blake2b.Sum512( return blake2b.Sum512(
marshalutil.New(IdLength + IdLength + payload.IdLength). marshalutil.New(IdLength + IdLength + payload.IdLength).
WriteBytes(message.trunkId.Bytes()). WriteBytes(message.trunkID.Bytes()).
WriteBytes(message.branchId.Bytes()). WriteBytes(message.branchID.Bytes()).
WriteBytes(message.ContentId().Bytes()). WriteBytes(message.ContentId().Bytes()).
Bytes(), Bytes(),
) )
...@@ -247,17 +240,13 @@ func (message *Message) Bytes() []byte { ...@@ -247,17 +240,13 @@ func (message *Message) Bytes() []byte {
// marshal result // marshal result
marshalUtil := marshalutil.New() marshalUtil := marshalutil.New()
marshalUtil.WriteBytes(message.trunkId.Bytes()) marshalUtil.WriteBytes(message.trunkID.Bytes())
marshalUtil.WriteBytes(message.branchId.Bytes()) marshalUtil.WriteBytes(message.branchID.Bytes())
marshalUtil.WriteBytes(message.issuerPublicKey.Bytes()) marshalUtil.WriteBytes(message.issuerPublicKey.Bytes())
marshalUtil.WriteTime(message.issuingTime) marshalUtil.WriteTime(message.issuingTime)
marshalUtil.WriteUint64(message.sequenceNumber) marshalUtil.WriteUint64(message.sequenceNumber)
marshalUtil.WriteBytes(message.payload.Bytes()) marshalUtil.WriteBytes(message.payload.Bytes())
marshalUtil.WriteUint64(message.nonce)
message.signatureMutex.Lock()
message.signature = message.issuerLocalIdentity.Sign(marshalUtil.Bytes())
message.signatureMutex.Unlock()
marshalUtil.WriteBytes(message.signature.Bytes()) marshalUtil.WriteBytes(message.signature.Bytes())
message.bytes = marshalUtil.Bytes() message.bytes = marshalUtil.Bytes()
...@@ -270,10 +259,10 @@ func (message *Message) UnmarshalObjectStorageValue(data []byte) (consumedBytes ...@@ -270,10 +259,10 @@ func (message *Message) UnmarshalObjectStorageValue(data []byte) (consumedBytes
marshalUtil := marshalutil.New(data) marshalUtil := marshalutil.New(data)
// parse information // parse information
if message.trunkId, err = ParseId(marshalUtil); err != nil { if message.trunkID, err = ParseId(marshalUtil); err != nil {
return return
} }
if message.branchId, err = ParseId(marshalUtil); err != nil { if message.branchID, err = ParseId(marshalUtil); err != nil {
return return
} }
if message.issuerPublicKey, err = ed25519.ParsePublicKey(marshalUtil); err != nil { if message.issuerPublicKey, err = ed25519.ParsePublicKey(marshalUtil); err != nil {
...@@ -288,6 +277,9 @@ func (message *Message) UnmarshalObjectStorageValue(data []byte) (consumedBytes ...@@ -288,6 +277,9 @@ func (message *Message) UnmarshalObjectStorageValue(data []byte) (consumedBytes
if message.payload, err = payload.Parse(marshalUtil); err != nil { if message.payload, err = payload.Parse(marshalUtil); err != nil {
return return
} }
if message.nonce, err = marshalUtil.ReadUint64(); err != nil {
return
}
if message.signature, err = ed25519.ParseSignature(marshalUtil); err != nil { if message.signature, err = ed25519.ParseSignature(marshalUtil); err != nil {
return return
} }
...@@ -311,19 +303,22 @@ func (message *Message) ObjectStorageValue() []byte { ...@@ -311,19 +303,22 @@ func (message *Message) ObjectStorageValue() []byte {
return message.Bytes() return message.Bytes()
} }
func (message *Message) Update(other objectstorage.StorableObject) { // Update updates the object with the values of another object.
// Since a Message is immutable, this function is not implemented and panics.
func (message *Message) Update(objectstorage.StorableObject) {
panic("messages should never be overwritten and only stored once to optimize IO") panic("messages should never be overwritten and only stored once to optimize IO")
} }
func (message *Message) String() string { func (message *Message) String() string {
return stringify.Struct("Message", return stringify.Struct("Message",
stringify.StructField("id", message.Id()), stringify.StructField("id", message.Id()),
stringify.StructField("trunkMessageId", message.TrunkId()), stringify.StructField("trunkId", message.TrunkId()),
stringify.StructField("branchMessageId", message.BranchId()), stringify.StructField("branchId", message.BranchId()),
stringify.StructField("issuer", message.IssuerPublicKey()), stringify.StructField("issuer", message.IssuerPublicKey()),
stringify.StructField("issuingTime", message.IssuingTime()), stringify.StructField("issuingTime", message.IssuingTime()),
stringify.StructField("sequenceNumber", message.SequenceNumber()), stringify.StructField("sequenceNumber", message.SequenceNumber()),
stringify.StructField("payload", message.Payload()), stringify.StructField("payload", message.Payload()),
stringify.StructField("nonce", message.Nonce()),
stringify.StructField("signature", message.Signature()), stringify.StructField("signature", message.Signature()),
) )
} }
......
...@@ -2,27 +2,44 @@ package messagefactory ...@@ -2,27 +2,44 @@ package messagefactory
import ( import (
"fmt" "fmt"
"sync"
"time" "time"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/tipselector" "github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/identity" "github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/kvstore" "github.com/iotaledger/hive.go/kvstore"
) )
const storeSequenceInterval = 100 const storeSequenceInterval = 100
// A TipSelector selects two tips, branch and trunk, for a new message to attach to.
type TipSelector interface {
Tips() (trunk message.Id, branch message.Id)
}
// A Worker performs the PoW for the provided message in serialized byte form.
type Worker interface {
DoPOW([]byte) (nonce uint64, err error)
}
// ZeroWorker is a PoW worker that always returns 0 as the nonce.
var ZeroWorker = WorkerFunc(func([]byte) (uint64, error) { return 0, nil })
// MessageFactory acts as a factory to create new messages. // MessageFactory acts as a factory to create new messages.
type MessageFactory struct { type MessageFactory struct {
Events *Events Events *Events
sequence *kvstore.Sequence sequence *kvstore.Sequence
localIdentity *identity.LocalIdentity localIdentity *identity.LocalIdentity
tipSelector *tipselector.TipSelector selector TipSelector
worker Worker
workerMutex sync.RWMutex
} }
// New creates a new message factory. // New creates a new message factory.
func New(store kvstore.KVStore, localIdentity *identity.LocalIdentity, tipSelector *tipselector.TipSelector, sequenceKey []byte) *MessageFactory { func New(store kvstore.KVStore, sequenceKey []byte, localIdentity *identity.LocalIdentity, selector TipSelector) *MessageFactory {
sequence, err := kvstore.NewSequence(store, sequenceKey, storeSequenceInterval) sequence, err := kvstore.NewSequence(store, sequenceKey, storeSequenceInterval)
if err != nil { if err != nil {
panic(fmt.Sprintf("could not create message sequence number: %v", err)) panic(fmt.Sprintf("could not create message sequence number: %v", err))
...@@ -32,10 +49,18 @@ func New(store kvstore.KVStore, localIdentity *identity.LocalIdentity, tipSelect ...@@ -32,10 +49,18 @@ func New(store kvstore.KVStore, localIdentity *identity.LocalIdentity, tipSelect
Events: newEvents(), Events: newEvents(),
sequence: sequence, sequence: sequence,
localIdentity: localIdentity, localIdentity: localIdentity,
tipSelector: tipSelector, selector: selector,
worker: ZeroWorker,
} }
} }
// SetWorker sets the PoW worker to be used for the messages.
func (m *MessageFactory) SetWorker(worker Worker) {
m.workerMutex.Lock()
defer m.workerMutex.Unlock()
m.worker = worker
}
// IssuePayload creates a new message including sequence number and tip selection and returns it. // IssuePayload creates a new message including sequence number and tip selection and returns it.
// It also triggers the MessageConstructed event once it's done, which is for example used by the plugins to listen for // It also triggers the MessageConstructed event once it's done, which is for example used by the plugins to listen for
// messages that shall be attached to the tangle. // messages that shall be attached to the tangle.
...@@ -46,16 +71,30 @@ func (m *MessageFactory) IssuePayload(payload payload.Payload) *message.Message ...@@ -46,16 +71,30 @@ func (m *MessageFactory) IssuePayload(payload payload.Payload) *message.Message
return nil return nil
} }
trunkMessageId, branchMessageId := m.tipSelector.Tips() trunkID, branchID := m.selector.Tips()
issuingTime := time.Now()
issuerPublicKey := m.localIdentity.PublicKey()
// do the PoW
nonce, err := m.doPOW(trunkID, branchID, issuingTime, issuerPublicKey, sequenceNumber, payload)
if err != nil {
m.Events.Error.Trigger(fmt.Errorf("pow failed: %w", err))
return nil
}
// create the signature
signature := m.sign(trunkID, branchID, issuingTime, issuerPublicKey, sequenceNumber, payload, nonce)
msg := message.New( msg := message.New(
trunkMessageId, trunkID,
branchMessageId, branchID,
m.localIdentity, issuingTime,
time.Now(), issuerPublicKey,
sequenceNumber, sequenceNumber,
payload, payload,
nonce,
signature,
) )
m.Events.MessageConstructed.Trigger(msg) m.Events.MessageConstructed.Trigger(msg)
return msg return msg
} }
...@@ -66,3 +105,37 @@ func (m *MessageFactory) Shutdown() { ...@@ -66,3 +105,37 @@ func (m *MessageFactory) Shutdown() {
m.Events.Error.Trigger(fmt.Errorf("could not release message sequence number: %w", err)) m.Events.Error.Trigger(fmt.Errorf("could not release message sequence number: %w", err))
} }
} }
func (m *MessageFactory) doPOW(trunkID message.Id, branchID message.Id, issuingTime time.Time, key ed25519.PublicKey, seq uint64, payload payload.Payload) (uint64, error) {
// create a dummy message to simplify marshaling
dummy := message.New(trunkID, branchID, issuingTime, key, seq, payload, 0, ed25519.EmptySignature).Bytes()
m.workerMutex.RLock()
defer m.workerMutex.RUnlock()
return m.worker.DoPOW(dummy)
}
func (m *MessageFactory) sign(trunkID message.Id, branchID message.Id, issuingTime time.Time, key ed25519.PublicKey, seq uint64, payload payload.Payload, nonce uint64) ed25519.Signature {
// create a dummy message to simplify marshaling
dummy := message.New(trunkID, branchID, issuingTime, key, seq, payload, nonce, ed25519.EmptySignature)
dummyBytes := dummy.Bytes()
contentLength := len(dummyBytes) - len(dummy.Signature())
return m.localIdentity.Sign(dummyBytes[:contentLength])
}
// The TipSelectorFunc type is an adapter to allow the use of ordinary functions as tip selectors.
type TipSelectorFunc func() (message.Id, message.Id)
// Tips calls f().
func (f TipSelectorFunc) Tips() (message.Id, message.Id) {
return f()
}
// The WorkerFunc type is an adapter to allow the use of ordinary functions as a PoW performer.
type WorkerFunc func([]byte) (uint64, error)
// DoPOW calls f(msg).
func (f WorkerFunc) DoPOW(msg []byte) (uint64, error) {
return f(msg)
}
package messagefactory package messagefactory
import ( import (
"encoding" "context"
"crypto"
"crypto/ed25519"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
...@@ -9,21 +11,27 @@ import ( ...@@ -9,21 +11,27 @@ import (
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/tipselector" "github.com/iotaledger/goshimmer/packages/pow"
"github.com/iotaledger/hive.go/kvstore/mapdb"
"github.com/iotaledger/hive.go/events" "github.com/iotaledger/hive.go/events"
"github.com/iotaledger/hive.go/identity" "github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/kvstore/mapdb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
_ "golang.org/x/crypto/blake2b"
) )
const ( const (
sequenceKey = "seq" sequenceKey = "seq"
targetPOW = 10
totalMessages = 2000 totalMessages = 2000
) )
func TestMessageFactory_BuildMessage(t *testing.T) { func TestMessageFactory_BuildMessage(t *testing.T) {
msgFactory := New(mapdb.NewMapDB(), identity.GenerateLocalIdentity(), tipselector.New(), []byte(sequenceKey)) msgFactory := New(
mapdb.NewMapDB(),
[]byte(sequenceKey),
identity.GenerateLocalIdentity(),
TipSelectorFunc(func() (message.Id, message.Id) { return message.EmptyId, message.EmptyId }),
)
defer msgFactory.Shutdown() defer msgFactory.Shutdown()
// keep track of sequence numbers // keep track of sequence numbers
...@@ -36,8 +44,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) { ...@@ -36,8 +44,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) {
})) }))
t.Run("CheckProperties", func(t *testing.T) { t.Run("CheckProperties", func(t *testing.T) {
data := []byte("TestCheckProperties") p := payload.NewData([]byte("TestCheckProperties"))
var p payload.Payload = NewMockPayload(data)
msg := msgFactory.IssuePayload(p) msg := msgFactory.IssuePayload(p)
assert.NotNil(t, msg.TrunkId()) assert.NotNil(t, msg.TrunkId())
...@@ -47,8 +54,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) { ...@@ -47,8 +54,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) {
assert.InDelta(t, time.Now().UnixNano(), msg.IssuingTime().UnixNano(), 100000000) assert.InDelta(t, time.Now().UnixNano(), msg.IssuingTime().UnixNano(), 100000000)
// check payload // check payload
assert.Same(t, p, msg.Payload()) assert.Equal(t, p, msg.Payload())
assert.Equal(t, data, msg.Payload().Bytes())
// check total events and sequence number // check total events and sequence number
assert.EqualValues(t, 1, countEvents) assert.EqualValues(t, 1, countEvents)
...@@ -62,8 +68,8 @@ func TestMessageFactory_BuildMessage(t *testing.T) { ...@@ -62,8 +68,8 @@ func TestMessageFactory_BuildMessage(t *testing.T) {
for i := 1; i < totalMessages; i++ { for i := 1; i < totalMessages; i++ {
t.Run("test", func(t *testing.T) { t.Run("test", func(t *testing.T) {
t.Parallel() t.Parallel()
data := []byte("TestCheckProperties")
var p payload.Payload = NewMockPayload(data) p := payload.NewData([]byte("TestParallelCreation"))
msg := msgFactory.IssuePayload(p) msg := msgFactory.IssuePayload(p)
assert.NotNil(t, msg.TrunkId()) assert.NotNil(t, msg.TrunkId())
...@@ -73,8 +79,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) { ...@@ -73,8 +79,7 @@ func TestMessageFactory_BuildMessage(t *testing.T) {
assert.InDelta(t, time.Now().UnixNano(), msg.IssuingTime().UnixNano(), 100000000) assert.InDelta(t, time.Now().UnixNano(), msg.IssuingTime().UnixNano(), 100000000)
// check payload // check payload
assert.Same(t, p, msg.Payload()) assert.Equal(t, p, msg.Payload())
assert.Equal(t, data, msg.Payload().Bytes())
sequenceNumbers.Store(msg.SequenceNumber(), true) sequenceNumbers.Store(msg.SequenceNumber(), true)
}) })
...@@ -104,28 +109,27 @@ func TestMessageFactory_BuildMessage(t *testing.T) { ...@@ -104,28 +109,27 @@ func TestMessageFactory_BuildMessage(t *testing.T) {
assert.EqualValues(t, totalMessages, countSequence) assert.EqualValues(t, totalMessages, countSequence)
} }
type MockPayload struct { func TestMessageFactory_POW(t *testing.T) {
data []byte msgFactory := New(
encoding.BinaryMarshaler mapdb.NewMapDB(),
encoding.BinaryUnmarshaler []byte(sequenceKey),
} identity.GenerateLocalIdentity(),
TipSelectorFunc(func() (message.Id, message.Id) { return message.EmptyId, message.EmptyId }),
func NewMockPayload(data []byte) *MockPayload { )
return &MockPayload{data: data} defer msgFactory.Shutdown()
}
func (m *MockPayload) Bytes() []byte { worker := pow.New(crypto.BLAKE2b_512, 1)
return m.data
}
func (m *MockPayload) Type() payload.Type { msgFactory.SetWorker(WorkerFunc(func(msgBytes []byte) (uint64, error) {
return payload.Type(0) content := msgBytes[:len(msgBytes)-ed25519.SignatureSize-8]
} return worker.Mine(context.Background(), content, targetPOW)
}))
func (m *MockPayload) String() string { msg := msgFactory.IssuePayload(payload.NewData([]byte("test")))
return string(m.data) msgBytes := msg.Bytes()
} content := msgBytes[:len(msgBytes)-ed25519.SignatureSize-8]
func (m *MockPayload) Unmarshal(bytes []byte) error { zeroes, err := worker.LeadingZerosWithNonce(content, msg.Nonce())
panic("implement me") assert.GreaterOrEqual(t, zeroes, targetPOW)
assert.NoError(t, err)
} }
package builtinfilters
import (
"crypto/ed25519"
"errors"
"fmt"
"sync"
"github.com/iotaledger/goshimmer/packages/pow"
"github.com/iotaledger/hive.go/async"
"github.com/iotaledger/hive.go/autopeering/peer"
)
var (
// ErrInvalidPOWDifficultly is returned when the nonce of a message does not fulfill the PoW difficulty.
ErrInvalidPOWDifficultly = errors.New("invalid PoW")
// ErrMessageTooSmall is returned when the message does not contain enough data for the PoW.
ErrMessageTooSmall = errors.New("message too small")
)
// PowFilter is a message bytes filter validating the PoW nonce.
type PowFilter struct {
worker *pow.Worker
difficulty int
workerPool async.WorkerPool
mu sync.Mutex
acceptCallback func([]byte, *peer.Peer)
rejectCallback func([]byte, error, *peer.Peer)
}
// NewPowFilter creates a new PoW bytes filter.
func NewPowFilter(worker *pow.Worker, difficulty int) *PowFilter {
return &PowFilter{
worker: worker,
difficulty: difficulty,
}
}
// Filter checks whether the given bytes pass the PoW validation and calls the corresponding callback.
func (f *PowFilter) Filter(msgBytes []byte, p *peer.Peer) {
f.workerPool.Submit(func() {
if err := f.validate(msgBytes); err != nil {
f.reject(msgBytes, err, p)
return
}
f.accept(msgBytes, p)
})
}
// OnAccept registers the given callback as the acceptance function of the filter.
func (f *PowFilter) OnAccept(callback func([]byte, *peer.Peer)) {
f.mu.Lock()
defer f.mu.Unlock()
f.acceptCallback = callback
}
// OnReject registers the given callback as the rejection function of the filter.
func (f *PowFilter) OnReject(callback func([]byte, error, *peer.Peer)) {
f.mu.Lock()
defer f.mu.Unlock()
f.rejectCallback = callback
}
// Shutdown shuts down the filter.
func (f *PowFilter) Shutdown() {
f.workerPool.ShutdownGracefully()
}
func (f *PowFilter) accept(msgBytes []byte, p *peer.Peer) {
f.mu.Lock()
defer f.mu.Unlock()
if f.acceptCallback != nil {
f.acceptCallback(msgBytes, p)
}
}
func (f *PowFilter) reject(msgBytes []byte, err error, p *peer.Peer) {
f.mu.Lock()
defer f.mu.Unlock()
if f.rejectCallback != nil {
f.rejectCallback(msgBytes, err, p)
}
}
func (f *PowFilter) validate(msgBytes []byte) error {
content, err := powData(msgBytes)
if err != nil {
return err
}
zeros, err := f.worker.LeadingZeros(content)
if err != nil {
return err
}
if zeros < f.difficulty {
return fmt.Errorf("%w: leading zeros %d for difficulty %d", ErrInvalidPOWDifficultly, zeros, f.difficulty)
}
return nil
}
// powData returns the bytes over which PoW should be computed.
func powData(msgBytes []byte) ([]byte, error) {
contentLength := len(msgBytes) - ed25519.SignatureSize
if contentLength < pow.NonceBytes {
return nil, ErrMessageTooSmall
}
return msgBytes[:contentLength], nil
}
package builtinfilters
import (
"context"
"crypto"
"errors"
"testing"
"time"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload"
"github.com/iotaledger/goshimmer/packages/pow"
"github.com/iotaledger/hive.go/autopeering/peer"
"github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
_ "golang.org/x/crypto/blake2b" // required by crypto.BLAKE2b_512
)
var (
testPayload = payload.NewData([]byte("test"))
testPeer *peer.Peer = nil
testWorker = pow.New(crypto.BLAKE2b_512, 1)
testDifficulty = 10
)
func TestPowFilter_Filter(t *testing.T) {
filter := NewPowFilter(testWorker, testDifficulty)
defer filter.Shutdown()
// set callbacks
m := &callbackMock{}
filter.OnAccept(m.Accept)
filter.OnReject(m.Reject)
t.Run("reject small message", func(t *testing.T) {
m.On("Reject", mock.Anything, mock.MatchedBy(func(err error) bool { return errors.Is(err, ErrMessageTooSmall) }), testPeer)
filter.Filter(nil, testPeer)
})
msg := newTestMessage(0)
msgBytes := msg.Bytes()
t.Run("reject invalid nonce", func(t *testing.T) {
m.On("Reject", msgBytes, mock.MatchedBy(func(err error) bool { return errors.Is(err, ErrInvalidPOWDifficultly) }), testPeer)
filter.Filter(msgBytes, testPeer)
})
nonce, err := testWorker.Mine(context.Background(), msgBytes[:len(msgBytes)-len(msg.Signature())-pow.NonceBytes], testDifficulty)
require.NoError(t, err)
msgPOW := newTestMessage(nonce)
msgPOWBytes := msgPOW.Bytes()
t.Run("accept valid nonce", func(t *testing.T) {
zeroes, err := testWorker.LeadingZeros(msgPOWBytes[:len(msgPOWBytes)-len(msgPOW.Signature())])
require.NoError(t, err)
require.GreaterOrEqual(t, zeroes, testDifficulty)
m.On("Accept", msgPOWBytes, testPeer)
filter.Filter(msgPOWBytes, testPeer)
})
filter.Shutdown()
m.AssertExpectations(t)
}
type callbackMock struct{ mock.Mock }
func (m *callbackMock) Accept(msg []byte, p *peer.Peer) { m.Called(msg, p) }
func (m *callbackMock) Reject(msg []byte, err error, p *peer.Peer) { m.Called(msg, err, p) }
func newTestMessage(nonce uint64) *message.Message {
return message.New(message.EmptyId, message.EmptyId, time.Time{}, ed25519.PublicKey{}, 0, testPayload, nonce, ed25519.Signature{})
}
...@@ -11,7 +11,7 @@ type BytesFilter interface { ...@@ -11,7 +11,7 @@ type BytesFilter interface {
Filter(bytes []byte, peer *peer.Peer) Filter(bytes []byte, peer *peer.Peer)
// OnAccept registers the given callback as the acceptance function of the filter. // OnAccept registers the given callback as the acceptance function of the filter.
OnAccept(callback func(bytes []byte, peer *peer.Peer)) OnAccept(callback func(bytes []byte, peer *peer.Peer))
// OnAccept registers the given callback as the rejection function of the filter. // OnReject registers the given callback as the rejection function of the filter.
OnReject(callback func(bytes []byte, err error, peer *peer.Peer)) OnReject(callback func(bytes []byte, err error, peer *peer.Peer))
// Shutdown shuts down the filter. // Shutdown shuts down the filter.
Shutdown() Shutdown()
......
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/events" "github.com/iotaledger/hive.go/events"
"github.com/iotaledger/hive.go/identity"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
...@@ -14,8 +14,7 @@ import ( ...@@ -14,8 +14,7 @@ import (
) )
func BenchmarkMessageParser_ParseBytesSame(b *testing.B) { func BenchmarkMessageParser_ParseBytesSame(b *testing.B) {
localIdentity := identity.GenerateLocalIdentity() msgBytes := newTestMessage("Test").Bytes()
msgBytes := message.New(message.EmptyId, message.EmptyId, localIdentity, time.Now(), 0, payload.NewData([]byte("Test"))).Bytes()
msgParser := New() msgParser := New()
b.ResetTimer() b.ResetTimer()
...@@ -29,9 +28,8 @@ func BenchmarkMessageParser_ParseBytesSame(b *testing.B) { ...@@ -29,9 +28,8 @@ func BenchmarkMessageParser_ParseBytesSame(b *testing.B) {
func BenchmarkMessageParser_ParseBytesDifferent(b *testing.B) { func BenchmarkMessageParser_ParseBytesDifferent(b *testing.B) {
messageBytes := make([][]byte, b.N) messageBytes := make([][]byte, b.N)
localIdentity := identity.GenerateLocalIdentity()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
messageBytes[i] = message.New(message.EmptyId, message.EmptyId, localIdentity, time.Now(), 0, payload.NewData([]byte("Test"+strconv.Itoa(i)))).Bytes() messageBytes[i] = newTestMessage("Test" + strconv.Itoa(i)).Bytes()
} }
msgParser := New() msgParser := New()
...@@ -46,8 +44,7 @@ func BenchmarkMessageParser_ParseBytesDifferent(b *testing.B) { ...@@ -46,8 +44,7 @@ func BenchmarkMessageParser_ParseBytesDifferent(b *testing.B) {
} }
func TestMessageParser_ParseMessage(t *testing.T) { func TestMessageParser_ParseMessage(t *testing.T) {
localIdentity := identity.GenerateLocalIdentity() msg := newTestMessage("Test")
msg := message.New(message.EmptyId, message.EmptyId, localIdentity, time.Now(), 0, payload.NewData([]byte("Test")))
msgParser := New() msgParser := New()
msgParser.Parse(msg.Bytes(), nil) msgParser.Parse(msg.Bytes(), nil)
...@@ -58,3 +55,7 @@ func TestMessageParser_ParseMessage(t *testing.T) { ...@@ -58,3 +55,7 @@ func TestMessageParser_ParseMessage(t *testing.T) {
msgParser.Shutdown() msgParser.Shutdown()
} }
func newTestMessage(payloadString string) *message.Message {
return message.New(message.EmptyId, message.EmptyId, time.Now(), ed25519.PublicKey{}, 0, payload.NewData([]byte(payloadString)), 0, ed25519.Signature{})
}
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/payload"
"github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/events" "github.com/iotaledger/hive.go/events"
"github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/kvstore/mapdb" "github.com/iotaledger/hive.go/kvstore/mapdb"
) )
...@@ -20,11 +20,9 @@ func BenchmarkTangle_AttachMessage(b *testing.B) { ...@@ -20,11 +20,9 @@ func BenchmarkTangle_AttachMessage(b *testing.B) {
return return
} }
testIdentity := identity.GenerateLocalIdentity()
messageBytes := make([]*message.Message, b.N) messageBytes := make([]*message.Message, b.N)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
messageBytes[i] = message.New(message.EmptyId, message.EmptyId, testIdentity, time.Now(), 0, payload.NewData([]byte("some data"))) messageBytes[i] = newTestMessage("some data")
messageBytes[i].Bytes() messageBytes[i].Bytes()
} }
...@@ -73,10 +71,8 @@ func TestTangle_AttachMessage(t *testing.T) { ...@@ -73,10 +71,8 @@ func TestTangle_AttachMessage(t *testing.T) {
fmt.Println("REMOVED:", messageId) fmt.Println("REMOVED:", messageId)
})) }))
localIdentity1 := identity.GenerateLocalIdentity() newMessageOne := newTestMessage("some data")
localIdentity2 := identity.GenerateLocalIdentity() newMessageTwo := newTestMessage("some other data")
newMessageOne := message.New(message.EmptyId, message.EmptyId, localIdentity1, time.Now(), 0, payload.NewData([]byte("some data")))
newMessageTwo := message.New(newMessageOne.Id(), newMessageOne.Id(), localIdentity2, time.Now(), 0, payload.NewData([]byte("some other data")))
messageTangle.AttachMessage(newMessageTwo) messageTangle.AttachMessage(newMessageTwo)
...@@ -86,3 +82,7 @@ func TestTangle_AttachMessage(t *testing.T) { ...@@ -86,3 +82,7 @@ func TestTangle_AttachMessage(t *testing.T) {
messageTangle.Shutdown() messageTangle.Shutdown()
} }
func newTestMessage(payloadString string) *message.Message {
return message.New(message.EmptyId, message.EmptyId, time.Now(), ed25519.PublicKey{}, 0, payload.NewData([]byte(payloadString)), 0, ed25519.Signature{})
}
...@@ -4,10 +4,12 @@ import ( ...@@ -4,10 +4,12 @@ import (
"runtime" "runtime"
"sync" "sync"
"testing" "testing"
"time"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/messagefactory"
"github.com/iotaledger/goshimmer/plugins/messagelayer"
"github.com/iotaledger/hive.go/async" "github.com/iotaledger/hive.go/async"
"github.com/iotaledger/hive.go/identity" "github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/kvstore/mapdb"
"github.com/panjf2000/ants/v2" "github.com/panjf2000/ants/v2"
...@@ -19,11 +21,11 @@ func BenchmarkVerifyDataMessages(b *testing.B) { ...@@ -19,11 +21,11 @@ func BenchmarkVerifyDataMessages(b *testing.B) {
var pool async.WorkerPool var pool async.WorkerPool
pool.Tune(runtime.NumCPU() * 2) pool.Tune(runtime.NumCPU() * 2)
localIdentity := identity.GenerateLocalIdentity() factory := messagefactory.New(mapdb.NewMapDB(), []byte(messagelayer.DBSequenceNumber), identity.GenerateLocalIdentity(), messagefactory.TipSelectorFunc(func() (message.Id, message.Id) { return message.EmptyId, message.EmptyId }))
messages := make([][]byte, b.N) messages := make([][]byte, b.N)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
messages[i] = message.New(message.EmptyId, message.EmptyId, localIdentity, time.Now(), 0, payload.NewData([]byte("some data"))).Bytes() messages[i] = factory.IssuePayload(payload.NewData([]byte("some data"))).Bytes()
} }
b.ResetTimer() b.ResetTimer()
...@@ -45,18 +47,16 @@ func BenchmarkVerifyDataMessages(b *testing.B) { ...@@ -45,18 +47,16 @@ func BenchmarkVerifyDataMessages(b *testing.B) {
func BenchmarkVerifySignature(b *testing.B) { func BenchmarkVerifySignature(b *testing.B) {
pool, _ := ants.NewPool(80, ants.WithNonblocking(false)) pool, _ := ants.NewPool(80, ants.WithNonblocking(false))
localIdentity := identity.GenerateLocalIdentity() factory := messagefactory.New(mapdb.NewMapDB(), []byte(messagelayer.DBSequenceNumber), identity.GenerateLocalIdentity(), messagefactory.TipSelectorFunc(func() (message.Id, message.Id) { return message.EmptyId, message.EmptyId }))
messages := make([]*message.Message, b.N) messages := make([]*message.Message, b.N)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
messages[i] = message.New(message.EmptyId, message.EmptyId, localIdentity, time.Now(), 0, payload.NewData([]byte("test"))) messages[i] = factory.IssuePayload(payload.NewData([]byte("test")))
messages[i].Bytes() messages[i].Bytes()
} }
var wg sync.WaitGroup
b.ResetTimer() b.ResetTimer()
var wg sync.WaitGroup
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
wg.Add(1) wg.Add(1)
......
package test package test
import ( import (
"fmt"
"testing" "testing"
"time" "time"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/messagefactory" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/messagefactory"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/tipselector" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/tipselector"
"github.com/iotaledger/goshimmer/plugins/messagelayer" "github.com/iotaledger/goshimmer/plugins/messagelayer"
"github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/iotaledger/hive.go/identity" "github.com/iotaledger/hive.go/identity"
"github.com/iotaledger/hive.go/kvstore/mapdb" "github.com/iotaledger/hive.go/kvstore/mapdb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
...@@ -31,14 +31,28 @@ func TestMessage_StorableObjectFromKey(t *testing.T) { ...@@ -31,14 +31,28 @@ func TestMessage_StorableObjectFromKey(t *testing.T) {
assert.Equal(t, key, messageFromKey.(*message.Message).Id()) assert.Equal(t, key, messageFromKey.(*message.Message).Id())
} }
func TestMessage_VerifySignature(t *testing.T) {
keyPair := ed25519.GenerateKeyPair()
pl := payload.NewData([]byte("test"))
unsigned := message.New(message.EmptyId, message.EmptyId, time.Time{}, keyPair.PublicKey, 0, pl, 0, ed25519.Signature{})
assert.False(t, unsigned.VerifySignature())
unsignedBytes := unsigned.Bytes()
signature := keyPair.PrivateKey.Sign(unsignedBytes[:len(unsignedBytes)-ed25519.SignatureSize])
signed := message.New(message.EmptyId, message.EmptyId, time.Time{}, keyPair.PublicKey, 0, pl, 0, signature)
assert.True(t, signed.VerifySignature())
}
func TestMessage_MarshalUnmarshal(t *testing.T) { func TestMessage_MarshalUnmarshal(t *testing.T) {
msgFactory := messagefactory.New(mapdb.NewMapDB(), identity.GenerateLocalIdentity(), tipselector.New(), []byte(messagelayer.DBSequenceNumber)) msgFactory := messagefactory.New(mapdb.NewMapDB(), []byte(messagelayer.DBSequenceNumber), identity.GenerateLocalIdentity(), tipselector.New())
defer msgFactory.Shutdown() defer msgFactory.Shutdown()
testMessage := msgFactory.IssuePayload(payload.NewData([]byte("sth"))) testMessage := msgFactory.IssuePayload(payload.NewData([]byte("test")))
assert.Equal(t, true, testMessage.VerifySignature()) assert.Equal(t, true, testMessage.VerifySignature())
fmt.Print(testMessage) t.Log(testMessage)
restoredMessage, err, _ := message.FromBytes(testMessage.Bytes()) restoredMessage, err, _ := message.FromBytes(testMessage.Bytes())
if assert.NoError(t, err, err) { if assert.NoError(t, err, err) {
...@@ -48,6 +62,7 @@ func TestMessage_MarshalUnmarshal(t *testing.T) { ...@@ -48,6 +62,7 @@ func TestMessage_MarshalUnmarshal(t *testing.T) {
assert.Equal(t, testMessage.IssuerPublicKey(), restoredMessage.IssuerPublicKey()) assert.Equal(t, testMessage.IssuerPublicKey(), restoredMessage.IssuerPublicKey())
assert.Equal(t, testMessage.IssuingTime().Round(time.Second), restoredMessage.IssuingTime().Round(time.Second)) assert.Equal(t, testMessage.IssuingTime().Round(time.Second), restoredMessage.IssuingTime().Round(time.Second))
assert.Equal(t, testMessage.SequenceNumber(), restoredMessage.SequenceNumber()) assert.Equal(t, testMessage.SequenceNumber(), restoredMessage.SequenceNumber())
assert.Equal(t, testMessage.Nonce(), restoredMessage.Nonce())
assert.Equal(t, testMessage.Signature(), restoredMessage.Signature()) assert.Equal(t, testMessage.Signature(), restoredMessage.Signature())
assert.Equal(t, true, restoredMessage.VerifySignature()) assert.Equal(t, true, restoredMessage.VerifySignature())
} }
......
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/iotaledger/hive.go/identity" "github.com/iotaledger/hive.go/crypto/ed25519"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/message" "github.com/iotaledger/goshimmer/packages/binary/messagelayer/message"
...@@ -21,8 +21,7 @@ func Test(t *testing.T) { ...@@ -21,8 +21,7 @@ func Test(t *testing.T) {
assert.Equal(t, message.EmptyId, branch1) assert.Equal(t, message.EmptyId, branch1)
// create a message and attach it // create a message and attach it
localIdentity1 := identity.GenerateLocalIdentity() message1 := newTestMessage(trunk1, branch1, "testmessage")
message1 := message.New(trunk1, branch1, localIdentity1, time.Now(), 0, payload.NewData([]byte("testmessage")))
tipSelector.AddTip(message1) tipSelector.AddTip(message1)
// check if the tip shows up in the tip count // check if the tip shows up in the tip count
...@@ -34,17 +33,15 @@ func Test(t *testing.T) { ...@@ -34,17 +33,15 @@ func Test(t *testing.T) {
assert.Equal(t, message1.Id(), branch2) assert.Equal(t, message1.Id(), branch2)
// create a 2nd message and attach it // create a 2nd message and attach it
localIdentity2 := identity.GenerateLocalIdentity() message2 := newTestMessage(message.EmptyId, message.EmptyId, "testmessage")
message2 := message.New(message.EmptyId, message.EmptyId, localIdentity2, time.Now(), 0, payload.NewData([]byte("testmessage")))
tipSelector.AddTip(message2) tipSelector.AddTip(message2)
// check if the tip shows up in the tip count // check if the tip shows up in the tip count
assert.Equal(t, 2, tipSelector.TipCount()) assert.Equal(t, 2, tipSelector.TipCount())
// attach a message to our two tips // attach a message to our two tips
localIdentity3 := identity.GenerateLocalIdentity()
trunk3, branch3 := tipSelector.Tips() trunk3, branch3 := tipSelector.Tips()
message3 := message.New(trunk3, branch3, localIdentity3, time.Now(), 0, payload.NewData([]byte("testmessage"))) message3 := newTestMessage(trunk3, branch3, "testmessage")
tipSelector.AddTip(message3) tipSelector.AddTip(message3)
// check if the tip shows replaces the current tips // check if the tip shows replaces the current tips
...@@ -53,3 +50,7 @@ func Test(t *testing.T) { ...@@ -53,3 +50,7 @@ func Test(t *testing.T) {
assert.Equal(t, message3.Id(), trunk4) assert.Equal(t, message3.Id(), trunk4)
assert.Equal(t, message3.Id(), branch4) assert.Equal(t, message3.Id(), branch4)
} }
func newTestMessage(trunk, branch message.Id, payloadString string) *message.Message {
return message.New(trunk, branch, time.Now(), ed25519.PublicKey{}, 0, payload.NewData([]byte(payloadString)), 0, ed25519.Signature{})
}
...@@ -17,6 +17,9 @@ var ( ...@@ -17,6 +17,9 @@ var (
ErrDone = errors.New("done") ErrDone = errors.New("done")
) )
// NonceBytes specifies the number of bytes required for the nonce.
const NonceBytes = 8
// Hash identifies a cryptographic hash function that is implemented in another package. // Hash identifies a cryptographic hash function that is implemented in another package.
type Hash interface { type Hash interface {
// Size returns the length, in bytes, of a digest resulting from the given hash function. // Size returns the length, in bytes, of a digest resulting from the given hash function.
...@@ -38,7 +41,7 @@ func New(hash Hash, numWorkers ...int) *Worker { ...@@ -38,7 +41,7 @@ func New(hash Hash, numWorkers ...int) *Worker {
hash: hash, hash: hash,
numWorkers: 1, numWorkers: 1,
} }
if len(numWorkers) > 0 { if len(numWorkers) > 0 && numWorkers[0] > 0 {
w.numWorkers = numWorkers[0] w.numWorkers = numWorkers[0]
} }
return w return w
...@@ -106,15 +109,15 @@ func (w *Worker) LeadingZeros(data []byte) (int, error) { ...@@ -106,15 +109,15 @@ func (w *Worker) LeadingZeros(data []byte) (int, error) {
// LeadingZerosWithNonce returns the number of leading zeros in the digest // LeadingZerosWithNonce returns the number of leading zeros in the digest
// after the provided 8-byte nonce is appended to msg. // after the provided 8-byte nonce is appended to msg.
func (w *Worker) LeadingZerosWithNonce(msg []byte, nonce uint64) (int, error) { func (w *Worker) LeadingZerosWithNonce(msg []byte, nonce uint64) (int, error) {
buf := make([]byte, len(msg)+8) buf := make([]byte, len(msg)+NonceBytes)
copy(buf, msg) copy(buf, msg)
binary.BigEndian.PutUint64(buf[len(msg):], nonce) putUint64(buf[len(msg):], nonce)
return w.LeadingZeros(buf) return w.LeadingZeros(buf)
} }
func (w *Worker) worker(msg []byte, startNonce uint64, target int, done *uint32, counter *uint64) (uint64, error) { func (w *Worker) worker(msg []byte, startNonce uint64, target int, done *uint32, counter *uint64) (uint64, error) {
buf := make([]byte, len(msg)+8) buf := make([]byte, len(msg)+NonceBytes)
copy(buf, msg) copy(buf, msg)
asAnInt := new(big.Int) asAnInt := new(big.Int)
...@@ -125,7 +128,7 @@ func (w *Worker) worker(msg []byte, startNonce uint64, target int, done *uint32, ...@@ -125,7 +128,7 @@ func (w *Worker) worker(msg []byte, startNonce uint64, target int, done *uint32,
atomic.AddUint64(counter, 1) atomic.AddUint64(counter, 1)
// write nonce in the buffer // write nonce in the buffer
binary.BigEndian.PutUint64(buf[len(msg):], nonce) putUint64(buf[len(msg):], nonce)
digest, err := w.sum(buf) digest, err := w.sum(buf)
if err != nil { if err != nil {
...@@ -149,3 +152,7 @@ func (w *Worker) sum(data []byte) ([]byte, error) { ...@@ -149,3 +152,7 @@ func (w *Worker) sum(data []byte) ([]byte, error) {
} }
return h.Sum(nil), nil return h.Sum(nil), nil
} }
func putUint64(b []byte, v uint64) {
binary.LittleEndian.PutUint64(b, v)
}
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
_ "golang.org/x/crypto/sha3" // required by crypto.SHA3_256 _ "golang.org/x/crypto/blake2b" // required by crypto.BLAKE2b_512
) )
const ( const (
...@@ -18,7 +18,7 @@ const ( ...@@ -18,7 +18,7 @@ const (
target = 10 target = 10
) )
var testWorker = New(crypto.SHA3_256, workers) var testWorker = New(crypto.BLAKE2b_512, workers)
func TestWorker_Work(t *testing.T) { func TestWorker_Work(t *testing.T) {
nonce, err := testWorker.Mine(context.Background(), nil, target) nonce, err := testWorker.Mine(context.Background(), nil, target)
...@@ -36,11 +36,11 @@ func TestWorker_Validate(t *testing.T) { ...@@ -36,11 +36,11 @@ func TestWorker_Validate(t *testing.T) {
expErr error expErr error
}{ }{
{msg: nil, nonce: 0, expLeadingZeros: 1, expErr: nil}, {msg: nil, nonce: 0, expLeadingZeros: 1, expErr: nil},
{msg: nil, nonce: 13176245766944605079, expLeadingZeros: 29, expErr: nil}, {msg: nil, nonce: 4611686018451317632, expLeadingZeros: 28, expErr: nil},
{msg: make([]byte, 1024), nonce: 0, expLeadingZeros: 4, expErr: nil}, {msg: make([]byte, 10240), nonce: 0, expLeadingZeros: 1, expErr: nil},
} }
w := &Worker{hash: crypto.SHA3_256} w := &Worker{hash: crypto.BLAKE2b_512}
for _, tt := range tests { for _, tt := range tests {
zeros, err := w.LeadingZerosWithNonce(tt.msg, tt.nonce) zeros, err := w.LeadingZerosWithNonce(tt.msg, tt.nonce)
assert.Equal(t, tt.expLeadingZeros, zeros) assert.Equal(t, tt.expLeadingZeros, zeros)
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/iotaledger/goshimmer/plugins/messagelayer" "github.com/iotaledger/goshimmer/plugins/messagelayer"
"github.com/iotaledger/goshimmer/plugins/metrics" "github.com/iotaledger/goshimmer/plugins/metrics"
"github.com/iotaledger/goshimmer/plugins/portcheck" "github.com/iotaledger/goshimmer/plugins/portcheck"
"github.com/iotaledger/goshimmer/plugins/pow"
"github.com/iotaledger/goshimmer/plugins/profiling" "github.com/iotaledger/goshimmer/plugins/profiling"
"github.com/iotaledger/goshimmer/plugins/sync" "github.com/iotaledger/goshimmer/plugins/sync"
...@@ -31,6 +32,7 @@ var PLUGINS = node.Plugins( ...@@ -31,6 +32,7 @@ var PLUGINS = node.Plugins(
profiling.Plugin, profiling.Plugin,
database.Plugin, database.Plugin,
autopeering.Plugin, autopeering.Plugin,
pow.Plugin,
messagelayer.Plugin, messagelayer.Plugin,
gossip.Plugin, gossip.Plugin,
issuer.Plugin, issuer.Plugin,
......
...@@ -44,7 +44,7 @@ func configure(*node.Plugin) { ...@@ -44,7 +44,7 @@ func configure(*node.Plugin) {
Tangle = tangle.New(store) Tangle = tangle.New(store)
// Setup MessageFactory (behavior + logging)) // Setup MessageFactory (behavior + logging))
MessageFactory = messagefactory.New(database.Store(), local.GetInstance().LocalIdentity(), TipSelector, []byte(DBSequenceNumber)) MessageFactory = messagefactory.New(database.Store(), []byte(DBSequenceNumber), local.GetInstance().LocalIdentity(), TipSelector)
MessageFactory.Events.MessageConstructed.Attach(events.NewClosure(Tangle.AttachMessage)) MessageFactory.Events.MessageConstructed.Attach(events.NewClosure(Tangle.AttachMessage))
MessageFactory.Events.Error.Attach(events.NewClosure(func(err error) { MessageFactory.Events.Error.Attach(events.NewClosure(func(err error) {
log.Errorf("internal error in message factory: %v", err) log.Errorf("internal error in message factory: %v", err)
......
package pow
import (
"time"
flag "github.com/spf13/pflag"
)
const (
// CfgPOWDifficulty defines the config flag of the PoW difficulty.
CfgPOWDifficulty = "pow.difficulty"
// CfgPOWNumThreads defines the config flag of the number of threads used to do the PoW.
CfgPOWNumThreads = "pow.numThreads"
// CfgPOWTimeout defines the config flag for the PoW timeout.
CfgPOWTimeout = "pow.timeout"
)
func init() {
flag.Int(CfgPOWDifficulty, 22, "PoW difficulty")
flag.Int(CfgPOWNumThreads, 1, "number of threads used to do the PoW")
flag.Duration(CfgPOWTimeout, time.Minute, "PoW timeout")
}
package pow
import (
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/messagefactory"
"github.com/iotaledger/goshimmer/packages/binary/messagelayer/messageparser/builtinfilters"
"github.com/iotaledger/goshimmer/plugins/messagelayer"
"github.com/iotaledger/hive.go/logger"
"github.com/iotaledger/hive.go/node"
)
// PluginName is the name of the PoW plugin.
const PluginName = "PoW"
var (
// Plugin is the plugin instance of the PoW plugin.
Plugin = node.NewPlugin(PluginName, node.Enabled, run)
)
func run(*node.Plugin) {
// assure that the logger is available
log := logger.NewLogger(PluginName)
if node.IsSkipped(messagelayer.Plugin) {
log.Infof("%s is disabled; skipping %s\n", messagelayer.PluginName, PluginName)
return
}
// assure that the PoW worker is initialized
worker := Worker()
messagelayer.MessageParser.AddBytesFilter(builtinfilters.NewPowFilter(worker, difficulty))
messagelayer.MessageFactory.SetWorker(messagefactory.WorkerFunc(DoPOW))
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment