diff --git a/server/server.go b/server/server.go index 6461085..47d4f19 100644 --- a/server/server.go +++ b/server/server.go @@ -33,6 +33,8 @@ type Server struct { config *Config httpServer *http.Server httpsServer *http.Server + smtpServer *smtp.Server + smtpBackend *smtpBackend topics map[string]*topic visitors map[string]*visitor firebase subscriber @@ -85,11 +87,12 @@ var ( ) var ( - topicRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! - jsonRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`) - sseRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`) - rawRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`) - sendRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`) + topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! + topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! + jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`) + ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`) + rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`) + publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`) staticRegex = regexp.MustCompile(`^/static/.+`) docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) @@ -225,6 +228,9 @@ func (s *Server) Run() error { if s.config.ListenHTTPS != "" { listenStr += fmt.Sprintf(" %s/https", s.config.ListenHTTPS) } + if s.config.SMTPServerListen != "" { + listenStr += fmt.Sprintf(" %s/smtp", s.config.SMTPServerListen) + } log.Printf("Listening on %s", listenStr) mux := http.NewServeMux() mux.HandleFunc("/", s.handle) @@ -243,7 +249,7 @@ func (s *Server) Run() error { } if s.config.SMTPServerListen != "" { go func() { - errChan <- s.runMailserver() + errChan <- s.runSMTPServer() }() } s.mu.Unlock() @@ -264,6 +270,9 @@ func (s *Server) Stop() { if s.httpsServer != nil { s.httpsServer.Close() } + if s.smtpServer != nil { + s.smtpServer.Close() + } close(s.closeChan) } @@ -295,17 +304,17 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { return s.handleDocs(w, r) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) - } else if r.Method == http.MethodGet && topicRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) { return s.handleTopic(w, r) - } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { + } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handlePublish) - } else if r.Method == http.MethodGet && sendRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handlePublish) - } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handleSubscribeJSON) - } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handleSubscribeSSE) - } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handleSubscribeRaw) } return errHTTPNotFound @@ -726,12 +735,15 @@ func (s *Server) updateStatsAndPrune() { messages += msgs } + // Mail + mailSuccess, mailFailure := s.smtpBackend.Counts() + // Print stats - log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", - s.messages, len(s.topics), subscribers, messages, len(s.visitors)) + log.Printf("Stats: %d message(s) published, %d in cache, %d successful mails, %d failed, %d topic(s) active, %d subscriber(s), %d visitor(s)", + s.messages, messages, mailSuccess, mailFailure, len(s.topics), subscribers, len(s.visitors)) } -func (s *Server) runMailserver() error { +func (s *Server) runSMTPServer() error { sub := func(m *message) error { url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message)) @@ -748,18 +760,16 @@ func (s *Server) runMailserver() error { } return nil } - ms := smtp.NewServer(newMailBackend(s.config, sub)) - - ms.Addr = s.config.SMTPServerListen - ms.Domain = s.config.SMTPServerDomain - ms.ReadTimeout = 10 * time.Second - ms.WriteTimeout = 10 * time.Second - ms.MaxMessageBytes = 2 * s.config.MessageLimit - ms.MaxRecipients = 1 - ms.AllowInsecureAuth = true - - log.Println("Starting server at", ms.Addr) - return ms.ListenAndServe() + s.smtpBackend = newMailBackend(s.config, sub) + s.smtpServer = smtp.NewServer(s.smtpBackend) + s.smtpServer.Addr = s.config.SMTPServerListen + s.smtpServer.Domain = s.config.SMTPServerDomain + s.smtpServer.ReadTimeout = 10 * time.Second + s.smtpServer.WriteTimeout = 10 * time.Second + s.smtpServer.MaxMessageBytes = 2 * s.config.MessageLimit + s.smtpServer.MaxRecipients = 1 + s.smtpServer.AllowInsecureAuth = true + return s.smtpServer.ListenAndServe() } func (s *Server) runManager() { diff --git a/server/smtp_server.go b/server/smtp_server.go index f304dea..df6a18b 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -5,17 +5,25 @@ import ( "errors" "github.com/emersion/go-smtp" "io" - "io/ioutil" - "log" "net/mail" "strings" "sync" ) +var ( + errInvalidDomain = errors.New("invalid domain") + errInvalidAddress = errors.New("invalid address") + errInvalidTopic = errors.New("invalid topic") + errTooManyRecipients = errors.New("too many recipients") +) + // smtpBackend implements SMTP server methods. type smtpBackend struct { - config *Config - sub subscriber + config *Config + sub subscriber + success int64 + failure int64 + mu sync.Mutex } func newMailBackend(conf *Config, sub subscriber) *smtpBackend { @@ -26,19 +34,24 @@ func newMailBackend(conf *Config, sub subscriber) *smtpBackend { } func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - return &smtpSession{config: b.config, sub: b.sub}, nil + return &smtpSession{backend: b}, nil } func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - return &smtpSession{config: b.config, sub: b.sub}, nil + return &smtpSession{backend: b}, nil +} + +func (b *smtpBackend) Counts() (success int64, failure int64) { + b.mu.Lock() + defer b.mu.Unlock() + return b.success, b.failure } // smtpSession is returned after EHLO. type smtpSession struct { - config *Config - sub subscriber - from, to string - mu sync.Mutex + backend *smtpBackend + topic string + mu sync.Mutex } func (s *smtpSession) AuthPlain(username, password string) error { @@ -46,63 +59,84 @@ func (s *smtpSession) AuthPlain(username, password string) error { } func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error { - s.mu.Lock() - defer s.mu.Unlock() - s.from = from return nil } func (s *smtpSession) Rcpt(to string) error { - s.mu.Lock() - defer s.mu.Unlock() - addressList, err := mail.ParseAddressList(to) - if err != nil { - return err - } else if len(addressList) != 1 { - return errors.New("only one recipient supported") - } else if !strings.HasSuffix(addressList[0].Address, "@"+s.config.SMTPServerDomain) { - return errors.New("invalid domain") - } else if s.config.SMTPServerAddrPrefix != "" && !strings.HasPrefix(addressList[0].Address, s.config.SMTPServerAddrPrefix) { - return errors.New("invalid address") - } - // FIXME check topic format - s.to = addressList[0].Address - return nil + return s.withFailCount(func() error { + conf := s.backend.config + addressList, err := mail.ParseAddressList(to) + if err != nil { + return err + } else if len(addressList) != 1 { + return errTooManyRecipients + } + to = addressList[0].Address + if !strings.HasSuffix(to, "@"+conf.SMTPServerDomain) { + return errInvalidDomain + } + to = strings.TrimSuffix(to, "@"+conf.SMTPServerDomain) + if conf.SMTPServerAddrPrefix != "" { + if !strings.HasPrefix(to, conf.SMTPServerAddrPrefix) { + return errInvalidAddress + } + to = strings.TrimPrefix(to, conf.SMTPServerAddrPrefix) + } + if !topicRegex.MatchString(to) { + return errInvalidTopic + } + s.mu.Lock() + s.topic = to + s.mu.Unlock() + return nil + }) } func (s *smtpSession) Data(r io.Reader) error { - s.mu.Lock() - defer s.mu.Unlock() - b, err := ioutil.ReadAll(r) - if err != nil { - return err - } - - log.Println("Data:", string(b)) - msg, err := mail.ReadMessage(bytes.NewReader(b)) - if err != nil { - return err - } - body, err := io.ReadAll(msg.Body) - if err != nil { - return err - } - topic := strings.TrimSuffix(s.to, "@"+s.config.SMTPServerDomain) - m := newDefaultMessage(topic, string(body)) - subject := msg.Header.Get("Subject") - if subject != "" { - m.Title = subject - } - return s.sub(m) + return s.withFailCount(func() error { + b, err := io.ReadAll(r) // Protected by MaxMessageBytes + if err != nil { + return err + } + msg, err := mail.ReadMessage(bytes.NewReader(b)) + if err != nil { + return err + } + body, err := io.ReadAll(io.LimitReader(msg.Body, int64(s.backend.config.MessageLimit))) + if err != nil { + return err + } + m := newDefaultMessage(s.topic, string(body)) + subject := msg.Header.Get("Subject") + if subject != "" { + m.Title = subject + } + if err := s.backend.sub(m); err != nil { + return err + } + s.backend.mu.Lock() + s.backend.success++ + s.backend.mu.Unlock() + return nil + }) } func (s *smtpSession) Reset() { s.mu.Lock() - s.from = "" - s.to = "" + s.topic = "" s.mu.Unlock() } func (s *smtpSession) Logout() error { return nil } + +func (s *smtpSession) withFailCount(fn func() error) error { + err := fn() + s.backend.mu.Lock() + defer s.backend.mu.Unlock() + if err != nil { + s.backend.failure++ + } + return err +}