Merge branch 'ip-range-exempt'
This commit is contained in:
		
						commit
						cbc912d1e3
					
				
					 14 changed files with 161 additions and 60 deletions
				
			
		
							
								
								
									
										37
									
								
								cmd/serve.go
									
										
									
									
									
								
							
							
						
						
									
										37
									
								
								cmd/serve.go
									
										
									
									
									
								
							|  | @ -5,16 +5,18 @@ package cmd | |||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"io/fs" | ||||
| 	"math" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"strings" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"heckel.io/ntfy/log" | ||||
| 
 | ||||
| 	"github.com/urfave/cli/v2" | ||||
| 	"github.com/urfave/cli/v2/altsrc" | ||||
| 	"heckel.io/ntfy/server" | ||||
|  | @ -208,16 +210,14 @@ func execServe(c *cli.Context) error { | |||
| 	} | ||||
| 
 | ||||
| 	// Resolve hosts | ||||
| 	visitorRequestLimitExemptIPs := make([]string, 0) | ||||
| 	visitorRequestLimitExemptIPs := make([]netip.Prefix, 0) | ||||
| 	for _, host := range visitorRequestLimitExemptHosts { | ||||
| 		ips, err := net.LookupIP(host) | ||||
| 		ips, err := parseIPHostPrefix(host) | ||||
| 		if err != nil { | ||||
| 			log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		for _, ip := range ips { | ||||
| 			visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) | ||||
| 		} | ||||
| 		visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...) | ||||
| 	} | ||||
| 
 | ||||
