Introduce text IDs for everything (esp user), to avoid security and accounting issues

This commit is contained in:
binwiederhier 2023-01-21 23:15:22 -05:00
parent 88abd8872d
commit 9c082a8331
13 changed files with 160 additions and 108 deletions

View file

@ -98,8 +98,8 @@ const (
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
) )
// Schema management queries // Schema management queries
@ -563,8 +563,8 @@ func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error)
return c.readAttachmentBytesUsed(rows) return c.readAttachmentBytesUsed(rows)
} }
func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) { func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix()) rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -12,10 +12,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var (
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
)
func TestSqliteCache_Messages(t *testing.T) { func TestSqliteCache_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t)) testCacheMessages(t, newSqliteTestCache(t))
} }
@ -294,10 +290,10 @@ func TestMemCache_Attachments(t *testing.T) {
} }
func testCacheAttachments(t *testing.T, c *messageCache) { func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
m.Sender = exampleIP1234 m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "flower.jpg", Name: "flower.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -310,7 +306,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car") m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2" m.ID = "m2"
m.Sender = exampleIP1234 m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "car.jpg", Name: "car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -323,7 +319,8 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car") m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3" m.ID = "m3"
m.Sender = exampleIP1234 m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &attachment{ m.Attachment = &attachment{
Name: "another-car.jpg", Name: "another-car.jpg",
Type: "image/jpeg", Type: "image/jpeg",
@ -355,11 +352,15 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
size, err := c.AttachmentBytesUsedBySender("1.2.3.4") size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(30000), size) require.Equal(t, int64(10000), size)
size, err = c.AttachmentBytesUsedBySender("5.6.7.8") size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), size) require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
} }
func TestSqliteCache_Attachments_Expired(t *testing.T) { func TestSqliteCache_Attachments_Expired(t *testing.T) {

View file

@ -38,12 +38,13 @@ import (
TODO TODO
-- --
- Security: Account re-creation leads to terrible behavior. Use user ID instead of user name for (a) visitor map, (b) messages.user column, (c) Stripe checkout session
- Reservation: Kill existing subscribers when topic is reserved (deadcade) - Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade) - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Logging: Add detailed logging with username/customerID for all Stripe events (phil) - Logging: Add detailed logging with username/customerID for all Stripe events (phil)
- Rate limiting: Sensitive endpoints (account/login/change-password/...) - Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
- Stripe: Add metadata to customer
races: races:
- v.user --> see publishSyncEventAsync() test - v.user --> see publishSyncEventAsync() test
@ -581,7 +582,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m = newPollRequestMessage(t.ID, m.PollID) m = newPollRequestMessage(t.ID, m.PollID)
} }
if v.user != nil { if v.user != nil {
m.User = v.user.Name m.User = v.user.ID
} }
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix() m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
@ -859,6 +860,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Time > attachmentExpiry { if m.Time > attachmentExpiry {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
} }
fmt.Printf("v = %#v\nlimits = %#v\nstats = %#v\n", v, vinfo.Limits, vinfo.Stats)
contentLengthStr := r.Header.Get("Content-Length") contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
@ -885,6 +887,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
} }
fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
if err == util.ErrLimitReached { if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachment return errHTTPEntityTooLargeAttachment
@ -1657,7 +1660,7 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
} }
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor { func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user) return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
} }
func (s *Server) writeJSON(w http.ResponseWriter, v any) error { func (s *Server) writeJSON(w http.ResponseWriter, v any) error {

View file

@ -337,7 +337,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPTooManyRequestsLimitReservations return errHTTPTooManyRequestsLimitReservations
} }
} }
if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil { if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())

View file

@ -212,7 +212,7 @@ func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)

View file

@ -128,7 +128,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{ params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name, ClientReferenceID: &v.user.ID,
SuccessURL: &successURL, SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
AllowPromotionCodes: stripe.Bool(true), AllowPromotionCodes: stripe.Bool(true),
@ -178,7 +178,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil { if err != nil {
return err return err
} }
u, err := s.userManager.User(sess.ClientReferenceID) u, err := s.userManager.UserByID(sess.ClientReferenceID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -176,8 +176,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll)) require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll)) require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
// Add billing details // Add billing details
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")

View file

@ -830,7 +830,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) { func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.VisitorRequestLimitBurst = 60 c.VisitorRequestLimitBurst = 60
c.VisitorRequestLimitReplenish = 500 * time.Millisecond c.VisitorRequestLimitReplenish = time.Second
s := newTestServer(t, c) s := newTestServer(t, c)
for i := 0; i < 60; i++ { for i := 0; i < 60; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@ -839,7 +839,7 @@ func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
response := request(t, s, "PUT", "/mytopic", "message", nil) response := request(t, s, "PUT", "/mytopic", "message", nil)
require.Equal(t, 429, response.Code) require.Equal(t, 429, response.Code)
time.Sleep(520 * time.Millisecond) time.Sleep(1020 * time.Millisecond)
response = request(t, s, "PUT", "/mytopic", "message", nil) response = request(t, s, "PUT", "/mytopic", "message", nil)
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
} }

