diff --git a/cmd/serve.go b/cmd/serve.go index 952c426..aff7c7c 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,16 +5,18 @@ package cmd import ( "errors" "fmt" - "heckel.io/ntfy/log" "io/fs" "math" "net" + "net/netip" "os" "os/signal" "strings" "syscall" "time" + "heckel.io/ntfy/log" + "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/server" @@ -208,16 +210,14 @@ func execServe(c *cli.Context) error { } // Resolve hosts - visitorRequestLimitExemptIPs := make([]string, 0) + visitorRequestLimitExemptIPs := make([]netip.Prefix, 0) for _, host := range visitorRequestLimitExemptHosts { - ips, err := net.LookupIP(host) + ips, err := parseIPHostPrefix(host) if err != nil { log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) continue } - for _, ip := range ips { - visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) - } + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...) } // Run server @@ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) { } } +func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { + // Try parsing as prefix, e.g. 10.0.1.0/24 + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) + return prefixes, nil + } + // Not a prefix, parse as host or IP (LookupHost passes through an IP as is) + ips, err := net.LookupHost(host) + if err != nil { + return nil, err + } + for _, ipStr := range ips { + ip, err := netip.ParseAddr(ipStr) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error()) + } + prefixes = append(prefixes, prefix.Masked()) + } + } + return +} + func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/cmd/serve_test.go b/cmd/serve_test.go index 3b2b9b0..774166c 100644 --- a/cmd/serve_test.go +++ b/cmd/serve_test.go @@ -2,17 +2,19 @@ package cmd import ( "fmt" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/client" - "heckel.io/ntfy/test" - "heckel.io/ntfy/util" "math/rand" "os" "os/exec" "path/filepath" "testing" "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/client" + "heckel.io/ntfy/test" + "heckel.io/ntfy/util" ) func init() { @@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) { require.Equal(t, "mytopic", m.Topic) } +func TestIP_Host_Parsing(t *testing.T) { + cases := map[string]string{ + "1.1.1.1": "1.1.1.1/32", + "fd00::1234": "fd00::1234/128", + "192.168.0.3/24": "192.168.0.0/24", + "10.1.2.3/8": "10.0.0.0/8", + "201:be93::4a6/21": "201:b800::/21", + } + for q, expectedAnswer := range cases { + ips, err := parseIPHostPrefix(q) + require.Nil(t, err) + assert.Equal(t, 1, len(ips)) + assert.Equal(t, expectedAnswer, ips[0].String()) + } +} + func newEmptyFile(t *testing.T) string { filename := filepath.Join(t.TempDir(), "empty") require.Nil(t, os.WriteFile(filename, []byte{}, 0600)) diff --git a/server/config.go b/server/config.go index e117da8..d8fd429 100644 --- a/server/config.go +++ b/server/config.go @@ -2,6 +2,7 @@ package server import ( "io/fs" + "net/netip" "time" ) @@ -92,7 +93,7 @@ type Config struct { VisitorAttachmentDailyBandwidthLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration - VisitorRequestExemptIPAddrs []string + VisitorRequestExemptIPAddrs []netip.Prefix VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration BehindProxy bool @@ -135,7 +136,7 @@ func NewConfig() *Config { VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, - VisitorRequestExemptIPAddrs: make([]string, 0), + VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, BehindProxy: false, diff --git a/server/message_cache.go b/server/message_cache.go index a2f49e7..f443339 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "strings" + "time" + _ "github.com/mattn/go-sqlite3" // SQLite driver "heckel.io/ntfy/log" "heckel.io/ntfy/util" - "strings" - "time" ) var ( @@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentSize, attachmentExpires, attachmentURL, - m.Sender, + m.Sender.String(), m.Encoding, published, ) @@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) { return nil, err } } + senderIP, err := netip.ParseAddr(sender) + if err != nil { + senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0 + } + var att *attachment if attachmentName != "" && attachmentURL != "" { att = &attachment{ @@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: sender, + Sender: senderIP, // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 23c080d..c72debc 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -3,11 +3,17 @@ package server import ( "database/sql" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "net/netip" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + exampleIP1234 = netip.MustParseAddr("1.2.3.4") ) func TestSqliteCache_Messages(t *testing.T) { @@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = "1.2.3.4" + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = "1.2.3.4" + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = "1.2.3.4" + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", @@ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, int64(5000), messages[0].Attachment.Size) require.Equal(t, expires1, messages[0].Attachment.Expires) require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[0].Sender) + require.Equal(t, "1.2.3.4", messages[0].Sender.String()) require.Equal(t, "sending you a car", messages[1].Message) require.Equal(t, "car.jpg", messages[1].Attachment.Name) @@ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, int64(10000), messages[1].Attachment.Size) require.Equal(t, expires2, messages[1].Attachment.Expires) require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[1].Sender) + require.Equal(t, "1.2.3.4", messages[1].Sender.String()) size, err := c.AttachmentBytesUsed("1.2.3.4") require.Nil(t, err) diff --git a/server/server.go b/server/server.go index 276e56f..ef09100 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "os" "path" @@ -42,7 +43,7 @@ type Server struct { smtpServerBackend *smtpBackend smtpSender mailer topics map[string]*topic - visitors map[string]*visitor + visitors map[netip.Addr]*visitor firebaseClient *firebaseClient messages int64 auth auth.Auther @@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) { smtpSender: mailer, topics: topics, auth: auther, - visitors: make(map[string]*visitor), + visitors: make(map[netip.Addr]*visitor), }, nil } @@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0 for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { remoteAddr := r.RemoteAddr - ip, _, err := net.SplitHostPort(remoteAddr) + addrPort, err := netip.ParseAddrPort(remoteAddr) + ip := addrPort.Addr() if err != nil { - ip = remoteAddr // This should not happen in real life; only in tests. + // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified + ip, err = netip.ParseAddr(remoteAddr) + if err != nil { + ip = netip.IPv4Unspecified() + log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err) + } } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // only the right-most address can be trusted (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") - ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) + realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) + if err != nil { + log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error()) + // Fall back to regular remote address if X-Forwarded-For is damaged + } else { + ip = realIP + } } return s.visitorFromIP(ip) } -func (s *Server) visitorFromIP(ip string) *visitor { +func (s *Server) visitorFromIP(ip netip.Addr) *visitor { s.mu.Lock() defer s.mu.Unlock() v, exists := s.visitors[ip] diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 3e034c0..36fd8b5 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -3,13 +3,15 @@ package server import ( "encoding/json" "errors" - "firebase.google.com/go/v4/messaging" "fmt" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" + "net/netip" "strings" "sync" "testing" + + "firebase.google.com/go/v4/messaging" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/auth" ) type testAuther struct { @@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) { sender := &testFirebaseSender{allowed: 2} client := newFirebaseClient(sender, &testAuther{}) - visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") + visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4")) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Equal(t, 1, len(sender.Messages())) diff --git a/server/server_matrix_test.go b/server/server_matrix_test.go index b2f9b1d..4b5a66c 100644 --- a/server/server_matrix_test.go +++ b/server/server_matrix_test.go @@ -1,11 +1,13 @@ package server import ( - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "net/netip" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { @@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) - v := newVisitor(newTestConfig(t), nil, "1.2.3.4") + v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4")) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) diff --git a/server/server_test.go b/server/server_test.go index ea3495d..a9c6ba6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,18 +6,20 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" "io" "log" "math/rand" "net/http" "net/http/httptest" + "net/netip" "path/filepath" "strings" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "heckel.io/ntfy/auth" "heckel.io/ntfy/util" @@ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) { messages = toMessages(t, response.Body.String()) require.Equal(t, 1, len(messages)) require.Equal(t, "a message", messages[0].Message) - require.Equal(t, "", messages[0].Sender) // Never return the sender! + require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender! messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true) require.Nil(t, err) require.Equal(t, 1, len(messages)) require.Equal(t, "a message", messages[0].Message) - require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though! + require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though! } func TestServer_PublishAtWithCacheError(t *testing.T) { @@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { c := newTestConfig(t) - c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() + c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() s := newTestServer(t, c) for i := 0; i < 65; i++ { // > 60 response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) @@ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) { require.Equal(t, int64(5000), msg.Attachment.Size) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") - require.Equal(t, "", msg.Sender) // Should never be returned + require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) // GET @@ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { require.Equal(t, int64(21), msg.Attachment.Size) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") - require.Equal(t, "", msg.Sender) // Should never be returned + require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") @@ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) { require.Equal(t, "", msg.Attachment.Type) require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Expires) - require.Equal(t, "", msg.Sender) + require.Equal(t, netip.Addr{}, msg.Sender) // Slightly unrelated cross-test: make sure we don't add an owner for external attachments size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") @@ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) { require.Equal(t, "", msg.Attachment.Type) require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Expires) - require.Equal(t, "", msg.Sender) + require.Equal(t, netip.Addr{}, msg.Sender) } func TestServer_PublishAttachmentBadURL(t *testing.T) { @@ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty! v := s.visitor(r) - require.Equal(t, "8.9.10.11", v.ip) + require.Equal(t, "8.9.10.11", v.ip.String()) } func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { @@ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.1.1.1") v := s.visitor(r) - require.Equal(t, "1.1.1.1", v.ip) + require.Equal(t, "1.1.1.1", v.ip.String()) } func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { @@ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") v := s.visitor(r) - require.Equal(t, "234.5.2.1", v.ip) + require.Equal(t, "234.5.2.1", v.ip.String()) } func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { diff --git a/server/smtp_sender.go b/server/smtp_sender.go index ecefd9c..7d6b751 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { if err != nil { return err } - message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) + message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m) if err != nil { return err } diff --git a/server/types.go b/server/types.go index b217b9d..ce57c9b 100644 --- a/server/types.go +++ b/server/types.go @@ -1,9 +1,11 @@ package server import ( - "heckel.io/ntfy/util" "net/http" + "net/netip" "time" + + "heckel.io/ntfy/util" ) // List of possible events @@ -33,7 +35,7 @@ type message struct { Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` - Sender string `json:"-"` // IP address of uploader, used for rate limiting + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes } diff --git a/server/visitor.go b/server/visitor.go index 5a8e186..cd120c4 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,10 +2,12 @@ package server import ( "errors" - "golang.org/x/time/rate" - "heckel.io/ntfy/util" + "net/netip" "sync" "time" + + "golang.org/x/time/rate" + "heckel.io/ntfy/util" ) const ( @@ -23,7 +25,7 @@ var ( type visitor struct { config *Config messageCache *messageCache - ip string + ip netip.Addr requests *rate.Limiter emails *rate.Limiter subscriptions util.Limiter @@ -40,7 +42,7 @@ type visitorStats struct { VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` } -func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { +func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor { return &visitor{ config: conf, messageCache: messageCache, @@ -115,7 +117,7 @@ func (v *visitor) Stale() bool { } func (v *visitor) Stats() (*visitorStats, error) { - attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) + attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) if err != nil { return nil, err } diff --git a/util/util.go b/util/util.go index 0507918..86566ed 100644 --- a/util/util.go +++ b/util/util.go @@ -5,16 +5,18 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gabriel-vasile/mimetype" - "golang.org/x/term" "io" "math/rand" + "net/netip" "os" "regexp" "strconv" "strings" "sync" "time" + + "github.com/gabriel-vasile/mimetype" + "golang.org/x/term" ) const ( @@ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } +// ContainsIP returns true if any one of the of prefixes contains the ip. +func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool { + for _, s := range haystack { + if s.Contains(needle) { + return true + } + } + return false +} + // ContainsAll returns true if all needles are contained in haystack func ContainsAll[T comparable](haystack []T, needles []T) bool { matches := 0 diff --git a/util/util_test.go b/util/util_test.go index 807d599..a435c06 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -1,10 +1,12 @@ package util import ( - "github.com/stretchr/testify/require" + "net/netip" "os" "path/filepath" "testing" + + "github.com/stretchr/testify/require" ) func TestRandomString(t *testing.T) { @@ -42,6 +44,13 @@ func TestContains(t *testing.T) { require.False(t, Contains(s, 3)) } +func TestContainsIP(t *testing.T) { + require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1"))) + require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876"))) + require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1"))) + require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1"))) +} + func TestSplitNoEmpty(t *testing.T) { require.Equal(t, []string{}, SplitNoEmpty("", ",")) require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))