diff --git a/README.md b/README.md new file mode 100644 index 0000000..93232bc --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ + + +echo "mychan:long process is done" | nc -N ntfy.sh 9999 +curl -d "long process is done" ntfy.sh/mychan + publish on channel + +curl ntfy.sh/mychan + subscribe to channel + +ntfy.sh/mychan/ws diff --git a/go.mod b/go.mod index e3f7ce5..ee3af69 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module heckel.io/notifyme go 1.16 + +require github.com/gorilla/websocket v1.4.2 // indirect diff --git a/go.sum b/go.sum index e69de29..85efffd 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/main.go b/main.go index 973a8ac..43af89f 100644 --- a/main.go +++ b/main.go @@ -1,138 +1,13 @@ package main import ( - "context" - "encoding/json" - "errors" - "io" + "heckel.io/notifyme/server" "log" - "math/rand" - "net/http" - "sync" - "time" ) -type Message struct { - Time int64 `json:"time"` - Message string `json:"message"` -} - -type Channel struct { - id string - listeners map[int]listener - last time.Time - ctx context.Context - mu sync.Mutex -} - -type Server struct { - channels map[string]*Channel - mu sync.Mutex -} - -type listener func(msg *Message) - func main() { - s := &Server{ - channels: make(map[string]*Channel), - } - go func() { - for { - time.Sleep(5 * time.Second) - s.mu.Lock() - log.Printf("channels: %d", len(s.channels)) - s.mu.Unlock() - } - }() - http.HandleFunc("/", s.handle) - if err := http.ListenAndServe(":9997", nil); err != nil { + s := server.New() + if err := s.Run(); err != nil { log.Fatalln(err) } } - -func (s *Server) handle(w http.ResponseWriter, r *http.Request) { - if err := s.handleInternal(w, r); err != nil { - w.WriteHeader(http.StatusInternalServerError) - _, _ = io.WriteString(w, err.Error()) - } -} - -func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { - if len(r.URL.Path) == 0 { - return errors.New("invalid path") - } - channel := s.channel(r.URL.Path[1:]) - switch r.Method { - case http.MethodGet: - return s.handleGET(w, r, channel) - case http.MethodPut: - return s.handlePUT(w, r, channel) - default: - return errors.New("invalid method") - } -} - -func (s *Server) handleGET(w http.ResponseWriter, r *http.Request, ch *Channel) error { - fl, ok := w.(http.Flusher) - if !ok { - return errors.New("not a flusher") - } - listenerID := rand.Int() - l := func (msg *Message) { - json.NewEncoder(w).Encode(&msg) - fl.Flush() - } - ch.mu.Lock() - ch.listeners[listenerID] = l - ch.last = time.Now() - ch.mu.Unlock() - select { - case <-ch.ctx.Done(): - case <-r.Context().Done(): - } - ch.mu.Lock() - delete(ch.listeners, listenerID) - if len(ch.listeners) == 0 { - s.mu.Lock() - delete(s.channels, ch.id) - s.mu.Unlock() - } - ch.mu.Unlock() - return nil -} - -func (s *Server) handlePUT(w http.ResponseWriter, r *http.Request, ch *Channel) error { - ch.mu.Lock() - defer ch.mu.Unlock() - if len(ch.listeners) == 0 { - return errors.New("no listeners") - } - defer r.Body.Close() - ch.last = time.Now() - msg, _ := io.ReadAll(r.Body) - for _, l := range ch.listeners { - l(&Message{ - Time: time.Now().UnixMilli(), - Message: string(msg), - }) - } - return nil -} - -func (s *Server) channel(channelID string) *Channel { - s.mu.Lock() - defer s.mu.Unlock() - c, ok := s.channels[channelID] - if !ok { - ctx, _ := context.WithCancel(context.Background()) // FIXME - c = &Channel{ - id: channelID, - listeners: make(map[int]listener), - last: time.Now(), - ctx: ctx, - mu: sync.Mutex{}, - } - s.channels[channelID] = c - } - return c -} diff --git a/server/index.html b/server/index.html new file mode 100644 index 0000000..9ab3b9b --- /dev/null +++ b/server/index.html @@ -0,0 +1,79 @@ + + + + ntfy.sh + + + +

ntfy.sh

