Skip to content
Snippets Groups Projects
Commit 64aa79c0 authored by Hans Moog's avatar Hans Moog
Browse files

Refactor: refactored state machine of gossip protocol

parent 282a620b
No related branches found
No related tags found
No related merge requests found
......@@ -7,4 +7,6 @@ var (
ErrInvalidAuthenticationMessage = errors.Wrap(errors.New("protocol error"), "invalid authentication message")
ErrInvalidIdentity = errors.Wrap(errors.New("protocol error"), "invalid identity message")
ErrInvalidStateTransition = errors.New("protocol error: invalid state transition message")
ErrSendFailed = errors.Wrap(errors.New("protocol error"), "failed to send message")
ErrInvalidSendParam = errors.New("invalid parameter passed to send")
)
package gossip
import (
"github.com/iotaledger/goshimmer/packages/errors"
"github.com/iotaledger/goshimmer/packages/events"
"github.com/iotaledger/goshimmer/packages/identity"
"github.com/iotaledger/goshimmer/packages/network"
......@@ -68,6 +69,6 @@ func connectionCaller(handler interface{}, params ...interface{}) { handler.(fun
func peerCaller(handler interface{}, params ...interface{}) { handler.(func(*Peer))(params[0].(*Peer)) }
func errorCaller(handler interface{}, params ...interface{}) { handler.(func(error))(params[0].(error)) }
func errorCaller(handler interface{}, params ...interface{}) { handler.(func(errors.IdentifiableError))(params[0].(errors.IdentifiableError)) }
func transactionCaller(handler interface{}, params ...interface{}) { handler.(func(*transaction.Transaction))(params[0].(*transaction.Transaction)) }
......@@ -74,7 +74,7 @@ func manageConnection(plugin *node.Plugin, neighbor *Peer) {
}))
if dialed {
go newProtocol(conn).init()
go newProtocol(conn).Init()
}
// wait for shutdown or
......
package gossip
import (
"github.com/iotaledger/goshimmer/packages/accountability"
"github.com/iotaledger/goshimmer/packages/errors"
"github.com/iotaledger/goshimmer/packages/events"
"github.com/iotaledger/goshimmer/packages/identity"
"github.com/iotaledger/goshimmer/packages/network"
"strconv"
)
// region interfaces ///////////////////////////////////////////////////////////////////////////////////////////////////
// region constants and variables //////////////////////////////////////////////////////////////////////////////////////
type protocolState interface {
Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError)
var DEFAULT_PROTOCOL = protocolDefinition{
version: 1,
initializer: protocolV1,
}
// endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
//region protocol //////////////////////////////////////////////////////////////////////////////////////////////////////
// region protocol /////////////////////////////////////////////////////////////////////////////////////////////////////
type protocol struct {
Conn *network.ManagedConnection
Neighbor *Peer
Version int
CurrentState protocolState
SendState protocolState
ReceivingState protocolState
Events protocolEvents
}
func newProtocol(conn *network.ManagedConnection) *protocol {
protocol := &protocol{
Conn: conn,
CurrentState: &versionState{},
Events: protocolEvents{
ReceiveVersion: events.NewEvent(intCaller),
ReceiveIdentification: events.NewEvent(identityCaller),
ReceiveConnectionAccepted: events.NewEvent(events.CallbackCaller),
ReceiveConnectionRejected: events.NewEvent(events.CallbackCaller),
Error: events.NewEvent(errorCaller),
},
}
return protocol
}
func (protocol *protocol) sendVersion() {
protocol.Conn.Write([]byte{1})
}
func (protocol *protocol) sendIdentification() {
if signature, err := accountability.OWN_ID.Sign(accountability.OWN_ID.Identifier); err == nil {
protocol.Conn.Write(accountability.OWN_ID.Identifier)
protocol.Conn.Write(signature)
}
}
func (protocol *protocol) rejectConnection() {
protocol.Conn.Write([]byte{0})
protocol.Conn.Close()
}
func (protocol *protocol) acceptConnection() {
protocol.Conn.Write([]byte{1})
}
protocol.SendState = &versionState{protocol: protocol}
protocol.ReceivingState = &versionState{protocol: protocol}
func (protocol *protocol) init() {
//region setup event handlers
onReceiveIdentification := events.NewClosure(func(identity *identity.Identity) {
if protocol.Neighbor == nil {
protocol.rejectConnection()
} else {
protocol.acceptConnection()
return protocol
}
})
onReceiveData := events.NewClosure(protocol.parseData)
var onClose *events.Closure // define var first so we can use it in the closure
func (protocol *protocol) Init() {
// setup event handlers
onReceiveData := events.NewClosure(protocol.Receive)
var onClose *events.Closure
onClose = events.NewClosure(func() {
protocol.Conn.Events.ReceiveData.Detach(onReceiveData)
protocol.Conn.Events.Close.Detach(onClose)
})
//endregion
// region register event handlers
protocol.Events.ReceiveIdentification.Attach(onReceiveIdentification)
protocol.Conn.Events.ReceiveData.Attach(onReceiveData)
protocol.Conn.Events.Close.Attach(onClose)
//endregion
//region send initial handshake
protocol.sendVersion()
protocol.sendIdentification()
//endregion
// send protocol version
if err := protocol.Send(DEFAULT_PROTOCOL.version); err != nil {
return
}
// initialize default protocol
if err := DEFAULT_PROTOCOL.initializer(protocol); err != nil {
protocol.SendState = nil
_ = protocol.Conn.Close()
protocol.Events.Error.Trigger(err)
return
}
// start reading from the connection
protocol.Conn.Read(make([]byte, 1000))
_, _ = protocol.Conn.Read(make([]byte, 1000))
}
func (protocol *protocol) parseData(data []byte) {
func (protocol *protocol) Receive(data []byte) {
offset := 0
length := len(data)
for offset < length && protocol.CurrentState != nil {
if readBytes, err := protocol.CurrentState.Consume(protocol, data, offset, length); err != nil {
for offset < length && protocol.ReceivingState != nil {
if readBytes, err := protocol.ReceivingState.Receive(data, offset, length); err != nil {
Events.Error.Trigger(err)
protocol.Neighbor.InitiatedConn.Close()
_ = protocol.Conn.Close()
return
} else {
......@@ -113,19 +94,39 @@ func (protocol *protocol) parseData(data []byte) {
}
}
// endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (protocol *protocol) Send(data interface{}) errors.IdentifiableError {
if protocol.SendState != nil {
if err := protocol.SendState.Send(data); err != nil {
protocol.SendState = nil
_ = protocol.Conn.Close()
protocol.Events.Error.Trigger(err)
return err
}
}
return nil
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region versionState /////////////////////////////////////////////////////////////////////////////////////////////////
type versionState struct{}
type versionState struct {
protocol *protocol
}
func (state *versionState) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
func (state *versionState) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
switch data[offset] {
case 1:
protocol := state.protocol
protocol.Version = 1
protocol.Events.ReceiveVersion.Trigger(1)
protocol.CurrentState = newIndentificationStateV1()
protocol.ReceivingState = newIndentificationStateV1(protocol)
return 1, nil
......@@ -134,4 +135,37 @@ func (state *versionState) Consume(protocol *protocol, data []byte, offset int,
}
}
// endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *versionState) Send(param interface{}) errors.IdentifiableError {
if version, ok := param.(byte); ok {
switch version {
case VERSION_1:
protocol := state.protocol
if _, err := protocol.Conn.Write([]byte{version}); err != nil {
return ErrSendFailed.Derive(err, "failed to send version byte")
}
protocol.SendState = newIndentificationStateV1(protocol)
return nil
}
}
return ErrInvalidSendParam.Derive("passed in parameter is not a valid version byte")
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region types and interfaces /////////////////////////////////////////////////////////////////////////////////////////
type protocolState interface {
Send(param interface{}) errors.IdentifiableError
Receive(data []byte, offset int, length int) (int, errors.IdentifiableError)
}
type protocolDefinition struct {
version byte
initializer func(*protocol) errors.IdentifiableError
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -2,28 +2,58 @@ package gossip
import (
"bytes"
"github.com/iotaledger/goshimmer/packages/accountability"
"github.com/iotaledger/goshimmer/packages/byteutils"
"github.com/iotaledger/goshimmer/packages/errors"
"github.com/iotaledger/goshimmer/packages/events"
"github.com/iotaledger/goshimmer/packages/identity"
"github.com/iotaledger/goshimmer/packages/transaction"
"strconv"
)
//region indentificationStateV1 ////////////////////////////////////////////////////////////////////////////////////////
// region protocolV1 ///////////////////////////////////////////////////////////////////////////////////////////////////
func protocolV1(protocol *protocol) errors.IdentifiableError {
if err := protocol.Send(accountability.OWN_ID); err != nil {
return err
}
onReceiveIdentification := events.NewClosure(func(identity *identity.Identity) {
if protocol.Neighbor == nil {
if err := protocol.Send(CONNECTION_REJECT); err != nil {
return
}
} else {
if err := protocol.Send(CONNECTION_ACCEPT); err != nil {
return
}
}
})
protocol.Events.ReceiveIdentification.Attach(onReceiveIdentification)
return nil
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region indentificationStateV1 ///////////////////////////////////////////////////////////////////////////////////////
type indentificationStateV1 struct {
protocol *protocol
buffer []byte
offset int
}
func newIndentificationStateV1() *indentificationStateV1 {
func newIndentificationStateV1(protocol *protocol) *indentificationStateV1 {
return &indentificationStateV1{
protocol: protocol,
buffer: make([]byte, MARSHALLED_IDENTITY_TOTAL_SIZE),
offset: 0,
}
}
func (state *indentificationStateV1) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
func (state *indentificationStateV1) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
bytesRead := byteutils.ReadAvailableBytesToBuffer(state.buffer, state.offset, data, offset, length)
state.offset += bytesRead
......@@ -31,6 +61,8 @@ func (state *indentificationStateV1) Consume(protocol *protocol, data []byte, of
if receivedIdentity, err := unmarshalIdentity(state.buffer); err != nil {
return bytesRead, ErrInvalidAuthenticationMessage.Derive(err, "invalid authentication message")
} else {
protocol := state.protocol
if neighbor, exists := GetNeighbor(receivedIdentity.StringIdentifier); exists {
protocol.Neighbor = neighbor
} else {
......@@ -39,7 +71,7 @@ func (state *indentificationStateV1) Consume(protocol *protocol, data []byte, of
protocol.Events.ReceiveIdentification.Trigger(receivedIdentity)
protocol.CurrentState = newacceptanceStateV1()
protocol.ReceivingState = newacceptanceStateV1(protocol)
state.offset = 0
}
}
......@@ -47,6 +79,27 @@ func (state *indentificationStateV1) Consume(protocol *protocol, data []byte, of
return bytesRead, nil
}
func (state *indentificationStateV1) Send(param interface{}) errors.IdentifiableError {
if id, ok := param.(*identity.Identity); ok {
if signature, err := id.Sign(id.Identifier); err == nil {
protocol := state.protocol
if _, err := protocol.Conn.Write(id.Identifier); err != nil {
return ErrSendFailed.Derive(err, "failed to send identifier")
}
if _, err := protocol.Conn.Write(signature); err != nil {
return ErrSendFailed.Derive(err, "failed to send signature")
}
protocol.SendState = newacceptanceStateV1(protocol)
return nil
}
}
return ErrInvalidSendParam.Derive("passed in parameter is not a valid identity")
}
func unmarshalIdentity(data []byte) (*identity.Identity, error) {
identifier := data[MARSHALLED_IDENTITY_START:MARSHALLED_IDENTITY_END]
......@@ -61,31 +114,33 @@ func unmarshalIdentity(data []byte) (*identity.Identity, error) {
}
}
//endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
//region acceptanceStateV1 /////////////////////////////////////////////////////////////////////////////////////////////
// region acceptanceStateV1 ////////////////////////////////////////////////////////////////////////////////////////////
type acceptanceStateV1 struct {}
type acceptanceStateV1 struct {
protocol *protocol
}
func newacceptanceStateV1() *acceptanceStateV1 {
return &acceptanceStateV1{}
func newacceptanceStateV1(protocol *protocol) *acceptanceStateV1 {
return &acceptanceStateV1{protocol: protocol}
}
func (state *acceptanceStateV1) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
func (state *acceptanceStateV1) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
protocol := state.protocol
switch data[offset] {
case 0:
protocol.Events.ReceiveConnectionRejected.Trigger()
protocol.Conn.Close()
_ = protocol.Conn.Close()
protocol.CurrentState = nil
break
protocol.ReceivingState = nil
case 1:
protocol.Events.ReceiveConnectionAccepted.Trigger()
protocol.CurrentState = newDispatchStateV1()
break
protocol.ReceivingState = newDispatchStateV1(protocol)
default:
return 1, ErrInvalidStateTransition.Derive("invalid acceptance state transition (" + strconv.Itoa(int(data[offset])) + ")")
......@@ -94,59 +149,135 @@ func (state *acceptanceStateV1) Consume(protocol *protocol, data []byte, offset
return 1, nil
}
//endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *acceptanceStateV1) Send(param interface{}) errors.IdentifiableError {
if responseType, ok := param.(byte); ok {
switch responseType {
case CONNECTION_REJECT:
protocol := state.protocol
if _, err := protocol.Conn.Write([]byte{CONNECTION_REJECT}); err != nil {
return ErrSendFailed.Derive(err, "failed to send reject message")
}
_ = protocol.Conn.Close()
protocol.SendState = nil
return nil
//region dispatchStateV1 ///////////////////////////////////////////////////////////////////////////////////////////////
case CONNECTION_ACCEPT:
protocol := state.protocol
if _, err := protocol.Conn.Write([]byte{CONNECTION_ACCEPT}); err != nil {
return ErrSendFailed.Derive(err, "failed to send accept message")
}
type dispatchStateV1 struct {}
protocol.SendState = newDispatchStateV1(protocol)
func newDispatchStateV1() *dispatchStateV1 {
return &dispatchStateV1{}
return nil
}
}
return ErrInvalidSendParam.Derive("passed in parameter is not a valid acceptance byte")
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *dispatchStateV1) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
// region dispatchStateV1 //////////////////////////////////////////////////////////////////////////////////////////////
type dispatchStateV1 struct {
protocol *protocol
}
func newDispatchStateV1(protocol *protocol) *dispatchStateV1 {
return &dispatchStateV1{
protocol: protocol,
}
}
func (state *dispatchStateV1) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
switch data[0] {
case 0:
protocol := state.protocol
protocol.Events.ReceiveConnectionRejected.Trigger()
protocol.Neighbor.InitiatedConn.Close()
protocol.CurrentState = nil
_ = protocol.Conn.Close()
protocol.ReceivingState = nil
case 1:
protocol.CurrentState = newTransactionStateV1()
break
protocol := state.protocol
protocol.ReceivingState = newTransactionStateV1(protocol)
case 2:
protocol.CurrentState = newRequestStateV1()
break
protocol := state.protocol
protocol.ReceivingState = newRequestStateV1(protocol)
default:
return 1, ErrInvalidStateTransition.Derive("invalid dispatch state transition (" + strconv.Itoa(int(data[offset])) + ")")
}
return 1, nil
}
//endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *dispatchStateV1) Send(param interface{}) errors.IdentifiableError {
if dispatchByte, ok := param.(byte); ok {
switch dispatchByte {
case DISPATCH_DROP:
protocol := state.protocol
_ = protocol.Conn.Close()
protocol.SendState = nil
return nil
case DISPATCH_TRANSACTION:
protocol := state.protocol
protocol.SendState = newTransactionStateV1(protocol)
return nil
case DISPATCH_REQUEST:
protocol := state.protocol
//region transactionStateV1 ////////////////////////////////////////////////////////////////////////////////////////////
protocol.SendState = newTransactionStateV1(protocol)
return nil
}
}
return ErrInvalidSendParam.Derive("passed in parameter is not a valid dispatch byte")
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region transactionStateV1 ///////////////////////////////////////////////////////////////////////////////////////////
type transactionStateV1 struct {
protocol *protocol
buffer []byte
offset int
}
func newTransactionStateV1() *transactionStateV1 {
func newTransactionStateV1(protocol *protocol) *transactionStateV1 {
return &transactionStateV1{
buffer: make([]byte, transaction.MARSHALLED_TOTAL_SIZE),
offset: 0,
}
}
func (state *transactionStateV1) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
func (state *transactionStateV1) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
bytesRead := byteutils.ReadAvailableBytesToBuffer(state.buffer, state.offset, data, offset, length)
state.offset += bytesRead
if state.offset == transaction.MARSHALLED_TOTAL_SIZE {
protocol := state.protocol
transactionData := make([]byte, transaction.MARSHALLED_TOTAL_SIZE)
copy(transactionData, state.buffer)
......@@ -154,36 +285,55 @@ func (state *transactionStateV1) Consume(protocol *protocol, data []byte, offset
go processTransactionData(transactionData)
protocol.CurrentState = newDispatchStateV1()
protocol.ReceivingState = newDispatchStateV1(protocol)
state.offset = 0
}
return bytesRead, nil
}
//endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *transactionStateV1) Send(param interface{}) errors.IdentifiableError {
return nil
}
//region requestStateV1 ////////////////////////////////////////////////////////////////////////////////////////////////
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region requestStateV1 ///////////////////////////////////////////////////////////////////////////////////////////////
type requestStateV1 struct {
buffer []byte
offset int
}
func newRequestStateV1() *requestStateV1 {
func newRequestStateV1(protocol *protocol) *requestStateV1 {
return &requestStateV1{
buffer: make([]byte, 1),
offset: 0,
}
}
func (state *requestStateV1) Consume(protocol *protocol, data []byte, offset int, length int) (int, errors.IdentifiableError) {
func (state *requestStateV1) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) {
return 0, nil
}
//endregion ////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (state *requestStateV1) Send(param interface{}) errors.IdentifiableError {
return nil
}
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
// region constants and variables //////////////////////////////////////////////////////////////////////////////////////
const (
VERSION_1 = byte(1)
CONNECTION_REJECT = byte(0)
CONNECTION_ACCEPT = byte(1)
DISPATCH_DROP = byte(0)
DISPATCH_TRANSACTION = byte(1)
DISPATCH_REQUEST = byte(2)
MARSHALLED_IDENTITY_START = 0
MARSHALLED_IDENTITY_SIGNATURE_START = MARSHALLED_IDENTITY_END
......@@ -195,3 +345,5 @@ const (
MARSHALLED_IDENTITY_TOTAL_SIZE = MARSHALLED_IDENTITY_SIGNATURE_END
)
// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -2,6 +2,7 @@ package gossip
import (
"github.com/iotaledger/goshimmer/packages/daemon"
"github.com/iotaledger/goshimmer/packages/errors"
"github.com/iotaledger/goshimmer/packages/events"
"github.com/iotaledger/goshimmer/packages/identity"
"github.com/iotaledger/goshimmer/packages/network"
......@@ -16,6 +17,11 @@ func configureServer(plugin *node.Plugin) {
TCPServer.Events.Connect.Attach(events.NewClosure(func(conn *network.ManagedConnection) {
protocol := newProtocol(conn)
// print protocol errors
protocol.Events.Error.Attach(events.NewClosure(func(err errors.IdentifiableError) {
plugin.LogFailure(err.Error())
}))
// store connection in neighbor if its a neighbor calling
protocol.Events.ReceiveIdentification.Attach(events.NewClosure(func(identity *identity.Identity) {
if protocol.Neighbor != nil {
......@@ -34,7 +40,7 @@ func configureServer(plugin *node.Plugin) {
}
}))
go protocol.init()
go protocol.Init()
}))
daemon.Events.Shutdown.Attach(events.NewClosure(func() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment