[KMSv2] store hash of encrypted DEK as key in cache
Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com> Kubernetes-commit: f72cf5c510cf2cf7b8ee375f5c2ec835e3ed225a
This commit is contained in:
		
							parent
							
								
									e3ca625155
								
							
						
					
					
						commit
						421ef770de
					
				| 
						 | 
					@ -18,8 +18,11 @@ limitations under the License.
 | 
				
			||||||
package kmsv2
 | 
					package kmsv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/base64"
 | 
						"crypto/sha256"
 | 
				
			||||||
 | 
						"hash"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
						"unsafe"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	utilcache "k8s.io/apimachinery/pkg/util/cache"
 | 
						utilcache "k8s.io/apimachinery/pkg/util/cache"
 | 
				
			||||||
	"k8s.io/apiserver/pkg/storage/value"
 | 
						"k8s.io/apiserver/pkg/storage/value"
 | 
				
			||||||
| 
						 | 
					@ -29,18 +32,26 @@ import (
 | 
				
			||||||
type simpleCache struct {
 | 
					type simpleCache struct {
 | 
				
			||||||
	cache *utilcache.Expiring
 | 
						cache *utilcache.Expiring
 | 
				
			||||||
	ttl   time.Duration
 | 
						ttl   time.Duration
 | 
				
			||||||
 | 
						// hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash)
 | 
				
			||||||
 | 
						// SHA-256 is used to prevent collisions
 | 
				
			||||||
 | 
						hashPool *sync.Pool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
 | 
					func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
 | 
				
			||||||
	return &simpleCache{
 | 
						return &simpleCache{
 | 
				
			||||||
		cache: utilcache.NewExpiringWithClock(clock),
 | 
							cache: utilcache.NewExpiringWithClock(clock),
 | 
				
			||||||
		ttl:   ttl,
 | 
							ttl:   ttl,
 | 
				
			||||||
 | 
							hashPool: &sync.Pool{
 | 
				
			||||||
 | 
								New: func() interface{} {
 | 
				
			||||||
 | 
									return sha256.New()
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// given a key, return the transformer, or nil if it does not exist in the cache
 | 
					// given a key, return the transformer, or nil if it does not exist in the cache
 | 
				
			||||||
func (c *simpleCache) get(key []byte) value.Transformer {
 | 
					func (c *simpleCache) get(key []byte) value.Transformer {
 | 
				
			||||||
	record, ok := c.cache.Get(base64.StdEncoding.EncodeToString(key))
 | 
						record, ok := c.cache.Get(c.keyFunc(key))
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -55,5 +66,25 @@ func (c *simpleCache) set(key []byte, transformer value.Transformer) {
 | 
				
			||||||
	if transformer == nil {
 | 
						if transformer == nil {
 | 
				
			||||||
		panic("transformer must not be nil")
 | 
							panic("transformer must not be nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.cache.Set(base64.StdEncoding.EncodeToString(key), transformer, c.ttl)
 | 
						c.cache.Set(c.keyFunc(key), transformer, c.ttl)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// keyFunc generates a string key by hashing the inputs.
 | 
				
			||||||
 | 
					// This lowers the memory requirement of the cache.
 | 
				
			||||||
 | 
					func (c *simpleCache) keyFunc(s []byte) string {
 | 
				
			||||||
 | 
						h := c.hashPool.Get().(hash.Hash)
 | 
				
			||||||
 | 
						h.Reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := h.Write(s); err != nil {
 | 
				
			||||||
 | 
							panic(err) // Write() on hash never fails
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation
 | 
				
			||||||
 | 
						c.hashPool.Put(h)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return key
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// toString performs unholy acts to avoid allocations
 | 
				
			||||||
 | 
					func toString(b []byte) string {
 | 
				
			||||||
 | 
						return *(*string)(unsafe.Pointer(&b))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,9 @@ limitations under the License.
 | 
				
			||||||
package kmsv2
 | 
					package kmsv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/sha256"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -57,3 +60,75 @@ func TestSimpleCacheSetError(t *testing.T) {
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestKeyFunc(t *testing.T) {
 | 
				
			||||||
 | 
						fakeClock := testingclock.NewFakeClock(time.Now())
 | 
				
			||||||
 | 
						cache := newSimpleCache(fakeClock, time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Run("AllocsPerRun test", func(t *testing.T) {
 | 
				
			||||||
 | 
							key, err := generateKey(encryptedDEKMaxSize) // simulate worst case EDEK
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							f := func() {
 | 
				
			||||||
 | 
								out := cache.keyFunc(key)
 | 
				
			||||||
 | 
								if len(out) != sha256.Size {
 | 
				
			||||||
 | 
									t.Errorf("Expected %d bytes, got %d", sha256.Size, len(out))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// prime the key func
 | 
				
			||||||
 | 
							var wg sync.WaitGroup
 | 
				
			||||||
 | 
							for i := 0; i < 100; i++ {
 | 
				
			||||||
 | 
								wg.Add(1)
 | 
				
			||||||
 | 
								go func() {
 | 
				
			||||||
 | 
									f()
 | 
				
			||||||
 | 
									wg.Done()
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							wg.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							allocs := testing.AllocsPerRun(100, f)
 | 
				
			||||||
 | 
							if allocs > 1 {
 | 
				
			||||||
 | 
								t.Errorf("Expected 1 allocations, got %v", allocs)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSimpleCache(t *testing.T) {
 | 
				
			||||||
 | 
						fakeClock := testingclock.NewFakeClock(time.Now())
 | 
				
			||||||
 | 
						cache := newSimpleCache(fakeClock, 5*time.Second)
 | 
				
			||||||
 | 
						envelopeTransformer := &envelopeTransformer{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
 | 
							k := fmt.Sprintf("key-%d", i)
 | 
				
			||||||
 | 
							wg.Add(1)
 | 
				
			||||||
 | 
							go func(key string) {
 | 
				
			||||||
 | 
								defer wg.Done()
 | 
				
			||||||
 | 
								cache.set([]byte(key), envelopeTransformer)
 | 
				
			||||||
 | 
							}(k)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if cache.cache.Len() != 10 {
 | 
				
			||||||
 | 
							t.Fatalf("Expected 10 items in the cache, got %v", cache.cache.Len())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
 | 
							k := fmt.Sprintf("key-%d", i)
 | 
				
			||||||
 | 
							if cache.get([]byte(k)) != envelopeTransformer {
 | 
				
			||||||
 | 
								t.Fatalf("Expected to get the transformer for key %v", k)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Wait for the cache to expire
 | 
				
			||||||
 | 
						fakeClock.Step(6 * time.Second)
 | 
				
			||||||
 | 
						for i := 0; i < 10; i++ {
 | 
				
			||||||
 | 
							k := fmt.Sprintf("key-%d", i)
 | 
				
			||||||
 | 
							if cache.get([]byte(k)) != nil {
 | 
				
			||||||
 | 
								t.Fatalf("Expected to get nil for key %v", k)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue