diff --git a/packages/autopeering/server/server.go b/packages/autopeering/server/server.go index bd3d65da1323192b9018b8ac8d4cc6436053f31a..74d1bfbda4d911265602e11836bbeeaa624796bd 100644 --- a/packages/autopeering/server/server.go +++ b/packages/autopeering/server/server.go @@ -4,7 +4,6 @@ import ( "container/list" "fmt" "io" - "net" "sync" "time" @@ -12,6 +11,7 @@ import ( "github.com/iotaledger/goshimmer/packages/autopeering/peer" pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto" "github.com/iotaledger/goshimmer/packages/autopeering/transport" + "github.com/iotaledger/goshimmer/packages/netutil" "github.com/iotaledger/hive.go/logger" ) @@ -260,7 +260,7 @@ func (s *Server) readLoop() { for { b, fromAddr, err := s.trans.ReadFrom() - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + if netutil.IsTemporaryError(err) { // ignore temporary read errors. s.log.Debugw("temporary read error", "err", err) continue diff --git a/packages/gossip/neighbor.go b/packages/gossip/neighbor.go index 178105a49785ce9a866c813cba259b29366116b0..6fa7566fb6163994c0d799a63fb60ab7dd836ee0 100644 --- a/packages/gossip/neighbor.go +++ b/packages/gossip/neighbor.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/iotaledger/goshimmer/packages/autopeering/peer" + "github.com/iotaledger/goshimmer/packages/netutil" "github.com/iotaledger/hive.go/logger" "github.com/iotaledger/hive.go/network" ) @@ -114,11 +115,12 @@ func (n *Neighbor) readLoop() { for { _, err := n.ManagedConnection.Read(b) - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + if netutil.IsTemporaryError(err) { // ignore temporary read errors. n.log.Debugw("temporary read error", "err", err) continue - } else if err != nil { + } + if err != nil { // return from the loop on all other errors if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { n.log.Warnw("read error", "err", err) diff --git a/packages/gossip/server/server.go b/packages/gossip/server/server.go index f6fd2f8b8ce853e5ef6e0d5790002e4a57a277fa..bb52798c93bd7e6a43aace85cdae9d41745bc84c 100644 --- a/packages/gossip/server/server.go +++ b/packages/gossip/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/iotaledger/goshimmer/packages/autopeering/peer" "github.com/iotaledger/goshimmer/packages/autopeering/peer/service" pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto" + "github.com/iotaledger/goshimmer/packages/netutil" "github.com/iotaledger/hive.go/backoff" "go.uber.org/zap" ) @@ -294,10 +295,11 @@ func (t *TCP) listenLoop() { for { conn, err := t.listener.AcceptTCP() - if err, ok := err.(net.Error); ok && err.Temporary() { + if netutil.IsTemporaryError(err) { t.log.Debugw("temporary read error", "err", err) continue - } else if err != nil { + } + if err != nil { // return from the loop on all other errors if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { t.log.Warnw("listen error", "err", err) diff --git a/packages/netutil/netutil.go b/packages/netutil/netutil.go new file mode 100644 index 0000000000000000000000000000000000000000..b1137f562a53c70bf55d901a826ac96ce8d8f6b3 --- /dev/null +++ b/packages/netutil/netutil.go @@ -0,0 +1,104 @@ +// Package netutil provides utility functions extending the stdnet package. +package netutil + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net" + "net/http" + "time" +) + +var ( + errInvalidData = errors.New("invalid data received") +) + +// IsIPv4 returns true if ip is an IPv4 address. +func IsIPv4(ip net.IP) bool { + return ip.To4() != nil +} + +// GetPublicIP queries the ipify API for the public IP address. +func GetPublicIP(preferIPv6 bool) (net.IP, error) { + var url string + if preferIPv6 { + url = "https://api6.ipify.org" + } else { + url = "https://api.ipify.org" + } + resp, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("get failed: %w", err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read failed: %w", err) + } + + // the body only consists of the ip address + ip := net.ParseIP(string(body)) + if ip == nil { + return nil, fmt.Errorf("not an IP: %s", body) + } + return ip, nil +} + +// IsTemporaryError checks whether the given error should be considered temporary. +func IsTemporaryError(err error) bool { + tempErr, ok := err.(interface { + Temporary() bool + }) + return ok && tempErr.Temporary() +} + +// CheckUDP checks whether data send to remote is received at local, otherwise an error is returned. +// If checkAddress is set, it checks whether the IP address that was on the packet matches remote. +// If checkPort is set, it checks whether the port that was on the packet matches remote. +func CheckUDP(local, remote *net.UDPAddr, checkAddress bool, checkPort bool) error { + conn, err := net.ListenUDP("udp", local) + if err != nil { + return fmt.Errorf("listen failed: %w", err) + } + defer conn.Close() + + nonce := generateNonce() + _, err = conn.WriteTo(nonce, remote) + if err != nil { + return fmt.Errorf("write failed: %w", err) + } + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err != nil { + return fmt.Errorf("set timeout failed: %w", err) + } + + p := make([]byte, len(nonce)+1) + n, from, err := conn.ReadFrom(p) + if err != nil { + return fmt.Errorf("read failed: %w", err) + } + if n != len(nonce) || !bytes.Equal(p[:n], nonce) { + return errInvalidData + } + udpAddr := from.(*net.UDPAddr) + if checkAddress && udpAddr.IP.Equal(remote.IP) { + return fmt.Errorf("IP changed: %s", udpAddr.IP) + } + if checkPort && udpAddr.Port != remote.Port { + return fmt.Errorf("port changed: %d", udpAddr.Port) + } + + return nil +} + +func generateNonce() []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, rand.Uint64()) + return b +} diff --git a/packages/netutil/netutil_test.go b/packages/netutil/netutil_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f5543f8db0aaff816fc6a187da5844ea9fdc514f --- /dev/null +++ b/packages/netutil/netutil_test.go @@ -0,0 +1,71 @@ +package netutil + +import ( + "errors" + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsIPv4(t *testing.T) { + tests := []struct { + in net.IP + out bool + }{ + {nil, false}, + {net.IPv4zero, true}, + {net.IPv6zero, false}, + {net.ParseIP("127.0.0.1"), true}, + {net.IPv6loopback, false}, + {net.ParseIP("8.8.8.8"), true}, + {net.ParseIP("2001:4860:4860::8888"), false}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { + assert.Equal(t, IsIPv4(tt.in), tt.out) + }) + } +} + +func TestIsTemporaryError(t *testing.T) { + tests := []struct { + in error + out bool + }{ + {nil, false}, + {errors.New("errorString"), false}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { + assert.Equal(t, IsTemporaryError(tt.in), tt.out) + }) + } +} + +func TestCheckUDP(t *testing.T) { + local, err := getLocalUDPAddr() + require.NoError(t, err) + assert.NoError(t, CheckUDP(local, local, true, true)) + + invalid := &net.UDPAddr{ + IP: local.IP, + Port: local.Port - 1, + Zone: local.Zone, + } + assert.Error(t, CheckUDP(local, invalid, false, false)) +} + +func getLocalUDPAddr() (*net.UDPAddr, error) { + addr, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + return conn.LocalAddr().(*net.UDPAddr), conn.Close() +} diff --git a/plugins/autopeering/local/local.go b/plugins/autopeering/local/local.go index ac43be56d9de48ca561fdccc11a240cf27fe75d8..c5ca19a1109dbc0eff75a2f836905624d1d047fe 100644 --- a/plugins/autopeering/local/local.go +++ b/plugins/autopeering/local/local.go @@ -3,14 +3,12 @@ package local import ( "crypto/ed25519" "encoding/base64" - "fmt" - "io/ioutil" "net" - "net/http" "strconv" "sync" "github.com/iotaledger/goshimmer/packages/autopeering/peer" + "github.com/iotaledger/goshimmer/packages/netutil" "github.com/iotaledger/goshimmer/packages/parameter" "github.com/iotaledger/hive.go/logger" ) @@ -29,7 +27,7 @@ func configureLocal() *peer.Local { } if ip.IsUnspecified() { log.Info("Querying public IP ...") - myIp, err := getPublicIP(isIPv4(ip)) + myIp, err := netutil.GetPublicIP(!netutil.IsIPv4(ip)) if err != nil { log.Fatalf("Error querying public IP: %s", err) } @@ -69,37 +67,6 @@ func configureLocal() *peer.Local { return local } -func isIPv4(ip net.IP) bool { - return ip.To4() != nil -} - -func getPublicIP(ipv4 bool) (net.IP, error) { - var url string - if ipv4 { - url = "https://api.ipify.org" - } else { - url = "https://api6.ipify.org" - } - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // the body only consists of the ip address - ip := net.ParseIP(string(body)) - if ip == nil { - return nil, fmt.Errorf("not an IP: %s", body) - } - - return ip, nil -} - func GetInstance() *peer.Local { once.Do(func() { instance = configureLocal() }) return instance