From 50381055ad427790b5761426604be1f5a16f5d29 Mon Sep 17 00:00:00 2001
From: Wolfgang Welz <welzwo@gmail.com>
Date: Fri, 2 Aug 2019 14:01:22 +0200
Subject: [PATCH] Improve error handling in BadgerDB singleton

---
 packages/database/badger_instance.go | 77 +++++++++++++++++-----------
 packages/database/database.go        | 48 ++++++++---------
 2 files changed, 70 insertions(+), 55 deletions(-)

diff --git a/packages/database/badger_instance.go b/packages/database/badger_instance.go
index 50b5d8ba..6cf834f7 100644
--- a/packages/database/badger_instance.go
+++ b/packages/database/badger_instance.go
@@ -6,48 +6,63 @@ import (
 
 	"github.com/dgraph-io/badger"
 	"github.com/dgraph-io/badger/options"
+	"github.com/pkg/errors"
 )
 
 var instance *badger.DB
+var once sync.Once
 
-var openLock sync.Mutex
-
-func GetBadgerInstance() (result *badger.DB, err error) {
-	openLock.Lock()
+// Returns whether the given file or directory exists.
+func exists(path string) (bool, error) {
+	_, err := os.Stat(path)
+	if err == nil {
+		return true, nil
+	}
+	if os.IsNotExist(err) {
+		return false, nil
+	}
+	return false, err
+}
 
-	if instance == nil {
-		directory := *DIRECTORY.Value
+func checkDir(dir string) error {
+	exists, err := exists(dir)
+	if err != nil {
+		return err
+	}
 
-		if _, osErr := os.Stat(directory); os.IsNotExist(osErr) {
-			if osErr := os.Mkdir(directory, 0700); osErr != nil {
-				err = osErr
+	if !exists {
+		return os.Mkdir(dir, 0700)
+	}
+	return nil
+}
 
-				return
-			}
-		} else if osErr != nil {
-			err = osErr
+func createDB() (*badger.DB, error) {
+	directory := *DIRECTORY.Value
+	if err := checkDir(directory); err != nil {
+		return nil, errors.Wrap(err, "Could not check directory")
+	}
 
-			return
-		}
+	opts := badger.DefaultOptions(directory)
+	opts.Logger = &logger{}
+	opts.Truncate = true
+	opts.TableLoadingMode = options.MemoryMap
 
-		opts := badger.DefaultOptions(directory)
-		opts.Logger = &logger{}
-		opts.Truncate = true
-		opts.TableLoadingMode = options.MemoryMap
+	db, err := badger.Open(opts)
+	if err != nil {
+		return nil, errors.Wrap(err, "Could not open new DB")
+	}
 
-		db, badgerErr := badger.Open(opts)
-		if badgerErr != nil {
-			err = badgerErr
+	return db, nil
+}
 
-			return
+func GetBadgerInstance() *badger.DB {
+	once.Do(func() {
+		db, err := createDB()
+		if err != nil {
+			// errors should cause a panic to avoid singleton deadlocks
+			panic(err)
 		}
-
 		instance = db
-	}
-
-	openLock.Unlock()
-
-	result = instance
-
-	return
+	})
+	return instance
 }
diff --git a/packages/database/database.go b/packages/database/database.go
index 4a0e9657..573d0dd4 100644
--- a/packages/database/database.go
+++ b/packages/database/database.go
@@ -6,12 +6,10 @@ import (
 	"github.com/dgraph-io/badger"
 )
 
-var databasesByName = make(map[string]*databaseImpl)
-var getLock sync.Mutex
+var dbMap = make(map[string]*prefixDb)
+var mu sync.Mutex
 
-var ErrKeyNotFound = badger.ErrKeyNotFound
-
-type databaseImpl struct {
+type prefixDb struct {
 	db       *badger.DB
 	name     string
 	prefix   []byte
@@ -19,30 +17,32 @@ type databaseImpl struct {
 }
 
 func Get(name string) (Database, error) {
-	getLock.Lock()
-	defer getLock.Unlock()
-
-	if database, exists := databasesByName[name]; exists {
-		return database, nil
+	// avoid locking if it's a clean hit
+	if db, exists := dbMap[name]; exists {
+		return db, nil
 	}
 
-	badgerInstance, err := GetBadgerInstance()
-	if err != nil {
-		return nil, err
+	mu.Lock()
+	defer mu.Unlock()
+
+	// needs to be re-checked after locking
+	if db, exists := dbMap[name]; exists {
+		return db, nil
 	}
 
-	database := &databaseImpl{
-		db:     badgerInstance,
+	badger := GetBadgerInstance()
+	db := &prefixDb{
+		db:     badger,
 		name:   name,
 		prefix: []byte(name + "_"),
 	}
 
-	databasesByName[name] = database
+	dbMap[name] = db
 
-	return databasesByName[name], nil
+	return db, nil
 }
 
-func (this *databaseImpl) Set(key []byte, value []byte) error {
+func (this *prefixDb) Set(key []byte, value []byte) error {
 	if err := this.db.Update(func(txn *badger.Txn) error { return txn.Set(append(this.prefix, key...), value) }); err != nil {
 		return err
 	}
@@ -50,7 +50,7 @@ func (this *databaseImpl) Set(key []byte, value []byte) error {
 	return nil
 }
 
-func (this *databaseImpl) Contains(key []byte) (bool, error) {
+func (this *prefixDb) Contains(key []byte) (bool, error) {
 	err := this.db.View(func(txn *badger.Txn) error {
 		_, err := txn.Get(append(this.prefix, key...))
 		if err != nil {
@@ -60,14 +60,14 @@ func (this *databaseImpl) Contains(key []byte) (bool, error) {
 		return nil
 	})
 
-	if err == ErrKeyNotFound {
+	if err == badger.ErrKeyNotFound {
 		return false, nil
 	} else {
 		return err == nil, err
 	}
 }
 
-func (this *databaseImpl) Get(key []byte) ([]byte, error) {
+func (this *prefixDb) Get(key []byte) ([]byte, error) {
 	var result []byte = nil
 
 	err := this.db.View(func(txn *badger.Txn) error {
@@ -86,7 +86,7 @@ func (this *databaseImpl) Get(key []byte) ([]byte, error) {
 	return result, err
 }
 
-func (this *databaseImpl) Delete(key []byte) error {
+func (this *prefixDb) Delete(key []byte) error {
 	err := this.db.Update(func(txn *badger.Txn) error {
 		err := txn.Delete(append(this.prefix, key...))
 		return err
@@ -94,10 +94,10 @@ func (this *databaseImpl) Delete(key []byte) error {
 	return err
 }
 
-func (this *databaseImpl) ForEach(consumer func([]byte, []byte)) error {
+func (this *prefixDb) ForEach(consumer func([]byte, []byte)) error {
 	err := this.db.View(func(txn *badger.Txn) error {
 		iteratorOptions := badger.DefaultIteratorOptions
-		iteratorOptions.Prefix = this.prefix
+		iteratorOptions.Prefix = this.prefix // filter by prefix
 
 		// create an iterator the default options
 		it := txn.NewIterator(iteratorOptions)
-- 
GitLab