package gossip import ( "strconv" "sync" "github.com/iotaledger/goshimmer/packages/errors" "github.com/iotaledger/goshimmer/packages/network" "github.com/iotaledger/hive.go/events" ) // region constants and variables ////////////////////////////////////////////////////////////////////////////////////// var DEFAULT_PROTOCOL = protocolDefinition{ version: VERSION_1, initializer: protocolV1, } // endregion /////////////////////////////////////////////////////////////////////////////////////////////////////////// // region protocol ///////////////////////////////////////////////////////////////////////////////////////////////////// type protocol struct { Conn *network.ManagedConnection Neighbor *Neighbor Version byte sendHandshakeCompleted bool receiveHandshakeCompleted bool SendState protocolState ReceivingState protocolState Events protocolEvents sendMutex sync.Mutex handshakeMutex sync.Mutex } func newProtocol(conn *network.ManagedConnection) *protocol { protocol := &protocol{ Conn: conn, Events: protocolEvents{ ReceiveVersion: events.NewEvent(intCaller), ReceiveIdentification: events.NewEvent(identityCaller), ReceiveConnectionAccepted: events.NewEvent(events.CallbackCaller), ReceiveConnectionRejected: events.NewEvent(events.CallbackCaller), ReceiveTransactionData: events.NewEvent(dataCaller), HandshakeCompleted: events.NewEvent(events.CallbackCaller), Error: events.NewEvent(errorCaller), }, sendHandshakeCompleted: false, receiveHandshakeCompleted: false, } protocol.SendState = &versionState{protocol: protocol} protocol.ReceivingState = &versionState{protocol: protocol} return protocol } func (protocol *protocol) Init() { // setup event handlers onReceiveData := events.NewClosure(protocol.Receive) onConnectionAccepted := events.NewClosure(func() { protocol.handshakeMutex.Lock() defer protocol.handshakeMutex.Unlock() protocol.receiveHandshakeCompleted = true if protocol.sendHandshakeCompleted { protocol.Events.HandshakeCompleted.Trigger() } }) var onClose *events.Closure onClose = events.NewClosure(func() { protocol.Conn.Events.ReceiveData.Detach(onReceiveData) protocol.Conn.Events.Close.Detach(onClose) protocol.Events.ReceiveConnectionAccepted.Detach(onConnectionAccepted) }) // region register event handlers protocol.Conn.Events.ReceiveData.Attach(onReceiveData) protocol.Conn.Events.Close.Attach(onClose) protocol.Events.ReceiveConnectionAccepted.Attach(onConnectionAccepted) // send protocol version if err := protocol.Send(DEFAULT_PROTOCOL.version); err != nil { return } // initialize default protocol if err := DEFAULT_PROTOCOL.initializer(protocol); err != nil { protocol.SendState = nil _ = protocol.Conn.Close() protocol.Events.Error.Trigger(err) return } // start reading from the connection _, _ = protocol.Conn.Read(make([]byte, 1000)) } func (protocol *protocol) Receive(data []byte) { offset := 0 length := len(data) for offset < length && protocol.ReceivingState != nil { if readBytes, err := protocol.ReceivingState.Receive(data, offset, length); err != nil { Events.Error.Trigger(err) _ = protocol.Conn.Close() return } else { offset += readBytes } } } func (protocol *protocol) Send(data interface{}) errors.IdentifiableError { protocol.sendMutex.Lock() defer protocol.sendMutex.Unlock() return protocol.send(data) } func (protocol *protocol) send(data interface{}) errors.IdentifiableError { if protocol.SendState != nil { if err := protocol.SendState.Send(data); err != nil { protocol.SendState = nil _ = protocol.Conn.Close() protocol.Events.Error.Trigger(err) return err } } return nil } // endregion /////////////////////////////////////////////////////////////////////////////////////////////////////////// // region versionState ///////////////////////////////////////////////////////////////////////////////////////////////// type versionState struct { protocol *protocol } func (state *versionState) Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) { switch data[offset] { case 1: protocol := state.protocol protocol.Version = 1 protocol.Events.ReceiveVersion.Trigger(1) protocol.ReceivingState = newIndentificationStateV1(protocol) return 1, nil default: return 1, ErrInvalidStateTransition.Derive("invalid version state transition (" + strconv.Itoa(int(data[offset])) + ")") } } func (state *versionState) Send(param interface{}) errors.IdentifiableError { if version, ok := param.(byte); ok { switch version { case VERSION_1: protocol := state.protocol if _, err := protocol.Conn.Write([]byte{version}); err != nil { return ErrSendFailed.Derive(err, "failed to send version byte") } protocol.SendState = newIndentificationStateV1(protocol) return nil } } return ErrInvalidSendParam.Derive("passed in parameter is not a valid version byte") } // endregion /////////////////////////////////////////////////////////////////////////////////////////////////////////// // region types and interfaces ///////////////////////////////////////////////////////////////////////////////////////// type protocolState interface { Send(param interface{}) errors.IdentifiableError Receive(data []byte, offset int, length int) (int, errors.IdentifiableError) } type protocolDefinition struct { version byte initializer func(*protocol) errors.IdentifiableError } // endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////