Skip to content
Snippets Groups Projects
Unverified Commit e0fc94a5 authored by Wolfgang Welz's avatar Wolfgang Welz Committed by GitHub
Browse files

Feat: Add netutil package (#159)

* feat: add netutil package

* use IsTemporaryError

* remove unused import
parent 8b3d0c77
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,6 @@ import (
"container/list"
"fmt"
"io"
"net"
"sync"
"time"
......@@ -12,6 +11,7 @@ import (
"github.com/iotaledger/goshimmer/packages/autopeering/peer"
pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto"
"github.com/iotaledger/goshimmer/packages/autopeering/transport"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/logger"
)
......@@ -260,7 +260,7 @@ func (s *Server) readLoop() {
for {
b, fromAddr, err := s.trans.ReadFrom()
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
if netutil.IsTemporaryError(err) {
// ignore temporary read errors.
s.log.Debugw("temporary read error", "err", err)
continue
......
......@@ -8,6 +8,7 @@ import (
"sync"
"github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/logger"
"github.com/iotaledger/hive.go/network"
)
......@@ -114,11 +115,12 @@ func (n *Neighbor) readLoop() {
for {
_, err := n.ManagedConnection.Read(b)
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
if netutil.IsTemporaryError(err) {
// ignore temporary read errors.
n.log.Debugw("temporary read error", "err", err)
continue
} else if err != nil {
}
if err != nil {
// return from the loop on all other errors
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
n.log.Warnw("read error", "err", err)
......
......@@ -15,6 +15,7 @@ import (
"github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/autopeering/peer/service"
pb "github.com/iotaledger/goshimmer/packages/autopeering/server/proto"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/hive.go/backoff"
"go.uber.org/zap"
)
......@@ -294,10 +295,11 @@ func (t *TCP) listenLoop() {
for {
conn, err := t.listener.AcceptTCP()
if err, ok := err.(net.Error); ok && err.Temporary() {
if netutil.IsTemporaryError(err) {
t.log.Debugw("temporary read error", "err", err)
continue
} else if err != nil {
}
if err != nil {
// return from the loop on all other errors
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
t.log.Warnw("listen error", "err", err)
......
// Package netutil provides utility functions extending the stdnet package.
package netutil
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"time"
)
var (
errInvalidData = errors.New("invalid data received")
)
// IsIPv4 returns true if ip is an IPv4 address.
func IsIPv4(ip net.IP) bool {
return ip.To4() != nil
}
// GetPublicIP queries the ipify API for the public IP address.
func GetPublicIP(preferIPv6 bool) (net.IP, error) {
var url string
if preferIPv6 {
url = "https://api6.ipify.org"
} else {
url = "https://api.ipify.org"
}
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("get failed: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read failed: %w", err)
}
// the body only consists of the ip address
ip := net.ParseIP(string(body))
if ip == nil {
return nil, fmt.Errorf("not an IP: %s", body)
}
return ip, nil
}
// IsTemporaryError checks whether the given error should be considered temporary.
func IsTemporaryError(err error) bool {
tempErr, ok := err.(interface {
Temporary() bool
})
return ok && tempErr.Temporary()
}
// CheckUDP checks whether data send to remote is received at local, otherwise an error is returned.
// If checkAddress is set, it checks whether the IP address that was on the packet matches remote.
// If checkPort is set, it checks whether the port that was on the packet matches remote.
func CheckUDP(local, remote *net.UDPAddr, checkAddress bool, checkPort bool) error {
conn, err := net.ListenUDP("udp", local)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer conn.Close()
nonce := generateNonce()
_, err = conn.WriteTo(nonce, remote)
if err != nil {
return fmt.Errorf("write failed: %w", err)
}
err = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
return fmt.Errorf("set timeout failed: %w", err)
}
p := make([]byte, len(nonce)+1)
n, from, err := conn.ReadFrom(p)
if err != nil {
return fmt.Errorf("read failed: %w", err)
}
if n != len(nonce) || !bytes.Equal(p[:n], nonce) {
return errInvalidData
}
udpAddr := from.(*net.UDPAddr)
if checkAddress && udpAddr.IP.Equal(remote.IP) {
return fmt.Errorf("IP changed: %s", udpAddr.IP)
}
if checkPort && udpAddr.Port != remote.Port {
return fmt.Errorf("port changed: %d", udpAddr.Port)
}
return nil
}
func generateNonce() []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, rand.Uint64())
return b
}
package netutil
import (
"errors"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsIPv4(t *testing.T) {
tests := []struct {
in net.IP
out bool
}{
{nil, false},
{net.IPv4zero, true},
{net.IPv6zero, false},
{net.ParseIP("127.0.0.1"), true},
{net.IPv6loopback, false},
{net.ParseIP("8.8.8.8"), true},
{net.ParseIP("2001:4860:4860::8888"), false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
assert.Equal(t, IsIPv4(tt.in), tt.out)
})
}
}
func TestIsTemporaryError(t *testing.T) {
tests := []struct {
in error
out bool
}{
{nil, false},
{errors.New("errorString"), false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
assert.Equal(t, IsTemporaryError(tt.in), tt.out)
})
}
}
func TestCheckUDP(t *testing.T) {
local, err := getLocalUDPAddr()
require.NoError(t, err)
assert.NoError(t, CheckUDP(local, local, true, true))
invalid := &net.UDPAddr{
IP: local.IP,
Port: local.Port - 1,
Zone: local.Zone,
}
assert.Error(t, CheckUDP(local, invalid, false, false))
}
func getLocalUDPAddr() (*net.UDPAddr, error) {
addr, err := net.ResolveUDPAddr("udp", ":0")
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
return conn.LocalAddr().(*net.UDPAddr), conn.Close()
}
......@@ -3,14 +3,12 @@ package local
import (
"crypto/ed25519"
"encoding/base64"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"sync"
"github.com/iotaledger/goshimmer/packages/autopeering/peer"
"github.com/iotaledger/goshimmer/packages/netutil"
"github.com/iotaledger/goshimmer/packages/parameter"
"github.com/iotaledger/hive.go/logger"
)
......@@ -29,7 +27,7 @@ func configureLocal() *peer.Local {
}
if ip.IsUnspecified() {
log.Info("Querying public IP ...")
myIp, err := getPublicIP(isIPv4(ip))
myIp, err := netutil.GetPublicIP(!netutil.IsIPv4(ip))
if err != nil {
log.Fatalf("Error querying public IP: %s", err)
}
......@@ -69,37 +67,6 @@ func configureLocal() *peer.Local {
return local
}
func isIPv4(ip net.IP) bool {
return ip.To4() != nil
}
func getPublicIP(ipv4 bool) (net.IP, error) {
var url string
if ipv4 {
url = "https://api.ipify.org"
} else {
url = "https://api6.ipify.org"
}
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// the body only consists of the ip address
ip := net.ParseIP(string(body))
if ip == nil {
return nil, fmt.Errorf("not an IP: %s", body)
}
return ip, nil
}
func GetInstance() *peer.Local {
once.Do(func() { instance = configureLocal() })
return instance
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment