Bill to visitor and set TTL in response

This commit is contained in:
Karmanyaah Malhotra 2023-02-14 14:07:02 -06:00
parent fb2fa4c478
commit 6bfe4a9779
2 changed files with 16 additions and 1 deletions

View file

@ -92,4 +92,5 @@ var (
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""}
) )

View file

@ -372,6 +372,7 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("TTL", "0") // if message is not being stored because of an error, tell them
w.WriteHeader(httpErr.HTTPCode) w.WriteHeader(httpErr.HTTPCode)
io.WriteString(w, httpErr.JSON()+"\n") io.WriteString(w, httpErr.JSON()+"\n")
} }
@ -605,6 +606,14 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if err != nil { if err != nil {
return nil, err return nil, err
} }
v_old := v
if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) {
v = t.getBillee()
if v == nil {
return nil, errHTTPWontStoreMessage
}
}
if !v.MessageAllowed() { if !v.MessageAllowed() {
return nil, errHTTPTooManyRequestsLimitMessages return nil, errHTTPTooManyRequestsLimitMessages
} }
@ -639,8 +648,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
"message_email": email, "message_email": email,
}). }).
Debug("Received message") Debug("Received message")
//Where should I log the original visitor vs the billing visitor
if log.IsTrace() { if log.IsTrace() {
logvrm(v, r, m). logvrm(v_old, r, m).
Tag(tagPublish). Tag(tagPublish).
Field("message_body", util.MaybeMarshalJSON(m)). Field("message_body", util.MaybeMarshalJSON(m)).
Trace("Message body") Trace("Message body")
@ -684,6 +694,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err != nil { if err != nil {
return err return err
} }
w.Header().Set("TTL", strconv.FormatInt(m.Expires-m.Time, 10)) // return how long a message will be stored for
// using m.Time, not time.Now() so the value isn't negative if the request is processed at a second boundary
return s.writeJSON(w, m) return s.writeJSON(w, m)
} }