Websockets; working
This commit is contained in:
parent
cdc9c0d62c
commit
846ee0fb2d
6 changed files with 118 additions and 4 deletions
109
server/server.go
109
server/server.go
|
@ -10,6 +10,8 @@ import (
|
|||
"firebase.google.com/go/messaging"
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/api/option"
|
||||
"heckel.io/ntfy/util"
|
||||
"html/template"
|
||||
|
@ -99,6 +101,7 @@ var (
|
|||
jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
|
||||
ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
|
||||
rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
|
||||
wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
|
||||
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
|
||||
|
||||
staticRegex = regexp.MustCompile(`^/static/.+`)
|
||||
|
@ -156,6 +159,10 @@ const (
|
|||
emptyMessageBody = "triggered" // Used if message body is empty
|
||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||
fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details
|
||||
wsWriteWait = 2 * time.Second
|
||||
wsBufferSize = 1024
|
||||
wsReadLimit = 64 // We only ever receive PINGs
|
||||
wsPongWait = 15 * time.Second
|
||||
)
|
||||
|
||||
// New instantiates a new Server. It creates the cache and adds a Firebase
|
||||
|
@ -404,6 +411,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
|||
return s.withRateLimit(w, r, s.handleSubscribeSSE)
|
||||
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeRaw)
|
||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeWS)
|
||||
}
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
@ -805,6 +814,106 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if err := v.SubscriptionAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitSubscriptions
|
||||
}
|
||||
defer v.RemoveSubscription()
|
||||
topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/ws") // Hack
|
||||
topicIDs := util.SplitNoEmpty(topicsStr, ",")
|
||||
topics, err := s.topicsFromIDs(topicIDs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
poll := readParam(r, "x-poll", "poll", "po") == "1"
|
||||
scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1"
|
||||
since, err := parseSince(r, poll)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upgrader := &websocket.Upgrader{
|
||||
ReadBufferSize: wsBufferSize,
|
||||
WriteBufferSize: wsBufferSize,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // We're open for business!
|
||||
},
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
g.Go(func() error {
|
||||
pongWait := s.config.KeepaliveInterval + wsPongWait
|
||||
conn.SetReadLimit(wsReadLimit)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
conn.SetPongHandler(func(appData string) error {
|
||||
return conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
})
|
||||
for {
|
||||
_, _, err := conn.NextReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
g.Go(func() error {
|
||||
ping := func() error {
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
v.Keepalive()
|
||||
if err := ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
sub := func(msg *message) error {
|
||||
if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
|
||||
return nil
|
||||
}
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteJSON(msg)
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) {
|
||||
messageFilter = readParam(r, "x-message", "message", "m")
|
||||
titleFilter = readParam(r, "x-title", "title", "t")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue