1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

Write Windows layer diffs to tar in standard format

Previously, Windows layer diffs were written using a Windows-internal
format based on the BackupRead/BackupWrite Win32 APIs. This caused
problems with tar-split and tarsum and led to performance problems
in implementing methods such as DiffPath. It also was just an
unnecessary differentiation point between Windows and Linux.

With this change, Windows layer diffs look much more like their
Linux counterparts. They use AUFS-style whiteout files for files
that have been removed, and they encode all metadata directly in
the tar file.

This change only affects Windows post-TP4, since changes to the Windows
container storage APIs were necessary to make this possible.

Signed-off-by: John Starks <jostarks@microsoft.com>
This commit is contained in:
John Starks 2016-02-18 18:11:36 -08:00
parent 882edc3f0e
commit 5649030e25
6 changed files with 284 additions and 793 deletions

View file

@ -4,16 +4,15 @@ package windows
import (
"fmt"
"strconv"
"strings"
"sync"
"github.com/Microsoft/hcsshim"
"github.com/Sirupsen/logrus"
"github.com/docker/docker/daemon/execdriver"
"github.com/docker/docker/dockerversion"
"github.com/docker/docker/pkg/parsers"
"github.com/docker/engine-api/types/container"
"golang.org/x/sys/windows/registry"
)
// TP4RetryHack is a hack to retry CreateComputeSystem if it fails with
@ -98,33 +97,11 @@ func NewDriver(root string, options []string) (*Driver, error) {
// TODO Windows TP5 timeframe. Remove this next block of code once TP4
// is no longer supported. Also remove the workaround in run.go.
//
// Hack for TP4 - determine the version of Windows from the registry.
// Hack for TP4.
// This overcomes an issue on TP4 which causes CreateComputeSystem to
// intermittently fail. It's predominantly here to make Windows to Windows
// CI more reliable.
TP4RetryHack = false
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
return &Driver{}, err
}
defer k.Close()
s, _, err := k.GetStringValue("BuildLab")
if err != nil {
return &Driver{}, err
}
parts := strings.Split(s, ".")
if len(parts) < 1 {
return &Driver{}, err
}
var val int
if val, err = strconv.Atoi(parts[0]); err != nil {
return &Driver{}, err
}
if val < 14250 {
TP4RetryHack = true
}
// End of Windows TP4 hack
TP4RetryHack = hcsshim.IsTP4()
return &Driver{
root: root,

View file

@ -3,17 +3,23 @@
package windows
import (
"bufio"
"crypto/sha512"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/Microsoft/go-winio"
"github.com/Microsoft/go-winio/archive/tar"
"github.com/Microsoft/go-winio/backuptar"
"github.com/Microsoft/hcsshim"
"github.com/Sirupsen/logrus"
"github.com/docker/docker/daemon/graphdriver"
@ -21,7 +27,6 @@ import (
"github.com/docker/docker/pkg/chrootarchive"
"github.com/docker/docker/pkg/idtools"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/random"
"github.com/vbatts/tar-split/tar/storage"
)
@ -265,7 +270,7 @@ func (d *Driver) Cleanup() error {
// Diff produces an archive of the changes between the specified
// layer and its parent layer which may be "".
func (d *Driver) Diff(id, parent string) (arch archive.Archive, err error) {
func (d *Driver) Diff(id, parent string) (_ archive.Archive, err error) {
rID, err := d.resolveID(id)
if err != nil {
return
@ -277,6 +282,8 @@ func (d *Driver) Diff(id, parent string) (arch archive.Archive, err error) {
return
}
var undo func()
d.Lock()
// To support export, a layer must be activated but not prepared.
@ -286,6 +293,56 @@ func (d *Driver) Diff(id, parent string) (arch archive.Archive, err error) {
d.Unlock()
return
}
undo = func() {
if err := hcsshim.DeactivateLayer(d.info, rID); err != nil {
logrus.Warnf("Failed to Deactivate %s: %s", rID, err)
}
}
} else {
if err = hcsshim.UnprepareLayer(d.info, rID); err != nil {
d.Unlock()
return
}
undo = func() {
if err := hcsshim.PrepareLayer(d.info, rID, layerChain); err != nil {
logrus.Warnf("Failed to re-PrepareLayer %s: %s", rID, err)
}
}
}
}
d.Unlock()
arch, err := d.exportLayer(rID, layerChain)
if err != nil {
undo()
return
}
return ioutils.NewReadCloserWrapper(arch, func() error {
defer undo()
return arch.Close()
}), nil
}
// Changes produces a list of changes between the specified layer
// and its parent layer. If parent is "", then all changes will be ADD changes.
func (d *Driver) Changes(id, parent string) ([]archive.Change, error) {
rID, err := d.resolveID(id)
if err != nil {
return nil, err
}
parentChain, err := d.getLayerChain(rID)
if err != nil {
return nil, err
}
d.Lock()
if d.info.Flavour == filterDriver {
if d.active[rID] == 0 {
if err = hcsshim.ActivateLayer(d.info, rID); err != nil {
d.Unlock()
return nil, err
}
defer func() {
if err := hcsshim.DeactivateLayer(d.info, rID); err != nil {
logrus.Warnf("Failed to Deactivate %s: %s", rID, err)
@ -294,25 +351,41 @@ func (d *Driver) Diff(id, parent string) (arch archive.Archive, err error) {
} else {
if err = hcsshim.UnprepareLayer(d.info, rID); err != nil {
d.Unlock()
return
return nil, err
}
defer func() {
if err := hcsshim.PrepareLayer(d.info, rID, layerChain); err != nil {
if err := hcsshim.PrepareLayer(d.info, rID, parentChain); err != nil {
logrus.Warnf("Failed to re-PrepareLayer %s: %s", rID, err)
}
}()
}
}
d.Unlock()
return d.exportLayer(rID, layerChain)
}
r, err := hcsshim.NewLayerReader(d.info, id, parentChain)
if err != nil {
return nil, err
}
defer r.Close()
// Changes produces a list of changes between the specified layer
// and its parent layer. If parent is "", then all changes will be ADD changes.
func (d *Driver) Changes(id, parent string) ([]archive.Change, error) {
return nil, fmt.Errorf("The Windows graphdriver does not support Changes()")
var changes []archive.Change
for {
name, _, fileInfo, err := r.Next()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
name = filepath.ToSlash(name)
if fileInfo == nil {
changes = append(changes, archive.Change{name, archive.ChangeDelete})
} else {
// Currently there is no way to tell between an add and a modify.
changes = append(changes, archive.Change{name, archive.ChangeModify})
}
}
return changes, nil
}
// ApplyDiff extracts the changeset from the given diff into the
@ -444,71 +517,162 @@ func (d *Driver) GetMetadata(id string) (map[string]string, error) {
return m, nil
}
// exportLayer generates an archive from a layer based on the given ID.
func (d *Driver) exportLayer(id string, parentLayerPaths []string) (arch archive.Archive, err error) {
layerFolder := d.dir(id)
tempFolder := layerFolder + "-" + strconv.FormatUint(uint64(random.Rand.Uint32()), 10)
if err = os.MkdirAll(tempFolder, 0755); err != nil {
logrus.Errorf("Could not create %s %s", tempFolder, err)
return
}
defer func() {
func writeTarFromLayer(r hcsshim.LayerReader, w io.Writer) error {
t := tar.NewWriter(w)
for {
name, size, fileInfo, err := r.Next()
if err == io.EOF {
break
}
if err != nil {
_, folderName := filepath.Split(tempFolder)
if err2 := hcsshim.DestroyLayer(d.info, folderName); err2 != nil {
logrus.Warnf("Couldn't clean-up tempFolder: %s %s", tempFolder, err2)
return err
}
if fileInfo == nil {
// Write a whiteout file.
hdr := &tar.Header{
Name: filepath.ToSlash(filepath.Join(filepath.Dir(name), archive.WhiteoutPrefix+filepath.Base(name))),
}
err := t.WriteHeader(hdr)
if err != nil {
return err
}
} else {
err = backuptar.WriteTarFileFromBackupStream(t, r, name, size, fileInfo)
if err != nil {
return err
}
}
}
return t.Close()
}
// exportLayer generates an archive from a layer based on the given ID.
func (d *Driver) exportLayer(id string, parentLayerPaths []string) (archive.Archive, error) {
if hcsshim.IsTP4() {
// Export in TP4 format to maintain compatibility with existing images and
// because ExportLayer is somewhat broken on TP4 and can't work with the new
// scheme.
tempFolder, err := ioutil.TempDir("", "hcs")
if err != nil {
return nil, err
}
defer func() {
if err != nil {
os.RemoveAll(tempFolder)
}
}()
if err = hcsshim.ExportLayer(d.info, id, tempFolder, parentLayerPaths); err != nil {
return nil, err
}
archive, err := archive.Tar(tempFolder, archive.Uncompressed)
if err != nil {
return nil, err
}
return ioutils.NewReadCloserWrapper(archive, func() error {
err := archive.Close()
os.RemoveAll(tempFolder)
return err
}), nil
}
var r hcsshim.LayerReader
r, err := hcsshim.NewLayerReader(d.info, id, parentLayerPaths)
if err != nil {
return nil, err
}
archive, w := io.Pipe()
go func() {
err := writeTarFromLayer(r, w)
cerr := r.Close()
if err == nil {
err = cerr
}
w.CloseWithError(err)
}()
if err = hcsshim.ExportLayer(d.info, id, tempFolder, parentLayerPaths); err != nil {
return
}
return archive, nil
}
archive, err := archive.Tar(tempFolder, archive.Uncompressed)
if err != nil {
return
}
return ioutils.NewReadCloserWrapper(archive, func() error {
err := archive.Close()
d.Put(id)
_, folderName := filepath.Split(tempFolder)
if err2 := hcsshim.DestroyLayer(d.info, folderName); err2 != nil {
logrus.Warnf("Couldn't clean-up tempFolder: %s %s", tempFolder, err2)
func writeLayerFromTar(r archive.Reader, w hcsshim.LayerWriter) (int64, error) {
t := tar.NewReader(r)
hdr, err := t.Next()
totalSize := int64(0)
buf := bufio.NewWriter(nil)
for err == nil {
base := path.Base(hdr.Name)
if strings.HasPrefix(base, archive.WhiteoutPrefix) {
name := path.Join(path.Dir(hdr.Name), base[len(archive.WhiteoutPrefix):])
err = w.Remove(filepath.FromSlash(name))
if err != nil {
return 0, err
}
hdr, err = t.Next()
} else {
var (
name string
size int64
fileInfo *winio.FileBasicInfo
)
name, size, fileInfo, err = backuptar.FileInfoFromHeader(hdr)
if err != nil {
return 0, err
}
err = w.Add(filepath.FromSlash(name), fileInfo)
if err != nil {
return 0, err
}
buf.Reset(w)
hdr, err = backuptar.WriteBackupStreamFromTarFile(buf, t, hdr)
ferr := buf.Flush()
if ferr != nil {
err = ferr
}
totalSize += size
}
return err
}), nil
}
if err != io.EOF {
return 0, err
}
return totalSize, nil
}
// importLayer adds a new layer to the tag and graph store based on the given data.
func (d *Driver) importLayer(id string, layerData archive.Reader, parentLayerPaths []string) (size int64, err error) {
layerFolder := d.dir(id)
tempFolder := layerFolder + "-" + strconv.FormatUint(uint64(random.Rand.Uint32()), 10)
if err = os.MkdirAll(tempFolder, 0755); err != nil {
logrus.Errorf("Could not create %s %s", tempFolder, err)
return
}
defer func() {
_, folderName := filepath.Split(tempFolder)
if err2 := hcsshim.DestroyLayer(d.info, folderName); err2 != nil {
logrus.Warnf("Couldn't clean-up tempFolder: %s %s", tempFolder, err2)
if hcsshim.IsTP4() {
// Import from TP4 format to maintain compatibility with existing images.
var tempFolder string
tempFolder, err = ioutil.TempDir("", "hcs")
if err != nil {
return
}
}()
defer os.RemoveAll(tempFolder)
start := time.Now().UTC()
logrus.Debugf("Start untar layer")
if size, err = chrootarchive.ApplyLayer(tempFolder, layerData); err != nil {
return
}
logrus.Debugf("Untar time: %vs", time.Now().UTC().Sub(start).Seconds())
if err = hcsshim.ImportLayer(d.info, id, tempFolder, parentLayerPaths); err != nil {
if size, err = chrootarchive.ApplyLayer(tempFolder, layerData); err != nil {
return
}
if err = hcsshim.ImportLayer(d.info, id, tempFolder, parentLayerPaths); err != nil {
return
}
return
}
var w hcsshim.LayerWriter
w, err = hcsshim.NewLayerWriter(d.info, id, parentLayerPaths)
if err != nil {
return
}
size, err = writeLayerFromTar(layerData, w)
if err != nil {
w.Close()
return
}
err = w.Close()
if err != nil {
return
}
return
}
@ -567,51 +731,78 @@ func (d *Driver) setLayerChain(id string, chain []string) error {
return nil
}
type fileGetCloserWithBackupPrivileges struct {
path string
}
func (fg *fileGetCloserWithBackupPrivileges) Get(filename string) (io.ReadCloser, error) {
var f *os.File
// Open the file while holding the Windows backup privilege. This ensures that the
// file can be opened even if the caller does not actually have access to it according
// to the security descriptor.
err := winio.RunWithPrivilege(winio.SeBackupPrivilege, func() error {
path := filepath.Join(fg.path, filename)
p, err := syscall.UTF16FromString(path)
if err != nil {
return err
}
h, err := syscall.CreateFile(&p[0], syscall.GENERIC_READ, syscall.FILE_SHARE_READ, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_BACKUP_SEMANTICS, 0)
if err != nil {
return &os.PathError{Op: "open", Path: path, Err: err}
}
f = os.NewFile(uintptr(h), path)
return nil
})
return f, err
}
func (fg *fileGetCloserWithBackupPrivileges) Close() error {
return nil
}
type fileGetDestroyCloser struct {
storage.FileGetter
d *Driver
folderName string
path string
}
func (f *fileGetDestroyCloser) Close() error {
// TODO: activate layers and release here?
return hcsshim.DestroyLayer(f.d.info, f.folderName)
return os.RemoveAll(f.path)
}
// DiffGetter returns a FileGetCloser that can read files from the directory that
// contains files for the layer differences. Used for direct access for tar-split.
func (d *Driver) DiffGetter(id string) (fg graphdriver.FileGetCloser, err error) {
id, err = d.resolveID(id)
func (d *Driver) DiffGetter(id string) (graphdriver.FileGetCloser, error) {
id, err := d.resolveID(id)
if err != nil {
return
return nil, err
}
// Getting the layer paths must be done outside of the lock.
layerChain, err := d.getLayerChain(id)
if err != nil {
return
}
layerFolder := d.dir(id)
tempFolder := layerFolder + "-" + strconv.FormatUint(uint64(random.Rand.Uint32()), 10)
if err = os.MkdirAll(tempFolder, 0755); err != nil {
logrus.Errorf("Could not create %s %s", tempFolder, err)
return
}
defer func() {
if hcsshim.IsTP4() {
// The export format for TP4 is different from the contents of the layer, so
// fall back to exporting the layer and getting file contents from there.
layerChain, err := d.getLayerChain(id)
if err != nil {
_, folderName := filepath.Split(tempFolder)
if err2 := hcsshim.DestroyLayer(d.info, folderName); err2 != nil {
logrus.Warnf("Couldn't clean-up tempFolder: %s %s", tempFolder, err2)
}
return nil, err
}
}()
if err = hcsshim.ExportLayer(d.info, id, tempFolder, layerChain); err != nil {
return
var tempFolder string
tempFolder, err = ioutil.TempDir("", "hcs")
if err != nil {
return nil, err
}
defer func() {
if err != nil {
os.RemoveAll(tempFolder)
}
}()
if err = hcsshim.ExportLayer(d.info, id, tempFolder, layerChain); err != nil {
return nil, err
}
return &fileGetDestroyCloser{storage.NewPathFileGetter(tempFolder), tempFolder}, nil
}
_, folderName := filepath.Split(tempFolder)
return &fileGetDestroyCloser{storage.NewPathFileGetter(tempFolder), d, folderName}, nil
return &fileGetCloserWithBackupPrivileges{d.dir(id)}, nil
}

View file

@ -1,178 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
// Package registry provides access to the Windows registry.
//
// Here is a simple example, opening a registry key and reading a string value from it.
//
// k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
// if err != nil {
// log.Fatal(err)
// }
// defer k.Close()
//
// s, _, err := k.GetStringValue("SystemRoot")
// if err != nil {
// log.Fatal(err)
// }
// fmt.Printf("Windows system root is %q\n", s)
//
package registry
import (
"io"
"syscall"
"time"
)
const (
// Registry key security and access rights.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms724878.aspx
// for details.
ALL_ACCESS = 0xf003f
CREATE_LINK = 0x00020
CREATE_SUB_KEY = 0x00004
ENUMERATE_SUB_KEYS = 0x00008
EXECUTE = 0x20019
NOTIFY = 0x00010
QUERY_VALUE = 0x00001
READ = 0x20019
SET_VALUE = 0x00002
WOW64_32KEY = 0x00200
WOW64_64KEY = 0x00100
WRITE = 0x20006
)
// Key is a handle to an open Windows registry key.
// Keys can be obtained by calling OpenKey; there are
// also some predefined root keys such as CURRENT_USER.
// Keys can be used directly in the Windows API.
type Key syscall.Handle
const (
// Windows defines some predefined root keys that are always open.
// An application can use these keys as entry points to the registry.
// Normally these keys are used in OpenKey to open new keys,
// but they can also be used anywhere a Key is required.
CLASSES_ROOT = Key(syscall.HKEY_CLASSES_ROOT)
CURRENT_USER = Key(syscall.HKEY_CURRENT_USER)
LOCAL_MACHINE = Key(syscall.HKEY_LOCAL_MACHINE)
USERS = Key(syscall.HKEY_USERS)
CURRENT_CONFIG = Key(syscall.HKEY_CURRENT_CONFIG)
)
// Close closes open key k.
func (k Key) Close() error {
return syscall.RegCloseKey(syscall.Handle(k))
}
// OpenKey opens a new key with path name relative to key k.
// It accepts any open key, including CURRENT_USER and others,
// and returns the new key and an error.
// The access parameter specifies desired access rights to the
// key to be opened.
func OpenKey(k Key, path string, access uint32) (Key, error) {
p, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, err
}
var subkey syscall.Handle
err = syscall.RegOpenKeyEx(syscall.Handle(k), p, 0, access, &subkey)
if err != nil {
return 0, err
}
return Key(subkey), nil
}
// ReadSubKeyNames returns the names of subkeys of key k.
// The parameter n controls the number of returned names,
// analogous to the way os.File.Readdirnames works.
func (k Key) ReadSubKeyNames(n int) ([]string, error) {
ki, err := k.Stat()
if err != nil {
return nil, err
}
names := make([]string, 0, ki.SubKeyCount)
buf := make([]uint16, ki.MaxSubKeyLen+1) // extra room for terminating zero byte
loopItems:
for i := uint32(0); ; i++ {
if n > 0 {
if len(names) == n {
return names, nil
}
}
l := uint32(len(buf))
for {
err := syscall.RegEnumKeyEx(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
if err == nil {
break
}
if err == syscall.ERROR_MORE_DATA {
// Double buffer size and try again.
l = uint32(2 * len(buf))
buf = make([]uint16, l)
continue
}
if err == _ERROR_NO_MORE_ITEMS {
break loopItems
}
return names, err
}
names = append(names, syscall.UTF16ToString(buf[:l]))
}
if n > len(names) {
return names, io.EOF
}
return names, nil
}
// CreateKey creates a key named path under open key k.
// CreateKey returns the new key and a boolean flag that reports
// whether the key already existed.
// The access parameter specifies the access rights for the key
// to be created.
func CreateKey(k Key, path string, access uint32) (newk Key, openedExisting bool, err error) {
var h syscall.Handle
var d uint32
err = regCreateKeyEx(syscall.Handle(k), syscall.StringToUTF16Ptr(path),
0, nil, _REG_OPTION_NON_VOLATILE, access, nil, &h, &d)
if err != nil {
return 0, false, err
}
return Key(h), d == _REG_OPENED_EXISTING_KEY, nil
}
// DeleteKey deletes the subkey path of key k and its values.
func DeleteKey(k Key, path string) error {
return regDeleteKey(syscall.Handle(k), syscall.StringToUTF16Ptr(path))
}
// A KeyInfo describes the statistics of a key. It is returned by Stat.
type KeyInfo struct {
SubKeyCount uint32
MaxSubKeyLen uint32 // size of the key's subkey with the longest name, in Unicode characters, not including the terminating zero byte
ValueCount uint32
MaxValueNameLen uint32 // size of the key's longest value name, in Unicode characters, not including the terminating zero byte
MaxValueLen uint32 // longest data component among the key's values, in bytes
lastWriteTime syscall.Filetime
}
// ModTime returns the key's last write time.
func (ki *KeyInfo) ModTime() time.Time {
return time.Unix(0, ki.lastWriteTime.Nanoseconds())
}
// Stat retrieves information about the open key k.
func (k Key) Stat() (*KeyInfo, error) {
var ki KeyInfo
err := syscall.RegQueryInfoKey(syscall.Handle(k), nil, nil, nil,
&ki.SubKeyCount, &ki.MaxSubKeyLen, nil, &ki.ValueCount,
&ki.MaxValueNameLen, &ki.MaxValueLen, nil, &ki.lastWriteTime)
if err != nil {
return nil, err
}
return &ki, nil
}

View file

@ -1,33 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package registry
import "syscall"
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go syscall.go
const (
_REG_OPTION_NON_VOLATILE = 0
_REG_CREATED_NEW_KEY = 1
_REG_OPENED_EXISTING_KEY = 2
_ERROR_NO_MORE_ITEMS syscall.Errno = 259
)
func LoadRegLoadMUIString() error {
return procRegLoadMUIStringW.Find()
}
//sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW
//sys regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) = advapi32.RegDeleteKeyW
//sys regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) = advapi32.RegSetValueExW
//sys regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) = advapi32.RegEnumValueW
//sys regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) = advapi32.RegDeleteValueW
//sys regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) = advapi32.RegLoadMUIStringW
//sys expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) = kernel32.ExpandEnvironmentStringsW

View file

@ -1,384 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
package registry
import (
"errors"
"io"
"syscall"
"unicode/utf16"
"unsafe"
)
const (
// Registry value types.
NONE = 0
SZ = 1
EXPAND_SZ = 2
BINARY = 3
DWORD = 4
DWORD_BIG_ENDIAN = 5
LINK = 6
MULTI_SZ = 7
RESOURCE_LIST = 8
FULL_RESOURCE_DESCRIPTOR = 9
RESOURCE_REQUIREMENTS_LIST = 10
QWORD = 11
)
var (
// ErrShortBuffer is returned when the buffer was too short for the operation.
ErrShortBuffer = syscall.ERROR_MORE_DATA
// ErrNotExist is returned when a registry key or value does not exist.
ErrNotExist = syscall.ERROR_FILE_NOT_FOUND
// ErrUnexpectedType is returned by Get*Value when the value's type was unexpected.
ErrUnexpectedType = errors.New("unexpected key value type")
)
// GetValue retrieves the type and data for the specified value associated
// with an open key k. It fills up buffer buf and returns the retrieved
// byte count n. If buf is too small to fit the stored value it returns
// ErrShortBuffer error along with the required buffer size n.
// If no buffer is provided, it returns true and actual buffer size n.
// If no buffer is provided, GetValue returns the value's type only.
// If the value does not exist, the error returned is ErrNotExist.
//
// GetValue is a low level function. If value's type is known, use the appropriate
// Get*Value function instead.
func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) {
pname, err := syscall.UTF16PtrFromString(name)
if err != nil {
return 0, 0, err
}
var pbuf *byte
if len(buf) > 0 {
pbuf = (*byte)(unsafe.Pointer(&buf[0]))
}
l := uint32(len(buf))
err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l)
if err != nil {
return int(l), valtype, err
}
return int(l), valtype, nil
}
func (k Key) getValue(name string, buf []byte) (date []byte, valtype uint32, err error) {
p, err := syscall.UTF16PtrFromString(name)
if err != nil {
return nil, 0, err
}
var t uint32
n := uint32(len(buf))
for {
err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
if err == nil {
return buf[:n], t, nil
}
if err != syscall.ERROR_MORE_DATA {
return nil, 0, err
}
if n <= uint32(len(buf)) {
return nil, 0, err
}
buf = make([]byte, n)
}
}
// GetStringValue retrieves the string value for the specified
// value name associated with an open key k. It also returns the value's type.
// If value does not exist, GetStringValue returns ErrNotExist.
// If value is not SZ or EXPAND_SZ, it will return the correct value
// type and ErrUnexpectedType.
func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) {
data, typ, err2 := k.getValue(name, make([]byte, 64))
if err2 != nil {
return "", typ, err2
}
switch typ {
case SZ, EXPAND_SZ:
default:
return "", typ, ErrUnexpectedType
}
if len(data) == 0 {
return "", typ, nil
}
u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[:]
return syscall.UTF16ToString(u), typ, nil
}
// GetMUIStringValue retrieves the localized string value for
// the specified value name associated with an open key k.
// If the value name doesn't exist or the localized string value
// can't be resolved, GetMUIStringValue returns ErrNotExist.
// GetMUIStringValue panics if the system doesn't support
// regLoadMUIString; use LoadRegLoadMUIString to check if
// regLoadMUIString is supported before calling this function.
func (k Key) GetMUIStringValue(name string) (string, error) {
pname, err := syscall.UTF16PtrFromString(name)
if err != nil {
return "", err
}
buf := make([]uint16, 1024)
var buflen uint32
var pdir *uint16
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path
// Try to resolve the string value using the system directory as
// a DLL search path; this assumes the string value is of the form
// @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320.
// This approach works with tzres.dll but may have to be revised
// in the future to allow callers to provide custom search paths.
var s string
s, err = ExpandString("%SystemRoot%\\system32\\")
if err != nil {
return "", err
}
pdir, err = syscall.UTF16PtrFromString(s)
if err != nil {
return "", err
}
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
}
for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed
if buflen <= uint32(len(buf)) {
break // Buffer not growing, assume race; break
}
buf = make([]uint16, buflen)
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
}
if err != nil {
return "", err
}
return syscall.UTF16ToString(buf), nil
}
// ExpandString expands environment-variable strings and replaces
// them with the values defined for the current user.
// Use ExpandString to expand EXPAND_SZ strings.
func ExpandString(value string) (string, error) {
if value == "" {
return "", nil
}
p, err := syscall.UTF16PtrFromString(value)
if err != nil {
return "", err
}
r := make([]uint16, 100)
for {
n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r)))
if err != nil {
return "", err
}
if n <= uint32(len(r)) {
u := (*[1 << 29]uint16)(unsafe.Pointer(&r[0]))[:]
return syscall.UTF16ToString(u), nil
}
r = make([]uint16, n)
}
}
// GetStringsValue retrieves the []string value for the specified
// value name associated with an open key k. It also returns the value's type.
// If value does not exist, GetStringsValue returns ErrNotExist.
// If value is not MULTI_SZ, it will return the correct value
// type and ErrUnexpectedType.
func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) {
data, typ, err2 := k.getValue(name, make([]byte, 64))
if err2 != nil {
return nil, typ, err2
}
if typ != MULTI_SZ {
return nil, typ, ErrUnexpectedType
}
if len(data) == 0 {
return nil, typ, nil
}
p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[:len(data)/2]
if len(p) == 0 {
return nil, typ, nil
}
if p[len(p)-1] == 0 {
p = p[:len(p)-1] // remove terminating null
}
val = make([]string, 0, 5)
from := 0
for i, c := range p {
if c == 0 {
val = append(val, string(utf16.Decode(p[from:i])))
from = i + 1
}
}
return val, typ, nil
}
// GetIntegerValue retrieves the integer value for the specified
// value name associated with an open key k. It also returns the value's type.
// If value does not exist, GetIntegerValue returns ErrNotExist.
// If value is not DWORD or QWORD, it will return the correct value
// type and ErrUnexpectedType.
func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) {
data, typ, err2 := k.getValue(name, make([]byte, 8))
if err2 != nil {
return 0, typ, err2
}
switch typ {
case DWORD:
if len(data) != 4 {
return 0, typ, errors.New("DWORD value is not 4 bytes long")
}
return uint64(*(*uint32)(unsafe.Pointer(&data[0]))), DWORD, nil
case QWORD:
if len(data) != 8 {
return 0, typ, errors.New("QWORD value is not 8 bytes long")
}
return uint64(*(*uint64)(unsafe.Pointer(&data[0]))), QWORD, nil
default:
return 0, typ, ErrUnexpectedType
}
}
// GetBinaryValue retrieves the binary value for the specified
// value name associated with an open key k. It also returns the value's type.
// If value does not exist, GetBinaryValue returns ErrNotExist.
// If value is not BINARY, it will return the correct value
// type and ErrUnexpectedType.
func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) {
data, typ, err2 := k.getValue(name, make([]byte, 64))
if err2 != nil {
return nil, typ, err2
}
if typ != BINARY {
return nil, typ, ErrUnexpectedType
}
return data, typ, nil
}
func (k Key) setValue(name string, valtype uint32, data []byte) error {
p, err := syscall.UTF16PtrFromString(name)
if err != nil {
return err
}
if len(data) == 0 {
return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0)
}
return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data)))
}
// SetDWordValue sets the data and type of a name value
// under key k to value and DWORD.
func (k Key) SetDWordValue(name string, value uint32) error {
return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:])
}
// SetQWordValue sets the data and type of a name value
// under key k to value and QWORD.
func (k Key) SetQWordValue(name string, value uint64) error {
return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:])
}
func (k Key) setStringValue(name string, valtype uint32, value string) error {
v, err := syscall.UTF16FromString(value)
if err != nil {
return err
}
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[:len(v)*2]
return k.setValue(name, valtype, buf)
}
// SetStringValue sets the data and type of a name value
// under key k to value and SZ. The value must not contain a zero byte.
func (k Key) SetStringValue(name, value string) error {
return k.setStringValue(name, SZ, value)
}
// SetExpandStringValue sets the data and type of a name value
// under key k to value and EXPAND_SZ. The value must not contain a zero byte.
func (k Key) SetExpandStringValue(name, value string) error {
return k.setStringValue(name, EXPAND_SZ, value)
}
// SetStringsValue sets the data and type of a name value
// under key k to value and MULTI_SZ. The value strings
// must not contain a zero byte.
func (k Key) SetStringsValue(name string, value []string) error {
ss := ""
for _, s := range value {
for i := 0; i < len(s); i++ {
if s[i] == 0 {
return errors.New("string cannot have 0 inside")
}
}
ss += s + "\x00"
}
v := utf16.Encode([]rune(ss + "\x00"))
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[:len(v)*2]
return k.setValue(name, MULTI_SZ, buf)
}
// SetBinaryValue sets the data and type of a name value
// under key k to value and BINARY.
func (k Key) SetBinaryValue(name string, value []byte) error {
return k.setValue(name, BINARY, value)
}
// DeleteValue removes a named value from the key k.
func (k Key) DeleteValue(name string) error {
return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name))
}
// ReadValueNames returns the value names of key k.
// The parameter n controls the number of returned names,
// analogous to the way os.File.Readdirnames works.
func (k Key) ReadValueNames(n int) ([]string, error) {
ki, err := k.Stat()
if err != nil {
return nil, err
}
names := make([]string, 0, ki.ValueCount)
buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character
loopItems:
for i := uint32(0); ; i++ {
if n > 0 {
if len(names) == n {
return names, nil
}
}
l := uint32(len(buf))
for {
err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
if err == nil {
break
}
if err == syscall.ERROR_MORE_DATA {
// Double buffer size and try again.
l = uint32(2 * len(buf))
buf = make([]uint16, l)
continue
}
if err == _ERROR_NO_MORE_ITEMS {
break loopItems
}
return names, err
}
names = append(names, syscall.UTF16ToString(buf[:l]))
}
if n > len(names) {
return names, io.EOF
}
return names, nil
}

View file

@ -1,82 +0,0 @@
// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT
package registry
import "unsafe"
import "syscall"
var _ unsafe.Pointer
var (
modadvapi32 = syscall.NewLazyDLL("advapi32.dll")
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procRegCreateKeyExW = modadvapi32.NewProc("RegCreateKeyExW")
procRegDeleteKeyW = modadvapi32.NewProc("RegDeleteKeyW")
procRegSetValueExW = modadvapi32.NewProc("RegSetValueExW")
procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW")
procRegDeleteValueW = modadvapi32.NewProc("RegDeleteValueW")
procRegLoadMUIStringW = modadvapi32.NewProc("RegLoadMUIStringW")
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
)
func regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) {
r0, _, _ := syscall.Syscall9(procRegCreateKeyExW.Addr(), 9, uintptr(key), uintptr(unsafe.Pointer(subkey)), uintptr(reserved), uintptr(unsafe.Pointer(class)), uintptr(options), uintptr(desired), uintptr(unsafe.Pointer(sa)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(disposition)))
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) {
r0, _, _ := syscall.Syscall(procRegDeleteKeyW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(subkey)), 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) {
r0, _, _ := syscall.Syscall6(procRegSetValueExW.Addr(), 6, uintptr(key), uintptr(unsafe.Pointer(valueName)), uintptr(reserved), uintptr(vtype), uintptr(unsafe.Pointer(buf)), uintptr(bufsize))
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) {
r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valtype)), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(buflen)), 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) {
r0, _, _ := syscall.Syscall(procRegDeleteValueW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(name)), 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) {
r0, _, _ := syscall.Syscall9(procRegLoadMUIStringW.Addr(), 7, uintptr(key), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(unsafe.Pointer(buflenCopied)), uintptr(flags), uintptr(unsafe.Pointer(dir)), 0, 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}
func expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall(procExpandEnvironmentStringsW.Addr(), 3, uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(size))
n = uint32(r0)
if n == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}