+ +Topics: + + + + + + +
+ + + + diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..3b839b8 --- /dev/null +++ b/server/server.go @@ -0,0 +1,185 @@ +package server + +import ( + "bytes" + _ "embed" // required for go:embed + "encoding/json" + "errors" + "github.com/gorilla/websocket" + "io" + "log" + "net/http" + "regexp" + "strings" + "sync" + "time" +) + +type Server struct { + topics map[string]*topic + mu sync.Mutex +} + +type message struct { + Time int64 `json:"time"` + Message string `json:"message"` +} + +const ( + messageLimit = 1024 +) + +var ( + topicRegex = regexp.MustCompile(`^/[^/]+$`) + wsRegex = regexp.MustCompile(`^/[^/]+/ws$`) + jsonRegex = regexp.MustCompile(`^/[^/]+/json$`) + wsUpgrader = websocket.Upgrader{ + ReadBufferSize: messageLimit, + WriteBufferSize: messageLimit, + } + + //go:embed "index.html" + indexSource string +) + +func New() *Server { + return &Server{ + topics: make(map[string]*topic), + } +} + +func (s *Server) Run() error { + go func() { + for { + time.Sleep(5 * time.Second) + s.mu.Lock() + log.Printf("topics: %d", len(s.topics)) + for _, t := range s.topics { + t.mu.Lock() + log.Printf("- %s: %d subscriber(s), %d message(s) sent, last active = %s", + t.id, len(t.subscribers), t.messages, t.last.String()) + t.mu.Unlock() + } + // TODO kill dead topics + s.mu.Unlock() + } + }() + log.Printf("Listening on :9997") + http.HandleFunc("/", s.handle) + return http.ListenAndServe(":9997", nil) +} + +func (s *Server) handle(w http.ResponseWriter, r *http.Request) { + if err := s.handleInternal(w, r); err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, err.Error()) + } +} + +func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { + if r.Method == http.MethodGet && r.URL.Path == "/" { + return s.handleHome(w, r) + } else if r.Method == http.MethodGet && wsRegex.MatchString(r.URL.Path) { + return s.handleSubscribeWS(w, r) + } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { + return s.handleSubscribeHTTP(w, r) + } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { + return s.handlePublishHTTP(w, r) + } + http.NotFound(w, r) + return nil +} + +func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { + _, err := io.WriteString(w, indexSource) + return err +} + +func (s *Server) handlePublishHTTP(w http.ResponseWriter, r *http.Request) error { + t, err := s.topic(r.URL.Path[1:]) + if err != nil { + return err + } + reader := io.LimitReader(r.Body, messageLimit) + b, err := io.ReadAll(reader) + if err != nil { + return err + } + msg := &message{ + Time: time.Now().UnixMilli(), + Message: string(b), + } + return t.Publish(msg) +} + +func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request) error { + t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/json")) // Hack + subscriberID := t.Subscribe(func (msg *message) error { + if err := json.NewEncoder(w).Encode(&msg); err != nil { + return err + } + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + return nil + }) + defer t.Unsubscribe(subscriberID) + select { + case <-t.ctx.Done(): + case <-r.Context().Done(): + } + return nil +} + +func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request) error { + conn, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + return err + } + t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/ws")) // Hack + t.Subscribe(func (msg *message) error { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(&msg); err != nil { + return err + } + defer conn.Close() + /*conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + }*/ + + w, err := conn.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + if _, err := w.Write([]byte(msg.Message)); err != nil { + return err + } + if err := w.Close(); err != nil { + return err + } + return nil + }) + return nil +} + +func (s *Server) createTopic(id string) *topic { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.topics[id]; !ok { + s.topics[id] = newTopic(id) + } + return s.topics[id] +} + +func (s *Server) topic(topicID string) (*topic, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.topics[topicID] + if !ok { + return nil, errors.New("topic does not exist") + } + return c, nil +} diff --git a/server/topic.go b/server/topic.go new file mode 100644 index 0000000..65dd741 --- /dev/null +++ b/server/topic.go @@ -0,0 +1,68 @@ +package server + +import ( + "context" + "errors" + "log" + "math/rand" + "sync" + "time" +) + +type topic struct { + id string + subscribers map[int]subscriber + messages int + last time.Time + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +type subscriber func(msg *message) error + +func newTopic(id string) *topic { + ctx, cancel := context.WithCancel(context.Background()) + return &topic{ + id: id, + subscribers: make(map[int]subscriber), + last: time.Now(), + ctx: ctx, + cancel: cancel, + } +} + +func (t *topic) Subscribe(s subscriber) int { + t.mu.Lock() + defer t.mu.Unlock() + subscriberID := rand.Int() + t.subscribers[subscriberID] = s + t.last = time.Now() + return subscriberID +} + +func (t *topic) Unsubscribe(id int) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.subscribers, id) +} + +func (t *topic) Publish(m *message) error { + t.mu.Lock() + defer t.mu.Unlock() + if len(t.subscribers) == 0 { + return errors.New("no subscribers") + } + t.last = time.Now() + t.messages++ + for _, s := range t.subscribers { + if err := s(m); err != nil { + log.Printf("error publishing message to subscriber x") + } + } + return nil +} + +func (t *topic) Close() { + t.cancel() +}