diff --git a/src/cmd/initContainer.go b/src/cmd/initContainer.go index cac4ddf..0ead557 100644 --- a/src/cmd/initContainer.go +++ b/src/cmd/initContainer.go @@ -17,6 +17,7 @@ package cmd import ( + "context" "errors" "fmt" "io/ioutil" @@ -353,6 +354,28 @@ func initContainer(cmd *cobra.Command, args []string) error { return err } + referenceCountGlobalLock, err := utils.GetReferenceCountGlobalLock(targetUser) + if err != nil { + return err + } + + var waitForRun bool + if referenceCountGlobalLockFile, err := utils.Flock(referenceCountGlobalLock, + syscall.LOCK_EX|syscall.LOCK_NB); err == nil { + waitForRun = true + if err := referenceCountGlobalLockFile.Close(); err != nil { + logrus.Debugf("Releasing global reference count lock: %s", err) + return utils.ErrFlockRelease + } + } + + parentCtx := context.Background() + waitForExitCtx, waitForExitCancel := context.WithCancelCause(parentCtx) + defer waitForExitCancel(errors.New("clean-up")) + + detectWhenContainerIsUnsedAsync(waitForExitCancel, initializedStamp, referenceCountGlobalLock, waitForRun) + done := waitForExitCtx.Done() + logrus.Debugf("Creating initialization stamp %s", initializedStamp) initializedStampFile, err := os.Create(initializedStamp) @@ -372,6 +395,16 @@ func initContainer(cmd *cobra.Command, args []string) error { for { select { + case <-done: + logrus.Debugf("Removing initialization stamp %s", initializedStamp) + if err := os.Remove(initializedStamp); err != nil { + logrus.Debugf("Removing initialization stamp %s failed: %s", initializedStamp, err) + return errors.New("failed to remove initialization stamp") + } + + cause := context.Cause(waitForExitCtx) + logrus.Debugf("Exiting entry point: %s", cause) + return nil case event := <-tickerDaily.C: handleDailyTick(event) case event := <-watcherForHostEvents: @@ -788,6 +821,55 @@ func createSymbolicLink(existingTarget, newLink string) error { return nil } +func detectWhenContainerIsUnsedAsync(cancel context.CancelCauseFunc, + initializedStamp, referenceCountGlobalLock string, + waitForRun bool) { + + go func() { + if waitForRun { + logrus.Debugf("This entry point was not started by 'toolbox enter' or 'toolbox run'") + logrus.Debugf("Waiting for 'toolbox enter' or 'toolbox run'") + time.Sleep(25 * time.Second) + } + + for { + logrus.Debugf("Waiting for 'podman exec' to begin") + if err := waitForExecToBegin(referenceCountGlobalLock); err != nil { + if errors.Is(err, utils.ErrFlockRelease) { + cancel(err) + } else { + logrus.Debugf("Waiting for 'podman exec' to begin: %s", err) + logrus.Debug("This entry point will not exit when the container is unused") + } + + return + } + + logrus.Debugf("Waiting for the container to be unused") + if err := waitForContainerToBeUnused(initializedStamp, + referenceCountGlobalLock); err != nil { + if errors.Is(err, syscall.EWOULDBLOCK) { + logrus.Debug("Detected potentially new use of the container") + continue + } + + if errors.Is(err, utils.ErrFlockRelease) { + cancel(err) + } else { + logrus.Debugf("Waiting for the container to be unused: %s", err) + logrus.Debug("This entry point will not exit when the container is unused") + } + + return + } + + cause := errors.New("all 'podman exec' sessions exited") + cancel(cause) + return + } + }() +} + func getDelayEntryPoint() (time.Duration, bool) { valueString := os.Getenv("TOOLBX_DELAY_ENTRY_POINT") if valueString == "" { @@ -1123,6 +1205,43 @@ func updateTimeZoneFromLocalTime() error { return nil } +func waitForExecToBegin(referenceCountGlobalLock string) error { + referenceCountGlobalLockFile, err := utils.Flock(referenceCountGlobalLock, syscall.LOCK_EX) + if err != nil { + return err + } + + if err := referenceCountGlobalLockFile.Close(); err != nil { + logrus.Debugf("Releasing global reference count lock: %s", err) + return utils.ErrFlockRelease + } + + return nil +} + +func waitForContainerToBeUnused(initializedStamp, referenceCountGlobalLock string) error { + referenceCountLocalLockFile, err := utils.Flock(initializedStamp, syscall.LOCK_EX) + if err != nil { + if errors.Is(err, syscall.EWOULDBLOCK) { + panicMsg := fmt.Sprintf("unexpected %T: %s", err, err) + panic(panicMsg) + } + + return err + } + + if _, err := utils.Flock(referenceCountGlobalLock, syscall.LOCK_EX|syscall.LOCK_NB); err != nil { + if err := referenceCountLocalLockFile.Close(); err != nil { + logrus.Debugf("Releasing local reference count lock: %s", err) + return utils.ErrFlockRelease + } + + return err + } + + return nil +} + func writeTimeZone(timeZone string) error { const etcTimeZone = "/etc/timezone" diff --git a/src/cmd/run.go b/src/cmd/run.go index 389ea16..dde645a 100644 --- a/src/cmd/run.go +++ b/src/cmd/run.go @@ -243,6 +243,35 @@ func runCommand(container string, } } + logrus.Debug("Acquiring global reference count lock") + + referenceCountGlobalLock, err := utils.GetReferenceCountGlobalLock(currentUser) + if err != nil { + return err + } + + referenceCountGlobalLockFile, err := utils.Flock(referenceCountGlobalLock, syscall.LOCK_SH) + if err != nil { + logrus.Debugf("Acquiring global reference count lock: %s", err) + + var errFlock *utils.FlockError + + if errors.As(err, &errFlock) { + if errors.Is(err, utils.ErrFlockAcquire) { + err = utils.ErrFlockAcquire + } else if errors.Is(err, utils.ErrFlockCreate) { + err = utils.ErrFlockCreate + } else { + panicMsg := fmt.Sprintf("unexpected %T: %s", err, err) + panic(panicMsg) + } + } + + return err + } + + defer referenceCountGlobalLockFile.Close() + logrus.Debugf("Inspecting container %s", container) containerObj, err := podman.InspectContainer(container) if err != nil { @@ -335,11 +364,45 @@ func runCommand(container string, logrus.Debugf("Waiting for container %s to finish initializing", container) } - if err := ensureContainerIsInitialized(container, entryPointPID, startContainerTimestamp); err != nil { + initializedStamp, err := utils.GetInitializedStamp(entryPointPID, currentUser) + if err != nil { + return err + } + + if err := ensureContainerIsInitialized(container, initializedStamp, startContainerTimestamp); err != nil { return err } logrus.Debugf("Container %s is initialized", container) + logrus.Debug("Acquiring local reference count lock") + + referenceCountLocalLockFile, err := utils.Flock(initializedStamp, syscall.LOCK_SH) + if err != nil { + logrus.Debugf("Acquiring local reference count lock: %s", err) + + var errFlock *utils.FlockError + + if errors.As(err, &errFlock) { + if errors.Is(err, utils.ErrFlockAcquire) { + err = utils.ErrFlockAcquire + } else if errors.Is(err, utils.ErrFlockCreate) { + err = utils.ErrFlockCreate + } else { + panicMsg := fmt.Sprintf("unexpected %T: %s", err, err) + panic(panicMsg) + } + } + + return err + } + + defer referenceCountLocalLockFile.Close() + + logrus.Debug("Releasing global reference count lock") + if err := referenceCountGlobalLockFile.Close(); err != nil { + logrus.Debugf("Releasing global reference count lock: %s", err) + return utils.ErrFlockRelease + } environ := append(cdiEnviron, p11KitServerEnviron...) if err := runCommandWithFallbacks(container, @@ -598,12 +661,7 @@ func constructExecArgs(container, preserveFDs string, return execArgs } -func ensureContainerIsInitialized(container string, entryPointPID int, timestamp time.Time) error { - initializedStamp, err := utils.GetInitializedStamp(entryPointPID, currentUser) - if err != nil { - return err - } - +func ensureContainerIsInitialized(container, initializedStamp string, timestamp time.Time) error { logrus.Debugf("Checking if initialization stamp %s exists", initializedStamp) shouldUsePolling := isUsePollingSet() diff --git a/src/pkg/utils/utils.go b/src/pkg/utils/utils.go index 627bdeb..10481fd 100644 --- a/src/pkg/utils/utils.go +++ b/src/pkg/utils/utils.go @@ -184,6 +184,8 @@ var ( ErrFlockCreate = errors.New("failed to create lock file") + ErrFlockRelease = errors.New("failed to release lock") + ErrImageWithoutBasename = errors.New("image does not have a basename") ) @@ -498,6 +500,16 @@ func GetP11KitServerSocketLock(targetUser *user.User) (string, error) { return p11KitServerSocketLock, nil } +func GetReferenceCountGlobalLock(targetUser *user.User) (string, error) { + toolbxRuntimeDirectory, err := GetRuntimeDirectory(targetUser) + if err != nil { + return "", err + } + + referenceCountGlobalLock := filepath.Join(toolbxRuntimeDirectory, "container-reference-count.lock") + return referenceCountGlobalLock, nil +} + func GetRuntimeDirectory(targetUser *user.User) (string, error) { if runtimeDirectories == nil { runtimeDirectories = make(map[string]string)