Allow multiple values for ip-whitelist

Allow the value of ip-whitelist to consist of multiple space-separated
addresses or CIDRs.

Updates #290
This commit is contained in:
Cameron Moore 2019-01-02 16:50:23 -06:00
parent 753734428f
commit f056f94305
2 changed files with 44 additions and 25 deletions

View file

@ -168,41 +168,36 @@ func CheckScalrSignature(headers map[string]interface{}, body []byte, signingKey
func CheckIPWhitelist(remoteAddr string, ipRange string) (bool, error) {
// Extract IP address from remote address.
ip := remoteAddr
// IPv6 addresses will likely be surrounded by [].
ip := strings.Trim(remoteAddr, " []")
if strings.LastIndex(remoteAddr, ":") != -1 {
ip = remoteAddr[0:strings.LastIndex(remoteAddr, ":")]
if i := strings.LastIndex(ip, ":"); i != -1 {
ip = ip[:i]
}
ip = strings.TrimSpace(ip)
// IPv6 addresses will likely be surrounded by [], so don't forget to remove those.
if strings.HasPrefix(ip, "[") && strings.HasSuffix(ip, "]") {
ip = ip[1 : len(ip)-1]
}
parsedIP := net.ParseIP(strings.TrimSpace(ip))
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false, fmt.Errorf("invalid IP address found in remote address '%s'", remoteAddr)
}
// Extract IP range in CIDR form. If a single IP address is provided, turn it into CIDR form.
for _, r := range strings.Fields(ipRange) {
// Extract IP range in CIDR form. If a single IP address is provided, turn it into CIDR form.
ipRange = strings.TrimSpace(ipRange)
if !strings.Contains(r, "/") {
r = r + "/32"
}
if !strings.Contains(ipRange, "/") {
ipRange = ipRange + "/32"
_, cidr, err := net.ParseCIDR(r)
if err != nil {
return false, err
}
if cidr.Contains(parsedIP) {
return true, nil
}
}
_, cidr, err := net.ParseCIDR(ipRange)
if err != nil {
return false, err
}
return cidr.Contains(parsedIP), nil
return false, nil
}
// ReplaceParameter replaces parameter value with the passed value in the passed map

View file

@ -108,6 +108,30 @@ func TestCheckScalrSignature(t *testing.T) {
}
}
var checkIPWhitelistTests = []struct {
addr string
ipRange string
expect bool
ok bool
}{
{"[ 10.0.0.1:1234 ] ", " 10.0.0.1 ", true, true},
{"[ 10.0.0.1:1234 ] ", " 10.0.0.0 ", false, true},
{"[ 10.0.0.1:1234 ] ", " 10.0.0.1 10.0.0.1 ", true, true},
{"[ 10.0.0.1:1234 ] ", " 10.0.0.0/31 ", true, true},
{" [2001:db8:1:2::1:1234] ", " 2001:db8:1::/48 ", true, true},
{" [2001:db8:1:2::1:1234] ", " 2001:db8:1::/48 2001:db8:1::/64", true, true},
{" [2001:db8:1:2::1:1234] ", " 2001:db8:1::/64 ", false, true},
}
func TestCheckIPWhitelist(t *testing.T) {
for _, tt := range checkIPWhitelistTests {
result, err := CheckIPWhitelist(tt.addr, tt.ipRange)
if (err == nil) != tt.ok || result != tt.expect {
t.Errorf("ip whitelist test failed {%q, %q}:\nwant {expect:%#v, ok:%#v},\ngot {result:%#v, ok:%#v}", tt.addr, tt.ipRange, tt.expect, tt.ok, result, err)
}
}
}
var extractParameterTests = []struct {
s string
params interface{}
@ -129,7 +153,7 @@ var extractParameterTests = []struct {
{"a.501.b", map[string]interface{}{"a": []interface{}{map[string]interface{}{"b": "y"}, map[string]interface{}{"b": "z"}}}, "", false}, // non-existent slice index
{"a.502.b", map[string]interface{}{"a": []interface{}{}}, "", false}, // non-existent slice index
{"a.b.503", map[string]interface{}{"a": map[string]interface{}{"b": []interface{}{"x", "y", "z"}}}, "", false}, // trailing, non-existent slice index
{"a.b", interface{}("a"), "", false}, // non-map, non-slice input
{"a.b", interface{}("a"), "", false}, // non-map, non-slice input
}
func TestExtractParameter(t *testing.T) {