diff options
author | David Crawshaw <crawshaw@tailscale.com> | 2019-04-17 09:41:25 -0400 |
---|---|---|
committer | David Crawshaw <david@zentus.com> | 2020-03-31 09:32:52 +1100 |
commit | a38504e3994268ee9aee5cb1d42bab82ef5ede76 (patch) | |
tree | f5a3d792e13f33a1c454ee786324c54ad4513511 | |
parent | d5d70756bbf34474645d4863cd49ffce82e35261 (diff) | |
download | wireguard-go-a38504e3994268ee9aee5cb1d42bab82ef5ede76.tar.gz wireguard-go-a38504e3994268ee9aee5cb1d42bab82ef5ede76.zip |
wgcfg: new config package
Based on types and config parser from wireguard-windows.
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
-rw-r--r-- | wgcfg/config.go | 78 | ||||
-rw-r--r-- | wgcfg/ip.go | 128 | ||||
-rw-r--r-- | wgcfg/key.go | 240 | ||||
-rw-r--r-- | wgcfg/key_test.go | 107 | ||||
-rw-r--r-- | wgcfg/name.go | 49 | ||||
-rw-r--r-- | wgcfg/parser.go | 397 | ||||
-rw-r--r-- | wgcfg/parser_test.go | 127 | ||||
-rw-r--r-- | wgcfg/writer.go | 75 |
8 files changed, 1201 insertions, 0 deletions
diff --git a/wgcfg/config.go b/wgcfg/config.go new file mode 100644 index 0000000..2b5e714 --- /dev/null +++ b/wgcfg/config.go @@ -0,0 +1,78 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +// Package wgcfg has types and a parser for representing WireGuard config. +package wgcfg + +import ( + "fmt" + "strings" +) + +// Config is a wireguard configuration. +type Config struct { + Name string + PrivateKey PrivateKey + Addresses []CIDR + ListenPort uint16 + MTU uint16 + DNS []IP + Peers []Peer +} + +type Peer struct { + PublicKey Key + PresharedKey SymmetricKey + AllowedIPs []CIDR + Endpoints []Endpoint + PersistentKeepalive uint16 +} + +type Endpoint struct { + Host string + Port uint16 +} + +func (e *Endpoint) String() string { + if strings.IndexByte(e.Host, ':') > 0 { + return fmt.Sprintf("[%s]:%d", e.Host, e.Port) + } + return fmt.Sprintf("%s:%d", e.Host, e.Port) +} + +func (e *Endpoint) IsEmpty() bool { + return len(e.Host) == 0 +} + +// Copy makes a deep copy of Config. +// The result aliases no memory with the original. +func (cfg Config) Copy() Config { + res := cfg + if res.Addresses != nil { + res.Addresses = append([]CIDR{}, res.Addresses...) + } + if res.DNS != nil { + res.DNS = append([]IP{}, res.DNS...) + } + peers := make([]Peer, 0, len(res.Peers)) + for _, peer := range res.Peers { + peers = append(peers, peer.Copy()) + } + res.Peers = peers + return res +} + +// Copy makes a deep copy of Peer. +// The result aliases no memory with the original. +func (peer Peer) Copy() Peer { + res := peer + if res.AllowedIPs != nil { + res.AllowedIPs = append([]CIDR{}, res.AllowedIPs...) + } + if res.Endpoints != nil { + res.Endpoints = append([]Endpoint{}, res.Endpoints...) + } + return res +} diff --git a/wgcfg/ip.go b/wgcfg/ip.go new file mode 100644 index 0000000..ecf5faf --- /dev/null +++ b/wgcfg/ip.go @@ -0,0 +1,128 @@ +package wgcfg + +import ( + "fmt" + "net" +) + +// IP is an IPv4 or an IPv6 address. +// +// Internally the address is always represented in its IPv6 form. +// IPv4 addresses use the IPv4-in-IPv6 syntax. +type IP struct { + Addr [16]byte +} + +func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } + +func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) } +func (ip *IP) Is6() bool { return !ip.Is4() } +func (ip *IP) Is4() bool { + return ip.Addr[0] == 0 && ip.Addr[1] == 0 && + ip.Addr[2] == 0 && ip.Addr[3] == 0 && + ip.Addr[4] == 0 && ip.Addr[5] == 0 && + ip.Addr[6] == 0 && ip.Addr[7] == 0 && + ip.Addr[8] == 0 && ip.Addr[9] == 0 && + ip.Addr[10] == 0xff && ip.Addr[11] == 0xff +} +func (ip *IP) To4() []byte { + if ip.Is4() { + return ip.Addr[12:16] + } else { + return nil + } +} +func (ip *IP) Equal(x *IP) bool { + if ip == nil || x == nil { + return false + } + // TODO: this isn't hard, write a more efficient implementation. + return ip.IP().Equal(x.IP()) +} + +func (ip IP) MarshalText() ([]byte, error) { + return []byte(ip.String()), nil +} + +func (ip *IP) UnmarshalText(text []byte) error { + parsedIP := ParseIP(string(text)) + if parsedIP == nil { + return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text)) + } + *ip = *parsedIP + return nil +} + +func IPv4(b0, b1, b2, b3 byte) (ip IP) { + ip.Addr[10], ip.Addr[11] = 0xff, 0xff // IPv4-in-IPv6 prefix + ip.Addr[12] = b0 + ip.Addr[13] = b1 + ip.Addr[14] = b2 + ip.Addr[15] = b3 + return ip +} + +// ParseIP parses the string representation of an address into an IP. +// +// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". +// If the string is not a valid IP address, ParseIP returns nil. +func ParseIP(s string) *IP { + netIP := net.ParseIP(s) + if netIP == nil { + return nil + } + ip := new(IP) + copy(ip.Addr[:], netIP.To16()) + return ip +} + +// CIDR is a compact IP address and subnet mask. +type CIDR struct { + IP IP + Mask uint8 // 0-32 for IsIPv4, 4-128 for IsIPv6 +} + +// ParseCIDR parses CIDR notation into a CIDR type. +// Typical CIDR strings look like "192.168.1.0/24". +func ParseCIDR(s string) (cidr *CIDR, err error) { + netIP, netAddr, err := net.ParseCIDR(s) + if err != nil { + return nil, err + } + cidr = new(CIDR) + copy(cidr.IP.Addr[:], netIP.To16()) + ones, _ := netAddr.Mask.Size() + cidr.Mask = uint8(ones) + + return cidr, nil +} + +func (r CIDR) String() string { return r.IPNet().String() } + +func (r *CIDR) IPNet() *net.IPNet { + bits := 128 + if r.IP.Is4() { + bits = 32 + } + return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} +} +func (r *CIDR) Contains(ip *IP) bool { + if r == nil || ip == nil { + return false + } + // TODO: this isn't hard, write a more efficient implementation. + return r.IPNet().Contains(ip.IP()) +} + +func (r CIDR) MarshalText() ([]byte, error) { + return []byte(r.String()), nil +} + +func (r *CIDR) UnmarshalText(text []byte) error { + cidr, err := ParseCIDR(string(text)) + if err != nil { + return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) + } + *r = *cidr + return nil +} diff --git a/wgcfg/key.go b/wgcfg/key.go new file mode 100644 index 0000000..1597203 --- /dev/null +++ b/wgcfg/key.go @@ -0,0 +1,240 @@ +package wgcfg + +import ( + "bytes" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" +) + +const KeySize = 32 + +// Key is curve25519 key. +// It is used by WireGuard to represent public and preshared keys. +type Key [KeySize]byte + +// NewPresharedKey generates a new random key. +func NewPresharedKey() (*Key, error) { + var k [KeySize]byte + _, err := rand.Read(k[:]) + if err != nil { + return nil, err + } + return (*Key)(&k), nil +} + +func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } + +func ParseHexKey(s string) (Key, error) { + b, err := hex.DecodeString(s) + if err != nil { + return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} + } + if len(b) != KeySize { + return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} + } + + var key Key + copy(key[:], b) + return key, nil +} + +func ParsePrivateHexKey(v string) (PrivateKey, error) { + k, err := ParseHexKey(v) + if err != nil { + return PrivateKey{}, err + } + pk := PrivateKey(k) + if pk.IsZero() { + // Do not clamp a zero key, pass the zero through + // (much like NaN propagation) so that IsZero reports + // a useful result. + return pk, nil + } + pk.clamp() + return pk, nil +} + +func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k Key) String() string { return "pub:" + k.Base64()[:8] } +func (k Key) HexString() string { return hex.EncodeToString(k[:]) } +func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *Key) ShortString() string { + if k.IsZero() { + return "[empty]" + } + long := k.String() + if len(long) < 10 { + return "invalid" + } + return "[" + long[0:4] + "ā¦" + long[len(long)-5:len(long)-1] + "]" +} + +func (k *Key) IsZero() bool { + if k == nil { + return true + } + var zeros Key + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 +} + +func (k *Key) MarshalJSON() ([]byte, error) { + if k == nil { + return []byte("null"), nil + } + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `"%x"`, k[:]) + return buf.Bytes(), nil +} + +func (k *Key) UnmarshalJSON(b []byte) error { + if k == nil { + return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") + } + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return errors.New("wgcfg.Key: UnmarshalJSON not given a string") + } + b = b[1 : len(b)-1] + key, err := ParseHexKey(string(b)) + if err != nil { + return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (a *Key) LessThan(b *Key) bool { + for i := range a { + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } + return false +} + +// PrivateKey is curve25519 key. +// It is used by WireGuard to represent private keys. +type PrivateKey [KeySize]byte + +// NewPrivateKey generates a new curve25519 secret key. +// It conforms to the format described on https://cr.yp.to/ecdh.html. +func NewPrivateKey() (PrivateKey, error) { + k, err := NewPresharedKey() + if err != nil { + return PrivateKey{}, err + } + k[0] &= 248 + k[31] = (k[31] & 127) | 64 + return (PrivateKey)(*k), nil +} + +func ParsePrivateKey(b64 string) (*PrivateKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + return (*PrivateKey)(k), err +} + +func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } + +func (k *PrivateKey) IsZero() bool { + pk := Key(*k) + return pk.IsZero() +} + +func (k *PrivateKey) clamp() { + k[0] &= 248 + k[31] = (k[31] & 127) | 64 +} + +// Public computes the public key matching this curve25519 secret key. +func (k *PrivateKey) Public() Key { + pk := Key(*k) + if pk.IsZero() { + panic("Tried to generate emptyPrivateKey.Public()") + } + var p [KeySize]byte + curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) + return (Key)(p) +} + +func (k PrivateKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, `privkey:%x`, k[:]) + return buf.Bytes(), nil +} + +func (k *PrivateKey) UnmarshalText(b []byte) error { + s := string(b) + if !strings.HasPrefix(s, `privkey:`) { + return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") + } + s = strings.TrimPrefix(s, `privkey:`) + key, err := ParseHexKey(s) + if err != nil { + return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { + apk := (*[KeySize]byte)(&pub) + ask := (*[KeySize]byte)(&k) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} + +func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { + k, err := enc.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func ParseSymmetricKey(b64 string) (SymmetricKey, error) { + k, err := parseKeyBase64(base64.StdEncoding, b64) + if err != nil { + return SymmetricKey{}, err + } + return SymmetricKey(*k), nil +} + +func ParseSymmetricHexKey(s string) (SymmetricKey, error) { + b, err := hex.DecodeString(s) + if err != nil { + return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} + } + if len(b) != chacha20poly1305.KeySize { + return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} + } + var key SymmetricKey + copy(key[:], b) + return key, nil +} + +// SymmetricKey is a chacha20poly1305 key. +// It is used by WireGuard to represent pre-shared symmetric keys. +type SymmetricKey [chacha20poly1305.KeySize]byte + +func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } +func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } +func (k SymmetricKey) Equal(k2 SymmetricKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go new file mode 100644 index 0000000..0b82d5f --- /dev/null +++ b/wgcfg/key_test.go @@ -0,0 +1,107 @@ +package wgcfg + +import ( + "bytes" + "testing" +) + +func TestKeyBasics(t *testing.T) { + k1, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + + b, err := k1.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(k1[:], k2[:]) { + t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) + } + if b1, b2 := k1.String(), k2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + }) + + t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { + k2 := new(PrivateKey) + if err := k2.UnmarshalText(b); err == nil { + t.Fatalf("successfully decoded key as private key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to NewPresharedKey should make a new key. + k3, err := NewPresharedKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(k1[:], k3[:]) { + t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := k1.String(), k3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + }) +} +func TestPrivateKeyBasics(t *testing.T) { + pri, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + b, err := pri.MarshalText() + if err != nil { + t.Fatal(err) + } + + t.Run("JSON round-trip", func(t *testing.T) { + // should preserve the keys + pri2 := new(PrivateKey) + if err := pri2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + if !bytes.Equal(pri[:], pri2[:]) { + t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) + } + if b1, b2 := pri.String(), pri2.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { + t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) + } + }) + + t.Run("JSON incompatible with Key", func(t *testing.T) { + k2 := new(Key) + if err := k2.UnmarshalJSON(b); err == nil { + t.Fatalf("successfully decoded private key as key") + } + }) + + t.Run("second key", func(t *testing.T) { + // A second call to New should make a new key. + pri3, err := NewPrivateKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(pri[:], pri3[:]) { + t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) + } + // Check for obvious comparables to make sure we are not generating bad strings somewhere. + if b1, b2 := pri.String(), pri3.String(); b1 == b2 { + t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) + } + if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { + t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) + } + }) +} diff --git a/wgcfg/name.go b/wgcfg/name.go new file mode 100644 index 0000000..28bc0f0 --- /dev/null +++ b/wgcfg/name.go @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "regexp" + "strings" +) + +var reservedNames = []string{ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", +} + +const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" + +var allowedNameFormat *regexp.Regexp + +func init() { + allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") +} + +func isReserved(name string) bool { + if len(name) == 0 { + return false + } + for _, reserved := range reservedNames { + if strings.EqualFold(name, reserved) { + return true + } + } + return false +} + +func hasSpecialChars(name string) bool { + return strings.ContainsAny(name, specialChars) +} + +func TunnelNameIsValid(name string) bool { + // Aside from our own restrictions, let's impose the Windows restrictions first + if isReserved(name) || hasSpecialChars(name) { + return false + } + return allowedNameFormat.MatchString(name) +} diff --git a/wgcfg/parser.go b/wgcfg/parser.go new file mode 100644 index 0000000..45a6057 --- /dev/null +++ b/wgcfg/parser.go @@ -0,0 +1,397 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: ā%sā", e.why, e.offender) +} + +func parseEndpoints(s string) ([]Endpoint, error) { + var eps []Endpoint + vals := strings.Split(s, ",") + for _, val := range vals { + e, err := parseEndpoint(val) + if err != nil { + return nil, err + } + eps = append(eps, *e) + } + return eps, nil +} + +func parseEndpoint(s string) (*Endpoint, error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return nil, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return nil, &ParseError{"Invalid endpoint host", host} + } + port, err := parsePort(portStr) + if err != nil { + return nil, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return nil, err + } + } else { + return nil, err + } + host = host[1 : len(host)-1] + } + return &Endpoint{host, uint16(port)}, nil +} + +func parseMTU(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 576 || m > 65535 { + return 0, &ParseError{"Invalid MTU", s} + } + return uint16(m), nil +} + +func parsePort(s string) (uint16, error) { + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid port", s} + } + return uint16(m), nil +} + +func parsePersistentKeepalive(s string) (uint16, error) { + if s == "off" { + return 0, nil + } + m, err := strconv.Atoi(s) + if err != nil { + return 0, err + } + if m < 0 || m > 65535 { + return 0, &ParseError{"Invalid persistent keepalive", s} + } + return uint16(m), nil +} + +func parseKeyHex(s string) (*Key, error) { + k, err := hex.DecodeString(s) + if err != nil { + return nil, &ParseError{"Invalid key: " + err.Error(), s} + } + if len(k) != KeySize { + return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} + } + var key Key + copy(key[:], k) + return &key, nil +} + +func parseBytesOrStamp(s string) (uint64, error) { + b, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s} + } + return b, nil +} + +func splitList(s string) ([]string, error) { + var out []string + for _, split := range strings.Split(s, ",") { + trim := strings.TrimSpace(split) + if len(trim) == 0 { + return nil, &ParseError{"Two commas in a row", s} + } + out = append(out, trim) + } + return out, nil +} + +type parserState int + +const ( + inInterfaceSection parserState = iota + inPeerSection + notInASection +) + +func (c *Config) maybeAddPeer(p *Peer) { + if p != nil { + c.Peers = append(c.Peers, *p) + } +} + +func FromWgQuick(s string, name string) (*Config, error) { + if !TunnelNameIsValid(name) { + return nil, &ParseError{"Tunnel name is not valid", name} + } + lines := strings.Split(s, "\n") + parserState := notInASection + conf := Config{Name: name} + sawPrivateKey := false + var peer *Peer + for _, line := range lines { + pound := strings.IndexByte(line, '#') + if pound >= 0 { + line = line[:pound] + } + line = strings.TrimSpace(line) + lineLower := strings.ToLower(line) + if len(line) == 0 { + continue + } + if lineLower == "[interface]" { + conf.maybeAddPeer(peer) + parserState = inInterfaceSection + continue + } + if lineLower == "[peer]" { + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + continue + } + if parserState == notInASection { + return nil, &ParseError{"Line must occur in a section", line} + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + if parserState == inInterfaceSection { + switch key { + case "privatekey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + sawPrivateKey = true + case "listenport": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "mtu": + m, err := parseMTU(val) + if err != nil { + return nil, err + } + conf.MTU = m + case "address": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + conf.Addresses = append(conf.Addresses, *a) + } + case "dns": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a := ParseIP(address) + if a == nil { + return nil, &ParseError{"Invalid IP address", address} + } + conf.DNS = append(conf.DNS, *a) + } + default: + return nil, &ParseError{"Invalid key for [Interface] section", key} + } + } else if parserState == inPeerSection { + switch key { + case "publickey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "presharedkey": + k, err := ParseKey(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "allowedips": + addresses, err := splitList(val) + if err != nil { + return nil, err + } + for _, address := range addresses { + a, err := ParseCIDR(address) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + } + case "persistentkeepalive": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for [Peer] section", key} + } + } + } + conf.maybeAddPeer(peer) + + if !sawPrivateKey { + return nil, &ParseError{"An interface must have a private key", "[none specified]"} + } + for _, p := range conf.Peers { + if p.PublicKey.IsZero() { + return nil, &ParseError{"All peers must have public keys", "[none specified]"} + } + } + + return &conf, nil +} + +// TODO(apenwarr): This is incompatibe with current Device.IpcSetOperation. +// It duplicates all the parser stuff in there, but is missing some +// keywords. Nothing useful seems to need it anymore. +func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { + lines := strings.Split(s, "\n") + parserState := inInterfaceSection + conf := Config{ + Name: existingConfig.Name, + Addresses: existingConfig.Addresses, + DNS: existingConfig.DNS, + MTU: existingConfig.MTU, + } + var peer *Peer + for _, line := range lines { + if len(line) == 0 { + continue + } + equals := strings.IndexByte(line, '=') + if equals < 0 { + return nil, &ParseError{"Invalid config key is missing an equals separator", line} + } + key, val := line[:equals], line[equals+1:] + if len(val) == 0 { + return nil, &ParseError{"Key must have a value", line} + } + switch key { + case "public_key": + conf.maybeAddPeer(peer) + peer = &Peer{} + parserState = inPeerSection + case "errno": + if val == "0" { + continue + } else { + return nil, &ParseError{"Error in getting configuration", val} + } + } + if parserState == inInterfaceSection { + switch key { + case "private_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + conf.PrivateKey = PrivateKey(*k) + case "listen_port": + p, err := parsePort(val) + if err != nil { + return nil, err + } + conf.ListenPort = p + case "fwmark": + // Ignored for now. + + default: + return nil, &ParseError{"Invalid key for interface section", key} + } + } else if parserState == inPeerSection { + switch key { + case "public_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PublicKey = *k + case "preshared_key": + k, err := parseKeyHex(val) + if err != nil { + return nil, err + } + peer.PresharedKey = SymmetricKey(*k) + case "protocol_version": + if val != "1" { + return nil, &ParseError{"Protocol version must be 1", val} + } + case "allowed_ip": + a, err := ParseCIDR(val) + if err != nil { + return nil, err + } + peer.AllowedIPs = append(peer.AllowedIPs, *a) + case "persistent_keepalive_interval": + p, err := parsePersistentKeepalive(val) + if err != nil { + return nil, err + } + peer.PersistentKeepalive = p + case "endpoint": + eps, err := parseEndpoints(val) + if err != nil { + return nil, err + } + peer.Endpoints = eps + default: + return nil, &ParseError{"Invalid key for peer section", key} + } + } + } + conf.maybeAddPeer(peer) + + return &conf, nil +} diff --git a/wgcfg/parser_test.go b/wgcfg/parser_test.go new file mode 100644 index 0000000..d0df537 --- /dev/null +++ b/wgcfg/parser_test.go @@ -0,0 +1,127 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "reflect" + "runtime" + "testing" +) + +const testInput = ` +[Interface] +Address = 10.192.122.1/24 +Address = 10.10.0.1/16 +PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= +ListenPort = 51820 #comments don't matter + +[Peer] +PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= +Endpoint = 192.95.5.67:1234 +AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 + +[Peer] +PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = [2607:5300:60:6b0::c05f:543]:2468 +AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 +PersistentKeepalive = 100 + +[Peer] +PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= +PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= +Endpoint = test.wireguard.com:18981 +AllowedIPs = 10.10.10.230/32` + +func noError(t *testing.T, err error) bool { + if err == nil { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error at %s:%d: %#v", fn, line, err) + return false +} + +func equal(t *testing.T, expected, actual interface{}) bool { + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func lenTest(t *testing.T, actualO interface{}, expected int) bool { + actual := reflect.ValueOf(actualO).Len() + if reflect.DeepEqual(expected, actual) { + return true + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) + return false +} +func contains(t *testing.T, list, element interface{}) bool { + listValue := reflect.ValueOf(list) + for i := 0; i < listValue.Len(); i++ { + if reflect.DeepEqual(listValue.Index(i).Interface(), element) { + return true + } + } + _, fn, line, _ := runtime.Caller(1) + t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) + return false +} + +func TestFromWgQuick(t *testing.T) { + conf, err := FromWgQuick(testInput, "test") + if noError(t, err) { + + lenTest(t, conf.Addresses, 2) + contains(t, conf.Addresses, CIDR{IPv4(10, 10, 0, 1), 16}) + contains(t, conf.Addresses, CIDR{IPv4(10, 192, 122, 1), 24}) + equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.PrivateKey.String()) + equal(t, uint16(51820), conf.ListenPort) + + lenTest(t, conf.Peers, 3) + lenTest(t, conf.Peers[0].AllowedIPs, 2) + equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoints[0]) + equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.Base64()) + + lenTest(t, conf.Peers[1].AllowedIPs, 2) + equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoints[0]) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.Base64()) + equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) + + lenTest(t, conf.Peers[2].AllowedIPs, 1) + equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoints[0]) + equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.Base64()) + equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.Base64()) + } +} + +func TestParseEndpoint(t *testing.T) { + _, err := parseEndpoint("[192.168.42.0:]:51880") + if err == nil { + t.Error("Error was expected") + } + e, err := parseEndpoint("192.168.42.0:51880") + if noError(t, err) { + equal(t, "192.168.42.0", e.Host) + equal(t, uint16(51880), e.Port) + } + e, err = parseEndpoint("test.wireguard.com:18981") + if noError(t, err) { + equal(t, "test.wireguard.com", e.Host) + equal(t, uint16(18981), e.Port) + } + e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") + if noError(t, err) { + equal(t, "2607:5300:60:6b0::c05f:543", e.Host) + equal(t, uint16(2468), e.Port) + } + _, err = parseEndpoint("[::::::invalid:18981") + if err == nil { + t.Error("Error was expected") + } +} diff --git a/wgcfg/writer.go b/wgcfg/writer.go new file mode 100644 index 0000000..aafb2a7 --- /dev/null +++ b/wgcfg/writer.go @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg + +import ( + "errors" + "fmt" + "net" + "strings" +) + +func (conf *Config) ToUAPI() (string, error) { + output := new(strings.Builder) + fmt.Fprintf(output, "private_key=%s\n", conf.PrivateKey.HexString()) + + if conf.ListenPort > 0 { + fmt.Fprintf(output, "listen_port=%d\n", conf.ListenPort) + } + + output.WriteString("replace_peers=true\n") + + for _, peer := range conf.Peers { + fmt.Fprintf(output, "public_key=%s\n", peer.PublicKey.HexString()) + fmt.Fprintf(output, "protocol_version=1\n") + fmt.Fprintf(output, "replace_allowed_ips=true\n") + + if !peer.PresharedKey.IsZero() { + fmt.Fprintf(output, "preshared_key = %s\n", peer.PresharedKey.String()) + } + + if len(peer.AllowedIPs) > 0 { + for _, address := range peer.AllowedIPs { + fmt.Fprintf(output, "allowed_ip=%s\n", address.String()) + } + } + + if len(peer.Endpoints) > 0 { + var reps []string + for _, ep := range peer.Endpoints { + ips, err := net.LookupIP(ep.Host) + if err != nil { + return "", err + } + var ip net.IP + for _, iterip := range ips { + iterip = iterip.To4() + if iterip != nil { + ip = iterip + break + } + if ip == nil { + ip = iterip + } + } + if ip == nil { + return "", errors.New("Unable to resolve IP address of endpoint") + } + resolvedEndpoint := Endpoint{ip.String(), ep.Port} + reps = append(reps, resolvedEndpoint.String()) + } + fmt.Fprintf(output, "endpoint=%s\n", strings.Join(reps, ",")) + } else { + fmt.Fprint(output, "endpoint=\n") + } + + // Note: this needs to come *after* endpoint definitions, + // because setting it will trigger a handshake to all + // already-defined endpoints. + fmt.Fprintf(output, "persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) + } + return output.String(), nil +} |