Update to gotest.tools/v3

We have been using a version over 6 years old.

Signed-off-by: Miloslav Trmač <mitr@redhat.com>
This commit is contained in:
Miloslav Trmač 2025-03-14 03:59:17 +01:00
parent 05de8c7758
commit dd32248f47
23 changed files with 960 additions and 442 deletions

View File

@ -14,8 +14,8 @@ import (
"github.com/containers/storage/pkg/system" "github.com/containers/storage/pkg/system"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"gotest.tools/assert" "gotest.tools/v3/assert"
is "gotest.tools/assert/cmp" is "gotest.tools/v3/assert/cmp"
) )
func TestCopy(t *testing.T) { func TestCopy(t *testing.T) {

2
go.mod
View File

@ -31,7 +31,7 @@ require (
github.com/vbatts/tar-split v0.12.1 github.com/vbatts/tar-split v0.12.1
golang.org/x/sync v0.12.0 golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0 golang.org/x/sys v0.31.0
gotest.tools v2.2.0+incompatible gotest.tools/v3 v3.5.2
) )
require ( require (

4
go.sum
View File

@ -198,7 +198,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -14,7 +14,7 @@ import (
"github.com/containers/storage/pkg/archive" "github.com/containers/storage/pkg/archive"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"gotest.tools/assert" "gotest.tools/v3/assert"
) )
// Test for CVE-2018-15664 // Test for CVE-2018-15664

View File

@ -11,7 +11,7 @@ import (
"github.com/containers/storage/pkg/unshare" "github.com/containers/storage/pkg/unshare"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gotest.tools/assert" "gotest.tools/v3/assert"
) )
func TestGetRootlessStorageOpts(t *testing.T) { func TestGetRootlessStorageOpts(t *testing.T) {

View File

@ -7,7 +7,7 @@ import (
"testing" "testing"
"github.com/containers/storage/pkg/unshare" "github.com/containers/storage/pkg/unshare"
"gotest.tools/assert" "gotest.tools/v3/assert"
) )
func TestDefaultStoreOpts(t *testing.T) { func TestDefaultStoreOpts(t *testing.T) {

View File

@ -1,311 +0,0 @@
/*Package assert provides assertions for comparing expected values to actual
values. When an assertion fails a helpful error message is printed.
Assert and Check
Assert() and Check() both accept a Comparison, and fail the test when the
comparison fails. The one difference is that Assert() will end the test execution
immediately (using t.FailNow()) whereas Check() will fail the test (using t.Fail()),
return the value of the comparison, then proceed with the rest of the test case.
Example usage
The example below shows assert used with some common types.
import (
"testing"
"gotest.tools/assert"
is "gotest.tools/assert/cmp"
)
func TestEverything(t *testing.T) {
// booleans
assert.Assert(t, ok)
assert.Assert(t, !missing)
// primitives
assert.Equal(t, count, 1)
assert.Equal(t, msg, "the message")
assert.Assert(t, total != 10) // NotEqual
// errors
assert.NilError(t, closer.Close())
assert.Error(t, err, "the exact error message")
assert.ErrorContains(t, err, "includes this")
assert.ErrorType(t, err, os.IsNotExist)
// complex types
assert.DeepEqual(t, result, myStruct{Name: "title"})
assert.Assert(t, is.Len(items, 3))
assert.Assert(t, len(sequence) != 0) // NotEmpty
assert.Assert(t, is.Contains(mapping, "key"))
// pointers and interface
assert.Assert(t, is.Nil(ref))
assert.Assert(t, ref != nil) // NotNil
}
Comparisons
Package https://godoc.org/gotest.tools/assert/cmp provides
many common comparisons. Additional comparisons can be written to compare
values in other ways. See the example Assert (CustomComparison).
Automated migration from testify
gty-migrate-from-testify is a binary which can update source code which uses
testify assertions to use the assertions provided by this package.
See http://bit.do/cmd-gty-migrate-from-testify.
*/
package assert // import "gotest.tools/assert"
import (
"fmt"
"go/ast"
"go/token"
gocmp "github.com/google/go-cmp/cmp"
"gotest.tools/assert/cmp"
"gotest.tools/internal/format"
"gotest.tools/internal/source"
)
// BoolOrComparison can be a bool, or cmp.Comparison. See Assert() for usage.
type BoolOrComparison interface{}
// TestingT is the subset of testing.T used by the assert package.
type TestingT interface {
FailNow()
Fail()
Log(args ...interface{})
}
type helperT interface {
Helper()
}
const failureMessage = "assertion failed: "
// nolint: gocyclo
func assert(
t TestingT,
failer func(),
argSelector argSelector,
comparison BoolOrComparison,
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)
// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)
case nil:
return true
case error:
msg := "error is not nil: "
t.Log(format.WithCustomMessage(failureMessage+msg+check.Error(), msgAndArgs...))
case cmp.Comparison:
success = runComparison(t, argSelector, check, msgAndArgs...)
case func() cmp.Result:
success = runComparison(t, argSelector, check, msgAndArgs...)
default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}
if success {
return true
}
failer()
return false
}
func runCompareFunc(
t TestingT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}
func logFailureFromBool(t TestingT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), formatFailureFromBool()
const comparisonArgPos = 1
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
return
}
msg, err := boolFailureMessage(args[comparisonArgPos])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}
func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
return x + " is " + y, nil
}
if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}
formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}
// Assert performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is logged, and execution is stopped immediately.
//
// The comparison argument may be one of three types: bool, cmp.Comparison or
// error.
// When called with a bool the failure message will contain the literal source
// code of the expression.
// When called with a cmp.Comparison the comparison is responsible for producing
// a helpful failure message.
// When called with an error a nil value is considered success. A non-nil error
// is a failure, and Error() is used as the failure message.
func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsFromComparisonCall, comparison, msgAndArgs...)
}
// Check performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is logged, and Check returns false. Otherwise returns
// true.
//
// See Assert for details about the comparison arg and failure messages.
func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
return assert(t, t.Fail, argsFromComparisonCall, comparison, msgAndArgs...)
}
// NilError fails the test immediately if err is not nil.
// This is equivalent to Assert(t, err)
func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, err, msgAndArgs...)
}
// Equal uses the == operator to assert two values are equal and fails the test
// if they are not equal.
//
// If the comparison fails Equal will use the variable names for x and y as part
// of the failure message to identify the actual and expected values.
//
// If either x or y are a multi-line string the failure message will include a
// unified diff of the two values. If the values only differ by whitespace
// the unified diff will be augmented by replacing whitespace characters with
// visible characters to identify the whitespace difference.
//
// This is equivalent to Assert(t, cmp.Equal(x, y)).
func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Equal(x, y), msgAndArgs...)
}
// DeepEqual uses google/go-cmp (http://bit.do/go-cmp) to assert two values are
// equal and fails the test if they are not equal.
//
// Package https://godoc.org/gotest.tools/assert/opt provides some additional
// commonly used Options.
//
// This is equivalent to Assert(t, cmp.DeepEqual(x, y)).
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.DeepEqual(x, y, opts...))
}
// Error fails the test if err is nil, or the error message is not the expected
// message.
// Equivalent to Assert(t, cmp.Error(err, message)).
func Error(t TestingT, err error, message string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Error(err, message), msgAndArgs...)
}
// ErrorContains fails the test if err is nil, or the error message does not
// contain the expected substring.
// Equivalent to Assert(t, cmp.ErrorContains(err, substring)).
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...)
}
// ErrorType fails the test if err is nil, or err is not the expected type.
//
// Expected can be one of:
// a func(error) bool which returns true if the error is the expected type,
// an instance of (or a pointer to) a struct of the expected type,
// a pointer to an interface the error is expected to implement,
// a reflect.Type of the expected struct or interface.
//
// Equivalent to Assert(t, cmp.ErrorType(err, expected)).
func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorType(err, expected), msgAndArgs...)
}

311
vendor/gotest.tools/v3/assert/assert.go vendored Normal file
View File

@ -0,0 +1,311 @@
/*
Package assert provides assertions for comparing expected values to actual
values in tests. When an assertion fails a helpful error message is printed.
# Example usage
All the assertions in this package use [testing.T.Helper] to mark themselves as
test helpers. This allows the testing package to print the filename and line
number of the file function that failed.
assert.NilError(t, err)
// filename_test.go:212: assertion failed: error is not nil: file not found
If any assertion is called from a helper function, make sure to call t.Helper
from the helper function so that the filename and line number remain correct.
The examples below show assert used with some common types and the failure
messages it produces. The filename and line number portion of the failure
message is omitted from these examples for brevity.
// booleans
assert.Assert(t, ok)
// assertion failed: ok is false
assert.Assert(t, !missing)
// assertion failed: missing is true
// primitives
assert.Equal(t, count, 1)
// assertion failed: 0 (count int) != 1 (int)
assert.Equal(t, msg, "the message")
// assertion failed: my message (msg string) != the message (string)
assert.Assert(t, total != 10) // use Assert for NotEqual
// assertion failed: total is 10
assert.Assert(t, count > 20, "count=%v", count)
// assertion failed: count is <= 20: count=1
// errors
assert.NilError(t, closer.Close())
// assertion failed: error is not nil: close /file: errno 11
assert.Error(t, err, "the exact error message")
// assertion failed: expected error "the exact error message", got "oops"
assert.ErrorContains(t, err, "includes this")
// assertion failed: expected error to contain "includes this", got "oops"
assert.ErrorIs(t, err, os.ErrNotExist)
// assertion failed: error is "oops", not "file does not exist" (os.ErrNotExist)
// complex types
assert.DeepEqual(t, result, myStruct{Name: "title"})
// assertion failed: ... (diff of the two structs)
assert.Assert(t, is.Len(items, 3))
// assertion failed: expected [] (length 0) to have length 3
assert.Assert(t, len(sequence) != 0) // use Assert for NotEmpty
// assertion failed: len(sequence) is 0
assert.Assert(t, is.Contains(mapping, "key"))
// assertion failed: map[other:1] does not contain key
// pointers and interface
assert.Assert(t, ref == nil)
// assertion failed: ref is not nil
assert.Assert(t, ref != nil) // use Assert for NotNil
// assertion failed: ref is nil
# Assert and Check
[Assert] and [Check] are very similar, they both accept a [cmp.Comparison], and fail
the test when the comparison fails. The one difference is that Assert uses
[testing.T.FailNow] to fail the test, which will end the test execution immediately.
Check uses [testing.T.Fail] to fail the test, which allows it to return the
result of the comparison, then proceed with the rest of the test case.
Like [testing.T.FailNow], [Assert] must be called from the goroutine running the test,
not from other goroutines created during the test. [Check] is safe to use from any
goroutine.
# Comparisons
Package [gotest.tools/v3/assert/cmp] provides
many common comparisons. Additional comparisons can be written to compare
values in other ways. See the example Assert (CustomComparison).
# Automated migration from testify
gty-migrate-from-testify is a command which translates Go source code from
testify assertions to the assertions provided by this package.
See http://pkg.go.dev/gotest.tools/v3/assert/cmd/gty-migrate-from-testify.
*/
package assert // import "gotest.tools/v3/assert"
import (
gocmp "github.com/google/go-cmp/cmp"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/assert"
)
// BoolOrComparison can be a bool, [cmp.Comparison], or error. See [Assert] for
// details about how this type is used.
type BoolOrComparison interface{}
// TestingT is the subset of [testing.T] (see also [testing.TB]) used by the assert package.
type TestingT interface {
FailNow()
Fail()
Log(args ...interface{})
}
type helperT interface {
Helper()
}
// Assert performs a comparison. If the comparison fails, the test is marked as
// failed, a failure message is logged, and execution is stopped immediately.
//
// The comparison argument may be one of three types:
//
// bool
// True is success. False is a failure. The failure message will contain
// the literal source code of the expression.
//
// cmp.Comparison
// Uses cmp.Result.Success() to check for success or failure.
// The comparison is responsible for producing a helpful failure message.
// http://pkg.go.dev/gotest.tools/v3/assert/cmp provides many common comparisons.
//
// error
// A nil value is considered success, and a non-nil error is a failure.
// The return value of error.Error is used as the failure message.
//
// Extra details can be added to the failure message using msgAndArgs. msgAndArgs
// may be either a single string, or a format string and args that will be
// passed to [fmt.Sprintf].
//
// Assert uses [testing.TB.FailNow] to fail the test. Like t.FailNow, Assert must be called
// from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] from other goroutines.
func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.FailNow()
}
}
// Check performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is printed, and Check returns false. If the comparison
// is successful Check returns true. Check may be called from any goroutine.
//
// See [Assert] for details about the comparison arg and failure messages.
func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.Fail()
return false
}
return true
}
// NilError fails the test immediately if err is not nil, and includes err.Error
// in the failure message.
//
// NilError uses [testing.TB.FailNow] to fail the test. Like t.FailNow, NilError must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] from other goroutines.
func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, err, msgAndArgs...) {
t.FailNow()
}
}
// Equal uses the == operator to assert two values are equal and fails the test
// if they are not equal.
//
// If the comparison fails Equal will use the variable names and types of
// x and y as part of the failure message to identify the actual and expected
// values.
//
// assert.Equal(t, actual, expected)
// // main_test.go:41: assertion failed: 1 (actual int) != 21 (expected int32)
//
// If either x or y are a multi-line string the failure message will include a
// unified diff of the two values. If the values only differ by whitespace
// the unified diff will be augmented by replacing whitespace characters with
// visible characters to identify the whitespace difference.
//
// Equal uses [testing.T.FailNow] to fail the test. Like t.FailNow, Equal must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.Equal] from other
// goroutines.
func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.Equal(x, y), msgAndArgs...) {
t.FailNow()
}
}
// DeepEqual uses [github.com/google/go-cmp/cmp]
// to assert two values are equal and fails the test if they are not equal.
//
// Package [gotest.tools/v3/assert/opt] provides some additional
// commonly used Options.
//
// DeepEqual uses [testing.T.FailNow] to fail the test. Like t.FailNow, DeepEqual must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.DeepEqual] from other
// goroutines.
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.DeepEqual(x, y, opts...)) {
t.FailNow()
}
}
// Error fails the test if err is nil, or if err.Error is not equal to expected.
// Both err.Error and expected will be included in the failure message.
// Error performs an exact match of the error text. Use [ErrorContains] if only
// part of the error message is relevant. Use [ErrorType] or [ErrorIs] to compare
// errors by type.
//
// Error uses [testing.T.FailNow] to fail the test. Like t.FailNow, Error must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.Error] from other
// goroutines.
func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.Error(err, expected), msgAndArgs...) {
t.FailNow()
}
}
// ErrorContains fails the test if err is nil, or if err.Error does not
// contain the expected substring. Both err.Error and the expected substring
// will be included in the failure message.
//
// ErrorContains uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorContains
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorContains] from other
// goroutines.
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...) {
t.FailNow()
}
}
// ErrorType fails the test if err is nil, or err is not the expected type.
// New code should use ErrorIs instead.
//
// Expected can be one of:
//
// func(error) bool
// The function should return true if the error is the expected type.
//
// struct{} or *struct{}
// A struct or a pointer to a struct. The assertion fails if the error is
// not of the same type.
//
// *interface{}
// A pointer to an interface type. The assertion fails if err does not
// implement the interface.
//
// reflect.Type
// The assertion fails if err does not implement the reflect.Type.
//
// ErrorType uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorType
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorType] from other
// goroutines.
func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorType(err, expected), msgAndArgs...) {
t.FailNow()
}
}
// ErrorIs fails the test if err is nil, or the error does not match expected
// when compared using errors.Is. See [errors.Is] for
// accepted arguments.
//
// ErrorIs uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorIs
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorIs] from other
// goroutines.
func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorIs(err, expected), msgAndArgs...) {
t.FailNow()
}
}

