WIP: Stripe integration

This commit is contained in:
binwiederhier 2023-01-14 06:43:44 -05:00
parent 7007c0a0bd
commit 01fd4754f9
20 changed files with 557 additions and 43 deletions

View file

@ -44,8 +44,11 @@ const (
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL
attachment_expiry_duration INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tier_id INT,
@ -56,12 +59,16 @@ const (
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
created_by TEXT NOT NULL,
created_at INT NOT NULL,
last_seen INT NOT NULL,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id INT NOT NULL,
topic TEXT NOT NULL,
@ -93,18 +100,24 @@ const (
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
selectTopicPermsQuery = `
SELECT read, write
FROM user_access a
@ -204,9 +217,21 @@ const (
INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTierByCodeQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
)
// Schema management queries
@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) {
return users, nil
}
// User returns the user with the given username if it exists, or ErrNotFound otherwise.
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
// You may also pass Everyone to retrieve the anonymous user and its Grant list.
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(selectUserByNameQuery, username)
@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) {
return a.readUser(rows)
}
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
if err != nil {
@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrNotFound
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil {
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err
}
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
user.Billing = &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
}
}
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
Code: tierCode.String,
Name: tierName.String,
@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String,
}
}
return user, nil
@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error {
return nil
}
func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
return err
}
return nil
}
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(selectTierByCodeQuery, code)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string
var stripePriceID sql.NullString
var paid bool
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
// When changed, note readUser() as well
return &Tier{
Code: code,
Name: name,
Paid: paid,
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, // May be empty!
}, nil
}
func toSQLWildcard(s string) string {
return strings.ReplaceAll(s, "*", "%")
}