Payment stuff, cont'd

This commit is contained in:
binwiederhier 2023-01-15 23:29:46 -05:00
parent f7f7f469ad
commit c06bfb989e
11 changed files with 457 additions and 309 deletions

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/stripe/stripe-go/v74"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
@ -60,7 +61,9 @@ const (
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
created_by TEXT NOT NULL,
created_at INT NOT NULL,
last_seen INT NOT NULL,
@ -100,20 +103,20 @@ const (
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -231,7 +234,11 @@ const (
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
updateBillingQuery = `
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?
WHERE user = ?
`
)
// Schema management queries
@ -597,14 +604,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -619,16 +626,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Messages: messages,
Emails: emails,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
},
}
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err
}
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
user.Billing = &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
}
}
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
@ -868,7 +875,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
}
func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), user.Name); err != nil {
return err
}
return nil
@ -1020,3 +1027,17 @@ func migrateFrom1(db *sql.DB) error {
}
return nil // Update this when a new version is added
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}
func nullInt64(v int64) sql.NullInt64 {
if v == 0 {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: v, Valid: true}
}