diff --git a/timeout/timeout.go b/timeout/timeout.go new file mode 100644 index 0000000..be85ebd --- /dev/null +++ b/timeout/timeout.go @@ -0,0 +1,26 @@ +package timeout + +import ( + "net" + "time" +) + +func New(netConn net.Conn, timeout time.Duration) net.Conn { + return &conn{netConn, timeout} +} + +// A net.Conn that sets a deadline for every Read or Write operation +type conn struct { + net.Conn + timeout time.Duration +} + +func (c *conn) Read(b []byte) (int, error) { + if c.timeout > 0 { + err := c.Conn.SetReadDeadline(time.Now().Add(c.timeout)) + if err != nil { + return 0, err + } + } + return c.Conn.Read(b) +} diff --git a/timeout/timeout_test.go b/timeout/timeout_test.go new file mode 100644 index 0000000..f2bebae --- /dev/null +++ b/timeout/timeout_test.go @@ -0,0 +1,33 @@ +package timeout + +import ( + "bufio" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRead(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello") + })) + defer ts.Close() + conn, err := net.Dial("tcp", ts.URL[7:]) + if err != nil { + t.Fatalf("failed to create connection to %q: %v", ts.URL, err) + } + tconn := New(conn, 1*time.Second) + + if _, err = bufio.NewReader(tconn).ReadString('\n'); err == nil { + t.Fatalf("expected timeout error, got none") + } + if _, err := fmt.Fprintf(tconn, "GET / HTTP/1.0\r\n\r\n"); err != nil { + t.Errorf("unexpected error: %v", err) + } + if _, err = bufio.NewReader(tconn).ReadString('\n'); err != nil { + t.Errorf("unexpected error: %v", err) + } +}