diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 3783e72d91..ead85be0bf 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -771,20 +771,33 @@ func NewTempArchive(src Archive, dir string) (*TempArchive, error) { return nil, err } size := st.Size() - return &TempArchive{f, size, 0}, nil + return &TempArchive{File: f, Size: size}, nil } type TempArchive struct { *os.File - Size int64 // Pre-computed from Stat().Size() as a convenience - read int64 + Size int64 // Pre-computed from Stat().Size() as a convenience + read int64 + closed bool +} + +// Close closes the underlying file if it's still open, or does a no-op +// to allow callers to try to close the TempArchive multiple times safely. +func (archive *TempArchive) Close() error { + if archive.closed { + return nil + } + + archive.closed = true + + return archive.File.Close() } func (archive *TempArchive) Read(data []byte) (int, error) { n, err := archive.File.Read(data) archive.read += int64(n) if err != nil || archive.read == archive.Size { - archive.File.Close() + archive.Close() os.Remove(archive.File.Name()) } return n, err diff --git a/pkg/archive/archive_test.go b/pkg/archive/archive_test.go index 05362a21c9..fdba6fb87c 100644 --- a/pkg/archive/archive_test.go +++ b/pkg/archive/archive_test.go @@ -9,6 +9,7 @@ import ( "os/exec" "path" "path/filepath" + "strings" "syscall" "testing" "time" @@ -607,3 +608,18 @@ func TestUntarInvalidSymlink(t *testing.T) { } } } + +func TestTempArchiveCloseMultipleTimes(t *testing.T) { + reader := ioutil.NopCloser(strings.NewReader("hello")) + tempArchive, err := NewTempArchive(reader, "") + buf := make([]byte, 10) + n, err := tempArchive.Read(buf) + if n != 5 { + t.Fatalf("Expected to read 5 bytes. Read %d instead", n) + } + for i := 0; i < 3; i++ { + if err = tempArchive.Close(); err != nil { + t.Fatalf("i=%d. Unexpected error closing temp archive: %v", i, err) + } + } +}