Fix almost all tests

This commit is contained in:
binwiederhier 2022-12-27 22:14:14 -05:00
parent 95a8e64fbb
commit d9722a9825
17 changed files with 197 additions and 253 deletions

View file

@ -56,13 +56,6 @@ func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1025)))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
@ -95,7 +88,7 @@ func TestFileCache_RemoveExpired(t *testing.T) {
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024, 1*1024)
cache, err := newFileCache(dir, 10*1024)
require.Nil(t, err)
return dir, cache
}

View file

@ -40,7 +40,7 @@ const (
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
user TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
@ -95,7 +95,7 @@ const (
// Schema management queries
const (
currentSchemaVersion = 9
currentSchemaVersion = 10
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
@ -193,6 +193,11 @@ const (
migrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
`
)
type messageCache struct {
@ -614,8 +619,9 @@ func setupCacheDB(db *sql.DB, startupQueries string) error {
return migrateFrom7(db)
} else if schemaVersion == 8 {
return migrateFrom8(db)
} else if schemaVersion == 9 {
return migrateFrom9(db)
}
// TODO add user column
return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
}
@ -731,5 +737,16 @@ func migrateFrom8(db *sql.DB) error {
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
return err
}
return migrateFrom9(db)
}
func migrateFrom9(db *sql.DB) error {
log.Info("Migrating cache database schema: from 9 to 10")
if _, err := db.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 10); err != nil {
return err
}
return nil // Update this when a new version is added
}

View file

@ -43,12 +43,15 @@ import (
"user list" shows * twice
"ntfy access everyone user4topic <bla>" twice -> UNIQUE constraint error
Account usage not updated "in real time"
Attachment expiration based on plan
Plan: Keep 10000 messages or keep X days?
Sync:
- "mute" setting
- figure out what settings are "web" or "phone"
UI:
- Subscription dotmenu dropdown: Move to nav bar, or make same as profile dropdown
- "Logout and delete local storage" option
- Delete local storage when deleting account
Pages:
- Home
- Password reset
@ -61,7 +64,8 @@ import (
- APIs
- CRUD tokens
- Expire tokens
-
- userManager can be nil
- visitor with/without user
*/
// Server is the main server, providing the UI and API for ntfy
@ -77,7 +81,7 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
userManager user.Manager
userManager *user.Manager // Might be nil!
messageCache *messageCache
fileCache *fileCache
closeChan chan bool
@ -165,9 +169,9 @@ func New(conf *Config) (*Server, error) {
return nil, err
}
}
var auther user.Manager
var userManager *user.Manager
if conf.AuthFile != "" {
auther, err = user.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
userManager, err = user.NewManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
if err != nil {
return nil, err
}
@ -178,7 +182,7 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
firebaseClient = newFirebaseClient(sender, auther)
firebaseClient = newFirebaseClient(sender, userManager)
}
return &Server{
config: conf,
@ -187,7 +191,7 @@ func New(conf *Config) (*Server, error) {
firebaseClient: firebaseClient,
smtpSender: mailer,
topics: topics,
userManager: auther,
userManager: userManager,
visitors: make(map[string]*visitor),
}, nil
}
@ -341,27 +345,27 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPath {
return s.handleAccountCreate(w, r, v)
return s.ensureAccountsEnabled(s.handleAccountCreate)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == accountPath {
return s.handleAccountGet(w, r, v)
return s.handleAccountGet(w, r, v) // Allowed by anonymous
} else if r.Method == http.MethodDelete && r.URL.Path == accountPath {
return s.handleAccountDelete(w, r, v)
return s.ensureWithAccount(s.handleAccountDelete)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
return s.handleAccountPasswordChange(w, r, v)
return s.ensureWithAccount(s.handleAccountPasswordChange)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath {
return s.handleAccountTokenIssue(w, r, v)
return s.ensureWithAccount(s.handleAccountTokenIssue)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
return s.handleAccountTokenExtend(w, r, v)
return s.ensureWithAccount(s.handleAccountTokenExtend)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
return s.handleAccountTokenDelete(w, r, v)
return s.ensureWithAccount(s.handleAccountTokenDelete)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
return s.handleAccountSettingsChange(w, r, v)
return s.ensureWithAccount(s.handleAccountSettingsChange)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
return s.handleAccountSubscriptionAdd(w, r, v)
return s.ensureWithAccount(s.handleAccountSubscriptionAdd)(w, r, v)
} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.handleAccountSubscriptionChange(w, r, v)
return s.ensureWithAccount(s.handleAccountSubscriptionChange)(w, r, v)
} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.handleAccountSubscriptionDelete(w, r, v)
return s.ensureWithAccount(s.handleAccountSubscriptionDelete)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@ -804,7 +808,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
}
stats, err := v.Stats()
stats, err := v.Info()
if err != nil {
return err
}
@ -1182,7 +1186,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return topics, nil
}
func (s *Server) updateStatsAndPrune() {
func (s *Server) execManager() {
log.Debug("Manager: Starting")
defer log.Debug("Manager: Finished")
@ -1203,8 +1207,10 @@ func (s *Server) updateStatsAndPrune() {
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error())
if s.userManager != nil {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error())
}
}
// Delete expired attachments
@ -1293,7 +1299,7 @@ func (s *Server) runManager() {
for {
select {
case <-time.After(s.config.ManagerInterval):
s.updateStatsAndPrune()
s.execManager()
case <-s.closeChan:
return
}
@ -1399,6 +1405,24 @@ func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
}
}
func (s *Server) ensureAccountsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.userManager != nil {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureWithAccount(next handleFunc) handleFunc {
return s.ensureAccountsEnabled(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user != nil {
return errHTTPNotFound
}
return next(w, r, v)
})
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
@ -1502,17 +1526,17 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy)
var user *user.User // may stay nil if no auth header!
if user, err = s.authenticate(r); err != nil {
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
if user != nil {
v = s.visitorFromUser(user, ip)
if u != nil {
v = s.visitorFromUser(u, ip)
} else {
v = s.visitorFromIP(ip)
}
v.user = user // Update user -- FIXME race?
v.user = u // Update user -- FIXME race?
return v, err // Always return visitor, even when error occurs!
}
@ -1521,17 +1545,19 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
value := r.Header.Get("Authorization")
value := strings.TrimSpace(r.Header.Get("Authorization"))
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
}
value = string(a)
value = strings.TrimSpace(string(a))
}
if value == "" {
return nil, nil
} else if s.userManager == nil {
return nil, errHTTPUnauthorized
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(value)

View file

@ -45,11 +45,11 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
stats, err := v.Stats()
stats, err := v.Info()
if err != nil {
return err
}
response := &apiAccountSettingsResponse{
response := &apiAccountResponse{
Stats: &apiAccountStats{
Messages: stats.Messages,
MessagesRemaining: stats.MessagesRemaining,

View file

@ -28,10 +28,10 @@ var (
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
type firebaseClient struct {
sender firebaseSender
auther user.Manager
auther user.Auther
}
func newFirebaseClient(sender firebaseSender, auther user.Manager) *firebaseClient {
func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClient {
return &firebaseClient{
sender: sender,
auther: auther,
@ -112,7 +112,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
// to Firebase here. This is mainly for iOS to support self-hosted servers.
func toFirebaseMessage(m *message, auther user.Manager) (*messaging.Message, error) {
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
var apnsConfig *messaging.APNSConfig
switch m.Event {

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"heckel.io/ntfy/user"
"net/netip"
"strings"
"sync"
@ -17,7 +18,9 @@ type testAuther struct {
Allow bool
}
func (t testAuther) AuthenticateUser(_, _ string) (*user.User, error) {
var _ user.Auther = (*testAuther)(nil)
func (t testAuther) Authenticate(_, _ string) (*user.User, error) {
return nil, errors.New("not used")
}
@ -323,7 +326,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"), nil)
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages()))

View file

@ -72,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
func TestMatrix_WriteMatrixError(t *testing.T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())

View file

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"heckel.io/ntfy/user"
"io"
"log"
"math/rand"
@ -171,7 +172,7 @@ func TestServer_StaticSites(t *testing.T) {
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
require.Equal(t, 200, rr.Code)
require.Contains(t, rr.Body.String(), `html, body {`)
require.Contains(t, rr.Body.String(), `/* general styling */`)
rr = request(t, s, "GET", "/docs", "", nil)
require.Equal(t, 301, rr.Code)
@ -353,7 +354,7 @@ func TestServer_PublishAtAndPrune(t *testing.T) {
"In": "1h",
})
require.Equal(t, 200, response.Code)
s.updateStatsAndPrune() // Fire pruning
s.execManager() // Fire pruning
response = request(t, s, "GET", "/mytopic/json?poll=1&scheduled=1", "", nil)
messages := toMessages(t, response.Body.String())
@ -625,8 +626,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c.AuthFile = filepath.Join(t.TempDir(), "user.db")
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("phil:phil"),
@ -642,9 +642,8 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefaultWrite = false
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", true, true))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("ben:ben"),
@ -659,10 +658,9 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefaultWrite = false
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, manager.AllowAccess("ben", "anothertopic", true, true))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", true, true))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", true, true))
response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{
"Authorization": basicAuth("ben:ben"),
@ -682,8 +680,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefaultWrite = false
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("phil:INVALID"),
@ -698,9 +695,8 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefaultWrite = false
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic!
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": basicAuth("ben:ben"),
@ -715,10 +711,9 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefaultWrite = true // Open by default
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, manager.AllowAccess(user.Everyone, "private", false, false))
require.Nil(t, manager.AllowAccess(user.Everyone, "announcements", true, false))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", false, false))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", true, false))
response := request(t, s, "PUT", "/mytopic", "test", nil)
require.Equal(t, 200, response.Code)
@ -748,8 +743,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefaultWrite = false
s := newTestServer(t, c)
manager := s.userManager.(user.Manager)
require.Nil(t, manager.AddUser("ben", "some pass", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(basicAuth("ben:some pass"))))
response := request(t, s, "GET", u, "", nil)
@ -760,27 +754,6 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
require.Equal(t, 401, response.Code)
}
/*
func TestServer_Curl_Publish_Poll(t *testing.T) {
s, port := test.StartServer(t)
defer test.StopServer(t, s, port)
cmd := exec.Command("sh", "-c", fmt.Sprintf(`curl -sd "This is a test" localhost:%d/mytopic`, port))
require.Nil(t, cmd.Run())
b, err := cmd.CombinedOutput()
require.Nil(t, err)
msg := toMessage(t, string(b))
require.Equal(t, "This is a test", msg.Message)
cmd = exec.Command("sh", "-c", fmt.Sprintf(`curl "localhost:%d/mytopic?poll=1"`, port))
require.Nil(t, cmd.Run())
b, err = cmd.CombinedOutput()
require.Nil(t, err)
msg = toMessage(t, string(b))
require.Equal(t, "This is a test", msg.Message)
}
*/
type testMailer struct {
count int
mu sync.Mutex
@ -1306,7 +1279,7 @@ func TestServer_PublishAttachmentAndPrune(t *testing.T) {
// Prune and makes sure it's gone
time.Sleep(time.Second) // Sigh ...
s.updateStatsAndPrune()
s.execManager()
require.NoFileExists(t, file)
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 404, response.Code)
@ -1360,7 +1333,7 @@ func TestServer_PublishAttachmentBandwidthLimitUploadOnly(t *testing.T) {
require.Equal(t, 41301, err.Code)
}
func TestServer_PublishAttachmentUserStats(t *testing.T) {
func TestServer_PublishAttachmentAccountStats(t *testing.T) {
content := util.RandomString(4999) // > 4096
c := newTestConfig(t)
@ -1374,14 +1347,14 @@ func TestServer_PublishAttachmentUserStats(t *testing.T) {
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
// User stats
response = request(t, s, "GET", "/user/stats", "", nil)
response = request(t, s, "GET", "/v1/account", "", nil)
require.Equal(t, 200, response.Code)
var stats visitorStats
require.Nil(t, json.NewDecoder(strings.NewReader(response.Body.String())).Decode(&stats))
require.Equal(t, int64(5000), stats.AttachmentFileSizeLimit)
require.Equal(t, int64(6000), stats.VisitorAttachmentBytesTotal)
require.Equal(t, int64(4999), stats.AttachmentBytes)
require.Equal(t, int64(1001), stats.VisitorAttachmentBytesRemaining)
var account *apiAccountResponse
require.Nil(t, json.NewDecoder(strings.NewReader(response.Body.String())).Decode(&account))
require.Equal(t, int64(5000), account.Limits.AttachmentFileSize)
require.Equal(t, int64(6000), account.Limits.AttachmentTotalSize)
require.Equal(t, int64(4999), account.Stats.AttachmentTotalSize)
require.Equal(t, int64(1001), account.Stats.AttachmentTotalSizeRemaining)
}
func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
@ -1391,7 +1364,8 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
v := s.visitor(r)
v, err := s.visitor(r)
require.Nil(t, err)
require.Equal(t, "8.9.10.11", v.ip.String())
}
@ -1402,7 +1376,8 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.1.1.1")
v := s.visitor(r)
v, err := s.visitor(r)
require.Nil(t, err)
require.Equal(t, "1.1.1.1", v.ip.String())
}
@ -1413,7 +1388,8 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
v := s.visitor(r)
v, err := s.visitor(r)
require.Nil(t, err)
require.Equal(t, "234.5.2.1", v.ip.String())
}
@ -1442,7 +1418,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
go func() {
log.Printf("Updating stats")
start := time.Now()
s.updateStatsAndPrune()
s.execManager()
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
statsChan <- true
}()

View file

@ -252,7 +252,7 @@ type apiAccountStats struct {
AttachmentTotalSizeRemaining int64 `json:"attachment_total_size_remaining"`
}
type apiAccountSettingsResponse struct {
type apiAccountResponse struct {
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`

View file

@ -40,7 +40,7 @@ type visitor struct {
mu sync.Mutex
}
type visitorStats struct {
type visitorInfo struct {
Basis string // "ip", "role" or "plan"
Messages int64
MessagesLimit int64
@ -165,30 +165,30 @@ func (v *visitor) IncrEmails() {
}
}
func (v *visitor) Stats() (*visitorStats, error) {
func (v *visitor) Info() (*visitorInfo, error) {
v.mu.Lock()
messages := v.messages
emails := v.emails
v.mu.Unlock()
stats := &visitorStats{}
info := &visitorInfo{}
if v.user != nil && v.user.Role == user.RoleAdmin {
stats.Basis = "role"
stats.MessagesLimit = 0
stats.EmailsLimit = 0
stats.AttachmentTotalSizeLimit = 0
stats.AttachmentFileSizeLimit = 0
info.Basis = "role"
info.MessagesLimit = 0
info.EmailsLimit = 0
info.AttachmentTotalSizeLimit = 0
info.AttachmentFileSizeLimit = 0
} else if v.user != nil && v.user.Plan != nil {
stats.Basis = "plan"
stats.MessagesLimit = v.user.Plan.MessagesLimit
stats.EmailsLimit = v.user.Plan.EmailsLimit
stats.AttachmentTotalSizeLimit = v.user.Plan.AttachmentTotalSizeLimit
stats.AttachmentFileSizeLimit = v.user.Plan.AttachmentFileSizeLimit
info.Basis = "plan"
info.MessagesLimit = v.user.Plan.MessagesLimit
info.EmailsLimit = v.user.Plan.EmailsLimit
info.AttachmentTotalSizeLimit = v.user.Plan.AttachmentTotalSizeLimit
info.AttachmentFileSizeLimit = v.user.Plan.AttachmentFileSizeLimit
} else {
stats.Basis = "ip"
stats.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish)
stats.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish)
stats.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit
stats.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit
info.Basis = "ip"
info.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish)
info.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish)
info.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit
info.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit
}
var attachmentsBytesUsed int64
var err error
@ -200,13 +200,13 @@ func (v *visitor) Stats() (*visitorStats, error) {
if err != nil {
return nil, err
}
stats.Messages = messages
stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.Messages)
stats.Emails = emails
stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.Emails)
stats.AttachmentTotalSize = attachmentsBytesUsed
stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize)
return stats, nil
info.Messages = messages
info.MessagesRemaining = zeroIfNegative(info.MessagesLimit - info.Messages)
info.Emails = emails
info.EmailsRemaining = zeroIfNegative(info.EmailsLimit - info.Emails)
info.AttachmentTotalSize = attachmentsBytesUsed
info.AttachmentTotalSizeRemaining = zeroIfNegative(info.AttachmentTotalSizeLimit - info.AttachmentTotalSize)
return info, nil
}
func zeroIfNegative(value int64) int64 {