diff --git a/packages/pow/pow.go b/packages/pow/pow.go new file mode 100644 index 0000000000000000000000000000000000000000..035086ecef88b65a712687d648fafae1c046a54d --- /dev/null +++ b/packages/pow/pow.go @@ -0,0 +1,151 @@ +package pow + +import ( + "context" + "encoding/binary" + "errors" + "hash" + "math" + "math/big" + "sync" + "sync/atomic" +) + +// errors returned by the PoW +var ( + ErrCancelled = errors.New("canceled") + ErrDone = errors.New("done") +) + +// Hash identifies a cryptographic hash function that is implemented in another package. +type Hash interface { + // Size returns the length, in bytes, of a digest resulting from the given hash function. + Size() int + // New returns a new hash.Hash calculating the given hash function. + New() hash.Hash +} + +// The Worker provides PoW functionality using an arbitrary hash function. +type Worker struct { + hash Hash + numWorkers int +} + +// New creates a new PoW based on the provided hash. +// The optional numWorkers specifies how many go routines are used to mine. +func New(hash Hash, numWorkers ...int) *Worker { + w := &Worker{ + hash: hash, + numWorkers: 1, + } + if len(numWorkers) > 0 { + w.numWorkers = numWorkers[0] + } + return w +} + +// Mine performs the PoW. +// It appends the 8-byte nonce to the provided msg and tries to find a nonce +// until the target number of leading zeroes is reached. +// The computation can be be canceled using the provided ctx. +func (w *Worker) Mine(ctx context.Context, msg []byte, target int) (uint64, error) { + var ( + done uint32 + counter uint64 + wg sync.WaitGroup + results = make(chan uint64, w.numWorkers) + closing = make(chan struct{}) + ) + + // stop when the context has been canceled + go func() { + select { + case <-ctx.Done(): + atomic.StoreUint32(&done, 1) + case <-closing: + return + } + }() + + workerWidth := math.MaxUint64 / uint64(w.numWorkers) + for i := 0; i < w.numWorkers; i++ { + startNonce := uint64(i) * workerWidth + wg.Add(1) + go func() { + defer wg.Done() + + nonce, workerErr := w.worker(msg, startNonce, target, &done, &counter) + if workerErr != nil { + return + } + atomic.StoreUint32(&done, 1) + results <- nonce + }() + } + wg.Wait() + close(results) + close(closing) + + nonce, ok := <-results + if !ok { + return 0, ErrCancelled + } + return nonce, nil +} + +// LeadingZeros returns the number of leading zeros in the digest of the given data. +func (w *Worker) LeadingZeros(data []byte) (int, error) { + digest, err := w.sum(data) + if err != nil { + return 0, err + } + asAnInt := new(big.Int).SetBytes(digest) + return 8*w.hash.Size() - asAnInt.BitLen(), nil +} + +// LeadingZerosWithNonce returns the number of leading zeros in the digest +// after the provided 8-byte nonce is appended to msg. +func (w *Worker) LeadingZerosWithNonce(msg []byte, nonce uint64) (int, error) { + buf := make([]byte, len(msg)+8) + copy(buf, msg) + binary.BigEndian.PutUint64(buf[len(msg):], nonce) + + return w.LeadingZeros(buf) +} + +func (w *Worker) worker(msg []byte, startNonce uint64, target int, done *uint32, counter *uint64) (uint64, error) { + buf := make([]byte, len(msg)+8) + copy(buf, msg) + asAnInt := new(big.Int) + + for nonce := startNonce; ; { + if atomic.LoadUint32(done) != 0 { + break + } + atomic.AddUint64(counter, 1) + + // write nonce in the buffer + binary.BigEndian.PutUint64(buf[len(msg):], nonce) + + digest, err := w.sum(buf) + if err != nil { + return 0, err + } + asAnInt.SetBytes(digest) + leadingZeros := 8*w.hash.Size() - asAnInt.BitLen() + if leadingZeros >= target { + return nonce, nil + } + + nonce++ + } + return 0, ErrDone +} + +func (w *Worker) sum(data []byte) ([]byte, error) { + h := w.hash.New() + if _, err := h.Write(data); err != nil { + return nil, err + } + return h.Sum(nil), nil +} diff --git a/packages/pow/pow_test.go b/packages/pow/pow_test.go new file mode 100644 index 0000000000000000000000000000000000000000..33e360161da68b9d68f214bfb1131e4be3937601 --- /dev/null +++ b/packages/pow/pow_test.go @@ -0,0 +1,78 @@ +package pow + +import ( + "context" + "crypto" + "math" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "golang.org/x/crypto/sha3" // required by crypto.SHA3_256 +) + +const ( + workers = 2 + target = 10 +) + +var testWorker = New(crypto.SHA3_256, workers) + +func TestWorker_Work(t *testing.T) { + nonce, err := testWorker.Mine(context.Background(), nil, target) + require.NoError(t, err) + difficulty, err := testWorker.LeadingZerosWithNonce(nil, nonce) + assert.GreaterOrEqual(t, difficulty, target) + assert.NoError(t, err) +} + +func TestWorker_Validate(t *testing.T) { + tests := []*struct { + msg []byte + nonce uint64 + expLeadingZeros int + expErr error + }{ + {msg: nil, nonce: 0, expLeadingZeros: 1, expErr: nil}, + {msg: nil, nonce: 13176245766944605079, expLeadingZeros: 29, expErr: nil}, + {msg: make([]byte, 1024), nonce: 0, expLeadingZeros: 4, expErr: nil}, + } + + w := &Worker{hash: crypto.SHA3_256} + for _, tt := range tests { + zeros, err := w.LeadingZerosWithNonce(tt.msg, tt.nonce) + assert.Equal(t, tt.expLeadingZeros, zeros) + assert.Equal(t, tt.expErr, err) + } +} + +func TestWorker_Cancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var err error + go func() { + _, err = testWorker.Mine(ctx, nil, math.MaxInt32) + }() + time.Sleep(10 * time.Millisecond) + cancel() + + assert.Eventually(t, func() bool { return err == ErrCancelled }, time.Second, 10*time.Millisecond) +} + +func BenchmarkWorker(b *testing.B) { + var ( + buf = make([]byte, 1024) + done uint32 + counter uint64 + ) + go func() { + _, _ = testWorker.worker(buf, 0, math.MaxInt32, &done, &counter) + }() + b.ResetTimer() + for atomic.LoadUint64(&counter) < uint64(b.N) { + } + atomic.StoreUint32(&done, 1) +}