From c2382d29a1c5e28b4682047f1eeadc78c644d07f Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Wed, 5 Oct 2022 15:42:07 -0500 Subject: [PATCH 1/8] refactor visitor IPs and allow exempting IP Ranges Use netip.Addr instead of storing addresses as strings. This requires conversions at the database level and in tests, but is more memory efficient otherwise, and facilitates the following. Parse rate limit exemptions as netip.Prefix. This allows storing IP ranges in the exemption list. Regular IP addresses (entered explicitly or resolved from hostnames) are IPV4/32, denoting a range of one address. --- cmd/serve.go | 37 ++++++++++++++++++++++++++++++---- server/config.go | 5 +++-- server/message_cache.go | 10 +++++---- server/message_cache_test.go | 12 ++++++----- server/server.go | 27 ++++++++++++++++--------- server/server_firebase_test.go | 10 +++++---- server/server_matrix_test.go | 6 ++++-- server/server_test.go | 6 ++++-- server/smtp_sender.go | 2 +- server/types.go | 6 ++++-- server/visitor.go | 12 ++++++----- util/util.go | 15 ++++++++++++-- 12 files changed, 106 insertions(+), 42 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 952c426..3cc0114 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,16 +5,18 @@ package cmd import ( "errors" "fmt" - "heckel.io/ntfy/log" "io/fs" "math" "net" + "net/netip" "os" "os/signal" "strings" "syscall" "time" + "heckel.io/ntfy/log" + "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "heckel.io/ntfy/server" @@ -208,15 +210,15 @@ func execServe(c *cli.Context) error { } // Resolve hosts - visitorRequestLimitExemptIPs := make([]string, 0) + visitorRequestLimitExemptIPs := make([]netip.Prefix, 0) for _, host := range visitorRequestLimitExemptHosts { - ips, err := net.LookupIP(host) + ips, err := parseIPHostPrefix(host) if err != nil { log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) continue } for _, ip := range ips { - visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip) } } @@ -303,6 +305,33 @@ func sigHandlerConfigReload(config string) { } } +func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { + //try parsing as prefix + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter + return prefixes, nil // success + } + + // not a prefix, parse as host or IP + // LookupHost forwards through if it's an IP + ips, err := net.LookupHost(host) + if err == nil { + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) + } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + } + } + } + return +} + + func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/server/config.go b/server/config.go index e117da8..d8fd429 100644 --- a/server/config.go +++ b/server/config.go @@ -2,6 +2,7 @@ package server import ( "io/fs" + "net/netip" "time" ) @@ -92,7 +93,7 @@ type Config struct { VisitorAttachmentDailyBandwidthLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration - VisitorRequestExemptIPAddrs []string + VisitorRequestExemptIPAddrs []netip.Prefix VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration BehindProxy bool @@ -135,7 +136,7 @@ func NewConfig() *Config { VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, - VisitorRequestExemptIPAddrs: make([]string, 0), + VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, BehindProxy: false, diff --git a/server/message_cache.go b/server/message_cache.go index a2f49e7..4845a91 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "strings" + "time" + _ "github.com/mattn/go-sqlite3" // SQLite driver "heckel.io/ntfy/log" "heckel.io/ntfy/util" - "strings" - "time" ) var ( @@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentSize, attachmentExpires, attachmentURL, - m.Sender, + m.Sender.String(), m.Encoding, published, ) @@ -477,7 +479,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: sender, + Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 23c080d..ea9580a 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -3,11 +3,13 @@ package server import ( "database/sql" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "net/netip" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSqliteCache_Messages(t *testing.T) { @@ -281,7 +283,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -294,7 +296,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -307,7 +309,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = "1.2.3.4" + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", diff --git a/server/server.go b/server/server.go index 276e56f..0b9cb21 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "os" "path" @@ -42,7 +43,7 @@ type Server struct { smtpServerBackend *smtpBackend smtpSender mailer topics map[string]*topic - visitors map[string]*visitor + visitors map[netip.Addr]*visitor firebaseClient *firebaseClient messages int64 auth auth.Auther @@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) { smtpSender: mailer, topics: topics, auth: auther, - visitors: make(map[string]*visitor), + visitors: make(map[netip.Addr]*visitor), }, nil } @@ -642,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() - m.Sender = v.ip // Important for rate limiting } + m.Sender = v.ip // Important for rate limiting actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { m.Actions, err = parseActions(actionsStr) @@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1436,21 +1437,29 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { remoteAddr := r.RemoteAddr - ip, _, err := net.SplitHostPort(remoteAddr) + ipport, err := netip.ParseAddrPort(remoteAddr) + ip := ipport.Addr() if err != nil { - ip = remoteAddr // This should not happen in real life; only in tests. + ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error. } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, // only the right-most address can be trusted (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") - ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) + myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) + if err != nil { + log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err) + // fall back to regular remote address if x forwarded for is damaged + } else { + ip = myip + } + } return s.visitorFromIP(ip) } -func (s *Server) visitorFromIP(ip string) *visitor { +func (s *Server) visitorFromIP(ip netip.Addr) *visitor { s.mu.Lock() defer s.mu.Unlock() v, exists := s.visitors[ip] diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 3e034c0..36fd8b5 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -3,13 +3,15 @@ package server import ( "encoding/json" "errors" - "firebase.google.com/go/v4/messaging" "fmt" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" + "net/netip" "strings" "sync" "testing" + + "firebase.google.com/go/v4/messaging" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/auth" ) type testAuther struct { @@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { func TestToFirebaseSender_Abuse(t *testing.T) { sender := &testFirebaseSender{allowed: 2} client := newFirebaseClient(sender, &testAuther{}) - visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") + visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4")) require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) require.Equal(t, 1, len(sender.Messages())) diff --git a/server/server_matrix_test.go b/server/server_matrix_test.go index b2f9b1d..4b5a66c 100644 --- a/server/server_matrix_test.go +++ b/server/server_matrix_test.go @@ -1,11 +1,13 @@ package server import ( - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "net/netip" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { @@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) { func TestMatrix_WriteMatrixError(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) - v := newVisitor(newTestConfig(t), nil, "1.2.3.4") + v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4")) require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) require.Equal(t, 200, w.Result().StatusCode) require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) diff --git a/server/server_test.go b/server/server_test.go index ea3495d..5a3dcc8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,18 +6,20 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" "io" "log" "math/rand" "net/http" "net/http/httptest" + "net/netip" "path/filepath" "strings" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "heckel.io/ntfy/auth" "heckel.io/ntfy/util" @@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { c := newTestConfig(t) - c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() + c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() s := newTestServer(t, c) for i := 0; i < 65; i++ { // > 60 response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) diff --git a/server/smtp_sender.go b/server/smtp_sender.go index ecefd9c..7d6b751 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { if err != nil { return err } - message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) + message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m) if err != nil { return err } diff --git a/server/types.go b/server/types.go index b217b9d..ce57c9b 100644 --- a/server/types.go +++ b/server/types.go @@ -1,9 +1,11 @@ package server import ( - "heckel.io/ntfy/util" "net/http" + "net/netip" "time" + + "heckel.io/ntfy/util" ) // List of possible events @@ -33,7 +35,7 @@ type message struct { Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` - Sender string `json:"-"` // IP address of uploader, used for rate limiting + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes } diff --git a/server/visitor.go b/server/visitor.go index 5a8e186..cd120c4 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,10 +2,12 @@ package server import ( "errors" - "golang.org/x/time/rate" - "heckel.io/ntfy/util" + "net/netip" "sync" "time" + + "golang.org/x/time/rate" + "heckel.io/ntfy/util" ) const ( @@ -23,7 +25,7 @@ var ( type visitor struct { config *Config messageCache *messageCache - ip string + ip netip.Addr requests *rate.Limiter emails *rate.Limiter subscriptions util.Limiter @@ -40,7 +42,7 @@ type visitorStats struct { VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` } -func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { +func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor { return &visitor{ config: conf, messageCache: messageCache, @@ -115,7 +117,7 @@ func (v *visitor) Stale() bool { } func (v *visitor) Stats() (*visitorStats, error) { - attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) + attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) if err != nil { return nil, err } diff --git a/util/util.go b/util/util.go index 0507918..de4b908 100644 --- a/util/util.go +++ b/util/util.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gabriel-vasile/mimetype" - "golang.org/x/term" "io" "math/rand" "os" @@ -15,6 +13,9 @@ import ( "strings" "sync" "time" + + "github.com/gabriel-vasile/mimetype" + "golang.org/x/term" ) const ( @@ -45,6 +46,16 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } +// ContainsContains returns true if any element of haystack .Contains(needle). +func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool { + for _, s := range haystack { + if s.Contains(needle) { + return true + } + } + return false +} + // ContainsAll returns true if all needles are contained in haystack func ContainsAll[T comparable](haystack []T, needles []T) bool { matches := 0 From de2ca33700442c134724a768b94f710555feeaa5 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 16:16:20 -0500 Subject: [PATCH 2/8] recommended fixes [1 of 2] --- cmd/serve.go | 45 ++++++++++++++++++++--------------------- server/message_cache.go | 7 ++++++- server/server.go | 10 ++++----- util/util.go | 5 +++-- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 3cc0114..be772e6 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -306,32 +306,31 @@ func sigHandlerConfigReload(config string) { } func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { - //try parsing as prefix - prefix, err := netip.ParsePrefix(host) - if err == nil { - prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter - return prefixes, nil // success - } + //try parsing as prefix + prefix, err := netip.ParsePrefix(host) + if err == nil { + prefixes = append(prefixes, prefix.Masked()) // Masked returns the prefix in its canonical form, the same for every ip in the range. This exists for ease of debugging. For example, 10.1.2.3/16 is 10.1.0.0/16. + return prefixes, nil // success + } - // not a prefix, parse as host or IP - // LookupHost forwards through if it's an IP - ips, err := net.LookupHost(host) - if err == nil { - for _, i := range ips { - ip, err := netip.ParseAddr(i) - if err == nil { - prefix, err := ip.Prefix(ip.BitLen()) - if err != nil { - return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) - } - prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip - } - } - } - return + // not a prefix, parse as host or IP + // LookupHost forwards through if it's an IP + ips, err := net.LookupHost(host) + if err == nil { + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) + } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + } + } + } + return } - func reloadLogLevel(inputSource altsrc.InputSourceContext) { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/server/message_cache.go b/server/message_cache.go index 4845a91..f443339 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -456,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) { return nil, err } } + senderIP, err := netip.ParseAddr(sender) + if err != nil { + senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0 + } + var att *attachment if attachmentName != "" && attachmentURL != "" { att = &attachment{ @@ -479,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Icon: icon, Actions: actions, Attachment: att, - Sender: netip.MustParseAddr(sender), // Must parse assuming database must be correct + Sender: senderIP, // Must parse assuming database must be correct Encoding: encoding, }) } diff --git a/server/server.go b/server/server.go index 0b9cb21..6b80173 100644 --- a/server/server.go +++ b/server/server.go @@ -643,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, errHTTPBadRequestDelayTooLarge } m.Time = delay.Unix() + m.Sender = v.ip // Important for rate limiting } - m.Sender = v.ip // Important for rate limiting actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { m.Actions, err = parseActions(actionsStr) @@ -1220,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor + v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0 for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): @@ -1287,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) { + if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests @@ -1449,8 +1449,8 @@ func (s *Server) visitor(r *http.Request) *visitor { ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) if err != nil { - log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err) - // fall back to regular remote address if x forwarded for is damaged + log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error()) + // fall back to regular remote address if X-Forwarded-For is damaged } else { ip = myip } diff --git a/util/util.go b/util/util.go index de4b908..86566ed 100644 --- a/util/util.go +++ b/util/util.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "math/rand" + "net/netip" "os" "regexp" "strconv" @@ -46,8 +47,8 @@ func Contains[T comparable](haystack []T, needle T) bool { return false } -// ContainsContains returns true if any element of haystack .Contains(needle). -func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool { +// ContainsIP returns true if any one of the of prefixes contains the ip. +func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool { for _, s := range haystack { if s.Contains(needle) { return true From 511d3f6aafb2c71992751c2204f863929fc64e39 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 16:24:11 -0500 Subject: [PATCH 3/8] recommended fixes [2 of 2] --- cmd/serve.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index be772e6..36dbf31 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -314,18 +314,19 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { } // not a prefix, parse as host or IP - // LookupHost forwards through if it's an IP + // LookupHost passes through an IP as is ips, err := net.LookupHost(host) - if err == nil { - for _, i := range ips { - ip, err := netip.ParseAddr(i) - if err == nil { - prefix, err := ip.Prefix(ip.BitLen()) - if err != nil { - return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error())) - } - prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + if err != nil { + return nil, err + } + for _, i := range ips { + ip, err := netip.ParseAddr(i) + if err == nil { + prefix, err := ip.Prefix(ip.BitLen()) + if err != nil { + return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error()) } + prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip } } return From 3b2929467914b910507ae3ae963909c387142ce3 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 20:27:22 -0500 Subject: [PATCH 4/8] minor modification to tests involving ips --- server/message_cache_test.go | 14 +++++++++----- server/server.go | 7 ++++++- server/server_test.go | 18 +++++++++--------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/server/message_cache_test.go b/server/message_cache_test.go index ea9580a..c72debc 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -12,6 +12,10 @@ import ( "github.com/stretchr/testify/require" ) +var ( + exampleIP1234 = netip.MustParseAddr("1.2.3.4") +) + func TestSqliteCache_Messages(t *testing.T) { testCacheMessages(t, newSqliteTestCache(t)) } @@ -283,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = netip.MustParseAddr("1.2.3.4") + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -296,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = netip.MustParseAddr("1.2.3.4") + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -309,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = netip.MustParseAddr("1.2.3.4") + m.Sender = exampleIP1234 m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", @@ -329,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, int64(5000), messages[0].Attachment.Size) require.Equal(t, expires1, messages[0].Attachment.Expires) require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[0].Sender) + require.Equal(t, "1.2.3.4", messages[0].Sender.String()) require.Equal(t, "sending you a car", messages[1].Message) require.Equal(t, "car.jpg", messages[1].Attachment.Name) @@ -337,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, int64(10000), messages[1].Attachment.Size) require.Equal(t, expires2, messages[1].Attachment.Expires) require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[1].Sender) + require.Equal(t, "1.2.3.4", messages[1].Sender.String()) size, err := c.AttachmentBytesUsed("1.2.3.4") require.Nil(t, err) diff --git a/server/server.go b/server/server.go index 6b80173..a4ef193 100644 --- a/server/server.go +++ b/server/server.go @@ -1440,7 +1440,12 @@ func (s *Server) visitor(r *http.Request) *visitor { ipport, err := netip.ParseAddrPort(remoteAddr) ip := ipport.Addr() if err != nil { - ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error. + // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified + ip, err = netip.ParseAddr(remoteAddr) + if err != nil { + ip = netip.IPv4Unspecified() + log.Error("Unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err) + } } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, diff --git a/server/server_test.go b/server/server_test.go index 5a3dcc8..a9c6ba6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -294,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) { messages = toMessages(t, response.Body.String()) require.Equal(t, 1, len(messages)) require.Equal(t, "a message", messages[0].Message) - require.Equal(t, "", messages[0].Sender) // Never return the sender! + require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender! messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true) require.Nil(t, err) require.Equal(t, 1, len(messages)) require.Equal(t, "a message", messages[0].Message) - require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though! + require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though! } func TestServer_PublishAtWithCacheError(t *testing.T) { @@ -1134,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) { require.Equal(t, int64(5000), msg.Attachment.Size) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") - require.Equal(t, "", msg.Sender) // Should never be returned + require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) // GET @@ -1170,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { require.Equal(t, int64(21), msg.Attachment.Size) require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") - require.Equal(t, "", msg.Sender) // Should never be returned + require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") @@ -1197,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) { require.Equal(t, "", msg.Attachment.Type) require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Expires) - require.Equal(t, "", msg.Sender) + require.Equal(t, netip.Addr{}, msg.Sender) // Slightly unrelated cross-test: make sure we don't add an owner for external attachments size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") @@ -1218,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) { require.Equal(t, "", msg.Attachment.Type) require.Equal(t, int64(0), msg.Attachment.Size) require.Equal(t, int64(0), msg.Attachment.Expires) - require.Equal(t, "", msg.Sender) + require.Equal(t, netip.Addr{}, msg.Sender) } func TestServer_PublishAttachmentBadURL(t *testing.T) { @@ -1393,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty! v := s.visitor(r) - require.Equal(t, "8.9.10.11", v.ip) + require.Equal(t, "8.9.10.11", v.ip.String()) } func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { @@ -1404,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.1.1.1") v := s.visitor(r) - require.Equal(t, "1.1.1.1", v.ip) + require.Equal(t, "1.1.1.1", v.ip.String()) } func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { @@ -1415,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") v := s.visitor(r) - require.Equal(t, "234.5.2.1", v.ip) + require.Equal(t, "234.5.2.1", v.ip.String()) } func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { From 4edc625331a5050a5aa6fe8d4a5b318010561178 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 20:36:01 -0500 Subject: [PATCH 5/8] fix lint --- cmd/serve.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 36dbf31..6d7754e 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -217,9 +217,7 @@ func execServe(c *cli.Context) error { log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) continue } - for _, ip := range ips { - visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip) - } + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...) } // Run server From bc5060b218e65d334a79a5c2ea1db2a09f55a8f2 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 21:15:45 -0500 Subject: [PATCH 6/8] test new config parsing --- cmd/serve_test.go | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/cmd/serve_test.go b/cmd/serve_test.go index 3b2b9b0..774166c 100644 --- a/cmd/serve_test.go +++ b/cmd/serve_test.go @@ -2,17 +2,19 @@ package cmd import ( "fmt" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/client" - "heckel.io/ntfy/test" - "heckel.io/ntfy/util" "math/rand" "os" "os/exec" "path/filepath" "testing" "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/client" + "heckel.io/ntfy/test" + "heckel.io/ntfy/util" ) func init() { @@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) { require.Equal(t, "mytopic", m.Topic) } +func TestIP_Host_Parsing(t *testing.T) { + cases := map[string]string{ + "1.1.1.1": "1.1.1.1/32", + "fd00::1234": "fd00::1234/128", + "192.168.0.3/24": "192.168.0.0/24", + "10.1.2.3/8": "10.0.0.0/8", + "201:be93::4a6/21": "201:b800::/21", + } + for q, expectedAnswer := range cases { + ips, err := parseIPHostPrefix(q) + require.Nil(t, err) + assert.Equal(t, 1, len(ips)) + assert.Equal(t, expectedAnswer, ips[0].String()) + } +} + func newEmptyFile(t *testing.T) string { filename := filepath.Join(t.TempDir(), "empty") require.Nil(t, os.WriteFile(filename, []byte{}, 0600)) From 1672322fc10d0304080e1e3e866d31c26bb71557 Mon Sep 17 00:00:00 2001 From: Karmanyaah Malhotra Date: Fri, 7 Oct 2022 21:22:22 -0500 Subject: [PATCH 7/8] test ContainsIP utility --- util/util_test.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/util/util_test.go b/util/util_test.go index 807d599..a435c06 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -1,10 +1,12 @@ package util import ( - "github.com/stretchr/testify/require" + "net/netip" "os" "path/filepath" "testing" + + "github.com/stretchr/testify/require" ) func TestRandomString(t *testing.T) { @@ -42,6 +44,13 @@ func TestContains(t *testing.T) { require.False(t, Contains(s, 3)) } +func TestContainsIP(t *testing.T) { + require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1"))) + require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876"))) + require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1"))) + require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1"))) +} + func TestSplitNoEmpty(t *testing.T) { require.Equal(t, []string{}, SplitNoEmpty("", ",")) require.Equal(t, []string{}, SplitNoEmpty(",,,", ",")) From 16ad94441b537f83f939d54ea7da41005581f153 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Sat, 8 Oct 2022 17:58:05 -0400 Subject: [PATCH 8/8] Personal preference --- cmd/serve.go | 16 +++++++--------- server/server.go | 13 ++++++------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 6d7754e..aff7c7c 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -304,27 +304,25 @@ func sigHandlerConfigReload(config string) { } func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { - //try parsing as prefix + // Try parsing as prefix, e.g. 10.0.1.0/24 prefix, err := netip.ParsePrefix(host) if err == nil { - prefixes = append(prefixes, prefix.Masked()) // Masked returns the prefix in its canonical form, the same for every ip in the range. This exists for ease of debugging. For example, 10.1.2.3/16 is 10.1.0.0/16. - return prefixes, nil // success + prefixes = append(prefixes, prefix.Masked()) + return prefixes, nil } - - // not a prefix, parse as host or IP - // LookupHost passes through an IP as is + // Not a prefix, parse as host or IP (LookupHost passes through an IP as is) ips, err := net.LookupHost(host) if err != nil { return nil, err } - for _, i := range ips { - ip, err := netip.ParseAddr(i) + for _, ipStr := range ips { + ip, err := netip.ParseAddr(ipStr) if err == nil { prefix, err := ip.Prefix(ip.BitLen()) if err != nil { return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error()) } - prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip + prefixes = append(prefixes, prefix.Masked()) } } return diff --git a/server/server.go b/server/server.go index a4ef193..ef09100 100644 --- a/server/server.go +++ b/server/server.go @@ -1437,14 +1437,14 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { remoteAddr := r.RemoteAddr - ipport, err := netip.ParseAddrPort(remoteAddr) - ip := ipport.Addr() + addrPort, err := netip.ParseAddrPort(remoteAddr) + ip := addrPort.Addr() if err != nil { // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified ip, err = netip.ParseAddr(remoteAddr) if err != nil { ip = netip.IPv4Unspecified() - log.Error("Unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err) + log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err) } } if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { @@ -1452,14 +1452,13 @@ func (s *Server) visitor(r *http.Request) *visitor { // only the right-most address can be trusted (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") - myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) + realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) if err != nil { log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error()) - // fall back to regular remote address if X-Forwarded-For is damaged + // Fall back to regular remote address if X-Forwarded-For is damaged } else { - ip = myip + ip = realIP } - } return s.visitorFromIP(ip) }