From d7a5d5b94cea602f0f811bd4ba46b7f9cabbcfb0 Mon Sep 17 00:00:00 2001 From: Antonio Murdaca Date: Sun, 29 Mar 2015 23:17:23 +0200 Subject: [PATCH] Refactor utils/utils, fixes #11923 Signed-off-by: Antonio Murdaca --- fileutils/fileutils.go | 54 +++++++++++++++ fileutils/fileutils_test.go | 81 +++++++++++++++++++++++ httputils/httputils.go | 26 ++++++++ ioutils/readers.go | 10 +++ ioutils/writers.go | 21 ++++++ ioutils/writers_test.go | 41 ++++++++++++ requestdecorator/requestdecorator_test.go | 16 ++--- resolvconf/resolvconf.go | 4 +- stringutils/stringutils.go | 56 ++++++++++++++++ stringutils/stringutils_test.go | 29 ++++++++ 10 files changed, 328 insertions(+), 10 deletions(-) create mode 100644 fileutils/fileutils_test.go create mode 100644 httputils/httputils.go create mode 100644 ioutils/writers_test.go diff --git a/fileutils/fileutils.go b/fileutils/fileutils.go index 4325297..ef2a652 100644 --- a/fileutils/fileutils.go +++ b/fileutils/fileutils.go @@ -1,6 +1,10 @@ package fileutils import ( + "fmt" + "io" + "io/ioutil" + "os" "path/filepath" "github.com/Sirupsen/logrus" @@ -25,3 +29,53 @@ func Matches(relFilePath string, patterns []string) (bool, error) { } return false, nil } + +func CopyFile(src, dst string) (int64, error) { + if src == dst { + return 0, nil + } + sf, err := os.Open(src) + if err != nil { + return 0, err + } + defer sf.Close() + if err := os.Remove(dst); err != nil && !os.IsNotExist(err) { + return 0, err + } + df, err := os.Create(dst) + if err != nil { + return 0, err + } + defer df.Close() + return io.Copy(df, sf) +} + +func GetTotalUsedFds() int { + if fds, err := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())); err != nil { + logrus.Errorf("Error opening /proc/%d/fd: %s", os.Getpid(), err) + } else { + return len(fds) + } + return -1 +} + +// ReadSymlinkedDirectory returns the target directory of a symlink. +// The target of the symbolic link may not be a file. +func ReadSymlinkedDirectory(path string) (string, error) { + var realPath string + var err error + if realPath, err = filepath.Abs(path); err != nil { + return "", fmt.Errorf("unable to get absolute path for %s: %s", path, err) + } + if realPath, err = filepath.EvalSymlinks(realPath); err != nil { + return "", fmt.Errorf("failed to canonicalise path for %s: %s", path, err) + } + realPathInfo, err := os.Stat(realPath) + if err != nil { + return "", fmt.Errorf("failed to stat target '%s' of '%s': %s", realPath, path, err) + } + if !realPathInfo.Mode().IsDir() { + return "", fmt.Errorf("canonical path points to a file '%s'", realPath) + } + return realPath, nil +} diff --git a/fileutils/fileutils_test.go b/fileutils/fileutils_test.go new file mode 100644 index 0000000..16d00d7 --- /dev/null +++ b/fileutils/fileutils_test.go @@ -0,0 +1,81 @@ +package fileutils + +import ( + "os" + "testing" +) + +// Reading a symlink to a directory must return the directory +func TestReadSymlinkedDirectoryExistingDirectory(t *testing.T) { + var err error + if err = os.Mkdir("/tmp/testReadSymlinkToExistingDirectory", 0777); err != nil { + t.Errorf("failed to create directory: %s", err) + } + + if err = os.Symlink("/tmp/testReadSymlinkToExistingDirectory", "/tmp/dirLinkTest"); err != nil { + t.Errorf("failed to create symlink: %s", err) + } + + var path string + if path, err = ReadSymlinkedDirectory("/tmp/dirLinkTest"); err != nil { + t.Fatalf("failed to read symlink to directory: %s", err) + } + + if path != "/tmp/testReadSymlinkToExistingDirectory" { + t.Fatalf("symlink returned unexpected directory: %s", path) + } + + if err = os.Remove("/tmp/testReadSymlinkToExistingDirectory"); err != nil { + t.Errorf("failed to remove temporary directory: %s", err) + } + + if err = os.Remove("/tmp/dirLinkTest"); err != nil { + t.Errorf("failed to remove symlink: %s", err) + } +} + +// Reading a non-existing symlink must fail +func TestReadSymlinkedDirectoryNonExistingSymlink(t *testing.T) { + var path string + var err error + if path, err = ReadSymlinkedDirectory("/tmp/test/foo/Non/ExistingPath"); err == nil { + t.Fatalf("error expected for non-existing symlink") + } + + if path != "" { + t.Fatalf("expected empty path, but '%s' was returned", path) + } +} + +// Reading a symlink to a file must fail +func TestReadSymlinkedDirectoryToFile(t *testing.T) { + var err error + var file *os.File + + if file, err = os.Create("/tmp/testReadSymlinkToFile"); err != nil { + t.Fatalf("failed to create file: %s", err) + } + + file.Close() + + if err = os.Symlink("/tmp/testReadSymlinkToFile", "/tmp/fileLinkTest"); err != nil { + t.Errorf("failed to create symlink: %s", err) + } + + var path string + if path, err = ReadSymlinkedDirectory("/tmp/fileLinkTest"); err == nil { + t.Fatalf("ReadSymlinkedDirectory on a symlink to a file should've failed") + } + + if path != "" { + t.Fatalf("path should've been empty: %s", path) + } + + if err = os.Remove("/tmp/testReadSymlinkToFile"); err != nil { + t.Errorf("failed to remove file: %s", err) + } + + if err = os.Remove("/tmp/fileLinkTest"); err != nil { + t.Errorf("failed to remove symlink: %s", err) + } +} diff --git a/httputils/httputils.go b/httputils/httputils.go new file mode 100644 index 0000000..1c92224 --- /dev/null +++ b/httputils/httputils.go @@ -0,0 +1,26 @@ +package httputils + +import ( + "fmt" + "net/http" + + "github.com/docker/docker/pkg/jsonmessage" +) + +// Request a given URL and return an io.Reader +func Download(url string) (resp *http.Response, err error) { + if resp, err = http.Get(url); err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("Got HTTP status code >= 400: %s", resp.Status) + } + return resp, nil +} + +func NewHTTPRequestError(msg string, res *http.Response) error { + return &jsonmessage.JSONError{ + Message: msg, + Code: res.StatusCode, + } +} diff --git a/ioutils/readers.go b/ioutils/readers.go index 58ff1af..0e542cb 100644 --- a/ioutils/readers.go +++ b/ioutils/readers.go @@ -3,6 +3,8 @@ package ioutils import ( "bytes" "crypto/rand" + "crypto/sha256" + "encoding/hex" "io" "math/big" "sync" @@ -215,3 +217,11 @@ func (r *bufReader) Close() error { } return closer.Close() } + +func HashData(src io.Reader) (string, error) { + h := sha256.New() + if _, err := io.Copy(h, src); err != nil { + return "", err + } + return "sha256:" + hex.EncodeToString(h.Sum(nil)), nil +} diff --git a/ioutils/writers.go b/ioutils/writers.go index c0b3608..43fdc44 100644 --- a/ioutils/writers.go +++ b/ioutils/writers.go @@ -37,3 +37,24 @@ func NewWriteCloserWrapper(r io.Writer, closer func() error) io.WriteCloser { closer: closer, } } + +// Wrap a concrete io.Writer and hold a count of the number +// of bytes written to the writer during a "session". +// This can be convenient when write return is masked +// (e.g., json.Encoder.Encode()) +type WriteCounter struct { + Count int64 + Writer io.Writer +} + +func NewWriteCounter(w io.Writer) *WriteCounter { + return &WriteCounter{ + Writer: w, + } +} + +func (wc *WriteCounter) Write(p []byte) (count int, err error) { + count, err = wc.Writer.Write(p) + wc.Count += int64(count) + return +} diff --git a/ioutils/writers_test.go b/ioutils/writers_test.go new file mode 100644 index 0000000..80d7f7f --- /dev/null +++ b/ioutils/writers_test.go @@ -0,0 +1,41 @@ +package ioutils + +import ( + "bytes" + "strings" + "testing" +) + +func TestNopWriter(t *testing.T) { + nw := &NopWriter{} + l, err := nw.Write([]byte{'c'}) + if err != nil { + t.Fatal(err) + } + if l != 1 { + t.Fatalf("Expected 1 got %d", l) + } +} + +func TestWriteCounter(t *testing.T) { + dummy1 := "This is a dummy string." + dummy2 := "This is another dummy string." + totalLength := int64(len(dummy1) + len(dummy2)) + + reader1 := strings.NewReader(dummy1) + reader2 := strings.NewReader(dummy2) + + var buffer bytes.Buffer + wc := NewWriteCounter(&buffer) + + reader1.WriteTo(wc) + reader2.WriteTo(wc) + + if wc.Count != totalLength { + t.Errorf("Wrong count: %d vs. %d", wc.Count, totalLength) + } + + if buffer.String() != dummy1+dummy2 { + t.Error("Wrong message written") + } +} diff --git a/requestdecorator/requestdecorator_test.go b/requestdecorator/requestdecorator_test.go index b2c1fb3..f1f9ef7 100644 --- a/requestdecorator/requestdecorator_test.go +++ b/requestdecorator/requestdecorator_test.go @@ -180,8 +180,8 @@ func TestRequestFactory(t *testing.T) { requestFactory := NewRequestFactory(ad, uad) - if dlen := len(requestFactory.GetDecorators()); dlen != 2 { - t.Fatalf("Expected to have two decorators, got %d", dlen) + if l := len(requestFactory.GetDecorators()); l != 2 { + t.Fatalf("Expected to have two decorators, got %d", l) } req, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test")) @@ -209,8 +209,8 @@ func TestRequestFactoryNewRequestWithDecorators(t *testing.T) { requestFactory := NewRequestFactory(ad) - if dlen := len(requestFactory.GetDecorators()); dlen != 1 { - t.Fatalf("Expected to have one decorators, got %d", dlen) + if l := len(requestFactory.GetDecorators()); l != 1 { + t.Fatalf("Expected to have one decorators, got %d", l) } ad2 := NewAuthDecorator("test2", "password2") @@ -235,15 +235,15 @@ func TestRequestFactoryNewRequestWithDecorators(t *testing.T) { func TestRequestFactoryAddDecorator(t *testing.T) { requestFactory := NewRequestFactory() - if dlen := len(requestFactory.GetDecorators()); dlen != 0 { - t.Fatalf("Expected to have zero decorators, got %d", dlen) + if l := len(requestFactory.GetDecorators()); l != 0 { + t.Fatalf("Expected to have zero decorators, got %d", l) } ad := NewAuthDecorator("test", "password") requestFactory.AddDecorator(ad) - if dlen := len(requestFactory.GetDecorators()); dlen != 1 { - t.Fatalf("Expected to have one decorators, got %d", dlen) + if l := len(requestFactory.GetDecorators()); l != 1 { + t.Fatalf("Expected to have one decorators, got %d", l) } } diff --git a/resolvconf/resolvconf.go b/resolvconf/resolvconf.go index d7d53e1..5707b16 100644 --- a/resolvconf/resolvconf.go +++ b/resolvconf/resolvconf.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/Sirupsen/logrus" - "github.com/docker/docker/utils" + "github.com/docker/docker/pkg/ioutils" ) var ( @@ -59,7 +59,7 @@ func GetIfChanged() ([]byte, string, error) { if err != nil { return nil, "", err } - newHash, err := utils.HashData(bytes.NewReader(resolv)) + newHash, err := ioutils.HashData(bytes.NewReader(resolv)) if err != nil { return nil, "", err } diff --git a/stringutils/stringutils.go b/stringutils/stringutils.go index f5f07dd..e3ebf5d 100644 --- a/stringutils/stringutils.go +++ b/stringutils/stringutils.go @@ -1,7 +1,9 @@ package stringutils import ( + "bytes" mathrand "math/rand" + "strings" "time" ) @@ -28,3 +30,57 @@ func GenerateRandomAsciiString(n int) string { } return string(res) } + +// Truncate a string to maxlen +func Truncate(s string, maxlen int) string { + if len(s) <= maxlen { + return s + } + return s[:maxlen] +} + +// Test wheather a string is contained in a slice of strings or not. +// Comparison is case insensitive +func InSlice(slice []string, s string) bool { + for _, ss := range slice { + if strings.ToLower(s) == strings.ToLower(ss) { + return true + } + } + return false +} + +func quote(word string, buf *bytes.Buffer) { + // Bail out early for "simple" strings + if word != "" && !strings.ContainsAny(word, "\\'\"`${[|&;<>()~*?! \t\n") { + buf.WriteString(word) + return + } + + buf.WriteString("'") + + for i := 0; i < len(word); i++ { + b := word[i] + if b == '\'' { + // Replace literal ' with a close ', a \', and a open ' + buf.WriteString("'\\''") + } else { + buf.WriteByte(b) + } + } + + buf.WriteString("'") +} + +// Take a list of strings and escape them so they will be handled right +// when passed as arguments to an program via a shell +func ShellQuoteArguments(args []string) string { + var buf bytes.Buffer + for i, arg := range args { + if i != 0 { + buf.WriteByte(' ') + } + quote(arg, &buf) + } + return buf.String() +} diff --git a/stringutils/stringutils_test.go b/stringutils/stringutils_test.go index a5a01b4..8dcb469 100644 --- a/stringutils/stringutils_test.go +++ b/stringutils/stringutils_test.go @@ -56,3 +56,32 @@ func TestGenerateRandomAsciiStringIsAscii(t *testing.T) { t.Fatalf("%s contained non-ascii characters", str) } } + +func TestTruncate(t *testing.T) { + str := "teststring" + newstr := Truncate(str, 4) + if newstr != "test" { + t.Fatalf("Expected test, got %s", newstr) + } + newstr = Truncate(str, 20) + if newstr != "teststring" { + t.Fatalf("Expected teststring, got %s", newstr) + } +} + +func TestInSlice(t *testing.T) { + slice := []string{"test", "in", "slice"} + + test := InSlice(slice, "test") + if !test { + t.Fatalf("Expected string test to be in slice") + } + test = InSlice(slice, "SLICE") + if !test { + t.Fatalf("Expected string SLICE to be in slice") + } + test = InSlice(slice, "notinslice") + if test { + t.Fatalf("Expected string notinslice not to be in slice") + } +}