188 lines
5.6 KiB
Go
188 lines
5.6 KiB
Go
/*
|
|
* Copyright 2024 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package ratelimiter
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
logger "d7y.io/dragonfly/v2/internal/dflog"
|
|
"d7y.io/dragonfly/v2/manager/config"
|
|
"d7y.io/dragonfly/v2/manager/database"
|
|
"d7y.io/dragonfly/v2/manager/models"
|
|
"d7y.io/dragonfly/v2/manager/types"
|
|
)
|
|
|
|
const (
|
|
// jobRateLimiterSuffix is the suffix of the job rate limiter key.
|
|
jobRateLimiterSuffix = "job"
|
|
|
|
// defaultRefreshInterval is the default interval to refresh the rate limiters.
|
|
defaultRefreshInterval = 3 * time.Minute
|
|
)
|
|
|
|
// JobRateLimiter is an interface for a job rate limiter.
|
|
type JobRateLimiter interface {
|
|
// AllowByClusterID checks if a request is allowed based on the rate limit for a specific cluster ID.
|
|
AllowByClusterID(ctx context.Context, clusterID uint) bool
|
|
|
|
// AllowByClusterIDs checks if a request is allowed based on the rate limit for multiple cluster IDs.
|
|
// If any cluster ID is not allowed, it returns false.
|
|
AllowByClusterIDs(ctx context.Context, clusterIDs []uint) bool
|
|
|
|
// Serve started job rate limiter server.
|
|
Serve()
|
|
|
|
// Stop job rate limiter server.
|
|
Stop()
|
|
}
|
|
|
|
// jobRateLimiter is an implementation of JobRateLimiter.
|
|
type jobRateLimiter struct {
|
|
// database used to store the rate limit.
|
|
database *database.Database
|
|
|
|
// clusters is a map of rate limiters for each cluster.
|
|
clusters *sync.Map
|
|
|
|
// refreshInterval is the interval to refresh the rate limiters.
|
|
refreshInterval time.Duration
|
|
|
|
// done is the channel to stop the rate limiter server.
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewJobRateLimiter creates a new instance of JobRateLimiter.
|
|
func NewJobRateLimiter(database *database.Database) (JobRateLimiter, error) {
|
|
j := &jobRateLimiter{
|
|
database: database,
|
|
clusters: &sync.Map{},
|
|
refreshInterval: defaultRefreshInterval,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
if err := j.refresh(context.Background()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return j, nil
|
|
}
|
|
|
|
// AllowByClusterID checks if a request is allowed based on the rate limit for a specific cluster ID.
|
|
func (j *jobRateLimiter) AllowByClusterID(ctx context.Context, clusterID uint) bool {
|
|
rawLimiter, loaded := j.clusters.Load(clusterID)
|
|
if !loaded {
|
|
logger.Errorf("[job-rate-limiter]: cluster %d not found", clusterID)
|
|
return false
|
|
}
|
|
|
|
limiter, ok := rawLimiter.(DistributedRateLimiter)
|
|
if !ok {
|
|
logger.Errorf("[job-rate-limiter]: cluster %d is not a distributed rate limiter", clusterID)
|
|
return false
|
|
}
|
|
|
|
result, err := limiter.Allow(ctx)
|
|
if err != nil {
|
|
logger.Errorf("[job-rate-limiter]: cluster %d allow failed: %v", clusterID, err)
|
|
return false
|
|
}
|
|
|
|
if result.Allowed == 0 {
|
|
logger.Errorf("[job-rate-limiter]: cluster %d rate limit exceeded", clusterID)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// AllowByClusterIDs checks if a request is allowed based on the rate limit for multiple cluster IDs.
|
|
// If any cluster ID is not allowed, it returns false.
|
|
func (j *jobRateLimiter) AllowByClusterIDs(ctx context.Context, clusterIDs []uint) bool {
|
|
for _, clusterID := range clusterIDs {
|
|
if allowed := j.AllowByClusterID(ctx, clusterID); !allowed {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// Serve started rate limiter server.
|
|
func (j *jobRateLimiter) Serve() {
|
|
tick := time.NewTicker(j.refreshInterval)
|
|
for {
|
|
select {
|
|
case <-tick.C:
|
|
logger.Infof("[job-rate-limiter]: refresh job rate limiter started")
|
|
if err := j.refresh(context.Background()); err != nil {
|
|
logger.Errorf("[job-rate-limiter]: refresh job rate limiter failed: %v", err)
|
|
}
|
|
case <-j.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop rate limiter server.
|
|
func (j *jobRateLimiter) Stop() {
|
|
close(j.done)
|
|
}
|
|
|
|
// refresh refreshes the rate limiters for all scheduler clusters.
|
|
func (j *jobRateLimiter) refresh(ctx context.Context) error {
|
|
var schedulerClusters []models.SchedulerCluster
|
|
if err := j.database.DB.WithContext(ctx).Find(&schedulerClusters).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
j.clusters.Clear()
|
|
for _, schedulerCluster := range schedulerClusters {
|
|
b, err := schedulerCluster.Config.MarshalJSON()
|
|
if err != nil {
|
|
logger.Errorf("[job-rate-limiter]: marshal scheduler cluster %d config failed: %v", schedulerCluster.ID, err)
|
|
return err
|
|
}
|
|
|
|
var schedulerClusterConfig types.SchedulerClusterConfig
|
|
if err := json.Unmarshal(b, &schedulerClusterConfig); err != nil {
|
|
logger.Errorf("[job-rate-limiter]: unmarshal scheduler cluster %d config failed: %v", schedulerCluster.ID, err)
|
|
return err
|
|
}
|
|
|
|
// Use the default rate limit if the rate limit is not set.
|
|
jobRateLimit := config.DefaultClusterJobRateLimit
|
|
if schedulerClusterConfig.JobRateLimit != 0 {
|
|
jobRateLimit = schedulerClusterConfig.JobRateLimit
|
|
}
|
|
|
|
logger.Debugf("[job-rate-limiter]: create job rate limiter for scheduler cluster %d with rate limit %d", schedulerCluster.ID, jobRateLimit)
|
|
j.clusters.Store(schedulerCluster.ID,
|
|
NewDistributedRateLimiter(j.database.RDB, j.key(schedulerCluster.ID), jobRateLimit))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// key is the rate limiter key for storing value in the database.
|
|
func (j *jobRateLimiter) key(clusterID uint) string {
|
|
return fmt.Sprintf("%d-%s", clusterID, jobRateLimiterSuffix)
|
|
}
|