podman/vendor/github.com/container-orchestrated-devices/container-device-interface/pkg/devices.go

181 lines
3.6 KiB
Go

package pkg
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
cdispec "github.com/container-orchestrated-devices/container-device-interface/specs-go"
spec "github.com/opencontainers/runtime-spec/specs-go"
)
const (
root = "/etc/cdi"
)
func collectCDISpecs() (map[string]*cdispec.Spec, error) {
var files []string
vendor := make(map[string]*cdispec.Spec)
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if info == nil || info.IsDir() {
return nil
}
if filepath.Ext(path) != ".json" {
return nil
}
files = append(files, path)
return nil
})
if err != nil {
return nil, err
}
for _, path := range files {
spec, err := loadCDIFile(path)
if err != nil {
continue
}
if _, ok := vendor[spec.Kind]; ok {
continue
}
vendor[spec.Kind] = spec
}
return vendor, nil
}
// TODO: Validate (e.g: duplicate device names)
func loadCDIFile(path string) (*cdispec.Spec, error) {
file, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var spec *cdispec.Spec
err = json.Unmarshal([]byte(file), &spec)
if err != nil {
return nil, err
}
return spec, nil
}
/*
* Pattern "vendor.com/device=myDevice" with the vendor being optional
*/
func extractVendor(dev string) (string, string) {
if strings.IndexByte(dev, '=') == -1 {
return "", dev
}
split := strings.SplitN(dev, "=", 2)
return split[0], split[1]
}
// GetCDIForDevice returns the CDI specification that matches the device name the user provided.
func GetCDIForDevice(dev string, specs map[string]*cdispec.Spec) (*cdispec.Spec, error) {
vendor, device := extractVendor(dev)
if vendor != "" {
s, ok := specs[vendor]
if !ok {
return nil, fmt.Errorf("Could not find vendor %q for device %q", vendor, device)
}
for _, d := range s.Devices {
if d.Name != device {
continue
}
return s, nil
}
return nil, fmt.Errorf("Could not find device %q for vendor %q", device, vendor)
}
var found []*cdispec.Spec
var vendors []string
for vendor, spec := range specs {
for _, d := range spec.Devices {
if d.Name != device {
continue
}
found = append(found, spec)
vendors = append(vendors, vendor)
}
}
if len(found) > 1 {
return nil, fmt.Errorf("%q is ambiguous and currently refers to multiple devices from different vendors: %q", dev, vendors)
}
if len(found) == 1 {
return found[0], nil
}
return nil, fmt.Errorf("Could not find device %q", dev)
}
// HasDevice returns true if a device is a CDI device
// an error may be returned in cases where permissions may be required
func HasDevice(dev string) (bool, error) {
specs, err := collectCDISpecs()
if err != nil {
return false, err
}
d, err := GetCDIForDevice(dev, specs)
if err != nil {
return false, err
}
return d != nil, nil
}
// UpdateOCISpecForDevices updates the given OCI spec based on the requested CDI devices
func UpdateOCISpecForDevices(ociconfig *spec.Spec, devs []string) error {
specs, err := collectCDISpecs()
if err != nil {
return err
}
return UpdateOCISpecForDevicesWithSpec(ociconfig, devs, specs)
}
// UpdateOCISpecForDevicesWithLoggerAndSpecs is mainly used for testing
func UpdateOCISpecForDevicesWithSpec(ociconfig *spec.Spec, devs []string, specs map[string]*cdispec.Spec) error {
edits := make(map[string]*cdispec.Spec)
for _, d := range devs {
spec, err := GetCDIForDevice(d, specs)
if err != nil {
return err
}
edits[spec.Kind] = spec
err = cdispec.ApplyOCIEditsForDevice(ociconfig, spec, d)
if err != nil {
return err
}
}
for _, spec := range edits {
if err := cdispec.ApplyOCIEdits(ociconfig, spec); err != nil {
return err
}
}
return nil
}