Skip to content
Snippets Groups Projects
Unverified Commit e0fc94a5 authored by Wolfgang Welz's avatar Wolfgang Welz Committed by GitHub
Browse files

Feat: Add netutil package (#159)

* feat: add netutil package

* use IsTemporaryError

* remove unused import
parent 8b3d0c77
Branches
Tags
No related merge requests found
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"container/list" "container/list"
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"time" "time"
...@@ -12,6 +11,7 @@ import ( ...@@ -12,6 +11,7 @@ import (
"github.com/iotaledger/goshimmer/packages/autopeering/peer" "github.com/iotaledger/goshimmer/packages/autopeering/peer"
pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto" pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto"
"github.com/iotaledger/goshimmer/packages/autopeering/transport" "github.com/iotaledger/goshimmer/packages/autopeering/transport"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/logger" "github.com/iotaledger/hive.go/logger"
) )
...@@ -260,7 +260,7 @@ func (s *Server) readLoop() { ...@@ -260,7 +260,7 @@ func (s *Server) readLoop() {
for { for {
b, fromAddr, err := s.trans.ReadFrom() b, fromAddr, err := s.trans.ReadFrom()
if nerr, ok := err.(net.Error); ok && nerr.Temporary() { if netutil.IsTemporaryError(err) {
// ignore temporary read errors. // ignore temporary read errors.
s.log.Debugw("temporary read error", "err", err) s.log.Debugw("temporary read error", "err", err)
continue continue
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"sync" "sync"
"github.com/iotaledger/goshimmer/packages/autopeering/peer" "github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/logger" "github.com/iotaledger/hive.go/logger"
"github.com/iotaledger/hive.go/network" "github.com/iotaledger/hive.go/network"
) )
...@@ -114,11 +115,12 @@ func (n *Neighbor) readLoop() { ...@@ -114,11 +115,12 @@ func (n *Neighbor) readLoop() {
for { for {
_, err := n.ManagedConnection.Read(b) _, err := n.ManagedConnection.Read(b)
if nerr, ok := err.(net.Error); ok && nerr.Temporary() { if netutil.IsTemporaryError(err) {
// ignore temporary read errors. // ignore temporary read errors.
n.log.Debugw("temporary read error", "err", err) n.log.Debugw("temporary read error", "err", err)
continue continue
} else if err != nil { }
if err != nil {
// return from the loop on all other errors // return from the loop on all other errors
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
n.log.Warnw("read error", "err", err) n.log.Warnw("read error", "err", err)
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/iotaledger/goshimmer/packages/autopeering/peer" "github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/autopeering/peer/service" "github.com/iotaledger/goshimmer/packages/autopeering/peer/service"
pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto" pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/backoff" "github.com/iotaledger/hive.go/backoff"
"go.uber.org/zap" "go.uber.org/zap"
) )
...@@ -294,10 +295,11 @@ func (t *TCP) listenLoop() { ...@@ -294,10 +295,11 @@ func (t *TCP) listenLoop() {
for { for {
conn, err := t.listener.AcceptTCP() conn, err := t.listener.AcceptTCP()
if err, ok := err.(net.Error); ok && err.Temporary() { if netutil.IsTemporaryError(err) {
t.log.Debugw("temporary read error", "err", err) t.log.Debugw("temporary read error", "err", err)
continue continue
} else if err != nil { }
if err != nil {
// return from the loop on all other errors // return from the loop on all other errors
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
t.log.Warnw("listen error", "err", err) t.log.Warnw("listen error", "err", err)
......
// Package netutil provides utility functions extending the stdnet package.
package netutil
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"time"
)
var (
errInvalidData = errors.New("invalid data received")
)
// IsIPv4 returns true if ip is an IPv4 address.
func IsIPv4(ip net.IP) bool {
return ip.To4() != nil
}
// GetPublicIP queries the ipify API for the public IP address.
func GetPublicIP(preferIPv6 bool) (net.IP, error) {
var url string
if preferIPv6 {
url = "https://api6.ipify.org"
} else {
url = "https://api.ipify.org"
}
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("get failed: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read failed: %w", err)
}
// the body only consists of the ip address
ip := net.ParseIP(string(body))
if ip == nil {
return nil, fmt.Errorf("not an IP: %s", body)
}
return ip, nil
}
// IsTemporaryError checks whether the given error should be considered temporary.
func IsTemporaryError(err error) bool {
tempErr, ok := err.(interface {
Temporary() bool
})
return ok && tempErr.Temporary()
}
// CheckUDP checks whether data send to remote is received at local, otherwise an error is returned.
// If checkAddress is set, it checks whether the IP address that was on the packet matches remote.
// If checkPort is set, it checks whether the port that was on the packet matches remote.
func CheckUDP(local, remote *net.UDPAddr, checkAddress bool, checkPort bool) error {
conn, err := net.ListenUDP("udp", local)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer conn.Close()
nonce := generateNonce()
_, err = conn.WriteTo(nonce, remote)
if err != nil {
return fmt.Errorf("write failed: %w", err)
}
err = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
return fmt.Errorf("set timeout failed: %w", err)
}
p := make([]byte, len(nonce)+1)
n, from, err := conn.ReadFrom(p)
if err != nil {
return fmt.Errorf("read failed: %w", err)
}
if n != len(nonce) || !bytes.Equal(p[:n], nonce) {
return errInvalidData
}
udpAddr := from.(*net.UDPAddr)
if checkAddress && udpAddr.IP.Equal(remote.IP) {
return fmt.Errorf("IP changed: %s", udpAddr.IP)
}
if checkPort && udpAddr.Port != remote.Port {
return fmt.Errorf("port changed: %d", udpAddr.Port)
}
return nil
}
func generateNonce() []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, rand.Uint64())
return b
}
package netutil
import (
"errors"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsIPv4(t *testing.T) {
tests := []struct {
in net.IP
out bool
}{
{nil, false},
{net.IPv4zero, true},
{net.IPv6zero, false},
{net.ParseIP("127.0.0.1"), true},
{net.IPv6loopback, false},
{net.ParseIP("8.8.8.8"), true},
{net.ParseIP("2001:4860:4860::8888"), false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
assert.Equal(t, IsIPv4(tt.in), tt.out)
})
}
}
func TestIsTemporaryError(t *testing.T) {
tests := []struct {
in error
out bool
}{
{nil, false},
{errors.New("errorString"), false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
assert.Equal(t, IsTemporaryError(tt.in), tt.out)
})
}
}
func TestCheckUDP(t *testing.T) {
local, err := getLocalUDPAddr()
require.NoError(t, err)
assert.NoError(t, CheckUDP(local, local, true, true))
invalid := &net.UDPAddr{
IP: local.IP,
Port: local.Port - 1,
Zone: local.Zone,
}
assert.Error(t, CheckUDP(local, invalid, false, false))
}
func getLocalUDPAddr() (*net.UDPAddr, error) {
addr, err := net.ResolveUDPAddr("udp", ":0")
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
return conn.LocalAddr().(*net.UDPAddr), conn.Close()
}
...@@ -3,14 +3,12 @@ package local ...@@ -3,14 +3,12 @@ package local
import ( import (
"crypto/ed25519" "crypto/ed25519"
"encoding/base64" "encoding/base64"
"fmt"
"io/ioutil"
"net" "net"
"net/http"
"strconv" "strconv"
"sync" "sync"
"github.com/iotaledger/goshimmer/packages/autopeering/peer" "github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/goshimmer/packages/parameter" "github.com/iotaledger/goshimmer/packages/parameter"
"github.com/iotaledger/hive.go/logger" "github.com/iotaledger/hive.go/logger"
) )
...@@ -29,7 +27,7 @@ func configureLocal() *peer.Local { ...@@ -29,7 +27,7 @@ func configureLocal() *peer.Local {
} }
if ip.IsUnspecified() { if ip.IsUnspecified() {
log.Info("Querying public IP ...") log.Info("Querying public IP ...")
myIp, err := getPublicIP(isIPv4(ip)) myIp, err := netutil.GetPublicIP(!netutil.IsIPv4(ip))
if err != nil { if err != nil {
log.Fatalf("Error querying public IP: %s", err) log.Fatalf("Error querying public IP: %s", err)
} }
...@@ -69,37 +67,6 @@ func configureLocal() *peer.Local { ...@@ -69,37 +67,6 @@ func configureLocal() *peer.Local {
return local return local
} }
func isIPv4(ip net.IP) bool {
return ip.To4() != nil
}
func getPublicIP(ipv4 bool) (net.IP, error) {
var url string
if ipv4 {
url = "https://api.ipify.org"
} else {
url = "https://api6.ipify.org"
}
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// the body only consists of the ip address
ip := net.ParseIP(string(body))
if ip == nil {
return nil, fmt.Errorf("not an IP: %s", body)
}
return ip, nil
}
func GetInstance() *peer.Local { func GetInstance() *peer.Local {
once.Do(func() { instance = configureLocal() }) once.Do(func() { instance = configureLocal() })
return instance return instance
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment