diff --git a/client/daemon/peer/piece_downloader.go b/client/daemon/peer/piece_downloader.go index 40111e689..45e21c040 100644 --- a/client/daemon/peer/piece_downloader.go +++ b/client/daemon/peer/piece_downloader.go @@ -178,7 +178,7 @@ func (p *pieceDownloader) DownloadPiece(ctx context.Context, req *DownloadPieceR reader, closer := resp.Body.(io.Reader), resp.Body.(io.Closer) if req.CalcDigest { req.log.Debugf("calculate digest for piece %d, digest: %s", req.piece.PieceNum, req.piece.PieceMd5) - reader, err = digest.NewReader(req.log, io.LimitReader(resp.Body, int64(req.piece.RangeSize)), req.piece.PieceMd5) + reader, err = digest.NewReader(io.LimitReader(resp.Body, int64(req.piece.RangeSize)), digest.WithDigest(req.piece.PieceMd5), digest.WithLogger(req.log)) if err != nil { _ = closer.Close() req.log.Errorf("init digest reader error: %s", err.Error()) diff --git a/client/daemon/peer/piece_manager.go b/client/daemon/peer/piece_manager.go index 954f8615b..c5540bb8b 100644 --- a/client/daemon/peer/piece_manager.go +++ b/client/daemon/peer/piece_manager.go @@ -205,7 +205,7 @@ func (pm *pieceManager) processPieceFromSource(pt Task, } if pm.calculateDigest { pt.Log().Debugf("calculate digest") - reader, _ = digest.NewReader(pt.Log(), reader) + reader, _ = digest.NewReader(reader, digest.WithLogger(pt.Log())) } var n int64 result.Size, err = pt.GetStorage().WritePiece( @@ -239,7 +239,7 @@ func (pm *pieceManager) processPieceFromSource(pt Task, return } if pm.calculateDigest { - md5 = reader.(digest.Reader).Digest() + md5 = reader.(digest.Reader).Encoded() } return } @@ -296,7 +296,7 @@ func (pm *pieceManager) DownloadSource(ctx context.Context, pt Task, request *sc // calc total if pm.calculateDigest { - reader, err = digest.NewReader(pt.Log(), response.Body, request.UrlMeta.Digest) + reader, err = digest.NewReader(response.Body, digest.WithDigest(request.UrlMeta.Digest), digest.WithLogger(pt.Log())) if err != nil { log.Errorf("init digest reader error: %s", err.Error()) return err @@ -462,7 +462,7 @@ func (pm *pieceManager) processPieceFromFile(ctx context.Context, ptm storage.Pe if pm.calculateDigest { log.Debugf("calculate digest in processPieceFromFile") - reader, _ = digest.NewReader(log, r) + reader, _ = digest.NewReader(r, digest.WithLogger(log)) } n, err := tsd.WritePiece(ctx, &storage.WritePieceRequest{ diff --git a/client/daemon/storage/local_storage.go b/client/daemon/storage/local_storage.go index 6918e384e..1f5afbf5a 100644 --- a/client/daemon/storage/local_storage.go +++ b/client/daemon/storage/local_storage.go @@ -156,7 +156,7 @@ func (t *localTaskStore) WritePiece(ctx context.Context, req *WritePieceRequest) if req.PieceMetadata.Md5 == "" { t.Debugf("piece md5 not found in metadata, read from reader") if get, ok := req.Reader.(digest.Reader); ok { - req.PieceMetadata.Md5 = get.Digest() + req.PieceMetadata.Md5 = get.Encoded() t.Infof("read md5 from reader, value: %s", req.PieceMetadata.Md5) } else { t.Debugf("reader is not a digest.Reader") diff --git a/client/daemon/storage/local_storage_subtask.go b/client/daemon/storage/local_storage_subtask.go index 426966585..584fc67ad 100644 --- a/client/daemon/storage/local_storage_subtask.go +++ b/client/daemon/storage/local_storage_subtask.go @@ -104,7 +104,7 @@ func (t *localSubTaskStore) WritePiece(ctx context.Context, req *WritePieceReque if req.PieceMetadata.Md5 == "" { t.Debugf("piece md5 not found in metadata, read from reader") if get, ok := req.Reader.(digest.Reader); ok { - req.PieceMetadata.Md5 = get.Digest() + req.PieceMetadata.Md5 = get.Encoded() t.Infof("read md5 from reader, value: %s", req.PieceMetadata.Md5) } else { t.Debugf("reader is not a digest.Reader") diff --git a/client/dfget/dfget.go b/client/dfget/dfget.go index e661811d2..108f83423 100644 --- a/client/dfget/dfget.go +++ b/client/dfget/dfget.go @@ -174,11 +174,18 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st } if !pkgstrings.IsBlank(cfg.Digest) { - parsedHash := digest.Parse(cfg.Digest) - realHash := digest.HashFile(target.Name(), digest.Algorithms[parsedHash[0]]) + d, err := digest.Parse(cfg.Digest) + if err != nil { + return err + } - if realHash != "" && realHash != parsedHash[1] { - return errors.Errorf("%s digest is not matched: real[%s] expected[%s]", parsedHash[0], realHash, parsedHash[1]) + encoded, err := digest.HashFile(target.Name(), d.Algorithm) + if err != nil { + return err + } + + if encoded != "" && encoded != d.Encoded { + return errors.Errorf("%s digest is not matched: real[%s] expected[%s]", d.Algorithm, encoded, d.Encoded) } } diff --git a/client/dfget/dfget_test.go b/client/dfget/dfget_test.go index 862e2a4d5..b25a96a9e 100644 --- a/client/dfget/dfget_test.go +++ b/client/dfget/dfget_test.go @@ -53,7 +53,7 @@ func Test_downloadFromSource(t *testing.T) { cfg := &config.DfgetConfig{ URL: "http://a.b.c/xx", Output: output, - Digest: strings.Join([]string{digest.Sha256Hash.String(), digest.Sha256(content)}, ":"), + Digest: strings.Join([]string{digest.AlgorithmSHA256, digest.Sha256(content)}, ":"), } request, err := source.NewRequest(cfg.URL) assert.Nil(t, err) diff --git a/go.mod b/go.mod index d723179d1..09715e2a1 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/distribution/distribution/v3 v3.0.0-20220526142353-ffbd94cbe269 github.com/docker/go-connections v0.4.0 github.com/docker/go-units v0.4.0 - github.com/emirpasic/gods v1.18.1 github.com/envoyproxy/protoc-gen-validate v0.6.7 github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/static v0.0.1 @@ -33,7 +32,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/golang/mock v1.6.0 github.com/gomodule/redigo v2.0.0+incompatible - github.com/google/go-cmp v0.5.8 github.com/google/go-github v17.0.0+incompatible github.com/google/uuid v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 @@ -46,7 +44,6 @@ require ( github.com/montanaflynn/stats v0.6.6 github.com/onsi/ginkgo/v2 v2.1.4 github.com/onsi/gomega v1.19.0 - github.com/opencontainers/go-digest v1.0.0 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.12.2 @@ -122,6 +119,7 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/google/go-cmp v0.5.8 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/googleapis/gax-go/v2 v2.4.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -161,6 +159,7 @@ require ( github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pelletier/go-toml v1.9.5 // indirect diff --git a/go.sum b/go.sum index ffa2e71fe..7adc3595d 100644 --- a/go.sum +++ b/go.sum @@ -253,8 +253,6 @@ github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFP github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= -github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= -github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/pkg/digest/digest.go b/pkg/digest/digest.go index f291e2cdc..d83017343 100644 --- a/pkg/digest/digest.go +++ b/pkg/digest/digest.go @@ -19,117 +19,122 @@ package digest import ( "bufio" "crypto/md5" + "crypto/sha1" "crypto/sha256" + "crypto/sha512" "encoding/hex" + "errors" + "fmt" "hash" "io" "os" "strings" - - "github.com/opencontainers/go-digest" - - "d7y.io/dragonfly/v2/pkg/unit" ) const ( - Sha256Hash digest.Algorithm = "sha256" - Md5Hash digest.Algorithm = "md5" + // AlgorithmSHA1 is sha1 algorithm name of hash. + AlgorithmSHA1 = "sha1" + + // AlgorithmSHA256 is sha256 algorithm name of hash. + AlgorithmSHA256 = "sha256" + + // AlgorithmSHA512 is sha512 algorithm name of hash. + AlgorithmSHA512 = "sha512" + + // AlgorithmMD5 is md5 algorithm name of hash. + AlgorithmMD5 = "md5" ) -var ( - // Algorithms is used to check if an algorithm is supported. - // If algo is not supported, Algorithms[algo] will return empty string. - // Please don't use digest.Algorithm() to convert a string to digest.Algorithm. - Algorithms = map[string]digest.Algorithm{ - Sha256Hash.String(): Sha256Hash, - Md5Hash.String(): Md5Hash, - } -) +// Digest provides digest operation function. +type Digest struct { + // Algorithm is hash algorithm. + Algorithm string -func Sha256(values ...string) string { - if len(values) == 0 { - return "" - } - - h := sha256.New() - for _, content := range values { - if _, err := h.Write([]byte(content)); err != nil { - return "" - } - } - - return ToHashString(h) + // Encoded is hash encode. + Encoded string } -func Md5Reader(reader io.Reader) string { - h := md5.New() - if _, err := io.Copy(h, reader); err != nil { - return "" - } - - return ToHashString(h) -} - -func Md5Bytes(bytes []byte) string { - h := md5.New() - h.Write(bytes) - return ToHashString(h) -} - -// HashFile computes hash value corresponding to hashType, -// hashType is from digestutils.Md5Hash and digestutils.Sha256Hash. -func HashFile(path string, hashType digest.Algorithm) string { - file, err := os.Stat(path) - if err != nil { - return "" - } - - if !file.Mode().IsRegular() { - return "" - } - +// HashFile computes hash value corresponding to algorithm. +func HashFile(path string, algorithm string) (string, error) { f, err := os.Open(path) if err != nil { - return "" + return "", err } defer f.Close() var h hash.Hash - if hashType == Md5Hash { - h = md5.New() - } else if hashType == Sha256Hash { + switch algorithm { + case AlgorithmSHA1: + h = sha1.New() + case AlgorithmSHA256: h = sha256.New() - } else { - return "" + case AlgorithmSHA512: + h = sha512.New() + case AlgorithmMD5: + h = md5.New() + default: + return "", fmt.Errorf("unsupport digest method: %s", algorithm) } - r := bufio.NewReaderSize(f, int(4*unit.MB)) - + r := bufio.NewReader(f) _, err = io.Copy(h, r) if err != nil { + return "", err + } + + return hex.EncodeToString(h.Sum(nil)), nil +} + +// Parse uses to parse digest string to algorithm and encoded. +func Parse(digest string) (*Digest, error) { + values := strings.Split(digest, ":") + if len(values) == 2 { + return &Digest{ + Algorithm: values[0], + Encoded: values[1], + }, nil + } + + if len(values) == 1 { + return &Digest{ + Algorithm: AlgorithmMD5, + Encoded: values[0], + }, nil + } + + return nil, errors.New("invalid digest") +} + +// Sha256 computes the SHA256 checksum with multiple data. +func Sha256(data ...string) string { + if len(data) == 0 { return "" } - return ToHashString(h) -} + h := sha256.New() + for _, s := range data { + if _, err := h.Write([]byte(s)); err != nil { + return "" + } + } -func ToHashString(h hash.Hash) string { return hex.EncodeToString(h.Sum(nil)) } -func Parse(digest string) []string { - digest = strings.Trim(digest, " ") - return strings.Split(digest, ":") +// Md5Reader computes the MD5 checksum with io.Reader. +func Md5Reader(reader io.Reader) string { + h := md5.New() + r := bufio.NewReader(reader) + if _, err := io.Copy(h, r); err != nil { + return "" + } + + return hex.EncodeToString(h.Sum(nil)) } -func CreateHash(hashType string) hash.Hash { - algo := Algorithms[hashType] - switch algo { - case Sha256Hash: - return sha256.New() - case Md5Hash: - return md5.New() - default: - return nil - } +// Md5Bytes computes the MD5 checksum with []byte. +func Md5Bytes(bytes []byte) string { + h := md5.New() + h.Write(bytes) + return hex.EncodeToString(h.Sum(nil)) } diff --git a/pkg/digest/digest_reader.go b/pkg/digest/digest_reader.go index c299823b9..02ba8778e 100644 --- a/pkg/digest/digest_reader.go +++ b/pkg/digest/digest_reader.go @@ -14,6 +14,8 @@ * limitations under the License. */ +//go:generate mockgen -destination mocks/digest_reader_mock.go -source digest_reader.go -package mocks + package digest import ( @@ -25,86 +27,108 @@ import ( "fmt" "hash" "io" - "strings" "github.com/pkg/errors" logger "d7y.io/dragonfly/v2/internal/dflog" ) -var ( - ErrDigestNotMatch = errors.New("digest not match") -) - -// reader reads stream with RateLimiter. -type reader struct { - r io.Reader - hash hash.Hash - digest string - *logger.SugaredLoggerOnWith -} - // Reader is the interface used for reading resource. type Reader interface { io.Reader - Digest() string + Encoded() string +} + +// reader reads stream with RateLimiter. +type reader struct { + r io.Reader + hash hash.Hash + digest string + encoded string + logger *logger.SugaredLoggerOnWith +} + +// Option is a functional option for digest reader. +type Option func(reader *reader) + +// WithLogger sets the logger for digest reader. +func WithLogger(logger *logger.SugaredLoggerOnWith) Option { + return func(reader *reader) { + reader.logger = logger + } +} + +// WithDigest sets the digest to be verified. +func WithDigest(digest string) Option { + return func(reader *reader) { + reader.digest = digest + } } // TODO add AF_ALG digest https://github.com/golang/sys/commit/e24f485414aeafb646f6fca458b0bf869c0880a1 -func NewReader(log *logger.SugaredLoggerOnWith, r io.Reader, digest ...string) (io.Reader, error) { - var ( - d string - hashMethod hash.Hash - ) - if len(digest) > 0 { - d = digest[0] +func NewReader(r io.Reader, options ...Option) (io.Reader, error) { + reader := &reader{ + r: r, + hash: md5.New(), + logger: &logger.SugaredLoggerOnWith{}, } - ds := strings.Split(d, ":") - if len(ds) == 2 { - d = ds[1] - switch ds[0] { - case "sha1": - hashMethod = sha1.New() - case "sha256": - hashMethod = sha256.New() - case "sha512": - hashMethod = sha512.New() - case "md5": - hashMethod = md5.New() - default: - return nil, fmt.Errorf("unsupport digest method: %s", ds[0]) + + for _, opt := range options { + opt(reader) + } + + if reader.digest != "" { + d, err := Parse(reader.digest) + if err != nil { + return nil, errors.New("invalid digest") } - } else { - hashMethod = md5.New() + + var h hash.Hash + switch d.Algorithm { + case AlgorithmSHA1: + h = sha1.New() + case AlgorithmSHA256: + h = sha256.New() + case AlgorithmSHA512: + h = sha512.New() + case AlgorithmMD5: + h = md5.New() + default: + return nil, fmt.Errorf("unsupport digest method: %s", d.Algorithm) + } + + reader.encoded = d.Encoded + reader.hash = h } - return &reader{ - SugaredLoggerOnWith: log, - digest: d, - hash: hashMethod, - r: r, - }, nil + + return reader, nil } -func (dr *reader) Read(p []byte) (int, error) { - n, err := dr.r.Read(p) +// Read uses to read content and validate encoded. +func (r *reader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) if err != nil && err != io.EOF { return n, err } + if n > 0 { - dr.hash.Write(p[:n]) + r.hash.Write(p[:n]) } - if err == io.EOF && dr.digest != "" { - digest := dr.Digest() - if digest != dr.digest { - dr.Warnf("digest not match, desired: %s, actual: %s", dr.digest, digest) - return n, ErrDigestNotMatch + + if err == io.EOF && r.digest != "" { + encoded := r.Encoded() + if encoded != r.encoded { + r.logger.Warnf("digest encoded not match, desired: %s, actual: %s", r.encoded, encoded) + return n, errors.New("digest encoded not match") } - dr.Debugf("digest match: %s", digest) + + r.logger.Debugf("digest encoded match: %s", encoded) } + return n, err } -// Digest returns the digest of contents. -func (dr *reader) Digest() string { - return hex.EncodeToString(dr.hash.Sum(nil)) +// Encoded returns the encoded of algorithm. +func (r *reader) Encoded() string { + return hex.EncodeToString(r.hash.Sum(nil)) } diff --git a/pkg/digest/digest_reader_test.go b/pkg/digest/digest_reader_test.go index f05e99a5b..362e1af1e 100644 --- a/pkg/digest/digest_reader_test.go +++ b/pkg/digest/digest_reader_test.go @@ -24,7 +24,6 @@ import ( "crypto/sha512" "encoding/hex" "io" - "os" "testing" testifyassert "github.com/stretchr/testify/assert" @@ -32,10 +31,6 @@ import ( logger "d7y.io/dragonfly/v2/internal/dflog" ) -func TestMain(m *testing.M) { - os.Exit(m.Run()) -} - func TestNewReader(t *testing.T) { assert := testifyassert.New(t) @@ -95,7 +90,7 @@ func TestNewReader(t *testing.T) { t.Run(tc.name, func(t *testing.T) { digest := tc.digest(tc.data) buf := bytes.NewBuffer(tc.data) - reader, err := NewReader(logger.With("test", "test"), buf, digest) + reader, err := NewReader(buf, WithDigest(digest), WithLogger(logger.With("test", "test"))) assert.Nil(err) data, err := io.ReadAll(reader) assert.Nil(err) diff --git a/pkg/digest/digest_test.go b/pkg/digest/digest_test.go index 7fa499ca9..313ef7be8 100644 --- a/pkg/digest/digest_test.go +++ b/pkg/digest/digest_test.go @@ -18,6 +18,7 @@ package digest import ( "crypto/md5" + "encoding/hex" "io/fs" "os" "path/filepath" @@ -45,7 +46,7 @@ func TestToHashString(t *testing.T) { var expected = "5d41402abc4b2a76b9719d911017c592" h := md5.New() h.Write([]byte("hello")) - assert.Equal(t, expected, ToHashString(h)) + assert.Equal(t, expected, hex.EncodeToString(h.Sum(nil))) } func TestMd5Reader(t *testing.T) { @@ -63,6 +64,7 @@ func TestHashFile(t *testing.T) { if _, err := f.Write([]byte("hello")); err != nil { t.Fatal(err) } - - assert.Equal(t, expected, HashFile(path, Md5Hash)) + encoded, err := HashFile(path, AlgorithmMD5) + assert.NoError(t, err) + assert.Equal(t, expected, encoded) } diff --git a/pkg/digest/mocks/digest_reader_mock.go b/pkg/digest/mocks/digest_reader_mock.go new file mode 100644 index 000000000..c52dbfe1b --- /dev/null +++ b/pkg/digest/mocks/digest_reader_mock.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: digest_reader.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockReader is a mock of Reader interface. +type MockReader struct { + ctrl *gomock.Controller + recorder *MockReaderMockRecorder +} + +// MockReaderMockRecorder is the mock recorder for MockReader. +type MockReaderMockRecorder struct { + mock *MockReader +} + +// NewMockReader creates a new mock instance. +func NewMockReader(ctrl *gomock.Controller) *MockReader { + mock := &MockReader{ctrl: ctrl} + mock.recorder = &MockReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReader) EXPECT() *MockReaderMockRecorder { + return m.recorder +} + +// Encoded mocks base method. +func (m *MockReader) Encoded() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encoded") + ret0, _ := ret[0].(string) + return ret0 +} + +// Encoded indicates an expected call of Encoded. +func (mr *MockReaderMockRecorder) Encoded() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoded", reflect.TypeOf((*MockReader)(nil).Encoded)) +} + +// Read mocks base method. +func (m *MockReader) Read(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockReaderMockRecorder) Read(p interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), p) +}