WIP: Stripe integration

This commit is contained in:
binwiederhier 2023-01-14 06:43:44 -05:00
parent 7007c0a0bd
commit 01fd4754f9
20 changed files with 557 additions and 43 deletions

View file

@ -103,7 +103,7 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
u, err := manager.User(username) u, err := manager.User(username)
if err == user.ErrNotFound { if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} else if u.Role == user.RoleAdmin { } else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
@ -173,7 +173,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error {
func showUserAccess(c *cli.Context, manager *user.Manager, username string) error { func showUserAccess(c *cli.Context, manager *user.Manager, username string) error {
users, err := manager.User(username) users, err := manager.User(username)
if err == user.ErrNotFound { if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} else if err != nil { } else if err != nil {
return err return err

View file

@ -5,6 +5,7 @@ package cmd
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"io/fs" "io/fs"
"math" "math"
@ -61,7 +62,6 @@ var flagsServe = append(
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-payments", Aliases: []string{"enable_payments"}, EnvVars: []string{"NTFY_ENABLE_PAYMENTS"}, Value: false, Usage: "enables payments integration [preliminary option, may change]"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
@ -80,6 +80,8 @@ var flagsServe = append(
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
) )
var cmdServe = &cli.Command{ var cmdServe = &cli.Command{
@ -132,7 +134,6 @@ func execServe(c *cli.Context) error {
webRoot := c.String("web-root") webRoot := c.String("web-root")
enableSignup := c.Bool("enable-signup") enableSignup := c.Bool("enable-signup")
enableLogin := c.Bool("enable-login") enableLogin := c.Bool("enable-login")
enablePayments := c.Bool("enable-payments")
enableReservations := c.Bool("enable-reservations") enableReservations := c.Bool("enable-reservations")
upstreamBaseURL := c.String("upstream-base-url") upstreamBaseURL := c.String("upstream-base-url")
smtpSenderAddr := c.String("smtp-sender-addr") smtpSenderAddr := c.String("smtp-sender-addr")
@ -152,6 +153,8 @@ func execServe(c *cli.Context) error {
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy") behindProxy := c.Bool("behind-proxy")
stripeKey := c.String("stripe-key")
stripeWebhookKey := c.String("stripe-webhook-key")
// Check values // Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
@ -188,14 +191,17 @@ func execServe(c *cli.Context) error {
return errors.New("if upstream-base-url is set, base-url must also be set") return errors.New("if upstream-base-url is set, base-url must also be set")
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL { } else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications") return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || enablePayments) { } else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or enable-payments if auth-file is not set") return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
} else if enableSignup && !enableLogin { } else if enableSignup && !enableLogin {
return errors.New("cannot set enable-signup without also setting enable-login") return errors.New("cannot set enable-signup without also setting enable-login")
} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
} }
webRootIsApp := webRoot == "app" webRootIsApp := webRoot == "app"
enableWeb := webRoot != "disable" enableWeb := webRoot != "disable"
enablePayments := stripeKey != ""
// Default auth permissions // Default auth permissions
authDefault, err := user.ParsePermission(authDefaultAccess) authDefault, err := user.ParsePermission(authDefaultAccess)
@ -239,6 +245,11 @@ func execServe(c *cli.Context) error {
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...) visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
} }
// Stripe things
if stripeKey != "" {
stripe.Key = stripeKey
}
// Run server // Run server
conf := server.NewConfig() conf := server.NewConfig()
conf.BaseURL = baseURL conf.BaseURL = baseURL
@ -282,6 +293,8 @@ func execServe(c *cli.Context) error {
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy conf.BehindProxy = behindProxy
conf.StripeKey = stripeKey
conf.StripeWebhookKey = stripeWebhookKey
conf.EnableWeb = enableWeb conf.EnableWeb = enableWeb
conf.EnableSignup = enableSignup conf.EnableSignup = enableSignup
conf.EnableLogin = enableLogin conf.EnableLogin = enableLogin

View file

@ -215,7 +215,7 @@ func execUserDel(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == user.ErrNotFound { if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if err := manager.RemoveUser(username); err != nil { if err := manager.RemoveUser(username); err != nil {
@ -237,7 +237,7 @@ func execUserChangePass(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == user.ErrNotFound { if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if password == "" { if password == "" {
@ -265,7 +265,7 @@ func execUserChangeRole(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == user.ErrNotFound { if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if err := manager.ChangeRole(username, role); err != nil { if err := manager.ChangeRole(username, role); err != nil {
@ -289,7 +289,7 @@ func execUserChangeTier(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := manager.User(username); err == user.ErrNotFound { if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} }
if tier == tierReset { if tier == tierReset {

4
go.mod
View file

@ -46,6 +46,10 @@ require (
github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/googleapis/gax-go/v2 v2.7.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/stripe/stripe-go/v74 v74.5.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
golang.org/x/net v0.4.0 // indirect golang.org/x/net v0.4.0 // indirect

12
go.sum
View file

@ -95,10 +95,20 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
@ -119,6 +129,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@ -135,6 +146,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -110,6 +110,8 @@ type Config struct {
VisitorAccountCreateLimitReplenish time.Duration VisitorAccountCreateLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool BehindProxy bool
StripeKey string
StripeWebhookKey string
EnableWeb bool EnableWeb bool
EnableSignup bool // Enable creation of accounts via API and UI EnableSignup bool // Enable creation of accounts via API and UI
EnableLogin bool EnableLogin bool

View file

@ -58,6 +58,8 @@ var (
errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""} errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""}
errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""} errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""} errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
errHTTPBadRequestInvalidStripeRequest = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View file

@ -36,6 +36,10 @@ import (
/* /*
TODO TODO
payments:
- handle overdue payment (-> downgrade after 7 days)
- delete stripe subscription when acocunt is deleted
Limits & rate limiting: Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful? users without tier: should the stats be persisted? are they meaningful?
-> test that the visitor is based on the IP address! -> test that the visitor is based on the IP address!
@ -43,6 +47,7 @@ import (
update last_seen when API is accessed update last_seen when API is accessed
Make sure account endpoints make sense for admins Make sure account endpoints make sense for admins
triggerChange after publishing a message
UI: UI:
- flicker of upgrade banner - flicker of upgrade banner
- JS constants - JS constants
@ -100,6 +105,11 @@ var (
accountSettingsPath = "/v1/account/settings" accountSettingsPath = "/v1/account/settings"
accountSubscriptionPath = "/v1/account/subscription" accountSubscriptionPath = "/v1/account/subscription"
accountReservationPath = "/v1/account/reservation" accountReservationPath = "/v1/account/reservation"
accountBillingPortalPath = "/v1/account/billing/portal"
accountBillingWebhookPath = "/v1/account/billing/webhook"
accountCheckoutPath = "/v1/account/checkout"
accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}"
accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`)
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
matrixPushPath = "/_matrix/push/v1/notify" matrixPushPath = "/_matrix/push/v1/notify"
@ -362,6 +372,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureUser(s.handleAccountReservationAdd)(w, r, v) return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountReservationDelete)(w, r, v) return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath {
return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v)
} else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w) return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {

View file

@ -2,6 +2,14 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/http" "net/http"
@ -9,6 +17,7 @@ import (
const ( const (
jsonBodyBytesLimit = 4096 jsonBodyBytesLimit = 4096
stripeBodyBytesLimit = 16384
subscriptionIDLength = 16 subscriptionIDLength = 16
createdByAPI = "api" createdByAPI = "api"
) )
@ -386,3 +395,226 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
return nil return nil
} }
func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
}
if tier.StripePriceID == "" {
log.Info("Checkout: Downgrading to no tier")
return errors.New("not a paid tier")
} else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" {
log.Info("Checkout: Changing tier and subscription to %s", tier.Code)
// Upgrade/downgrade tier
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
} else {
// Checkout flow
log.Info("Checkout: No existing subscription, creating checkout flow")
}
successURL := s.config.BaseURL + accountCheckoutSuccessTemplate
var stripeCustomerID *string
if v.user.Billing != nil {
stripeCustomerID = &v.user.Billing.StripeCustomerID
}
params := &stripe.CheckoutSessionParams{
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
Customer: stripeCustomerID,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Quantity: stripe.Int64(1),
},
},
}
sess, err := session.New(params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
sessionID := matches[1]
// FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
return errHTTPBadRequestInvalidStripeRequest
}
priceID := sub.Items.Data[0].Price.ID
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
u, err := s.userManager.User(sess.ClientReferenceID)
if err != nil {
return err
}
if u.Billing == nil {
u.Billing = &user.Billing{}
}
u.Billing.StripeCustomerID = sess.Customer.ID
u.Billing.StripeSubscriptionID = sess.Subscription.ID
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
return nil
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing == nil {
return errHTTPBadRequestNotAPaidUser
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
}
ps, err := portalsession.New(params)
if err != nil {
return err
}
response := &apiAccountBillingPortalRedirectResponse{
RedirectURL: ps.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
if err != nil {
return err
} else if body.LimitReached {
return errHTTPEntityTooLargeJSONBody
}
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
log.Warn("Stripe: invalid request: %s", err.Error())
return errHTTPBadRequestInvalidStripeRequest
} else if event.Data == nil || event.Data.Raw == nil {
log.Warn("Stripe: invalid request, data is nil")
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: webhook event %s received", event.Type)
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
switch event.Type {
case "checkout.session.completed":
// Payment is successful and the subscription is created.
// Provision the subscription, save the customer ID.
return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw)
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
case "invoice.paid":
// Continue to provision the subscription as payments continue to be made.
// Store the status in your database and check when a user accesses your service.
// This approach helps you avoid hitting rate limits.
return nil // FIXME
case "invoice.payment_failed":
// The payment failed or the customer does not have a valid payment method.
// The subscription becomes past_due. Notify your customer and send them to the
// customer portal to update their payment information.
return nil // FIXME
default:
log.Warn("Stripe: unhandled webhook %s", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error {
log.Info("Stripe: checkout completed for customer %s", stripeCustomerID)
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !status.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
if err != nil {
return err
}
tier, err := s.userManager.TierByStripePrice(priceID.String())
if err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
return nil
}

View file

@ -295,3 +295,15 @@ type apiConfigResponse struct {
EnableReservations bool `json:"enable_reservations"` EnableReservations bool `json:"enable_reservations"`
DisallowedTopics []string `json:"disallowed_topics"` DisallowedTopics []string `json:"disallowed_topics"`
} }
type apiAccountTierChangeRequest struct {
Tier string `json:"tier"`
}
type apiAccountCheckoutResponse struct {
RedirectURL string `json:"redirect_url"`
}
type apiAccountBillingPortalRedirectResponse struct {
RedirectURL string `json:"redirect_url"`
}

View file

@ -44,8 +44,11 @@ const (
reservations_limit INT NOT NULL, reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL attachment_expiry_duration INT NOT NULL,
stripe_price_id TEXT
); );
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
tier_id INT, tier_id INT,
@ -56,12 +59,16 @@ const (
sync_topic TEXT NOT NULL, sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0), stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
created_by TEXT NOT NULL, created_by TEXT NOT NULL,
created_at INT NOT NULL, created_at INT NOT NULL,
last_seen INT NOT NULL, last_seen INT NOT NULL,
FOREIGN KEY (tier_id) REFERENCES tier (id) FOREIGN KEY (tier_id) REFERENCES tier (id)
); );
CREATE UNIQUE INDEX idx_user ON user (user); CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access ( CREATE TABLE IF NOT EXISTS user_access (
user_id INT NOT NULL, user_id INT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
@ -93,18 +100,24 @@ const (
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
selectTopicPermsQuery = ` selectTopicPermsQuery = `
SELECT read, write SELECT read, write
FROM user_access a FROM user_access a
@ -205,8 +218,20 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTierByCodeQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
) )
// Schema management queries // Schema management queries
@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) {
return users, nil return users, nil
} }
// User returns the user with the given username if it exists, or ErrNotFound otherwise. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
// You may also pass Everyone to retrieve the anonymous user and its Grant list. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
func (a *Manager) User(username string) (*User, error) { func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(selectUserByNameQuery, username) rows, err := a.db.Query(selectUserByNameQuery, username)
@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) {
return a.readUser(rows) return a.readUser(rows)
} }
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix()) rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
if err != nil { if err != nil {
@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role, prefs, syncTopic string var username, hash, role, prefs, syncTopic string
var tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool var paid sql.NullBool
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil { if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err return nil, err
} }
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
user.Billing = &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
}
}
if tierCode.Valid { if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{ user.Tier = &Tier{
Code: tierCode.String, Code: tierCode.String,
Name: tierName.String, Name: tierName.String,
@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String,
} }
} }
return user, nil return user, nil
@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error {
return nil return nil
} }
func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
return err
}
return nil
}
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(selectTierByCodeQuery, code)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string
var stripePriceID sql.NullString
var paid bool
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
// When changed, note readUser() as well
return &Tier{
Code: code,
Name: name,
Paid: paid,
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, // May be empty!
}, nil
}
func toSQLWildcard(s string) string { func toSQLWildcard(s string) string {
return strings.ReplaceAll(s, "*", "%") return strings.ReplaceAll(s, "*", "%")
} }

View file

@ -208,7 +208,7 @@ func TestManager_UserManagement(t *testing.T) {
// Remove user // Remove user
require.Nil(t, a.RemoveUser("ben")) require.Nil(t, a.RemoveUser("ben"))
_, err = a.User("ben") _, err = a.User("ben")
require.Equal(t, ErrNotFound, err) require.Equal(t, ErrUserNotFound, err)
users, err = a.Users() users, err = a.Users()
require.Nil(t, err) require.Nil(t, err)

View file

@ -16,6 +16,7 @@ type User struct {
Prefs *Prefs Prefs *Prefs
Tier *Tier Tier *Tier
Stats *Stats Stats *Stats
Billing *Billing
SyncTopic string SyncTopic string
Created time.Time Created time.Time
LastSeen time.Time LastSeen time.Time
@ -58,6 +59,7 @@ type Tier struct {
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentExpiryDuration time.Duration AttachmentExpiryDuration time.Duration
StripePriceID string
} }
// Subscription represents a user's topic subscription // Subscription represents a user's topic subscription
@ -81,6 +83,12 @@ type Stats struct {
Emails int64 Emails int64
} }
// Billing is a struct holding a user's billing information
type Billing struct {
StripeCustomerID string
StripeSubscriptionID string
}
// Grant is a struct that represents an access control entry to a topic by a user // Grant is a struct that represents an access control entry to a topic by a user
type Grant struct { type Grant struct {
TopicPattern string // May include wildcard (*) TopicPattern string // May include wildcard (*)
@ -212,5 +220,6 @@ var (
ErrUnauthenticated = errors.New("unauthenticated") ErrUnauthenticated = errors.New("unauthenticated")
ErrUnauthorized = errors.New("unauthorized") ErrUnauthorized = errors.New("unauthorized")
ErrInvalidArgument = errors.New("invalid argument") ErrInvalidArgument = errors.New("invalid argument")
ErrNotFound = errors.New("not found") ErrUserNotFound = errors.New("user not found")
ErrTierNotFound = errors.New("tier not found")
) )

View file

@ -8,7 +8,7 @@ import {
accountTokenUrl, accountTokenUrl,
accountUrl, maybeWithAuth, topicUrl, accountUrl, maybeWithAuth, topicUrl,
withBasicAuth, withBasicAuth,
withBearerAuth withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl
} from "./utils"; } from "./utils";
import session from "./Session"; import session from "./Session";
import subscriptionManager from "./SubscriptionManager"; import subscriptionManager from "./SubscriptionManager";
@ -228,7 +228,7 @@ class AccountApi {
this.triggerChange(); // Dangle! this.triggerChange(); // Dangle!
} }
async upsertAccess(topic, everyone) { async upsertReservation(topic, everyone) {
const url = accountReservationUrl(config.base_url); const url = accountReservationUrl(config.base_url);
console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`); console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`);
const response = await fetch(url, { const response = await fetch(url, {
@ -249,7 +249,7 @@ class AccountApi {
this.triggerChange(); // Dangle! this.triggerChange(); // Dangle!
} }
async deleteAccess(topic) { async deleteReservation(topic) {
const url = accountReservationSingleUrl(config.base_url, topic); const url = accountReservationSingleUrl(config.base_url, topic);
console.log(`[AccountApi] Removing topic reservation ${url}`); console.log(`[AccountApi] Removing topic reservation ${url}`);
const response = await fetch(url, { const response = await fetch(url, {
@ -264,6 +264,39 @@ class AccountApi {
this.triggerChange(); // Dangle! this.triggerChange(); // Dangle!
} }
async createCheckoutSession(tier) {
const url = accountCheckoutUrl(config.base_url);
console.log(`[AccountApi] Creating checkout session`);
const response = await fetch(url, {
method: "POST",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({
tier: tier
})
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async createBillingPortalSession() {
const url = accountBillingPortalUrl(config.base_url);
console.log(`[AccountApi] Creating billing portal session`);
const response = await fetch(url, {
method: "POST",
headers: withBearerAuth({}, session.token())
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async sync() { async sync() {
try { try {
if (!session.token()) { if (!session.token()) {

View file

@ -26,6 +26,8 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr
export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`; export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`;
export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`; export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`;
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`;
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`; export const expandSecureUrl = (url) => `https://${url}`;

View file

@ -171,10 +171,28 @@ const Stats = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { account } = useContext(AccountContext); const { account } = useContext(AccountContext);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false); const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) { if (!account) {
return <></>; return <></>;
} }
const normalize = (value, max) => Math.min(value / max * 100, 100);
const normalize = (value, max) => {
return Math.min(value / max * 100, 100);
};
const handleManageBilling = async () => {
try {
const response = await accountApi.createBillingPortalSession();
window.location.href = response.redirect_url;
} catch (e) {
console.log(`[Account] Error changing password`, e);
if ((e instanceof UnauthorizedError)) {
session.resetAndRedirect(routes.login);
}
// TODO show error
}
};
return ( return (
<Card sx={{p: 3}} aria-label={t("account_usage_title")}> <Card sx={{p: 3}} aria-label={t("account_usage_title")}>
<Typography variant="h5" sx={{marginBottom: 2}}> <Typography variant="h5" sx={{marginBottom: 2}}>
@ -201,12 +219,20 @@ const Stats = () => {
>{t("account_usage_tier_upgrade_button")}</Button> >{t("account_usage_tier_upgrade_button")}</Button>
} }
{config.enable_payments && account.role === "user" && account.tier?.paid && {config.enable_payments && account.role === "user" && account.tier?.paid &&
<>
<Button <Button
variant="outlined" variant="outlined"
size="small" size="small"
onClick={() => setUpgradeDialogOpen(true)} onClick={() => setUpgradeDialogOpen(true)}
sx={{ml: 1}} sx={{ml: 1}}
>{t("account_usage_tier_change_button")}</Button> >{t("account_usage_tier_change_button")}</Button>
<Button
variant="outlined"
size="small"
onClick={handleManageBilling}
sx={{ml: 1}}
>Manage billing</Button>
</>
} }
<UpgradeDialog <UpgradeDialog
open={upgradeDialogOpen} open={upgradeDialogOpen}

View file

@ -501,7 +501,7 @@ const Reservations = () => {
const handleDialogSubmit = async (reservation) => { const handleDialogSubmit = async (reservation) => {
setDialogOpen(false); setDialogOpen(false);
try { try {
await accountApi.upsertAccess(reservation.topic, reservation.everyone); await accountApi.upsertReservation(reservation.topic, reservation.everyone);
await accountApi.sync(); await accountApi.sync();
console.debug(`[Preferences] Added topic reservation`, reservation); console.debug(`[Preferences] Added topic reservation`, reservation);
} catch (e) { } catch (e) {
@ -557,7 +557,7 @@ const ReservationsTable = (props) => {
const handleDialogSubmit = async (reservation) => { const handleDialogSubmit = async (reservation) => {
setDialogOpen(false); setDialogOpen(false);
try { try {
await accountApi.upsertAccess(reservation.topic, reservation.everyone); await accountApi.upsertReservation(reservation.topic, reservation.everyone);
await accountApi.sync(); await accountApi.sync();
console.debug(`[Preferences] Added topic reservation`, reservation); console.debug(`[Preferences] Added topic reservation`, reservation);
} catch (e) { } catch (e) {
@ -568,7 +568,7 @@ const ReservationsTable = (props) => {
const handleDeleteClick = async (reservation) => { const handleDeleteClick = async (reservation) => {
try { try {
await accountApi.deleteAccess(reservation.topic); await accountApi.deleteReservation(reservation.topic);
await accountApi.sync(); await accountApi.sync();
console.debug(`[Preferences] Deleted topic reservation`, reservation); console.debug(`[Preferences] Deleted topic reservation`, reservation);
} catch (e) { } catch (e) {

View file

@ -110,7 +110,7 @@ const SubscribePage = (props) => {
if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) { if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) {
console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`); console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`);
try { try {
await accountApi.upsertAccess(topic, everyone); await accountApi.upsertReservation(topic, everyone);
// Account sync later after it was added // Account sync later after it was added
} catch (e) { } catch (e) {
console.log(`[SubscribeDialog] Error reserving topic`, e); console.log(`[SubscribeDialog] Error reserving topic`, e);

View file

@ -37,9 +37,9 @@ const SubscriptionSettingsDialog = (props) => {
// Reservation // Reservation
if (reserveTopicVisible) { if (reserveTopicVisible) {
await accountApi.upsertAccess(subscription.topic, everyone); await accountApi.upsertReservation(subscription.topic, everyone);
} else if (!reserveTopicVisible && subscription.reservation) { // Was removed } else if (!reserveTopicVisible && subscription.reservation) { // Was removed
await accountApi.deleteAccess(subscription.topic); await accountApi.deleteReservation(subscription.topic);
} }
// Sync account // Sync account

View file

@ -2,28 +2,83 @@ import * as React from 'react';
import Dialog from '@mui/material/Dialog'; import Dialog from '@mui/material/Dialog';
import DialogContent from '@mui/material/DialogContent'; import DialogContent from '@mui/material/DialogContent';
import DialogTitle from '@mui/material/DialogTitle'; import DialogTitle from '@mui/material/DialogTitle';
import {useMediaQuery} from "@mui/material"; import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
import theme from "./theme"; import theme from "./theme";
import DialogFooter from "./DialogFooter"; import DialogFooter from "./DialogFooter";
import Button from "@mui/material/Button";
import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
import session from "../app/Session";
import routes from "./routes";
import {useContext, useState} from "react";
import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography";
import {AccountContext} from "./App";
const UpgradeDialog = (props) => { const UpgradeDialog = (props) => {
const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [selected, setSelected] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState("");
const handleSuccess = async () => { const handleCheckout = async () => {
// TODO try {
const response = await accountApi.createCheckoutSession(selected);
if (response.redirect_url) {
window.location.href = response.redirect_url;
} else {
await accountApi.sync();
}
} catch (e) {
console.log(`[UpgradeDialog] Error creating checkout session`, e);
if ((e instanceof UnauthorizedError)) {
session.resetAndRedirect(routes.login);
}
// FIXME show error
}
} }
return ( return (
<Dialog open={props.open} onClose={props.onCancel} fullScreen={fullScreen}> <Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
<DialogTitle>Upgrade to Pro</DialogTitle> <DialogTitle>Upgrade to Pro</DialogTitle>
<DialogContent> <DialogContent>
Content <div style={{
display: "flex",
flexDirection: "row"
}}>
<TierCard code={null} name={"Free"} selected={selected === null} onClick={() => setSelected(null)}/>
<TierCard code="starter" name={"Starter"} selected={selected === "starter"} onClick={() => setSelected("starter")}/>
<TierCard code="pro" name={"Pro"} selected={selected === "pro"} onClick={() => setSelected("pro")}/>
<TierCard code="business" name={"Business"} selected={selected === "business"} onClick={() => setSelected("business")}/>
</div>
</DialogContent> </DialogContent>
<DialogFooter> <DialogFooter status={errorText}>
Footer <Button onClick={handleCheckout}>Checkout</Button>
</DialogFooter> </DialogFooter>
</Dialog> </Dialog>
); );
}; };
const TierCard = (props) => {
const cardStyle = (props.selected) ? {
border: "1px solid red",
} : {};
return (
<Card sx={{ m: 1, maxWidth: 345 }}>
<CardActionArea>
<CardContent sx={{...cardStyle}} onClick={props.onClick}>
<Typography gutterBottom variant="h5" component="div">
{props.name}
</Typography>
<Typography variant="body2" color="text.secondary">
Lizards are a widespread group of squamate reptiles, with over 6,000
species, ranging across all continents except Antarctica
</Typography>
</CardContent>
</CardActionArea>
</Card>
);
}
export default UpgradeDialog; export default UpgradeDialog;