token/cache: use go 1.20's approach for no-copy string/bytes conversions

Note that this fixes a bug in the existing `toBytes` implementation
which does not correctly set the capacity on the returned slice.

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: aa80f8fb856bb2b645c90457f9b1dd75e4e57c73
This commit is contained in:
Monis Khan 2023-02-21 12:24:21 -05:00 committed by Kubernetes Publisher
parent a8f9a38ca8
commit 6ab879299d
2 changed files with 126 additions and 2 deletions

View File

@ -277,12 +277,24 @@ func writeLength(w io.Writer, b []byte, length int) {
// toBytes performs unholy acts to avoid allocations
func toBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(&s))
// unsafe.StringData is unspecified for the empty string, so we provide a strict interpretation
if len(s) == 0 {
return nil
}
// Copied from go 1.20.1 os.File.WriteString
// https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/os/file.go#L246
return unsafe.Slice(unsafe.StringData(s), len(s))
}
// toString performs unholy acts to avoid allocations
func toString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
// unsafe.SliceData relies on cap whereas we want to rely on len
if len(b) == 0 {
return ""
}
// Copied from go 1.20.1 strings.Builder.String
// https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/strings/builder.go#L48
return unsafe.String(unsafe.SliceData(b), len(b))
}
// simple recorder that only appends warning

View File

@ -17,6 +17,7 @@ limitations under the License.
package cache
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
@ -31,6 +32,8 @@ import (
"time"
"github.com/google/go-cmp/cmp"
utilrand "k8s.io/apimachinery/pkg/util/rand"
"k8s.io/apimachinery/pkg/util/uuid"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
@ -547,3 +550,112 @@ func withAudit(ctx context.Context) context.Context {
ac.Event = &auditinternal.Event{Level: auditinternal.LevelMetadata}
return ctx
}
func TestUnsafeConversions(t *testing.T) {
t.Parallel()
// needs to be large to force allocations so we pick a random value between [1024, 2048]
size := utilrand.IntnRange(1024, 2048+1)
t.Run("toBytes semantics", func(t *testing.T) {
t.Parallel()
s := utilrand.String(size)
b := toBytes(s)
if len(b) != size {
t.Errorf("unexpected length: %d", len(b))
}
if cap(b) != size {
t.Errorf("unexpected capacity: %d", cap(b))
}
if !bytes.Equal(b, []byte(s)) {
t.Errorf("unexpected equality failure: %#v", b)
}
})
t.Run("toBytes allocations", func(t *testing.T) {
t.Parallel()
s := utilrand.String(size)
f := func() {
b := toBytes(s)
if len(b) != size {
t.Errorf("invalid length: %d", len(b))
}
}
allocs := testing.AllocsPerRun(100, f)
if allocs > 0 {
t.Errorf("expected zero allocations, got %v", allocs)
}
})
t.Run("toString semantics", func(t *testing.T) {
t.Parallel()
b := make([]byte, size)
if _, err := rand.Read(b); err != nil {
t.Fatal(err)
}
s := toString(b)
if len(s) != size {
t.Errorf("unexpected length: %d", len(s))
}
if s != string(b) {
t.Errorf("unexpected equality failure: %#v", s)
}
})
t.Run("toString allocations", func(t *testing.T) {
t.Parallel()
b := make([]byte, size)
if _, err := rand.Read(b); err != nil {
t.Fatal(err)
}
f := func() {
s := toString(b)
if len(s) != size {
t.Errorf("invalid length: %d", len(s))
}
}
allocs := testing.AllocsPerRun(100, f)
if allocs > 0 {
t.Errorf("expected zero allocations, got %v", allocs)
}
})
}
func TestKeyFunc(t *testing.T) {
t.Parallel()
hashPool := &sync.Pool{
New: func() interface{} {
return hmac.New(sha256.New, []byte("098c9e46-b7f4-4358-bb3c-35cb7495b836")) // deterministic HMAC for testing
},
}
// use realistic audiences
auds := []string{"7daf30b7-a85c-429b-8b21-e666aecbb235", "c22aa267-bdde-4acb-8505-998be7818400", "44f9b4f3-7125-4333-b04c-1446a16c6113"}
keyWithAuds := "\"\xf7\xac\xcd\x12\xf5\x83l\xa9;@\n\xa13a;\nd\x1f\xdelL\xd1\xe1!\x8a\xdahٛ\xbb\xf0"
keyWithoutAuds := "\x054a \xa5\x8e\xea\xb2?\x8c\x88\xb9,e\n5\xe7ȵ>\xfdK\x0e\x93+\x02˿&\xf98\x1e"
t.Run("has audiences", func(t *testing.T) {
t.Parallel()
key := keyFunc(hashPool, auds, jwtToken)
if key != keyWithAuds {
t.Errorf("unexpected equality failure: %#v", key)
}
})
t.Run("nil audiences", func(t *testing.T) {
t.Parallel()
key := keyFunc(hashPool, nil, jwtToken)
if key != keyWithoutAuds {
t.Errorf("unexpected equality failure: %#v", key)
}
})
}