diff --git a/client/daemon/storage/keepalive.go b/client/daemon/storage/keepalive.go new file mode 100644 index 000000000..e7d4df4b4 --- /dev/null +++ b/client/daemon/storage/keepalive.go @@ -0,0 +1,76 @@ +/* + * Copyright 2024 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package storage + +import ( + "context" + "io" + + commonv1 "d7y.io/api/v2/pkg/apis/common/v1" + + "d7y.io/dragonfly/v2/client/util" +) + +type keepAliveTaskStorageDriver struct { + TaskStorageDriver + util.KeepAlive +} + +func (k *keepAliveTaskStorageDriver) WritePiece(ctx context.Context, req *WritePieceRequest) (int64, error) { + k.Keep() + return k.TaskStorageDriver.WritePiece(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) ReadPiece(ctx context.Context, req *ReadPieceRequest) (io.Reader, io.Closer, error) { + k.Keep() + return k.TaskStorageDriver.ReadPiece(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) ReadAllPieces(ctx context.Context, req *ReadAllPiecesRequest) (io.ReadCloser, error) { + k.Keep() + return k.TaskStorageDriver.ReadAllPieces(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) GetPieces(ctx context.Context, req *commonv1.PieceTaskRequest) (*commonv1.PiecePacket, error) { + k.Keep() + return k.TaskStorageDriver.GetPieces(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) GetTotalPieces(ctx context.Context, req *PeerTaskMetadata) (int32, error) { + k.Keep() + return k.TaskStorageDriver.GetTotalPieces(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) GetExtendAttribute(ctx context.Context, req *PeerTaskMetadata) (*commonv1.ExtendAttribute, error) { + k.Keep() + return k.TaskStorageDriver.GetExtendAttribute(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) UpdateTask(ctx context.Context, req *UpdateTaskRequest) error { + k.Keep() + return k.TaskStorageDriver.UpdateTask(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) Store(ctx context.Context, req *StoreRequest) error { + k.Keep() + return k.TaskStorageDriver.Store(ctx, req) +} + +func (k *keepAliveTaskStorageDriver) ValidateDigest(req *PeerTaskMetadata) error { + k.Keep() + return k.TaskStorageDriver.ValidateDigest(req) +} diff --git a/client/daemon/storage/storage_manager.go b/client/daemon/storage/storage_manager.go index 52e00b468..2ae420b13 100644 --- a/client/daemon/storage/storage_manager.go +++ b/client/daemon/storage/storage_manager.go @@ -257,7 +257,7 @@ func (s *storageManager) RegisterTask(ctx context.Context, req *RegisterTaskRequ TaskID: req.TaskID, }) if ok { - return ts, nil + return s.keepAliveTaskStorageDriver(ts), nil } // double check if task store exists // if ok, just unlock and return @@ -268,10 +268,14 @@ func (s *storageManager) RegisterTask(ctx context.Context, req *RegisterTaskRequ PeerID: req.PeerID, TaskID: req.TaskID, }); ok { - return ts, nil + return s.keepAliveTaskStorageDriver(ts), nil } // still not exist, create a new task store - return s.CreateTask(req) + ts, err := s.CreateTask(req) + if err != nil { + return nil, err + } + return s.keepAliveTaskStorageDriver(ts), err } func (s *storageManager) RegisterSubTask(ctx context.Context, req *RegisterSubTaskRequest) (TaskStorageDriver, error) { @@ -301,7 +305,7 @@ func (s *storageManager) RegisterSubTask(ctx context.Context, req *RegisterSubTa TaskID: req.SubTask.TaskID, }, subtask) s.Unlock() - return subtask, nil + return s.keepAliveTaskStorageDriver(subtask), nil } func (s *storageManager) WritePiece(ctx context.Context, req *WritePieceRequest) (int64, error) { @@ -1106,3 +1110,10 @@ func (s *storageManager) ListAllPeers(perGroupCount int) [][]*dfdaemonv1.PeerMet } return allPeers } + +func (s *storageManager) keepAliveTaskStorageDriver(ts TaskStorageDriver) TaskStorageDriver { + return &keepAliveTaskStorageDriver{ + KeepAlive: s, + TaskStorageDriver: ts, + } +}