From 887d90826495cd37de5cc60a7da7bfb5c7869be1 Mon Sep 17 00:00:00 2001 From: Paul Chesnais Date: Thu, 1 Aug 2024 17:14:30 -0400 Subject: [PATCH] mem: introduce `mem` package to facilitate memory reuse (#7432) --- internal/internal.go | 4 + mem/buffer_pool.go | 186 +++++++++++++++++++++++++ mem/buffer_pool_test.go | 75 ++++++++++ mem/buffer_slice.go | 194 ++++++++++++++++++++++++++ mem/buffer_slice_test.go | 173 +++++++++++++++++++++++ mem/buffers.go | 149 ++++++++++++++++++++ mem/buffers_test.go | 293 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 1074 insertions(+) create mode 100644 mem/buffer_pool.go create mode 100644 mem/buffer_pool_test.go create mode 100644 mem/buffer_slice.go create mode 100644 mem/buffer_slice_test.go create mode 100644 mem/buffers.go create mode 100644 mem/buffers_test.go diff --git a/internal/internal.go b/internal/internal.go index e1e1422e1..433e697f1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -221,6 +221,10 @@ var ( // sets the metric registry to its original state. Only called in testing // functions. SnapshotMetricRegistryForTesting any // func(t *testing.T) + + // SetDefaultBufferPoolForTesting updates the default buffer pool, for + // testing purposes. + SetDefaultBufferPoolForTesting any // func(mem.BufferPool) ) // HealthChecker defines the signature of the client-side LB channel health diff --git a/mem/buffer_pool.go b/mem/buffer_pool.go new file mode 100644 index 000000000..d2e8cd448 --- /dev/null +++ b/mem/buffer_pool.go @@ -0,0 +1,186 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem + +import ( + "sort" + "sync" + + "google.golang.org/grpc/internal" +) + +// BufferPool is a pool of buffers that can be shared and reused, resulting in +// decreased memory allocation. +type BufferPool interface { + // Get returns a buffer with specified length from the pool. + Get(length int) []byte + + // Put returns a buffer to the pool. + Put([]byte) +} + +var defaultBufferPoolSizes = []int{ + 256, + 4 << 10, // 4KB (go page size) + 16 << 10, // 16KB (max HTTP/2 frame size used by gRPC) + 32 << 10, // 32KB (default buffer size for io.Copy) + 1 << 20, // 1MB +} + +var defaultBufferPool BufferPool + +func init() { + defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...) + + internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) { defaultBufferPool = pool } +} + +// DefaultBufferPool returns the current default buffer pool. It is a BufferPool +// created with NewBufferPool that uses a set of default sizes optimized for +// expected workflows. +func DefaultBufferPool() BufferPool { + return defaultBufferPool +} + +// NewTieredBufferPool returns a BufferPool implementation that uses multiple +// underlying pools of the given pool sizes. +func NewTieredBufferPool(poolSizes ...int) BufferPool { + sort.Ints(poolSizes) + pools := make([]*sizedBufferPool, len(poolSizes)) + for i, s := range poolSizes { + pools[i] = newSizedBufferPool(s) + } + return &tieredBufferPool{ + sizedPools: pools, + } +} + +// tieredBufferPool implements the BufferPool interface with multiple tiers of +// buffer pools for different sizes of buffers. +type tieredBufferPool struct { + sizedPools []*sizedBufferPool + fallbackPool simpleBufferPool +} + +func (p *tieredBufferPool) Get(size int) []byte { + return p.getPool(size).Get(size) +} + +func (p *tieredBufferPool) Put(buf []byte) { + p.getPool(cap(buf)).Put(buf) +} + +func (p *tieredBufferPool) getPool(size int) BufferPool { + poolIdx := sort.Search(len(p.sizedPools), func(i int) bool { + return p.sizedPools[i].defaultSize >= size + }) + + if poolIdx == len(p.sizedPools) { + return &p.fallbackPool + } + + return p.sizedPools[poolIdx] +} + +// sizedBufferPool is a BufferPool implementation that is optimized for specific +// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size +// of 16kb and a sizedBufferPool can be configured to only return buffers with a +// capacity of 16kb. Note that however it does not support returning larger +// buffers and in fact panics if such a buffer is requested. Because of this, +// this BufferPool implementation is not meant to be used on its own and rather +// is intended to be embedded in a tieredBufferPool such that Get is only +// invoked when the required size is smaller than or equal to defaultSize. +type sizedBufferPool struct { + pool sync.Pool + defaultSize int +} + +func (p *sizedBufferPool) Get(size int) []byte { + bs := *p.pool.Get().(*[]byte) + return bs[:size] +} + +func (p *sizedBufferPool) Put(buf []byte) { + if cap(buf) < p.defaultSize { + // Ignore buffers that are too small to fit in the pool. Otherwise, when + // Get is called it will panic as it tries to index outside the bounds + // of the buffer. + return + } + buf = buf[:cap(buf)] + clear(buf) + p.pool.Put(&buf) +} + +func newSizedBufferPool(size int) *sizedBufferPool { + return &sizedBufferPool{ + pool: sync.Pool{ + New: func() any { + buf := make([]byte, size) + return &buf + }, + }, + defaultSize: size, + } +} + +var _ BufferPool = (*simpleBufferPool)(nil) + +// simpleBufferPool is an implementation of the BufferPool interface that +// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to +// acquire a buffer from the pool but if that buffer is too small, it returns it +// to the pool and creates a new one. +type simpleBufferPool struct { + pool sync.Pool +} + +func (p *simpleBufferPool) Get(size int) []byte { + bs, ok := p.pool.Get().(*[]byte) + if ok && cap(*bs) >= size { + return (*bs)[:size] + } + + // A buffer was pulled from the pool, but it is tool small. Put it back in + // the pool and create one large enough. + if ok { + p.pool.Put(bs) + } + + return make([]byte, size) +} + +func (p *simpleBufferPool) Put(buf []byte) { + buf = buf[:cap(buf)] + clear(buf) + p.pool.Put(&buf) +} + +var _ BufferPool = NopBufferPool{} + +// NopBufferPool is a buffer pool that returns new buffers without pooling. +type NopBufferPool struct{} + +// Get returns a buffer with specified length from the pool. +func (NopBufferPool) Get(length int) []byte { + return make([]byte, length) +} + +// Put returns a buffer to the pool. +func (NopBufferPool) Put([]byte) { +} diff --git a/mem/buffer_pool_test.go b/mem/buffer_pool_test.go new file mode 100644 index 000000000..d6b9d42af --- /dev/null +++ b/mem/buffer_pool_test.go @@ -0,0 +1,75 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/mem" +) + +func (s) TestBufferPool(t *testing.T) { + var poolSizes = []int{4, 8, 16, 32} + pools := []mem.BufferPool{ + mem.NopBufferPool{}, + mem.NewTieredBufferPool(poolSizes...), + } + + testSizes := append([]int{1}, poolSizes...) + testSizes = append(testSizes, 64) + + for _, p := range pools { + for _, l := range testSizes { + bs := p.Get(l) + if len(bs) != l { + t.Fatalf("Get(%d) returned buffer of length %d, want %d", l, len(bs), l) + } + + p.Put(bs) + } + } +} + +func (s) TestBufferPoolClears(t *testing.T) { + pool := mem.NewTieredBufferPool(4) + + buf := pool.Get(4) + copy(buf, "1234") + pool.Put(buf) + + if !cmp.Equal(buf, make([]byte, 4)) { + t.Fatalf("buffer not cleared") + } +} + +func (s) TestBufferPoolIgnoresShortBuffers(t *testing.T) { + pool := mem.NewTieredBufferPool(10, 20) + buf := pool.Get(1) + if cap(buf) != 10 { + t.Fatalf("Get(1) returned buffer with capacity: %d, want 10", cap(buf)) + } + + // Insert a short buffer into the pool, which is currently empty. + pool.Put(make([]byte, 1)) + // Then immediately request a buffer that would be pulled from the pool where the + // short buffer would have been returned. If the short buffer is pulled from the + // pool, it could cause a panic. + pool.Get(10) +} diff --git a/mem/buffer_slice.go b/mem/buffer_slice.go new file mode 100644 index 000000000..ec508d0ca --- /dev/null +++ b/mem/buffer_slice.go @@ -0,0 +1,194 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem + +import ( + "io" +) + +// BufferSlice offers a means to represent data that spans one or more Buffer +// instances. A BufferSlice is meant to be immutable after creation, and methods +// like Ref create and return copies of the slice. This is why all methods have +// value receivers rather than pointer receivers. +// +// Note that any of the methods that read the underlying buffers such as Ref, +// Len or CopyTo etc., will panic if any underlying buffers have already been +// freed. It is recommended to not directly interact with any of the underlying +// buffers directly, rather such interactions should be mediated through the +// various methods on this type. +// +// By convention, any APIs that return (mem.BufferSlice, error) should reduce +// the burden on the caller by never returning a mem.BufferSlice that needs to +// be freed if the error is non-nil, unless explicitly stated. +type BufferSlice []*Buffer + +// Len returns the sum of the length of all the Buffers in this slice. +// +// # Warning +// +// Invoking the built-in len on a BufferSlice will return the number of buffers +// in the slice, and *not* the value returned by this function. +func (s BufferSlice) Len() int { + var length int + for _, b := range s { + length += b.Len() + } + return length +} + +// Ref returns a new BufferSlice containing a new reference of each Buffer in the +// input slice. +func (s BufferSlice) Ref() BufferSlice { + out := make(BufferSlice, len(s)) + for i, b := range s { + out[i] = b.Ref() + } + return out +} + +// Free invokes Buffer.Free() on each Buffer in the slice. +func (s BufferSlice) Free() { + for _, b := range s { + b.Free() + } +} + +// CopyTo copies each of the underlying Buffer's data into the given buffer, +// returning the number of bytes copied. Has the same semantics as the copy +// builtin in that it will copy as many bytes as it can, stopping when either dst +// is full or s runs out of data, returning the minimum of s.Len() and len(dst). +func (s BufferSlice) CopyTo(dst []byte) int { + off := 0 + for _, b := range s { + off += copy(dst[off:], b.ReadOnlyData()) + } + return off +} + +// Materialize concatenates all the underlying Buffer's data into a single +// contiguous buffer using CopyTo. +func (s BufferSlice) Materialize() []byte { + l := s.Len() + if l == 0 { + return nil + } + out := make([]byte, l) + s.CopyTo(out) + return out +} + +// MaterializeToBuffer functions like Materialize except that it writes the data +// to a single Buffer pulled from the given BufferPool. As a special case, if the +// input BufferSlice only actually has one Buffer, this function has nothing to +// do and simply returns said Buffer. +func (s BufferSlice) MaterializeToBuffer(pool BufferPool) *Buffer { + if len(s) == 1 { + return s[0].Ref() + } + buf := pool.Get(s.Len()) + s.CopyTo(buf) + return NewBuffer(buf, pool.Put) +} + +// Reader returns a new Reader for the input slice after taking references to +// each underlying buffer. +func (s BufferSlice) Reader() *Reader { + return &Reader{ + data: s.Ref(), + len: s.Len(), + } +} + +var _ io.ReadCloser = (*Reader)(nil) + +// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface +// with other parts systems. It also provides an additional convenience method +// Remaining(), which returns the number of unread bytes remaining in the slice. +// +// Note that reading data from the reader does not free the underlying buffers! +// Only calling Close once all data is read will free the buffers. +type Reader struct { + data BufferSlice + len int + // The index into data[0].ReadOnlyData(). + bufferIdx int +} + +// Remaining returns the number of unread bytes remaining in the slice. +func (r *Reader) Remaining() int { + return r.len +} + +// Close frees the underlying BufferSlice and never returns an error. Subsequent +// calls to Read will return (0, io.EOF). +func (r *Reader) Close() error { + r.data.Free() + r.data = nil + r.len = 0 + return nil +} + +func (r *Reader) Read(buf []byte) (n int, _ error) { + if r.len == 0 { + return 0, io.EOF + } + + for len(buf) != 0 && r.len != 0 { + // Copy as much as possible from the first Buffer in the slice into the + // given byte slice. + data := r.data[0].ReadOnlyData() + copied := copy(buf, data[r.bufferIdx:]) + r.len -= copied // Reduce len by the number of bytes copied. + r.bufferIdx += copied // Increment the buffer index. + n += copied // Increment the total number of bytes read. + buf = buf[copied:] // Shrink the given byte slice. + + // If we have copied all of the data from the first Buffer, free it and + // advance to the next in the slice. + if r.bufferIdx == len(data) { + oldBuffer := r.data[0] + oldBuffer.Free() + r.data = r.data[1:] + r.bufferIdx = 0 + } + } + + return n, nil +} + +var _ io.Writer = (*writer)(nil) + +type writer struct { + buffers *BufferSlice + pool BufferPool +} + +func (w *writer) Write(p []byte) (n int, err error) { + b := Copy(p, w.pool) + *w.buffers = append(*w.buffers, b) + return b.Len(), nil +} + +// NewWriter wraps the given BufferSlice and BufferPool to implement the +// io.Writer interface. Every call to Write copies the contents of the given +// buffer into a new Buffer pulled from the given pool and the Buffer is added to +// the given BufferSlice. +func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer { + return &writer{buffers: buffers, pool: pool} +} diff --git a/mem/buffer_slice_test.go b/mem/buffer_slice_test.go new file mode 100644 index 000000000..d98055bba --- /dev/null +++ b/mem/buffer_slice_test.go @@ -0,0 +1,173 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem_test + +import ( + "bytes" + "fmt" + "io" + "testing" + + "google.golang.org/grpc/mem" +) + +func (s) TestBufferSlice_Len(t *testing.T) { + tests := []struct { + name string + in mem.BufferSlice + want int + }{ + { + name: "empty", + in: nil, + want: 0, + }, + { + name: "single", + in: mem.BufferSlice{mem.NewBuffer([]byte("abcd"), nil)}, + want: 4, + }, + { + name: "multiple", + in: mem.BufferSlice{ + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + }, + want: 12, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.in.Len(); got != tt.want { + t.Errorf("BufferSlice.Len() = %v, want %v", got, tt.want) + } + }) + } +} + +func (s) TestBufferSlice_Ref(t *testing.T) { + // Create a new buffer slice and a reference to it. + bs := mem.BufferSlice{ + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + } + bsRef := bs.Ref() + + // Free the original buffer slice and verify that the reference can still + // read data from it. + bs.Free() + got := bsRef.Materialize() + want := []byte("abcdabcd") + if !bytes.Equal(got, want) { + t.Errorf("BufferSlice.Materialize() = %s, want %s", string(got), string(want)) + } +} + +func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) { + tests := []struct { + name string + in mem.BufferSlice + pool mem.BufferPool + wantData []byte + }{ + { + name: "single", + in: mem.BufferSlice{mem.NewBuffer([]byte("abcd"), nil)}, + pool: nil, // MaterializeToBuffer should not use the pool in this case. + wantData: []byte("abcd"), + }, + { + name: "multiple", + in: mem.BufferSlice{ + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + }, + pool: mem.DefaultBufferPool(), + wantData: []byte("abcdabcdabcd"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.in.MaterializeToBuffer(tt.pool) + defer got.Free() + if !bytes.Equal(got.ReadOnlyData(), tt.wantData) { + t.Errorf("BufferSlice.MaterializeToBuffer() = %s, want %s", string(got.ReadOnlyData()), string(tt.wantData)) + } + }) + } +} + +func (s) TestBufferSlice_Reader(t *testing.T) { + bs := mem.BufferSlice{ + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + mem.NewBuffer([]byte("abcd"), nil), + } + wantData := []byte("abcdabcdabcd") + + reader := bs.Reader() + var gotData []byte + // Read into a buffer of size 1 until EOF, and verify that the data matches. + for { + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n > 0 { + gotData = append(gotData, buf[:n]...) + } + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("BufferSlice.Reader() failed unexpectedly: %v", err) + } + } + if !bytes.Equal(gotData, wantData) { + t.Errorf("BufferSlice.Reader() returned data %v, want %v", string(gotData), string(wantData)) + } + + // Reader should have released its references to the underlying buffers, but + // bs still holds its reference and it should be able to read data from it. + gotData = bs.Materialize() + if !bytes.Equal(gotData, wantData) { + t.Errorf("BufferSlice.Materialize() = %s, want %s", string(gotData), string(wantData)) + } +} + +func ExampleNewWriter() { + var bs mem.BufferSlice + pool := mem.DefaultBufferPool() + writer := mem.NewWriter(&bs, pool) + + for _, data := range [][]byte{ + []byte("abcd"), + []byte("abcd"), + []byte("abcd"), + } { + n, err := writer.Write(data) + fmt.Printf("Wrote %d bytes, err: %v\n", n, err) + } + fmt.Println(string(bs.Materialize())) + // Output: + // Wrote 4 bytes, err: + // Wrote 4 bytes, err: + // Wrote 4 bytes, err: + // abcdabcdabcd +} diff --git a/mem/buffers.go b/mem/buffers.go new file mode 100644 index 000000000..3b8f8addb --- /dev/null +++ b/mem/buffers.go @@ -0,0 +1,149 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package mem provides utilities that facilitate memory reuse in byte slices +// that are used as buffers. +// +// # Experimental +// +// Notice: All APIs in this package are EXPERIMENTAL and may be changed or +// removed in a later release. +package mem + +import ( + "fmt" + "sync/atomic" +) + +// A Buffer represents a reference counted piece of data (in bytes) that can be +// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be +// released by calling Free(), which invokes the given free function only after +// all references are released. +// +// Note that a Buffer is not safe for concurrent access and instead each +// goroutine should use its own reference to the data, which can be acquired via +// a call to Ref(). +// +// Attempts to access the underlying data after releasing the reference to the +// Buffer will panic. +type Buffer struct { + data []byte + refs *atomic.Int32 + free func() + freed bool +} + +// NewBuffer creates a new Buffer from the given data, initializing the +// reference counter to 1. The given free function is called when all references +// to the returned Buffer are released. +// +// Note that the backing array of the given data is not copied. +func NewBuffer(data []byte, onFree func([]byte)) *Buffer { + b := &Buffer{data: data, refs: new(atomic.Int32)} + if onFree != nil { + b.free = func() { onFree(data) } + } + b.refs.Add(1) + return b +} + +// Copy creates a new Buffer from the given data, initializing the reference +// counter to 1. +// +// It acquires a []byte from the given pool and copies over the backing array +// of the given data. The []byte acquired from the pool is returned to the +// pool when all references to the returned Buffer are released. +func Copy(data []byte, pool BufferPool) *Buffer { + buf := pool.Get(len(data)) + copy(buf, data) + return NewBuffer(buf, pool.Put) +} + +// ReadOnlyData returns the underlying byte slice. Note that it is undefined +// behavior to modify the contents of this slice in any way. +func (b *Buffer) ReadOnlyData() []byte { + if b.freed { + panic("Cannot read freed buffer") + } + return b.data +} + +// Ref returns a new reference to this Buffer's underlying byte slice. +func (b *Buffer) Ref() *Buffer { + if b.freed { + panic("Cannot ref freed buffer") + } + + b.refs.Add(1) + return &Buffer{ + data: b.data, + refs: b.refs, + free: b.free, + } +} + +// Free decrements this Buffer's reference counter and frees the underlying +// byte slice if the counter reaches 0 as a result of this call. +func (b *Buffer) Free() { + if b.freed { + return + } + + b.freed = true + refs := b.refs.Add(-1) + if refs == 0 && b.free != nil { + b.free() + } + b.data = nil +} + +// Len returns the Buffer's size. +func (b *Buffer) Len() int { + // Convenience: io.Reader returns (n int, err error), and n is often checked + // before err is checked. To mimic this, Len() should work on nil Buffers. + if b == nil { + return 0 + } + return len(b.ReadOnlyData()) +} + +// Split modifies the receiver to point to the first n bytes while it returns a +// new reference to the remaining bytes. The returned Buffer functions just like +// a normal reference acquired using Ref(). +func (b *Buffer) Split(n int) *Buffer { + if b.freed { + panic("Cannot split freed buffer") + } + + b.refs.Add(1) + + split := &Buffer{ + refs: b.refs, + free: b.free, + } + + b.data, split.data = b.data[:n], b.data[n:] + + return split +} + +// String returns a string representation of the buffer. May be used for +// debugging purposes. +func (b *Buffer) String() string { + return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData())) +} diff --git a/mem/buffers_test.go b/mem/buffers_test.go new file mode 100644 index 000000000..b761fda79 --- /dev/null +++ b/mem/buffers_test.go @@ -0,0 +1,293 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package mem_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/mem" +) + +const ( + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 100 * time.Millisecond +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// Tests that a buffer created with NewBuffer, which when later freed, invokes +// the free function with the correct data. +func (s) TestBuffer_NewBufferAndFree(t *testing.T) { + data := "abcd" + errCh := make(chan error, 1) + freeF := func(got []byte) { + if !bytes.Equal(got, []byte(data)) { + errCh <- fmt.Errorf("Free function called with bytes %s, want %s", string(got), string(data)) + return + } + errCh <- nil + } + + buf := mem.NewBuffer([]byte(data), freeF) + if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("Buffer contains data %s, want %s", string(got), string(data)) + } + + // Verify that the free function is invoked when all references are freed. + buf.Free() + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(defaultTestTimeout): + t.Fatalf("Timeout waiting for Buffer to be freed") + } +} + +// Tests that a buffer created with NewBuffer, on which an additional reference +// is acquired, which when later freed, invokes the free function with the +// correct data, but only after all references are released. +func (s) TestBuffer_NewBufferRefAndFree(t *testing.T) { + data := "abcd" + errCh := make(chan error, 1) + freeF := func(got []byte) { + if !bytes.Equal(got, []byte(data)) { + errCh <- fmt.Errorf("Free function called with bytes %s, want %s", string(got), string(data)) + return + } + errCh <- nil + } + + buf := mem.NewBuffer([]byte(data), freeF) + if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("Buffer contains data %s, want %s", string(got), string(data)) + } + + bufRef := buf.Ref() + if got := bufRef.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data)) + } + + // Verify that the free function is not invoked when all references are yet + // to be freed. + buf.Free() + select { + case <-errCh: + t.Fatalf("Free function called before all references freed") + case <-time.After(defaultTestShortTimeout): + } + + // Verify that the free function is invoked when all references are freed. + bufRef.Free() + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(defaultTestTimeout): + t.Fatalf("Timeout waiting for Buffer to be freed") + } +} + +// testBufferPool is a buffer pool that makes new buffer without pooling, and +// notifies on a channel that a buffer was returned to the pool. +type testBufferPool struct { + putCh chan []byte +} + +func (t *testBufferPool) Get(length int) []byte { + return make([]byte, length) +} + +func (t *testBufferPool) Put(data []byte) { + t.putCh <- data +} + +func newTestBufferPool() *testBufferPool { + return &testBufferPool{putCh: make(chan []byte, 1)} +} + +// Tests that a buffer created with Copy, which when later freed, returns the underlying +// byte slice to the buffer pool. +func (s) TestBufer_CopyAndFree(t *testing.T) { + data := "abcd" + testPool := newTestBufferPool() + + buf := mem.Copy([]byte(data), testPool) + if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("Buffer contains data %s, want %s", string(got), string(data)) + } + + // Verify that the free function is invoked when all references are freed. + buf.Free() + select { + case got := <-testPool.putCh: + if !bytes.Equal(got, []byte(data)) { + t.Fatalf("Free function called with bytes %s, want %s", string(got), string(data)) + } + case <-time.After(defaultTestTimeout): + t.Fatalf("Timeout waiting for Buffer to be freed") + } +} + +// Tests that a buffer created with Copy, on which an additional reference is +// acquired, which when later freed, returns the underlying byte slice to the +// buffer pool. +func (s) TestBuffer_CopyRefAndFree(t *testing.T) { + data := "abcd" + testPool := newTestBufferPool() + + buf := mem.Copy([]byte(data), testPool) + if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("Buffer contains data %s, want %s", string(got), string(data)) + } + + bufRef := buf.Ref() + if got := bufRef.ReadOnlyData(); !bytes.Equal(got, []byte(data)) { + t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data)) + } + + // Verify that the free function is not invoked when all references are yet + // to be freed. + buf.Free() + select { + case <-testPool.putCh: + t.Fatalf("Free function called before all references freed") + case <-time.After(defaultTestShortTimeout): + } + + // Verify that the free function is invoked when all references are freed. + bufRef.Free() + select { + case got := <-testPool.putCh: + if !bytes.Equal(got, []byte(data)) { + t.Fatalf("Free function called with bytes %s, want %s", string(got), string(data)) + } + case <-time.After(defaultTestTimeout): + t.Fatalf("Timeout waiting for Buffer to be freed") + } +} + +func (s) TestBuffer_Split(t *testing.T) { + ready := false + freed := false + data := []byte{1, 2, 3, 4} + buf := mem.NewBuffer(data, func(bytes []byte) { + if !ready { + t.Fatalf("Freed too early") + } + freed = true + }) + checkBufData := func(b *mem.Buffer, expected []byte) { + if !bytes.Equal(b.ReadOnlyData(), expected) { + t.Fatalf("Buffer did not contain expected data %v, got %v", expected, b.ReadOnlyData()) + } + } + + // Take a ref of the original buffer + ref1 := buf.Ref() + + split1 := buf.Split(2) + checkBufData(buf, data[:2]) + checkBufData(split1, data[2:]) + // Check that even though buf was split, the reference wasn't modified + checkBufData(ref1, data) + ref1.Free() + + // Check that splitting the buffer more than once works as intended. + split2 := split1.Split(1) + checkBufData(split1, data[2:3]) + checkBufData(split2, data[3:]) + + // If any of the following frees actually free the buffer, the test will fail. + buf.Free() + split2.Free() + + ready = true + split1.Free() + + if !freed { + t.Fatalf("Buffer never freed") + } +} + +func checkForPanic(t *testing.T, wantErr string) { + t.Helper() + r := recover() + if r == nil { + t.Fatalf("Use after free dit not panic") + } + if r.(string) != wantErr { + t.Fatalf("panic called with %v, want %s", r, wantErr) + } +} + +func (s) TestBuffer_ReadOnlyDataAfterFree(t *testing.T) { + // Verify that reading before freeing does not panic. + buf := mem.NewBuffer([]byte("abcd"), nil) + buf.ReadOnlyData() + + buf.Free() + defer checkForPanic(t, "Cannot read freed buffer") + buf.ReadOnlyData() +} + +func (s) TestBuffer_RefAfterFree(t *testing.T) { + // Verify that acquiring a ref before freeing does not panic. + buf := mem.NewBuffer([]byte("abcd"), nil) + bufRef := buf.Ref() + defer bufRef.Free() + + buf.Free() + defer checkForPanic(t, "Cannot ref freed buffer") + buf.Ref() +} + +func (s) TestBuffer_SplitAfterFree(t *testing.T) { + // Verify that splitting before freeing does not panic. + buf := mem.NewBuffer([]byte("abcd"), nil) + bufSplit := buf.Split(2) + defer bufSplit.Free() + + buf.Free() + defer checkForPanic(t, "Cannot split freed buffer") + buf.Split(1) +} + +func (s) TestBuffer_FreeAfterFree(t *testing.T) { + buf := mem.NewBuffer([]byte("abcd"), nil) + if buf.Len() != 4 { + t.Fatalf("Buffer length is %d, want 4", buf.Len()) + } + + // Ensure that a double free does not panic. + buf.Free() + buf.Free() +}