View file

@ -241,7 +241,7 @@ func (v *visitor) Info() (*visitorInfo, error) {
var attachmentsBytesUsed int64 var attachmentsBytesUsed int64
var err error var err error
if v.user != nil { if v.user != nil {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name) attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.ID)
} else { } else {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
} }

View file

@ -16,13 +16,19 @@ import (
) )
const ( const (
bcryptCost = 10 tierIDPrefix = "ti_"
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost tierIDLength = 8
userStatsQueueWriterInterval = 33 * time.Second syncTopicPrefix = "st_"
tokenLength = 32 syncTopicLength = 16
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much userIDPrefix = "u_"
syncTopicLength = 16 userIDLength = 12
tokenMaxCount = 10 // Only keep this many tokens in the table per user userPasswordBcryptCost = 10
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
userStatsQueueWriterInterval = 33 * time.Second
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 10 // Only keep this many tokens in the table per user
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
) )
var ( var (
@ -35,7 +41,7 @@ var (
const ( const (
createTablesQueriesNoTx = ` createTablesQueriesNoTx = `
CREATE TABLE IF NOT EXISTS tier ( CREATE TABLE IF NOT EXISTS tier (
id INTEGER PRIMARY KEY AUTOINCREMENT, id TEXT PRIMARY KEY,
code TEXT NOT NULL, code TEXT NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
messages_limit INT NOT NULL, messages_limit INT NOT NULL,
@ -50,7 +56,7 @@ const (
CREATE UNIQUE INDEX idx_tier_code ON tier (code); CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT, id TEXT PRIMARY KEY,
tier_id INT, tier_id INT,
user TEXT NOT NULL, user TEXT NOT NULL,
pass TEXT NOT NULL, pass TEXT NOT NULL,
@ -72,7 +78,7 @@ const (
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access ( CREATE TABLE IF NOT EXISTS user_access (
user_id INT NOT NULL, user_id TEXT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
read INT NOT NULL, read INT NOT NULL,
write INT NOT NULL, write INT NOT NULL,
@ -82,7 +88,7 @@ const (
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
); );
CREATE TABLE IF NOT EXISTS user_token ( CREATE TABLE IF NOT EXISTS user_token (
user_id INT NOT NULL, user_id TEXT NOT NULL,
token TEXT NOT NULL, token TEXT NOT NULL,
expires INT NOT NULL, expires INT NOT NULL,
PRIMARY KEY (user_id, token), PRIMARY KEY (user_id, token),
@ -93,7 +99,7 @@ const (
version INT NOT NULL version INT NOT NULL
); );
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH()) VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING; ON CONFLICT (id) DO NOTHING;
` `
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -101,21 +107,27 @@ const (
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
` `
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -129,8 +141,8 @@ const (
` `
insertUserQuery = ` insertUserQuery = `
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at) INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
` `
selectUsernamesQuery = ` selectUsernamesQuery = `
SELECT user SELECT user
@ -145,7 +157,7 @@ const (
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?` updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?` updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
deleteUserQuery = `DELETE FROM user WHERE user = ?` deleteUserQuery = `DELETE FROM user WHERE user = ?`
@ -199,8 +211,8 @@ const (
AND topic = ? AND topic = ?
` `
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)` selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
@ -209,27 +221,27 @@ const (
WHERE (user_id, token) NOT IN ( WHERE (user_id, token) NOT IN (
SELECT user_id, token SELECT user_id, token
FROM user_token FROM user_token
WHERE user_id = (SELECT id FROM user WHERE user = ?) WHERE user_id = ?
ORDER BY expires DESC ORDER BY expires DESC
LIMIT ? LIMIT ?
) )
` `
insertTierQuery = ` insertTierQuery = `
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
selectTiersQuery = ` selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier FROM tier
` `
selectTierByCodeQuery = ` selectTierByCodeQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier FROM tier
WHERE code = ? WHERE code = ?
` `
selectTierByPriceIDQuery = ` selectTierByPriceIDQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier FROM tier
WHERE stripe_price_id = ? WHERE stripe_price_id = ?
` `
@ -254,10 +266,12 @@ const (
migrate1To2RenameUserTableQueryNoTx = ` migrate1To2RenameUserTableQueryNoTx = `
ALTER TABLE user RENAME TO user_old; ALTER TABLE user RENAME TO user_old;
` `
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = ` migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old;
INSERT INTO user_access (user_id, topic, read, write) INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write SELECT u.id, a.topic, a.read, a.write
FROM user u FROM user u
@ -266,8 +280,7 @@ const (
DROP TABLE access; DROP TABLE access;
DROP TABLE user_old; DROP TABLE user_old;
` `
migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user` migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
) )
// Manager is an implementation of Manager. It stores users and access control list // Manager is an implementation of Manager. It stores users and access control list
@ -317,7 +330,7 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
user, err := a.User(username) user, err := a.User(username)
if err != nil { if err != nil {
log.Trace("authentication of user %s failed (1): %s", username, err.Error()) log.Trace("authentication of user %s failed (1): %s", username, err.Error())
bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} }
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
@ -345,16 +358,16 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
// given user, if there are too many of them. // given user, if there are too many of them.
func (a *Manager) CreateToken(user *User) (*Token, error) { func (a *Manager) CreateToken(user *User) (*Token, error) {
token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration) token, expires := util.RandomStringPrefix(tokenPrefix, tokenLength), time.Now().Add(tokenExpiryDuration)
tx, err := a.db.Begin() tx, err := a.db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { if _, err := tx.Exec(insertTokenQuery, user.ID, token, expires.Unix()); err != nil {
return nil, err return nil, err
} }
rows, err := tx.Query(selectTokenCountQuery, user.Name) rows, err := tx.Query(selectTokenCountQuery, user.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -369,7 +382,7 @@ func (a *Manager) CreateToken(user *User) (*Token, error) {
if tokenCount >= tokenMaxCount { if tokenCount >= tokenMaxCount {
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
// on two indices, whereas the query below is a full table scan. // on two indices, whereas the query below is a full table scan.
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil { if _, err := tx.Exec(deleteExcessTokensQuery, user.ID, tokenMaxCount); err != nil {
return nil, err return nil, err
} }
} }
@ -444,7 +457,7 @@ func (a *Manager) ResetStats() error {
func (a *Manager) EnqueueStats(user *User) { func (a *Manager) EnqueueStats(user *User) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
a.statsQueue[user.Name] = user a.statsQueue[user.ID] = user
} }
func (a *Manager) userStatsQueueWriter(interval time.Duration) { func (a *Manager) userStatsQueueWriter(interval time.Duration) {
@ -472,9 +485,9 @@ func (a *Manager) writeUserStatsQueue() error {
} }
defer tx.Rollback() defer tx.Rollback()
log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue)) log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
for username, u := range statsQueue { for userID, u := range statsQueue {
log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails) log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails)
if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil { if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil {
return err return err
} }
} }
@ -524,12 +537,13 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
if err != nil { if err != nil {
return err return err
} }
syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix() userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil { syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil {
return err return err
} }
return nil return nil
@ -587,6 +601,15 @@ func (a *Manager) User(username string) (*User, error) {
return a.readUser(rows) return a.readUser(rows)
} }
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByID(id string) (*User, error) {
rows, err := a.db.Query(selectUserByIDQuery, id)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID) rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
@ -606,19 +629,20 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
} }
user := &User{ user := &User{
ID: id,
Name: username, Name: username,
Hash: hash, Hash: hash,
Role: Role(role), Role: Role(role),
@ -744,7 +768,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
// ChangePassword changes a user's password // ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string) error { func (a *Manager) ChangePassword(username, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
if err != nil { if err != nil {
return err return err
} }
@ -818,6 +842,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
// CheckAllowAccess tests if a user may create an access control entry for the given topic. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
// If there are any ACL entries that are not owned by the user, an error is returned. // If there are any ACL entries that are not owned by the user, an error is returned.
// FIXME is this the same as HasReservation?
func (a *Manager) CheckAllowAccess(username string, topic string) error { func (a *Manager) CheckAllowAccess(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument return ErrInvalidArgument
@ -856,24 +881,6 @@ func (a *Manager) AllowAccess(username string, topicPattern string, permission P
return nil return nil
} }
func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*). // empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error { func (a *Manager) ResetAccess(username string, topicPattern string) error {
@ -893,6 +900,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error {
return err return err
} }
// AddReservation creates two access control entries for the given topic: one with full read/write access for the
// given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
// can modify or delete them.
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// RemoveReservations deletes the access control entries associated with the given username/topic, as
// well as all entries with Everyone/topic. This is the counterpart for AddReservation.
func (a *Manager) RemoveReservations(username string, topics ...string) error { func (a *Manager) RemoveReservations(username string, topics ...string) error {
if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
return ErrInvalidArgument return ErrInvalidArgument
@ -925,7 +955,8 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database // CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error { func (a *Manager) CreateTier(tier *Tier) error {
if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil { tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
return err return err
} }
return nil return nil
@ -980,19 +1011,20 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
} }
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var code, name string var id, code, name string
var stripePriceID sql.NullString var stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrTierNotFound return nil, ErrTierNotFound
} }
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
} }
// When changed, note readUser() as well // When changed, note readUser() as well
return &Tier{ return &Tier{
ID: id,
Code: code, Code: code,
Name: name, Name: name,
Paid: stripePriceID.Valid, // If there is a price, it's a paid tier Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
@ -1069,36 +1101,41 @@ func migrateFrom1(db *sql.DB) error {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil { if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
return err return err
} }
if _, err := tx.Exec(createTablesQueriesNoTx); err != nil { if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
return err return err
} }
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { // Insert users from user_old into new user table, with ID and sync_topic
return err rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
}
rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close() defer rows.Close()
syncTopics := make(map[int]string) usernames := make([]string, 0)
for rows.Next() { for rows.Next() {
var userID int var username string
if err := rows.Scan(&userID); err != nil { if err := rows.Scan(&username); err != nil {
return err return err
} }
syncTopics[userID] = util.RandomString(syncTopicLength) usernames = append(usernames, username)
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return err return err
} }
for userID, syncTopic := range syncTopics { for _, username := range usernames {
if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil { userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err return err
} }
} }
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
return err return err
} }

View file

@ -259,8 +259,8 @@ func TestManager_ChangeRole(t *testing.T) {
func TestManager_Reservations(t *testing.T) { func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
reservations, err := a.Reservations("ben") reservations, err := a.Reservations("ben")
@ -294,7 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
})) }))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ChangeTier("ben", "pro")) require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
ben, err := a.User("ben") ben, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
@ -626,11 +626,14 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
everyoneGrants, err := a.Grants(Everyone) everyoneGrants, err := a.Grants(Everyone)
require.Nil(t, err) require.Nil(t, err)
require.True(t, strings.HasPrefix(phil.ID, "u_"))
require.Equal(t, "phil", phil.Name) require.Equal(t, "phil", phil.Name)
require.Equal(t, RoleAdmin, phil.Role) require.Equal(t, RoleAdmin, phil.Role)
require.Equal(t, syncTopicLength, len(phil.SyncTopic)) require.Equal(t, syncTopicLength, len(phil.SyncTopic))
require.Equal(t, 0, len(philGrants)) require.Equal(t, 0, len(philGrants))
require.True(t, strings.HasPrefix(ben.ID, "u_"))
require.NotEqual(t, phil.ID, ben.ID)
require.Equal(t, "ben", ben.Name) require.Equal(t, "ben", ben.Name)
require.Equal(t, RoleUser, ben.Role) require.Equal(t, RoleUser, ben.Role)
require.Equal(t, syncTopicLength, len(ben.SyncTopic)) require.Equal(t, syncTopicLength, len(ben.SyncTopic))
@ -641,6 +644,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Equal(t, "secret", benGrants[1].TopicPattern) require.Equal(t, "secret", benGrants[1].TopicPattern)
require.Equal(t, PermissionRead, benGrants[1].Allow) require.Equal(t, PermissionRead, benGrants[1].Allow)
require.Equal(t, "u_everyone", everyone.ID)
require.Equal(t, Everyone, everyone.Name) require.Equal(t, Everyone, everyone.Name)
require.Equal(t, RoleAnonymous, everyone.Role) require.Equal(t, RoleAnonymous, everyone.Role)
require.Equal(t, 1, len(everyoneGrants)) require.Equal(t, 1, len(everyoneGrants))

