WIP: Attachments
This commit is contained in:
parent
09515f26df
commit
eb5b86ffe2
13 changed files with 444 additions and 46 deletions
41
util/content_type_writer.go
Normal file
41
util/content_type_writer.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ContentTypeWriter is an implementation of http.ResponseWriter that will detect the content type and set the
|
||||
// Content-Type and (optionally) Content-Disposition headers accordingly.
|
||||
//
|
||||
// It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html"
|
||||
// content type.
|
||||
type ContentTypeWriter struct {
|
||||
w http.ResponseWriter
|
||||
sniffed bool
|
||||
}
|
||||
|
||||
// NewContentTypeWriter creates a new ContentTypeWriter
|
||||
func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter {
|
||||
return &ContentTypeWriter{w, false}
|
||||
}
|
||||
|
||||
func (w *ContentTypeWriter) Write(p []byte) (n int, err error) {
|
||||
if w.sniffed {
|
||||
return w.w.Write(p)
|
||||
}
|
||||
// Detect and set Content-Type header
|
||||
// Fix content types that we don't want to inline-render in the browser. In particular,
|
||||
// we don't want to render HTML in the browser for security reasons.
|
||||
contentType := http.DetectContentType(p)
|
||||
if strings.HasPrefix(contentType, "text/html") {
|
||||
contentType = strings.ReplaceAll(contentType, "text/html", "text/plain")
|
||||
} else if contentType == "application/octet-stream" {
|
||||
contentType = "" // Reset to let downstream http.ResponseWriter take care of it
|
||||
}
|
||||
if contentType != "" {
|
||||
w.w.Header().Set("Content-Type", contentType)
|
||||
}
|
||||
w.sniffed = true
|
||||
return w.w.Write(p)
|
||||
}
|
50
util/content_type_writer_test.go
Normal file
50
util/content_type_writer_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSniffWriter_WriteHTML(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
sw := NewContentTypeWriter(rr)
|
||||
sw.Write([]byte("<script>alert('hi')</script>"))
|
||||
require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
sw := NewContentTypeWriter(rr)
|
||||
sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33})
|
||||
sw.Write([]byte("<script>alert('hi')</script>"))
|
||||
require.Equal(t, "application/pdf", rr.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestSniffWriter_NoSniffWriterWriteHTML(t *testing.T) {
|
||||
// This test just makes sure that without the sniff-w, we would get text/html
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
rr.Write([]byte("<script>alert('hi')</script>"))
|
||||
require.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) {
|
||||
// This test shows how splitting the HTML into two Write() calls will still yield text/plain
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
sw := NewContentTypeWriter(rr)
|
||||
sw.Write([]byte("<scr"))
|
||||
sw.Write([]byte("ipt>alert('hi')</script>"))
|
||||
require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestSniffWriter_WriteUnknownMimeType(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
sw := NewContentTypeWriter(rr)
|
||||
randomBytes := make([]byte, 199)
|
||||
rand.Read(randomBytes)
|
||||
sw.Write(randomBytes)
|
||||
require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type"))
|
||||
}
|
|
@ -2,6 +2,7 @@ package util
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -58,3 +59,43 @@ func (l *Limiter) Value() int64 {
|
|||
defer l.mu.Unlock()
|
||||
return l.value
|
||||
}
|
||||
|
||||
// Limit returns the defined limit
|
||||
func (l *Limiter) Limit() int64 {
|
||||
return l.limit
|
||||
}
|
||||
|
||||
// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
|
||||
// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
|
||||
// Each limiter's value is increased with every write.
|
||||
type LimitWriter struct {
|
||||
w io.Writer
|
||||
written int64
|
||||
limiters []*Limiter
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewLimitWriter creates a new LimitWriter
|
||||
func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
|
||||
return &LimitWriter{
|
||||
w: w,
|
||||
limiters: limiters,
|
||||
}
|
||||
}
|
||||
|
||||
// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
|
||||
func (w *LimitWriter) Write(p []byte) (n int, err error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
for i := 0; i < len(w.limiters); i++ {
|
||||
if err := w.limiters[i].Add(int64(len(p))); err != nil {
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
w.limiters[j].Sub(int64(len(p)))
|
||||
}
|
||||
return 0, ErrLimitReached
|
||||
}
|
||||
}
|
||||
n, err = w.w.Write(p)
|
||||
w.written += int64(n)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -17,14 +18,68 @@ func TestLimiter_Add(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLimiter_AddSub(t *testing.T) {
|
||||
func TestLimiter_AddSet(t *testing.T) {
|
||||
l := NewLimiter(10)
|
||||
l.Add(5)
|
||||
if l.Value() != 5 {
|
||||
t.Fatalf("expected value to be %d, got %d", 5, l.Value())
|
||||
}
|
||||
l.Sub(2)
|
||||
if l.Value() != 3 {
|
||||
t.Fatalf("expected value to be %d, got %d", 3, l.Value())
|
||||
l.Set(7)
|
||||
if l.Value() != 7 {
|
||||
t.Fatalf("expected value to be %d, got %d", 7, l.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitWriter_WriteNoLimiter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
lw := NewLimitWriter(&buf)
|
||||
if _, err := lw.Write(make([]byte, 10)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := lw.Write(make([]byte, 1)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if buf.Len() != 11 {
|
||||
t.Fatalf("expected buffer length to be %d, got %d", 11, buf.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitWriter_WriteOneLimiter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
l := NewLimiter(10)
|
||||
lw := NewLimitWriter(&buf, l)
|
||||
if _, err := lw.Write(make([]byte, 10)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := lw.Write(make([]byte, 1)); err != ErrLimitReached {
|
||||
t.Fatalf("expected ErrLimitReached, got %#v", err)
|
||||
}
|
||||
if buf.Len() != 10 {
|
||||
t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len())
|
||||
}
|
||||
if l.Value() != 10 {
|
||||
t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
l1 := NewLimiter(11)
|
||||
l2 := NewLimiter(9)
|
||||
lw := NewLimitWriter(&buf, l1, l2)
|
||||
if _, err := lw.Write(make([]byte, 8)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := lw.Write(make([]byte, 2)); err != ErrLimitReached {
|
||||
t.Fatalf("expected ErrLimitReached, got %#v", err)
|
||||
}
|
||||
if buf.Len() != 8 {
|
||||
t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len())
|
||||
}
|
||||
if l1.Value() != 8 {
|
||||
t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value())
|
||||
}
|
||||
if l2.Value() != 8 {
|
||||
t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value())
|
||||
}
|
||||
}
|
||||
|
|
61
util/peak.go
Normal file
61
util/peak.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory.
|
||||
// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully
|
||||
// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining
|
||||
// underlying reader.
|
||||
type PeakedReadCloser struct {
|
||||
PeakedBytes []byte
|
||||
LimitReached bool
|
||||
peaked io.Reader
|
||||
underlying io.ReadCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser
|
||||
func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) {
|
||||
if underlying == nil {
|
||||
underlying = io.NopCloser(strings.NewReader(""))
|
||||
}
|
||||
peaked := make([]byte, limit)
|
||||
read, err := io.ReadFull(underlying, peaked)
|
||||
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
return &PeakedReadCloser{
|
||||
PeakedBytes: peaked[:read],
|
||||
LimitReached: read == limit,
|
||||
underlying: underlying,
|
||||
peaked: bytes.NewReader(peaked[:read]),
|
||||
closed: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Read reads from the peaked bytes and then from the underlying stream
|
||||
func (r *PeakedReadCloser) Read(p []byte) (n int, err error) {
|
||||
if r.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = r.peaked.Read(p)
|
||||
if err == io.EOF {
|
||||
return r.underlying.Read(p)
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Close closes the underlying stream
|
||||
func (r *PeakedReadCloser) Close() error {
|
||||
if r.closed {
|
||||
return io.EOF
|
||||
}
|
||||
r.closed = true
|
||||
return r.underlying.Close()
|
||||
}
|
55
util/peak_test.go
Normal file
55
util/peak_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPeak_LimitReached(t *testing.T) {
|
||||
underlying := io.NopCloser(strings.NewReader("1234567890"))
|
||||
peaked, err := Peak(underlying, 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, []byte("12345"), peaked.PeakedBytes)
|
||||
require.Equal(t, true, peaked.LimitReached)
|
||||
|
||||
all, err := io.ReadAll(peaked)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, []byte("1234567890"), all)
|
||||
require.Equal(t, []byte("12345"), peaked.PeakedBytes)
|
||||
require.Equal(t, true, peaked.LimitReached)
|
||||
}
|
||||
|
||||
func TestPeak_LimitNotReached(t *testing.T) {
|
||||
underlying := io.NopCloser(strings.NewReader("1234567890"))
|
||||
peaked, err := Peak(underlying, 15)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
all, err := io.ReadAll(peaked)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, []byte("1234567890"), all)
|
||||
require.Equal(t, []byte("1234567890"), peaked.PeakedBytes)
|
||||
require.Equal(t, false, peaked.LimitReached)
|
||||
}
|
||||
|
||||
func TestPeak_Nil(t *testing.T) {
|
||||
peaked, err := Peak(nil, 15)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
all, err := io.ReadAll(peaked)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, []byte(""), all)
|
||||
require.Equal(t, []byte(""), peaked.PeakedBytes)
|
||||
require.Equal(t, false, peaked.LimitReached)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue