diff --git a/libmachine/drivers/rpc/server_driver.go b/libmachine/drivers/rpc/server_driver.go index 0fd3df5150..6fb3f3d5d8 100644 --- a/libmachine/drivers/rpc/server_driver.go +++ b/libmachine/drivers/rpc/server_driver.go @@ -3,6 +3,8 @@ package rpcdriver import ( "encoding/gob" "encoding/json" + "fmt" + "runtime/debug" "github.com/docker/machine/libmachine/drivers" "github.com/docker/machine/libmachine/log" @@ -11,6 +13,20 @@ import ( "github.com/docker/machine/libmachine/version" ) +type Stacker interface { + Stack() []byte +} + +type StandardStack struct{} + +func (ss *StandardStack) Stack() []byte { + return debug.Stack() +} + +var ( + stdStacker Stacker = &StandardStack{} +) + func init() { gob.Register(new(RPCFlags)) gob.Register(new(mcnflag.IntFlag)) @@ -108,8 +124,22 @@ func (r *RPCServerDriver) SetConfigRaw(data []byte, _ *struct{}) error { return json.Unmarshal(data, &r.ActualDriver) } -func (r *RPCServerDriver) Create(_, _ *struct{}) error { - return r.ActualDriver.Create() +func trapPanic(err *error) { + if r := recover(); r != nil { + *err = fmt.Errorf("Panic in the driver: %s\n%s", r.(error), stdStacker.Stack()) + } +} + +func (r *RPCServerDriver) Create(_, _ *struct{}) (err error) { + // In an ideal world, plugins wouldn't ever panic. However, panics + // have been known to happen and cause issues. Therefore, we recover + // and do not crash the RPC server completely in the case of a panic + // during create. + defer trapPanic(&err) + + err = r.ActualDriver.Create() + + return err } func (r *RPCServerDriver) DriverName(_ *struct{}, reply *string) error { diff --git a/libmachine/drivers/rpc/server_driver_test.go b/libmachine/drivers/rpc/server_driver_test.go new file mode 100644 index 0000000000..952684df39 --- /dev/null +++ b/libmachine/drivers/rpc/server_driver_test.go @@ -0,0 +1,75 @@ +package rpcdriver + +import ( + "errors" + "testing" + + "github.com/docker/machine/drivers/fakedriver" + "github.com/stretchr/testify/assert" +) + +type panicDriver struct { + *fakedriver.Driver + panicErr error + returnErr error +} + +type FakeStacker struct { + trace []byte +} + +func (fs *FakeStacker) Stack() []byte { + return fs.trace +} + +func (p *panicDriver) Create() error { + if p.panicErr != nil { + panic(p.panicErr) + } + return p.returnErr +} + +func TestRPCServerDriverCreate(t *testing.T) { + testCases := []struct { + description string + expectedErr error + serverDriver *RPCServerDriver + stacker Stacker + }{ + { + description: "Happy path", + expectedErr: nil, + serverDriver: &RPCServerDriver{ + ActualDriver: &panicDriver{ + returnErr: nil, + }, + }, + }, + { + description: "Normal error, no panic", + expectedErr: errors.New("API not available"), + serverDriver: &RPCServerDriver{ + ActualDriver: &panicDriver{ + returnErr: errors.New("API not available"), + }, + }, + }, + { + description: "Panic happened during create", + expectedErr: errors.New("Panic in the driver: index out of range\nSTACK TRACE"), + serverDriver: &RPCServerDriver{ + ActualDriver: &panicDriver{ + panicErr: errors.New("index out of range"), + }, + }, + stacker: &FakeStacker{ + trace: []byte("STACK TRACE"), + }, + }, + } + + for _, tc := range testCases { + stdStacker = tc.stacker + assert.Equal(t, tc.expectedErr, tc.serverDriver.Create(nil, nil)) + } +}