aboutsummaryrefslogtreecommitdiff
path: root/src/conn_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/conn_linux.go')
-rw-r--r--src/conn_linux.go525
1 files changed, 420 insertions, 105 deletions
diff --git a/src/conn_linux.go b/src/conn_linux.go
index a349a9e..cdba74f 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -7,6 +7,7 @@
package main
import (
+ "encoding/binary"
"errors"
"golang.org/x/sys/unix"
"net"
@@ -16,19 +17,229 @@ import (
/* Supports source address caching
*
- * It is important that the endpoint is only updated after the packet content has been authenticated.
- *
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
*/
-type Endpoint struct {
- // source (selected based on dst type)
- // (could use RawSockaddrAny and unsafe)
- srcIPv6 unix.RawSockaddrInet6
- srcIPv4 unix.RawSockaddrInet4
- srcIf4 int32
-
- dst unix.RawSockaddrAny
+type NativeEndpoint struct {
+ src unix.RawSockaddrInet6
+ dst unix.RawSockaddrInet6
+}
+
+type NativeBind struct {
+ sock4 int
+ sock6 int
+}
+
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = NativeBind{}
+
+type IPv4Source struct {
+ src unix.RawSockaddrInet4
+ Ifindex int32
+}
+
+func htons(val uint16) uint16 {
+ var out [unsafe.Sizeof(val)]byte
+ binary.BigEndian.PutUint16(out[:], val)
+ return *((*uint16)(unsafe.Pointer(&out[0])))
+}
+
+func ntohs(val uint16) uint16 {
+ tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
+ return binary.BigEndian.Uint16((*tmp)[:])
+}
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ var end NativeEndpoint
+ addr, err := parseEndpoint(s)
+ if err != nil {
+ return nil, err
+ }
+
+ ipv4 := addr.IP.To4()
+ if ipv4 != nil {
+ dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+ dst.Family = unix.AF_INET
+ dst.Port = htons(uint16(addr.Port))
+ dst.Zero = [8]byte{}
+ copy(dst.Addr[:], ipv4)
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ ipv6 := addr.IP.To16()
+ if ipv6 != nil {
+ zone, err := zoneToUint32(addr.Zone)
+ if err != nil {
+ return nil, err
+ }
+ dst := &end.dst
+ dst.Family = unix.AF_INET6
+ dst.Port = htons(uint16(addr.Port))
+ dst.Flowinfo = 0
+ dst.Scope_id = zone
+ copy(dst.Addr[:], ipv6[:])
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ return nil, errors.New("Failed to recognize IP address format")
+}
+
+func CreateBind(port uint16) (Bind, uint16, error) {
+ var err error
+ var bind NativeBind
+
+ bind.sock6, port, err = create6(port)
+ if err != nil {
+ return nil, port, err
+ }
+
+ bind.sock4, port, err = create4(port)
+ if err != nil {
+ unix.Close(bind.sock6)
+ }
+ return bind, port, err
+}
+
+func (bind NativeBind) SetMark(value uint32) error {
+ err := unix.SetsockoptInt(
+ bind.sock6,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+
+ return unix.SetsockoptInt(
+ bind.sock4,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+}
+
+func closeUnblock(fd int) error {
+ // shutdown to unblock readers
+ unix.Shutdown(fd, unix.SHUT_RD)
+ return unix.Close(fd)
+}
+
+func (bind NativeBind) Close() error {
+ err1 := closeUnblock(bind.sock6)
+ err2 := closeUnblock(bind.sock4)
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive6(
+ bind.sock6,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive4(
+ bind.sock4,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+ nend := end.(*NativeEndpoint)
+ switch nend.dst.Family {
+ case unix.AF_INET6:
+ return send6(bind.sock6, nend, buff)
+ case unix.AF_INET:
+ return send4(bind.sock4, nend, buff)
+ default:
+ return errors.New("Unknown address family of destination")
+ }
+}
+
+func sockaddrToString(addr unix.RawSockaddrInet6) string {
+ var udpAddr net.UDPAddr
+
+ switch addr.Family {
+ case unix.AF_INET6:
+ udpAddr.Port = int(ntohs(addr.Port))
+ udpAddr.IP = addr.Addr[:]
+ return udpAddr.String()
+
+ case unix.AF_INET:
+ ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+ udpAddr.Port = int(ntohs(ptr.Port))
+ udpAddr.IP = net.IPv4(
+ ptr.Addr[0],
+ ptr.Addr[1],
+ ptr.Addr[2],
+ ptr.Addr[3],
+ )
+ return udpAddr.String()
+
+ default:
+ return "<unknown address family>"
+ }
+}
+
+func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
+ switch addr.Family {
+ case unix.AF_INET6:
+ return addr.Addr[:]
+ case unix.AF_INET:
+ ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+ return net.IPv4(
+ ptr.Addr[0],
+ ptr.Addr[1],
+ ptr.Addr[2],
+ ptr.Addr[3],
+ )
+ default:
+ return nil
+ }
+}
+
+func (end *NativeEndpoint) SrcIP() net.IP {
+ return rawAddrToIP(end.src)
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+ return rawAddrToIP(end.dst)
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
+ ptr := unsafe.Pointer(&end.src)
+ arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
+ return arr[:]
+}
+
+func (end *NativeEndpoint) SrcToString() string {
+ return sockaddrToString(end.src)
+}
+
+func (end *NativeEndpoint) DstToString() string {
+ return sockaddrToString(end.dst)
+}
+
+func (end *NativeEndpoint) ClearDst() {
+ end.dst = unix.RawSockaddrInet6{}
+}
+
+func (end *NativeEndpoint) ClearSrc() {
+ end.src = unix.RawSockaddrInet6{}
}
func zoneToUint32(zone string) (uint32, error) {
@@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
-func (end *Endpoint) ClearSrc() {
- end.srcIf4 = 0
- end.srcIPv4 = unix.RawSockaddrInet4{}
- end.srcIPv6 = unix.RawSockaddrInet6{}
-}
+func create4(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
-func (end *Endpoint) Set(s string) error {
- addr, err := parseEndpoint(s)
if err != nil {
- return err
+ return -1, 0, err
}
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
- if err != nil {
+ addr := unix.SockaddrInet4{
+ Port: int(port),
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
return err
}
- ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
- ptr.Family = unix.AF_INET6
- ptr.Port = uint16(addr.Port)
- ptr.Flowinfo = 0
- ptr.Scope_id = zone
- copy(ptr.Addr[:], ipv6[:])
- end.ClearSrc()
- return nil
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_PKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+ }(); err != nil {
+ unix.Close(fd)
}
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
- ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
- ptr.Family = unix.AF_INET
- ptr.Port = uint16(addr.Port)
- ptr.Zero = [8]byte{}
- copy(ptr.Addr[:], ipv4)
- end.ClearSrc()
- return nil
+ return fd, uint16(addr.Port), err
+}
+
+func create6(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET6,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return -1, 0, err
+ }
+
+ // set sockopts and bind
+
+ addr := unix.SockaddrInet6{
+ Port: int(port),
+ }
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVPKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_V6ONLY,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
}
- return errors.New("Failed to recognize IP address format")
+ return fd, uint16(addr.Port), err
}
-func send6(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
@@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo,
+ Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
- Addr: end.srcIPv6.Addr,
- Ifindex: end.srcIPv6.Scope_id,
+ Addr: end.src.Addr,
+ Ifindex: end.src.Scope_id,
},
}
@@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
- sock,
+ uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
+
+ if errno == 0 {
+ return nil
+ }
+
+ // clear src and retry
+
if errno == unix.EINVAL {
end.ClearSrc()
+ cmsg.pktinfo = unix.Inet6Pktinfo{}
+ _, _, errno = unix.Syscall(
+ unix.SYS_SENDMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
}
+
return errno
}
-func send4(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
+ // construct message header
+
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
+ src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
@@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet6Pktinfo,
+ Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
- Spec_dst: end.srcIPv4.Addr,
- Ifindex: end.srcIf4,
+ Spec_dst: src4.src.Addr,
+ Ifindex: src4.Ifindex,
},
}
@@ -156,51 +451,92 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
+ Flags: 0,
}
-
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
- sock,
+ uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
+
+ // clear source and try again
+
if errno == unix.EINVAL {
end.ClearSrc()
+ cmsg.pktinfo = unix.Inet4Pktinfo{}
+ _, _, errno = unix.Syscall(
+ unix.SYS_SENDMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
+ }
+
+ // errno = 0 is still an error instance
+
+ if errno == 0 {
+ return nil
}
+
return errno
}
-func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- // extract underlying file descriptor
+ // contruct message header
- file, err := c.File()
- if err != nil {
- return err
+ var iovec unix.Iovec
+ iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+ iovec.SetLen(len(buff))
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }
+
+ var msghdr unix.Msghdr
+ msghdr.Iov = &iovec
+ msghdr.Iovlen = 1
+ msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
+ msghdr.Namelen = unix.SizeofSockaddrInet4
+ msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
+ msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+ // recvmsg(sock, &mskhdr, 0)
+
+ size, _, errno := unix.Syscall(
+ unix.SYS_RECVMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
+
+ if errno != 0 {
+ return 0, errno
}
- sock := file.Fd()
- // send depending on address family of dst
+ // update source cache
- family := *((*uint16)(unsafe.Pointer(&end.dst)))
- if family == unix.AF_INET {
- return send4(sock, end, buff)
- } else if family == unix.AF_INET6 {
- return send6(sock, end, buff)
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+ src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+ src4.src.Family = unix.AF_INET
+ src4.src.Addr = cmsg.pktinfo.Spec_dst
+ src4.Ifindex = cmsg.pktinfo.Ifindex
}
- return errors.New("Unknown address family of source")
+
+ return int(size), nil
}
-func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- file, err := c.File()
- if err != nil {
- return err, nil, nil
- }
+ // contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
@@ -208,60 +544,39 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
var cmsg struct {
cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo // big enough
+ pktinfo unix.Inet6Pktinfo
}
var msg unix.Msghdr
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
- msg.Namelen = uint32(unix.SizeofSockaddrAny)
+ msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
- _, _, errno := unix.Syscall(
+ // recvmsg(sock, &mskhdr, 0)
+
+ size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
- file.Fd(),
+ uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
0,
)
if errno != 0 {
- return errno, nil, nil
+ return 0, errno
}
+ // update source cache
+
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-
- }
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-
- info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
- println(info)
-
- }
-
- return nil, nil, nil
-}
-
-func setMark(conn *net.UDPConn, value uint32) error {
- if conn == nil {
- return nil
+ end.src.Family = unix.AF_INET6
+ end.src.Addr = cmsg.pktinfo.Addr
+ end.src.Scope_id = cmsg.pktinfo.Ifindex
}
- file, err := conn.File()
- if err != nil {
- return err
- }
-
- return unix.SetsockoptInt(
- int(file.Fd()),
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
+ return int(size), nil
}