| 	// Run server | ||||
|  | @ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { | ||||
| 	// Try parsing as prefix, e.g. 10.0.1.0/24 | ||||
| 	prefix, err := netip.ParsePrefix(host) | ||||
| 	if err == nil { | ||||
| 		prefixes = append(prefixes, prefix.Masked()) | ||||
| 		return prefixes, nil | ||||
| 	} | ||||
| 	// Not a prefix, parse as host or IP (LookupHost passes through an IP as is) | ||||
| 	ips, err := net.LookupHost(host) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	for _, ipStr := range ips { | ||||
| 		ip, err := netip.ParseAddr(ipStr) | ||||
| 		if err == nil { | ||||
| 			prefix, err := ip.Prefix(ip.BitLen()) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error()) | ||||
| 			} | ||||
| 			prefixes = append(prefixes, prefix.Masked()) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func reloadLogLevel(inputSource altsrc.InputSourceContext) { | ||||
| 	newLevelStr, err := inputSource.String("log-level") | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -2,17 +2,19 @@ package cmd | |||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/client" | ||||
| 	"heckel.io/ntfy/test" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"math/rand" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/client" | ||||
| 	"heckel.io/ntfy/test" | ||||
| 	"heckel.io/ntfy/util" | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
|  | @ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) { | |||
| 	require.Equal(t, "mytopic", m.Topic) | ||||
| } | ||||
| 
 | ||||
| func TestIP_Host_Parsing(t *testing.T) { | ||||
| 	cases := map[string]string{ | ||||
| 		"1.1.1.1":          "1.1.1.1/32", | ||||
| 		"fd00::1234":       "fd00::1234/128", | ||||
| 		"192.168.0.3/24":   "192.168.0.0/24", | ||||
| 		"10.1.2.3/8":       "10.0.0.0/8", | ||||
| 		"201:be93::4a6/21": "201:b800::/21", | ||||
| 	} | ||||
| 	for q, expectedAnswer := range cases { | ||||
| 		ips, err := parseIPHostPrefix(q) | ||||
| 		require.Nil(t, err) | ||||
| 		assert.Equal(t, 1, len(ips)) | ||||
| 		assert.Equal(t, expectedAnswer, ips[0].String()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newEmptyFile(t *testing.T) string { | ||||
| 	filename := filepath.Join(t.TempDir(), "empty") | ||||
| 	require.Nil(t, os.WriteFile(filename, []byte{}, 0600)) | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"io/fs" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
|  | @ -92,7 +93,7 @@ type Config struct { | |||
| 	VisitorAttachmentDailyBandwidthLimit int | ||||
| 	VisitorRequestLimitBurst             int | ||||
| 	VisitorRequestLimitReplenish         time.Duration | ||||
| 	VisitorRequestExemptIPAddrs          []string | ||||
| 	VisitorRequestExemptIPAddrs          []netip.Prefix | ||||
| 	VisitorEmailLimitBurst               int | ||||
| 	VisitorEmailLimitReplenish           time.Duration | ||||
| 	BehindProxy                          bool | ||||
|  | @ -135,7 +136,7 @@ func NewConfig() *Config { | |||
| 		VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, | ||||
| 		VisitorRequestLimitBurst:             DefaultVisitorRequestLimitBurst, | ||||
| 		VisitorRequestLimitReplenish:         DefaultVisitorRequestLimitReplenish, | ||||
| 		VisitorRequestExemptIPAddrs:          make([]string, 0), | ||||
| 		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0), | ||||
| 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst, | ||||
| 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish, | ||||
| 		BehindProxy:                          false, | ||||
|  |  | |||
|  | @ -5,11 +5,13 @@ import ( | |||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	_ "github.com/mattn/go-sqlite3" // SQLite driver | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
|  | @ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error { | |||
| 			attachmentSize, | ||||
| 			attachmentExpires, | ||||
| 			attachmentURL, | ||||
| 			m.Sender, | ||||
| 			m.Sender.String(), | ||||
| 			m.Encoding, | ||||
| 			published, | ||||
| 		) | ||||
|  | @ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) { | |||
| 				return nil, err | ||||
| 			} | ||||
| 		} | ||||
| 		senderIP, err := netip.ParseAddr(sender) | ||||
| 		if err != nil { | ||||
| 			senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0 | ||||
| 		} | ||||
| 
 | ||||
| 		var att *attachment | ||||
| 		if attachmentName != "" && attachmentURL != "" { | ||||
| 			att = &attachment{ | ||||
|  | @ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { | |||
| 			Icon:       icon, | ||||
| 			Actions:    actions, | ||||
| 			Attachment: att, | ||||
| 			Sender:     sender, | ||||
| 			Sender:     senderIP, // Must parse assuming database must be correct | ||||
| 			Encoding:   encoding, | ||||
| 		}) | ||||
| 	} | ||||
|  |  | |||
|  | @ -3,11 +3,17 @@ package server | |||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"net/netip" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	exampleIP1234 = netip.MustParseAddr("1.2.3.4") | ||||
| ) | ||||
| 
 | ||||
| func TestSqliteCache_Messages(t *testing.T) { | ||||
|  | @ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | |||
| 	expires1 := time.Now().Add(-4 * time.Hour).Unix() | ||||
| 	m := newDefaultMessage("mytopic", "flower for you") | ||||
| 	m.ID = "m1" | ||||
| 	m.Sender = "1.2.3.4" | ||||
| 	m.Sender = exampleIP1234 | ||||
| 	m.Attachment = &attachment{ | ||||
| 		Name:    "flower.jpg", | ||||
| 		Type:    "image/jpeg", | ||||
|  | @ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | |||
| 	expires2 := time.Now().Add(2 * time.Hour).Unix() // Future | ||||
| 	m = newDefaultMessage("mytopic", "sending you a car") | ||||
| 	m.ID = "m2" | ||||
| 	m.Sender = "1.2.3.4" | ||||
| 	m.Sender = exampleIP1234 | ||||
| 	m.Attachment = &attachment{ | ||||
| 		Name:    "car.jpg", | ||||
| 		Type:    "image/jpeg", | ||||
|  | @ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | |||
| 	expires3 := time.Now().Add(1 * time.Hour).Unix() // Future | ||||
| 	m = newDefaultMessage("another-topic", "sending you another car") | ||||
| 	m.ID = "m3" | ||||
| 	m.Sender = "1.2.3.4" | ||||
| 	m.Sender = exampleIP1234 | ||||
| 	m.Attachment = &attachment{ | ||||
| 		Name:    "another-car.jpg", | ||||
| 		Type:    "image/jpeg", | ||||
|  | @ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | |||
| 	require.Equal(t, int64(5000), messages[0].Attachment.Size) | ||||
| 	require.Equal(t, expires1, messages[0].Attachment.Expires) | ||||
| 	require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) | ||||
| 	require.Equal(t, "1.2.3.4", messages[0].Sender) | ||||
| 	require.Equal(t, "1.2.3.4", messages[0].Sender.String()) | ||||
| 
 | ||||
| 	require.Equal(t, "sending you a car", messages[1].Message) | ||||
| 	require.Equal(t, "car.jpg", messages[1].Attachment.Name) | ||||
|  | @ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { | |||
| 	require.Equal(t, int64(10000), messages[1].Attachment.Size) | ||||
| 	require.Equal(t, expires2, messages[1].Attachment.Expires) | ||||
| 	require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) | ||||
| 	require.Equal(t, "1.2.3.4", messages[1].Sender) | ||||
| 	require.Equal(t, "1.2.3.4", messages[1].Sender.String()) | ||||
| 
 | ||||
| 	size, err := c.AttachmentBytesUsed("1.2.3.4") | ||||
| 	require.Nil(t, err) | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ import ( | |||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path" | ||||
|  | @ -42,7 +43,7 @@ type Server struct { | |||
| 	smtpServerBackend *smtpBackend | ||||
| 	smtpSender        mailer | ||||
| 	topics            map[string]*topic | ||||
| 	visitors          map[string]*visitor | ||||
| 	visitors          map[netip.Addr]*visitor | ||||
| 	firebaseClient    *firebaseClient | ||||
| 	messages          int64 | ||||
| 	auth              auth.Auther | ||||
|  | @ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) { | |||
| 		smtpSender:     mailer, | ||||
| 		topics:         topics, | ||||
| 		auth:           auther, | ||||
| 		visitors:       make(map[string]*visitor), | ||||
| 		visitors:       make(map[netip.Addr]*visitor), | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() { | |||
| 	if s.firebaseClient == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor | ||||
| 	v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0 | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-time.After(s.config.FirebaseKeepaliveInterval): | ||||
|  | @ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { | |||
| 
 | ||||
| func (s *Server) limitRequests(next handleFunc) handleFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
| 		if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) { | ||||
| 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { | ||||
| 			return next(w, r, v) | ||||
| 		} else if err := v.RequestAllowed(); err != nil { | ||||
| 			return errHTTPTooManyRequestsLimitRequests | ||||
|  | @ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool | |||
| // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). | ||||
| func (s *Server) visitor(r *http.Request) *visitor { | ||||
| 	remoteAddr := r.RemoteAddr | ||||
| 	ip, _, err := net.SplitHostPort(remoteAddr) | ||||
| 	addrPort, err := netip.ParseAddrPort(remoteAddr) | ||||
| 	ip := addrPort.Addr() | ||||
| 	if err != nil { | ||||
| 		ip = remoteAddr // This should not happen in real life; only in tests. | ||||
| 		// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified | ||||
| 		ip, err = netip.ParseAddr(remoteAddr) | ||||
| 		if err != nil { | ||||
| 			ip = netip.IPv4Unspecified() | ||||
| 			log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err) | ||||
| 		} | ||||
| 	} | ||||
| 	if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" { | ||||
| 		// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy, | ||||
| 		// only the right-most address can be trusted (as this is the one added by our proxy server). | ||||
| 		// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. | ||||
| 		ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",") | ||||
| 		ip = strings.TrimSpace(util.LastString(ips, remoteAddr)) | ||||
| 		realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr))) | ||||
| 		if err != nil { | ||||
| 			log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error()) | ||||
| 			// Fall back to regular remote address if X-Forwarded-For is damaged | ||||
| 		} else { | ||||
| 			ip = realIP | ||||
| 		} | ||||
| 	} | ||||
| 	return s.visitorFromIP(ip) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) visitorFromIP(ip string) *visitor { | ||||
| func (s *Server) visitorFromIP(ip netip.Addr) *visitor { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	v, exists := s.visitors[ip] | ||||
|  |  | |||
|  | @ -3,13 +3,15 @@ package server | |||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"firebase.google.com/go/v4/messaging" | ||||
| 	"fmt" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/auth" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"firebase.google.com/go/v4/messaging" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/auth" | ||||
| ) | ||||
| 
 | ||||
| type testAuther struct { | ||||
|  | @ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { | |||
| func TestToFirebaseSender_Abuse(t *testing.T) { | ||||
| 	sender := &testFirebaseSender{allowed: 2} | ||||
| 	client := newFirebaseClient(sender, &testAuther{}) | ||||
| 	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") | ||||
| 	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4")) | ||||
| 
 | ||||
| 	require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) | ||||
| 	require.Equal(t, 1, len(sender.Messages())) | ||||
|  |  | |||
|  | @ -1,11 +1,13 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) { | ||||
|  | @ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) { | |||
| func TestMatrix_WriteMatrixError(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 	r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil) | ||||
| 	v := newVisitor(newTestConfig(t), nil, "1.2.3.4") | ||||
| 	v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4")) | ||||
| 	require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch})) | ||||
| 	require.Equal(t, 200, w.Result().StatusCode) | ||||
| 	require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String()) | ||||
|  |  | |||
|  | @ -6,18 +6,20 @@ import ( | |||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/netip" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/auth" | ||||
| 	"heckel.io/ntfy/util" | ||||
|  | @ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) { | |||
| 	messages = toMessages(t, response.Body.String()) | ||||
| 	require.Equal(t, 1, len(messages)) | ||||
| 	require.Equal(t, "a message", messages[0].Message) | ||||
| 	require.Equal(t, "", messages[0].Sender) // Never return the sender! | ||||
| 	require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender! | ||||
| 
 | ||||
| 	messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true) | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, 1, len(messages)) | ||||
| 	require.Equal(t, "a message", messages[0].Message) | ||||
| 	require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though! | ||||
| 	require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though! | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishAtWithCacheError(t *testing.T) { | ||||
|  | @ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { | |||
| 
 | ||||
| func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { | ||||
| 	c := newTestConfig(t) | ||||
| 	c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request() | ||||
| 	c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() | ||||
| 	s := newTestServer(t, c) | ||||
| 	for i := 0; i < 65; i++ { // > 60 | ||||
| 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) | ||||
|  | @ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) { | |||
| 	require.Equal(t, int64(5000), msg.Attachment.Size) | ||||
| 	require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours | ||||
| 	require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") | ||||
| 	require.Equal(t, "", msg.Sender) // Should never be returned | ||||
| 	require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned | ||||
| 	require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) | ||||
| 
 | ||||
| 	// GET | ||||
|  | @ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { | |||
| 	require.Equal(t, int64(21), msg.Attachment.Size) | ||||
| 	require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) | ||||
| 	require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") | ||||
| 	require.Equal(t, "", msg.Sender) // Should never be returned | ||||
| 	require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned | ||||
| 	require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) | ||||
| 
 | ||||
| 	path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") | ||||
|  | @ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) { | |||
| 	require.Equal(t, "", msg.Attachment.Type) | ||||
| 	require.Equal(t, int64(0), msg.Attachment.Size) | ||||
| 	require.Equal(t, int64(0), msg.Attachment.Expires) | ||||
| 	require.Equal(t, "", msg.Sender) | ||||
| 	require.Equal(t, netip.Addr{}, msg.Sender) | ||||
| 
 | ||||
| 	// Slightly unrelated cross-test: make sure we don't add an owner for external attachments | ||||
| 	size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") | ||||
|  | @ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) { | |||
| 	require.Equal(t, "", msg.Attachment.Type) | ||||
| 	require.Equal(t, int64(0), msg.Attachment.Size) | ||||
| 	require.Equal(t, int64(0), msg.Attachment.Expires) | ||||
| 	require.Equal(t, "", msg.Sender) | ||||
| 	require.Equal(t, netip.Addr{}, msg.Sender) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishAttachmentBadURL(t *testing.T) { | ||||
|  | @ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) { | |||
| 	r.RemoteAddr = "8.9.10.11" | ||||
| 	r.Header.Set("X-Forwarded-For", "  ") // Spaces, not empty! | ||||
| 	v := s.visitor(r) | ||||
| 	require.Equal(t, "8.9.10.11", v.ip) | ||||
| 	require.Equal(t, "8.9.10.11", v.ip.String()) | ||||
| } | ||||
| 
 | ||||
| func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { | ||||
|  | @ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { | |||
| 	r.RemoteAddr = "8.9.10.11" | ||||
| 	r.Header.Set("X-Forwarded-For", "1.1.1.1") | ||||
| 	v := s.visitor(r) | ||||
| 	require.Equal(t, "1.1.1.1", v.ip) | ||||
| 	require.Equal(t, "1.1.1.1", v.ip.String()) | ||||
| } | ||||
| 
 | ||||
| func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { | ||||
|  | @ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { | |||
| 	r.RemoteAddr = "8.9.10.11" | ||||
| 	r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") | ||||
| 	v := s.visitor(r) | ||||
| 	require.Equal(t, "234.5.2.1", v.ip) | ||||
| 	require.Equal(t, "234.5.2.1", v.ip.String()) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { | ||||
|  |  | |||
|  | @ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error { | |||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m) | ||||
| 		message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  |  | |||
|  | @ -1,9 +1,11 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"heckel.io/ntfy/util" | ||||
| ) | ||||
| 
 | ||||
| // List of possible events | ||||
|  | @ -33,7 +35,7 @@ type message struct { | |||
| 	Actions    []*action   `json:"actions,omitempty"` | ||||
| 	Attachment *attachment `json:"attachment,omitempty"` | ||||
| 	PollID     string      `json:"poll_id,omitempty"` | ||||
| 	Sender     string      `json:"-"`                  // IP address of uploader, used for rate limiting | ||||
| 	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting | ||||
| 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,10 +2,12 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"golang.org/x/time/rate" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"golang.org/x/time/rate" | ||||
| 	"heckel.io/ntfy/util" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -23,7 +25,7 @@ var ( | |||
| type visitor struct { | ||||
| 	config        *Config | ||||
| 	messageCache  *messageCache | ||||
| 	ip            string | ||||
| 	ip            netip.Addr | ||||
| 	requests      *rate.Limiter | ||||
| 	emails        *rate.Limiter | ||||
| 	subscriptions util.Limiter | ||||
|  | @ -40,7 +42,7 @@ type visitorStats struct { | |||
| 	VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` | ||||
| } | ||||
| 
 | ||||
| func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { | ||||
| func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor { | ||||
| 	return &visitor{ | ||||
| 		config:        conf, | ||||
| 		messageCache:  messageCache, | ||||
|  | @ -115,7 +117,7 @@ func (v *visitor) Stale() bool { | |||
| } | ||||
| 
 | ||||
| func (v *visitor) Stats() (*visitorStats, error) { | ||||
| 	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) | ||||
| 	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  |  | |||
							
								
								
									
										16
									
								
								util/util.go
									
										
									
									
									
								
							
							
						
						
									
										16
									
								
								util/util.go
									
										
									
									
									
								
							|  | @ -5,16 +5,18 @@ import ( | |||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gabriel-vasile/mimetype" | ||||
| 	"golang.org/x/term" | ||||
| 	"io" | ||||
| 	"math/rand" | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/gabriel-vasile/mimetype" | ||||
| 	"golang.org/x/term" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool { | |||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // ContainsIP returns true if any one of the of prefixes contains the ip. | ||||
| func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool { | ||||
| 	for _, s := range haystack { | ||||
| 		if s.Contains(needle) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // ContainsAll returns true if all needles are contained in haystack | ||||
| func ContainsAll[T comparable](haystack []T, needles []T) bool { | ||||
| 	matches := 0 | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| package util | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestRandomString(t *testing.T) { | ||||
|  | @ -42,6 +44,13 @@ func TestContains(t *testing.T) { | |||
| 	require.False(t, Contains(s, 3)) | ||||
| } | ||||
| 
 | ||||
| func TestContainsIP(t *testing.T) { | ||||
| 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1"))) | ||||
| 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876"))) | ||||
| 	require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1"))) | ||||
| 	require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1"))) | ||||
| } | ||||
| 
 | ||||
| func TestSplitNoEmpty(t *testing.T) { | ||||
| 	require.Equal(t, []string{}, SplitNoEmpty("", ",")) | ||||
| 	require.Equal(t, []string{}, SplitNoEmpty(",,,", ",")) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue