package daemon import ( "os" "os/exec" "strconv" "strings" "github.com/containerd/containerd/contrib/nvidia" "github.com/docker/docker/pkg/capabilities" "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" ) // TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object. // TODO: add list of device capabilities in daemon/node info var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request") const nvidiaHook = "nvidia-container-runtime-hook" // These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps var allNvidiaCaps = map[nvidia.Capability]struct{}{ nvidia.Compute: {}, nvidia.Compat32: {}, nvidia.Graphics: {}, nvidia.Utility: {}, nvidia.Video: {}, nvidia.Display: {}, } func init() { if _, err := exec.LookPath(nvidiaHook); err != nil { // do not register Nvidia driver if helper binary is not present. return } capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}} nvidiaDriver := &deviceDriver{ capset: capset, updateSpec: setNvidiaGPUs, } for c := range allNvidiaCaps { nvidiaDriver.capset[string(c)] = struct{}{} } registerDeviceDriver("nvidia", nvidiaDriver) } func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error { req := dev.req if req.Count != 0 && len(req.DeviceIDs) > 0 { return errConflictCountDeviceIDs } if len(req.DeviceIDs) > 0 { s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ",")) } else if req.Count > 0 { s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count)) } else if req.Count < 0 { s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all") } var nvidiaCaps []string // req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities. for _, c := range dev.selectedCaps { nvcap := nvidia.Capability(c) if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap { nvidiaCaps = append(nvidiaCaps, c) continue } // TODO: nvidia.WithRequiredCUDAVersion // for now we let the prestart hook verify cuda versions but errors are not pretty. } if nvidiaCaps != nil { s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ",")) } path, err := exec.LookPath(nvidiaHook) if err != nil { return err } if s.Hooks == nil { s.Hooks = &specs.Hooks{} } s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ Path: path, Args: []string{ nvidiaHook, "prestart", }, Env: os.Environ(), }) return nil } // countToDevices returns the list 0, 1, ... count-1 of deviceIDs. func countToDevices(count int) string { devices := make([]string, count) for i := range devices { devices[i] = strconv.Itoa(i) } return strings.Join(devices, ",") }