diff --git a/ioutils/fswriters.go b/ioutils/fswriters.go index ca97670..6dc50a0 100644 --- a/ioutils/fswriters.go +++ b/ioutils/fswriters.go @@ -15,13 +15,15 @@ func NewAtomicFileWriter(filename string, perm os.FileMode) (io.WriteCloser, err if err != nil { return nil, err } + abspath, err := filepath.Abs(filename) if err != nil { return nil, err } return &atomicFileWriter{ - f: f, - fn: abspath, + f: f, + fn: abspath, + perm: perm, }, nil } @@ -34,6 +36,7 @@ func AtomicWriteFile(filename string, data []byte, perm os.FileMode) error { n, err := f.Write(data) if err == nil && n < len(data) { err = io.ErrShortWrite + f.(*atomicFileWriter).writeErr = err } if err1 := f.Close(); err == nil { err = err1 @@ -45,6 +48,7 @@ type atomicFileWriter struct { f *os.File fn string writeErr error + perm os.FileMode } func (w *atomicFileWriter) Write(dt []byte) (int, error) { @@ -57,7 +61,7 @@ func (w *atomicFileWriter) Write(dt []byte) (int, error) { func (w *atomicFileWriter) Close() (retErr error) { defer func() { - if retErr != nil { + if retErr != nil || w.writeErr != nil { os.Remove(w.f.Name()) } }() @@ -68,6 +72,9 @@ func (w *atomicFileWriter) Close() (retErr error) { if err := w.f.Close(); err != nil { return err } + if err := os.Chmod(w.f.Name(), w.perm); err != nil { + return err + } if w.writeErr == nil { return os.Rename(w.f.Name(), w.fn) } diff --git a/ioutils/fswriters_test.go b/ioutils/fswriters_test.go index 40717a5..470ca1a 100644 --- a/ioutils/fswriters_test.go +++ b/ioutils/fswriters_test.go @@ -16,7 +16,7 @@ func TestAtomicWriteToFile(t *testing.T) { defer os.RemoveAll(tmpDir) expected := []byte("barbaz") - if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0600); err != nil { + if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0666); err != nil { t.Fatalf("Error writing to file: %v", err) } @@ -28,4 +28,12 @@ func TestAtomicWriteToFile(t *testing.T) { if bytes.Compare(actual, expected) != 0 { t.Fatalf("Data mismatch, expected %q, got %q", expected, actual) } + + st, err := os.Stat(filepath.Join(tmpDir, "foo")) + if err != nil { + t.Fatalf("Error statting file: %v", err) + } + if expected := os.FileMode(0666); st.Mode() != expected { + t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode()) + } }