diff options
Diffstat (limited to 'conn_linux.go')
-rw-r--r-- | conn_linux.go | 85 |
1 files changed, 60 insertions, 25 deletions
diff --git a/conn_linux.go b/conn_linux.go index 83cf1a2..cc1ce2e 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -232,30 +232,32 @@ func (bind *NativeBind) Close() error { return err3 } -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) { var end NativeEndpoint + var tos byte if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, tos, syscall.EAFNOSUPPORT } - n, err := receive6( + n, tos, err := receive6( bind.sock6, buff, &end, ) - return n, &end, err + return n, &end, tos, err } -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) { var end NativeEndpoint + var tos byte if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, tos, syscall.EAFNOSUPPORT } - n, err := receive4( + n, tos, err := receive4( bind.sock4, buff, &end, ) - return n, &end, err + return n, &end, tos, err } func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error { @@ -384,6 +386,15 @@ func create4(port uint16) (int, uint16, error) { return err } + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_RECVTOS, + 1, + ); err != nil { + return err + } + return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) @@ -442,6 +453,15 @@ func create6(port uint16) (int, uint16, error) { return err } + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVTCLASS, + 1, + ); err != nil { + return err + } + return unix.Bind(fd, &addr) }(); err != nil { @@ -452,12 +472,13 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } +type ipTos struct { + tos byte +} + func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header - type ipTos struct { - tos byte - } cmsg := struct { cmsghdr unix.Cmsghdr @@ -505,9 +526,6 @@ func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header - type ipTos struct { - tos byte - } cmsg := struct { cmsghdr unix.Cmsghdr @@ -555,19 +573,21 @@ func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { return err } -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) { // contruct message header var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + cmsghdr2 unix.Cmsghdr + iptos ipTos } size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) if err != nil { - return 0, err + return 0, 0, err } end.isV6 = false @@ -576,7 +596,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { } // update source cache - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { @@ -584,22 +603,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { end.src4().ifindex = cmsg.pktinfo.Ifindex } - return size, nil + tos := byte(0) + if cmsg.cmsghdr2.Level == unix.IPPROTO_IP && + cmsg.cmsghdr2.Type == unix.IP_TOS && + cmsg.cmsghdr2.Len >= 1 { + tos = cmsg.iptos.tos + } + + return size, tos, nil } -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) { // contruct message header var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + cmsghdr2 unix.Cmsghdr + iptos ipTos } size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) if err != nil { - return 0, err + return 0, 0, err } end.isV6 = true @@ -616,7 +644,14 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { end.dst6().ZoneId = cmsg.pktinfo.Ifindex } - return size, nil + tos := byte(0) + if cmsg.cmsghdr2.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr2.Type == unix.IPV6_TCLASS && + cmsg.cmsghdr2.Len >= 1 { + tos = cmsg.iptos.tos + } + + return size, tos, nil } func (bind *NativeBind) routineRouteListener(device *Device) { |