From 0d4a46ce285b0f5295561e1845f75537e80d2d53 Mon Sep 17 00:00:00 2001 From: Nalin Dahyabhai Date: Tue, 16 Jul 2019 16:10:41 -0400 Subject: [PATCH] Keep track of the UIDs and GIDs used in applied layers Add a field to the Layer structure that lets us make note of the set of UIDs and GIDs which own files in the layer, populated by scanning the diff that we used to populate the layer, if there was one. Signed-off-by: Nalin Dahyabhai --- layers.go | 38 +++++++++++++++++++++++++- pkg/tarlog/tarlogger.go | 47 ++++++++++++++++++++++++++++++++ pkg/tarlog/tarlogger_test.go | 52 ++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 pkg/tarlog/tarlogger.go create mode 100644 pkg/tarlog/tarlogger_test.go diff --git a/layers.go b/layers.go index f40423e17..ce5d97917 100644 --- a/layers.go +++ b/layers.go @@ -1,15 +1,18 @@ package storage import ( + "archive/tar" "bytes" "encoding/json" "fmt" "io" "io/ioutil" "os" + "path" "path/filepath" "reflect" "sort" + "strings" "time" drivers "github.com/containers/storage/drivers" @@ -18,6 +21,7 @@ import ( "github.com/containers/storage/pkg/ioutils" "github.com/containers/storage/pkg/stringid" "github.com/containers/storage/pkg/system" + "github.com/containers/storage/pkg/tarlog" "github.com/containers/storage/pkg/truncindex" "github.com/klauspost/pgzip" digest "github.com/opencontainers/go-digest" @@ -96,6 +100,12 @@ type Layer struct { // that was last passed to ApplyDiff() or Put(). CompressionType archive.Compression `json:"compression,omitempty"` + // UIDs and GIDs are lists of UIDs and GIDs used in the layer. This + // field is only populated (i.e., will only contain one or more + // entries) if the layer was created using ApplyDiff() or Put(). + UIDs []uint32 `json:"uidset,omitempty"` + GIDs []uint32 `json:"gidset,omitempty"` + // Flags is arbitrary data about the layer. Flags map[string]interface{} `json:"flags,omitempty"` @@ -1236,7 +1246,18 @@ func (r *layerStore) ApplyDiff(to string, diff io.Reader) (size int64, err error } uncompressedDigest := digest.Canonical.Digester() uncompressedCounter := ioutils.NewWriteCounter(uncompressedDigest.Hash()) - payload, err := asm.NewInputTarStream(io.TeeReader(uncompressed, uncompressedCounter), metadata, storage.NewDiscardFilePutter()) + uidLog := make(map[uint32]struct{}) + gidLog := make(map[uint32]struct{}) + idLogger, err := tarlog.NewLogger(func(h *tar.Header) { + if !strings.HasPrefix(path.Base(h.Name), archive.WhiteoutPrefix) { + uidLog[uint32(h.Uid)] = struct{}{} + gidLog[uint32(h.Gid)] = struct{}{} + } + }) + if err != nil { + return -1, err + } + payload, err := asm.NewInputTarStream(io.TeeReader(uncompressed, io.MultiWriter(uncompressedCounter, idLogger)), metadata, storage.NewDiscardFilePutter()) if err != nil { return -1, err } @@ -1245,6 +1266,7 @@ func (r *layerStore) ApplyDiff(to string, diff io.Reader) (size int64, err error return -1, err } compressor.Close() + idLogger.Close() if err == nil { if err := os.MkdirAll(filepath.Dir(r.tspath(layer.ID)), 0700); err != nil { return -1, err @@ -1279,6 +1301,20 @@ func (r *layerStore) ApplyDiff(to string, diff io.Reader) (size int64, err error layer.UncompressedDigest = uncompressedDigest.Digest() layer.UncompressedSize = uncompressedCounter.Count layer.CompressionType = compression + layer.UIDs = make([]uint32, 0, len(uidLog)) + for uid := range uidLog { + layer.UIDs = append(layer.UIDs, uid) + } + sort.Slice(layer.UIDs, func(i, j int) bool { + return layer.UIDs[i] < layer.UIDs[j] + }) + layer.GIDs = make([]uint32, 0, len(gidLog)) + for gid := range gidLog { + layer.GIDs = append(layer.GIDs, gid) + } + sort.Slice(layer.GIDs, func(i, j int) bool { + return layer.GIDs[i] < layer.GIDs[j] + }) err = r.Save() diff --git a/pkg/tarlog/tarlogger.go b/pkg/tarlog/tarlogger.go new file mode 100644 index 000000000..8451de01e --- /dev/null +++ b/pkg/tarlog/tarlogger.go @@ -0,0 +1,47 @@ +package tarlog + +import ( + "archive/tar" + "io" + "os" + "sync" + + "github.com/pkg/errors" +) + +type tarLogger struct { + writer *os.File + wg sync.WaitGroup +} + +// NewLogger returns a writer that, when a tar archive is written to it, calls +// `logger` for each file header it encounters in the archive. +func NewLogger(logger func(*tar.Header)) (io.WriteCloser, error) { + reader, writer, err := os.Pipe() + if err != nil { + return nil, errors.Wrapf(err, "error creating pipe for tar logger") + } + t := &tarLogger{writer: writer} + tr := tar.NewReader(reader) + t.wg.Add(1) + go func() { + hdr, err := tr.Next() + for err == nil { + logger(hdr) + hdr, err = tr.Next() + } + reader.Close() + t.wg.Done() + }() + return t, nil +} + +func (t *tarLogger) Write(b []byte) (int, error) { + return t.writer.Write(b) +} + +func (t *tarLogger) Close() error { + err := t.writer.Close() + t.wg.Wait() + return err +} diff --git a/pkg/tarlog/tarlogger_test.go b/pkg/tarlog/tarlogger_test.go new file mode 100644 index 000000000..f29093daf --- /dev/null +++ b/pkg/tarlog/tarlogger_test.go @@ -0,0 +1,52 @@ +package tarlog + +import ( + "archive/tar" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTarLogger(t *testing.T) { + cases := make([]struct { + name string + data []byte + }, 32) + for i := range cases { + cases[i].name = "ext" + if i > 0 { + cases[i].name = cases[i-1].name + "." + cases[i].name + } + cases[i].data = make([]byte, i*64) + } + + loggedNames := []string{} + logNames := func(h *tar.Header) { + loggedNames = append(loggedNames, h.Name) + } + + logger, err := NewLogger(logNames) + require.NoError(t, err, "error creating new TarLogger") + + writer := tar.NewWriter(logger) + for i := range cases { + h := &tar.Header{ + Name: cases[i].name, + Typeflag: tar.TypeReg, + Size: int64(len(cases[i].data)), + } + err := writer.WriteHeader(h) + require.NoError(t, err, "error writing header to tar buffer") + n, err := writer.Write(cases[i].data) + require.NoError(t, err, "error writing data to tar buffer") + require.Equal(t, n, len(cases[i].data), "expected to write %d bytes, wrote %d", len(cases[i].data), n) + } + writer.Close() + + logger.Close() + + require.Equal(t, len(loggedNames), len(cases), "expected to log %d names, logged %d", len(cases), len(loggedNames)) + for i := range cases { + require.Equal(t, loggedNames[i], cases[i].name, "expected to see name %q, got name %q", cases[i].name, loggedNames[i]) + } +}