View file

@ -10,6 +10,7 @@ import (
// User is a struct that represents a user // User is a struct that represents a user
type User struct { type User struct {
ID string
Name string Name string
Hash string // password hash (bcrypt) Hash string // password hash (bcrypt)
Token string // Only set if token was used to log in Token string // Only set if token was used to log in
@ -50,6 +51,7 @@ type Prefs struct {
// Tier represents a user's account type, including its account limits // Tier represents a user's account type, including its account limits
type Tier struct { type Tier struct {
ID string
Code string Code string
Name string Name string
Paid bool Paid bool

View file

@ -107,13 +107,18 @@ func LastString(s []string, def string) string {
// RandomString returns a random string with a given length // RandomString returns a random string with a given length
func RandomString(length int) string { func RandomString(length int) string {
return RandomStringPrefix("", length)
}
// RandomStringPrefix returns a random string with a given length, with a prefix
func RandomStringPrefix(prefix string, length int) string {
randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?! randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!
defer randomMutex.Unlock() defer randomMutex.Unlock()
b := make([]byte, length) b := make([]byte, length-len(prefix))
for i := range b { for i := range b {
b[i] = randomStringCharset[random.Intn(len(randomStringCharset))] b[i] = randomStringCharset[random.Intn(len(randomStringCharset))]
} }
return string(b) return prefix + string(b)
} }
// ValidRandomString returns true if the given string matches the format created by RandomString // ValidRandomString returns true if the given string matches the format created by RandomString