diff --git a/internal/pkg/container/nsupdate.go b/internal/pkg/container/nsupdate.go index eab31cf..7fc5087 100644 --- a/internal/pkg/container/nsupdate.go +++ b/internal/pkg/container/nsupdate.go @@ -1,14 +1,101 @@ package container import ( - "net" + "fmt" + "strings" "time" cmd "gitea.elkins.co/Networking/ccl/internal/pkg/command" "github.com/miekg/dns" ) -func (c *Container) doReverse(rv string, dn string) error { +func (c *Container) makeDnsClient() *dns.Client { + cli := new(dns.Client) + if c.TSIGName != "" { + cli.TsigSecret = map[string]string{c.TSIGName: c.TSIGKey} + } + return cli +} + +func (c *Container) killDnsReverse(ip string) error { + rv, err := dns.ReverseAddr(ip) + if err != nil { + return err + } + cli := c.makeDnsClient() + + // Determine SOA of old reverse zone + msg := new(dns.Msg) + msg.SetQuestion(rv, dns.TypeSOA) + resp, _, err := cli.Exchange(msg, c.DnsServer) + if err != nil { + return err + } + soa := resp.Ns[0].Header().Name + + // Remove the old PTR + ptr := dns.ANY{ + Hdr: dns.RR_Header{ + Name: rv, + Rrtype: dns.TypePTR, + }, + } + msg = new(dns.Msg) + msg.SetUpdate(soa) + msg.RemoveName([]dns.RR{&ptr}) + if c.TSIGName != "" { + msg.SetTsig(c.TSIGName, dns.HmacSHA256, 300, time.Now().Unix()) + } + resp, _, err = cli.Exchange(msg, c.DnsServer) + if err != nil { + return err + } + return nil +} + +// This is the same code for ipv4 or ipv6 so factor it out +func (c *Container) doDnsReverse(ip string, dn string, rrtype uint16) error { + rv, err := dns.ReverseAddr(ip) + if err != nil { + return err + } + + cli := c.makeDnsClient() + + // Delete any existing PTR + // Strategy: + // 1. find existing RR's of our given type + // 2. for each one, kill any reverse pointers to that address + msg := new(dns.Msg) + msg.SetQuestion(dn, rrtype) + resp, _, err := cli.Exchange(msg, c.DnsServer) + if err != nil { + return fmt.Errorf("error looking up existing RRs for %s type %d: %s", dn, rrtype, err) + } + for i := range resp.Answer { + var ip string + if rrtype == dns.TypeA { + x := (resp.Answer[i]).(*dns.A) + ip = x.A.String() + } else { + x := (resp.Answer[i]).(*dns.AAAA) + ip = x.AAAA.String() + } + if err := c.killDnsReverse(ip); err != nil { + return err + } + } + + // Determine SOA of reverse zone + msg = new(dns.Msg) + msg.SetQuestion(rv, dns.TypeSOA) + resp, _, err = cli.Exchange(msg, c.DnsServer) + if err != nil { + return err + } + soa := resp.Ns[0].Header().Name + + // Update the reverse record ptr := dns.PTR{ Hdr: dns.RR_Header{ Name: rv, @@ -19,20 +106,6 @@ func (c *Container) doReverse(rv string, dn string) error { Ptr: dn, } - cli := new(dns.Client) - if c.TSIGName != "" { - cli.TsigSecret = map[string]string{c.TSIGName: c.TSIGKey} - } - - msg := new(dns.Msg) - msg.SetQuestion(rv, dns.TypeSOA) - resp, _, err := cli.Exchange(msg, c.DnsServer) - if err != nil { - return err - } - soa := resp.Ns[0].Header().Name - - // Update the reverse record msg = new(dns.Msg) msg.SetUpdate(soa) msg.Ns = append(msg.Ns, &ptr) @@ -46,106 +119,133 @@ func (c *Container) doReverse(rv string, dn string) error { return nil } -func (c *Container) NsUpdateCommands() cmd.Set { - if c.DomainName == "" || c.DnsServer == "" { - return cmd.Set{ID: "NSUPDATE_NOT_CONFIGURED"} +func (c *Container) killDnsForward(name string, rrtype uint16) error { + cli := c.makeDnsClient() + + msg := new(dns.Msg) + msg.SetQuestion(name, rrtype) + resp, _, err := cli.Exchange(msg, c.DnsServer) + if err != nil { + return err } + + dn := dns.Fqdn(c.DomainName) + if !strings.HasSuffix(dn, ".") { + dn = dn + "." + } + + if len(resp.Answer) > 0 { + msg := new(dns.Msg) + msg.SetUpdate(dn) + msg.RemoveRRset(resp.Answer) + if _, _, err := cli.Exchange(msg, c.DnsServer); err != nil { + return err + } + } + + return nil +} + +func (c *Container) doDnsForward(rr string) error { + rr_parsed, err := dns.NewRR(rr) + if err != nil { + return err + } + + cli := c.makeDnsClient() + dn := dns.Fqdn(c.DomainName) + if !strings.HasSuffix(dn, ".") { + dn = dn + "." + } + + // Update the forward record + msg := new(dns.Msg) + msg.SetUpdate(dn) + msg.Ns = append(msg.Ns, rr_parsed) + if c.TSIGName != "" { + msg.SetTsig(c.TSIGName, dns.HmacSHA256, 300, time.Now().Unix()) + } + if _, _, err = cli.Exchange(msg, c.DnsServer); err != nil { + return err + } + return nil +} + +func (c *Container) NsUpdateCommands() cmd.Set { + // check that a server is configured + if c.DomainName == "" || c.DnsServer == "" { + return cmd.Set{ID: "NSUPDATE", Commands: []cmd.Command{ + cmd.NewFunc("NSUPDATE_NOT_CONFIGURED", func() error { + return fmt.Errorf("nsupdate command requires `domain_name` and `dns_server` to be configured") + }), + }} + } + + // determine hostname hostname := c.Hostname if c.Hostname == "" { hostname = c.Name } dn := dns.Fqdn(hostname + "." + c.DomainName) + if !strings.HasSuffix(dn, ".") { + dn = dn + "." + } + + // prepare update commands cmds := []cmd.Command{} - // TODO: also iterate over c.IPv6Addresses + // TODO: also iterate over c.IPv6Addresses, if it ever works for i := range c.Networks { n := &c.Networks[i] - if n.IPv6.Bool && !n.IPv6Address.IsUnspecified() { - ad := net.ParseIP(n.IPv6Address.String()) - if ad != nil { - f_6 := func() error { - aaaa := dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dn, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 7200, - }, - AAAA: ad, - } - - rv, err := dns.ReverseAddr(aaaa.AAAA.String()) - if err != nil { - return err - } - - cli := new(dns.Client) - if c.TSIGName != "" { - cli.TsigSecret = map[string]string{c.TSIGName: c.TSIGKey} - } - - // Update the forward record - msg := new(dns.Msg) - msg.SetUpdate(dns.Fqdn(c.DomainName)) - msg.Ns = append(msg.Ns, &aaaa) - - if c.TSIGName != "" { - msg.SetTsig(c.TSIGName, dns.HmacSHA256, 300, time.Now().Unix()) - } - if _, _, err = cli.Exchange(msg, c.DnsServer); err != nil { - return err - } - - if err = c.doReverse(rv, dn); err != nil { - return err - } - return nil + if n.IPv6.Bool && n.IPv6Address != nil && !n.IPv6Address.IsUnspecified() { + f_6 := func() error { + aaaa := dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dn, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 7200, + }, + AAAA: n.IPv6Address, } - cmds = append(cmds, cmd.NewFunc("nsupate6", f_6)) + + if err := c.doDnsReverse(aaaa.AAAA.String(), dn, dns.TypeAAAA); err != nil { + return err + } + if err := c.killDnsForward(dn, dns.TypeAAAA); err != nil { + return err + } + if err := c.doDnsForward(aaaa.String()); err != nil { + return err + } + return nil } + cmds = append(cmds, cmd.NewFunc("nsupate6", f_6)) } - if !n.IPv4Address.IsUnspecified() { - ad := net.ParseIP(n.IPv4Address.String()) - if ad != nil { - f_4 := func() error { - a := dns.A{ - Hdr: dns.RR_Header{ - Name: dn, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 7200, - }, - A: ad, - } - - rv, err := dns.ReverseAddr(a.A.String()) - if err != nil { - return err - } - - cli := new(dns.Client) - if c.TSIGName != "" { - cli.TsigSecret = map[string]string{c.TSIGName: c.TSIGKey} - } - - // Update the forward record - msg := new(dns.Msg) - msg.SetUpdate(dns.Fqdn(c.DomainName)) - msg.Ns = append(msg.Ns, &a) - if c.TSIGName != "" { - msg.SetTsig(c.TSIGName, dns.HmacSHA256, 300, time.Now().Unix()) - } - if _, _, err = cli.Exchange(msg, c.DnsServer); err != nil { - return err - } - - if err = c.doReverse(rv, dn); err != nil { - return err - } - return nil + if n.IPv4Address != nil && !n.IPv4Address.IsUnspecified() { + f_4 := func() error { + a := dns.A{ + Hdr: dns.RR_Header{ + Name: dn, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 7200, + }, + A: n.IPv4Address, } - cmds = append(cmds, cmd.NewFunc("nsupate4", f_4)) + + if err := c.doDnsReverse(a.A.String(), dn, dns.TypeA); err != nil { + return err + } + if err := c.killDnsForward(dn, dns.TypeA); err != nil { + return err + } + if err := c.doDnsForward(a.String()); err != nil { + return err + } + return nil } + cmds = append(cmds, cmd.NewFunc("nsupate4", f_4)) } } return c.newCommandSet("NSUPDATE", cmds)