Websockets; working

This commit is contained in:
Philipp Heckel 2022-01-15 13:23:35 -05:00
parent cdc9c0d62c
commit 846ee0fb2d
6 changed files with 118 additions and 4 deletions

View file

@ -10,6 +10,8 @@ import (
"firebase.google.com/go/messaging"
"fmt"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
"google.golang.org/api/option"
"heckel.io/ntfy/util"
"html/template"
@ -99,6 +101,7 @@ var (
jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
staticRegex = regexp.MustCompile(`^/static/.+`)
@ -156,6 +159,10 @@ const (
emptyMessageBody = "triggered" // Used if message body is empty
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details
wsWriteWait = 2 * time.Second
wsBufferSize = 1024
wsReadLimit = 64 // We only ever receive PINGs
wsPongWait = 15 * time.Second
)
// New instantiates a new Server. It creates the cache and adds a Firebase
@ -404,6 +411,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
return s.withRateLimit(w, r, s.handleSubscribeSSE)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeRaw)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeWS)
}
return errHTTPNotFound
}
@ -805,6 +814,106 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
}
}
func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
if err := v.SubscriptionAllowed(); err != nil {
return errHTTPTooManyRequestsLimitSubscriptions
}
defer v.RemoveSubscription()
topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/ws") // Hack
topicIDs := util.SplitNoEmpty(topicsStr, ",")
topics, err := s.topicsFromIDs(topicIDs...)
if err != nil {
return err
}
poll := readParam(r, "x-poll", "poll", "po") == "1"
scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1"
since, err := parseSince(r, poll)
if err != nil {
return err
}
messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
if err != nil {
return err
}
upgrader := &websocket.Upgrader{
ReadBufferSize: wsBufferSize,
WriteBufferSize: wsBufferSize,
CheckOrigin: func(r *http.Request) bool {
return true // We're open for business!
},
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
defer conn.Close()
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
pongWait := s.config.KeepaliveInterval + wsPongWait
conn.SetReadLimit(wsReadLimit)
if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
return err
}
conn.SetPongHandler(func(appData string) error {
return conn.SetReadDeadline(time.Now().Add(pongWait))
})
for {
_, _, err := conn.NextReader()
if err != nil {
return err
}
}
})
g.Go(func() error {
ping := func() error {
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
return err
}
return conn.WriteMessage(websocket.PingMessage, nil)
}
for {
select {
case <-ctx.Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
if err := ping(); err != nil {
return err
}
}
}
})
sub := func(msg *message) error {
if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
return nil
}
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
return err
}
return conn.WriteJSON(msg)
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
}
defer func() {
for i, subscriberID := range subscriberIDs {
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
return err
}
return g.Wait()
}
func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) {
messageFilter = readParam(r, "x-message", "message", "m")
titleFilter = readParam(r, "x-title", "title", "t")