Allow mocking the Stripe API
This commit is contained in:
parent
3bd6518309
commit
4e51a715c1
6 changed files with 224 additions and 29 deletions
1
go.mod
1
go.mod
|
@ -49,6 +49,7 @@ require (
|
||||||
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.0 // indirect
|
||||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
golang.org/x/net v0.4.0 // indirect
|
golang.org/x/net v0.4.0 // indirect
|
||||||
|
|
1
go.sum
1
go.sum
|
@ -94,6 +94,7 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
|
|
@ -37,8 +37,6 @@ import (
|
||||||
/*
|
/*
|
||||||
TODO
|
TODO
|
||||||
payments:
|
payments:
|
||||||
- send dunning emails when overdue
|
|
||||||
- payment methods
|
|
||||||
- delete subscription when account deleted
|
- delete subscription when account deleted
|
||||||
- delete messages + reserved topics on ResetTier
|
- delete messages + reserved topics on ResetTier
|
||||||
|
|
||||||
|
@ -76,9 +74,10 @@ type Server struct {
|
||||||
visitors map[string]*visitor // ip:<ip> or user:<user>
|
visitors map[string]*visitor // ip:<ip> or user:<user>
|
||||||
firebaseClient *firebaseClient
|
firebaseClient *firebaseClient
|
||||||
messages int64
|
messages int64
|
||||||
userManager *user.Manager // Might be nil!
|
userManager *user.Manager // Might be nil!
|
||||||
messageCache *messageCache
|
messageCache *messageCache // Database that stores the messages
|
||||||
fileCache *fileCache
|
fileCache *fileCache // File system based cache that stores attachments
|
||||||
|
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||||
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
|
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -160,6 +159,10 @@ func New(conf *Config) (*Server, error) {
|
||||||
if conf.SMTPSenderAddr != "" {
|
if conf.SMTPSenderAddr != "" {
|
||||||
mailer = &smtpSender{config: conf}
|
mailer = &smtpSender{config: conf}
|
||||||
}
|
}
|
||||||
|
var stripe stripeAPI
|
||||||
|
if conf.StripeSecretKey != "" {
|
||||||
|
stripe = newStripeAPI()
|
||||||
|
}
|
||||||
messageCache, err := createMessageCache(conf)
|
messageCache, err := createMessageCache(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -190,7 +193,7 @@ func New(conf *Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
firebaseClient = newFirebaseClient(sender, userManager)
|
firebaseClient = newFirebaseClient(sender, userManager)
|
||||||
}
|
}
|
||||||
return &Server{
|
s := &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
fileCache: fileCache,
|
fileCache: fileCache,
|
||||||
|
@ -199,8 +202,10 @@ func New(conf *Config) (*Server, error) {
|
||||||
topics: topics,
|
topics: topics,
|
||||||
userManager: userManager,
|
userManager: userManager,
|
||||||
visitors: make(map[string]*visitor),
|
visitors: make(map[string]*visitor),
|
||||||
priceCache: util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration),
|
stripe: stripe,
|
||||||
}, nil
|
}
|
||||||
|
s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config) (*messageCache, error) {
|
func createMessageCache(conf *Config) (*messageCache, error) {
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
|
||||||
|
|
||||||
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
|
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if s.config.StripeSecretKey == "" {
|
if s.config.StripeSecretKey == "" || s.stripe == nil {
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
var stripeCustomerID *string
|
var stripeCustomerID *string
|
||||||
if v.user.Billing.StripeCustomerID != "" {
|
if v.user.Billing.StripeCustomerID != "" {
|
||||||
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
||||||
stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
|
stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
|
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
|
||||||
|
@ -120,7 +120,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
Enabled: stripe.Bool(true),
|
Enabled: stripe.Bool(true),
|
||||||
},*/
|
},*/
|
||||||
}
|
}
|
||||||
sess, err := session.New(params)
|
sess, err := s.stripe.NewCheckoutSession(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -137,14 +137,14 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||||
return errHTTPInternalErrorInvalidPath
|
return errHTTPInternalErrorInvalidPath
|
||||||
}
|
}
|
||||||
sessionID := matches[1]
|
sessionID := matches[1]
|
||||||
sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
|
sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("Stripe: %s", err)
|
log.Warn("Stripe: %s", err)
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
|
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
|
||||||
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
|
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
|
||||||
}
|
}
|
||||||
sub, err := subscription.Get(sess.Subscription.ID, nil)
|
sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
|
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
|
||||||
|
@ -180,7 +180,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
|
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
|
||||||
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
|
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err = subscription.Update(sub.ID, params)
|
_, err = s.stripe.UpdateSubscription(sub.ID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
||||||
params := &stripe.SubscriptionParams{
|
params := &stripe.SubscriptionParams{
|
||||||
CancelAtPeriodEnd: stripe.Bool(true),
|
CancelAtPeriodEnd: stripe.Bool(true),
|
||||||
}
|
}
|
||||||
_, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
|
_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -224,7 +224,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
||||||
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
||||||
ReturnURL: stripe.String(s.config.BaseURL),
|
ReturnURL: stripe.String(s.config.BaseURL),
|
||||||
}
|
}
|
||||||
ps, err := portalsession.New(params)
|
ps, err := s.stripe.NewPortalSession(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -248,7 +248,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
||||||
} else if body.LimitReached {
|
} else if body.LimitReached {
|
||||||
return errHTTPEntityTooLargeJSONBody
|
return errHTTPEntityTooLargeJSONBody
|
||||||
}
|
}
|
||||||
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
} else if event.Data == nil || event.Data.Raw == nil {
|
} else if event.Data == nil || event.Data.Raw == nil {
|
||||||
|
@ -331,24 +331,82 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
||||||
|
|
||||||
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
||||||
// in memory, and ultimately for the web app to display the price table.
|
// in memory, and ultimately for the web app to display the price table.
|
||||||
func fetchStripePrices() (map[string]string, error) {
|
func (s *Server) fetchStripePrices() (map[string]string, error) {
|
||||||
log.Debug("Caching prices from Stripe API")
|
log.Debug("Caching prices from Stripe API")
|
||||||
prices := make(map[string]string)
|
priceMap := make(map[string]string)
|
||||||
iter := price.List(&stripe.PriceListParams{
|
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
|
||||||
Active: stripe.Bool(true),
|
if err != nil {
|
||||||
})
|
log.Warn("Fetching Stripe prices failed: %s", err.Error())
|
||||||
for iter.Next() {
|
return nil, err
|
||||||
p := iter.Price()
|
}
|
||||||
|
for _, p := range prices {
|
||||||
if p.UnitAmount%100 == 0 {
|
if p.UnitAmount%100 == 0 {
|
||||||
prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||||
} else {
|
} else {
|
||||||
prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||||
}
|
}
|
||||||
log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
|
||||||
|
}
|
||||||
|
return priceMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripeAPI is a small interface to facilitate mocking of the Stripe API
|
||||||
|
type stripeAPI interface {
|
||||||
|
NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
|
||||||
|
NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
|
||||||
|
ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
|
||||||
|
GetCustomer(id string) (*stripe.Customer, error)
|
||||||
|
GetSession(id string) (*stripe.CheckoutSession, error)
|
||||||
|
GetSubscription(id string) (*stripe.Subscription, error)
|
||||||
|
UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
|
||||||
|
ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
|
||||||
|
type realStripeAPI struct{}
|
||||||
|
|
||||||
|
var _ stripeAPI = (*realStripeAPI)(nil)
|
||||||
|
|
||||||
|
func newStripeAPI() stripeAPI {
|
||||||
|
return &realStripeAPI{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
|
||||||
|
return session.New(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
|
||||||
|
return portalsession.New(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
|
||||||
|
prices := make([]*stripe.Price, 0)
|
||||||
|
iter := price.List(params)
|
||||||
|
for iter.Next() {
|
||||||
|
prices = append(prices, iter.Price())
|
||||||
}
|
}
|
||||||
if iter.Err() != nil {
|
if iter.Err() != nil {
|
||||||
log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
|
||||||
return nil, iter.Err()
|
return nil, iter.Err()
|
||||||
}
|
}
|
||||||
return prices, nil
|
return prices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
|
||||||
|
return customer.Get(id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
|
||||||
|
return session.Get(id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
|
||||||
|
return subscription.Get(id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
|
||||||
|
return subscription.Update(id, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
|
||||||
|
return webhook.ConstructEvent(payload, header, secret)
|
||||||
|
}
|
||||||
|
|
130
server/server_payments_test.go
Normal file
130
server/server_payments_test.go
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stripe/stripe-go/v74"
|
||||||
|
"heckel.io/ntfy/user"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
||||||
|
stripeMock := &testStripeAPI{}
|
||||||
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.StripeSecretKey = "secret key"
|
||||||
|
c.StripeWebhookKey = "webhook key"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
s.stripe = stripeMock
|
||||||
|
|
||||||
|
// Define how the mock should react
|
||||||
|
stripeMock.
|
||||||
|
On("NewCheckoutSession", mock.Anything).
|
||||||
|
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
|
||||||
|
|
||||||
|
// Create tier and user
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
StripePriceID: "price_123",
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||||
|
|
||||||
|
// Create subscription
|
||||||
|
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
||||||
|
stripeMock := &testStripeAPI{}
|
||||||
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.StripeSecretKey = "secret key"
|
||||||
|
c.StripeWebhookKey = "webhook key"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
s.stripe = stripeMock
|
||||||
|
|
||||||
|
// Define how the mock should react
|
||||||
|
stripeMock.
|
||||||
|
On("GetCustomer", "acct_123").
|
||||||
|
Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
|
||||||
|
stripeMock.
|
||||||
|
On("NewCheckoutSession", mock.Anything).
|
||||||
|
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
|
||||||
|
|
||||||
|
// Create tier and user
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
StripePriceID: "price_123",
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||||
|
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
u.Billing.StripeCustomerID = "acct_123"
|
||||||
|
require.Nil(t, s.userManager.ChangeBilling(u))
|
||||||
|
|
||||||
|
// Create subscription
|
||||||
|
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStripeAPI struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
|
||||||
|
args := s.Called(params)
|
||||||
|
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
|
||||||
|
args := s.Called(params)
|
||||||
|
return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
|
||||||
|
args := s.Called(params)
|
||||||
|
return args.Get(0).([]*stripe.Price), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
|
||||||
|
args := s.Called(id)
|
||||||
|
return args.Get(0).(*stripe.Customer), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
|
||||||
|
args := s.Called(id)
|
||||||
|
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
|
||||||
|
args := s.Called(id)
|
||||||
|
return args.Get(0).(*stripe.Subscription), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
|
||||||
|
args := s.Called(id)
|
||||||
|
return args.Get(0).(*stripe.Subscription), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
|
||||||
|
args := s.Called(payload, header, secret)
|
||||||
|
return args.Get(0).(stripe.Event), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ stripeAPI = (*testStripeAPI)(nil)
|
Loading…
Reference in a new issue