package buffconn

import (
	"encoding/binary"
	"errors"
	"fmt"
	"net"
	"sync"
	"time"

	"github.com/iotaledger/hive.go/events"
	"go.uber.org/atomic"
)

const (
	// MaxMessageSize is the maximum message size in bytes.
	MaxMessageSize = 4096
	// IOTimeout specifies the timeout for sending and receiving multi packet messages.
	IOTimeout = 4 * time.Second

	headerSize = 4 // size of the header: uint32
)

// Errors returned by the BufferedConnection.
var (
	ErrInvalidHeader      = errors.New("invalid message header")
	ErrInsufficientBuffer = errors.New("insufficient buffer")
)

// BufferedConnectionEvents contains all the events that are triggered during the peer discovery.
type BufferedConnectionEvents struct {
	ReceiveMessage *events.Event
	Close          *events.Event
}

// BufferedConnection is a wrapper for sending and reading messages with a buffer.
type BufferedConnection struct {
	Events BufferedConnectionEvents

	conn                 net.Conn
	incomingHeaderBuffer []byte
	closeOnce            sync.Once

	bytesRead    *atomic.Uint32
	bytesWritten *atomic.Uint32
}

// NewBufferedConnection creates a new BufferedConnection from a net.Conn.
func NewBufferedConnection(conn net.Conn) *BufferedConnection {
	return &BufferedConnection{
		Events: BufferedConnectionEvents{
			ReceiveMessage: events.NewEvent(events.ByteSliceCaller),
			Close:          events.NewEvent(events.CallbackCaller),
		},
		conn:                 conn,
		incomingHeaderBuffer: make([]byte, headerSize),
		bytesRead:            atomic.NewUint32(0),
		bytesWritten:         atomic.NewUint32(0),
	}
}

// Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
func (c *BufferedConnection) Close() (err error) {
	c.closeOnce.Do(func() {
		err = c.conn.Close()
		// close in separate go routine to avoid deadlocks
		go c.Events.Close.Trigger()
	})
	return err
}

// LocalAddr returns the local network address.
func (c *BufferedConnection) LocalAddr() net.Addr {
	return c.conn.LocalAddr()
}

// RemoteAddr returns the remote network address.
func (c *BufferedConnection) RemoteAddr() net.Addr {
	return c.conn.RemoteAddr()
}

// BytesRead returns the total number of bytes read.
func (c *BufferedConnection) BytesRead() uint32 {
	return c.bytesRead.Load()
}

// BytesWritten returns the total number of bytes written.
func (c *BufferedConnection) BytesWritten() uint32 {
	return c.bytesWritten.Load()
}

// Read starts reading on the connection, it only returns when an error occurred or when Close has been called.
// If a complete message has been received and ReceiveMessage event is triggered with its complete payload.
// If read leads to an error, the loop will be stopped and that error returned.
func (c *BufferedConnection) Read() error {
	buffer := make([]byte, MaxMessageSize)

	for {
		n, err := c.readMessage(buffer)
		if err != nil {
			return err
		}
		if n > 0 {
			c.Events.ReceiveMessage.Trigger(buffer[:n])
		}
	}
}

// Write sends a stream of bytes as messages.
// Each array of bytes you pass in will be pre-pended with it's size. If the
// connection isn't open you will receive an error. If not all bytes can be
// written, Write will keep trying until the full message is delivered, or the
// connection is broken.
func (c *BufferedConnection) Write(msg []byte) (int, error) {
	if l := len(msg); l > MaxMessageSize {
		panic(fmt.Sprintf("invalid message length: %d", l))
	}

	buffer := append(newHeader(len(msg)), msg...)

	if err := c.conn.SetWriteDeadline(time.Now().Add(IOTimeout)); err != nil {
		return 0, fmt.Errorf("error while setting timeout: %w", err)
	}

	toWrite := len(buffer)
	for bytesWritten := 0; bytesWritten < toWrite; {
		n, err := c.conn.Write(buffer[bytesWritten:])
		bytesWritten += n
		c.bytesWritten.Add(uint32(n))
		if err != nil {
			return bytesWritten, err
		}
	}
	return toWrite - headerSize, nil
}

func (c *BufferedConnection) read(buffer []byte) (int, error) {
	toRead := len(buffer)
	for bytesRead := 0; bytesRead < toRead; {
		n, err := c.conn.Read(buffer[bytesRead:])
		bytesRead += n
		c.bytesRead.Add(uint32(n))
		if err != nil {
			return bytesRead, err
		}
	}
	return toRead, nil
}

func (c *BufferedConnection) readMessage(buffer []byte) (int, error) {
	if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
		return 0, fmt.Errorf("error while unsetting timeout: %w", err)
	}
	_, err := c.read(c.incomingHeaderBuffer)
	if err != nil {
		return 0, err
	}

	msgLength, err := parseHeader(c.incomingHeaderBuffer)
	if err != nil {
		return 0, err
	}
	if msgLength > len(buffer) {
		return 0, ErrInsufficientBuffer
	}

	if err := c.conn.SetReadDeadline(time.Now().Add(IOTimeout)); err != nil {
		return 0, fmt.Errorf("error while setting timeout: %w", err)
	}
	return c.read(buffer[:msgLength])
}

func newHeader(msgLength int) []byte {
	// the header only consists of the message length
	header := make([]byte, headerSize)
	binary.BigEndian.PutUint32(header, uint32(msgLength))
	return header
}

func parseHeader(header []byte) (int, error) {
	if len(header) != headerSize {
		return 0, ErrInvalidHeader
	}
	msgLength := int(binary.BigEndian.Uint32(header))
	if msgLength > MaxMessageSize {
		return 0, ErrInvalidHeader
	}
	return msgLength, nil
}