Upstream access token

This commit is contained in:
binwiederhier 2023-05-18 13:08:10 -04:00
parent f13a654fe8
commit 25d3a66f91
9 changed files with 144 additions and 50 deletions

View file

@ -98,6 +98,7 @@ type Config struct {
FirebasePollInterval time.Duration
FirebaseQuotaExceededPenaltyDuration time.Duration
UpstreamBaseURL string
UpstreamAccessToken string
SMTPSenderAddr string
SMTPSenderUser string
SMTPSenderPass string
@ -182,6 +183,7 @@ func NewConfig() *Config {
FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
UpstreamBaseURL: "",
UpstreamAccessToken: "",
SMTPSenderAddr: "",
SMTPSenderUser: "",
SMTPSenderPass: "",

View file

@ -855,7 +855,11 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
logvm(v, m).Err(err).Warn("Unable to publish poll request")
return
}
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Set("X-Poll-ID", m.ID)
if s.config.UpstreamAccessToken != "" {
req.Header.Set("Authorization", util.BearerAuth(s.config.UpstreamAccessToken))
}
var httpClient = &http.Client{
Timeout: time.Second * 10,
}

View file

@ -208,7 +208,12 @@
# the message ID of the original message, instructing the iOS app to poll this server for the actual message contents.
# This is to prevent the upstream server and Firebase/APNS from being able to read the message.
#
# - upstream-base-url is the base URL of the upstream server. Should be "https://ntfy.sh".
# - upstream-access-token is the token used to authenticate with the upstream server. This is only required
# if you exceed the upstream rate limits, or the uptream server requires authentication.
#
# upstream-base-url:
# upstream-access-token:
# Rate limiting: Total number of topics before the server rejects new topics.
#

View file

@ -18,6 +18,7 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -2491,6 +2492,66 @@ func TestServer_PublishWithUTF8MimeHeader(t *testing.T) {
require.Equal(t, "ntfy 很棒", m.Tags[1])
}
func TestServer_UpstreamBaseURL_Success(t *testing.T) {
var pollID atomic.Pointer[string]
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/87c9cddf7b0105f5fe849bf084c6e600be0fde99be3223335199b4965bd7b735", r.URL.Path)
require.Equal(t, "", string(body))
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
s := newTestServer(t, c)
// Send message, and wait for upstream server to receive it
response := request(t, s, "PUT", "/mytopic", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
waitFor(t, func() bool {
pID := pollID.Load()
return pID != nil && *pID == m.ID
})
}
func TestServer_UpstreamBaseURL_With_Access_Token_Success(t *testing.T) {
var pollID atomic.Pointer[string]
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/a1c72bcb4daf5af54d13ef86aea8f76c11e8b88320d55f1811d5d7b173bcc1df", r.URL.Path)
require.Equal(t, "Bearer tk_1234567890", r.Header.Get("Authorization"))
require.Equal(t, "", string(body))
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
c.UpstreamAccessToken = "tk_1234567890"
s := newTestServer(t, c)
// Send message, and wait for upstream server to receive it
response := request(t, s, "PUT", "/mytopic1", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
waitFor(t, func() bool {
pID := pollID.Load()
return pID != nil && *pID == m.ID
})
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345"
@ -2592,7 +2653,7 @@ func waitForWithMaxWait(t *testing.T, maxWait time.Duration, f func() bool) {
if f() {
return
}
time.Sleep(100 * time.Millisecond)
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("Function f did not succeed after %v: %v", maxWait, string(debug.Stack()))
}

View file

@ -87,8 +87,9 @@ func (s *Server) callPhoneInternal(data url.Values) (string, error) {
if err != nil {
return "", err
}
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
@ -110,8 +111,9 @@ func (s *Server) verifyPhoneNumber(v *visitor, r *http.Request, phoneNumber, cha
if err != nil {
return err
}
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
@ -135,8 +137,9 @@ func (s *Server) verifyPhoneNumberCheck(v *visitor, r *http.Request, phoneNumber
if err != nil {
return err
}
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err