From a5f237c2b54177ffe45cf371461db1892e092452 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Mon, 23 Sep 2019 11:09:01 -0700 Subject: [PATCH] Use FILE_SHARE_DELETE for log files on Windows. This fixes issues where one goroutine tries to delete or rename a file while another goroutine has the file open (e.g. a log reader). Signed-off-by: Brian Goff --- daemon/logger/loggerutils/file_unix.go | 9 + daemon/logger/loggerutils/file_windows.go | 224 ++++++++++++++++++ .../logger/loggerutils/file_windows_test.go | 34 +++ daemon/logger/loggerutils/logfile.go | 8 +- 4 files changed, 271 insertions(+), 4 deletions(-) create mode 100644 daemon/logger/loggerutils/file_unix.go create mode 100644 daemon/logger/loggerutils/file_windows.go create mode 100644 daemon/logger/loggerutils/file_windows_test.go diff --git a/daemon/logger/loggerutils/file_unix.go b/daemon/logger/loggerutils/file_unix.go new file mode 100644 index 0000000000..eb39d4b73d --- /dev/null +++ b/daemon/logger/loggerutils/file_unix.go @@ -0,0 +1,9 @@ +// +build !windows + +package loggerutils + +import "os" + +func openFile(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.OpenFile(name, flag, perm) +} diff --git a/daemon/logger/loggerutils/file_windows.go b/daemon/logger/loggerutils/file_windows.go new file mode 100644 index 0000000000..507da6874d --- /dev/null +++ b/daemon/logger/loggerutils/file_windows.go @@ -0,0 +1,224 @@ +package loggerutils + +import ( + "os" + "syscall" + "unsafe" + + "github.com/pkg/errors" +) + +func openFile(name string, flag int, perm os.FileMode) (*os.File, error) { + if name == "" { + return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT} + } + h, err := syscallOpen(fixLongPath(name), flag|syscall.O_CLOEXEC, syscallMode(perm)) + if err != nil { + return nil, errors.Wrap(err, "error opening file") + } + return os.NewFile(uintptr(h), name), nil +} + +// syscallOpen is copied from syscall.Open but is modified to +// always open a file with FILE_SHARE_DELETE +func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) { + if len(path) == 0 { + return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND + } + + pathp, err := syscall.UTF16PtrFromString(path) + if err != nil { + return syscall.InvalidHandle, err + } + var access uint32 + switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) { + case syscall.O_RDONLY: + access = syscall.GENERIC_READ + case syscall.O_WRONLY: + access = syscall.GENERIC_WRITE + case syscall.O_RDWR: + access = syscall.GENERIC_READ | syscall.GENERIC_WRITE + } + if mode&syscall.O_CREAT != 0 { + access |= syscall.GENERIC_WRITE + } + if mode&syscall.O_APPEND != 0 { + access &^= syscall.GENERIC_WRITE + access |= syscall.FILE_APPEND_DATA + } + sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE) + var sa *syscall.SecurityAttributes + if mode&syscall.O_CLOEXEC == 0 { + sa = makeInheritSa() + } + var createmode uint32 + switch { + case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL): + createmode = syscall.CREATE_NEW + case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC): + createmode = syscall.CREATE_ALWAYS + case mode&syscall.O_CREAT == syscall.O_CREAT: + createmode = syscall.OPEN_ALWAYS + case mode&syscall.O_TRUNC == syscall.O_TRUNC: + createmode = syscall.TRUNCATE_EXISTING + default: + createmode = syscall.OPEN_EXISTING + } + h, e := syscall.CreateFile(pathp, access, sharemode, sa, createmode, syscall.FILE_ATTRIBUTE_NORMAL, 0) + return h, e +} + +func makeInheritSa() *syscall.SecurityAttributes { + var sa syscall.SecurityAttributes + sa.Length = uint32(unsafe.Sizeof(sa)) + sa.InheritHandle = 1 + return &sa +} + +// fixLongPath returns the extended-length (\\?\-prefixed) form of +// path when needed, in order to avoid the default 260 character file +// path limit imposed by Windows. If path is not easily converted to +// the extended-length form (for example, if path is a relative path +// or contains .. elements), or is short enough, fixLongPath returns +// path unmodified. +// +// See https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247(v=vs.85).aspx#maxpath +// +// Copied from os.OpenFile +func fixLongPath(path string) string { + // Do nothing (and don't allocate) if the path is "short". + // Empirically (at least on the Windows Server 2013 builder), + // the kernel is arbitrarily okay with < 248 bytes. That + // matches what the docs above say: + // "When using an API to create a directory, the specified + // path cannot be so long that you cannot append an 8.3 file + // name (that is, the directory name cannot exceed MAX_PATH + // minus 12)." Since MAX_PATH is 260, 260 - 12 = 248. + // + // The MSDN docs appear to say that a normal path that is 248 bytes long + // will work; empirically the path must be less then 248 bytes long. + if len(path) < 248 { + // Don't fix. (This is how Go 1.7 and earlier worked, + // not automatically generating the \\?\ form) + return path + } + + // The extended form begins with \\?\, as in + // \\?\c:\windows\foo.txt or \\?\UNC\server\share\foo.txt. + // The extended form disables evaluation of . and .. path + // elements and disables the interpretation of / as equivalent + // to \. The conversion here rewrites / to \ and elides + // . elements as well as trailing or duplicate separators. For + // simplicity it avoids the conversion entirely for relative + // paths or paths containing .. elements. For now, + // \\server\share paths are not converted to + // \\?\UNC\server\share paths because the rules for doing so + // are less well-specified. + if len(path) >= 2 && path[:2] == `\\` { + // Don't canonicalize UNC paths. + return path + } + if !isAbs(path) { + // Relative path + return path + } + + const prefix = `\\?` + + pathbuf := make([]byte, len(prefix)+len(path)+len(`\`)) + copy(pathbuf, prefix) + n := len(path) + r, w := 0, len(prefix) + for r < n { + switch { + case os.IsPathSeparator(path[r]): + // empty block + r++ + case path[r] == '.' && (r+1 == n || os.IsPathSeparator(path[r+1])): + // /./ + r++ + case r+1 < n && path[r] == '.' && path[r+1] == '.' && (r+2 == n || os.IsPathSeparator(path[r+2])): + // /../ is currently unhandled + return path + default: + pathbuf[w] = '\\' + w++ + for ; r < n && !os.IsPathSeparator(path[r]); r++ { + pathbuf[w] = path[r] + w++ + } + } + } + // A drive's root directory needs a trailing \ + if w == len(`\\?\c:`) { + pathbuf[w] = '\\' + w++ + } + return string(pathbuf[:w]) +} + +// copied from os package for os.OpenFile +func syscallMode(i os.FileMode) (o uint32) { + o |= uint32(i.Perm()) + if i&os.ModeSetuid != 0 { + o |= syscall.S_ISUID + } + if i&os.ModeSetgid != 0 { + o |= syscall.S_ISGID + } + if i&os.ModeSticky != 0 { + o |= syscall.S_ISVTX + } + // No mapping for Go's ModeTemporary (plan9 only). + return +} + +func isAbs(path string) (b bool) { + v := volumeName(path) + if v == "" { + return false + } + path = path[len(v):] + if path == "" { + return false + } + return os.IsPathSeparator(path[0]) +} + +func volumeName(path string) (v string) { + if len(path) < 2 { + return "" + } + // with drive letter + c := path[0] + if path[1] == ':' && + ('0' <= c && c <= '9' || 'a' <= c && c <= 'z' || + 'A' <= c && c <= 'Z') { + return path[:2] + } + // is it UNC + if l := len(path); l >= 5 && os.IsPathSeparator(path[0]) && os.IsPathSeparator(path[1]) && + !os.IsPathSeparator(path[2]) && path[2] != '.' { + // first, leading `\\` and next shouldn't be `\`. its server name. + for n := 3; n < l-1; n++ { + // second, next '\' shouldn't be repeated. + if os.IsPathSeparator(path[n]) { + n++ + // third, following something characters. its share name. + if !os.IsPathSeparator(path[n]) { + if path[n] == '.' { + break + } + for ; n < l; n++ { + if os.IsPathSeparator(path[n]) { + break + } + } + return path[:n] + } + break + } + } + } + return "" +} diff --git a/daemon/logger/loggerutils/file_windows_test.go b/daemon/logger/loggerutils/file_windows_test.go new file mode 100644 index 0000000000..a7e69097b5 --- /dev/null +++ b/daemon/logger/loggerutils/file_windows_test.go @@ -0,0 +1,34 @@ +package loggerutils + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "gotest.tools/assert" +) + +func TestOpenFileDelete(t *testing.T) { + tmpDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(tmpDir) + + f, err := openFile(filepath.Join(tmpDir, "test.txt"), os.O_CREATE|os.O_RDWR, 644) + assert.NilError(t, err) + defer f.Close() + + assert.NilError(t, os.RemoveAll(f.Name())) +} + +func TestOpenFileRename(t *testing.T) { + tmpDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(tmpDir) + + f, err := openFile(filepath.Join(tmpDir, "test.txt"), os.O_CREATE|os.O_RDWR, 0644) + assert.NilError(t, err) + defer f.Close() + + assert.NilError(t, os.Rename(f.Name(), f.Name()+"renamed")) +} diff --git a/daemon/logger/loggerutils/logfile.go b/daemon/logger/loggerutils/logfile.go index cfc4a2e210..2f6e7381df 100644 --- a/daemon/logger/loggerutils/logfile.go +++ b/daemon/logger/loggerutils/logfile.go @@ -111,7 +111,7 @@ type GetTailReaderFunc func(ctx context.Context, f SizeReaderAt, nLogLines int) // NewLogFile creates new LogFile func NewLogFile(logPath string, capacity int64, maxFiles int, compress bool, marshaller logger.MarshalFunc, decodeFunc makeDecoderFunc, perms os.FileMode, getTailReader GetTailReaderFunc) (*LogFile, error) { - log, err := os.OpenFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, perms) + log, err := openFile(logPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, perms) if err != nil { return nil, err } @@ -182,7 +182,7 @@ func (w *LogFile) checkCapacityAndRotate() error { w.rotateMu.Unlock() return err } - file, err := os.OpenFile(fname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, w.perms) + file, err := openFile(fname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, w.perms) if err != nil { w.rotateMu.Unlock() return err @@ -250,7 +250,7 @@ func compressFile(fileName string, lastTimestamp time.Time) { } }() - outFile, err := os.OpenFile(fileName+".gz", os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0640) + outFile, err := openFile(fileName+".gz", os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0640) if err != nil { logrus.Errorf("Failed to open or create gzip log file: %v", err) return @@ -451,7 +451,7 @@ func decompressfile(fileName, destFileName string, since time.Time) (*os.File, e return nil, nil } - rs, err := os.OpenFile(destFileName, os.O_CREATE|os.O_RDWR, 0640) + rs, err := openFile(destFileName, os.O_CREATE|os.O_RDWR, 0640) if err != nil { return nil, errors.Wrap(err, "error creating file for copying decompressed log stream") }