diff --git a/iptables/iptables.go b/iptables/iptables.go index c58db9c..90ccbef 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -114,7 +114,7 @@ func RemoveExistingChain(name string, table Table) error { } // Add forwarding rule to 'filter' table and corresponding nat rule to 'nat' table -func (c *Chain) Forward(action Action, ip net.IP, port int, proto, dest_addr string, dest_port int) error { +func (c *Chain) Forward(action Action, ip net.IP, port int, proto, destAddr string, destPort int) error { daddr := ip.String() if ip.IsUnspecified() { // iptables interprets "0.0.0.0" as "0.0.0.0/32", whereas we @@ -128,7 +128,7 @@ func (c *Chain) Forward(action Action, ip net.IP, port int, proto, dest_addr str "--dport", strconv.Itoa(port), "!", "-i", c.Bridge, "-j", "DNAT", - "--to-destination", net.JoinHostPort(dest_addr, strconv.Itoa(dest_port))); err != nil { + "--to-destination", net.JoinHostPort(destAddr, strconv.Itoa(destPort))); err != nil { return err } else if len(output) != 0 { return &ChainError{Chain: "FORWARD", Output: output} @@ -138,14 +138,25 @@ func (c *Chain) Forward(action Action, ip net.IP, port int, proto, dest_addr str "!", "-i", c.Bridge, "-o", c.Bridge, "-p", proto, - "-d", dest_addr, - "--dport", strconv.Itoa(dest_port), + "-d", destAddr, + "--dport", strconv.Itoa(destPort), "-j", "ACCEPT"); err != nil { return err } else if len(output) != 0 { return &ChainError{Chain: "FORWARD", Output: output} } + if output, err := Raw("-t", string(Nat), string(action), "POSTROUTING", + "-p", proto, + "-s", destAddr, + "-d", destAddr, + "--dport", strconv.Itoa(destPort), + "-j", "MASQUERADE"); err != nil { + return err + } else if len(output) != 0 { + return &ChainError{Chain: "FORWARD", Output: output} + } + return nil } @@ -156,8 +167,8 @@ func (c *Chain) Link(action Action, ip1, ip2 net.IP, port int, proto string) err "-i", c.Bridge, "-o", c.Bridge, "-p", proto, "-s", ip1.String(), - "--dport", strconv.Itoa(port), "-d", ip2.String(), + "--dport", strconv.Itoa(port), "-j", "ACCEPT"); err != nil { return err } else if len(output) != 0 { @@ -167,8 +178,8 @@ func (c *Chain) Link(action Action, ip1, ip2 net.IP, port int, proto string) err "-i", c.Bridge, "-o", c.Bridge, "-p", proto, "-s", ip2.String(), - "--dport", strconv.Itoa(port), "-d", ip1.String(), + "--sport", strconv.Itoa(port), "-j", "ACCEPT"); err != nil { return err } else if len(output) != 0 { @@ -206,18 +217,17 @@ func (c *Chain) Output(action Action, args ...string) error { } func (c *Chain) Remove() error { + // Ignore errors - This could mean the chains were never set up if c.Table == Nat { - // Ignore errors - This could mean the chains were never set up c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL") c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", "127.0.0.0/8") c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL") // Created in versions <= 0.1.6 c.Prerouting(Delete) c.Output(Delete) - - Raw("-t", string(Nat), "-F", c.Name) - Raw("-t", string(Nat), "-X", c.Name) } + Raw("-t", string(c.Table), "-F", c.Name) + Raw("-t", string(c.Table), "-X", c.Name) return nil } diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go new file mode 100644 index 0000000..8aaf429 --- /dev/null +++ b/iptables/iptables_test.go @@ -0,0 +1,204 @@ +package iptables + +import ( + "net" + "os/exec" + "strconv" + "strings" + "testing" +) + +const chainName = "DOCKERTEST" + +var natChain *Chain +var filterChain *Chain + +func TestNewChain(t *testing.T) { + var err error + + natChain, err = NewChain(chainName, "lo", Nat) + if err != nil { + t.Fatal(err) + } + + filterChain, err = NewChain(chainName, "lo", Filter) + if err != nil { + t.Fatal(err) + } +} + +func TestForward(t *testing.T) { + ip := net.ParseIP("192.168.1.1") + port := 1234 + dstAddr := "172.17.0.1" + dstPort := 4321 + proto := "tcp" + + err := natChain.Forward(Insert, ip, port, proto, dstAddr, dstPort) + if err != nil { + t.Fatal(err) + } + + dnatRule := []string{natChain.Name, + "-t", string(natChain.Table), + "!", "-i", filterChain.Bridge, + "-d", ip.String(), + "-p", proto, + "--dport", strconv.Itoa(port), + "-j", "DNAT", + "--to-destination", dstAddr + ":" + strconv.Itoa(dstPort), + } + + if !Exists(dnatRule...) { + t.Fatalf("DNAT rule does not exist") + } + + filterRule := []string{filterChain.Name, + "-t", string(filterChain.Table), + "!", "-i", filterChain.Bridge, + "-o", filterChain.Bridge, + "-d", dstAddr, + "-p", proto, + "--dport", strconv.Itoa(dstPort), + "-j", "ACCEPT", + } + + if !Exists(filterRule...) { + t.Fatalf("filter rule does not exist") + } + + masqRule := []string{"POSTROUTING", + "-t", string(natChain.Table), + "-d", dstAddr, + "-s", dstAddr, + "-p", proto, + "--dport", strconv.Itoa(dstPort), + "-j", "MASQUERADE", + } + + if !Exists(masqRule...) { + t.Fatalf("MASQUERADE rule does not exist") + } +} + +func TestLink(t *testing.T) { + var err error + + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.2") + port := 1234 + proto := "tcp" + + err = filterChain.Link(Append, ip1, ip2, port, proto) + if err != nil { + t.Fatal(err) + } + + rule1 := []string{filterChain.Name, + "-t", string(filterChain.Table), + "-i", filterChain.Bridge, + "-o", filterChain.Bridge, + "-p", proto, + "-s", ip1.String(), + "-d", ip2.String(), + "--dport", strconv.Itoa(port), + "-j", "ACCEPT"} + + if !Exists(rule1...) { + t.Fatalf("rule1 does not exist") + } + + rule2 := []string{filterChain.Name, + "-t", string(filterChain.Table), + "-i", filterChain.Bridge, + "-o", filterChain.Bridge, + "-p", proto, + "-s", ip2.String(), + "-d", ip1.String(), + "--sport", strconv.Itoa(port), + "-j", "ACCEPT"} + + if !Exists(rule2...) { + t.Fatalf("rule2 does not exist") + } +} + +func TestPrerouting(t *testing.T) { + args := []string{ + "-i", "lo", + "-d", "192.168.1.1"} + + err := natChain.Prerouting(Insert, args...) + if err != nil { + t.Fatal(err) + } + + rule := []string{"PREROUTING", + "-t", string(Nat), + "-j", natChain.Name} + + rule = append(rule, args...) + + if !Exists(rule...) { + t.Fatalf("rule does not exist") + } + + delRule := append([]string{"-D"}, rule...) + if _, err = Raw(delRule...); err != nil { + t.Fatal(err) + } +} + +func TestOutput(t *testing.T) { + args := []string{ + "-o", "lo", + "-d", "192.168.1.1"} + + err := natChain.Output(Insert, args...) + if err != nil { + t.Fatal(err) + } + + rule := []string{"OUTPUT", + "-t", string(natChain.Table), + "-j", natChain.Name} + + rule = append(rule, args...) + + if !Exists(rule...) { + t.Fatalf("rule does not exist") + } + + delRule := append([]string{"-D"}, rule...) + if _, err = Raw(delRule...); err != nil { + t.Fatal(err) + } +} + +func TestCleanup(t *testing.T) { + var err error + var rules []byte + + // Cleanup filter/FORWARD first otherwise output of iptables-save is dirty + link := []string{"-t", string(filterChain.Table), + string(Delete), "FORWARD", + "-o", filterChain.Bridge, + "-j", filterChain.Name} + if _, err = Raw(link...); err != nil { + t.Fatal(err) + } + filterChain.Remove() + + err = RemoveExistingChain(chainName, Nat) + if err != nil { + t.Fatal(err) + } + + rules, err = exec.Command("iptables-save").Output() + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(rules), chainName) { + t.Fatalf("Removing chain failed. %s found in iptables-save", chainName) + } +}