Merge pull request #1506 from huone1/refactor/selectclusters
refactor the selectclusters process
This commit is contained in:
commit
70a08589f5
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/karmada-io/karmada
|
||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/agiledragon/gomonkey/v2 v2.5.0
|
||||||
github.com/distribution/distribution/v3 v3.0.0-20210507173845-9329f6a62b67
|
github.com/distribution/distribution/v3 v3.0.0-20210507173845-9329f6a62b67
|
||||||
github.com/evanphx/json-patch/v5 v5.6.0
|
github.com/evanphx/json-patch/v5 v5.6.0
|
||||||
github.com/gogo/protobuf v1.3.2
|
github.com/gogo/protobuf v1.3.2
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -92,6 +92,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdko
|
||||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
|
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
|
||||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
|
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
|
||||||
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
|
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
|
||||||
|
github.com/agiledragon/gomonkey/v2 v2.5.0 h1:CLygw0ubsk/Gv07g4fhyHPp3YdjdECkofQTS9wg+KVs=
|
||||||
|
github.com/agiledragon/gomonkey/v2 v2.5.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY=
|
||||||
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
|
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
|
|
|
@ -11,10 +11,10 @@ import (
|
||||||
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2"
|
workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2"
|
||||||
"github.com/karmada-io/karmada/pkg/scheduler/cache"
|
"github.com/karmada-io/karmada/pkg/scheduler/cache"
|
||||||
|
"github.com/karmada-io/karmada/pkg/scheduler/core/spreadconstraint"
|
||||||
"github.com/karmada-io/karmada/pkg/scheduler/framework"
|
"github.com/karmada-io/karmada/pkg/scheduler/framework"
|
||||||
"github.com/karmada-io/karmada/pkg/scheduler/framework/runtime"
|
"github.com/karmada-io/karmada/pkg/scheduler/framework/runtime"
|
||||||
"github.com/karmada-io/karmada/pkg/scheduler/metrics"
|
"github.com/karmada-io/karmada/pkg/scheduler/metrics"
|
||||||
"github.com/karmada-io/karmada/pkg/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ScheduleAlgorithm is the interface that should be implemented to schedule a resource to the target clusters.
|
// ScheduleAlgorithm is the interface that should be implemented to schedule a resource to the target clusters.
|
||||||
|
@ -64,7 +64,11 @@ func (g *genericScheduler) Schedule(ctx context.Context, placement *policyv1alph
|
||||||
}
|
}
|
||||||
klog.V(4).Infof("feasible clusters scores: %v", clustersScore)
|
klog.V(4).Infof("feasible clusters scores: %v", clustersScore)
|
||||||
|
|
||||||
clusters := g.selectClusters(clustersScore, placement.SpreadConstraints, feasibleClusters)
|
clusters, err := g.selectClusters(clustersScore, placement, spec)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("failed to select clusters: %v", err)
|
||||||
|
}
|
||||||
|
klog.V(4).Infof("selected clusters: %v", clusters)
|
||||||
|
|
||||||
clustersWithReplicas, err := g.assignReplicas(clusters, placement.ReplicaScheduling, spec)
|
clustersWithReplicas, err := g.assignReplicas(clusters, placement.ReplicaScheduling, spec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -122,76 +126,13 @@ func (g *genericScheduler) prioritizeClusters(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *genericScheduler) selectClusters(clustersScore framework.ClusterScoreList, spreadConstraints []policyv1alpha1.SpreadConstraint, clusters []*clusterv1alpha1.Cluster) []*clusterv1alpha1.Cluster {
|
func (g *genericScheduler) selectClusters(clustersScore framework.ClusterScoreList,
|
||||||
|
placement *policyv1alpha1.Placement, spec *workv1alpha2.ResourceBindingSpec) ([]*clusterv1alpha1.Cluster, error) {
|
||||||
defer metrics.ScheduleStep(metrics.ScheduleStepSelect, time.Now())
|
defer metrics.ScheduleStep(metrics.ScheduleStepSelect, time.Now())
|
||||||
|
|
||||||
if len(spreadConstraints) != 0 {
|
groupClustersInfo := spreadconstraint.GroupClustersWithScore(clustersScore, placement, spec)
|
||||||
return g.matchSpreadConstraints(clusters, spreadConstraints)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clusters
|
return spreadconstraint.SelectBestClusters(placement, groupClustersInfo)
|
||||||
}
|
|
||||||
|
|
||||||
func (g *genericScheduler) matchSpreadConstraints(clusters []*clusterv1alpha1.Cluster, spreadConstraints []policyv1alpha1.SpreadConstraint) []*clusterv1alpha1.Cluster {
|
|
||||||
state := util.NewSpreadGroup()
|
|
||||||
g.runSpreadConstraintsFilter(clusters, spreadConstraints, state)
|
|
||||||
return g.calSpreadResult(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now support spread by cluster. More rules will be implemented later.
|
|
||||||
func (g *genericScheduler) runSpreadConstraintsFilter(clusters []*clusterv1alpha1.Cluster, spreadConstraints []policyv1alpha1.SpreadConstraint, spreadGroup *util.SpreadGroup) {
|
|
||||||
for _, spreadConstraint := range spreadConstraints {
|
|
||||||
spreadGroup.InitialGroupRecord(spreadConstraint)
|
|
||||||
if spreadConstraint.SpreadByField == policyv1alpha1.SpreadByFieldCluster {
|
|
||||||
g.groupByFieldCluster(clusters, spreadConstraint, spreadGroup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *genericScheduler) groupByFieldCluster(clusters []*clusterv1alpha1.Cluster, spreadConstraint policyv1alpha1.SpreadConstraint, spreadGroup *util.SpreadGroup) {
|
|
||||||
for _, cluster := range clusters {
|
|
||||||
clusterGroup := cluster.Name
|
|
||||||
spreadGroup.GroupRecord[spreadConstraint][clusterGroup] = append(spreadGroup.GroupRecord[spreadConstraint][clusterGroup], cluster)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *genericScheduler) calSpreadResult(spreadGroup *util.SpreadGroup) []*clusterv1alpha1.Cluster {
|
|
||||||
// TODO: now support single spread constraint
|
|
||||||
if len(spreadGroup.GroupRecord) > 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return g.chooseSpreadGroup(spreadGroup)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *genericScheduler) chooseSpreadGroup(spreadGroup *util.SpreadGroup) []*clusterv1alpha1.Cluster {
|
|
||||||
var feasibleClusters []*clusterv1alpha1.Cluster
|
|
||||||
for spreadConstraint, clusterGroups := range spreadGroup.GroupRecord {
|
|
||||||
if spreadConstraint.SpreadByField == policyv1alpha1.SpreadByFieldCluster {
|
|
||||||
if len(clusterGroups) < spreadConstraint.MinGroups {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(clusterGroups) <= spreadConstraint.MaxGroups {
|
|
||||||
for _, v := range clusterGroups {
|
|
||||||
feasibleClusters = append(feasibleClusters, v...)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if spreadConstraint.MaxGroups > 0 && len(clusterGroups) > spreadConstraint.MaxGroups {
|
|
||||||
var groups []string
|
|
||||||
for group := range clusterGroups {
|
|
||||||
groups = append(groups, group)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < spreadConstraint.MaxGroups; i++ {
|
|
||||||
feasibleClusters = append(feasibleClusters, clusterGroups[groups[i]]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return feasibleClusters
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *genericScheduler) assignReplicas(
|
func (g *genericScheduler) assignReplicas(
|
||||||
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
package spreadconstraint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
||||||
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
|
workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2"
|
||||||
|
"github.com/karmada-io/karmada/pkg/scheduler/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupClustersInfo indicate the cluster global view
|
||||||
|
type GroupClustersInfo struct {
|
||||||
|
Providers map[string]ProviderInfo
|
||||||
|
Regions map[string]RegionInfo
|
||||||
|
Zones map[string]ZoneInfo
|
||||||
|
|
||||||
|
// Clusters from global view, sorted by cluster.Score descending.
|
||||||
|
Clusters []ClusterDetailInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderInfo indicate the provider information
|
||||||
|
type ProviderInfo struct {
|
||||||
|
Name string
|
||||||
|
AvailableReplicas int64
|
||||||
|
|
||||||
|
// Regions under this provider
|
||||||
|
Regions map[string]struct{}
|
||||||
|
// Zones under this provider
|
||||||
|
Zones map[string]struct{}
|
||||||
|
// Clusters under this provider, sorted by cluster.Score descending.
|
||||||
|
Clusters []ClusterDetailInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegionInfo indicate the region information
|
||||||
|
type RegionInfo struct {
|
||||||
|
Name string
|
||||||
|
AvailableReplicas int64
|
||||||
|
|
||||||
|
// Zones under this provider
|
||||||
|
Zones map[string]struct{}
|
||||||
|
// Clusters under this region, sorted by cluster.Score descending.
|
||||||
|
Clusters []ClusterDetailInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZoneInfo indicate the zone information
|
||||||
|
type ZoneInfo struct {
|
||||||
|
Name string
|
||||||
|
AvailableReplicas int64
|
||||||
|
|
||||||
|
// Clusters under this zone, sorted by cluster.Score descending.
|
||||||
|
Clusters []ClusterDetailInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterDetailInfo indicate the cluster information
|
||||||
|
type ClusterDetailInfo struct {
|
||||||
|
Name string
|
||||||
|
Score int64
|
||||||
|
AvailableReplicas int64
|
||||||
|
|
||||||
|
Cluster *clusterv1alpha1.Cluster
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupClustersWithScore groups cluster base provider/region/zone/cluster
|
||||||
|
func GroupClustersWithScore(
|
||||||
|
clustersScore framework.ClusterScoreList,
|
||||||
|
placement *policyv1alpha1.Placement,
|
||||||
|
spec *workv1alpha2.ResourceBindingSpec,
|
||||||
|
) *GroupClustersInfo {
|
||||||
|
if isTopologyIgnored(placement) {
|
||||||
|
return groupClustersIngoreTopology(clustersScore, spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupClustersBasedTopology(clustersScore, spec, placement.SpreadConstraints)
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupClustersBasedTopology(
|
||||||
|
clustersScore framework.ClusterScoreList,
|
||||||
|
rbSpec *workv1alpha2.ResourceBindingSpec,
|
||||||
|
spreadConstraints []policyv1alpha1.SpreadConstraint,
|
||||||
|
) *GroupClustersInfo {
|
||||||
|
groupClustersInfo := &GroupClustersInfo{
|
||||||
|
Providers: make(map[string]ProviderInfo),
|
||||||
|
Regions: make(map[string]RegionInfo),
|
||||||
|
Zones: make(map[string]ZoneInfo),
|
||||||
|
}
|
||||||
|
|
||||||
|
groupClustersInfo.generateClustersInfo(clustersScore, rbSpec)
|
||||||
|
groupClustersInfo.generateZoneInfo(spreadConstraints)
|
||||||
|
groupClustersInfo.generateRegionInfo(spreadConstraints)
|
||||||
|
groupClustersInfo.generateProviderInfo(spreadConstraints)
|
||||||
|
|
||||||
|
return groupClustersInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupClustersIngoreTopology(
|
||||||
|
clustersScore framework.ClusterScoreList,
|
||||||
|
rbSpec *workv1alpha2.ResourceBindingSpec,
|
||||||
|
) *GroupClustersInfo {
|
||||||
|
groupClustersInfo := &GroupClustersInfo{}
|
||||||
|
groupClustersInfo.generateClustersInfo(clustersScore, rbSpec)
|
||||||
|
|
||||||
|
return groupClustersInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *GroupClustersInfo) generateClustersInfo(clustersScore framework.ClusterScoreList, rbSpec *workv1alpha2.ResourceBindingSpec) {
|
||||||
|
var clusters []*clusterv1alpha1.Cluster
|
||||||
|
for _, clusterScore := range clustersScore {
|
||||||
|
clusterInfo := ClusterDetailInfo{}
|
||||||
|
clusterInfo.Name = clusterScore.Cluster.Name
|
||||||
|
clusterInfo.Score = clusterScore.Score
|
||||||
|
clusterInfo.Cluster = clusterScore.Cluster
|
||||||
|
info.Clusters = append(info.Clusters, clusterInfo)
|
||||||
|
clusters = append(clusters, clusterScore.Cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
clustersReplicas := calAvailableReplicas(clusters, rbSpec)
|
||||||
|
for i, clustersReplica := range clustersReplicas {
|
||||||
|
info.Clusters[i].AvailableReplicas = int64(clustersReplica.Replicas)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortClusters(info.Clusters)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *GroupClustersInfo) generateZoneInfo(spreadConstraints []policyv1alpha1.SpreadConstraint) {
|
||||||
|
if !IsSpreadConstraintExisted(spreadConstraints, policyv1alpha1.SpreadByFieldZone) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clusterInfo := range info.Clusters {
|
||||||
|
zone := clusterInfo.Cluster.Spec.Zone
|
||||||
|
if zone == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneInfo, ok := info.Zones[zone]
|
||||||
|
if !ok {
|
||||||
|
zoneInfo = ZoneInfo{
|
||||||
|
Name: zone,
|
||||||
|
Clusters: make([]ClusterDetailInfo, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
zoneInfo.Clusters = append(zoneInfo.Clusters, clusterInfo)
|
||||||
|
zoneInfo.AvailableReplicas += clusterInfo.AvailableReplicas
|
||||||
|
info.Zones[zone] = zoneInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *GroupClustersInfo) generateRegionInfo(spreadConstraints []policyv1alpha1.SpreadConstraint) {
|
||||||
|
if !IsSpreadConstraintExisted(spreadConstraints, policyv1alpha1.SpreadByFieldRegion) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clusterInfo := range info.Clusters {
|
||||||
|
region := clusterInfo.Cluster.Spec.Region
|
||||||
|
if region == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
regionInfo, ok := info.Regions[region]
|
||||||
|
if !ok {
|
||||||
|
regionInfo = RegionInfo{
|
||||||
|
Name: region,
|
||||||
|
Zones: make(map[string]struct{}),
|
||||||
|
Clusters: make([]ClusterDetailInfo, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if clusterInfo.Cluster.Spec.Zone != "" {
|
||||||
|
regionInfo.Zones[clusterInfo.Cluster.Spec.Zone] = struct{}{}
|
||||||
|
}
|
||||||
|
regionInfo.Clusters = append(regionInfo.Clusters, clusterInfo)
|
||||||
|
regionInfo.AvailableReplicas += clusterInfo.AvailableReplicas
|
||||||
|
info.Regions[region] = regionInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *GroupClustersInfo) generateProviderInfo(spreadConstraints []policyv1alpha1.SpreadConstraint) {
|
||||||
|
if !IsSpreadConstraintExisted(spreadConstraints, policyv1alpha1.SpreadByFieldProvider) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clusterInfo := range info.Clusters {
|
||||||
|
provider := clusterInfo.Cluster.Spec.Provider
|
||||||
|
if provider == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
providerInfo, ok := info.Providers[provider]
|
||||||
|
if !ok {
|
||||||
|
providerInfo = ProviderInfo{
|
||||||
|
Name: provider,
|
||||||
|
Regions: make(map[string]struct{}),
|
||||||
|
Zones: make(map[string]struct{}),
|
||||||
|
Clusters: make([]ClusterDetailInfo, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if clusterInfo.Cluster.Spec.Zone != "" {
|
||||||
|
providerInfo.Zones[clusterInfo.Cluster.Spec.Zone] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if clusterInfo.Cluster.Spec.Region != "" {
|
||||||
|
providerInfo.Regions[clusterInfo.Cluster.Spec.Region] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
providerInfo.Clusters = append(providerInfo.Clusters, clusterInfo)
|
||||||
|
providerInfo.AvailableReplicas += clusterInfo.AvailableReplicas
|
||||||
|
info.Providers[provider] = providerInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTopologyIgnored(placement *policyv1alpha1.Placement) bool {
|
||||||
|
strategy := placement.ReplicaScheduling
|
||||||
|
spreadConstraints := placement.SpreadConstraints
|
||||||
|
|
||||||
|
if len(spreadConstraints) == 0 || (len(spreadConstraints) == 1 && spreadConstraints[0].SpreadByField == policyv1alpha1.SpreadByFieldCluster) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the replica division preference is 'static weighted', ignore the declaration specified by spread constraints.
|
||||||
|
if strategy != nil && strategy.ReplicaSchedulingType == policyv1alpha1.ReplicaSchedulingTypeDivided &&
|
||||||
|
strategy.ReplicaDivisionPreference == policyv1alpha1.ReplicaDivisionPreferenceWeighted &&
|
||||||
|
(len(strategy.WeightPreference.StaticWeightList) != 0 && strategy.WeightPreference.DynamicWeight == "") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortClusters(infos []ClusterDetailInfo) {
|
||||||
|
sort.Slice(infos, func(i, j int) bool {
|
||||||
|
if infos[i].Score != infos[j].Score {
|
||||||
|
return infos[i].Score > infos[j].Score
|
||||||
|
}
|
||||||
|
|
||||||
|
return infos[i].Name < infos[j].Name
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,203 @@
|
||||||
|
package spreadconstraint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/agiledragon/gomonkey/v2"
|
||||||
|
|
||||||
|
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
||||||
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
|
workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2"
|
||||||
|
"github.com/karmada-io/karmada/pkg/scheduler/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateClusterScore() framework.ClusterScoreList {
|
||||||
|
return framework.ClusterScoreList{
|
||||||
|
{
|
||||||
|
Cluster: NewClusterWithTopology("member1", "P1", "R1", "Z1"),
|
||||||
|
Score: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Cluster: NewClusterWithTopology("member2", "P1", "R1", "Z2"),
|
||||||
|
Score: 40,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Cluster: NewClusterWithTopology("member3", "P2", "R1", "Z1"),
|
||||||
|
Score: 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Cluster: NewClusterWithTopology("member4", "P2", "R2", "Z2"),
|
||||||
|
Score: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func Test_GroupClustersWithScore(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
clustersScore framework.ClusterScoreList
|
||||||
|
placement *policyv1alpha1.Placement
|
||||||
|
spec *workv1alpha2.ResourceBindingSpec
|
||||||
|
}
|
||||||
|
type want struct {
|
||||||
|
clusters []string
|
||||||
|
zoneCnt int
|
||||||
|
regionCnt int
|
||||||
|
providerCnt int
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is nil",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is cluster",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldCluster,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is zone",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldZone,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
zoneCnt: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is region",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldRegion,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
regionCnt: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is provider",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldProvider,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
providerCnt: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test SpreadConstraints is provider/region/zone",
|
||||||
|
args: args{
|
||||||
|
clustersScore: generateClusterScore(),
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldProvider,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldRegion,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldZone,
|
||||||
|
MaxGroups: 1,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
clusters: []string{"member4", "member2", "member3", "member1"},
|
||||||
|
providerCnt: 2,
|
||||||
|
regionCnt: 2,
|
||||||
|
zoneCnt: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
patches := gomonkey.ApplyFunc(calAvailableReplicas, func(clusters []*clusterv1alpha1.Cluster, spec *workv1alpha2.ResourceBindingSpec) []workv1alpha2.TargetCluster {
|
||||||
|
availableTargetClusters := make([]workv1alpha2.TargetCluster, len(clusters))
|
||||||
|
|
||||||
|
for i := range availableTargetClusters {
|
||||||
|
availableTargetClusters[i].Name = clusters[i].Name
|
||||||
|
availableTargetClusters[i].Replicas = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return availableTargetClusters
|
||||||
|
})
|
||||||
|
|
||||||
|
defer patches.Reset()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
groupInfo := GroupClustersWithScore(tt.args.clustersScore, tt.args.placement, tt.args.spec)
|
||||||
|
for i, cluster := range groupInfo.Clusters {
|
||||||
|
if cluster.Name != tt.want.clusters[i] {
|
||||||
|
t.Errorf("test %s : the clusters aren't sorted", tt.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.want.zoneCnt != len(groupInfo.Zones) {
|
||||||
|
t.Errorf("test %s : zoneCnt = %v, want %v", tt.name, len(groupInfo.Zones), tt.want.zoneCnt)
|
||||||
|
}
|
||||||
|
if tt.want.regionCnt != len(groupInfo.Regions) {
|
||||||
|
t.Errorf("test %s : regionCnt = %v, want %v", tt.name, len(groupInfo.Regions), tt.want.regionCnt)
|
||||||
|
}
|
||||||
|
if tt.want.providerCnt != len(groupInfo.Providers) {
|
||||||
|
t.Errorf("test %s : providerCnt = %v, want %v", tt.name, len(groupInfo.Providers), tt.want.providerCnt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package spreadconstraint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
||||||
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectBestClusters selects the cluster set based the GroupClustersInfo and placement
|
||||||
|
func SelectBestClusters(placement *policyv1alpha1.Placement, groupClustersInfo *GroupClustersInfo) ([]*clusterv1alpha1.Cluster, error) {
|
||||||
|
if len(placement.SpreadConstraints) != 0 {
|
||||||
|
return selectBestClustersBySpreadConstraints(placement.SpreadConstraints, groupClustersInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clusters []*clusterv1alpha1.Cluster
|
||||||
|
for _, cluster := range groupClustersInfo.Clusters {
|
||||||
|
clusters = append(clusters, cluster.Cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clusters, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectBestClustersBySpreadConstraints(spreadConstraints []policyv1alpha1.SpreadConstraint,
|
||||||
|
groupClustersInfo *GroupClustersInfo) ([]*clusterv1alpha1.Cluster, error) {
|
||||||
|
if len(spreadConstraints) > 1 {
|
||||||
|
return nil, fmt.Errorf("just support single spread constraint")
|
||||||
|
}
|
||||||
|
|
||||||
|
spreadConstraint := spreadConstraints[0]
|
||||||
|
if spreadConstraint.SpreadByField == policyv1alpha1.SpreadByFieldCluster {
|
||||||
|
return selectBestClustersByCluster(spreadConstraint, groupClustersInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("just support cluster spread constraint")
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectBestClustersByCluster(spreadConstraint policyv1alpha1.SpreadConstraint, groupClustersInfo *GroupClustersInfo) ([]*clusterv1alpha1.Cluster, error) {
|
||||||
|
totalClusterCnt := len(groupClustersInfo.Clusters)
|
||||||
|
if spreadConstraint.MinGroups > totalClusterCnt {
|
||||||
|
return nil, fmt.Errorf("the number of feasible clusters is less than spreadConstraint.MinGroups")
|
||||||
|
}
|
||||||
|
|
||||||
|
needCnt := spreadConstraint.MaxGroups
|
||||||
|
if spreadConstraint.MaxGroups > totalClusterCnt {
|
||||||
|
needCnt = totalClusterCnt
|
||||||
|
}
|
||||||
|
|
||||||
|
var clusters []*clusterv1alpha1.Cluster
|
||||||
|
for i := 0; i < needCnt; i++ {
|
||||||
|
clusters = append(clusters, groupClustersInfo.Clusters[i].Cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clusters, nil
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package spreadconstraint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
|
||||||
|
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
||||||
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewClusterWithTopology will build a Cluster with topology.
|
||||||
|
func NewClusterWithTopology(name, provider, region, zone string) *clusterv1alpha1.Cluster {
|
||||||
|
return &clusterv1alpha1.Cluster{
|
||||||
|
ObjectMeta: metav1.ObjectMeta{Name: name},
|
||||||
|
Spec: clusterv1alpha1.ClusterSpec{
|
||||||
|
Provider: provider,
|
||||||
|
Region: region,
|
||||||
|
Zone: zone,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateClusterInfo() []ClusterDetailInfo {
|
||||||
|
return []ClusterDetailInfo{
|
||||||
|
{
|
||||||
|
Name: "member4",
|
||||||
|
Score: 60,
|
||||||
|
AvailableReplicas: 101,
|
||||||
|
Cluster: NewClusterWithTopology("member4", "P2", "R2", "Z2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "member2",
|
||||||
|
Score: 40,
|
||||||
|
AvailableReplicas: 101,
|
||||||
|
Cluster: NewClusterWithTopology("member2", "P1", "R1", "Z2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "member3",
|
||||||
|
Score: 30,
|
||||||
|
AvailableReplicas: 101,
|
||||||
|
Cluster: NewClusterWithTopology("member3", "P2", "R1", "Z1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "member1",
|
||||||
|
Score: 20,
|
||||||
|
AvailableReplicas: 101,
|
||||||
|
Cluster: NewClusterWithTopology("member1", "P1", "R1", "Z1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectBestClusters(t *testing.T) {
|
||||||
|
clustetInfos := generateClusterInfo()
|
||||||
|
type args struct {
|
||||||
|
placement *policyv1alpha1.Placement
|
||||||
|
groupClustersInfo *GroupClustersInfo
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []*clusterv1alpha1.Cluster
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "select clusters by cluster score",
|
||||||
|
args: args{
|
||||||
|
placement: &policyv1alpha1.Placement{
|
||||||
|
SpreadConstraints: []policyv1alpha1.SpreadConstraint{
|
||||||
|
{
|
||||||
|
SpreadByField: policyv1alpha1.SpreadByFieldCluster,
|
||||||
|
MaxGroups: 2,
|
||||||
|
MinGroups: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
groupClustersInfo: &GroupClustersInfo{
|
||||||
|
Clusters: clustetInfos,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []*clusterv1alpha1.Cluster{
|
||||||
|
clustetInfos[0].Cluster,
|
||||||
|
clustetInfos[1].Cluster,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := SelectBestClusters(tt.args.placement, tt.args.groupClustersInfo)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SelectBestClusters() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SelectBestClusters() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
package spreadconstraint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"k8s.io/klog/v2"
|
||||||
|
|
||||||
|
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
||||||
|
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
||||||
|
workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2"
|
||||||
|
estimatorclient "github.com/karmada-io/karmada/pkg/estimator/client"
|
||||||
|
"github.com/karmada-io/karmada/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func calAvailableReplicas(clusters []*clusterv1alpha1.Cluster, spec *workv1alpha2.ResourceBindingSpec) []workv1alpha2.TargetCluster {
|
||||||
|
availableClusters := make([]workv1alpha2.TargetCluster, len(clusters))
|
||||||
|
|
||||||
|
// Set the boundary.
|
||||||
|
for i := range availableClusters {
|
||||||
|
availableClusters[i].Name = clusters[i].Name
|
||||||
|
availableClusters[i].Replicas = math.MaxInt32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the minimum value of MaxAvailableReplicas in terms of all estimators.
|
||||||
|
estimators := estimatorclient.GetReplicaEstimators()
|
||||||
|
ctx := context.WithValue(context.TODO(), util.ContextKeyObject,
|
||||||
|
fmt.Sprintf("kind=%s, name=%s/%s", spec.Resource.Kind, spec.Resource.Namespace, spec.Resource.Name))
|
||||||
|
for _, estimator := range estimators {
|
||||||
|
res, err := estimator.MaxAvailableReplicas(ctx, clusters, spec.ReplicaRequirements)
|
||||||
|
if err != nil {
|
||||||
|
klog.Errorf("Max cluster available replicas error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i := range res {
|
||||||
|
if res[i].Replicas == estimatorclient.UnauthenticReplica {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if availableClusters[i].Name == res[i].Name && availableClusters[i].Replicas > res[i].Replicas {
|
||||||
|
availableClusters[i].Replicas = res[i].Replicas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In most cases, the target cluster max available replicas should not be MaxInt32 unless the workload is best-effort
|
||||||
|
// and the scheduler-estimator has not been enabled. So we set the replicas to spec.Replicas for avoiding overflow.
|
||||||
|
for i := range availableClusters {
|
||||||
|
if availableClusters[i].Replicas == math.MaxInt32 {
|
||||||
|
availableClusters[i].Replicas = spec.Replicas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
klog.V(4).Infof("cluster replicas info: %v", availableClusters)
|
||||||
|
return availableClusters
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSpreadConstraintExisted judge if the specific field is existed in the spread constraints
|
||||||
|
func IsSpreadConstraintExisted(spreadConstraints []policyv1alpha1.SpreadConstraint, field policyv1alpha1.SpreadFieldValue) bool {
|
||||||
|
for _, spreadConstraint := range spreadConstraints {
|
||||||
|
if spreadConstraint.SpreadByField == field {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
|
@ -1,30 +0,0 @@
|
||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
clusterv1alpha1 "github.com/karmada-io/karmada/pkg/apis/cluster/v1alpha1"
|
|
||||||
policyv1alpha1 "github.com/karmada-io/karmada/pkg/apis/policy/v1alpha1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SpreadGroup stores the cluster group info for given spread constraints
|
|
||||||
type SpreadGroup struct {
|
|
||||||
// The outer map's keys are SpreadConstraint. The values (inner map) of the outer map are maps with string
|
|
||||||
// keys and []string values. The inner map's key should specify the cluster group name.
|
|
||||||
GroupRecord map[policyv1alpha1.SpreadConstraint]map[string][]*clusterv1alpha1.Cluster
|
|
||||||
sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSpreadGroup initializes a SpreadGroup
|
|
||||||
func NewSpreadGroup() *SpreadGroup {
|
|
||||||
return &SpreadGroup{
|
|
||||||
GroupRecord: make(map[policyv1alpha1.SpreadConstraint]map[string][]*clusterv1alpha1.Cluster),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitialGroupRecord initials a spread state record
|
|
||||||
func (ss *SpreadGroup) InitialGroupRecord(constraint policyv1alpha1.SpreadConstraint) {
|
|
||||||
ss.Lock()
|
|
||||||
defer ss.Unlock()
|
|
||||||
ss.GroupRecord[constraint] = make(map[string][]*clusterv1alpha1.Cluster)
|
|
||||||
}
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2018 Zhang Xiaolong
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,53 @@
|
||||||
|
# gomonkey
|
||||||
|
|
||||||
|
gomonkey is a library to make monkey patching in unit tests easy, and the core idea of monkey patching comes from [Bouke](https://github.com/bouk), you can read [this blogpost](https://bou.ke/blog/monkey-patching-in-go/) for an explanation on how it works.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
+ support a patch for a function
|
||||||
|
+ support a patch for a public member method
|
||||||
|
+ support a patch for a private member method
|
||||||
|
+ support a patch for a interface
|
||||||
|
+ support a patch for a function variable
|
||||||
|
+ support a patch for a global variable
|
||||||
|
+ support patches of a specified sequence for a function
|
||||||
|
+ support patches of a specified sequence for a member method
|
||||||
|
+ support patches of a specified sequence for a interface
|
||||||
|
+ support patches of a specified sequence for a function variable
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
+ gomonkey fails to patch a function or a member method if inlining is enabled, please running your tests with inlining disabled by adding the command line argument that is `-gcflags=-l`(below go1.10) or `-gcflags=all=-l`(go1.10 and above).
|
||||||
|
+ A panic may happen when a goroutine is patching a function or a member method that is visited by another goroutine at the same time. That is to say, gomonkey is not threadsafe.
|
||||||
|
|
||||||
|
## Supported Platform:
|
||||||
|
|
||||||
|
- ARCH
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- 386
|
||||||
|
|
||||||
|
- OS
|
||||||
|
- Linux
|
||||||
|
- MAC OS X
|
||||||
|
- Windows
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
- below v2.1.0, for example v2.0.2
|
||||||
|
```go
|
||||||
|
$ go get github.com/agiledragon/gomonkey@v2.0.2
|
||||||
|
```
|
||||||
|
- v2.1.0 and above, for example v2.2.0
|
||||||
|
```go
|
||||||
|
$ go get github.com/agiledragon/gomonkey/v2@v2.2.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Method
|
||||||
|
```go
|
||||||
|
$ cd test
|
||||||
|
$ go test -gcflags=all=-l
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using gomonkey
|
||||||
|
|
||||||
|
Please refer to the test cases as idioms, very complete and detailed.
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
// Customized reflect package for gomonkey,copy most code from go/src/reflect/type.go
|
||||||
|
|
||||||
|
package creflect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rtype is the common implementation of most values.
|
||||||
|
// rtype must be kept in sync with ../runtime/type.go:/^type._type.
|
||||||
|
type rtype struct {
|
||||||
|
size uintptr
|
||||||
|
ptrdata uintptr // number of bytes in the type that can contain pointers
|
||||||
|
hash uint32 // hash of type; avoids computation in hash tables
|
||||||
|
tflag tflag // extra type information flags
|
||||||
|
align uint8 // alignment of variable with this type
|
||||||
|
fieldAlign uint8 // alignment of struct field with this type
|
||||||
|
kind uint8 // enumeration for C
|
||||||
|
// function for comparing objects of this type
|
||||||
|
// (ptr to object A, ptr to object B) -> ==?
|
||||||
|
equal func(unsafe.Pointer, unsafe.Pointer) bool
|
||||||
|
gcdata *byte // garbage collection data
|
||||||
|
str nameOff // string form
|
||||||
|
ptrToThis typeOff // type for pointer to this type, may be zero
|
||||||
|
}
|
||||||
|
|
||||||
|
func Create(t reflect.Type) *rtype {
|
||||||
|
i := *(*funcValue)(unsafe.Pointer(&t))
|
||||||
|
r := (*rtype)(i.p)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
type funcValue struct {
|
||||||
|
_ uintptr
|
||||||
|
p unsafe.Pointer
|
||||||
|
}
|
||||||
|
func funcPointer(v reflect.Method, ok bool) (unsafe.Pointer, bool) {
|
||||||
|
return (*funcValue)(unsafe.Pointer(&v.Func)).p, ok
|
||||||
|
}
|
||||||
|
func MethodByName(r reflect.Type, name string) (fn unsafe.Pointer, ok bool) {
|
||||||
|
t := Create(r)
|
||||||
|
if r.Kind() == reflect.Interface {
|
||||||
|
return funcPointer(r.MethodByName(name))
|
||||||
|
}
|
||||||
|
ut := t.uncommon(r)
|
||||||
|
if ut == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range ut.methods() {
|
||||||
|
if t.nameOff(p.name).name() == name {
|
||||||
|
return t.Method(p), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rtype) Method(p method) (fn unsafe.Pointer) {
|
||||||
|
tfn := t.textOff(p.tfn)
|
||||||
|
fn = unsafe.Pointer(&tfn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type tflag uint8
|
||||||
|
type nameOff int32 // offset to a name
|
||||||
|
type typeOff int32 // offset to an *rtype
|
||||||
|
type textOff int32 // offset from top of text section
|
||||||
|
|
||||||
|
//go:linkname resolveTextOff reflect.resolveTextOff
|
||||||
|
func resolveTextOff(rtype unsafe.Pointer, off int32) unsafe.Pointer
|
||||||
|
|
||||||
|
func (t *rtype) textOff(off textOff) unsafe.Pointer {
|
||||||
|
return resolveTextOff(unsafe.Pointer(t), int32(off))
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname resolveNameOff reflect.resolveNameOff
|
||||||
|
func resolveNameOff(ptrInModule unsafe.Pointer, off int32) unsafe.Pointer
|
||||||
|
|
||||||
|
func (t *rtype) nameOff(off nameOff) name {
|
||||||
|
return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
tflagUncommon tflag = 1 << 0
|
||||||
|
)
|
||||||
|
// uncommonType is present only for defined types or types with methods
|
||||||
|
type uncommonType struct {
|
||||||
|
pkgPath nameOff // import path; empty for built-in types like int, string
|
||||||
|
mcount uint16 // number of methods
|
||||||
|
xcount uint16 // number of exported methods
|
||||||
|
moff uint32 // offset from this uncommontype to [mcount]method
|
||||||
|
_ uint32 // unused
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptrType represents a pointer type.
|
||||||
|
type ptrType struct {
|
||||||
|
rtype
|
||||||
|
elem *rtype // pointer element (pointed at) type
|
||||||
|
}
|
||||||
|
|
||||||
|
// funcType represents a function type.
|
||||||
|
type funcType struct {
|
||||||
|
rtype
|
||||||
|
inCount uint16
|
||||||
|
outCount uint16 // top bit is set if last input parameter is ...
|
||||||
|
}
|
||||||
|
|
||||||
|
func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer {
|
||||||
|
return unsafe.Pointer(uintptr(p) + x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// interfaceType represents an interface type.
|
||||||
|
type interfaceType struct {
|
||||||
|
rtype
|
||||||
|
pkgPath name // import path
|
||||||
|
methods []imethod // sorted by hash
|
||||||
|
}
|
||||||
|
|
||||||
|
type imethod struct {
|
||||||
|
name nameOff // name of method
|
||||||
|
typ typeOff // .(*FuncType) underneath
|
||||||
|
}
|
||||||
|
|
||||||
|
// name is an encoded type name with optional extra data.
|
||||||
|
type name struct {
|
||||||
|
bytes *byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type String struct {
|
||||||
|
Data unsafe.Pointer
|
||||||
|
Len int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n name) name() (s string) {
|
||||||
|
if n.bytes == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b := (*[4]byte)(unsafe.Pointer(n.bytes))
|
||||||
|
|
||||||
|
hdr := (*String)(unsafe.Pointer(&s))
|
||||||
|
hdr.Data = unsafe.Pointer(&b[3])
|
||||||
|
hdr.Len = int(b[1])<<8 | int(b[2])
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rtype) uncommon(r reflect.Type) *uncommonType {
|
||||||
|
if t.tflag&tflagUncommon == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch r.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
type u struct {
|
||||||
|
ptrType
|
||||||
|
u uncommonType
|
||||||
|
}
|
||||||
|
return &(*u)(unsafe.Pointer(t)).u
|
||||||
|
case reflect.Func:
|
||||||
|
type u struct {
|
||||||
|
funcType
|
||||||
|
u uncommonType
|
||||||
|
}
|
||||||
|
return &(*u)(unsafe.Pointer(t)).u
|
||||||
|
case reflect.Interface:
|
||||||
|
type u struct {
|
||||||
|
interfaceType
|
||||||
|
u uncommonType
|
||||||
|
}
|
||||||
|
return &(*u)(unsafe.Pointer(t)).u
|
||||||
|
case reflect.Struct:
|
||||||
|
type u struct {
|
||||||
|
interfaceType
|
||||||
|
u uncommonType
|
||||||
|
}
|
||||||
|
return &(*u)(unsafe.Pointer(t)).u
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method on non-interface type
|
||||||
|
type method struct {
|
||||||
|
name nameOff // name of method
|
||||||
|
mtyp typeOff // method type (without receiver)
|
||||||
|
ifn textOff // fn used in interface call (one-word receiver)
|
||||||
|
tfn textOff // fn used for normal method call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *uncommonType) methods() []method {
|
||||||
|
if t.mcount == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.mcount > 0"))[:t.mcount:t.mcount]
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
func buildJmpDirective(double uintptr) []byte {
|
||||||
|
d0 := byte(double)
|
||||||
|
d1 := byte(double >> 8)
|
||||||
|
d2 := byte(double >> 16)
|
||||||
|
d3 := byte(double >> 24)
|
||||||
|
|
||||||
|
return []byte{
|
||||||
|
0xBA, d0, d1, d2, d3, // MOV edx, double
|
||||||
|
0xFF, 0x22, // JMP [edx]
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
func buildJmpDirective(double uintptr) []byte {
|
||||||
|
d0 := byte(double)
|
||||||
|
d1 := byte(double >> 8)
|
||||||
|
d2 := byte(double >> 16)
|
||||||
|
d3 := byte(double >> 24)
|
||||||
|
d4 := byte(double >> 32)
|
||||||
|
d5 := byte(double >> 40)
|
||||||
|
d6 := byte(double >> 48)
|
||||||
|
d7 := byte(double >> 56)
|
||||||
|
|
||||||
|
return []byte{
|
||||||
|
0x48, 0xBA, d0, d1, d2, d3, d4, d5, d6, d7, // MOV rdx, double
|
||||||
|
0xFF, 0x22, // JMP [rdx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
func buildJmpDirective(double uintptr) []byte {
|
||||||
|
res := make([]byte, 0, 24)
|
||||||
|
d0d1 := double & 0xFFFF
|
||||||
|
d2d3 := double >> 16 & 0xFFFF
|
||||||
|
d4d5 := double >> 32 & 0xFFFF
|
||||||
|
d6d7 := double >> 48 & 0xFFFF
|
||||||
|
|
||||||
|
res = append(res, movImm(0B10, 0, d0d1)...) // MOVZ x26, double[16:0]
|
||||||
|
res = append(res, movImm(0B11, 1, d2d3)...) // MOVK x26, double[32:16]
|
||||||
|
res = append(res, movImm(0B11, 2, d4d5)...) // MOVK x26, double[48:32]
|
||||||
|
res = append(res, movImm(0B11, 3, d6d7)...) // MOVK x26, double[64:48]
|
||||||
|
res = append(res, []byte{0x4A, 0x03, 0x40, 0xF9}...) // LDR x10, [x26]
|
||||||
|
res = append(res, []byte{0x40, 0x01, 0x1F, 0xD6}...) // BR x10
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func movImm(opc, shift int, val uintptr) []byte {
|
||||||
|
var m uint32 = 26 // rd
|
||||||
|
m |= uint32(val) << 5 // imm16
|
||||||
|
m |= uint32(shift&3) << 21 // hw
|
||||||
|
m |= 0b100101 << 23 // const
|
||||||
|
m |= uint32(opc&0x3) << 29 // opc
|
||||||
|
m |= 0b1 << 31 // sf
|
||||||
|
|
||||||
|
res := make([]byte, 4)
|
||||||
|
*(*uint32)(unsafe.Pointer(&res[0])) = m
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func modifyBinary(target uintptr, bytes []byte) {
|
||||||
|
function := entryAddress(target, len(bytes))
|
||||||
|
err := mprotectCrossPage(target, len(bytes), syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
copy(function, bytes)
|
||||||
|
err = mprotectCrossPage(target, len(bytes), syscall.PROT_READ|syscall.PROT_EXEC)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mprotectCrossPage(addr uintptr, length int, prot int) error {
|
||||||
|
pageSize := syscall.Getpagesize()
|
||||||
|
for p := pageStart(addr); p < addr+uintptr(length); p += uintptr(pageSize) {
|
||||||
|
page := entryAddress(p, pageSize)
|
||||||
|
if err := syscall.Mprotect(page, prot); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func modifyBinary(target uintptr, bytes []byte) {
|
||||||
|
function := entryAddress(target, len(bytes))
|
||||||
|
err := mprotectCrossPage(target, len(bytes), syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
copy(function, bytes)
|
||||||
|
err = mprotectCrossPage(target, len(bytes), syscall.PROT_READ|syscall.PROT_EXEC)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mprotectCrossPage(addr uintptr, length int, prot int) error {
|
||||||
|
pageSize := syscall.Getpagesize()
|
||||||
|
for p := pageStart(addr); p < addr+uintptr(length); p += uintptr(pageSize) {
|
||||||
|
page := entryAddress(p, pageSize)
|
||||||
|
if err := syscall.Mprotect(page, prot); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
25
vendor/github.com/agiledragon/gomonkey/v2/modify_binary_windows.go
generated
vendored
Normal file
25
vendor/github.com/agiledragon/gomonkey/v2/modify_binary_windows.go
generated
vendored
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func modifyBinary(target uintptr, bytes []byte) {
|
||||||
|
function := entryAddress(target, len(bytes))
|
||||||
|
|
||||||
|
proc := syscall.NewLazyDLL("kernel32.dll").NewProc("VirtualProtect")
|
||||||
|
const PROT_READ_WRITE = 0x40
|
||||||
|
var old uint32
|
||||||
|
result, _, _ := proc.Call(target, uintptr(len(bytes)), uintptr(PROT_READ_WRITE), uintptr(unsafe.Pointer(&old)))
|
||||||
|
if result == 0 {
|
||||||
|
panic(result)
|
||||||
|
}
|
||||||
|
copy(function, bytes)
|
||||||
|
|
||||||
|
var ignore uint32
|
||||||
|
result, _, _ = proc.Call(target, uintptr(len(bytes)), uintptr(old), uintptr(unsafe.Pointer(&ignore)))
|
||||||
|
if result == 0 {
|
||||||
|
panic(result)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,340 @@
|
||||||
|
package gomonkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/agiledragon/gomonkey/v2/creflect"
|
||||||
|
"reflect"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Patches struct {
|
||||||
|
originals map[uintptr][]byte
|
||||||
|
values map[reflect.Value]reflect.Value
|
||||||
|
valueHolders map[reflect.Value]reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
type Params []interface{}
|
||||||
|
type OutputCell struct {
|
||||||
|
Values Params
|
||||||
|
Times int
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFunc(target, double interface{}) *Patches {
|
||||||
|
return create().ApplyFunc(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyMethod(target reflect.Type, methodName string, double interface{}) *Patches {
|
||||||
|
return create().ApplyMethod(target, methodName, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
|
||||||
|
return create().ApplyMethodFunc(target, methodName, doubleFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
|
||||||
|
return create().ApplyPrivateMethod(target, methodName, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGlobalVar(target, double interface{}) *Patches {
|
||||||
|
return create().ApplyGlobalVar(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFuncVar(target, double interface{}) *Patches {
|
||||||
|
return create().ApplyFuncVar(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
|
||||||
|
return create().ApplyFuncSeq(target, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyMethodSeq(target reflect.Type, methodName string, outputs []OutputCell) *Patches {
|
||||||
|
return create().ApplyMethodSeq(target, methodName, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
|
||||||
|
return create().ApplyFuncVarSeq(target, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
|
||||||
|
return create().ApplyFuncReturn(target, output...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
|
||||||
|
return create().ApplyMethodReturn(target, methodName, output...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
|
||||||
|
return create().ApplyFuncVarReturn(target, output...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func create() *Patches {
|
||||||
|
return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPatches() *Patches {
|
||||||
|
return create()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
d := reflect.ValueOf(double)
|
||||||
|
return this.ApplyCore(t, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyMethod(target reflect.Type, methodName string, double interface{}) *Patches {
|
||||||
|
m, ok := target.MethodByName(methodName)
|
||||||
|
if !ok {
|
||||||
|
panic("retrieve method by name failed")
|
||||||
|
}
|
||||||
|
d := reflect.ValueOf(double)
|
||||||
|
return this.ApplyCore(m.Func, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
|
||||||
|
m, ok := target.MethodByName(methodName)
|
||||||
|
if !ok {
|
||||||
|
panic("retrieve method by name failed")
|
||||||
|
}
|
||||||
|
d := funcToMethod(m.Type, doubleFunc)
|
||||||
|
return this.ApplyCore(m.Func, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
|
||||||
|
m, ok := creflect.MethodByName(target, methodName)
|
||||||
|
if !ok {
|
||||||
|
panic("retrieve method by name failed")
|
||||||
|
}
|
||||||
|
d := reflect.ValueOf(double)
|
||||||
|
return this.ApplyCoreOnlyForPrivateMethod(m, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches {
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
if t.Type().Kind() != reflect.Ptr {
|
||||||
|
panic("target is not a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.values[t] = reflect.ValueOf(t.Elem().Interface())
|
||||||
|
d := reflect.ValueOf(double)
|
||||||
|
t.Elem().Set(d)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches {
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
d := reflect.ValueOf(double)
|
||||||
|
if t.Type().Kind() != reflect.Ptr {
|
||||||
|
panic("target is not a pointer")
|
||||||
|
}
|
||||||
|
this.check(t.Elem(), d)
|
||||||
|
return this.ApplyGlobalVar(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
|
||||||
|
funcType := reflect.TypeOf(target)
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
d := getDoubleFunc(funcType, outputs)
|
||||||
|
return this.ApplyCore(t, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyMethodSeq(target reflect.Type, methodName string, outputs []OutputCell) *Patches {
|
||||||
|
m, ok := target.MethodByName(methodName)
|
||||||
|
if !ok {
|
||||||
|
panic("retrieve method by name failed")
|
||||||
|
}
|
||||||
|
d := getDoubleFunc(m.Type, outputs)
|
||||||
|
return this.ApplyCore(m.Func, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
if t.Type().Kind() != reflect.Ptr {
|
||||||
|
panic("target is not a pointer")
|
||||||
|
}
|
||||||
|
if t.Elem().Kind() != reflect.Func {
|
||||||
|
panic("target is not a func")
|
||||||
|
}
|
||||||
|
|
||||||
|
funcType := reflect.TypeOf(target).Elem()
|
||||||
|
double := getDoubleFunc(funcType, outputs).Interface()
|
||||||
|
return this.ApplyGlobalVar(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
|
||||||
|
funcType := reflect.TypeOf(target)
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
outputs := []OutputCell{{Values: returns, Times: -1}}
|
||||||
|
d := getDoubleFunc(funcType, outputs)
|
||||||
|
return this.ApplyCore(t, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
|
||||||
|
m, ok := reflect.TypeOf(target).MethodByName(methodName)
|
||||||
|
if !ok {
|
||||||
|
panic("retrieve method by name failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs := []OutputCell{{Values: returns, Times: -1}}
|
||||||
|
d := getDoubleFunc(m.Type, outputs)
|
||||||
|
return this.ApplyCore(m.Func, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
|
||||||
|
t := reflect.ValueOf(target)
|
||||||
|
if t.Type().Kind() != reflect.Ptr {
|
||||||
|
panic("target is not a pointer")
|
||||||
|
}
|
||||||
|
if t.Elem().Kind() != reflect.Func {
|
||||||
|
panic("target is not a func")
|
||||||
|
}
|
||||||
|
|
||||||
|
funcType := reflect.TypeOf(target).Elem()
|
||||||
|
outputs := []OutputCell{{Values: returns, Times: -1}}
|
||||||
|
double := getDoubleFunc(funcType, outputs).Interface()
|
||||||
|
return this.ApplyGlobalVar(target, double)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) Reset() {
|
||||||
|
for target, bytes := range this.originals {
|
||||||
|
modifyBinary(target, bytes)
|
||||||
|
delete(this.originals, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
for target, variable := range this.values {
|
||||||
|
target.Elem().Set(variable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
|
||||||
|
this.check(target, double)
|
||||||
|
assTarget := *(*uintptr)(getPointer(target))
|
||||||
|
if _, ok := this.originals[assTarget]; ok {
|
||||||
|
panic("patch has been existed")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.valueHolders[double] = double
|
||||||
|
original := replace(assTarget, uintptr(getPointer(double)))
|
||||||
|
this.originals[assTarget] = original
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches {
|
||||||
|
if double.Kind() != reflect.Func {
|
||||||
|
panic("double is not a func")
|
||||||
|
}
|
||||||
|
assTarget := *(*uintptr)(target)
|
||||||
|
if _, ok := this.originals[assTarget]; ok {
|
||||||
|
panic("patch has been existed")
|
||||||
|
}
|
||||||
|
this.valueHolders[double] = double
|
||||||
|
original := replace(assTarget, uintptr(getPointer(double)))
|
||||||
|
this.originals[assTarget] = original
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Patches) check(target, double reflect.Value) {
|
||||||
|
if target.Kind() != reflect.Func {
|
||||||
|
panic("target is not a func")
|
||||||
|
}
|
||||||
|
|
||||||
|
if double.Kind() != reflect.Func {
|
||||||
|
panic("double is not a func")
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Type() != double.Type() {
|
||||||
|
panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func replace(target, double uintptr) []byte {
|
||||||
|
code := buildJmpDirective(double)
|
||||||
|
bytes := entryAddress(target, len(code))
|
||||||
|
original := make([]byte, len(bytes))
|
||||||
|
copy(original, bytes)
|
||||||
|
modifyBinary(target, code)
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
|
||||||
|
if funcType.NumOut() != len(outputs[0].Values) {
|
||||||
|
panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
|
||||||
|
funcType.NumOut(), len(outputs[0].Values)))
|
||||||
|
}
|
||||||
|
|
||||||
|
needReturn := false
|
||||||
|
slice := make([]Params, 0)
|
||||||
|
for _, output := range outputs {
|
||||||
|
if output.Times == -1 {
|
||||||
|
needReturn = true
|
||||||
|
slice = []Params{output.Values}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t := 0
|
||||||
|
if output.Times <= 1 {
|
||||||
|
t = 1
|
||||||
|
} else {
|
||||||
|
t = output.Times
|
||||||
|
}
|
||||||
|
for j := 0; j < t; j++ {
|
||||||
|
slice = append(slice, output.Values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
lenOutputs := len(slice)
|
||||||
|
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
|
||||||
|
if needReturn {
|
||||||
|
return GetResultValues(funcType, slice[0]...)
|
||||||
|
}
|
||||||
|
if i < lenOutputs {
|
||||||
|
i++
|
||||||
|
return GetResultValues(funcType, slice[i-1]...)
|
||||||
|
}
|
||||||
|
panic("double seq is less than call seq")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value {
|
||||||
|
var resultValues []reflect.Value
|
||||||
|
for i, r := range results {
|
||||||
|
var resultValue reflect.Value
|
||||||
|
if r == nil {
|
||||||
|
resultValue = reflect.Zero(funcType.Out(i))
|
||||||
|
} else {
|
||||||
|
v := reflect.New(funcType.Out(i))
|
||||||
|
v.Elem().Set(reflect.ValueOf(r))
|
||||||
|
resultValue = v.Elem()
|
||||||
|
}
|
||||||
|
resultValues = append(resultValues, resultValue)
|
||||||
|
}
|
||||||
|
return resultValues
|
||||||
|
}
|
||||||
|
|
||||||
|
type funcValue struct {
|
||||||
|
_ uintptr
|
||||||
|
p unsafe.Pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPointer(v reflect.Value) unsafe.Pointer {
|
||||||
|
return (*funcValue)(unsafe.Pointer(&v)).p
|
||||||
|
}
|
||||||
|
|
||||||
|
func entryAddress(p uintptr, l int) []byte {
|
||||||
|
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func pageStart(ptr uintptr) uintptr {
|
||||||
|
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
|
||||||
|
rf := reflect.TypeOf(doubleFunc)
|
||||||
|
if rf.Kind() != reflect.Func {
|
||||||
|
panic("doubleFunc is not a func")
|
||||||
|
}
|
||||||
|
vf := reflect.ValueOf(doubleFunc)
|
||||||
|
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
|
||||||
|
return vf.Call(in[1:])
|
||||||
|
})
|
||||||
|
}
|
|
@ -18,6 +18,10 @@ github.com/PuerkitoBio/purell
|
||||||
# github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578
|
# github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578
|
||||||
## explicit
|
## explicit
|
||||||
github.com/PuerkitoBio/urlesc
|
github.com/PuerkitoBio/urlesc
|
||||||
|
# github.com/agiledragon/gomonkey/v2 v2.5.0
|
||||||
|
## explicit; go 1.14
|
||||||
|
github.com/agiledragon/gomonkey/v2
|
||||||
|
github.com/agiledragon/gomonkey/v2/creflect
|
||||||
# github.com/alessio/shellescape v1.4.1
|
# github.com/alessio/shellescape v1.4.1
|
||||||
## explicit; go 1.14
|
## explicit; go 1.14
|
||||||
github.com/alessio/shellescape
|
github.com/alessio/shellescape
|
||||||
|
|
Loading…
Reference in New Issue