View File

@ -1,26 +1,27 @@
/*Package cmp provides Comparisons for Assert and Check*/ /*Package cmp provides Comparisons for Assert and Check*/
package cmp // import "gotest.tools/assert/cmp" package cmp // import "gotest.tools/v3/assert/cmp"
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"gotest.tools/internal/format" "gotest.tools/v3/internal/format"
) )
// Comparison is a function which compares values and returns ResultSuccess if // Comparison is a function which compares values and returns [ResultSuccess] if
// the actual value matches the expected value. If the values do not match the // the actual value matches the expected value. If the values do not match the
// Result will contain a message about why it failed. // [Result] will contain a message about why it failed.
type Comparison func() Result type Comparison func() Result
// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp) // DeepEqual compares two values using [github.com/google/go-cmp/cmp]
// and succeeds if the values are equal. // and succeeds if the values are equal.
// //
// The comparison can be customized using comparison Options. // The comparison can be customized using comparison Options.
// Package https://godoc.org/gotest.tools/assert/opt provides some additional // Package [gotest.tools/v3/assert/opt] provides some additional
// commonly used Options. // commonly used Options.
func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
return func() (result Result) { return func() (result Result) {
@ -33,7 +34,7 @@ func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
if diff == "" { if diff == "" {
return ResultSuccess return ResultSuccess
} }
return multiLineDiffResult(diff) return multiLineDiffResult(diff, x, y)
} }
} }
@ -59,16 +60,17 @@ func toResult(success bool, msg string) Result {
return ResultFailure(msg) return ResultFailure(msg)
} }
// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid // RegexOrPattern may be either a [*regexp.Regexp] or a string that is a valid
// regexp pattern. // regexp pattern.
type RegexOrPattern interface{} type RegexOrPattern interface{}
// Regexp succeeds if value v matches regular expression re. // Regexp succeeds if value v matches regular expression re.
// //
// Example: // Example:
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str)) //
// r := regexp.MustCompile("^[0-9a-f]{32}$") // assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
// assert.Assert(t, cmp.Regexp(r, str)) // r := regexp.MustCompile("^[0-9a-f]{32}$")
// assert.Assert(t, cmp.Regexp(r, str))
func Regexp(re RegexOrPattern, v string) Comparison { func Regexp(re RegexOrPattern, v string) Comparison {
match := func(re *regexp.Regexp) Result { match := func(re *regexp.Regexp) Result {
return toResult( return toResult(
@ -92,7 +94,7 @@ func Regexp(re RegexOrPattern, v string) Comparison {
} }
} }
// Equal succeeds if x == y. See assert.Equal for full documentation. // Equal succeeds if x == y. See [gotest.tools/v3/assert.Equal] for full documentation.
func Equal(x, y interface{}) Comparison { func Equal(x, y interface{}) Comparison {
return func() Result { return func() Result {
switch { switch {
@ -100,13 +102,13 @@ func Equal(x, y interface{}) Comparison {
return ResultSuccess return ResultSuccess
case isMultiLineStringCompare(x, y): case isMultiLineStringCompare(x, y):
diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
return multiLineDiffResult(diff) return multiLineDiffResult(diff, x, y)
} }
return ResultFailureTemplate(` return ResultFailureTemplate(`
{{- .Data.x}} ( {{- printf "%v" .Data.x}} (
{{- with callArg 0 }}{{ formatNode . }} {{end -}} {{- with callArg 0 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.x -}} {{- printf "%T" .Data.x -}}
) != {{ .Data.y}} ( ) != {{ printf "%v" .Data.y}} (
{{- with callArg 1 }}{{ formatNode . }} {{end -}} {{- with callArg 1 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.y -}} {{- printf "%T" .Data.y -}}
)`, )`,
@ -126,12 +128,12 @@ func isMultiLineStringCompare(x, y interface{}) bool {
return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
} }
func multiLineDiffResult(diff string) Result { func multiLineDiffResult(diff string, x, y interface{}) Result {
return ResultFailureTemplate(` return ResultFailureTemplate(`
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}{{end}} --- {{ with callArg 0 }}{{ formatNode . }}{{else}}{{end}}
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}{{end}} +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}{{end}}
{{ .Data.diff }}`, {{ .Data.diff }}`,
map[string]interface{}{"diff": diff}) map[string]interface{}{"diff": diff, "x": x, "y": y})
} }
// Len succeeds if the sequence has the expected length. // Len succeeds if the sequence has the expected length.
@ -156,15 +158,15 @@ func Len(seq interface{}, expected int) Comparison {
// slice, or array. // slice, or array.
// //
// If collection is a string, item must also be a string, and is compared using // If collection is a string, item must also be a string, and is compared using
// strings.Contains(). // [strings.Contains].
// If collection is a Map, contains will succeed if item is a key in the map. // If collection is a Map, contains will succeed if item is a key in the map.
// If collection is a slice or array, item is compared to each item in the // If collection is a slice or array, item is compared to each item in the
// sequence using reflect.DeepEqual(). // sequence using [reflect.DeepEqual].
func Contains(collection interface{}, item interface{}) Comparison { func Contains(collection interface{}, item interface{}) Comparison {
return func() Result { return func() Result {
colValue := reflect.ValueOf(collection) colValue := reflect.ValueOf(collection)
if !colValue.IsValid() { if !colValue.IsValid() {
return ResultFailure(fmt.Sprintf("nil does not contain items")) return ResultFailure("nil does not contain items")
} }
msg := fmt.Sprintf("%v does not contain %v", collection, item) msg := fmt.Sprintf("%v does not contain %v", collection, item)
@ -241,10 +243,13 @@ func ErrorContains(err error, substring string) Comparison {
} }
} }
type causer interface {
Cause() error
}
func formatErrorMessage(err error) string { func formatErrorMessage(err error) string {
if _, ok := err.(interface { //nolint:errorlint,nolintlint // unwrapping is not appropriate here
Cause() error if _, ok := err.(causer); ok {
}); ok {
return fmt.Sprintf("%q\n%+v", err, err) return fmt.Sprintf("%q\n%+v", err, err)
} }
// This error was not wrapped with github.com/pkg/errors // This error was not wrapped with github.com/pkg/errors
@ -253,7 +258,7 @@ func formatErrorMessage(err error) string {
// Nil succeeds if obj is a nil interface, pointer, or function. // Nil succeeds if obj is a nil interface, pointer, or function.
// //
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices, // Use [gotest.tools/v3/assert.NilError] for comparing errors. Use Len(obj, 0) for comparing slices,
// maps, and channels. // maps, and channels.
func Nil(obj interface{}) Comparison { func Nil(obj interface{}) Comparison {
msgFunc := func(value reflect.Value) string { msgFunc := func(value reflect.Value) string {
@ -281,12 +286,27 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
} }
// ErrorType succeeds if err is not nil and is of the expected type. // ErrorType succeeds if err is not nil and is of the expected type.
// New code should use [ErrorIs] instead.
// //
// Expected can be one of: // Expected can be one of:
// a func(error) bool which returns true if the error is the expected type, //
// an instance of (or a pointer to) a struct of the expected type, // func(error) bool
// a pointer to an interface the error is expected to implement, //
// a reflect.Type of the expected struct or interface. // Function should return true if the error is the expected type.
//
// type struct{}, type &struct{}
//
// A struct or a pointer to a struct.
// Fails if the error is not of the same type as expected.
//
// type &interface{}
//
// A pointer to an interface type.
// Fails if err does not implement the interface.
//
// reflect.Type
//
// Fails if err does not implement the [reflect.Type].
func ErrorType(err error, expected interface{}) Comparison { func ErrorType(err error, expected interface{}) Comparison {
return func() Result { return func() Result {
switch expectedType := expected.(type) { switch expectedType := expected.(type) {
@ -298,7 +318,7 @@ func ErrorType(err error, expected interface{}) Comparison {
} }
return cmpErrorTypeEqualType(err, expectedType) return cmpErrorTypeEqualType(err, expectedType)
case nil: case nil:
return ResultFailure(fmt.Sprintf("invalid type for expected: nil")) return ResultFailure("invalid type for expected: nil")
} }
expectedType := reflect.TypeOf(expected) expectedType := reflect.TypeOf(expected)
@ -354,3 +374,30 @@ func isPtrToInterface(typ reflect.Type) bool {
func isPtrToStruct(typ reflect.Type) bool { func isPtrToStruct(typ reflect.Type) bool {
return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
} }
var (
stdlibErrorNewType = reflect.TypeOf(errors.New(""))
stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
)
// ErrorIs succeeds if errors.Is(actual, expected) returns true. See
// [errors.Is] for accepted argument values.
func ErrorIs(actual error, expected error) Comparison {
return func() Result {
if errors.Is(actual, expected) {
return ResultSuccess
}
// The type of stdlib errors is excluded because the type is not relevant
// in those cases. The type is only important when it is a user defined
// custom error type.
return ResultFailureTemplate(`error is
{{- if not .Data.a }} nil,{{ else }}
{{- printf " \"%v\"" .Data.a }}
{{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
{{- end }} not {{ printf "\"%v\"" .Data.x }} (
{{- with callArg 1 }}{{ formatNode . }}{{ end }}
{{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
map[string]interface{}{"a": actual, "x": expected})
}
}

View File

@ -4,39 +4,46 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"go/ast" "go/ast"
"reflect"
"text/template" "text/template"
"gotest.tools/internal/source" "gotest.tools/v3/internal/source"
) )
// Result of a Comparison. // A Result of a [Comparison].
type Result interface { type Result interface {
Success() bool Success() bool
} }
type result struct { // StringResult is an implementation of [Result] that reports the error message
// string verbatim and does not provide any templating or formatting of the
// message.
type StringResult struct {
success bool success bool
message string message string
} }
func (r result) Success() bool { // Success returns true if the comparison was successful.
func (r StringResult) Success() bool {
return r.success return r.success
} }
func (r result) FailureMessage() string { // FailureMessage returns the message used to provide additional information
// about the failure.
func (r StringResult) FailureMessage() string {
return r.message return r.message
} }
// ResultSuccess is a constant which is returned by a ComparisonWithResult to // ResultSuccess is a constant which is returned by a [Comparison] to
// indicate success. // indicate success.
var ResultSuccess = result{success: true} var ResultSuccess = StringResult{success: true}
// ResultFailure returns a failed Result with a failure message. // ResultFailure returns a failed [Result] with a failure message.
func ResultFailure(message string) Result { func ResultFailure(message string) StringResult {
return result{message: message} return StringResult{message: message}
} }
// ResultFromError returns ResultSuccess if err is nil. Otherwise ResultFailure // ResultFromError returns [ResultSuccess] if err is nil. Otherwise [ResultFailure]
// is returned with the error message as the failure message. // is returned with the error message as the failure message.
func ResultFromError(err error) Result { func ResultFromError(err error) Result {
if err == nil { if err == nil {
@ -46,13 +53,12 @@ func ResultFromError(err error) Result {
} }
type templatedResult struct { type templatedResult struct {
success bool
template string template string
data map[string]interface{} data map[string]interface{}
} }
func (r templatedResult) Success() bool { func (r templatedResult) Success() bool {
return r.success return false
} }
func (r templatedResult) FailureMessage(args []ast.Expr) string { func (r templatedResult) FailureMessage(args []ast.Expr) string {
@ -63,7 +69,12 @@ func (r templatedResult) FailureMessage(args []ast.Expr) string {
return msg return msg
} }
// ResultFailureTemplate returns a Result with a template string and data which func (r templatedResult) UpdatedExpected(stackIndex int) error {
// TODO: would be nice to have structured data instead of a map
return source.UpdateExpectedValue(stackIndex+1, r.data["x"], r.data["y"])
}
// ResultFailureTemplate returns a [Result] with a template string and data which
// can be used to format a failure message. The template may access data from .Data, // can be used to format a failure message. The template may access data from .Data,
// the comparison args with the callArg function, and the formatNode function may // the comparison args with the callArg function, and the formatNode function may
// be used to format the call args. // be used to format the call args.
@ -80,6 +91,11 @@ func renderMessage(result templatedResult, args []ast.Expr) (string, error) {
} }
return args[index] return args[index]
}, },
// TODO: any way to include this from ErrorIS instead of here?
"notStdlibErrorType": func(typ interface{}) bool {
r := reflect.TypeOf(typ)
return r != stdlibFmtErrorType && r != stdlibErrorNewType
},
}) })
var err error var err error
tmpl, err = tmpl.Parse(result.template) tmpl, err = tmpl.Parse(result.template)

View File

@ -0,0 +1,160 @@
// Package assert provides internal utilties for assertions.
package assert
import (
"fmt"
"go/ast"
"go/token"
"reflect"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/format"
"gotest.tools/v3/internal/source"
)
// LogT is the subset of testing.T used by the assert package.
type LogT interface {
Log(args ...interface{})
}
type helperT interface {
Helper()
}
const failureMessage = "assertion failed: "
// Eval the comparison and print a failure messages if the comparison has failed.
func Eval(
t LogT,
argSelector argSelector,
comparison interface{},
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)
// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)
case nil:
return true
case error:
msg := failureMsgFromError(check)
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
case cmp.Comparison:
success = RunComparison(t, argSelector, check, msgAndArgs...)
case func() cmp.Result:
success = RunComparison(t, argSelector, check, msgAndArgs...)
default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}
return success
}
func runCompareFunc(
t LogT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}
func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
}
var msg string
const comparisonArgIndex = 1 // Assert(t, comparison)
if len(args) <= comparisonArgIndex {
msg = "but assert failed to find the expression to print"
} else {
msg, err = boolFailureMessage(args[comparisonArgIndex])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}
}
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}
func failureMsgFromError(err error) string {
// Handle errors with non-nil types
v := reflect.ValueOf(err)
if v.Kind() == reflect.Ptr && v.IsNil() {
return fmt.Sprintf("error is not nil: error has type %T", err)
}
return "error is not nil: " + err.Error()
}
func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
switch binaryExpr.Op {
case token.NEQ:
return x + " is " + y, nil
case token.EQL:
return x + " is not " + y, nil
case token.GTR:
return x + " is <= " + y, nil
case token.LSS:
return x + " is >= " + y, nil
case token.GEQ:
return x + " is less than " + y, nil
case token.LEQ:
return x + " is greater than " + y, nil
}
}
if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}
if ident, ok := expr.(*ast.Ident); ok {
return ident.Name + " is false", nil
}
formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}

View File

@ -1,16 +1,19 @@
package assert package assert
import ( import (
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"gotest.tools/assert/cmp" "gotest.tools/v3/assert/cmp"
"gotest.tools/internal/format" "gotest.tools/v3/internal/format"
"gotest.tools/internal/source" "gotest.tools/v3/internal/source"
) )
func runComparison( // RunComparison and return Comparison.Success. If the comparison fails a messages
t TestingT, // will be printed using t.Log.
func RunComparison(
t LogT,
argSelector argSelector, argSelector argSelector,
f cmp.Comparison, f cmp.Comparison,
msgAndArgs ...interface{}, msgAndArgs ...interface{},
@ -23,10 +26,26 @@ func runComparison(
return true return true
} }
if source.IsUpdate() {
if updater, ok := result.(updateExpected); ok {
const stackIndex = 3 // Assert/Check, assert, RunComparison
err := updater.UpdatedExpected(stackIndex)
switch {
case err == nil:
return true
case errors.Is(err, source.ErrNotFound):
// do nothing, fallthrough to regular failure message
default:
t.Log("failed to update source", err)
return false
}
}
}
var message string var message string
switch typed := result.(type) { switch typed := result.(type) {
case resultWithComparisonArgs: case resultWithComparisonArgs:
const stackIndex = 3 // Assert/Check, assert, runComparison const stackIndex = 3 // Assert/Check, assert, RunComparison
args, err := source.CallExprArgs(stackIndex) args, err := source.CallExprArgs(stackIndex)
if err != nil { if err != nil {
t.Log(err.Error()) t.Log(err.Error())
@ -50,6 +69,10 @@ type resultBasic interface {
FailureMessage() string FailureMessage() string
} }
type updateExpected interface {
UpdatedExpected(stackIndex int) error
}
// filterPrintableExpr filters the ast.Expr slice to only include Expr that are // filterPrintableExpr filters the ast.Expr slice to only include Expr that are
// easy to read when printed and contain relevant information to an assertion. // easy to read when printed and contain relevant information to an assertion.
// //
@ -88,15 +111,20 @@ func isShortPrintableExpr(expr ast.Expr) bool {
type argSelector func([]ast.Expr) []ast.Expr type argSelector func([]ast.Expr) []ast.Expr
func argsAfterT(args []ast.Expr) []ast.Expr { // ArgsAfterT selects args starting at position 1. Used when the caller has a
// testing.T as the first argument, and the args to select should follow it.
func ArgsAfterT(args []ast.Expr) []ast.Expr {
if len(args) < 1 { if len(args) < 1 {
return nil return nil
} }
return args[1:] return args[1:]
} }
func argsFromComparisonCall(args []ast.Expr) []ast.Expr { // ArgsFromComparisonCall selects args from the CallExpression at position 1.
if len(args) < 1 { // Used when the caller has a testing.T as the first argument, and the args to
// select are passed to the cmp.Comparison at position 1.
func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
if len(args) <= 1 {
return nil return nil
} }
if callExpr, ok := args[1].(*ast.CallExpr); ok { if callExpr, ok := args[1].(*ast.CallExpr); ok {
@ -104,3 +132,15 @@ func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
} }
return nil return nil
} }
// ArgsAtZeroIndex selects args from the CallExpression at position 1.
// Used when the caller accepts a single cmp.Comparison argument.
func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
if len(args) == 0 {
return nil
}
if callExpr, ok := args[0].(*ast.CallExpr); ok {
return callExpr.Args
}
return nil
}

View File

@ -1,19 +1,20 @@
/*Package difflib is a partial port of Python difflib module. /*
Package difflib is a partial port of Python difflib module.
Original source: https://github.com/pmezard/go-difflib Original source: https://github.com/pmezard/go-difflib
This file is trimmed to only the parts used by this repository. This file is trimmed to only the parts used by this repository.
*/ */
package difflib // import "gotest.tools/internal/difflib" package difflib // import "gotest.tools/v3/internal/difflib"
func min(a, b int) int { func minInt(a, b int) int {
if a < b { if a < b {
return a return a
} }
return b return b
} }
func max(a, b int) int { func maxInt(a, b int) int {
if a > b { if a > b {
return a return a
} }
@ -170,12 +171,15 @@ func (m *SequenceMatcher) isBJunk(s string) bool {
// If IsJunk is not defined: // If IsJunk is not defined:
// //
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where // Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
// alo <= i <= i+k <= ahi //
// blo <= j <= j+k <= bhi // alo <= i <= i+k <= ahi
// blo <= j <= j+k <= bhi
//
// and for all (i',j',k') meeting those conditions, // and for all (i',j',k') meeting those conditions,
// k >= k' //
// i <= i' // k >= k'
// and if i == i', j <= j' // i <= i'
// and if i == i', j <= j'
// //
// In other words, of all maximal matching blocks, return one that // In other words, of all maximal matching blocks, return one that
// starts earliest in a, and of all those maximal matching blocks that // starts earliest in a, and of all those maximal matching blocks that
@ -393,12 +397,12 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
if codes[0].Tag == 'e' { if codes[0].Tag == 'e' {
c := codes[0] c := codes[0]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} codes[0] = OpCode{c.Tag, maxInt(i1, i2-n), i2, maxInt(j1, j2-n), j2}
} }
if codes[len(codes)-1].Tag == 'e' { if codes[len(codes)-1].Tag == 'e' {
c := codes[len(codes)-1] c := codes[len(codes)-1]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} codes[len(codes)-1] = OpCode{c.Tag, i1, minInt(i2, i1+n), j1, minInt(j2, j1+n)}
} }
nn := n + n nn := n + n
groups := [][]OpCode{} groups := [][]OpCode{}
@ -408,11 +412,11 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
// End the current group and start a new one whenever // End the current group and start a new one whenever
// there is a large range with no changes. // there is a large range with no changes.
if c.Tag == 'e' && i2-i1 > nn { if c.Tag == 'e' && i2-i1 > nn {
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), group = append(group, OpCode{c.Tag, i1, minInt(i2, i1+n),
j1, min(j2, j1+n)}) j1, minInt(j2, j1+n)})
groups = append(groups, group) groups = append(groups, group)
group = []OpCode{} group = []OpCode{}
i1, j1 = max(i1, i2-n), max(j1, j2-n) i1, j1 = maxInt(i1, i2-n), maxInt(j1, j2-n)
} }
group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
} }

View File

@ -1,3 +1,4 @@
// Package format provides utilities for formatting diffs and messages.
package format package format
import ( import (
@ -6,7 +7,7 @@ import (
"strings" "strings"
"unicode" "unicode"
"gotest.tools/internal/difflib" "gotest.tools/v3/internal/difflib"
) )
const ( const (

View File

@ -1,4 +1,4 @@
package format // import "gotest.tools/internal/format" package format // import "gotest.tools/v3/internal/format"
import "fmt" import "fmt"

View File

@ -0,0 +1,51 @@
package source
import (
"fmt"
"os"
"path/filepath"
)
// These Bazel env vars are documented here:
// https://bazel.build/reference/test-encyclopedia
// Signifies test executable is being driven by `bazel test`.
//
// Due to Bazel's compilation and sandboxing strategy,
// some care is required to handle resolving the original *.go source file.
var inBazelTest = os.Getenv("BAZEL_TEST") == "1"
// The name of the target being tested (ex: //some_package:some_package_test)
var bazelTestTarget = os.Getenv("TEST_TARGET")
// Absolute path to the base of the runfiles tree
var bazelTestSrcdir = os.Getenv("TEST_SRCDIR")
// The local repository's workspace name (ex: __main__)
var bazelTestWorkspace = os.Getenv("TEST_WORKSPACE")
func bazelSourcePath(filename string) (string, error) {
// Use the env vars to resolve the test source files,
// which must be provided as test data in the respective go_test target.
filename = filepath.Join(bazelTestSrcdir, bazelTestWorkspace, filename)
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return "", fmt.Errorf(bazelMissingSourceMsg, filename, bazelTestTarget)
}
return filename, nil
}
var bazelMissingSourceMsg = `
the test source file does not exist: %s
It appears that you are running this test under Bazel (target: %s).
Check that your test source files are added as test data in your go_test targets.
Example:
go_test(
name = "your_package_test",
srcs = ["your_test.go"],
deps = ["@tools_gotest_v3//assert"],
data = glob(["*_test.go"])
)"
`

View File

@ -1,10 +1,9 @@
package source package source
import ( import (
"fmt"
"go/ast" "go/ast"
"go/token" "go/token"
"github.com/pkg/errors"
) )
func scanToDeferLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { func scanToDeferLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
@ -29,11 +28,11 @@ func guessDefer(node ast.Node) (ast.Node, error) {
defers := collectDefers(node) defers := collectDefers(node)
switch len(defers) { switch len(defers) {
case 0: case 0:
return nil, errors.New("failed to expression in defer") return nil, fmt.Errorf("failed to find expression in defer")
case 1: case 1:
return defers[0].Call, nil return defers[0].Call, nil
default: default:
return nil, errors.Errorf( return nil, fmt.Errorf(
"ambiguous call expression: multiple (%d) defers in call block", "ambiguous call expression: multiple (%d) defers in call block",
len(defers)) len(defers))
} }

View File

@ -1,22 +1,19 @@
package source // import "gotest.tools/internal/source" // Package source provides utilities for handling source-code.
package source // import "gotest.tools/v3/internal/source"
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/format" "go/format"
"go/parser" "go/parser"
"go/token" "go/token"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strconv"
"strings"
"github.com/pkg/errors"
) )
const baseStackIndex = 1
// FormattedCallExprArg returns the argument from an ast.CallExpr at the // FormattedCallExprArg returns the argument from an ast.CallExpr at the
// index in the call stack. The argument is formatted using FormatNode. // index in the call stack. The argument is formatted using FormatNode.
func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
@ -33,28 +30,39 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
// the index in the call stack. // the index in the call stack.
func CallExprArgs(stackIndex int) ([]ast.Expr, error) { func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) _, filename, line, ok := runtime.Caller(stackIndex + 1)
if !ok { if !ok {
return nil, errors.New("failed to get call stack") return nil, errors.New("failed to get call stack")
} }
debug("call stack position: %s:%d", filename, lineNum) debug("call stack position: %s:%d", filename, line)
node, err := getNodeAtLine(filename, lineNum) // Normally, `go` will compile programs with absolute paths in
if err != nil { // the debug metadata. However, in the name of reproducibility,
return nil, err // Bazel uses a compilation strategy that results in relative paths
// (otherwise, since Bazel uses a random tmp dir for compile and sandboxing,
// the resulting binaries would change across compiles/test runs).
if inBazelTest && !filepath.IsAbs(filename) {
var err error
filename, err = bazelSourcePath(filename)
if err != nil {
return nil, err
}
} }
debug("found node: %s", debugFormatNode{node})
return getCallExprArgs(node)
}
func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
fileset := token.NewFileSet() fileset := token.NewFileSet()
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed to parse source file: %s", filename) return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
} }
expr, err := getCallExprArgs(fileset, astFile, line)
if err != nil {
return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
}
return expr, nil
}
func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
if node := scanToLine(fileset, astFile, lineNum); node != nil { if node := scanToLine(fileset, astFile, lineNum); node != nil {
return node, nil return node, nil
} }
@ -64,8 +72,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
return node, err return node, err
} }
} }
return nil, errors.Errorf( return nil, errors.New("failed to find expression")
"failed to find an expression on line %d in %s", lineNum, filename)
} }
func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
@ -74,7 +81,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
switch { switch {
case node == nil || matchedNode != nil: case node == nil || matchedNode != nil:
return false return false
case nodePosition(fileset, node).Line == lineNum: case fileset.Position(node.Pos()).Line == lineNum:
matchedNode = node matchedNode = node
return false return false
} }
@ -83,35 +90,18 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
return matchedNode return matchedNode
} }
// In golang 1.9 the line number changed from being the line where the statement func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
// ended to the line where the statement began. node, err := getNodeAtLine(fileset, astFile, line)
func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { if err != nil {
if goVersionBefore19 { return nil, err
return fileset.Position(node.End())
} }
return fileset.Position(node.Pos())
}
var goVersionBefore19 = func() bool { debug("found node: %s", debugFormatNode{node})
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
minor, err := strconv.ParseInt(parts[1], 10, 32)
return err == nil && parts[0] == "1" && minor < 9
}()
func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
visitor := &callExprVisitor{} visitor := &callExprVisitor{}
ast.Walk(visitor, node) ast.Walk(visitor, node)
if visitor.expr == nil { if visitor.expr == nil {
return nil, errors.New("failed to find call expression") return nil, errors.New("failed to find an expression")
} }
debug("callExpr: %s", debugFormatNode{visitor.expr}) debug("callExpr: %s", debugFormatNode{visitor.expr})
return visitor.expr.Args, nil return visitor.expr.Args, nil
@ -158,6 +148,9 @@ type debugFormatNode struct {
} }
func (n debugFormatNode) String() string { func (n debugFormatNode) String() string {
if n.Node == nil {
return "none"
}
out, err := FormatNode(n.Node) out, err := FormatNode(n.Node)
if err != nil { if err != nil {
return fmt.Sprintf("failed to format %s: %s", n.Node, err) return fmt.Sprintf("failed to format %s: %s", n.Node, err)

View File

@ -0,0 +1,171 @@
package source
import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"runtime"
"strings"
)
// IsUpdate is returns true if the -update flag is set. It indicates the user
// running the tests would like to update any golden values.
func IsUpdate() bool {
if Update {
return true
}
return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
}
// Update is a shim for testing, and for compatibility with the old -update-golden
// flag.
var Update bool
func init() {
if f := flag.Lookup("update"); f != nil {
getter, ok := f.Value.(flag.Getter)
msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
if !ok {
panic(msg)
}
if _, ok := getter.Get().(bool); !ok {
panic(msg)
}
return
}
flag.Bool("update", false, "update golden values")
}
// ErrNotFound indicates that UpdateExpectedValue failed to find the
// variable to update, likely because it is not a package level variable.
var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
// UpdateExpectedValue looks for a package-level variable with a name that
// starts with expected in the arguments to the caller. If the variable is
// found, the value of the variable will be updated to value of the other
// argument to the caller.
func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
_, filename, line, ok := runtime.Caller(stackIndex + 1)
if !ok {
return errors.New("failed to get call stack")
}
debug("call stack position: %s:%d", filename, line)
fileset := token.NewFileSet()
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse source file %s: %w", filename, err)
}
expr, err := getCallExprArgs(fileset, astFile, line)
if err != nil {
return fmt.Errorf("call from %s:%d: %w", filename, line, err)
}
if len(expr) < 3 {
debug("not enough arguments %d: %v",
len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
}
argIndex, ident := getIdentForExpectedValueArg(expr)
if argIndex < 0 || ident == nil {
debug("no arguments started with the word 'expected': %v",
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
}
value := x
if argIndex == 1 {
value = y
}
strValue, ok := value.(string)
if !ok {
debug("value must be type string, got %T", value)
return ErrNotFound
}
return UpdateVariable(filename, fileset, astFile, ident, strValue)
}
// UpdateVariable writes to filename the contents of astFile with the value of
// the variable updated to value.
func UpdateVariable(
filename string,
fileset *token.FileSet,
astFile *ast.File,
ident *ast.Ident,
value string,
) error {
obj := ident.Obj
if obj == nil {
return ErrNotFound
}
if obj.Kind != ast.Con && obj.Kind != ast.Var {
debug("can only update var and const, found %v", obj.Kind)
return ErrNotFound
}
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
if len(decl.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}
decl.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}
case *ast.AssignStmt:
if len(decl.Lhs) != 1 {
debug("more than one name in ast.AssignStmt")
return ErrNotFound
}
decl.Rhs[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}
default:
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
return ErrNotFound
}
var buf bytes.Buffer
if err := format.Node(&buf, fileset, astFile); err != nil {
return fmt.Errorf("failed to format file after update: %w", err)
}
fh, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to open file %v: %w", filename, err)
}
if _, err = fh.Write(buf.Bytes()); err != nil {
return fmt.Errorf("failed to write file %v: %w", filename, err)
}
if err := fh.Sync(); err != nil {
return fmt.Errorf("failed to sync file %v: %w", filename, err)
}
return nil
}
func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
for i := 1; i < 3; i++ {
switch e := expr[i].(type) {
case *ast.Ident:
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
return i, e
}
}
}
return -1, nil
}

View File

@ -0,0 +1,35 @@
package source
import (
"runtime"
"strconv"
"strings"
)
// GoVersionLessThan returns true if runtime.Version() is semantically less than
// version major.minor. Returns false if a release version can not be parsed from
// runtime.Version().
func GoVersionLessThan(major, minor int64) bool {
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
rMajor, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return false
}
if rMajor != major {
return rMajor < major
}
rMinor, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return false
}
return rMinor < minor
}

15
vendor/modules.txt vendored
View File

@ -224,10 +224,11 @@ google.golang.org/protobuf/types/known/anypb
# gopkg.in/yaml.v3 v3.0.1 # gopkg.in/yaml.v3 v3.0.1
## explicit ## explicit
gopkg.in/yaml.v3 gopkg.in/yaml.v3
# gotest.tools v2.2.0+incompatible # gotest.tools/v3 v3.5.2
## explicit ## explicit; go 1.17
gotest.tools/assert gotest.tools/v3/assert
gotest.tools/assert/cmp gotest.tools/v3/assert/cmp
gotest.tools/internal/difflib gotest.tools/v3/internal/assert
gotest.tools/internal/format gotest.tools/v3/internal/difflib
gotest.tools/internal/source gotest.tools/v3/internal/format
gotest.tools/v3/internal/source