diff --git a/pkg/chunked/compression_linux.go b/pkg/chunked/compression_linux.go index 3cf6100be..ddd7ff53b 100644 --- a/pkg/chunked/compression_linux.go +++ b/pkg/chunked/compression_linux.go @@ -185,7 +185,10 @@ func openTmpFileNoTmpFile(tmpDir string) (*os.File, error) { // Returns (manifest blob, parsed manifest, tar-split file or nil, manifest offset). // The opened tar-split file’s position is unspecified. // It may return an error matching ErrFallbackToOrdinaryLayerDownload / errFallbackCanConvert. -func readZstdChunkedManifest(tmpDir string, blobStream ImageSourceSeekable, tocDigest digest.Digest, annotations map[string]string) (_ []byte, _ *minimal.TOC, _ *os.File, _ int64, retErr error) { +// The compressed parameter indicates whether the manifest and tar-split data are zstd-compressed +// (true) or stored uncompressed (false). Uncompressed data is used only for an optimization to convert +// a regular OCI layer to zstd:chunked when convert_images is set, and it is not used for distributed images. +func readZstdChunkedManifest(tmpDir string, blobStream ImageSourceSeekable, tocDigest digest.Digest, annotations map[string]string, compressed bool) (_ []byte, _ *minimal.TOC, _ *os.File, _ int64, retErr error) { offsetMetadata := annotations[minimal.ManifestInfoKey] if offsetMetadata == "" { return nil, nil, nil, 0, fmt.Errorf("%q annotation missing", minimal.ManifestInfoKey) @@ -261,7 +264,7 @@ func readZstdChunkedManifest(tmpDir string, blobStream ImageSourceSeekable, tocD return nil, nil, nil, 0, err } - decodedBlob, err := decodeAndValidateBlob(manifest, manifestLengthUncompressed, tocDigest.String()) + decodedBlob, err := decodeAndValidateBlob(manifest, manifestLengthUncompressed, tocDigest.String(), compressed) if err != nil { return nil, nil, nil, 0, fmt.Errorf("validating and decompressing TOC: %w", err) } @@ -288,7 +291,7 @@ func readZstdChunkedManifest(tmpDir string, blobStream ImageSourceSeekable, tocD decodedTarSplit.Close() } }() - if err := decodeAndValidateBlobToStream(tarSplit, decodedTarSplit, toc.TarSplitDigest.String()); err != nil { + if err := decodeAndValidateBlobToStream(tarSplit, decodedTarSplit, toc.TarSplitDigest.String(), compressed); err != nil { return nil, nil, nil, 0, fmt.Errorf("validating and decompressing tar-split: %w", err) } // We use the TOC for creating on-disk files, but the tar-split for creating metadata @@ -487,11 +490,15 @@ func validateBlob(blob []byte, expectedCompressedChecksum string) error { return nil } -func decodeAndValidateBlob(blob []byte, lengthUncompressed uint64, expectedCompressedChecksum string) ([]byte, error) { +func decodeAndValidateBlob(blob []byte, lengthUncompressed uint64, expectedCompressedChecksum string, compressed bool) ([]byte, error) { if err := validateBlob(blob, expectedCompressedChecksum); err != nil { return nil, err } + if !compressed { + return blob, nil + } + decoder, err := zstd.NewReader(nil) if err != nil { return nil, err @@ -502,11 +509,16 @@ func decodeAndValidateBlob(blob []byte, lengthUncompressed uint64, expectedCompr return decoder.DecodeAll(blob, b) } -func decodeAndValidateBlobToStream(blob []byte, w *os.File, expectedCompressedChecksum string) error { +func decodeAndValidateBlobToStream(blob []byte, w *os.File, expectedCompressedChecksum string, compressed bool) error { if err := validateBlob(blob, expectedCompressedChecksum); err != nil { return err } + if !compressed { + _, err := w.Write(blob) + return err + } + decoder, err := zstd.NewReader(bytes.NewReader(blob)) if err != nil { return err diff --git a/pkg/chunked/compressor/compressor.go b/pkg/chunked/compressor/compressor.go index 0de063a24..2930723aa 100644 --- a/pkg/chunked/compressor/compressor.go +++ b/pkg/chunked/compressor/compressor.go @@ -11,7 +11,6 @@ import ( "github.com/containers/storage/pkg/chunked/internal/minimal" "github.com/containers/storage/pkg/ioutils" - "github.com/klauspost/compress/zstd" "github.com/opencontainers/go-digest" "github.com/vbatts/tar-split/archive/tar" "github.com/vbatts/tar-split/tar/asm" @@ -202,15 +201,15 @@ type tarSplitData struct { compressed *bytes.Buffer digester digest.Digester uncompressedCounter *ioutils.WriteCounter - zstd *zstd.Encoder + zstd minimal.ZstdWriter packer storage.Packer } -func newTarSplitData(level int) (*tarSplitData, error) { +func newTarSplitData(createZstdWriter minimal.CreateZstdWriterFunc) (*tarSplitData, error) { compressed := bytes.NewBuffer(nil) digester := digest.Canonical.Digester() - zstdWriter, err := minimal.ZstdWriterWithLevel(io.MultiWriter(compressed, digester.Hash()), level) + zstdWriter, err := createZstdWriter(io.MultiWriter(compressed, digester.Hash())) if err != nil { return nil, err } @@ -227,11 +226,11 @@ func newTarSplitData(level int) (*tarSplitData, error) { }, nil } -func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, reader io.Reader, level int) error { +func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, reader io.Reader, createZstdWriter minimal.CreateZstdWriterFunc) error { // total written so far. Used to retrieve partial offsets in the file dest := ioutils.NewWriteCounter(destFile) - tarSplitData, err := newTarSplitData(level) + tarSplitData, err := newTarSplitData(createZstdWriter) if err != nil { return err } @@ -251,7 +250,7 @@ func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, r buf := make([]byte, 4096) - zstdWriter, err := minimal.ZstdWriterWithLevel(dest, level) + zstdWriter, err := createZstdWriter(dest) if err != nil { return err } @@ -404,18 +403,11 @@ func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, r return err } - if err := zstdWriter.Flush(); err != nil { - zstdWriter.Close() - return err - } if err := zstdWriter.Close(); err != nil { return err } zstdWriter = nil - if err := tarSplitData.zstd.Flush(); err != nil { - return err - } if err := tarSplitData.zstd.Close(); err != nil { return err } @@ -427,7 +419,7 @@ func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, r UncompressedSize: tarSplitData.uncompressedCounter.Count, } - return minimal.WriteZstdChunkedManifest(dest, outMetadata, uint64(dest.Count), &ts, metadata, level) + return minimal.WriteZstdChunkedManifest(dest, outMetadata, uint64(dest.Count), &ts, metadata, createZstdWriter) } type zstdChunkedWriter struct { @@ -454,7 +446,7 @@ func (w zstdChunkedWriter) Write(p []byte) (int, error) { } } -// zstdChunkedWriterWithLevel writes a zstd compressed tarball where each file is +// makeZstdChunkedWriter writes a zstd compressed tarball where each file is // compressed separately so it can be addressed separately. Idea based on CRFS: // https://github.com/google/crfs // The difference with CRFS is that the zstd compression is used instead of gzip. @@ -469,12 +461,12 @@ func (w zstdChunkedWriter) Write(p []byte) (int, error) { // [SKIPPABLE FRAME 1]: [ZSTD SKIPPABLE FRAME, SIZE=MANIFEST LENGTH][MANIFEST] // [SKIPPABLE FRAME 2]: [ZSTD SKIPPABLE FRAME, SIZE=16][MANIFEST_OFFSET][MANIFEST_LENGTH][MANIFEST_LENGTH_UNCOMPRESSED][MANIFEST_TYPE][CHUNKED_ZSTD_MAGIC_NUMBER] // MANIFEST_OFFSET, MANIFEST_LENGTH, MANIFEST_LENGTH_UNCOMPRESSED and CHUNKED_ZSTD_MAGIC_NUMBER are 64 bits unsigned in little endian format. -func zstdChunkedWriterWithLevel(out io.Writer, metadata map[string]string, level int) (io.WriteCloser, error) { +func makeZstdChunkedWriter(out io.Writer, metadata map[string]string, createZstdWriter minimal.CreateZstdWriterFunc) (io.WriteCloser, error) { ch := make(chan error, 1) r, w := io.Pipe() go func() { - ch <- writeZstdChunkedStream(out, metadata, r, level) + ch <- writeZstdChunkedStream(out, metadata, r, createZstdWriter) _, _ = io.Copy(io.Discard, r) // Ordinarily writeZstdChunkedStream consumes all of r. If it fails, ensure the write end never blocks and eventually terminates. r.Close() close(ch) @@ -493,5 +485,40 @@ func ZstdCompressor(r io.Writer, metadata map[string]string, level *int) (io.Wri level = &l } - return zstdChunkedWriterWithLevel(r, metadata, *level) + createZstdWriter := func(dest io.Writer) (minimal.ZstdWriter, error) { + return minimal.ZstdWriterWithLevel(dest, *level) + } + + return makeZstdChunkedWriter(r, metadata, createZstdWriter) +} + +type noCompression struct { + dest io.Writer +} + +func (n *noCompression) Write(p []byte) (int, error) { + return n.dest.Write(p) +} + +func (n *noCompression) Close() error { + return nil +} + +func (n *noCompression) Flush() error { + return nil +} + +func (n *noCompression) Reset(dest io.Writer) { + n.dest = dest +} + +// NoCompression writes directly to the output file without any compression +// +// Such an output does not follow the zstd:chunked spec and cannot be generally consumed; this function +// only exists for internal purposes and should not be called from outside c/storage. +func NoCompression(r io.Writer, metadata map[string]string) (io.WriteCloser, error) { + createZstdWriter := func(dest io.Writer) (minimal.ZstdWriter, error) { + return &noCompression{dest: dest}, nil + } + return makeZstdChunkedWriter(r, metadata, createZstdWriter) } diff --git a/pkg/chunked/compressor/compressor_test.go b/pkg/chunked/compressor/compressor_test.go index 100c7d923..b1552e14f 100644 --- a/pkg/chunked/compressor/compressor_test.go +++ b/pkg/chunked/compressor/compressor_test.go @@ -3,8 +3,11 @@ package compressor import ( "bufio" "bytes" + "errors" "io" "testing" + + "github.com/stretchr/testify/assert" ) func TestHole(t *testing.T) { @@ -88,3 +91,82 @@ func TestTwoHoles(t *testing.T) { t.Error("didn't receive EOF") } } + +func TestNoCompressionWrite(t *testing.T) { + var buf bytes.Buffer + nc := &noCompression{dest: &buf} + + data := []byte("hello world") + n, err := nc.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + assert.Equal(t, data, buf.Bytes()) + + data2 := []byte(" again") + n, err = nc.Write(data2) + assert.NoError(t, err) + assert.Equal(t, len(data2), n) + assert.Equal(t, append(data, data2...), buf.Bytes()) +} + +func TestNoCompressionClose(t *testing.T) { + var buf bytes.Buffer + nc := &noCompression{dest: &buf} + err := nc.Close() + assert.NoError(t, err) +} + +func TestNoCompressionFlush(t *testing.T) { + var buf bytes.Buffer + nc := &noCompression{dest: &buf} + err := nc.Flush() + assert.NoError(t, err) +} + +func TestNoCompressionReset(t *testing.T) { + var buf1 bytes.Buffer + nc := &noCompression{dest: &buf1} + + data1 := []byte("initial data") + _, err := nc.Write(data1) + assert.NoError(t, err) + assert.Equal(t, data1, buf1.Bytes()) + + err = nc.Close() + assert.NoError(t, err) + + var buf2 bytes.Buffer + nc.Reset(&buf2) + + data2 := []byte("new data") + _, err = nc.Write(data2) + assert.NoError(t, err) + + assert.Equal(t, data1, buf1.Bytes(), "Buffer 1 should remain unchanged") + assert.Equal(t, data2, buf2.Bytes(), "Buffer 2 should contain the new data") + + err = nc.Close() + assert.NoError(t, err) + + // Test Reset with nil, though Write would panic, Reset itself should work + nc.Reset(nil) + assert.Nil(t, nc.dest) +} + +// Mock writer that returns an error on Write +type errorWriter struct{} + +func (ew *errorWriter) Write(p []byte) (n int, err error) { + return 0, errors.New("mock write error") +} + +func TestNoCompressionWriteError(t *testing.T) { + ew := &errorWriter{} + nc := &noCompression{dest: ew} + + data := []byte("hello world") + n, err := nc.Write(data) + assert.Error(t, err) + assert.Equal(t, 0, n) + assert.Equal(t, "mock write error", err.Error()) +} diff --git a/pkg/chunked/internal/minimal/compression.go b/pkg/chunked/internal/minimal/compression.go index f85c5973c..4191524cc 100644 --- a/pkg/chunked/internal/minimal/compression.go +++ b/pkg/chunked/internal/minimal/compression.go @@ -20,6 +20,15 @@ import ( "github.com/vbatts/tar-split/archive/tar" ) +// ZstdWriter is an interface that wraps standard io.WriteCloser and Reset() to reuse the compressor with a new writer. +type ZstdWriter interface { + io.WriteCloser + Reset(dest io.Writer) +} + +// CreateZstdWriterFunc is a function that creates a ZstdWriter for the provided destination writer. +type CreateZstdWriterFunc func(dest io.Writer) (ZstdWriter, error) + // TOC is short for Table of Contents and is used by the zstd:chunked // file format to effectively add an overall index into the contents // of a tarball; it also includes file metadata. @@ -179,7 +188,7 @@ type TarSplitData struct { UncompressedSize int64 } -func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, offset uint64, tarSplitData *TarSplitData, metadata []FileMetadata, level int) error { +func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, offset uint64, tarSplitData *TarSplitData, metadata []FileMetadata, createZstdWriter CreateZstdWriterFunc) error { // 8 is the size of the zstd skippable frame header + the frame size const zstdSkippableFrameHeader = 8 manifestOffset := offset + zstdSkippableFrameHeader @@ -198,7 +207,7 @@ func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, off } var compressedBuffer bytes.Buffer - zstdWriter, err := ZstdWriterWithLevel(&compressedBuffer, level) + zstdWriter, err := createZstdWriter(&compressedBuffer) if err != nil { return err } @@ -244,7 +253,7 @@ func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, off return appendZstdSkippableFrame(dest, manifestDataLE) } -func ZstdWriterWithLevel(dest io.Writer, level int) (*zstd.Encoder, error) { +func ZstdWriterWithLevel(dest io.Writer, level int) (ZstdWriter, error) { el := zstd.EncoderLevelFromZstd(level) return zstd.NewWriter(dest, zstd.WithEncoderLevel(el)) } diff --git a/pkg/chunked/storage_linux.go b/pkg/chunked/storage_linux.go index 97dc9b814..f23a96b7a 100644 --- a/pkg/chunked/storage_linux.go +++ b/pkg/chunked/storage_linux.go @@ -170,8 +170,7 @@ func (c *chunkedDiffer) convertTarToZstdChunked(destDirectory string, payload *o } newAnnotations := make(map[string]string) - level := 1 - chunked, err := compressor.ZstdCompressor(f, newAnnotations, &level) + chunked, err := compressor.NoCompression(f, newAnnotations) if err != nil { f.Close() return 0, nil, "", nil, err @@ -341,7 +340,7 @@ func makeConvertFromRawDiffer(store storage.Store, blobDigest digest.Digest, blo // makeZstdChunkedDiffer sets up a chunkedDiffer for a zstd:chunked layer. // It may return an error matching ErrFallbackToOrdinaryLayerDownload / errFallbackCanConvert. func makeZstdChunkedDiffer(store storage.Store, blobSize int64, tocDigest digest.Digest, annotations map[string]string, iss ImageSourceSeekable, pullOptions pullOptions) (_ *chunkedDiffer, retErr error) { - manifest, toc, tarSplit, tocOffset, err := readZstdChunkedManifest(store.RunRoot(), iss, tocDigest, annotations) + manifest, toc, tarSplit, tocOffset, err := readZstdChunkedManifest(store.RunRoot(), iss, tocDigest, annotations, true) if err != nil { // May be ErrFallbackToOrdinaryLayerDownload / errFallbackCanConvert return nil, fmt.Errorf("read zstd:chunked manifest: %w", err) } @@ -666,20 +665,17 @@ func (o *originFile) OpenFile() (io.ReadCloser, error) { return srcFile, nil } -func (c *chunkedDiffer) prepareCompressedStreamToFile(partCompression compressedFileType, from io.Reader, mf *missingFileChunk) (compressedFileType, error) { +func (c *chunkedDiffer) prepareCompressedStreamToFile(partCompression compressedFileType, mf *missingFileChunk) (compressedFileType, error) { switch { case partCompression == fileTypeHole: // The entire part is a hole. Do not need to read from a file. - c.rawReader = nil return fileTypeHole, nil case mf.Hole: // Only the missing chunk in the requested part refers to a hole. // The received data must be discarded. - limitReader := io.LimitReader(from, mf.CompressedSize) - _, err := io.CopyBuffer(io.Discard, limitReader, c.copyBuffer) + _, err := io.CopyBuffer(io.Discard, c.rawReader, c.copyBuffer) return fileTypeHole, err case partCompression == fileTypeZstdChunked: - c.rawReader = io.LimitReader(from, mf.CompressedSize) if c.zstdReader == nil { r, err := zstd.NewReader(c.rawReader) if err != nil { @@ -692,7 +688,6 @@ func (c *chunkedDiffer) prepareCompressedStreamToFile(partCompression compressed } } case partCompression == fileTypeEstargz: - c.rawReader = io.LimitReader(from, mf.CompressedSize) if c.gzipReader == nil { r, err := pgzip.NewReader(c.rawReader) if err != nil { @@ -705,7 +700,7 @@ func (c *chunkedDiffer) prepareCompressedStreamToFile(partCompression compressed } } case partCompression == fileTypeNoCompression: - c.rawReader = io.LimitReader(from, mf.UncompressedSize) + return fileTypeNoCompression, nil default: return partCompression, fmt.Errorf("unknown file type %q", c.fileType) } @@ -905,6 +900,7 @@ func (c *chunkedDiffer) storeMissingFiles(streams chan io.ReadCloser, errs chan for _, missingPart := range missingParts { var part io.ReadCloser partCompression := c.fileType + readingFromLocalFile := false switch { case missingPart.Hole: partCompression = fileTypeHole @@ -915,6 +911,7 @@ func (c *chunkedDiffer) storeMissingFiles(streams chan io.ReadCloser, errs chan return err } partCompression = fileTypeNoCompression + readingFromLocalFile = true case missingPart.SourceChunk != nil: select { case p := <-streams: @@ -948,7 +945,18 @@ func (c *chunkedDiffer) storeMissingFiles(streams chan io.ReadCloser, errs chan goto exit } - compression, err := c.prepareCompressedStreamToFile(partCompression, part, &mf) + c.rawReader = nil + if part != nil { + limit := mf.CompressedSize + // If we are reading from a source file, use the uncompressed size to limit the reader, because + // the compressed size refers to the original layer stream. + if readingFromLocalFile { + limit = mf.UncompressedSize + } + c.rawReader = io.LimitReader(part, limit) + } + + compression, err := c.prepareCompressedStreamToFile(partCompression, &mf) if err != nil { Err = err goto exit @@ -1440,7 +1448,9 @@ func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, diff if err != nil { return graphdriver.DriverWithDifferOutput{}, err } + c.uncompressedTarSize = tarSize + // fileSource is a O_TMPFILE file descriptor, so we // need to keep it open until the entire file is processed. defer fileSource.Close() @@ -1456,7 +1466,7 @@ func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, diff if tocDigest == nil { return graphdriver.DriverWithDifferOutput{}, fmt.Errorf("internal error: just-created zstd:chunked missing TOC digest") } - manifest, toc, tarSplit, tocOffset, err := readZstdChunkedManifest(dest, fileSource, *tocDigest, annotations) + manifest, toc, tarSplit, tocOffset, err := readZstdChunkedManifest(dest, fileSource, *tocDigest, annotations, false) if err != nil { return graphdriver.DriverWithDifferOutput{}, fmt.Errorf("read zstd:chunked manifest: %w", err) } @@ -1465,7 +1475,7 @@ func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, diff stream = fileSource // fill the chunkedDiffer with the data we just read. - c.fileType = fileTypeZstdChunked + c.fileType = fileTypeNoCompression c.manifest = manifest c.toc = toc c.tarSplit = tarSplit diff --git a/pkg/chunked/zstdchunked_test.go b/pkg/chunked/zstdchunked_test.go index f2834b42d..2565c2f55 100644 --- a/pkg/chunked/zstdchunked_test.go +++ b/pkg/chunked/zstdchunked_test.go @@ -129,7 +129,12 @@ func TestGenerateAndParseManifest(t *testing.T) { var b bytes.Buffer writer := bufio.NewWriter(&b) - if err := minimal.WriteZstdChunkedManifest(writer, annotations, offsetManifest, &ts, someFiles[:], 9); err != nil { + + createZstdWriter := func(dest io.Writer) (minimal.ZstdWriter, error) { + return minimal.ZstdWriterWithLevel(dest, 9) + } + + if err := minimal.WriteZstdChunkedManifest(writer, annotations, offsetManifest, &ts, someFiles[:], createZstdWriter); err != nil { t.Error(err) } if err := writer.Flush(); err != nil { @@ -179,7 +184,7 @@ func TestGenerateAndParseManifest(t *testing.T) { tocDigest, err := toc.GetTOCDigest(annotations) require.NoError(t, err) require.NotNil(t, tocDigest) - manifest, decodedTOC, _, _, err := readZstdChunkedManifest(t.TempDir(), s, *tocDigest, annotations) + manifest, decodedTOC, _, _, err := readZstdChunkedManifest(t.TempDir(), s, *tocDigest, annotations, true) require.NoError(t, err) var toc minimal.TOC