Merge branch 'master' into ignore-hyperpod

This commit is contained in:
Yiqing Wang 2025-06-10 13:50:22 -07:00
commit 20810f8d93
103 changed files with 4761 additions and 3488 deletions

View File

@ -19,6 +19,13 @@ package signers
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/auth/credentials" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/auth/credentials"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/errors" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/errors"
@ -26,16 +33,12 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/responses" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/responses"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/utils" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/utils"
"k8s.io/klog/v2" "k8s.io/klog/v2"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
) )
const ( const (
defaultOIDCDurationSeconds = 3600 defaultOIDCDurationSeconds = 3600
oidcTokenFilePath = "ALIBABA_CLOUD_OIDC_TOKEN_FILE"
oldOidcTokenFilePath = "ALICLOUD_OIDC_TOKEN_FILE_PATH"
) )
// OIDCSigner is kind of signer // OIDCSigner is kind of signer
@ -149,7 +152,7 @@ func (signer *OIDCSigner) getOIDCToken(OIDCTokenFilePath string) string {
tokenPath := OIDCTokenFilePath tokenPath := OIDCTokenFilePath
_, err := os.Stat(tokenPath) _, err := os.Stat(tokenPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
tokenPath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") tokenPath = utils.FirstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath))
if tokenPath == "" { if tokenPath == "" {
klog.Error("oidc token file path is missing") klog.Error("oidc token file path is missing")
return "" return ""

View File

@ -22,11 +22,12 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/google/uuid"
"net/url" "net/url"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
"github.com/google/uuid"
) )
/* if you use go 1.10 or higher, you can hack this util by these to avoid "TimeZone.zip not found" on Windows */ /* if you use go 1.10 or higher, you can hack this util by these to avoid "TimeZone.zip not found" on Windows */
@ -127,3 +128,15 @@ func InitStructWithDefaultTag(bean interface{}) {
} }
} }
} }
// FirstNotEmpty returns the first non-empty string from the input list.
// If all strings are empty or no arguments are provided, it returns an empty string.
func FirstNotEmpty(strs ...string) string {
for _, str := range strs {
if str != "" {
return str
}
}
return ""
}

View File

@ -0,0 +1,45 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFirstNotEmpty(t *testing.T) {
// Test case where the first non-empty string is at the beginning
result := FirstNotEmpty("hello", "world", "test")
assert.Equal(t, "hello", result)
// Test case where the first non-empty string is in the middle
result = FirstNotEmpty("", "foo", "bar")
assert.Equal(t, "foo", result)
// Test case where the first non-empty string is at the end
result = FirstNotEmpty("", "", "baz")
assert.Equal(t, "baz", result)
// Test case where all strings are empty
result = FirstNotEmpty("", "", "")
assert.Equal(t, "", result)
// Test case with no arguments
result = FirstNotEmpty()
assert.Equal(t, "", result)
}

View File

@ -19,6 +19,7 @@ package alicloud
import ( import (
"os" "os"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/utils"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/metadata" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/metadata"
"k8s.io/klog/v2" "k8s.io/klog/v2"
) )
@ -63,19 +64,19 @@ func (cc *cloudConfig) isValid() bool {
} }
if cc.OIDCProviderARN == "" { if cc.OIDCProviderARN == "" {
cc.OIDCProviderARN = firstNotEmpty(os.Getenv(oidcProviderARN), os.Getenv(oldOidcProviderARN)) cc.OIDCProviderARN = utils.FirstNotEmpty(os.Getenv(oidcProviderARN), os.Getenv(oldOidcProviderARN))
} }
if cc.OIDCTokenFilePath == "" { if cc.OIDCTokenFilePath == "" {
cc.OIDCTokenFilePath = firstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath)) cc.OIDCTokenFilePath = utils.FirstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath))
} }
if cc.RoleARN == "" { if cc.RoleARN == "" {
cc.RoleARN = firstNotEmpty(os.Getenv(roleARN), os.Getenv(oldRoleARN)) cc.RoleARN = utils.FirstNotEmpty(os.Getenv(roleARN), os.Getenv(oldRoleARN))
} }
if cc.RoleSessionName == "" { if cc.RoleSessionName == "" {
cc.RoleSessionName = firstNotEmpty(os.Getenv(roleSessionName), os.Getenv(oldRoleSessionName)) cc.RoleSessionName = utils.FirstNotEmpty(os.Getenv(roleSessionName), os.Getenv(oldRoleSessionName))
} }
if cc.RegionId != "" && cc.AccessKeyID != "" && cc.AccessKeySecret != "" { if cc.RegionId != "" && cc.AccessKeyID != "" && cc.AccessKeySecret != "" {
@ -133,15 +134,3 @@ func (cc *cloudConfig) getRegion() string {
} }
return r return r
} }
// firstNotEmpty returns the first non-empty string from the input list.
// If all strings are empty or no arguments are provided, it returns an empty string.
func firstNotEmpty(strs ...string) string {
for _, str := range strs {
if str != "" {
return str
}
}
return ""
}

View File

@ -55,25 +55,3 @@ func TestOldRRSACloudConfigIsValid(t *testing.T) {
assert.True(t, cfg.isValid()) assert.True(t, cfg.isValid())
assert.True(t, cfg.RRSAEnabled) assert.True(t, cfg.RRSAEnabled)
} }
func TestFirstNotEmpty(t *testing.T) {
// Test case where the first non-empty string is at the beginning
result := firstNotEmpty("hello", "world", "test")
assert.Equal(t, "hello", result)
// Test case where the first non-empty string is in the middle
result = firstNotEmpty("", "foo", "bar")
assert.Equal(t, "foo", result)
// Test case where the first non-empty string is at the end
result = firstNotEmpty("", "", "baz")
assert.Equal(t, "baz", result)
// Test case where all strings are empty
result = firstNotEmpty("", "", "")
assert.Equal(t, "", result)
// Test case with no arguments
result = firstNotEmpty()
assert.Equal(t, "", result)
}

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"] resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"] - apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"] resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"] - apiGroups: ["batch", "extensions"]
resources: ["jobs"] resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault type: RuntimeDefault
serviceAccountName: cluster-autoscaler serviceAccountName: cluster-autoscaler
containers: containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2 - image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler name: cluster-autoscaler
resources: resources:
limits: limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"] resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"] - apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"] resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"] - apiGroups: ["batch", "extensions"]
resources: ["jobs"] resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault type: RuntimeDefault
serviceAccountName: cluster-autoscaler serviceAccountName: cluster-autoscaler
containers: containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2 - image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler name: cluster-autoscaler
resources: resources:
limits: limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"] resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"] - apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"] resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"] - apiGroups: ["batch", "extensions"]
resources: ["jobs"] resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault type: RuntimeDefault
serviceAccountName: cluster-autoscaler serviceAccountName: cluster-autoscaler
containers: containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2 - image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler name: cluster-autoscaler
resources: resources:
limits: limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"] resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"] - apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"] resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"] verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"] - apiGroups: ["batch", "extensions"]
resources: ["jobs"] resources: ["jobs"]
@ -153,7 +153,7 @@ spec:
nodeSelector: nodeSelector:
kubernetes.io/role: control-plane kubernetes.io/role: control-plane
containers: containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2 - image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler name: cluster-autoscaler
resources: resources:
limits: limits:

View File

@ -20,7 +20,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"testing" "testing"
"time" "time"
@ -422,8 +421,7 @@ func TestDeleteInstances(t *testing.T) {
}, },
}, nil) }, nil)
err = as.DeleteInstances(instances) err = as.DeleteInstances(instances)
expectedErrStr := "The specified account is disabled." assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), expectedErrStr))
} }
func TestAgentPoolDeleteNodes(t *testing.T) { func TestAgentPoolDeleteNodes(t *testing.T) {
@ -478,8 +476,7 @@ func TestAgentPoolDeleteNodes(t *testing.T) {
ObjectMeta: v1.ObjectMeta{Name: "node"}, ObjectMeta: v1.ObjectMeta{Name: "node"},
}, },
}) })
expectedErrStr := "The specified account is disabled." assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), expectedErrStr))
as.minSize = 3 as.minSize = 3
err = as.DeleteNodes([]*apiv1.Node{}) err = as.DeleteNodes([]*apiv1.Node{})

View File

@ -25,6 +25,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to" "github.com/Azure/go-autorest/autorest/to"
"github.com/Azure/skewer" "github.com/Azure/skewer"
@ -67,13 +68,18 @@ type azureCache struct {
// Cache content. // Cache content.
// resourceGroup specifies the name of the resource group that this cache tracks // resourceGroup specifies the name of the node resource group that this cache tracks
resourceGroup string resourceGroup string
clusterResourceGroup string
clusterName string
// enableVMsAgentPool specifies whether VMs agent pool type is supported.
enableVMsAgentPool bool
// vmType can be one of vmTypeVMSS (default), vmTypeStandard // vmType can be one of vmTypeVMSS (default), vmTypeStandard
vmType string vmType string
vmsPoolSet map[string]struct{} // track the nodepools that're vms pool vmsPoolMap map[string]armcontainerservice.AgentPool // track the nodepools that're vms pool
// scaleSets keeps the set of all known scalesets in the resource group, populated/refreshed via VMSS.List() call. // scaleSets keeps the set of all known scalesets in the resource group, populated/refreshed via VMSS.List() call.
// It is only used/populated if vmType is vmTypeVMSS (default). // It is only used/populated if vmType is vmTypeVMSS (default).
@ -106,8 +112,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az
azClient: client, azClient: client,
refreshInterval: cacheTTL, refreshInterval: cacheTTL,
resourceGroup: config.ResourceGroup, resourceGroup: config.ResourceGroup,
clusterResourceGroup: config.ClusterResourceGroup,
clusterName: config.ClusterName,
enableVMsAgentPool: config.EnableVMsAgentPool,
vmType: config.VMType, vmType: config.VMType,
vmsPoolSet: make(map[string]struct{}), vmsPoolMap: make(map[string]armcontainerservice.AgentPool),
scaleSets: make(map[string]compute.VirtualMachineScaleSet), scaleSets: make(map[string]compute.VirtualMachineScaleSet),
virtualMachines: make(map[string][]compute.VirtualMachine), virtualMachines: make(map[string][]compute.VirtualMachine),
registeredNodeGroups: make([]cloudprovider.NodeGroup, 0), registeredNodeGroups: make([]cloudprovider.NodeGroup, 0),
@ -130,11 +139,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az
return cache, nil return cache, nil
} }
func (m *azureCache) getVMsPoolSet() map[string]struct{} { func (m *azureCache) getVMsPoolMap() map[string]armcontainerservice.AgentPool {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.vmsPoolSet return m.vmsPoolMap
} }
func (m *azureCache) getVirtualMachines() map[string][]compute.VirtualMachine { func (m *azureCache) getVirtualMachines() map[string][]compute.VirtualMachine {
@ -232,13 +241,20 @@ func (m *azureCache) fetchAzureResources() error {
return err return err
} }
m.scaleSets = vmssResult m.scaleSets = vmssResult
vmResult, vmsPoolSet, err := m.fetchVirtualMachines() vmResult, err := m.fetchVirtualMachines()
if err != nil { if err != nil {
return err return err
} }
// we fetch both sets of resources since CAS may operate on mixed nodepools // we fetch both sets of resources since CAS may operate on mixed nodepools
m.virtualMachines = vmResult m.virtualMachines = vmResult
m.vmsPoolSet = vmsPoolSet // fetch VMs pools if enabled
if m.enableVMsAgentPool {
vmsPoolMap, err := m.fetchVMsPools()
if err != nil {
return err
}
m.vmsPoolMap = vmsPoolMap
}
return nil return nil
} }
@ -251,19 +267,17 @@ const (
) )
// fetchVirtualMachines returns the updated list of virtual machines in the config resource group using the Azure API. // fetchVirtualMachines returns the updated list of virtual machines in the config resource group using the Azure API.
func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, map[string]struct{}, error) { func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, error) {
ctx, cancel := getContextWithCancel() ctx, cancel := getContextWithCancel()
defer cancel() defer cancel()
result, err := m.azClient.virtualMachinesClient.List(ctx, m.resourceGroup) result, err := m.azClient.virtualMachinesClient.List(ctx, m.resourceGroup)
if err != nil { if err != nil {
klog.Errorf("VirtualMachinesClient.List in resource group %q failed: %v", m.resourceGroup, err) klog.Errorf("VirtualMachinesClient.List in resource group %q failed: %v", m.resourceGroup, err)
return nil, nil, err.Error() return nil, err.Error()
} }
instances := make(map[string][]compute.VirtualMachine) instances := make(map[string][]compute.VirtualMachine)
// track the nodepools that're vms pools
vmsPoolSet := make(map[string]struct{})
for _, instance := range result { for _, instance := range result {
if instance.Tags == nil { if instance.Tags == nil {
continue continue
@ -280,20 +294,43 @@ func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine
} }
instances[to.String(vmPoolName)] = append(instances[to.String(vmPoolName)], instance) instances[to.String(vmPoolName)] = append(instances[to.String(vmPoolName)], instance)
}
// if the nodepool is already in the map, skip it return instances, nil
if _, ok := vmsPoolSet[to.String(vmPoolName)]; ok {
continue
} }
// nodes from vms pool will have tag "aks-managed-agentpool-type" set to "VirtualMachines" // fetchVMsPools returns a name to agentpool map of all the VMs pools in the cluster
if agentpoolType := tags[agentpoolTypeTag]; agentpoolType != nil { func (m *azureCache) fetchVMsPools() (map[string]armcontainerservice.AgentPool, error) {
if strings.EqualFold(to.String(agentpoolType), vmsPoolType) { ctx, cancel := getContextWithTimeout(vmsContextTimeout)
vmsPoolSet[to.String(vmPoolName)] = struct{}{} defer cancel()
// defensive check, should never happen when enableVMsAgentPool toggle is on
if m.azClient.agentPoolClient == nil {
return nil, errors.New("agentPoolClient is nil")
}
vmsPoolMap := make(map[string]armcontainerservice.AgentPool)
pager := m.azClient.agentPoolClient.NewListPager(m.clusterResourceGroup, m.clusterName, nil)
var aps []*armcontainerservice.AgentPool
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
klog.Errorf("agentPoolClient.pager.NextPage in cluster %s resource group %s failed: %v",
m.clusterName, m.clusterResourceGroup, err)
return nil, err
}
aps = append(aps, resp.Value...)
}
for _, ap := range aps {
if ap != nil && ap.Name != nil && ap.Properties != nil && ap.Properties.Type != nil &&
*ap.Properties.Type == armcontainerservice.AgentPoolTypeVirtualMachines {
// we only care about VMs pools, skip other types
klog.V(6).Infof("Found VMs pool %q", *ap.Name)
vmsPoolMap[*ap.Name] = *ap
} }
} }
}
return instances, vmsPoolSet, nil return vmsPoolMap, nil
} }
// fetchScaleSets returns the updated list of scale sets in the config resource group using the Azure API. // fetchScaleSets returns the updated list of scale sets in the config resource group using the Azure API.
@ -422,7 +459,7 @@ func (m *azureCache) HasInstance(providerID string) (bool, error) {
// FindForInstance returns node group of the given Instance // FindForInstance returns node group of the given Instance
func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudprovider.NodeGroup, error) { func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudprovider.NodeGroup, error) {
vmsPoolSet := m.getVMsPoolSet() vmsPoolMap := m.getVMsPoolMap()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -441,7 +478,7 @@ func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudpr
} }
// cluster with vmss pool only // cluster with vmss pool only
if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolSet) == 0 { if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolMap) == 0 {
if m.areAllScaleSetsUniform() { if m.areAllScaleSetsUniform() {
// Omit virtual machines not managed by vmss only in case of uniform scale set. // Omit virtual machines not managed by vmss only in case of uniform scale set.
if ok := virtualMachineRE.Match([]byte(inst.Name)); ok { if ok := virtualMachineRE.Match([]byte(inst.Name)); ok {

View File

@ -22,9 +22,42 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
providerazureconsts "sigs.k8s.io/cloud-provider-azure/pkg/consts" providerazureconsts "sigs.k8s.io/cloud-provider-azure/pkg/consts"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/go-autorest/autorest/to"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
) )
func TestFetchVMsPools(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
provider := newTestProvider(t)
ac := provider.azureManager.azureCache
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ac.azClient.agentPoolClient = mockAgentpoolclient
vmsPool := getTestVMsAgentPool(false)
vmssPoolType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("vmsspool1"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssPoolType,
},
}
invalidPool := armcontainerservice.AgentPool{}
fakeAPListPager := getFakeAgentpoolListPager(&vmsPool, &vmssPool, &invalidPool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
vmsPoolMap, err := ac.fetchVMsPools()
assert.NoError(t, err)
assert.Equal(t, 1, len(vmsPoolMap))
_, ok := vmsPoolMap[to.String(vmsPool.Name)]
assert.True(t, ok)
}
func TestRegister(t *testing.T) { func TestRegister(t *testing.T) {
provider := newTestProvider(t) provider := newTestProvider(t)
ss := newTestScaleSet(provider.azureManager, "ss") ss := newTestScaleSet(provider.azureManager, "ss")

View File

@ -19,6 +19,8 @@ package azure
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"time"
_ "go.uber.org/mock/mockgen/model" // for go:generate _ "go.uber.org/mock/mockgen/model" // for go:generate
@ -29,7 +31,7 @@ import (
azurecore_policy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azurecore_policy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/azure"
@ -47,7 +49,12 @@ import (
providerazureconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" providerazureconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config"
) )
//go:generate sh -c "mockgen k8s.io/autoscaler/cluster-autoscaler/cloudprovider/azure AgentPoolsClient >./agentpool_client.go" //go:generate sh -c "mockgen -source=azure_client.go -destination azure_mock_agentpool_client.go -package azure -exclude_interfaces DeploymentsClient"
const (
vmsContextTimeout = 5 * time.Minute
vmsAsyncContextTimeout = 30 * time.Minute
)
// AgentPoolsClient interface defines the methods needed for scaling vms pool. // AgentPoolsClient interface defines the methods needed for scaling vms pool.
// it is implemented by track2 sdk armcontainerservice.AgentPoolsClient // it is implemented by track2 sdk armcontainerservice.AgentPoolsClient
@ -68,52 +75,89 @@ type AgentPoolsClient interface {
machines armcontainerservice.AgentPoolDeleteMachinesParameter, machines armcontainerservice.AgentPoolDeleteMachinesParameter,
options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) ( options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (
*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) *runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error)
NewListPager(
resourceGroupName, resourceName string,
options *armcontainerservice.AgentPoolsClientListOptions,
) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse]
} }
func getAgentpoolClientCredentials(cfg *Config) (azcore.TokenCredential, error) { func getAgentpoolClientCredentials(cfg *Config) (azcore.TokenCredential, error) {
var cred azcore.TokenCredential if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal {
var err error // Use MSI
if cfg.AuthMethod == authMethodCLI { if cfg.UseManagedIdentityExtension {
cred, err = azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ // Use System Assigned MSI
TenantID: cfg.TenantID}) if cfg.UserAssignedIdentityID == "" {
if err != nil { klog.V(4).Info("Agentpool client: using System Assigned MSI to retrieve access token")
klog.Errorf("NewAzureCLICredential failed: %v", err) return azidentity.NewManagedIdentityCredential(nil)
return nil, err
} }
} else if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal { // Use User Assigned MSI
cred, err = azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil) klog.V(4).Info("Agentpool client: using User Assigned MSI to retrieve access token")
if err != nil { return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
klog.Errorf("NewClientSecretCredential failed: %v", err) ID: azidentity.ClientID(cfg.UserAssignedIdentityID),
return nil, err })
}
} else {
return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod)
}
return cred, nil
} }
func getAgentpoolClientRetryOptions(cfg *Config) azurecore_policy.RetryOptions { // Use Service Principal with ClientID and ClientSecret
if cfg.AuthMethod == authMethodCLI { if cfg.AADClientID != "" && cfg.AADClientSecret != "" {
return azurecore_policy.RetryOptions{ klog.V(2).Infoln("Agentpool client: using client_id+client_secret to retrieve access token")
MaxRetries: -1, // no retry when using CLI auth for UT return azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil)
}
// Use Service Principal with ClientCert and AADClientCertPassword
if cfg.AADClientID != "" && cfg.AADClientCertPath != "" {
klog.V(2).Infoln("Agentpool client: using client_cert+client_private_key to retrieve access token")
certData, err := os.ReadFile(cfg.AADClientCertPath)
if err != nil {
return nil, fmt.Errorf("reading the client certificate from file %s failed with error: %w", cfg.AADClientCertPath, err)
}
certs, privateKey, err := azidentity.ParseCertificates(certData, []byte(cfg.AADClientCertPassword))
if err != nil {
return nil, fmt.Errorf("parsing service principal certificate data failed with error: %w", err)
}
return azidentity.NewClientCertificateCredential(cfg.TenantID, cfg.AADClientID, certs, privateKey, &azidentity.ClientCertificateCredentialOptions{
SendCertificateChain: true,
})
} }
} }
return azextensions.DefaultRetryOpts()
if cfg.UseFederatedWorkloadIdentityExtension {
klog.V(4).Info("Agentpool client: using workload identity for access token")
return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
TokenFilePath: cfg.AADFederatedTokenFile,
})
}
return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod)
} }
func newAgentpoolClient(cfg *Config) (AgentPoolsClient, error) { func newAgentpoolClient(cfg *Config) (AgentPoolsClient, error) {
retryOptions := getAgentpoolClientRetryOptions(cfg) retryOptions := azextensions.DefaultRetryOpts()
cred, err := getAgentpoolClientCredentials(cfg)
if err != nil {
klog.Errorf("failed to get agent pool client credentials: %v", err)
return nil, err
}
env := azure.PublicCloud // default to public cloud
if cfg.Cloud != "" {
var err error
env, err = azure.EnvironmentFromName(cfg.Cloud)
if err != nil {
klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err)
return nil, err
}
}
if cfg.ARMBaseURLForAPClient != "" { if cfg.ARMBaseURLForAPClient != "" {
klog.V(10).Infof("Using ARMBaseURLForAPClient to create agent pool client") klog.V(10).Infof("Using ARMBaseURLForAPClient to create agent pool client")
return newAgentpoolClientWithConfig(cfg.SubscriptionID, nil, cfg.ARMBaseURLForAPClient, "UNKNOWN", retryOptions) return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, cfg.ARMBaseURLForAPClient, env.TokenAudience, retryOptions, true /*insecureAllowCredentialWithHTTP*/)
} }
return newAgentpoolClientWithPublicEndpoint(cfg, retryOptions) return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions, false /*insecureAllowCredentialWithHTTP*/)
} }
func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCredential, func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCredential,
cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) { cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions, insecureAllowCredentialWithHTTP bool) (AgentPoolsClient, error) {
agentPoolsClient, err := armcontainerservice.NewAgentPoolsClient(subscriptionID, cred, agentPoolsClient, err := armcontainerservice.NewAgentPoolsClient(subscriptionID, cred,
&policy.ClientOptions{ &policy.ClientOptions{
ClientOptions: azurecore_policy.ClientOptions{ ClientOptions: azurecore_policy.ClientOptions{
@ -125,6 +169,7 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden
}, },
}, },
}, },
InsecureAllowCredentialWithHTTP: insecureAllowCredentialWithHTTP,
Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()), Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()),
Transport: azextensions.DefaultHTTPClient(), Transport: azextensions.DefaultHTTPClient(),
Retry: retryOptions, Retry: retryOptions,
@ -139,26 +184,6 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden
return agentPoolsClient, nil return agentPoolsClient, nil
} }
func newAgentpoolClientWithPublicEndpoint(cfg *Config, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) {
cred, err := getAgentpoolClientCredentials(cfg)
if err != nil {
klog.Errorf("failed to get agent pool client credentials: %v", err)
return nil, err
}
// default to public cloud
env := azure.PublicCloud
if cfg.Cloud != "" {
env, err = azure.EnvironmentFromName(cfg.Cloud)
if err != nil {
klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err)
return nil, err
}
}
return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions)
}
type azClient struct { type azClient struct {
virtualMachineScaleSetsClient vmssclient.Interface virtualMachineScaleSetsClient vmssclient.Interface
virtualMachineScaleSetVMsClient vmssvmclient.Interface virtualMachineScaleSetVMsClient vmssvmclient.Interface
@ -232,9 +257,11 @@ func newAzClient(cfg *Config, env *azure.Environment) (*azClient, error) {
agentPoolClient, err := newAgentpoolClient(cfg) agentPoolClient, err := newAgentpoolClient(cfg)
if err != nil { if err != nil {
// we don't want to fail the whole process so we don't break any existing functionality klog.Errorf("newAgentpoolClient failed with error: %s", err)
// since this may not be fatal - it is only used by vms pool which is still under development. if cfg.EnableVMsAgentPool {
klog.Warningf("newAgentpoolClient failed with error: %s", err) // only return error if VMs agent pool is supported which is controlled by toggle
return nil, err
}
} }
return &azClient{ return &azClient{

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources" "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
"github.com/Azure/go-autorest/autorest/to" "github.com/Azure/go-autorest/autorest/to"
@ -132,7 +133,7 @@ func TestNodeGroups(t *testing.T) {
) )
assert.True(t, registered) assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup( registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"), newTestVMsPool(provider.azureManager),
) )
assert.True(t, registered) assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2) assert.Equal(t, len(provider.NodeGroups()), 2)
@ -146,9 +147,14 @@ func TestHasInstance(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl) mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient
provider.azureManager.azureCache.clusterName = "test-cluster"
provider.azureManager.azureCache.clusterResourceGroup = "test-rg"
provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types
// Simulate node groups and instances // Simulate node groups and instances
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform) expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform)
@ -158,6 +164,20 @@ func TestHasInstance(t *testing.T) {
mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil).AnyTimes() mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil).AnyTimes()
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes() mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes()
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("test-asg"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssType,
},
}
vmsPool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool)
mockAgentpoolclient.EXPECT().NewListPager(
provider.azureManager.azureCache.clusterResourceGroup,
provider.azureManager.azureCache.clusterName, nil).
Return(fakeAPListPager).AnyTimes()
// Register node groups // Register node groups
assert.Equal(t, len(provider.NodeGroups()), 0) assert.Equal(t, len(provider.NodeGroups()), 0)
@ -168,9 +188,9 @@ func TestHasInstance(t *testing.T) {
assert.True(t, registered) assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup( registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"), newTestVMsPool(provider.azureManager),
) )
provider.azureManager.explicitlyConfigured["test-vms-pool"] = true provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true
assert.True(t, registered) assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2) assert.Equal(t, len(provider.NodeGroups()), 2)
@ -264,9 +284,14 @@ func TestMixedNodeGroups(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl) mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
provider.azureManager.azureCache.clusterName = "test-cluster"
provider.azureManager.azureCache.clusterResourceGroup = "test-rg"
provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types
provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform) expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform)
expectedVMsPoolVMs := newTestVMsPoolVMList(3) expectedVMsPoolVMs := newTestVMsPoolVMList(3)
@ -276,6 +301,19 @@ func TestMixedNodeGroups(t *testing.T) {
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes() mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes()
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("test-asg"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssType,
},
}
vmsPool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool)
mockAgentpoolclient.EXPECT().NewListPager(provider.azureManager.azureCache.clusterResourceGroup, provider.azureManager.azureCache.clusterName, nil).
Return(fakeAPListPager).AnyTimes()
assert.Equal(t, len(provider.NodeGroups()), 0) assert.Equal(t, len(provider.NodeGroups()), 0)
registered := provider.azureManager.RegisterNodeGroup( registered := provider.azureManager.RegisterNodeGroup(
newTestScaleSet(provider.azureManager, "test-asg"), newTestScaleSet(provider.azureManager, "test-asg"),
@ -284,9 +322,9 @@ func TestMixedNodeGroups(t *testing.T) {
assert.True(t, registered) assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup( registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"), newTestVMsPool(provider.azureManager),
) )
provider.azureManager.explicitlyConfigured["test-vms-pool"] = true provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true
assert.True(t, registered) assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2) assert.Equal(t, len(provider.NodeGroups()), 2)
@ -307,7 +345,7 @@ func TestMixedNodeGroups(t *testing.T) {
group, err = provider.NodeGroupForNode(vmsPoolNode) group, err = provider.NodeGroupForNode(vmsPoolNode)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, group, "Group should not be nil") assert.NotNil(t, group, "Group should not be nil")
assert.Equal(t, group.Id(), "test-vms-pool") assert.Equal(t, group.Id(), vmsNodeGroupName)
assert.Equal(t, group.MinSize(), 3) assert.Equal(t, group.MinSize(), 3)
assert.Equal(t, group.MaxSize(), 10) assert.Equal(t, group.MaxSize(), 10)
} }

View File

@ -86,6 +86,9 @@ type Config struct {
// EnableForceDelete defines whether to enable force deletion on the APIs // EnableForceDelete defines whether to enable force deletion on the APIs
EnableForceDelete bool `json:"enableForceDelete,omitempty" yaml:"enableForceDelete,omitempty"` EnableForceDelete bool `json:"enableForceDelete,omitempty" yaml:"enableForceDelete,omitempty"`
// EnableVMsAgentPool defines whether to support VMs agentpool type in addition to VMSS type
EnableVMsAgentPool bool `json:"enableVMsAgentPool,omitempty" yaml:"enableVMsAgentPool,omitempty"`
// (DEPRECATED, DO NOT USE) EnableDynamicInstanceList defines whether to enable dynamic instance workflow for instance information check // (DEPRECATED, DO NOT USE) EnableDynamicInstanceList defines whether to enable dynamic instance workflow for instance information check
EnableDynamicInstanceList bool `json:"enableDynamicInstanceList,omitempty" yaml:"enableDynamicInstanceList,omitempty"` EnableDynamicInstanceList bool `json:"enableDynamicInstanceList,omitempty" yaml:"enableDynamicInstanceList,omitempty"`
@ -122,6 +125,7 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) {
// Static defaults // Static defaults
cfg.EnableDynamicInstanceList = false cfg.EnableDynamicInstanceList = false
cfg.EnableVmssFlexNodes = false cfg.EnableVmssFlexNodes = false
cfg.EnableVMsAgentPool = false
cfg.CloudProviderBackoffRetries = providerazureconsts.BackoffRetriesDefault cfg.CloudProviderBackoffRetries = providerazureconsts.BackoffRetriesDefault
cfg.CloudProviderBackoffExponent = providerazureconsts.BackoffExponentDefault cfg.CloudProviderBackoffExponent = providerazureconsts.BackoffExponentDefault
cfg.CloudProviderBackoffDuration = providerazureconsts.BackoffDurationDefault cfg.CloudProviderBackoffDuration = providerazureconsts.BackoffDurationDefault
@ -257,6 +261,9 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) {
if _, err = assignBoolFromEnvIfExists(&cfg.StrictCacheUpdates, "AZURE_STRICT_CACHE_UPDATES"); err != nil { if _, err = assignBoolFromEnvIfExists(&cfg.StrictCacheUpdates, "AZURE_STRICT_CACHE_UPDATES"); err != nil {
return nil, err return nil, err
} }
if _, err = assignBoolFromEnvIfExists(&cfg.EnableVMsAgentPool, "AZURE_ENABLE_VMS_AGENT_POOLS"); err != nil {
return nil, err
}
if _, err = assignBoolFromEnvIfExists(&cfg.EnableDynamicInstanceList, "AZURE_ENABLE_DYNAMIC_INSTANCE_LIST"); err != nil { if _, err = assignBoolFromEnvIfExists(&cfg.EnableDynamicInstanceList, "AZURE_ENABLE_DYNAMIC_INSTANCE_LIST"); err != nil {
return nil, err return nil, err
} }

View File

@ -22,80 +22,79 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"k8s.io/klog/v2" "k8s.io/klog/v2"
) )
// GetVMSSTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information. // GetInstanceTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information.
// It is declared as a variable for testing purpose. // It is declared as a variable for testing purpose.
var GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { var GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
var vmssType *InstanceType var instanceType *InstanceType
for k := range InstanceTypes { for k := range InstanceTypes {
if strings.EqualFold(k, *template.Sku.Name) { if strings.EqualFold(k, template.SkuName) {
vmssType = InstanceTypes[k] instanceType = InstanceTypes[k]
break break
} }
} }
promoRe := regexp.MustCompile(`(?i)_promo`) promoRe := regexp.MustCompile(`(?i)_promo`)
if promoRe.MatchString(*template.Sku.Name) { if promoRe.MatchString(template.SkuName) {
if vmssType == nil { if instanceType == nil {
// We didn't find an exact match but this is a promo type, check for matching standard // We didn't find an exact match but this is a promo type, check for matching standard
klog.V(4).Infof("No exact match found for %s, checking standard types", *template.Sku.Name) klog.V(4).Infof("No exact match found for %s, checking standard types", template.SkuName)
skuName := promoRe.ReplaceAllString(*template.Sku.Name, "") skuName := promoRe.ReplaceAllString(template.SkuName, "")
for k := range InstanceTypes { for k := range InstanceTypes {
if strings.EqualFold(k, skuName) { if strings.EqualFold(k, skuName) {
vmssType = InstanceTypes[k] instanceType = InstanceTypes[k]
break break
} }
} }
} }
} }
if vmssType == nil { if instanceType == nil {
return vmssType, fmt.Errorf("instance type %q not supported", *template.Sku.Name) return instanceType, fmt.Errorf("instance type %q not supported", template.SkuName)
} }
return vmssType, nil return instanceType, nil
} }
// GetVMSSTypeDynamically fetched vmss instance information using sku api calls. // GetInstanceTypeDynamically fetched vmss instance information using sku api calls.
// It is declared as a variable for testing purpose. // It is declared as a variable for testing purpose.
var GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { var GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
ctx := context.Background() ctx := context.Background()
var vmssType InstanceType var instanceType InstanceType
sku, err := azCache.GetSKU(ctx, *template.Sku.Name, *template.Location) sku, err := azCache.GetSKU(ctx, template.SkuName, template.Location)
if err != nil { if err != nil {
// We didn't find an exact match but this is a promo type, check for matching standard // We didn't find an exact match but this is a promo type, check for matching standard
promoRe := regexp.MustCompile(`(?i)_promo`) promoRe := regexp.MustCompile(`(?i)_promo`)
skuName := promoRe.ReplaceAllString(*template.Sku.Name, "") skuName := promoRe.ReplaceAllString(template.SkuName, "")
if skuName != *template.Sku.Name { if skuName != template.SkuName {
klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", *template.Sku.Name, skuName, err) klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", template.SkuName, skuName, err)
sku, err = azCache.GetSKU(ctx, skuName, *template.Location) sku, err = azCache.GetSKU(ctx, skuName, template.Location)
} }
if err != nil { if err != nil {
return vmssType, fmt.Errorf("instance type %q not supported. Error %v", *template.Sku.Name, err) return instanceType, fmt.Errorf("instance type %q not supported. Error %v", template.SkuName, err)
} }
} }
vmssType.VCPU, err = sku.VCPU() instanceType.VCPU, err = sku.VCPU()
if err != nil { if err != nil {
klog.V(1).Infof("Failed to parse vcpu from sku %q %v", *template.Sku.Name, err) klog.V(1).Infof("Failed to parse vcpu from sku %q %v", template.SkuName, err)
return vmssType, err return instanceType, err
} }
gpu, err := getGpuFromSku(sku) gpu, err := getGpuFromSku(sku)
if err != nil { if err != nil {
klog.V(1).Infof("Failed to parse gpu from sku %q %v", *template.Sku.Name, err) klog.V(1).Infof("Failed to parse gpu from sku %q %v", template.SkuName, err)
return vmssType, err return instanceType, err
} }
vmssType.GPU = gpu instanceType.GPU = gpu
memoryGb, err := sku.Memory() memoryGb, err := sku.Memory()
if err != nil { if err != nil {
klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", *template.Sku.Name, err) klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", template.SkuName, err)
return vmssType, err return instanceType, err
} }
vmssType.MemoryMb = int64(memoryGb) * 1024 instanceType.MemoryMb = int64(memoryGb) * 1024
return vmssType, nil return instanceType, nil
} }

View File

@ -168,6 +168,23 @@ func (m *AzureManager) fetchExplicitNodeGroups(specs []string) error {
return nil return nil
} }
// parseSKUAndVMsAgentpoolNameFromSpecName parses the spec name for a mixed-SKU VMs pool.
// The spec name should be in the format <agentpoolname>/<sku>, e.g., "mypool1/Standard_D2s_v3", if the agent pool is a VMs pool.
// This method returns a boolean indicating if the agent pool is a VMs pool, along with the agent pool name and SKU.
func (m *AzureManager) parseSKUAndVMsAgentpoolNameFromSpecName(name string) (bool, string, string) {
parts := strings.Split(name, "/")
if len(parts) == 2 {
agentPoolName := parts[0]
sku := parts[1]
vmsPoolMap := m.azureCache.getVMsPoolMap()
if _, ok := vmsPoolMap[agentPoolName]; ok {
return true, agentPoolName, sku
}
}
return false, "", ""
}
func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGroup, error) { func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGroup, error) {
scaleToZeroSupported := scaleToZeroSupportedStandard scaleToZeroSupported := scaleToZeroSupportedStandard
if strings.EqualFold(m.config.VMType, providerazureconsts.VMTypeVMSS) { if strings.EqualFold(m.config.VMType, providerazureconsts.VMTypeVMSS) {
@ -177,9 +194,13 @@ func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGr
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse node group spec: %v", err) return nil, fmt.Errorf("failed to parse node group spec: %v", err)
} }
vmsPoolSet := m.azureCache.getVMsPoolSet()
if _, ok := vmsPoolSet[s.Name]; ok { // Starting from release 1.30, a cluster may have both VMSS and VMs pools.
return NewVMsPool(s, m), nil // Therefore, we cannot solely rely on the VMType to determine the node group type.
// Instead, we need to check the cache to determine if the agent pool is a VMs pool.
isVMsPool, agentPoolName, sku := m.parseSKUAndVMsAgentpoolNameFromSpecName(s.Name)
if isVMsPool {
return NewVMPool(s, m, agentPoolName, sku)
} }
switch m.config.VMType { switch m.config.VMType {

View File

@ -297,6 +297,7 @@ func TestCreateAzureManagerValidConfig(t *testing.T) {
VmssVmsCacheJitter: 120, VmssVmsCacheJitter: 120,
MaxDeploymentsCount: 8, MaxDeploymentsCount: 8,
EnableFastDeleteOnFailedProvisioning: true, EnableFastDeleteOnFailedProvisioning: true,
EnableVMsAgentPool: false,
} }
assert.NoError(t, err) assert.NoError(t, err)
@ -618,9 +619,14 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMSSClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachineScaleSet{}, nil).AnyTimes() mockVMSSClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachineScaleSet{}, nil).AnyTimes()
mockVMClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachine{}, nil).AnyTimes() mockVMClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachine{}, nil).AnyTimes()
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
vmspool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmspool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager).AnyTimes()
mockAzClient := &azClient{ mockAzClient := &azClient{
virtualMachinesClient: mockVMClient, virtualMachinesClient: mockVMClient,
virtualMachineScaleSetsClient: mockVMSSClient, virtualMachineScaleSetsClient: mockVMSSClient,
agentPoolClient: mockAgentpoolclient,
} }
expectedConfig := &Config{ expectedConfig := &Config{
@ -702,6 +708,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
VmssVmsCacheJitter: 90, VmssVmsCacheJitter: 90,
MaxDeploymentsCount: 8, MaxDeploymentsCount: 8,
EnableFastDeleteOnFailedProvisioning: true, EnableFastDeleteOnFailedProvisioning: true,
EnableVMsAgentPool: true,
} }
t.Setenv("ARM_CLOUD", "AzurePublicCloud") t.Setenv("ARM_CLOUD", "AzurePublicCloud")
@ -735,6 +742,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
t.Setenv("ARM_CLUSTER_RESOURCE_GROUP", "myrg") t.Setenv("ARM_CLUSTER_RESOURCE_GROUP", "myrg")
t.Setenv("ARM_BASE_URL_FOR_AP_CLIENT", "nodeprovisioner-svc.nodeprovisioner.svc.cluster.local") t.Setenv("ARM_BASE_URL_FOR_AP_CLIENT", "nodeprovisioner-svc.nodeprovisioner.svc.cluster.local")
t.Setenv("AZURE_ENABLE_FAST_DELETE_ON_FAILED_PROVISIONING", "true") t.Setenv("AZURE_ENABLE_FAST_DELETE_ON_FAILED_PROVISIONING", "true")
t.Setenv("AZURE_ENABLE_VMS_AGENT_POOLS", "true")
t.Run("environment variables correctly set", func(t *testing.T) { t.Run("environment variables correctly set", func(t *testing.T) {
manager, err := createAzureManagerInternal(nil, cloudprovider.NodeGroupDiscoveryOptions{}, mockAzClient) manager, err := createAzureManagerInternal(nil, cloudprovider.NodeGroupDiscoveryOptions{}, mockAzClient)

View File

@ -21,7 +21,7 @@ import (
reflect "reflect" reflect "reflect"
runtime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" runtime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4" armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@ -49,46 +49,60 @@ func (m *MockAgentPoolsClient) EXPECT() *MockAgentPoolsClientMockRecorder {
} }
// BeginCreateOrUpdate mocks base method. // BeginCreateOrUpdate mocks base method.
func (m *MockAgentPoolsClient) BeginCreateOrUpdate(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPool, arg5 *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) { func (m *MockAgentPoolsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, parameters armcontainerservice.AgentPool, options *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginCreateOrUpdate", arg0, arg1, arg2, arg3, arg4, arg5) ret := m.ctrl.Call(m, "BeginCreateOrUpdate", ctx, resourceGroupName, resourceName, agentPoolName, parameters, options)
ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]) ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse])
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// BeginCreateOrUpdate indicates an expected call of BeginCreateOrUpdate. // BeginCreateOrUpdate indicates an expected call of BeginCreateOrUpdate.
func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(ctx, resourceGroupName, resourceName, agentPoolName, parameters, options any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), arg0, arg1, arg2, arg3, arg4, arg5) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), ctx, resourceGroupName, resourceName, agentPoolName, parameters, options)
} }
// BeginDeleteMachines mocks base method. // BeginDeleteMachines mocks base method.
func (m *MockAgentPoolsClient) BeginDeleteMachines(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPoolDeleteMachinesParameter, arg5 *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) { func (m *MockAgentPoolsClient) BeginDeleteMachines(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, machines armcontainerservice.AgentPoolDeleteMachinesParameter, options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginDeleteMachines", arg0, arg1, arg2, arg3, arg4, arg5) ret := m.ctrl.Call(m, "BeginDeleteMachines", ctx, resourceGroupName, resourceName, agentPoolName, machines, options)
ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]) ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse])
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// BeginDeleteMachines indicates an expected call of BeginDeleteMachines. // BeginDeleteMachines indicates an expected call of BeginDeleteMachines.
func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(ctx, resourceGroupName, resourceName, agentPoolName, machines, options any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), arg0, arg1, arg2, arg3, arg4, arg5) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), ctx, resourceGroupName, resourceName, agentPoolName, machines, options)
} }
// Get mocks base method. // Get mocks base method.
func (m *MockAgentPoolsClient) Get(arg0 context.Context, arg1, arg2, arg3 string, arg4 *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) { func (m *MockAgentPoolsClient) Get(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, options *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, resourceName, agentPoolName, options)
ret0, _ := ret[0].(armcontainerservice.AgentPoolsClientGetResponse) ret0, _ := ret[0].(armcontainerservice.AgentPoolsClientGetResponse)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// Get indicates an expected call of Get. // Get indicates an expected call of Get.
func (mr *MockAgentPoolsClientMockRecorder) Get(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { func (mr *MockAgentPoolsClientMockRecorder) Get(ctx, resourceGroupName, resourceName, agentPoolName, options any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), ctx, resourceGroupName, resourceName, agentPoolName, options)
}
// NewListPager mocks base method.
func (m *MockAgentPoolsClient) NewListPager(resourceGroupName, resourceName string, options *armcontainerservice.AgentPoolsClientListOptions) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, resourceName, options)
ret0, _ := ret[0].(*runtime.Pager[armcontainerservice.AgentPoolsClientListResponse])
return ret0
}
// NewListPager indicates an expected call of NewListPager.
func (mr *MockAgentPoolsClientMockRecorder) NewListPager(resourceGroupName, resourceName, options any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockAgentPoolsClient)(nil).NewListPager), resourceGroupName, resourceName, options)
} }

View File

@ -651,15 +651,18 @@ func (scaleSet *ScaleSet) Debug() string {
// TemplateNodeInfo returns a node template for this scale set. // TemplateNodeInfo returns a node template for this scale set.
func (scaleSet *ScaleSet) TemplateNodeInfo() (*framework.NodeInfo, error) { func (scaleSet *ScaleSet) TemplateNodeInfo() (*framework.NodeInfo, error) {
template, err := scaleSet.getVMSSFromCache() vmss, err := scaleSet.getVMSSFromCache()
if err != nil { if err != nil {
return nil, err return nil, err
} }
inputLabels := map[string]string{} inputLabels := map[string]string{}
inputTaints := "" inputTaints := ""
node, err := buildNodeFromTemplate(scaleSet.Name, inputLabels, inputTaints, template, scaleSet.manager, scaleSet.enableDynamicInstanceList) template, err := buildNodeTemplateFromVMSS(vmss, inputLabels, inputTaints)
if err != nil {
return nil, err
}
node, err := buildNodeFromTemplate(scaleSet.Name, template, scaleSet.manager, scaleSet.enableDynamicInstanceList)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1232,12 +1232,12 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
// Properly testing dynamic SKU list through skewer is not possible, // Properly testing dynamic SKU list through skewer is not possible,
// because there are no Resource API mocks included yet. // because there are no Resource API mocks included yet.
// Instead, the rest of the (consumer side) tests here // Instead, the rest of the (consumer side) tests here
// override GetVMSSTypeDynamically and GetVMSSTypeStatically functions. // override GetInstanceTypeDynamically and GetInstanceTypeStatically functions.
t.Run("Checking dynamic workflow", func(t *testing.T) { t.Run("Checking dynamic workflow", func(t *testing.T) {
asg.enableDynamicInstanceList = true asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
vmssType := InstanceType{} vmssType := InstanceType{}
vmssType.VCPU = 1 vmssType.VCPU = 1
vmssType.GPU = 2 vmssType.GPU = 2
@ -1255,10 +1255,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Checking static workflow if dynamic fails", func(t *testing.T) { t.Run("Checking static workflow if dynamic fails", func(t *testing.T) {
asg.enableDynamicInstanceList = true asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
return InstanceType{}, fmt.Errorf("dynamic error exists") return InstanceType{}, fmt.Errorf("dynamic error exists")
} }
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
vmssType := InstanceType{} vmssType := InstanceType{}
vmssType.VCPU = 1 vmssType.VCPU = 1
vmssType.GPU = 2 vmssType.GPU = 2
@ -1276,10 +1276,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Fails to find vmss instance information using static and dynamic workflow, instance not supported", func(t *testing.T) { t.Run("Fails to find vmss instance information using static and dynamic workflow, instance not supported", func(t *testing.T) {
asg.enableDynamicInstanceList = true asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
return InstanceType{}, fmt.Errorf("dynamic error exists") return InstanceType{}, fmt.Errorf("dynamic error exists")
} }
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
return &InstanceType{}, fmt.Errorf("static error exists") return &InstanceType{}, fmt.Errorf("static error exists")
} }
nodeInfo, err := asg.TemplateNodeInfo() nodeInfo, err := asg.TemplateNodeInfo()
@ -1292,7 +1292,7 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Checking static-only workflow", func(t *testing.T) { t.Run("Checking static-only workflow", func(t *testing.T) {
asg.enableDynamicInstanceList = false asg.enableDynamicInstanceList = false
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
vmssType := InstanceType{} vmssType := InstanceType{}
vmssType.VCPU = 1 vmssType.VCPU = 1
vmssType.GPU = 2 vmssType.GPU = 2

View File

@ -24,7 +24,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
apiv1 "k8s.io/api/core/v1" apiv1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -84,8 +86,132 @@ const (
clusterLabelKey = AKSLabelKeyPrefixValue + "cluster" clusterLabelKey = AKSLabelKeyPrefixValue + "cluster"
) )
func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, inputTaints string, // VMPoolNodeTemplate holds properties for node from VMPool
template compute.VirtualMachineScaleSet, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) { type VMPoolNodeTemplate struct {
AgentPoolName string
Taints []apiv1.Taint
Labels map[string]*string
OSDiskType *armcontainerservice.OSDiskType
}
// VMSSNodeTemplate holds properties for node from VMSS
type VMSSNodeTemplate struct {
InputLabels map[string]string
InputTaints string
Tags map[string]*string
OSDisk *compute.VirtualMachineScaleSetOSDisk
}
// NodeTemplate represents a template for an Azure node
type NodeTemplate struct {
SkuName string
InstanceOS string
Location string
Zones []string
VMPoolNodeTemplate *VMPoolNodeTemplate
VMSSNodeTemplate *VMSSNodeTemplate
}
func buildNodeTemplateFromVMSS(vmss compute.VirtualMachineScaleSet, inputLabels map[string]string, inputTaints string) (NodeTemplate, error) {
instanceOS := cloudprovider.DefaultOS
if vmss.VirtualMachineProfile != nil &&
vmss.VirtualMachineProfile.OsProfile != nil &&
vmss.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil {
instanceOS = "windows"
}
var osDisk *compute.VirtualMachineScaleSetOSDisk
if vmss.VirtualMachineProfile != nil &&
vmss.VirtualMachineProfile.StorageProfile != nil &&
vmss.VirtualMachineProfile.StorageProfile.OsDisk != nil {
osDisk = vmss.VirtualMachineProfile.StorageProfile.OsDisk
}
if vmss.Sku == nil || vmss.Sku.Name == nil {
return NodeTemplate{}, fmt.Errorf("VMSS %s has no SKU", to.String(vmss.Name))
}
if vmss.Location == nil {
return NodeTemplate{}, fmt.Errorf("VMSS %s has no location", to.String(vmss.Name))
}
zones := []string{}
if vmss.Zones != nil {
zones = *vmss.Zones
}
return NodeTemplate{
SkuName: *vmss.Sku.Name,
Location: *vmss.Location,
Zones: zones,
InstanceOS: instanceOS,
VMSSNodeTemplate: &VMSSNodeTemplate{
InputLabels: inputLabels,
InputTaints: inputTaints,
OSDisk: osDisk,
Tags: vmss.Tags,
},
}, nil
}
func buildNodeTemplateFromVMPool(vmsPool armcontainerservice.AgentPool, location string, skuName string, labelsFromSpec map[string]string, taintsFromSpec string) (NodeTemplate, error) {
if vmsPool.Properties == nil {
return NodeTemplate{}, fmt.Errorf("vmsPool %s has nil properties", to.String(vmsPool.Name))
}
// labels from the agentpool
labels := vmsPool.Properties.NodeLabels
// labels from spec
for k, v := range labelsFromSpec {
if labels == nil {
labels = make(map[string]*string)
}
labels[k] = to.StringPtr(v)
}
// taints from the agentpool
taintsList := []string{}
for _, taint := range vmsPool.Properties.NodeTaints {
if to.String(taint) != "" {
taintsList = append(taintsList, to.String(taint))
}
}
// taints from spec
if taintsFromSpec != "" {
taintsList = append(taintsList, taintsFromSpec)
}
taintsStr := strings.Join(taintsList, ",")
taints := extractTaintsFromSpecString(taintsStr)
var zones []string
if vmsPool.Properties.AvailabilityZones != nil {
for _, zone := range vmsPool.Properties.AvailabilityZones {
if zone != nil {
zones = append(zones, *zone)
}
}
}
var instanceOS string
if vmsPool.Properties.OSType != nil {
instanceOS = strings.ToLower(string(*vmsPool.Properties.OSType))
}
return NodeTemplate{
SkuName: skuName,
Zones: zones,
InstanceOS: instanceOS,
Location: location,
VMPoolNodeTemplate: &VMPoolNodeTemplate{
AgentPoolName: to.String(vmsPool.Name),
OSDiskType: vmsPool.Properties.OSDiskType,
Taints: taints,
Labels: labels,
},
}, nil
}
func buildNodeFromTemplate(nodeGroupName string, template NodeTemplate, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) {
node := apiv1.Node{} node := apiv1.Node{}
nodeName := fmt.Sprintf("%s-asg-%d", nodeGroupName, rand.Int63()) nodeName := fmt.Sprintf("%s-asg-%d", nodeGroupName, rand.Int63())
@ -104,28 +230,28 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
// Fetching SKU information from SKU API if enableDynamicInstanceList is true. // Fetching SKU information from SKU API if enableDynamicInstanceList is true.
var dynamicErr error var dynamicErr error
if enableDynamicInstanceList { if enableDynamicInstanceList {
var vmssTypeDynamic InstanceType var instanceTypeDynamic InstanceType
klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", *template.Sku.Name) klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", template.SkuName)
vmssTypeDynamic, dynamicErr = GetVMSSTypeDynamically(template, manager.azureCache) instanceTypeDynamic, dynamicErr = GetInstanceTypeDynamically(template, manager.azureCache)
if dynamicErr == nil { if dynamicErr == nil {
vcpu = vmssTypeDynamic.VCPU vcpu = instanceTypeDynamic.VCPU
gpuCount = vmssTypeDynamic.GPU gpuCount = instanceTypeDynamic.GPU
memoryMb = vmssTypeDynamic.MemoryMb memoryMb = instanceTypeDynamic.MemoryMb
} else { } else {
klog.Errorf("Dynamically fetching of instance information from SKU api failed with error: %v", dynamicErr) klog.Errorf("Dynamically fetching of instance information from SKU api failed with error: %v", dynamicErr)
} }
} }
if !enableDynamicInstanceList || dynamicErr != nil { if !enableDynamicInstanceList || dynamicErr != nil {
klog.V(1).Infof("Falling back to static SKU list for SKU: %s", *template.Sku.Name) klog.V(1).Infof("Falling back to static SKU list for SKU: %s", template.SkuName)
// fall-back on static list of vmss if dynamic workflow fails. // fall-back on static list of vmss if dynamic workflow fails.
vmssTypeStatic, staticErr := GetVMSSTypeStatically(template) instanceTypeStatic, staticErr := GetInstanceTypeStatically(template)
if staticErr == nil { if staticErr == nil {
vcpu = vmssTypeStatic.VCPU vcpu = instanceTypeStatic.VCPU
gpuCount = vmssTypeStatic.GPU gpuCount = instanceTypeStatic.GPU
memoryMb = vmssTypeStatic.MemoryMb memoryMb = instanceTypeStatic.MemoryMb
} else { } else {
// return error if neither of the workflows results with vmss data. // return error if neither of the workflows results with vmss data.
klog.V(1).Infof("Instance type %q not supported, err: %v", *template.Sku.Name, staticErr) klog.V(1).Infof("Instance type %q not supported, err: %v", template.SkuName, staticErr)
return nil, staticErr return nil, staticErr
} }
} }
@ -134,7 +260,7 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
node.Status.Capacity[apiv1.ResourceCPU] = *resource.NewQuantity(vcpu, resource.DecimalSI) node.Status.Capacity[apiv1.ResourceCPU] = *resource.NewQuantity(vcpu, resource.DecimalSI)
// isNPSeries returns if a SKU is an NP-series SKU // isNPSeries returns if a SKU is an NP-series SKU
// SKU API reports GPUs for NP-series but it's actually FPGAs // SKU API reports GPUs for NP-series but it's actually FPGAs
if isNPSeries(*template.Sku.Name) { if isNPSeries(template.SkuName) {
node.Status.Capacity[xilinxFpgaResourceName] = *resource.NewQuantity(gpuCount, resource.DecimalSI) node.Status.Capacity[xilinxFpgaResourceName] = *resource.NewQuantity(gpuCount, resource.DecimalSI)
} else { } else {
node.Status.Capacity[gpu.ResourceNvidiaGPU] = *resource.NewQuantity(gpuCount, resource.DecimalSI) node.Status.Capacity[gpu.ResourceNvidiaGPU] = *resource.NewQuantity(gpuCount, resource.DecimalSI)
@ -145,9 +271,37 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
// TODO: set real allocatable. // TODO: set real allocatable.
node.Status.Allocatable = node.Status.Capacity node.Status.Allocatable = node.Status.Capacity
if template.VMSSNodeTemplate != nil {
node = processVMSSTemplate(template, nodeName, node)
} else if template.VMPoolNodeTemplate != nil {
node = processVMPoolTemplate(template, nodeName, node)
} else {
return nil, fmt.Errorf("invalid node template: missing both VMSS and VMPool templates")
}
klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels)
klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints)
node.Status.Conditions = cloudprovider.BuildReadyConditions()
return &node, nil
}
func processVMPoolTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node {
labels := buildGenericLabels(template, nodeName)
labels[agentPoolNodeLabelKey] = template.VMPoolNodeTemplate.AgentPoolName
if template.VMPoolNodeTemplate.Labels != nil {
for k, v := range template.VMPoolNodeTemplate.Labels {
labels[k] = to.String(v)
}
}
node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels)
node.Spec.Taints = template.VMPoolNodeTemplate.Taints
return node
}
func processVMSSTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node {
// NodeLabels // NodeLabels
if template.Tags != nil { if template.VMSSNodeTemplate.Tags != nil {
for k, v := range template.Tags { for k, v := range template.VMSSNodeTemplate.Tags {
if v != nil { if v != nil {
node.Labels[k] = *v node.Labels[k] = *v
} else { } else {
@ -164,10 +318,10 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
labels := make(map[string]string) labels := make(map[string]string)
// Prefer the explicit labels in spec coming from RP over the VMSS template // Prefer the explicit labels in spec coming from RP over the VMSS template
if len(inputLabels) > 0 { if len(template.VMSSNodeTemplate.InputLabels) > 0 {
labels = inputLabels labels = template.VMSSNodeTemplate.InputLabels
} else { } else {
labels = extractLabelsFromScaleSet(template.Tags) labels = extractLabelsFromTags(template.VMSSNodeTemplate.Tags)
} }
// Add the agentpool label, its value should come from the VMSS poolName tag // Add the agentpool label, its value should come from the VMSS poolName tag
@ -182,87 +336,74 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
labels[agentPoolNodeLabelKey] = node.Labels[poolNameTag] labels[agentPoolNodeLabelKey] = node.Labels[poolNameTag]
} }
// Add the storage profile and storage tier labels // Add the storage profile and storage tier labels for vmss node
if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.StorageProfile != nil && template.VirtualMachineProfile.StorageProfile.OsDisk != nil { if template.VMSSNodeTemplate.OSDisk != nil {
// ephemeral // ephemeral
if template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings != nil && template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings.Option == compute.Local { if template.VMSSNodeTemplate.OSDisk.DiffDiskSettings != nil && template.VMSSNodeTemplate.OSDisk.DiffDiskSettings.Option == compute.Local {
labels[legacyStorageProfileNodeLabelKey] = "ephemeral" labels[legacyStorageProfileNodeLabelKey] = "ephemeral"
labels[storageProfileNodeLabelKey] = "ephemeral" labels[storageProfileNodeLabelKey] = "ephemeral"
} else { } else {
labels[legacyStorageProfileNodeLabelKey] = "managed" labels[legacyStorageProfileNodeLabelKey] = "managed"
labels[storageProfileNodeLabelKey] = "managed" labels[storageProfileNodeLabelKey] = "managed"
} }
if template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk != nil { if template.VMSSNodeTemplate.OSDisk.ManagedDisk != nil {
labels[legacyStorageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType) labels[legacyStorageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType)
labels[storageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType) labels[storageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType)
} }
// Add ephemeral-storage value // Add ephemeral-storage value
if template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB != nil { if template.VMSSNodeTemplate.OSDisk.DiskSizeGB != nil {
node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI) node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VMSSNodeTemplate.OSDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI)
klog.V(4).Infof("OS Disk Size from template is: %d", *template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB) klog.V(4).Infof("OS Disk Size from template is: %d", *template.VMSSNodeTemplate.OSDisk.DiskSizeGB)
klog.V(4).Infof("Setting ephemeral storage to: %v", node.Status.Capacity[apiv1.ResourceEphemeralStorage]) klog.V(4).Infof("Setting ephemeral storage to: %v", node.Status.Capacity[apiv1.ResourceEphemeralStorage])
} }
} }
// If we are on GPU-enabled SKUs, append the accelerator // If we are on GPU-enabled SKUs, append the accelerator
// label so that CA makes better decision when scaling from zero for GPU pools // label so that CA makes better decision when scaling from zero for GPU pools
if isNvidiaEnabledSKU(*template.Sku.Name) { if isNvidiaEnabledSKU(template.SkuName) {
labels[GPULabel] = "nvidia" labels[GPULabel] = "nvidia"
labels[legacyGPULabel] = "nvidia" labels[legacyGPULabel] = "nvidia"
} }
// Extract allocatables from tags // Extract allocatables from tags
resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.Tags) resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.VMSSNodeTemplate.Tags)
for resourceName, val := range resourcesFromTags { for resourceName, val := range resourcesFromTags {
node.Status.Capacity[apiv1.ResourceName(resourceName)] = *val node.Status.Capacity[apiv1.ResourceName(resourceName)] = *val
} }
node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels) node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels)
klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels)
var taints []apiv1.Taint var taints []apiv1.Taint
// Prefer the explicit taints in spec over the VMSS template // Prefer the explicit taints in spec over the tags from vmss or vm
if inputTaints != "" { if template.VMSSNodeTemplate.InputTaints != "" {
taints = extractTaintsFromSpecString(inputTaints) taints = extractTaintsFromSpecString(template.VMSSNodeTemplate.InputTaints)
} else { } else {
taints = extractTaintsFromScaleSet(template.Tags) taints = extractTaintsFromTags(template.VMSSNodeTemplate.Tags)
} }
// Taints from the Scale Set's Tags // Taints from the Scale Set's Tags
node.Spec.Taints = taints node.Spec.Taints = taints
klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints) return node
node.Status.Conditions = cloudprovider.BuildReadyConditions()
return &node, nil
} }
func buildInstanceOS(template compute.VirtualMachineScaleSet) string { func buildGenericLabels(template NodeTemplate, nodeName string) map[string]string {
instanceOS := cloudprovider.DefaultOS
if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.OsProfile != nil && template.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil {
instanceOS = "windows"
}
return instanceOS
}
func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string) map[string]string {
result := make(map[string]string) result := make(map[string]string)
result[kubeletapis.LabelArch] = cloudprovider.DefaultArch result[kubeletapis.LabelArch] = cloudprovider.DefaultArch
result[apiv1.LabelArchStable] = cloudprovider.DefaultArch result[apiv1.LabelArchStable] = cloudprovider.DefaultArch
result[kubeletapis.LabelOS] = buildInstanceOS(template) result[kubeletapis.LabelOS] = template.InstanceOS
result[apiv1.LabelOSStable] = buildInstanceOS(template) result[apiv1.LabelOSStable] = template.InstanceOS
result[apiv1.LabelInstanceType] = *template.Sku.Name result[apiv1.LabelInstanceType] = template.SkuName
result[apiv1.LabelInstanceTypeStable] = *template.Sku.Name result[apiv1.LabelInstanceTypeStable] = template.SkuName
result[apiv1.LabelZoneRegion] = strings.ToLower(*template.Location) result[apiv1.LabelZoneRegion] = strings.ToLower(template.Location)
result[apiv1.LabelTopologyRegion] = strings.ToLower(*template.Location) result[apiv1.LabelTopologyRegion] = strings.ToLower(template.Location)
if template.Zones != nil && len(*template.Zones) > 0 { if len(template.Zones) > 0 {
failureDomains := make([]string, len(*template.Zones)) failureDomains := make([]string, len(template.Zones))
for k, v := range *template.Zones { for k, v := range template.Zones {
failureDomains[k] = strings.ToLower(*template.Location) + "-" + v failureDomains[k] = strings.ToLower(template.Location) + "-" + v
} }
//Picks random zones for Multi-zone nodepool when scaling from zero. //Picks random zones for Multi-zone nodepool when scaling from zero.
//This random zone will not be the same as the zone of the VMSS that is being created, the purpose of creating //This random zone will not be the same as the zone of the VMSS that is being created, the purpose of creating
@ -283,7 +424,7 @@ func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string
return result return result
} }
func extractLabelsFromScaleSet(tags map[string]*string) map[string]string { func extractLabelsFromTags(tags map[string]*string) map[string]string {
result := make(map[string]string) result := make(map[string]string)
for tagName, tagValue := range tags { for tagName, tagValue := range tags {
@ -300,7 +441,7 @@ func extractLabelsFromScaleSet(tags map[string]*string) map[string]string {
return result return result
} }
func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint { func extractTaintsFromTags(tags map[string]*string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0) taints := make([]apiv1.Taint, 0)
for tagName, tagValue := range tags { for tagName, tagValue := range tags {
@ -327,35 +468,61 @@ func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint {
return taints return taints
} }
// extractTaintsFromSpecString is for nodepool taints
// Example of a valid taints string, is the same argument to kubelet's `--register-with-taints` // Example of a valid taints string, is the same argument to kubelet's `--register-with-taints`
// "dedicated=foo:NoSchedule,group=bar:NoExecute,app=fizz:PreferNoSchedule" // "dedicated=foo:NoSchedule,group=bar:NoExecute,app=fizz:PreferNoSchedule"
func extractTaintsFromSpecString(taintsString string) []apiv1.Taint { func extractTaintsFromSpecString(taintsString string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0) taints := make([]apiv1.Taint, 0)
dedupMap := make(map[string]interface{})
// First split the taints at the separator // First split the taints at the separator
splits := strings.Split(taintsString, ",") splits := strings.Split(taintsString, ",")
for _, split := range splits { for _, split := range splits {
taintSplit := strings.Split(split, "=") if dedupMap[split] != nil {
if len(taintSplit) != 2 {
continue continue
} }
dedupMap[split] = struct{}{}
valid, taint := constructTaintFromString(split)
if valid {
taints = append(taints, taint)
}
}
return taints
}
// buildNodeTaintsForVMPool is for VMPool taints, it looks for the taints in the format
// []string{zone=dmz:NoSchedule, usage=monitoring:NoSchedule}
func buildNodeTaintsForVMPool(taintStrs []string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0)
for _, taintStr := range taintStrs {
valid, taint := constructTaintFromString(taintStr)
if valid {
taints = append(taints, taint)
}
}
return taints
}
// constructTaintFromString constructs a taint from a string in the format <key>=<value>:<effect>
// if the input string is not in the correct format, it returns false and an empty taint
func constructTaintFromString(taintString string) (bool, apiv1.Taint) {
taintSplit := strings.Split(taintString, "=")
if len(taintSplit) != 2 {
return false, apiv1.Taint{}
}
taintKey := taintSplit[0] taintKey := taintSplit[0]
taintValue := taintSplit[1] taintValue := taintSplit[1]
r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)") r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)")
if !r.MatchString(taintValue) { if !r.MatchString(taintValue) {
continue return false, apiv1.Taint{}
} }
values := strings.SplitN(taintValue, ":", 2) values := strings.SplitN(taintValue, ":", 2)
taints = append(taints, apiv1.Taint{ return true, apiv1.Taint{
Key: taintKey, Key: taintKey,
Value: values[0], Value: values[0],
Effect: apiv1.TaintEffect(values[1]), Effect: apiv1.TaintEffect(values[1]),
})
} }
return taints
} }
func extractAutoscalingOptionsFromScaleSetTags(tags map[string]*string) map[string]string { func extractAutoscalingOptionsFromScaleSetTags(tags map[string]*string) map[string]string {

View File

@ -21,6 +21,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/to" "github.com/Azure/go-autorest/autorest/to"
@ -30,7 +31,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
) )
func TestExtractLabelsFromScaleSet(t *testing.T) { func TestExtractLabelsFromTags(t *testing.T) {
expectedNodeLabelKey := "zip" expectedNodeLabelKey := "zip"
expectedNodeLabelValue := "zap" expectedNodeLabelValue := "zap"
extraNodeLabelValue := "buzz" extraNodeLabelValue := "buzz"
@ -52,14 +53,14 @@ func TestExtractLabelsFromScaleSet(t *testing.T) {
fmt.Sprintf("%s%s", nodeLabelTagName, escapedUnderscoreNodeLabelKey): &escapedUnderscoreNodeLabelValue, fmt.Sprintf("%s%s", nodeLabelTagName, escapedUnderscoreNodeLabelKey): &escapedUnderscoreNodeLabelValue,
} }
labels := extractLabelsFromScaleSet(tags) labels := extractLabelsFromTags(tags)
assert.Len(t, labels, 3) assert.Len(t, labels, 3)
assert.Equal(t, expectedNodeLabelValue, labels[expectedNodeLabelKey]) assert.Equal(t, expectedNodeLabelValue, labels[expectedNodeLabelKey])
assert.Equal(t, escapedSlashNodeLabelValue, labels[expectedSlashEscapedNodeLabelKey]) assert.Equal(t, escapedSlashNodeLabelValue, labels[expectedSlashEscapedNodeLabelKey])
assert.Equal(t, escapedUnderscoreNodeLabelValue, labels[expectedUnderscoreEscapedNodeLabelKey]) assert.Equal(t, escapedUnderscoreNodeLabelValue, labels[expectedUnderscoreEscapedNodeLabelKey])
} }
func TestExtractTaintsFromScaleSet(t *testing.T) { func TestExtractTaintsFromTags(t *testing.T) {
noScheduleTaintValue := "foo:NoSchedule" noScheduleTaintValue := "foo:NoSchedule"
noExecuteTaintValue := "bar:NoExecute" noExecuteTaintValue := "bar:NoExecute"
preferNoScheduleTaintValue := "fizz:PreferNoSchedule" preferNoScheduleTaintValue := "fizz:PreferNoSchedule"
@ -100,7 +101,7 @@ func TestExtractTaintsFromScaleSet(t *testing.T) {
}, },
} }
taints := extractTaintsFromScaleSet(tags) taints := extractTaintsFromTags(tags)
assert.Len(t, taints, 4) assert.Len(t, taints, 4)
assert.Equal(t, makeTaintSet(expectedTaints), makeTaintSet(taints)) assert.Equal(t, makeTaintSet(expectedTaints), makeTaintSet(taints))
} }
@ -137,6 +138,11 @@ func TestExtractTaintsFromSpecString(t *testing.T) {
Value: "fizz", Value: "fizz",
Effect: apiv1.TaintEffectPreferNoSchedule, Effect: apiv1.TaintEffectPreferNoSchedule,
}, },
{
Key: "dedicated", // duplicate key, should be ignored
Value: "foo",
Effect: apiv1.TaintEffectNoSchedule,
},
} }
taints := extractTaintsFromSpecString(strings.Join(taintsString, ",")) taints := extractTaintsFromSpecString(strings.Join(taintsString, ","))
@ -176,8 +182,9 @@ func TestTopologyFromScaleSet(t *testing.T) {
Location: to.StringPtr("westus"), Location: to.StringPtr("westus"),
} }
expectedZoneValues := []string{"westus-1", "westus-2", "westus-3"} expectedZoneValues := []string{"westus-1", "westus-2", "westus-3"}
template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "")
labels := buildGenericLabels(testVmss, testNodeName) assert.NoError(t, err)
labels := buildGenericLabels(template, testNodeName)
failureDomain, ok := labels[apiv1.LabelZoneFailureDomain] failureDomain, ok := labels[apiv1.LabelZoneFailureDomain]
assert.True(t, ok) assert.True(t, ok)
topologyZone, ok := labels[apiv1.LabelTopologyZone] topologyZone, ok := labels[apiv1.LabelTopologyZone]
@ -205,7 +212,9 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) {
expectedFailureDomain := "0" expectedFailureDomain := "0"
expectedTopologyZone := "0" expectedTopologyZone := "0"
expectedAzureDiskTopology := "" expectedAzureDiskTopology := ""
labels := buildGenericLabels(testVmss, testNodeName) template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "")
assert.NoError(t, err)
labels := buildGenericLabels(template, testNodeName)
failureDomain, ok := labels[apiv1.LabelZoneFailureDomain] failureDomain, ok := labels[apiv1.LabelZoneFailureDomain]
assert.True(t, ok) assert.True(t, ok)
@ -219,6 +228,61 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, expectedAzureDiskTopology, azureDiskTopology) assert.Equal(t, expectedAzureDiskTopology, azureDiskTopology)
} }
func TestBuildNodeTemplateFromVMPool(t *testing.T) {
agentPoolName := "testpool"
location := "eastus"
skuName := "Standard_DS2_v2"
labelKey := "foo"
labelVal := "bar"
taintStr := "dedicated=foo:NoSchedule,boo=fizz:PreferNoSchedule,group=bar:NoExecute"
osType := armcontainerservice.OSTypeLinux
osDiskType := armcontainerservice.OSDiskTypeEphemeral
zone1 := "1"
zone2 := "2"
vmpool := armcontainerservice.AgentPool{
Name: to.StringPtr(agentPoolName),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
NodeLabels: map[string]*string{
"existing": to.StringPtr("label"),
"department": to.StringPtr("engineering"),
},
NodeTaints: []*string{to.StringPtr("group=bar:NoExecute")},
OSType: &osType,
OSDiskType: &osDiskType,
AvailabilityZones: []*string{&zone1, &zone2},
},
}
labelsFromSpec := map[string]string{labelKey: labelVal}
taintsFromSpec := taintStr
template, err := buildNodeTemplateFromVMPool(vmpool, location, skuName, labelsFromSpec, taintsFromSpec)
assert.NoError(t, err)
assert.Equal(t, skuName, template.SkuName)
assert.Equal(t, location, template.Location)
assert.ElementsMatch(t, []string{zone1, zone2}, template.Zones)
assert.Equal(t, "linux", template.InstanceOS)
assert.NotNil(t, template.VMPoolNodeTemplate)
assert.Equal(t, agentPoolName, template.VMPoolNodeTemplate.AgentPoolName)
assert.Equal(t, &osDiskType, template.VMPoolNodeTemplate.OSDiskType)
// Labels: should include both from NodeLabels and labelsFromSpec
assert.Contains(t, template.VMPoolNodeTemplate.Labels, "existing")
assert.Equal(t, "label", *template.VMPoolNodeTemplate.Labels["existing"])
assert.Contains(t, template.VMPoolNodeTemplate.Labels, "department")
assert.Equal(t, "engineering", *template.VMPoolNodeTemplate.Labels["department"])
assert.Contains(t, template.VMPoolNodeTemplate.Labels, labelKey)
assert.Equal(t, labelVal, *template.VMPoolNodeTemplate.Labels[labelKey])
// Taints: should include both from NodeTaints and taintsFromSpec
taintSet := makeTaintSet(template.VMPoolNodeTemplate.Taints)
expectedTaints := []apiv1.Taint{
{Key: "group", Value: "bar", Effect: apiv1.TaintEffectNoExecute},
{Key: "dedicated", Value: "foo", Effect: apiv1.TaintEffectNoSchedule},
{Key: "boo", Value: "fizz", Effect: apiv1.TaintEffectPreferNoSchedule},
}
assert.Equal(t, makeTaintSet(expectedTaints), taintSet)
}
func makeTaintSet(taints []apiv1.Taint) map[apiv1.Taint]bool { func makeTaintSet(taints []apiv1.Taint) map[apiv1.Taint]bool {
set := make(map[apiv1.Taint]bool) set := make(map[apiv1.Taint]bool)

View File

@ -18,142 +18,426 @@ package azure
import ( import (
"fmt" "fmt"
"net/http"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
apiv1 "k8s.io/api/core/v1" apiv1 "k8s.io/api/core/v1"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"k8s.io/autoscaler/cluster-autoscaler/config" "k8s.io/autoscaler/cluster-autoscaler/config"
"k8s.io/autoscaler/cluster-autoscaler/config/dynamic" "k8s.io/autoscaler/cluster-autoscaler/config/dynamic"
"k8s.io/autoscaler/cluster-autoscaler/simulator/framework" "k8s.io/autoscaler/cluster-autoscaler/simulator/framework"
klog "k8s.io/klog/v2"
) )
// VMsPool is single instance VM pool // VMPool represents a group of standalone virtual machines (VMs) with a single SKU.
// this is a placeholder for now, no real implementation // It is part of a mixed-SKU agent pool (an agent pool with type `VirtualMachines`).
type VMsPool struct { // Terminology:
// - Agent pool: A node pool in an AKS cluster.
// - VMs pool: An agent pool of type `VirtualMachines`, which can contain mixed SKUs.
// - VMPool: A subset of VMs within a VMs pool that share the same SKU.
type VMPool struct {
azureRef azureRef
manager *AzureManager manager *AzureManager
resourceGroup string agentPoolName string // the virtual machines agentpool that this VMPool belongs to
sku string // sku of the VM in the pool
minSize int minSize int
maxSize int maxSize int
curSize int64
// sizeMutex sync.Mutex
// lastSizeRefresh time.Time
} }
// NewVMsPool creates a new VMsPool // NewVMPool creates a new VMPool - a pool of standalone VMs of a single size.
func NewVMsPool(spec *dynamic.NodeGroupSpec, am *AzureManager) *VMsPool { func NewVMPool(spec *dynamic.NodeGroupSpec, am *AzureManager, agentPoolName string, sku string) (*VMPool, error) {
nodepool := &VMsPool{ if am.azClient.agentPoolClient == nil {
return nil, fmt.Errorf("agentPoolClient is nil")
}
nodepool := &VMPool{
azureRef: azureRef{ azureRef: azureRef{
Name: spec.Name, Name: spec.Name, // in format "<agentPoolName>/<sku>"
}, },
manager: am, manager: am,
resourceGroup: am.config.ResourceGroup, sku: sku,
agentPoolName: agentPoolName,
curSize: -1,
minSize: spec.MinSize, minSize: spec.MinSize,
maxSize: spec.MaxSize, maxSize: spec.MaxSize,
} }
return nodepool, nil
return nodepool
} }
// MinSize returns the minimum size the cluster is allowed to scaled down // MinSize returns the minimum size the vmPool is allowed to scaled down
// to as provided by the node spec in --node parameter. // to as provided by the node spec in --node parameter.
func (agentPool *VMsPool) MinSize() int { func (vmPool *VMPool) MinSize() int {
return agentPool.minSize return vmPool.minSize
} }
// Exist is always true since we are initialized with an existing agentpool // Exist is always true since we are initialized with an existing vmPool
func (agentPool *VMsPool) Exist() bool { func (vmPool *VMPool) Exist() bool {
return true return true
} }
// Create creates the node group on the cloud provider side. // Create creates the node group on the cloud provider side.
func (agentPool *VMsPool) Create() (cloudprovider.NodeGroup, error) { func (vmPool *VMPool) Create() (cloudprovider.NodeGroup, error) {
return nil, cloudprovider.ErrAlreadyExist return nil, cloudprovider.ErrAlreadyExist
} }
// Delete deletes the node group on the cloud provider side. // Delete deletes the node group on the cloud provider side.
func (agentPool *VMsPool) Delete() error { func (vmPool *VMPool) Delete() error {
return cloudprovider.ErrNotImplemented
}
// ForceDeleteNodes deletes nodes from the group regardless of constraints.
func (vmPool *VMPool) ForceDeleteNodes(nodes []*apiv1.Node) error {
return cloudprovider.ErrNotImplemented return cloudprovider.ErrNotImplemented
} }
// Autoprovisioned is always false since we are initialized with an existing agentpool // Autoprovisioned is always false since we are initialized with an existing agentpool
func (agentPool *VMsPool) Autoprovisioned() bool { func (vmPool *VMPool) Autoprovisioned() bool {
return false return false
} }
// GetOptions returns NodeGroupAutoscalingOptions that should be used for this particular // GetOptions returns NodeGroupAutoscalingOptions that should be used for this particular
// NodeGroup. Returning a nil will result in using default options. // NodeGroup. Returning a nil will result in using default options.
func (agentPool *VMsPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) { func (vmPool *VMPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) {
// TODO(wenxuan): Implement this method // TODO(wenxuan): implement this method when vmPool can fully support GPU nodepool
return nil, cloudprovider.ErrNotImplemented return nil, nil
} }
// MaxSize returns the maximum size scale limit provided by --node // MaxSize returns the maximum size scale limit provided by --node
// parameter to the autoscaler main // parameter to the autoscaler main
func (agentPool *VMsPool) MaxSize() int { func (vmPool *VMPool) MaxSize() int {
return agentPool.maxSize return vmPool.maxSize
} }
// TargetSize returns the current TARGET size of the node group. It is possible that the // TargetSize returns the current target size of the node group. This value represents
// number is different from the number of nodes registered in Kubernetes. // the desired number of nodes in the VMPool, which may differ from the actual number
func (agentPool *VMsPool) TargetSize() (int, error) { // of nodes currently present.
// TODO(wenxuan): Implement this method func (vmPool *VMPool) TargetSize() (int, error) {
return -1, cloudprovider.ErrNotImplemented // VMs in the "Deleting" state are not counted towards the target size.
size, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false})
return int(size), err
} }
// IncreaseSize increase the size through a PUT AP call. It calculates the expected size // IncreaseSize increases the size of the VMPool by sending a PUT request to update the agent pool.
// based on a delta provided as parameter // This method waits until the asynchronous PUT operation completes or the client-side timeout is reached.
func (agentPool *VMsPool) IncreaseSize(delta int) error { func (vmPool *VMPool) IncreaseSize(delta int) error {
// TODO(wenxuan): Implement this method if delta <= 0 {
return cloudprovider.ErrNotImplemented return fmt.Errorf("size increase must be positive, current delta: %d", delta)
} }
// DeleteNodes extracts the providerIDs from the node spec and // Skip VMs in the failed state so that a PUT AP will be triggered to fix the failed VMs.
// delete or deallocate the nodes from the agent pool based on the scale down policy. currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: true})
func (agentPool *VMsPool) DeleteNodes(nodes []*apiv1.Node) error { if err != nil {
// TODO(wenxuan): Implement this method return err
return cloudprovider.ErrNotImplemented
} }
// ForceDeleteNodes deletes nodes from the group regardless of constraints. if int(currentSize)+delta > vmPool.MaxSize() {
func (agentPool *VMsPool) ForceDeleteNodes(nodes []*apiv1.Node) error { return fmt.Errorf("size-increasing request of %d is bigger than max size %d", int(currentSize)+delta, vmPool.MaxSize())
return cloudprovider.ErrNotImplemented }
updateCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout)
defer cancel()
versionedAP, err := vmPool.getAgentpoolFromCache()
if err != nil {
klog.Errorf("Failed to get vmPool %s, error: %s", vmPool.agentPoolName, err)
return err
}
count := currentSize + int32(delta)
requestBody := armcontainerservice.AgentPool{}
// self-hosted CAS will be using Manual scale profile
if len(versionedAP.Properties.VirtualMachinesProfile.Scale.Manual) > 0 {
requestBody = buildRequestBodyForScaleUp(versionedAP, count, vmPool.sku)
} else { // AKS-managed CAS will use custom header for setting the target count
header := make(http.Header)
header.Set("Target-Count", fmt.Sprintf("%d", count))
updateCtx = policy.WithHTTPHeader(updateCtx, header)
}
defer vmPool.manager.invalidateCache()
poller, err := vmPool.manager.azClient.agentPoolClient.BeginCreateOrUpdate(
updateCtx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName,
requestBody, nil)
if err != nil {
klog.Errorf("Failed to scale up agentpool %s in cluster %s for vmPool %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, err)
return err
}
if _, err := poller.PollUntilDone(updateCtx, nil /*default polling interval is 30s*/); err != nil {
klog.Errorf("agentPoolClient.BeginCreateOrUpdate for aks cluster %s agentpool %s for scaling up vmPool %s failed with error %s",
vmPool.manager.config.ClusterName, vmPool.agentPoolName, vmPool.Name, err)
return err
}
klog.Infof("Successfully scaled up agentpool %s in cluster %s for vmPool %s to size %d",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, count)
return nil
}
// buildRequestBodyForScaleUp builds the request body for scale up for self-hosted CAS
func buildRequestBodyForScaleUp(agentpool armcontainerservice.AgentPool, count int32, vmSku string) armcontainerservice.AgentPool {
requestBody := armcontainerservice.AgentPool{
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: agentpool.Properties.Type,
},
}
// the request body must have the same mode as the original agentpool
// otherwise the PUT request will fail
if agentpool.Properties.Mode != nil &&
*agentpool.Properties.Mode == armcontainerservice.AgentPoolModeSystem {
systemMode := armcontainerservice.AgentPoolModeSystem
requestBody.Properties.Mode = &systemMode
}
// set the count of the matching manual scale profile to the new target value
for _, manualProfile := range agentpool.Properties.VirtualMachinesProfile.Scale.Manual {
if manualProfile != nil && len(manualProfile.Sizes) == 1 &&
strings.EqualFold(to.String(manualProfile.Sizes[0]), vmSku) {
klog.V(5).Infof("Found matching manual profile for VM SKU: %s, updating count to: %d", vmSku, count)
manualProfile.Count = to.Int32Ptr(count)
requestBody.Properties.VirtualMachinesProfile = agentpool.Properties.VirtualMachinesProfile
break
}
}
return requestBody
}
// DeleteNodes removes the specified nodes from the VMPool by extracting their providerIDs
// and performing the appropriate delete or deallocate operation based on the agent pool's
// scale-down policy. This method waits for the asynchronous delete operation to complete,
// with a client-side timeout.
func (vmPool *VMPool) DeleteNodes(nodes []*apiv1.Node) error {
// Ensure we don't scale below the minimum size by excluding VMs in the "Deleting" state.
currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false})
if err != nil {
return fmt.Errorf("unable to retrieve current size: %w", err)
}
if int(currentSize) <= vmPool.MinSize() {
return fmt.Errorf("cannot delete nodes as minimum size of %d has been reached", vmPool.MinSize())
}
providerIDs, err := vmPool.getProviderIDsForNodes(nodes)
if err != nil {
return fmt.Errorf("failed to retrieve provider IDs for nodes: %w", err)
}
if len(providerIDs) == 0 {
return nil
}
klog.V(3).Infof("Deleting nodes from vmPool %s: %v", vmPool.Name, providerIDs)
machineNames := make([]*string, len(providerIDs))
for i, providerID := range providerIDs {
// extract the machine name from the providerID by splitting the providerID by '/' and get the last element
// The providerID look like this:
// "azure:///subscriptions/0000000-0000-0000-0000-00000000000/resourceGroups/mc_myrg_mycluster_eastus/providers/Microsoft.Compute/virtualMachines/aks-mypool-12345678-vms0"
machineName, err := resourceName(providerID)
if err != nil {
return err
}
machineNames[i] = &machineName
}
requestBody := armcontainerservice.AgentPoolDeleteMachinesParameter{
MachineNames: machineNames,
}
deleteCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout)
defer cancel()
defer vmPool.manager.invalidateCache()
poller, err := vmPool.manager.azClient.agentPoolClient.BeginDeleteMachines(
deleteCtx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName,
requestBody, nil)
if err != nil {
klog.Errorf("Failed to delete nodes from agentpool %s in cluster %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, err)
return err
}
if _, err := poller.PollUntilDone(deleteCtx, nil); err != nil {
klog.Errorf("agentPoolClient.BeginDeleteMachines for aks cluster %s for scaling down vmPool %s failed with error %s",
vmPool.manager.config.ClusterName, vmPool.agentPoolName, err)
return err
}
klog.Infof("Successfully deleted %d nodes from vmPool %s", len(providerIDs), vmPool.Name)
return nil
}
func (vmPool *VMPool) getProviderIDsForNodes(nodes []*apiv1.Node) ([]string, error) {
var providerIDs []string
for _, node := range nodes {
belongs, err := vmPool.Belongs(node)
if err != nil {
return nil, fmt.Errorf("failed to check if node %s belongs to vmPool %s: %w", node.Name, vmPool.Name, err)
}
if !belongs {
return nil, fmt.Errorf("node %s does not belong to vmPool %s", node.Name, vmPool.Name)
}
providerIDs = append(providerIDs, node.Spec.ProviderID)
}
return providerIDs, nil
}
// Belongs returns true if the given k8s node belongs to this vms nodepool.
func (vmPool *VMPool) Belongs(node *apiv1.Node) (bool, error) {
klog.V(6).Infof("Check if node belongs to this vmPool:%s, node:%v\n", vmPool, node)
ref := &azureRef{
Name: node.Spec.ProviderID,
}
nodeGroup, err := vmPool.manager.GetNodeGroupForInstance(ref)
if err != nil {
return false, err
}
if nodeGroup == nil {
return false, fmt.Errorf("%s doesn't belong to a known node group", node.Name)
}
if !strings.EqualFold(nodeGroup.Id(), vmPool.Id()) {
return false, nil
}
return true, nil
} }
// DecreaseTargetSize decreases the target size of the node group. // DecreaseTargetSize decreases the target size of the node group.
func (agentPool *VMsPool) DecreaseTargetSize(delta int) error { func (vmPool *VMPool) DecreaseTargetSize(delta int) error {
// TODO(wenxuan): Implement this method // The TargetSize of a VMPool is automatically adjusted after node deletions.
return cloudprovider.ErrNotImplemented // This method is invoked in scenarios such as (see details in clusterstate.go):
// - len(readiness.Registered) > acceptableRange.CurrentTarget
// - len(readiness.Registered) < acceptableRange.CurrentTarget - unregisteredNodes
// For VMPool, this method should not be called because:
// CurrentTarget = len(readiness.Registered) + unregisteredNodes - len(nodesInDeletingState)
// Here, nodesInDeletingState is a subset of unregisteredNodes,
// ensuring len(readiness.Registered) is always within the acceptable range.
// here we just invalidate the cache to avoid any potential bugs
vmPool.manager.invalidateCache()
klog.Warningf("DecreaseTargetSize called for VMPool %s, but it should not be used, invalidating cache", vmPool.Name)
return nil
} }
// Id returns the name of the agentPool // Id returns the name of the agentPool, it is in the format of <agentpoolname>/<sku>
func (agentPool *VMsPool) Id() string { // e.g. mypool1/Standard_D2s_v3
return agentPool.azureRef.Name func (vmPool *VMPool) Id() string {
return vmPool.azureRef.Name
} }
// Debug returns a string with basic details of the agentPool // Debug returns a string with basic details of the agentPool
func (agentPool *VMsPool) Debug() string { func (vmPool *VMPool) Debug() string {
return fmt.Sprintf("%s (%d:%d)", agentPool.Id(), agentPool.MinSize(), agentPool.MaxSize()) return fmt.Sprintf("%s (%d:%d)", vmPool.Id(), vmPool.MinSize(), vmPool.MaxSize())
} }
func (agentPool *VMsPool) getVMsFromCache() ([]compute.VirtualMachine, error) { func isSpotAgentPool(ap armcontainerservice.AgentPool) bool {
// vmsPoolMap is a map of agent pool name to the list of virtual machines if ap.Properties != nil && ap.Properties.ScaleSetPriority != nil {
vmsPoolMap := agentPool.manager.azureCache.getVirtualMachines() return strings.EqualFold(string(*ap.Properties.ScaleSetPriority), "Spot")
if _, ok := vmsPoolMap[agentPool.Name]; !ok { }
return []compute.VirtualMachine{}, fmt.Errorf("vms pool %s not found in the cache", agentPool.Name) return false
} }
return vmsPoolMap[agentPool.Name], nil // skipOption is used to determine whether to skip VMs in certain states when calculating the current size of the vmPool.
type skipOption struct {
// skipDeleting indicates whether to skip VMs in the "Deleting" state.
skipDeleting bool
// skipFailed indicates whether to skip VMs in the "Failed" state.
skipFailed bool
}
// getCurSize determines the current count of VMs in the vmPool, including unregistered ones.
// The source of truth depends on the pool type (spot or non-spot).
func (vmPool *VMPool) getCurSize(op skipOption) (int32, error) {
agentPool, err := vmPool.getAgentpoolFromCache()
if err != nil {
klog.Errorf("Failed to retrieve agent pool %s from cache: %v", vmPool.agentPoolName, err)
return -1, err
}
// spot pool size is retrieved directly from Azure instead of the cache
if isSpotAgentPool(agentPool) {
return vmPool.getSpotPoolSize()
}
// non-spot pool size is retrieved from the cache
vms, err := vmPool.getVMsFromCache(op)
if err != nil {
klog.Errorf("Failed to get VMs from cache for agentpool %s with error: %v", vmPool.agentPoolName, err)
return -1, err
}
return int32(len(vms)), nil
}
// getSpotPoolSize retrieves the current size of a spot agent pool directly from Azure.
func (vmPool *VMPool) getSpotPoolSize() (int32, error) {
ap, err := vmPool.getAgentpoolFromAzure()
if err != nil {
klog.Errorf("Failed to get agentpool %s from Azure with error: %v", vmPool.agentPoolName, err)
return -1, err
}
if ap.Properties != nil {
// the VirtualMachineNodesStatus returned by AKS-RP is constructed from the vm list returned from CRP.
// it only contains VMs in the running state.
for _, status := range ap.Properties.VirtualMachineNodesStatus {
if status != nil {
if strings.EqualFold(to.String(status.Size), vmPool.sku) {
return to.Int32(status.Count), nil
}
}
}
}
return -1, fmt.Errorf("failed to get the size of spot agentpool %s", vmPool.agentPoolName)
}
// getVMsFromCache retrieves the list of virtual machines in this VMPool.
// If excludeDeleting is true, it skips VMs in the "Deleting" state.
// https://learn.microsoft.com/en-us/azure/virtual-machines/states-billing#provisioning-states
func (vmPool *VMPool) getVMsFromCache(op skipOption) ([]compute.VirtualMachine, error) {
vmsMap := vmPool.manager.azureCache.getVirtualMachines()
var filteredVMs []compute.VirtualMachine
for _, vm := range vmsMap[vmPool.agentPoolName] {
if vm.VirtualMachineProperties == nil ||
vm.VirtualMachineProperties.HardwareProfile == nil ||
!strings.EqualFold(string(vm.HardwareProfile.VMSize), vmPool.sku) {
continue
}
if op.skipDeleting && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Deleting") {
klog.V(4).Infof("Skipping VM %s in deleting state", to.String(vm.ID))
continue
}
if op.skipFailed && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Failed") {
klog.V(4).Infof("Skipping VM %s in failed state", to.String(vm.ID))
continue
}
filteredVMs = append(filteredVMs, vm)
}
return filteredVMs, nil
} }
// Nodes returns the list of nodes in the vms agentPool. // Nodes returns the list of nodes in the vms agentPool.
func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) { func (vmPool *VMPool) Nodes() ([]cloudprovider.Instance, error) {
vms, err := agentPool.getVMsFromCache() vms, err := vmPool.getVMsFromCache(skipOption{}) // no skip option, get all VMs
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -163,7 +447,7 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) {
if vm.ID == nil || len(*vm.ID) == 0 { if vm.ID == nil || len(*vm.ID) == 0 {
continue continue
} }
resourceID, err := convertResourceGroupNameToLower("azure://" + *vm.ID) resourceID, err := convertResourceGroupNameToLower("azure://" + to.String(vm.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -173,12 +457,53 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) {
return nodes, nil return nodes, nil
} }
// TemplateNodeInfo is not implemented. // TemplateNodeInfo returns a NodeInfo object that can be used to create a new node in the vmPool.
func (agentPool *VMsPool) TemplateNodeInfo() (*framework.NodeInfo, error) { func (vmPool *VMPool) TemplateNodeInfo() (*framework.NodeInfo, error) {
return nil, cloudprovider.ErrNotImplemented ap, err := vmPool.getAgentpoolFromCache()
if err != nil {
return nil, err
}
inputLabels := map[string]string{}
inputTaints := ""
template, err := buildNodeTemplateFromVMPool(ap, vmPool.manager.config.Location, vmPool.sku, inputLabels, inputTaints)
if err != nil {
return nil, err
}
node, err := buildNodeFromTemplate(vmPool.agentPoolName, template, vmPool.manager, vmPool.manager.config.EnableDynamicInstanceList)
if err != nil {
return nil, err
}
nodeInfo := framework.NewNodeInfo(node, nil, &framework.PodInfo{Pod: cloudprovider.BuildKubeProxy(vmPool.agentPoolName)})
return nodeInfo, nil
}
func (vmPool *VMPool) getAgentpoolFromCache() (armcontainerservice.AgentPool, error) {
vmsPoolMap := vmPool.manager.azureCache.getVMsPoolMap()
if _, exists := vmsPoolMap[vmPool.agentPoolName]; !exists {
return armcontainerservice.AgentPool{}, fmt.Errorf("VMs agent pool %s not found in cache", vmPool.agentPoolName)
}
return vmsPoolMap[vmPool.agentPoolName], nil
}
// getAgentpoolFromAzure returns the AKS agentpool from Azure
func (vmPool *VMPool) getAgentpoolFromAzure() (armcontainerservice.AgentPool, error) {
ctx, cancel := getContextWithTimeout(vmsContextTimeout)
defer cancel()
resp, err := vmPool.manager.azClient.agentPoolClient.Get(
ctx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName, nil)
if err != nil {
return resp.AgentPool, fmt.Errorf("failed to get agentpool %s in cluster %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, err)
}
return resp.AgentPool, nil
} }
// AtomicIncreaseSize is not implemented. // AtomicIncreaseSize is not implemented.
func (agentPool *VMsPool) AtomicIncreaseSize(delta int) error { func (vmPool *VMPool) AtomicIncreaseSize(delta int) error {
return cloudprovider.ErrNotImplemented return cloudprovider.ErrNotImplemented
} }

View File

@ -17,45 +17,64 @@ limitations under the License.
package azure package azure
import ( import (
"context"
"fmt" "fmt"
"net/http"
"testing" "testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to" "github.com/Azure/go-autorest/autorest/to"
"go.uber.org/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
apiv1 "k8s.io/api/core/v1" apiv1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"k8s.io/autoscaler/cluster-autoscaler/config" "k8s.io/autoscaler/cluster-autoscaler/config"
"k8s.io/autoscaler/cluster-autoscaler/config/dynamic" "k8s.io/autoscaler/cluster-autoscaler/config/dynamic"
providerazure "sigs.k8s.io/cloud-provider-azure/pkg/provider" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient"
) )
func newTestVMsPool(manager *AzureManager, name string) *VMsPool { const (
return &VMsPool{ vmSku = "Standard_D2_v2"
vmsAgentPoolName = "test-vms-pool"
vmsNodeGroupName = vmsAgentPoolName + "/" + vmSku
fakeVMsNodeName = "aks-" + vmsAgentPoolName + "-13222729-vms%d"
fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/" + fakeVMsNodeName
)
func newTestVMsPool(manager *AzureManager) *VMPool {
return &VMPool{
azureRef: azureRef{ azureRef: azureRef{
Name: name, Name: vmsNodeGroupName,
}, },
manager: manager, manager: manager,
minSize: 3, minSize: 3,
maxSize: 10, maxSize: 10,
agentPoolName: vmsAgentPoolName,
sku: vmSku,
} }
} }
const (
fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/%d"
)
func newTestVMsPoolVMList(count int) []compute.VirtualMachine { func newTestVMsPoolVMList(count int) []compute.VirtualMachine {
var vmList []compute.VirtualMachine var vmList []compute.VirtualMachine
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
vm := compute.VirtualMachine{ vm := compute.VirtualMachine{
ID: to.StringPtr(fmt.Sprintf(fakeVMsPoolVMID, i)), ID: to.StringPtr(fmt.Sprintf(fakeVMsPoolVMID, i)),
VirtualMachineProperties: &compute.VirtualMachineProperties{ VirtualMachineProperties: &compute.VirtualMachineProperties{
VMID: to.StringPtr(fmt.Sprintf("123E4567-E89B-12D3-A456-426655440000-%d", i)), VMID: to.StringPtr(fmt.Sprintf("123E4567-E89B-12D3-A456-426655440000-%d", i)),
HardwareProfile: &compute.HardwareProfile{
VMSize: compute.VirtualMachineSizeTypes(vmSku),
},
ProvisioningState: to.StringPtr("Succeeded"),
}, },
Tags: map[string]*string{ Tags: map[string]*string{
agentpoolTypeTag: to.StringPtr("VirtualMachines"), agentpoolTypeTag: to.StringPtr("VirtualMachines"),
agentpoolNameTag: to.StringPtr("test-vms-pool"), agentpoolNameTag: to.StringPtr(vmsAgentPoolName),
}, },
} }
vmList = append(vmList, vm) vmList = append(vmList, vm)
@ -63,41 +82,73 @@ func newTestVMsPoolVMList(count int) []compute.VirtualMachine {
return vmList return vmList
} }
func newVMsNode(vmID int64) *apiv1.Node { func newVMsNode(vmIdx int64) *apiv1.Node {
node := &apiv1.Node{ return &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf(fakeVMsNodeName, vmIdx),
},
Spec: apiv1.NodeSpec{ Spec: apiv1.NodeSpec{
ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmID), ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmIdx),
},
}
}
func getTestVMsAgentPool(isSystemPool bool) armcontainerservice.AgentPool {
mode := armcontainerservice.AgentPoolModeUser
if isSystemPool {
mode = armcontainerservice.AgentPoolModeSystem
}
vmsPoolType := armcontainerservice.AgentPoolTypeVirtualMachines
return armcontainerservice.AgentPool{
Name: to.StringPtr(vmsAgentPoolName),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmsPoolType,
Mode: &mode,
VirtualMachinesProfile: &armcontainerservice.VirtualMachinesProfile{
Scale: &armcontainerservice.ScaleProfile{
Manual: []*armcontainerservice.ManualScaleProfile{
{
Count: to.Int32Ptr(3),
Sizes: []*string{to.StringPtr(vmSku)},
},
},
},
},
VirtualMachineNodesStatus: []*armcontainerservice.VirtualMachineNodes{
{
Count: to.Int32Ptr(3),
Size: to.StringPtr(vmSku),
},
},
}, },
} }
return node
} }
func TestNewVMsPool(t *testing.T) { func TestNewVMsPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
manager := newTestAzureManager(t)
manager.azClient.agentPoolClient = mockAgentpoolclient
manager.config.ResourceGroup = "MC_rg"
manager.config.ClusterResourceGroup = "rg"
manager.config.ClusterName = "mycluster"
spec := &dynamic.NodeGroupSpec{ spec := &dynamic.NodeGroupSpec{
Name: "test-nodepool", Name: vmsAgentPoolName,
MinSize: 1, MinSize: 1,
MaxSize: 5, MaxSize: 10,
}
am := &AzureManager{
config: &Config{
Config: providerazure.Config{
ResourceGroup: "test-resource-group",
},
},
} }
nodepool := NewVMsPool(spec, am) ap, err := NewVMPool(spec, manager, vmsAgentPoolName, vmSku)
assert.NoError(t, err)
assert.Equal(t, "test-nodepool", nodepool.azureRef.Name) assert.Equal(t, vmsAgentPoolName, ap.azureRef.Name)
assert.Equal(t, "test-resource-group", nodepool.resourceGroup) assert.Equal(t, 1, ap.minSize)
assert.Equal(t, int64(-1), nodepool.curSize) assert.Equal(t, 10, ap.maxSize)
assert.Equal(t, 1, nodepool.minSize)
assert.Equal(t, 5, nodepool.maxSize)
assert.Equal(t, am, nodepool.manager)
} }
func TestMinSize(t *testing.T) { func TestMinSize(t *testing.T) {
agentPool := &VMsPool{ agentPool := &VMPool{
minSize: 1, minSize: 1,
} }
@ -105,12 +156,12 @@ func TestMinSize(t *testing.T) {
} }
func TestExist(t *testing.T) { func TestExist(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
assert.True(t, agentPool.Exist()) assert.True(t, agentPool.Exist())
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
nodeGroup, err := agentPool.Create() nodeGroup, err := agentPool.Create()
assert.Nil(t, nodeGroup) assert.Nil(t, nodeGroup)
@ -118,65 +169,43 @@ func TestCreate(t *testing.T) {
} }
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
err := agentPool.Delete() err := agentPool.Delete()
assert.Equal(t, cloudprovider.ErrNotImplemented, err) assert.Equal(t, cloudprovider.ErrNotImplemented, err)
} }
func TestAutoprovisioned(t *testing.T) { func TestAutoprovisioned(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
assert.False(t, agentPool.Autoprovisioned()) assert.False(t, agentPool.Autoprovisioned())
} }
func TestGetOptions(t *testing.T) { func TestGetOptions(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
defaults := config.NodeGroupAutoscalingOptions{} defaults := config.NodeGroupAutoscalingOptions{}
options, err := agentPool.GetOptions(defaults) options, err := agentPool.GetOptions(defaults)
assert.Nil(t, options) assert.Nil(t, options)
assert.Equal(t, cloudprovider.ErrNotImplemented, err) assert.Nil(t, err)
} }
func TestMaxSize(t *testing.T) { func TestMaxSize(t *testing.T) {
agentPool := &VMsPool{ agentPool := &VMPool{
maxSize: 10, maxSize: 10,
} }
assert.Equal(t, 10, agentPool.MaxSize()) assert.Equal(t, 10, agentPool.MaxSize())
} }
func TestTargetSize(t *testing.T) {
agentPool := &VMsPool{}
size, err := agentPool.TargetSize()
assert.Equal(t, -1, size)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestIncreaseSize(t *testing.T) {
agentPool := &VMsPool{}
err := agentPool.IncreaseSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestDeleteNodes(t *testing.T) {
agentPool := &VMsPool{}
err := agentPool.DeleteNodes(nil)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestDecreaseTargetSize(t *testing.T) { func TestDecreaseTargetSize(t *testing.T) {
agentPool := &VMsPool{} agentPool := newTestVMsPool(newTestAzureManager(t))
err := agentPool.DecreaseTargetSize(1) err := agentPool.DecreaseTargetSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err) assert.Nil(t, err)
} }
func TestId(t *testing.T) { func TestId(t *testing.T) {
agentPool := &VMsPool{ agentPool := &VMPool{
azureRef: azureRef{ azureRef: azureRef{
Name: "test-id", Name: "test-id",
}, },
@ -186,7 +215,7 @@ func TestId(t *testing.T) {
} }
func TestDebug(t *testing.T) { func TestDebug(t *testing.T) {
agentPool := &VMsPool{ agentPool := &VMPool{
azureRef: azureRef{ azureRef: azureRef{
Name: "test-debug", Name: "test-debug",
}, },
@ -198,115 +227,341 @@ func TestDebug(t *testing.T) {
assert.Equal(t, expectedDebugString, agentPool.Debug()) assert.Equal(t, expectedDebugString, agentPool.Debug())
} }
func TestTemplateNodeInfo(t *testing.T) { func TestTemplateNodeInfo(t *testing.T) {
agentPool := &VMsPool{} ctrl := gomock.NewController(t)
defer ctrl.Finish()
nodeInfo, err := agentPool.TemplateNodeInfo() ap := newTestVMsPool(newTestAzureManager(t))
assert.Nil(t, nodeInfo) ap.manager.config.EnableVMsAgentPool = true
assert.Equal(t, cloudprovider.ErrNotImplemented, err) mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
nodeInfo, err := ap.TemplateNodeInfo()
assert.NotNil(t, nodeInfo)
assert.Nil(t, err)
} }
func TestAtomicIncreaseSize(t *testing.T) { func TestAtomicIncreaseSize(t *testing.T) {
agentPool := &VMsPool{} agentPool := &VMPool{}
err := agentPool.AtomicIncreaseSize(1) err := agentPool.AtomicIncreaseSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err) assert.Equal(t, cloudprovider.ErrNotImplemented, err)
} }
// Test cases for getVMsFromCache()
// Test case 1 - when the vms pool is not found in the cache
// Test case 2 - when the vms pool is found in the cache but has no VMs
// Test case 3 - when the vms pool is found in the cache and has VMs
// Test case 4 - when the vms pool is found in the cache and has VMs with no name
func TestGetVMsFromCache(t *testing.T) { func TestGetVMsFromCache(t *testing.T) {
// Test case 1
manager := &AzureManager{ manager := &AzureManager{
azureCache: &azureCache{ azureCache: &azureCache{
virtualMachines: make(map[string][]compute.VirtualMachine), virtualMachines: make(map[string][]compute.VirtualMachine),
vmsPoolMap: make(map[string]armcontainerservice.AgentPool),
}, },
} }
agentPool := &VMsPool{ agentPool := &VMPool{
manager: manager, manager: manager,
azureRef: azureRef{ agentPoolName: vmsAgentPoolName,
Name: "test-vms-pool", sku: vmSku,
},
} }
_, err := agentPool.getVMsFromCache() // Test case 1 - when the vms pool is not found in the cache
assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache") vms, err := agentPool.getVMsFromCache(skipOption{})
assert.Nil(t, err)
assert.Len(t, vms, 0)
// Test case 2 // Test case 2 - when the vms pool is found in the cache but has no VMs
manager.azureCache.virtualMachines["test-vms-pool"] = []compute.VirtualMachine{} manager.azureCache.virtualMachines[vmsAgentPoolName] = []compute.VirtualMachine{}
_, err = agentPool.getVMsFromCache() vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, vms, 0)
// Test case 3 // Test case 3 - when the vms pool is found in the cache and has VMs
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3)
vms, err := agentPool.getVMsFromCache() vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, vms, 3) assert.Len(t, vms, 3)
// Test case 4 // Test case 4 - should skip failed VMs
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) vmList := newTestVMsPoolVMList(3)
agentPool.azureRef.Name = "" vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Failed")
_, err = agentPool.getVMsFromCache() manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
assert.EqualError(t, err, "vms pool not found in the cache") vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true})
assert.NoError(t, err)
assert.Len(t, vms, 2)
// Test case 5 - should skip deleting VMs
vmList = newTestVMsPoolVMList(3)
vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting")
manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
vms, err = agentPool.getVMsFromCache(skipOption{skipDeleting: true})
assert.NoError(t, err)
assert.Len(t, vms, 2)
// Test case 6 - should not skip deleting VMs
vmList = newTestVMsPoolVMList(3)
vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting")
manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true})
assert.NoError(t, err)
assert.Len(t, vms, 3)
// Test case 7 - when the vms pool is found in the cache and has VMs with no name
manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3)
agentPool.agentPoolName = ""
vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err)
assert.Len(t, vms, 0)
}
func TestGetVMsFromCacheForVMsPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(2)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ac.enableVMsAgentPool = true
ap.manager.azureCache = ac
vms, err := ap.getVMsFromCache(skipOption{})
assert.Equal(t, 2, len(vms))
assert.NoError(t, err)
} }
// Test cases for Nodes()
// Test case 1 - when there are no VMs in the pool
// Test case 2 - when there are VMs in the pool
// Test case 3 - when there are VMs in the pool with no ID
// Test case 4 - when there is an error converting resource group name
// Test case 5 - when there is an error getting VMs from cache
func TestNodes(t *testing.T) { func TestNodes(t *testing.T) {
// Test case 1 ctrl := gomock.NewController(t)
manager := &AzureManager{ defer ctrl.Finish()
azureCache: &azureCache{
virtualMachines: make(map[string][]compute.VirtualMachine), ap := newTestVMsPool(newTestAzureManager(t))
}, expectedVMs := newTestVMsPoolVMList(2)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
vms, err := ap.Nodes()
assert.Equal(t, 2, len(vms))
assert.NoError(t, err)
} }
agentPool := &VMsPool{
manager: manager, func TestGetCurSizeForVMsPool(t *testing.T) {
azureRef: azureRef{ ctrl := gomock.NewController(t)
Name: "test-vms-pool", defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(3)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
curSize, err := ap.getCurSize(skipOption{})
assert.NoError(t, err)
assert.Equal(t, int32(3), curSize)
}
func TestVMsPoolIncreaseSize(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
manager := newTestAzureManager(t)
ap := newTestVMsPool(manager)
expectedVMs := newTestVMsPoolVMList(3)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
// failure case 1
err1 := ap.IncreaseSize(-1)
expectedErr := fmt.Errorf("size increase must be positive, current delta: -1")
assert.Equal(t, expectedErr, err1)
// failure case 2
err2 := ap.IncreaseSize(8)
expectedErr = fmt.Errorf("size-increasing request of 11 is bigger than max size 10")
assert.Equal(t, expectedErr, err2)
// success case 3
resp := &http.Response{
Header: map[string][]string{
"Fake-Poller-Status": {"Done"},
}, },
} }
nodes, err := agentPool.Nodes() fakePoller, pollerErr := runtime.NewPoller(resp, runtime.Pipeline{},
assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache") &runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{
assert.Empty(t, nodes) Handler: &fakehandler[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{},
})
// Test case 2 assert.NoError(t, pollerErr)
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
nodes, err = agentPool.Nodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
// Test case 3 mockAgentpoolclient.EXPECT().BeginCreateOrUpdate(
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) gomock.Any(), manager.config.ClusterResourceGroup,
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = nil manager.config.ClusterName,
nodes, err = agentPool.Nodes() vmsAgentPoolName,
assert.NoError(t, err) gomock.Any(), gomock.Any()).Return(fakePoller, nil)
assert.Len(t, nodes, 2)
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
emptyString := ""
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &emptyString
nodes, err = agentPool.Nodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
// Test case 4 err3 := ap.IncreaseSize(1)
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) assert.NoError(t, err3)
bogusID := "foo" }
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &bogusID
nodes, err = agentPool.Nodes() func TestDeleteVMsPoolNodes_Failed(t *testing.T) {
assert.Empty(t, nodes) ctrl := gomock.NewController(t)
assert.Error(t, err) defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
// Test case 5 node := newVMsNode(0)
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(1)
agentPool.azureRef.Name = "" expectedVMs := newTestVMsPoolVMList(3)
nodes, err = agentPool.Nodes() mockVMClient := mockvmclient.NewMockInterface(ctrl)
assert.Empty(t, nodes) ap.manager.azClient.virtualMachinesClient = mockVMClient
assert.Error(t, err) ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
agentpool := getTestVMsAgentPool(false)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager)
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.azureCache.enableVMsAgentPool = true
registered := ap.manager.RegisterNodeGroup(ap)
assert.True(t, registered)
ap.manager.explicitlyConfigured[vmsNodeGroupName] = true
ap.manager.forceRefresh()
// failure case
deleteErr := ap.DeleteNodes([]*apiv1.Node{node})
assert.Error(t, deleteErr)
assert.Contains(t, deleteErr.Error(), "cannot delete nodes as minimum size of 3 has been reached")
}
func TestDeleteVMsPoolNodes_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(5)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
agentpool := getTestVMsAgentPool(false)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager)
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.azureCache.enableVMsAgentPool = true
registered := ap.manager.RegisterNodeGroup(ap)
assert.True(t, registered)
ap.manager.explicitlyConfigured[vmsNodeGroupName] = true
ap.manager.forceRefresh()
// success case
resp := &http.Response{
Header: map[string][]string{
"Fake-Poller-Status": {"Done"},
},
}
fakePoller, err := runtime.NewPoller(resp, runtime.Pipeline{},
&runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{
Handler: &fakehandler[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{},
})
assert.NoError(t, err)
mockAgentpoolclient.EXPECT().BeginDeleteMachines(
gomock.Any(), ap.manager.config.ClusterResourceGroup,
ap.manager.config.ClusterName,
vmsAgentPoolName,
gomock.Any(), gomock.Any()).Return(fakePoller, nil)
node := newVMsNode(0)
derr := ap.DeleteNodes([]*apiv1.Node{node})
assert.NoError(t, derr)
}
type fakehandler[T any] struct{}
func (f *fakehandler[T]) Done() bool {
return true
}
func (f *fakehandler[T]) Poll(ctx context.Context) (*http.Response, error) {
return nil, nil
}
func (f *fakehandler[T]) Result(ctx context.Context, out *T) error {
return nil
}
func getFakeAgentpoolListPager(agentpool ...*armcontainerservice.AgentPool) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] {
fakeFetcher := func(ctx context.Context, response *armcontainerservice.AgentPoolsClientListResponse) (armcontainerservice.AgentPoolsClientListResponse, error) {
return armcontainerservice.AgentPoolsClientListResponse{
AgentPoolListResult: armcontainerservice.AgentPoolListResult{
Value: agentpool,
},
}, nil
}
return runtime.NewPager(runtime.PagingHandler[armcontainerservice.AgentPoolsClientListResponse]{
More: func(response armcontainerservice.AgentPoolsClientListResponse) bool {
return false
},
Fetcher: fakeFetcher,
})
} }

View File

@ -214,6 +214,11 @@ autoscaler about the sizing of the nodes in the node group. At the minimum,
you must specify the CPU and memory annotations, these annotations should you must specify the CPU and memory annotations, these annotations should
match the expected capacity of the nodes created from the infrastructure. match the expected capacity of the nodes created from the infrastructure.
> Note: The scale from zero annotations will override any capacity information
> supplied by the Cluster API provider in the infrastructure machine templates.
> If both the annotations and the provider supplied capacity information are
> present, the annotations will take precedence.
For example, if my MachineDeployment will create nodes that have "16000m" CPU, For example, if my MachineDeployment will create nodes that have "16000m" CPU,
"128G" memory, "100Gi" ephemeral disk storage, 2 NVidia GPUs, and can support "128G" memory, "100Gi" ephemeral disk storage, 2 NVidia GPUs, and can support
200 max pods, the following annotations will instruct the autoscaler how to 200 max pods, the following annotations will instruct the autoscaler how to
@ -240,11 +245,12 @@ metadata:
capacity.cluster-autoscaler.kubernetes.io/gpu-count: "2" capacity.cluster-autoscaler.kubernetes.io/gpu-count: "2"
``` ```
*Note* the `maxPods` annotation will default to `110` if it is not supplied. > Note: the `maxPods` annotation will default to `110` if it is not supplied.
This value is inspired by the Kubernetes best practices > This value is inspired by the Kubernetes best practices
[Considerations for large clusters](https://kubernetes.io/docs/setup/best-practices/cluster-large/). > [Considerations for large clusters](https://kubernetes.io/docs/setup/best-practices/cluster-large/).
*Note* User should select the annotation for GPU either `gpu-type` or `dra-driver` depends on whether using Device Plugin or Dynamic Resource Allocation(DRA). `gpu-count` is a common parameter in both. > Note: User should select the annotation for GPU either `gpu-type` or `dra-driver` depends on whether using
> Device Plugin or Dynamic Resource Allocation(DRA). `gpu-count` is a common parameter in both.
#### RBAC changes for scaling from zero #### RBAC changes for scaling from zero
@ -289,6 +295,12 @@ metadata:
capacity.cluster-autoscaler.kubernetes.io/taints: "key1=value1:NoSchedule,key2=value2:NoExecute" capacity.cluster-autoscaler.kubernetes.io/taints: "key1=value1:NoSchedule,key2=value2:NoExecute"
``` ```
> Note: The labels supplied through the capacity annotation will be combined
> with the labels to be propagated from the scalable Cluster API resource.
> The annotation does not override the labels in the scalable resource.
> Please see the [Cluster API Book chapter on Metadata propagation](https://cluster-api.sigs.k8s.io/reference/api/metadata-propagation)
> for more information.
#### Per-NodeGroup autoscaling options #### Per-NodeGroup autoscaling options
Custom autoscaling options per node group (MachineDeployment/MachinePool/MachineSet) can be specified as annoations with a common prefix: Custom autoscaling options per node group (MachineDeployment/MachinePool/MachineSet) can be specified as annoations with a common prefix:

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -54,9 +55,21 @@ const (
type ActionError struct { type ActionError struct {
Code string Code string
Message string Message string
action *Action
}
// Action returns the [Action] that triggered the error if available.
func (e ActionError) Action() *Action {
return e.action
} }
func (e ActionError) Error() string { func (e ActionError) Error() string {
action := e.Action()
if action != nil {
// For easier debugging, the error string contains the Action ID.
return fmt.Sprintf("%s (%s, %d)", e.Message, e.Code, action.ID)
}
return fmt.Sprintf("%s (%s)", e.Message, e.Code) return fmt.Sprintf("%s (%s)", e.Message, e.Code)
} }
@ -65,6 +78,7 @@ func (a *Action) Error() error {
return ActionError{ return ActionError{
Code: a.ErrorCode, Code: a.ErrorCode,
Message: a.ErrorMessage, Message: a.ErrorMessage,
action: a,
} }
} }
return nil return nil
@ -111,11 +125,15 @@ func (c *ActionClient) List(ctx context.Context, opts ActionListOpts) ([]*Action
} }
// All returns all actions. // All returns all actions.
//
// Deprecated: It is required to pass in a list of IDs since 30 January 2025. Please use [ActionClient.AllWithOpts] instead.
func (c *ActionClient) All(ctx context.Context) ([]*Action, error) { func (c *ActionClient) All(ctx context.Context) ([]*Action, error) {
return c.action.All(ctx, ActionListOpts{ListOpts: ListOpts{PerPage: 50}}) return c.action.All(ctx, ActionListOpts{ListOpts: ListOpts{PerPage: 50}})
} }
// AllWithOpts returns all actions for the given options. // AllWithOpts returns all actions for the given options.
//
// It is required to set [ActionListOpts.ID]. Any other fields set in the opts are ignored.
func (c *ActionClient) AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error) { func (c *ActionClient) AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error) {
return c.action.All(ctx, opts) return c.action.All(ctx, opts)
} }
@ -136,20 +154,19 @@ func (c *ResourceActionClient) getBaseURL() string {
// GetByID retrieves an action by its ID. If the action does not exist, nil is returned. // GetByID retrieves an action by its ID. If the action does not exist, nil is returned.
func (c *ResourceActionClient) GetByID(ctx context.Context, id int64) (*Action, *Response, error) { func (c *ResourceActionClient) GetByID(ctx context.Context, id int64) (*Action, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("%s/actions/%d", c.getBaseURL(), id), nil) opPath := c.getBaseURL() + "/actions/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.ActionGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.ActionGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return ActionFromSchema(body.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// List returns a list of actions for a specific page. // List returns a list of actions for a specific page.
@ -157,44 +174,23 @@ func (c *ResourceActionClient) GetByID(ctx context.Context, id int64) (*Action,
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *ResourceActionClient) List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error) { func (c *ResourceActionClient) List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error) {
req, err := c.client.NewRequest( opPath := c.getBaseURL() + "/actions?%s"
ctx, ctx = ctxutil.SetOpPath(ctx, opPath)
"GET",
fmt.Sprintf("%s/actions?%s", c.getBaseURL(), opts.values().Encode()), reqPath := fmt.Sprintf(opPath, opts.values().Encode())
nil,
) respBody, resp, err := getRequest[schema.ActionListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.ActionListResponse return allFromSchemaFunc(respBody.Actions, ActionFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
actions := make([]*Action, 0, len(body.Actions))
for _, i := range body.Actions {
actions = append(actions, ActionFromSchema(i))
}
return actions, resp, nil
} }
// All returns all actions for the given options. // All returns all actions for the given options.
func (c *ResourceActionClient) All(ctx context.Context, opts ActionListOpts) ([]*Action, error) { func (c *ResourceActionClient) All(ctx context.Context, opts ActionListOpts) ([]*Action, error) {
allActions := []*Action{} return iterPages(func(page int) ([]*Action, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
actions, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allActions = append(allActions, actions...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allActions, nil
} }

View File

@ -16,11 +16,14 @@ type ActionWaiter interface {
var _ ActionWaiter = (*ActionClient)(nil) var _ ActionWaiter = (*ActionClient)(nil)
// WaitForFunc waits until all actions are completed by polling the API at the interval // WaitForFunc waits until all actions are completed by polling the API at the interval
// defined by [WithPollBackoffFunc]. An action is considered as complete when its status is // defined by [WithPollOpts]. An action is considered as complete when its status is
// either [ActionStatusSuccess] or [ActionStatusError]. // either [ActionStatusSuccess] or [ActionStatusError].
// //
// The handleUpdate callback is called every time an action is updated. // The handleUpdate callback is called every time an action is updated.
func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error { func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error {
// Filter out nil actions
actions = slices.DeleteFunc(actions, func(a *Action) bool { return a == nil })
running := make(map[int64]struct{}, len(actions)) running := make(map[int64]struct{}, len(actions))
for _, action := range actions { for _, action := range actions {
if action.Status == ActionStatusRunning { if action.Status == ActionStatusRunning {
@ -48,20 +51,21 @@ func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update
retries++ retries++
} }
updates := make([]*Action, 0, len(running))
for runningIDsChunk := range slices.Chunk(slices.Sorted(maps.Keys(running)), 25) {
opts := ActionListOpts{ opts := ActionListOpts{
Sort: []string{"status", "id"}, Sort: []string{"status", "id"},
ID: make([]int64, 0, len(running)), ID: runningIDsChunk,
} }
for actionID := range running {
opts.ID = append(opts.ID, actionID)
}
slices.Sort(opts.ID)
updates, err := c.AllWithOpts(ctx, opts) updatesChunk, err := c.AllWithOpts(ctx, opts)
if err != nil { if err != nil {
return err return err
} }
updates = append(updates, updatesChunk...)
}
if len(updates) != len(running) { if len(updates) != len(running) {
// Some actions may not exist in the API, also fail early to prevent an // Some actions may not exist in the API, also fail early to prevent an
// infinite loop when updates == 0. // infinite loop when updates == 0.
@ -95,7 +99,7 @@ func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update
} }
// WaitFor waits until all actions succeed by polling the API at the interval defined by // WaitFor waits until all actions succeed by polling the API at the interval defined by
// [WithPollBackoffFunc]. An action is considered as succeeded when its status is either // [WithPollOpts]. An action is considered as succeeded when its status is either
// [ActionStatusSuccess]. // [ActionStatusSuccess].
// //
// If a single action fails, the function will stop waiting and the error set in the // If a single action fails, the function will stop waiting and the error set in the

View File

@ -21,7 +21,7 @@ import (
// timeout, use the [context.Context]. Once the method has stopped watching, // timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed. // both returned channels are closed.
// //
// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait // WatchOverallProgress uses the [WithPollOpts] of the [Client] to wait
// until sending the next request. // until sending the next request.
// //
// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead. // Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead.
@ -86,7 +86,7 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti
// timeout, use the [context.Context]. Once the method has stopped watching, // timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed. // both returned channels are closed.
// //
// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until // WatchProgress uses the [WithPollOpts] of the [Client] to wait until
// sending the next request. // sending the next request.
// //
// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead. // Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead.

View File

@ -1,15 +1,12 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -98,41 +95,32 @@ type CertificateClient struct {
// GetByID retrieves a Certificate by its ID. If the Certificate does not exist, nil is returned. // GetByID retrieves a Certificate by its ID. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) GetByID(ctx context.Context, id int64) (*Certificate, *Response, error) { func (c *CertificateClient) GetByID(ctx context.Context, id int64) (*Certificate, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/certificates/%d", id), nil) const opPath = "/certificates/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.CertificateGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.CertificateGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return CertificateFromSchema(body.Certificate), resp, nil return CertificateFromSchema(respBody.Certificate), resp, nil
} }
// GetByName retrieves a Certificate by its name. If the Certificate does not exist, nil is returned. // GetByName retrieves a Certificate by its name. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) GetByName(ctx context.Context, name string) (*Certificate, *Response, error) { func (c *CertificateClient) GetByName(ctx context.Context, name string) (*Certificate, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Certificate, *Response, error) {
return nil, nil, nil return c.List(ctx, CertificateListOpts{Name: name})
} })
Certificate, response, err := c.List(ctx, CertificateListOpts{Name: name})
if len(Certificate) == 0 {
return nil, response, err
}
return Certificate[0], response, err
} }
// Get retrieves a Certificate by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Certificate by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Certificate by its name. If the Certificate does not exist, nil is returned. // retrieves a Certificate by its name. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) Get(ctx context.Context, idOrName string) (*Certificate, *Response, error) { func (c *CertificateClient) Get(ctx context.Context, idOrName string) (*Certificate, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// CertificateListOpts specifies options for listing Certificates. // CertificateListOpts specifies options for listing Certificates.
@ -158,22 +146,17 @@ func (l CertificateListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *CertificateClient) List(ctx context.Context, opts CertificateListOpts) ([]*Certificate, *Response, error) { func (c *CertificateClient) List(ctx context.Context, opts CertificateListOpts) ([]*Certificate, *Response, error) {
path := "/certificates?" + opts.values().Encode() const opPath = "/certificates?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.CertificateListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.CertificateListResponse return allFromSchemaFunc(respBody.Certificates, CertificateFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
Certificates := make([]*Certificate, 0, len(body.Certificates))
for _, s := range body.Certificates {
Certificates = append(Certificates, CertificateFromSchema(s))
}
return Certificates, resp, nil
} }
// All returns all Certificates. // All returns all Certificates.
@ -183,22 +166,10 @@ func (c *CertificateClient) All(ctx context.Context) ([]*Certificate, error) {
// AllWithOpts returns all Certificates for the given options. // AllWithOpts returns all Certificates for the given options.
func (c *CertificateClient) AllWithOpts(ctx context.Context, opts CertificateListOpts) ([]*Certificate, error) { func (c *CertificateClient) AllWithOpts(ctx context.Context, opts CertificateListOpts) ([]*Certificate, error) {
allCertificates := []*Certificate{} return iterPages(func(page int) ([]*Certificate, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
Certificates, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allCertificates = append(allCertificates, Certificates...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allCertificates, nil
} }
// CertificateCreateOpts specifies options for creating a new Certificate. // CertificateCreateOpts specifies options for creating a new Certificate.
@ -214,7 +185,7 @@ type CertificateCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o CertificateCreateOpts) Validate() error { func (o CertificateCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
switch o.Type { switch o.Type {
case "", CertificateTypeUploaded: case "", CertificateTypeUploaded:
@ -222,23 +193,23 @@ func (o CertificateCreateOpts) Validate() error {
case CertificateTypeManaged: case CertificateTypeManaged:
return o.validateManaged() return o.validateManaged()
default: default:
return fmt.Errorf("invalid type: %s", o.Type) return invalidFieldValue(o, "Type", o.Type)
} }
} }
func (o CertificateCreateOpts) validateManaged() error { func (o CertificateCreateOpts) validateManaged() error {
if len(o.DomainNames) == 0 { if len(o.DomainNames) == 0 {
return errors.New("no domain names") return missingField(o, "DomainNames")
} }
return nil return nil
} }
func (o CertificateCreateOpts) validateUploaded() error { func (o CertificateCreateOpts) validateUploaded() error {
if o.Certificate == "" { if o.Certificate == "" {
return errors.New("missing certificate") return missingField(o, "Certificate")
} }
if o.PrivateKey == "" { if o.PrivateKey == "" {
return errors.New("missing private key") return missingField(o, "PrivateKey")
} }
return nil return nil
} }
@ -249,7 +220,7 @@ func (o CertificateCreateOpts) validateUploaded() error {
// CreateCertificate to create such certificates. // CreateCertificate to create such certificates.
func (c *CertificateClient) Create(ctx context.Context, opts CertificateCreateOpts) (*Certificate, *Response, error) { func (c *CertificateClient) Create(ctx context.Context, opts CertificateCreateOpts) (*Certificate, *Response, error) {
if !(opts.Type == "" || opts.Type == CertificateTypeUploaded) { if !(opts.Type == "" || opts.Type == CertificateTypeUploaded) {
return nil, nil, fmt.Errorf("invalid certificate type: %s", opts.Type) return nil, nil, invalidFieldValue(opts, "Type", opts.Type)
} }
result, resp, err := c.CreateCertificate(ctx, opts) result, resp, err := c.CreateCertificate(ctx, opts)
if err != nil { if err != nil {
@ -262,16 +233,20 @@ func (c *CertificateClient) Create(ctx context.Context, opts CertificateCreateOp
func (c *CertificateClient) CreateCertificate( func (c *CertificateClient) CreateCertificate(
ctx context.Context, opts CertificateCreateOpts, ctx context.Context, opts CertificateCreateOpts,
) (CertificateCreateResult, *Response, error) { ) (CertificateCreateResult, *Response, error) {
var ( const opPath = "/certificates"
action *Action ctx = ctxutil.SetOpPath(ctx, opPath)
reqBody schema.CertificateCreateRequest
) reqPath := opPath
result := CertificateCreateResult{}
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return CertificateCreateResult{}, nil, err return result, nil, err
} }
reqBody.Name = opts.Name reqBody := schema.CertificateCreateRequest{
Name: opts.Name,
}
switch opts.Type { switch opts.Type {
case "", CertificateTypeUploaded: case "", CertificateTypeUploaded:
@ -282,32 +257,24 @@ func (c *CertificateClient) CreateCertificate(
reqBody.Type = string(CertificateTypeManaged) reqBody.Type = string(CertificateTypeManaged)
reqBody.DomainNames = opts.DomainNames reqBody.DomainNames = opts.DomainNames
default: default:
return CertificateCreateResult{}, nil, fmt.Errorf("invalid certificate type: %v", opts.Type) return result, nil, invalidFieldValue(opts, "Type", opts.Type)
} }
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.CertificateCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return CertificateCreateResult{}, nil, err return result, resp, err
}
req, err := c.client.NewRequest(ctx, "POST", "/certificates", bytes.NewReader(reqBodyData))
if err != nil {
return CertificateCreateResult{}, nil, err
} }
respBody := schema.CertificateCreateResponse{} result.Certificate = CertificateFromSchema(respBody.Certificate)
resp, err := c.client.Do(req, &respBody)
if err != nil {
return CertificateCreateResult{}, resp, err
}
cert := CertificateFromSchema(respBody.Certificate)
if respBody.Action != nil { if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action) result.Action = ActionFromSchema(*respBody.Action)
} }
return CertificateCreateResult{Certificate: cert, Action: action}, resp, nil return result, resp, nil
} }
// CertificateUpdateOpts specifies options for updating a Certificate. // CertificateUpdateOpts specifies options for updating a Certificate.
@ -318,6 +285,11 @@ type CertificateUpdateOpts struct {
// Update updates a Certificate. // Update updates a Certificate.
func (c *CertificateClient) Update(ctx context.Context, certificate *Certificate, opts CertificateUpdateOpts) (*Certificate, *Response, error) { func (c *CertificateClient) Update(ctx context.Context, certificate *Certificate, opts CertificateUpdateOpts) (*Certificate, *Response, error) {
const opPath = "/certificates/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, certificate.ID)
reqBody := schema.CertificateUpdateRequest{} reqBody := schema.CertificateUpdateRequest{}
if opts.Name != "" { if opts.Name != "" {
reqBody.Name = &opts.Name reqBody.Name = &opts.Name
@ -325,46 +297,36 @@ func (c *CertificateClient) Update(ctx context.Context, certificate *Certificate
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/certificates/%d", certificate.ID) respBody, resp, err := putRequest[schema.CertificateUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.CertificateUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return CertificateFromSchema(respBody.Certificate), resp, nil return CertificateFromSchema(respBody.Certificate), resp, nil
} }
// Delete deletes a certificate. // Delete deletes a certificate.
func (c *CertificateClient) Delete(ctx context.Context, certificate *Certificate) (*Response, error) { func (c *CertificateClient) Delete(ctx context.Context, certificate *Certificate) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/certificates/%d", certificate.ID), nil) const opPath = "/certificates/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, certificate.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// RetryIssuance retries the issuance of a failed managed certificate. // RetryIssuance retries the issuance of a failed managed certificate.
func (c *CertificateClient) RetryIssuance(ctx context.Context, certificate *Certificate) (*Action, *Response, error) { func (c *CertificateClient) RetryIssuance(ctx context.Context, certificate *Certificate) (*Action, *Response, error) {
var respBody schema.CertificateIssuanceRetryResponse const opPath = "/certificates/%d/actions/retry"
ctx = ctxutil.SetOpPath(ctx, opPath)
req, err := c.client.NewRequest(ctx, "POST", fmt.Sprintf("/certificates/%d/actions/retry", certificate.ID), nil) reqPath := fmt.Sprintf(opPath, certificate.ID)
respBody, resp, err := postRequest[schema.CertificateIssuanceRetryResponse](ctx, c.client, reqPath, nil)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
resp, err := c.client.Do(req, &respBody)
if err != nil { return ActionFromSchema(respBody.Action), resp, nil
return nil, nil, err
}
action := ActionFromSchema(respBody.Action)
return action, resp, nil
} }

View File

@ -3,13 +3,12 @@ package hcloud
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"math/rand"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@ -19,7 +18,6 @@ import (
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
// Endpoint is the base URL of the API. // Endpoint is the base URL of the API.
@ -43,13 +41,43 @@ func ConstantBackoff(d time.Duration) BackoffFunc {
} }
// ExponentialBackoff returns a BackoffFunc which implements an exponential // ExponentialBackoff returns a BackoffFunc which implements an exponential
// backoff. // backoff, truncated to 60 seconds.
// It uses the formula: // See [ExponentialBackoffWithOpts] for more details.
func ExponentialBackoff(multiplier float64, base time.Duration) BackoffFunc {
return ExponentialBackoffWithOpts(ExponentialBackoffOpts{
Base: base,
Multiplier: multiplier,
Cap: time.Minute,
})
}
// ExponentialBackoffOpts defines the options used by [ExponentialBackoffWithOpts].
type ExponentialBackoffOpts struct {
Base time.Duration
Multiplier float64
Cap time.Duration
Jitter bool
}
// ExponentialBackoffWithOpts returns a BackoffFunc which implements an exponential
// backoff, truncated to a maximum, and an optional full jitter.
// //
// b^retries * d // See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
func ExponentialBackoff(b float64, d time.Duration) BackoffFunc { func ExponentialBackoffWithOpts(opts ExponentialBackoffOpts) BackoffFunc {
baseSeconds := opts.Base.Seconds()
capSeconds := opts.Cap.Seconds()
return func(retries int) time.Duration { return func(retries int) time.Duration {
return time.Duration(math.Pow(b, float64(retries))) * d // Exponential backoff
backoff := baseSeconds * math.Pow(opts.Multiplier, float64(retries))
// Cap backoff
backoff = math.Min(capSeconds, backoff)
// Add jitter
if opts.Jitter {
backoff = ((backoff - baseSeconds) * rand.Float64()) + baseSeconds // #nosec G404
}
return time.Duration(backoff * float64(time.Second))
} }
} }
@ -58,7 +86,8 @@ type Client struct {
endpoint string endpoint string
token string token string
tokenValid bool tokenValid bool
backoffFunc BackoffFunc retryBackoffFunc BackoffFunc
retryMaxRetries int
pollBackoffFunc BackoffFunc pollBackoffFunc BackoffFunc
httpClient *http.Client httpClient *http.Client
applicationName string applicationName string
@ -66,6 +95,7 @@ type Client struct {
userAgent string userAgent string
debugWriter io.Writer debugWriter io.Writer
instrumentationRegistry prometheus.Registerer instrumentationRegistry prometheus.Registerer
handler handler
Action ActionClient Action ActionClient
Certificate CertificateClient Certificate CertificateClient
@ -110,30 +140,73 @@ func WithToken(token string) ClientOption {
// polling from the API. // polling from the API.
// //
// Deprecated: Setting the poll interval is deprecated, you can now configure // Deprecated: Setting the poll interval is deprecated, you can now configure
// [WithPollBackoffFunc] with a [ConstantBackoff] to get the same results. To // [WithPollOpts] with a [ConstantBackoff] to get the same results. To
// migrate your code, replace your usage like this: // migrate your code, replace your usage like this:
// //
// // before // // before
// hcloud.WithPollInterval(2 * time.Second) // hcloud.WithPollInterval(2 * time.Second)
// // now // // now
// hcloud.WithPollBackoffFunc(hcloud.ConstantBackoff(2 * time.Second)) // hcloud.WithPollOpts(hcloud.PollOpts{
// BackoffFunc: hcloud.ConstantBackoff(2 * time.Second),
// })
func WithPollInterval(pollInterval time.Duration) ClientOption { func WithPollInterval(pollInterval time.Duration) ClientOption {
return WithPollBackoffFunc(ConstantBackoff(pollInterval)) return WithPollOpts(PollOpts{
BackoffFunc: ConstantBackoff(pollInterval),
})
} }
// WithPollBackoffFunc configures a Client to use the specified backoff // WithPollBackoffFunc configures a Client to use the specified backoff
// function when polling from the API. // function when polling from the API.
//
// Deprecated: WithPollBackoffFunc is deprecated, use [WithPollOpts] instead.
func WithPollBackoffFunc(f BackoffFunc) ClientOption { func WithPollBackoffFunc(f BackoffFunc) ClientOption {
return WithPollOpts(PollOpts{
BackoffFunc: f,
})
}
// PollOpts defines the options used by [WithPollOpts].
type PollOpts struct {
BackoffFunc BackoffFunc
}
// WithPollOpts configures a Client to use the specified options when polling from the API.
//
// If [PollOpts.BackoffFunc] is nil, the existing backoff function will be preserved.
func WithPollOpts(opts PollOpts) ClientOption {
return func(client *Client) { return func(client *Client) {
client.pollBackoffFunc = f if opts.BackoffFunc != nil {
client.pollBackoffFunc = opts.BackoffFunc
}
} }
} }
// WithBackoffFunc configures a Client to use the specified backoff function. // WithBackoffFunc configures a Client to use the specified backoff function.
// The backoff function is used for retrying HTTP requests. // The backoff function is used for retrying HTTP requests.
//
// Deprecated: WithBackoffFunc is deprecated, use [WithRetryOpts] instead.
func WithBackoffFunc(f BackoffFunc) ClientOption { func WithBackoffFunc(f BackoffFunc) ClientOption {
return func(client *Client) { return func(client *Client) {
client.backoffFunc = f client.retryBackoffFunc = f
}
}
// RetryOpts defines the options used by [WithRetryOpts].
type RetryOpts struct {
BackoffFunc BackoffFunc
MaxRetries int
}
// WithRetryOpts configures a Client to use the specified options when retrying API
// requests.
//
// If [RetryOpts.BackoffFunc] is nil, the existing backoff function will be preserved.
func WithRetryOpts(opts RetryOpts) ClientOption {
return func(client *Client) {
if opts.BackoffFunc != nil {
client.retryBackoffFunc = opts.BackoffFunc
}
client.retryMaxRetries = opts.MaxRetries
} }
} }
@ -175,7 +248,15 @@ func NewClient(options ...ClientOption) *Client {
endpoint: Endpoint, endpoint: Endpoint,
tokenValid: true, tokenValid: true,
httpClient: &http.Client{}, httpClient: &http.Client{},
backoffFunc: ExponentialBackoff(2, 500*time.Millisecond),
retryBackoffFunc: ExponentialBackoffWithOpts(ExponentialBackoffOpts{
Base: time.Second,
Multiplier: 2,
Cap: time.Minute,
Jitter: true,
}),
retryMaxRetries: 5,
pollBackoffFunc: ConstantBackoff(500 * time.Millisecond), pollBackoffFunc: ConstantBackoff(500 * time.Millisecond),
} }
@ -186,9 +267,11 @@ func NewClient(options ...ClientOption) *Client {
client.buildUserAgent() client.buildUserAgent()
if client.instrumentationRegistry != nil { if client.instrumentationRegistry != nil {
i := instrumentation.New("api", client.instrumentationRegistry) i := instrumentation.New("api", client.instrumentationRegistry)
client.httpClient.Transport = i.InstrumentedRoundTripper() client.httpClient.Transport = i.InstrumentedRoundTripper(client.httpClient.Transport)
} }
client.handler = assembleHandlerChain(client)
client.Action = ActionClient{action: &ResourceActionClient{client: client}} client.Action = ActionClient{action: &ResourceActionClient{client: client}}
client.Datacenter = DatacenterClient{client: client} client.Datacenter = DatacenterClient{client: client}
client.FloatingIP = FloatingIPClient{client: client, Action: &ResourceActionClient{client: client, resource: "floating_ips"}} client.FloatingIP = FloatingIPClient{client: client, Action: &ResourceActionClient{client: client, resource: "floating_ips"}}
@ -238,97 +321,8 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
// Do performs an HTTP request against the API. // Do performs an HTTP request against the API.
// v can be nil, an io.Writer to write the response body to or a pointer to // v can be nil, an io.Writer to write the response body to or a pointer to
// a struct to json.Unmarshal the response to. // a struct to json.Unmarshal the response to.
func (c *Client) Do(r *http.Request, v interface{}) (*Response, error) { func (c *Client) Do(req *http.Request, v any) (*Response, error) {
var retries int return c.handler.Do(req, v)
var body []byte
var err error
if r.ContentLength > 0 {
body, err = io.ReadAll(r.Body)
if err != nil {
r.Body.Close()
return nil, err
}
r.Body.Close()
}
for {
if r.ContentLength > 0 {
r.Body = io.NopCloser(bytes.NewReader(body))
}
if c.debugWriter != nil {
dumpReq, err := dumpRequest(r)
if err != nil {
return nil, err
}
fmt.Fprintf(c.debugWriter, "--- Request:\n%s\n\n", dumpReq)
}
resp, err := c.httpClient.Do(r)
if err != nil {
return nil, err
}
response := &Response{Response: resp}
body, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return response, err
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(body))
if c.debugWriter != nil {
dumpResp, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, err
}
fmt.Fprintf(c.debugWriter, "--- Response:\n%s\n\n", dumpResp)
}
if err = response.readMeta(body); err != nil {
return response, fmt.Errorf("hcloud: error reading response meta data: %s", err)
}
if response.StatusCode >= 400 && response.StatusCode <= 599 {
err = errorFromResponse(response, body)
if err == nil {
err = fmt.Errorf("hcloud: server responded with status code %d", resp.StatusCode)
} else if IsError(err, ErrorCodeConflict) {
c.backoff(retries)
retries++
continue
}
return response, err
}
if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, bytes.NewReader(body))
} else {
err = json.Unmarshal(body, v)
}
}
return response, err
}
}
func (c *Client) backoff(retries int) {
time.Sleep(c.backoffFunc(retries))
}
func (c *Client) all(f func(int) (*Response, error)) error {
var (
page = 1
)
for {
resp, err := f(page)
if err != nil {
return err
}
if resp.Meta.Pagination == nil || resp.Meta.Pagination.NextPage == 0 {
return nil
}
page = resp.Meta.Pagination.NextPage
}
} }
func (c *Client) buildUserAgent() { func (c *Client) buildUserAgent() {
@ -342,43 +336,6 @@ func (c *Client) buildUserAgent() {
} }
} }
func dumpRequest(r *http.Request) ([]byte, error) {
// Duplicate the request, so we can redact the auth header
rDuplicate := r.Clone(context.Background())
rDuplicate.Header.Set("Authorization", "REDACTED")
// To get the request body we need to read it before the request was actually sent.
// See https://github.com/golang/go/issues/29792
dumpReq, err := httputil.DumpRequestOut(rDuplicate, true)
if err != nil {
return nil, err
}
// Set original request body to the duplicate created by DumpRequestOut. The request body is not duplicated
// by .Clone() and instead just referenced, so it would be completely read otherwise.
r.Body = rDuplicate.Body
return dumpReq, nil
}
func errorFromResponse(resp *Response, body []byte) error {
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
return nil
}
var respBody schema.ErrorResponse
if err := json.Unmarshal(body, &respBody); err != nil {
return nil
}
if respBody.Error.Code == "" && respBody.Error.Message == "" {
return nil
}
hcErr := ErrorFromSchema(respBody.Error)
hcErr.response = resp
return hcErr
}
const ( const (
headerCorrelationID = "X-Correlation-Id" headerCorrelationID = "X-Correlation-Id"
) )
@ -387,35 +344,34 @@ const (
type Response struct { type Response struct {
*http.Response *http.Response
Meta Meta Meta Meta
// body holds a copy of the http.Response body that must be used within the handler
// chain. The http.Response.Body is reserved for external users.
body []byte
} }
func (r *Response) readMeta(body []byte) error { // populateBody copies the original [http.Response] body into the internal [Response] body
if h := r.Header.Get("RateLimit-Limit"); h != "" { // property, and restore the original [http.Response] body as if it was untouched.
r.Meta.Ratelimit.Limit, _ = strconv.Atoi(h) func (r *Response) populateBody() error {
} // Read full response body and save it for later use
if h := r.Header.Get("RateLimit-Remaining"); h != "" { body, err := io.ReadAll(r.Body)
r.Meta.Ratelimit.Remaining, _ = strconv.Atoi(h) r.Body.Close()
} if err != nil {
if h := r.Header.Get("RateLimit-Reset"); h != "" {
if ts, err := strconv.ParseInt(h, 10, 64); err == nil {
r.Meta.Ratelimit.Reset = time.Unix(ts, 0)
}
}
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") {
var s schema.MetaResponse
if err := json.Unmarshal(body, &s); err != nil {
return err return err
} }
if s.Meta.Pagination != nil { r.body = body
p := PaginationFromSchema(*s.Meta.Pagination)
r.Meta.Pagination = &p // Restore the body as if it was untouched, as it might be read by external users
} r.Body = io.NopCloser(bytes.NewReader(body))
}
return nil return nil
} }
// hasJSONBody returns whether the response has a JSON body.
func (r *Response) hasJSONBody() bool {
return len(r.body) > 0 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/json")
}
// internalCorrelationID returns the unique ID of the request as set by the API. This ID can help with support requests, // internalCorrelationID returns the unique ID of the request as set by the API. This ID can help with support requests,
// as it allows the people working on identify this request in particular. // as it allows the people working on identify this request in particular.
func (r *Response) internalCorrelationID() string { func (r *Response) internalCorrelationID() string {

View File

@ -0,0 +1,101 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"io"
)
func getRequest[Schema any](ctx context.Context, client *Client, url string) (Schema, *Response, error) {
var respBody Schema
req, err := client.NewRequest(ctx, "GET", url, nil)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func postRequest[Schema any](ctx context.Context, client *Client, url string, reqBody any) (Schema, *Response, error) {
var respBody Schema
var reqBodyReader io.Reader
if reqBody != nil {
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
return respBody, nil, err
}
reqBodyReader = bytes.NewReader(reqBodyBytes)
}
req, err := client.NewRequest(ctx, "POST", url, reqBodyReader)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func putRequest[Schema any](ctx context.Context, client *Client, url string, reqBody any) (Schema, *Response, error) {
var respBody Schema
var reqBodyReader io.Reader
if reqBody != nil {
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
return respBody, nil, err
}
reqBodyReader = bytes.NewReader(reqBodyBytes)
}
req, err := client.NewRequest(ctx, "PUT", url, reqBodyReader)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func deleteRequest[Schema any](ctx context.Context, client *Client, url string) (Schema, *Response, error) {
var respBody Schema
req, err := client.NewRequest(ctx, "DELETE", url, nil)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func deleteRequestNoResult(ctx context.Context, client *Client, url string) (*Response, error) {
req, err := client.NewRequest(ctx, "DELETE", url, nil)
if err != nil {
return nil, err
}
return client.Do(req, nil)
}

View File

@ -0,0 +1,56 @@
package hcloud
import (
"context"
"net/http"
)
// handler is an interface representing a client request transaction. The handler are
// meant to be chained, similarly to the [http.RoundTripper] interface.
//
// The handler chain is placed between the [Client] API operations and the
// [http.Client].
type handler interface {
Do(req *http.Request, v any) (resp *Response, err error)
}
// assembleHandlerChain assembles the chain of handlers used to make API requests.
//
// The order of the handlers is important.
func assembleHandlerChain(client *Client) handler {
// Start down the chain: sending the http request
h := newHTTPHandler(client.httpClient)
// Insert debug writer if enabled
if client.debugWriter != nil {
h = wrapDebugHandler(h, client.debugWriter)
}
// Read rate limit headers
h = wrapRateLimitHandler(h)
// Build error from response
h = wrapErrorHandler(h)
// Retry request if condition are met
h = wrapRetryHandler(h, client.retryBackoffFunc, client.retryMaxRetries)
// Finally parse the response body into the provided schema
h = wrapParseHandler(h)
return h
}
// cloneRequest clones both the request and the request body.
func cloneRequest(req *http.Request, ctx context.Context) (cloned *http.Request, err error) { //revive:disable:context-as-argument
cloned = req.Clone(ctx)
if req.ContentLength > 0 {
cloned.Body, err = req.GetBody()
if err != nil {
return nil, err
}
}
return cloned, nil
}

View File

@ -0,0 +1,50 @@
package hcloud
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
)
func wrapDebugHandler(wrapped handler, output io.Writer) handler {
return &debugHandler{wrapped, output}
}
type debugHandler struct {
handler handler
output io.Writer
}
func (h *debugHandler) Do(req *http.Request, v any) (resp *Response, err error) {
// Clone the request, so we can redact the auth header, read the body
// and use a new context.
cloned, err := cloneRequest(req, context.Background())
if err != nil {
return nil, err
}
cloned.Header.Set("Authorization", "REDACTED")
dumpReq, err := httputil.DumpRequestOut(cloned, true)
if err != nil {
return nil, err
}
fmt.Fprintf(h.output, "--- Request:\n%s\n\n", dumpReq)
resp, err = h.handler.Do(req, v)
if err != nil {
return resp, err
}
dumpResp, err := httputil.DumpResponse(resp.Response, true)
if err != nil {
return nil, err
}
fmt.Fprintf(h.output, "--- Response:\n%s\n\n", dumpResp)
return resp, err
}

View File

@ -0,0 +1,53 @@
package hcloud
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
var ErrStatusCode = errors.New("server responded with status code")
func wrapErrorHandler(wrapped handler) handler {
return &errorHandler{wrapped}
}
type errorHandler struct {
handler handler
}
func (h *errorHandler) Do(req *http.Request, v any) (resp *Response, err error) {
resp, err = h.handler.Do(req, v)
if err != nil {
return resp, err
}
if resp.StatusCode >= 400 && resp.StatusCode <= 599 {
err = errorFromBody(resp)
if err == nil {
err = fmt.Errorf("hcloud: %w %d", ErrStatusCode, resp.StatusCode)
}
}
return resp, err
}
func errorFromBody(resp *Response) error {
if !resp.hasJSONBody() {
return nil
}
var s schema.ErrorResponse
if err := json.Unmarshal(resp.body, &s); err != nil {
return nil // nolint: nilerr
}
if s.Error.Code == "" && s.Error.Message == "" {
return nil
}
hcErr := ErrorFromSchema(s.Error)
hcErr.response = resp
return hcErr
}

View File

@ -0,0 +1,28 @@
package hcloud
import (
"net/http"
)
func newHTTPHandler(httpClient *http.Client) handler {
return &httpHandler{httpClient}
}
type httpHandler struct {
httpClient *http.Client
}
func (h *httpHandler) Do(req *http.Request, _ interface{}) (*Response, error) {
httpResponse, err := h.httpClient.Do(req) //nolint: bodyclose
resp := &Response{Response: httpResponse}
if err != nil {
return resp, err
}
err = resp.populateBody()
if err != nil {
return resp, err
}
return resp, err
}

View File

@ -0,0 +1,50 @@
package hcloud
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
func wrapParseHandler(wrapped handler) handler {
return &parseHandler{wrapped}
}
type parseHandler struct {
handler handler
}
func (h *parseHandler) Do(req *http.Request, v any) (resp *Response, err error) {
// respBody is not needed down the handler chain
resp, err = h.handler.Do(req, nil)
if err != nil {
return resp, err
}
if resp.hasJSONBody() {
// Parse the response meta
var s schema.MetaResponse
if err := json.Unmarshal(resp.body, &s); err != nil {
return resp, fmt.Errorf("hcloud: error reading response meta data: %w", err)
}
if s.Meta.Pagination != nil {
p := PaginationFromSchema(*s.Meta.Pagination)
resp.Meta.Pagination = &p
}
}
// Parse the response schema
if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, bytes.NewReader(resp.body))
} else {
err = json.Unmarshal(resp.body, v)
}
}
return resp, err
}

View File

@ -0,0 +1,36 @@
package hcloud
import (
"net/http"
"strconv"
"time"
)
func wrapRateLimitHandler(wrapped handler) handler {
return &rateLimitHandler{wrapped}
}
type rateLimitHandler struct {
handler handler
}
func (h *rateLimitHandler) Do(req *http.Request, v any) (resp *Response, err error) {
resp, err = h.handler.Do(req, v)
// Ensure the embedded [*http.Response] is not nil, e.g. on canceled context
if resp != nil && resp.Response != nil && resp.Response.Header != nil {
if h := resp.Header.Get("RateLimit-Limit"); h != "" {
resp.Meta.Ratelimit.Limit, _ = strconv.Atoi(h)
}
if h := resp.Header.Get("RateLimit-Remaining"); h != "" {
resp.Meta.Ratelimit.Remaining, _ = strconv.Atoi(h)
}
if h := resp.Header.Get("RateLimit-Reset"); h != "" {
if ts, err := strconv.ParseInt(h, 10, 64); err == nil {
resp.Meta.Ratelimit.Reset = time.Unix(ts, 0)
}
}
}
return resp, err
}

View File

@ -0,0 +1,84 @@
package hcloud
import (
"errors"
"net"
"net/http"
"time"
)
func wrapRetryHandler(wrapped handler, backoffFunc BackoffFunc, maxRetries int) handler {
return &retryHandler{wrapped, backoffFunc, maxRetries}
}
type retryHandler struct {
handler handler
backoffFunc BackoffFunc
maxRetries int
}
func (h *retryHandler) Do(req *http.Request, v any) (resp *Response, err error) {
retries := 0
ctx := req.Context()
for {
// Clone the request using the original context
cloned, err := cloneRequest(req, ctx)
if err != nil {
return nil, err
}
resp, err = h.handler.Do(cloned, v)
if err != nil {
// Beware the diversity of the errors:
// - request preparation
// - network connectivity
// - http status code (see [errorHandler])
if ctx.Err() != nil {
// early return if the context was canceled or timed out
return resp, err
}
if retries < h.maxRetries && retryPolicy(resp, err) {
select {
case <-ctx.Done():
return resp, err
case <-time.After(h.backoffFunc(retries)):
retries++
continue
}
}
}
return resp, err
}
}
func retryPolicy(resp *Response, err error) bool {
if err != nil {
var apiErr Error
var netErr net.Error
switch {
case errors.As(err, &apiErr):
switch apiErr.Code { //nolint:exhaustive
case ErrorCodeConflict:
return true
case ErrorCodeRateLimitExceeded:
return true
}
case errors.Is(err, ErrStatusCode):
switch resp.Response.StatusCode {
// 5xx errors
case http.StatusBadGateway, http.StatusGatewayTimeout:
return true
}
case errors.As(err, &netErr):
if netErr.Timeout() {
return true
}
}
}
return false
}

View File

@ -0,0 +1,85 @@
package hcloud
import (
"context"
"strconv"
)
// allFromSchemaFunc transform each item in the list using the FromSchema function, and
// returns the result.
func allFromSchemaFunc[T, V any](all []T, fn func(T) V) []V {
result := make([]V, len(all))
for i, t := range all {
result[i] = fn(t)
}
return result
}
// iterPages fetches each pages using the list function, and returns the result.
func iterPages[T any](listFn func(int) ([]*T, *Response, error)) ([]*T, error) {
page := 1
result := []*T{}
for {
pageResult, resp, err := listFn(page)
if err != nil {
return nil, err
}
result = append(result, pageResult...)
if resp.Meta.Pagination == nil || resp.Meta.Pagination.NextPage == 0 {
return result, nil
}
page = resp.Meta.Pagination.NextPage
}
}
// firstBy fetches a list of items using the list function, and returns the first item
// of the list if present otherwise nil.
func firstBy[T any](listFn func() ([]*T, *Response, error)) (*T, *Response, error) {
items, resp, err := listFn()
if len(items) == 0 {
return nil, resp, err
}
return items[0], resp, err
}
// firstByName is a wrapper around [firstBy], that checks if the provided name is not
// empty.
func firstByName[T any](name string, listFn func() ([]*T, *Response, error)) (*T, *Response, error) {
if name == "" {
return nil, nil, nil
}
return firstBy(listFn)
}
// getByIDOrName fetches the resource by ID when the identifier is an integer, otherwise
// by Name. To support resources that have a integer as Name, an additional attempt is
// made to fetch the resource by Name using the ID.
//
// Since API managed resources (locations, server types, ...) do not have integers as
// names, this function is only meaningful for user managed resources (ssh keys,
// servers).
func getByIDOrName[T any](
ctx context.Context,
getByIDFn func(ctx context.Context, id int64) (*T, *Response, error),
getByNameFn func(ctx context.Context, name string) (*T, *Response, error),
idOrName string,
) (*T, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
result, resp, err := getByIDFn(ctx, id)
if err != nil {
return result, resp, err
}
if result != nil {
return result, resp, err
}
// Fallback to get by Name if the resource was not found
}
return getByNameFn(ctx, idOrName)
}

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -32,32 +33,27 @@ type DatacenterClient struct {
// GetByID retrieves a datacenter by its ID. If the datacenter does not exist, nil is returned. // GetByID retrieves a datacenter by its ID. If the datacenter does not exist, nil is returned.
func (c *DatacenterClient) GetByID(ctx context.Context, id int64) (*Datacenter, *Response, error) { func (c *DatacenterClient) GetByID(ctx context.Context, id int64) (*Datacenter, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/datacenters/%d", id), nil) const opPath = "/datacenters/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.DatacenterGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.DatacenterGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, resp, err return nil, resp, err
} }
return DatacenterFromSchema(body.Datacenter), resp, nil
return DatacenterFromSchema(respBody.Datacenter), resp, nil
} }
// GetByName retrieves a datacenter by its name. If the datacenter does not exist, nil is returned. // GetByName retrieves a datacenter by its name. If the datacenter does not exist, nil is returned.
func (c *DatacenterClient) GetByName(ctx context.Context, name string) (*Datacenter, *Response, error) { func (c *DatacenterClient) GetByName(ctx context.Context, name string) (*Datacenter, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Datacenter, *Response, error) {
return nil, nil, nil return c.List(ctx, DatacenterListOpts{Name: name})
} })
datacenters, response, err := c.List(ctx, DatacenterListOpts{Name: name})
if len(datacenters) == 0 {
return nil, response, err
}
return datacenters[0], response, err
} }
// Get retrieves a datacenter by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a datacenter by its ID if the input can be parsed as an integer, otherwise it
@ -92,22 +88,17 @@ func (l DatacenterListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *DatacenterClient) List(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, *Response, error) { func (c *DatacenterClient) List(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, *Response, error) {
path := "/datacenters?" + opts.values().Encode() const opPath = "/datacenters?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.DatacenterListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.DatacenterListResponse return allFromSchemaFunc(respBody.Datacenters, DatacenterFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
datacenters := make([]*Datacenter, 0, len(body.Datacenters))
for _, i := range body.Datacenters {
datacenters = append(datacenters, DatacenterFromSchema(i))
}
return datacenters, resp, nil
} }
// All returns all datacenters. // All returns all datacenters.
@ -117,20 +108,8 @@ func (c *DatacenterClient) All(ctx context.Context) ([]*Datacenter, error) {
// AllWithOpts returns all datacenters for the given options. // AllWithOpts returns all datacenters for the given options.
func (c *DatacenterClient) AllWithOpts(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, error) { func (c *DatacenterClient) AllWithOpts(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, error) {
allDatacenters := []*Datacenter{} return iterPages(func(page int) ([]*Datacenter, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
datacenters, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allDatacenters = append(allDatacenters, datacenters...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allDatacenters, nil
} }

View File

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings"
) )
// ErrorCode represents an error code returned from the API. // ErrorCode represents an error code returned from the API.
@ -29,6 +31,7 @@ const (
ErrorCodeRobotUnavailable ErrorCode = "robot_unavailable" // Robot was not available. The caller may retry the operation after a short delay ErrorCodeRobotUnavailable ErrorCode = "robot_unavailable" // Robot was not available. The caller may retry the operation after a short delay
ErrorCodeResourceLocked ErrorCode = "resource_locked" // The resource is locked. The caller should contact support ErrorCodeResourceLocked ErrorCode = "resource_locked" // The resource is locked. The caller should contact support
ErrorUnsupportedError ErrorCode = "unsupported_error" // The given resource does not support this ErrorUnsupportedError ErrorCode = "unsupported_error" // The given resource does not support this
ErrorDeprecatedAPIEndpoint ErrorCode = "deprecated_api_endpoint" // The request can not be answered because the API functionality was removed
// Server related error codes. // Server related error codes.
@ -126,11 +129,16 @@ type ErrorDetailsInvalidInputField struct {
Messages []string Messages []string
} }
// IsError returns whether err is an API error with the given error code. // ErrorDetailsDeprecatedAPIEndpoint contains the details of a 'deprecated_api_endpoint' error.
func IsError(err error, code ErrorCode) bool { type ErrorDetailsDeprecatedAPIEndpoint struct {
Announcement string
}
// IsError returns whether err is an API error with one of the given error codes.
func IsError(err error, code ...ErrorCode) bool {
var apiErr Error var apiErr Error
ok := errors.As(err, &apiErr) ok := errors.As(err, &apiErr)
return ok && apiErr.Code == code return ok && slices.Index(code, apiErr.Code) > -1
} }
type InvalidIPError struct { type InvalidIPError struct {
@ -148,3 +156,40 @@ type DNSNotFoundError struct {
func (e DNSNotFoundError) Error() string { func (e DNSNotFoundError) Error() string {
return fmt.Sprintf("dns for ip %s not found", e.IP.String()) return fmt.Sprintf("dns for ip %s not found", e.IP.String())
} }
// ArgumentError is a type of error returned when validating arguments.
type ArgumentError string
func (e ArgumentError) Error() string { return string(e) }
func newArgumentErrorf(format string, args ...any) ArgumentError {
return ArgumentError(fmt.Sprintf(format, args...))
}
func missingArgument(name string, obj any) error {
return newArgumentErrorf("missing argument '%s' [%T]", name, obj)
}
func invalidArgument(name string, obj any) error {
return newArgumentErrorf("invalid value '%v' for argument '%s' [%T]", obj, name, obj)
}
func missingField(obj any, field string) error {
return newArgumentErrorf("missing field [%s] in [%T]", field, obj)
}
func invalidFieldValue(obj any, field string, value any) error {
return newArgumentErrorf("invalid value '%v' for field [%s] in [%T]", value, field, obj)
}
func missingOneOfFields(obj any, fields ...string) error {
return newArgumentErrorf("missing one of fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}
func mutuallyExclusiveFields(obj any, fields ...string) error {
return newArgumentErrorf("found mutually exclusive fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}
func missingRequiredTogetherFields(obj any, fields ...string) error {
return newArgumentErrorf("missing required together fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}

View File

@ -0,0 +1,11 @@
package actionutil
import "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud"
// AppendNext return the action and the next actions in a new slice.
func AppendNext(action *hcloud.Action, nextActions []*hcloud.Action) []*hcloud.Action {
all := make([]*hcloud.Action, 0, 1+len(nextActions))
all = append(all, action)
all = append(all, nextActions...)
return all
}

View File

@ -0,0 +1,30 @@
package ctxutil
import (
"context"
"strings"
)
// key is an unexported type to prevents collisions with keys defined in other packages.
type key struct{}
// opPathKey is the key for operation path in Contexts.
var opPathKey = key{}
// SetOpPath processes the operation path and save it in the context before returning it.
func SetOpPath(ctx context.Context, path string) context.Context {
path, _, _ = strings.Cut(path, "?")
path = strings.ReplaceAll(path, "%d", "-")
path = strings.ReplaceAll(path, "%s", "-")
return context.WithValue(ctx, opPathKey, path)
}
// OpPath returns the operation path from the context.
func OpPath(ctx context.Context) string {
result, ok := ctx.Value(opPathKey).(string)
if !ok {
return ""
}
return result
}

View File

@ -0,0 +1,4 @@
// Package exp is a namespace that holds experimental features for the `hcloud-go` library.
//
// Breaking changes may occur without notice. Do not use in production!
package exp

View File

@ -0,0 +1,40 @@
package envutil
import (
"fmt"
"os"
"strings"
)
// LookupEnvWithFile retrieves the value of the environment variable named by the key (e.g.
// HCLOUD_TOKEN). If the previous environment variable is not set, it retrieves the
// content of the file located by a second environment variable named by the key +
// '_FILE' (.e.g HCLOUD_TOKEN_FILE).
//
// For both cases, the returned value may be empty.
//
// The value from the environment takes precedence over the value from the file.
func LookupEnvWithFile(key string) (string, error) {
// Check if the value is set in the environment (e.g. HCLOUD_TOKEN)
value, ok := os.LookupEnv(key)
if ok {
return value, nil
}
key += "_FILE"
// Check if the value is set via a file (e.g. HCLOUD_TOKEN_FILE)
valueFile, ok := os.LookupEnv(key)
if !ok {
// Validation of the value happens outside of this function
return "", nil
}
// Read the content of the file
valueBytes, err := os.ReadFile(valueFile)
if err != nil {
return "", fmt.Errorf("failed to read %s: %w", key, err)
}
return strings.TrimSpace(string(valueBytes)), nil
}

View File

@ -0,0 +1,19 @@
package randutil
import (
"crypto/rand"
"encoding/hex"
"fmt"
)
// GenerateID returns a hex encoded random string with a len of 8 chars similar to
// "2873fce7".
func GenerateID() string {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
// Should never happen as of go1.24: https://github.com/golang/go/issues/66821
panic(fmt.Errorf("failed to generate random string: %w", err))
}
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,86 @@
package sshutil
import (
"crypto"
"crypto/ed25519"
"encoding/pem"
"fmt"
"golang.org/x/crypto/ssh"
)
// GenerateKeyPair generates a new ed25519 ssh key pair, and returns the private key and
// the public key respectively.
func GenerateKeyPair() ([]byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
return nil, nil, fmt.Errorf("could not generate key pair: %w", err)
}
privBytes, err := encodePrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("could not encode private key: %w", err)
}
pubBytes, err := encodePublicKey(pub)
if err != nil {
return nil, nil, fmt.Errorf("could not encode public key: %w", err)
}
return privBytes, pubBytes, nil
}
func encodePrivateKey(priv crypto.PrivateKey) ([]byte, error) {
privPem, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
return nil, err
}
return pem.EncodeToMemory(privPem), nil
}
func encodePublicKey(pub crypto.PublicKey) ([]byte, error) {
sshPub, err := ssh.NewPublicKey(pub)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(sshPub), nil
}
type privateKeyWithPublicKey interface {
crypto.PrivateKey
Public() crypto.PublicKey
}
// GeneratePublicKey generate a public key from the provided private key.
func GeneratePublicKey(privBytes []byte) ([]byte, error) {
priv, err := ssh.ParseRawPrivateKey(privBytes)
if err != nil {
return nil, fmt.Errorf("could not decode private key: %w", err)
}
key, ok := priv.(privateKeyWithPublicKey)
if !ok {
return nil, fmt.Errorf("private key doesn't export Public() crypto.PublicKey")
}
pubBytes, err := encodePublicKey(key.Public())
if err != nil {
return nil, fmt.Errorf("could not encode public key: %w", err)
}
return pubBytes, nil
}
// GetPublicKeyFingerprint generate the finger print for the provided public key.
func GetPublicKeyFingerprint(pubBytes []byte) (string, error) {
pub, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes)
if err != nil {
return "", fmt.Errorf("could not decode public key: %w", err)
}
fingerprint := ssh.FingerprintLegacyMD5(pub)
return fingerprint, nil
}

View File

@ -0,0 +1,24 @@
package labelutil
import (
"fmt"
"sort"
"strings"
)
// Selector combines the label set into a [label selector](https://docs.hetzner.cloud/#label-selector) that only selects
// resources have all specified labels set.
//
// The selector string can be used to filter resources when listing, for example with [hcloud.ServerClient.AllWithOpts()].
func Selector(labels map[string]string) string {
selectors := make([]string, 0, len(labels))
for k, v := range labels {
selectors = append(selectors, fmt.Sprintf("%s=%s", k, v))
}
// Reproducible result for tests
sort.Strings(selectors)
return strings.Join(selectors, ",")
}

View File

@ -0,0 +1,123 @@
package mockutil
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Request describes a http request that a [httptest.Server] should receive, and the
// corresponding response to return.
//
// Additional checks on the request (e.g. request body) may be added using the
// [Request.Want] function.
//
// The response body is populated from either a JSON struct, or a JSON string.
type Request struct {
Method string
Path string
Want func(t *testing.T, r *http.Request)
Status int
JSON any
JSONRaw string
}
// Handler is using a [Server] to mock http requests provided by the user.
func Handler(t *testing.T, requests []Request) http.HandlerFunc {
t.Helper()
server := NewServer(t, requests)
t.Cleanup(server.close)
return server.handler
}
// NewServer returns a new mock server that closes itself at the end of the test.
func NewServer(t *testing.T, requests []Request) *Server {
t.Helper()
o := &Server{t: t}
o.Server = httptest.NewServer(http.HandlerFunc(o.handler))
t.Cleanup(o.close)
o.Expect(requests)
return o
}
// Server embeds a [httptest.Server] that answers HTTP calls with a list of expected [Request].
//
// Request matching is based on the request count, and the user provided request will be
// iterated over.
//
// A Server must be created using the [NewServer] function.
type Server struct {
*httptest.Server
t *testing.T
requests []Request
index int
}
// Expect adds requests to the list of requests expected by the [Server].
func (m *Server) Expect(requests []Request) {
m.requests = append(m.requests, requests...)
}
func (m *Server) close() {
m.t.Helper()
m.Server.Close()
assert.EqualValues(m.t, len(m.requests), m.index, "expected more calls")
}
func (m *Server) handler(w http.ResponseWriter, r *http.Request) {
if testing.Verbose() {
m.t.Logf("call %d: %s %s\n", m.index, r.Method, r.RequestURI)
}
if m.index >= len(m.requests) {
m.t.Fatalf("received unknown call %d", m.index)
}
expected := m.requests[m.index]
expectedCall := expected.Method
foundCall := r.Method
if expected.Path != "" {
expectedCall += " " + expected.Path
foundCall += " " + r.RequestURI
}
require.Equal(m.t, expectedCall, foundCall) // nolint: testifylint
if expected.Want != nil {
expected.Want(m.t, r)
}
switch {
case expected.JSON != nil:
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(expected.Status)
if err := json.NewEncoder(w).Encode(expected.JSON); err != nil {
m.t.Fatal(err)
}
case expected.JSONRaw != "":
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(expected.Status)
_, err := w.Write([]byte(expected.JSONRaw))
if err != nil {
m.t.Fatal(err)
}
default:
w.WriteHeader(expected.Status)
}
m.index++
}

View File

@ -1,16 +1,13 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -96,41 +93,33 @@ type FirewallClient struct {
// GetByID retrieves a Firewall by its ID. If the Firewall does not exist, nil is returned. // GetByID retrieves a Firewall by its ID. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) GetByID(ctx context.Context, id int64) (*Firewall, *Response, error) { func (c *FirewallClient) GetByID(ctx context.Context, id int64) (*Firewall, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/firewalls/%d", id), nil) const opPath = "/firewalls/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.FirewallGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.FirewallGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return FirewallFromSchema(body.Firewall), resp, nil
return FirewallFromSchema(respBody.Firewall), resp, nil
} }
// GetByName retrieves a Firewall by its name. If the Firewall does not exist, nil is returned. // GetByName retrieves a Firewall by its name. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) GetByName(ctx context.Context, name string) (*Firewall, *Response, error) { func (c *FirewallClient) GetByName(ctx context.Context, name string) (*Firewall, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Firewall, *Response, error) {
return nil, nil, nil return c.List(ctx, FirewallListOpts{Name: name})
} })
firewalls, response, err := c.List(ctx, FirewallListOpts{Name: name})
if len(firewalls) == 0 {
return nil, response, err
}
return firewalls[0], response, err
} }
// Get retrieves a Firewall by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Firewall by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Firewall by its name. If the Firewall does not exist, nil is returned. // retrieves a Firewall by its name. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) Get(ctx context.Context, idOrName string) (*Firewall, *Response, error) { func (c *FirewallClient) Get(ctx context.Context, idOrName string) (*Firewall, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// FirewallListOpts specifies options for listing Firewalls. // FirewallListOpts specifies options for listing Firewalls.
@ -156,22 +145,17 @@ func (l FirewallListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *FirewallClient) List(ctx context.Context, opts FirewallListOpts) ([]*Firewall, *Response, error) { func (c *FirewallClient) List(ctx context.Context, opts FirewallListOpts) ([]*Firewall, *Response, error) {
path := "/firewalls?" + opts.values().Encode() const opPath = "/firewalls?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.FirewallListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.FirewallListResponse return allFromSchemaFunc(respBody.Firewalls, FirewallFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
firewalls := make([]*Firewall, 0, len(body.Firewalls))
for _, s := range body.Firewalls {
firewalls = append(firewalls, FirewallFromSchema(s))
}
return firewalls, resp, nil
} }
// All returns all Firewalls. // All returns all Firewalls.
@ -181,22 +165,10 @@ func (c *FirewallClient) All(ctx context.Context) ([]*Firewall, error) {
// AllWithOpts returns all Firewalls for the given options. // AllWithOpts returns all Firewalls for the given options.
func (c *FirewallClient) AllWithOpts(ctx context.Context, opts FirewallListOpts) ([]*Firewall, error) { func (c *FirewallClient) AllWithOpts(ctx context.Context, opts FirewallListOpts) ([]*Firewall, error) {
allFirewalls := []*Firewall{} return iterPages(func(page int) ([]*Firewall, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
firewalls, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allFirewalls = append(allFirewalls, firewalls...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allFirewalls, nil
} }
// FirewallCreateOpts specifies options for creating a new Firewall. // FirewallCreateOpts specifies options for creating a new Firewall.
@ -210,7 +182,7 @@ type FirewallCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o FirewallCreateOpts) Validate() error { func (o FirewallCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
return nil return nil
} }
@ -223,28 +195,27 @@ type FirewallCreateResult struct {
// Create creates a new Firewall. // Create creates a new Firewall.
func (c *FirewallClient) Create(ctx context.Context, opts FirewallCreateOpts) (FirewallCreateResult, *Response, error) { func (c *FirewallClient) Create(ctx context.Context, opts FirewallCreateOpts) (FirewallCreateResult, *Response, error) {
const opPath = "/firewalls"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := FirewallCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return FirewallCreateResult{}, nil, err return result, nil, err
}
reqBody := firewallCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return FirewallCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/firewalls", bytes.NewReader(reqBodyData))
if err != nil {
return FirewallCreateResult{}, nil, err
} }
respBody := schema.FirewallCreateResponse{} reqBody := firewallCreateOptsToSchema(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FirewallCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return FirewallCreateResult{}, resp, err return result, resp, err
}
result := FirewallCreateResult{
Firewall: FirewallFromSchema(respBody.Firewall),
Actions: ActionsFromSchema(respBody.Actions),
} }
result.Firewall = FirewallFromSchema(respBody.Firewall)
result.Actions = ActionsFromSchema(respBody.Actions)
return result, resp, nil return result, resp, nil
} }
@ -256,6 +227,11 @@ type FirewallUpdateOpts struct {
// Update updates a Firewall. // Update updates a Firewall.
func (c *FirewallClient) Update(ctx context.Context, firewall *Firewall, opts FirewallUpdateOpts) (*Firewall, *Response, error) { func (c *FirewallClient) Update(ctx context.Context, firewall *Firewall, opts FirewallUpdateOpts) (*Firewall, *Response, error) {
const opPath = "/firewalls/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
reqBody := schema.FirewallUpdateRequest{} reqBody := schema.FirewallUpdateRequest{}
if opts.Name != "" { if opts.Name != "" {
reqBody.Name = &opts.Name reqBody.Name = &opts.Name
@ -263,32 +239,23 @@ func (c *FirewallClient) Update(ctx context.Context, firewall *Firewall, opts Fi
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d", firewall.ID) respBody, resp, err := putRequest[schema.FirewallUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FirewallUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return FirewallFromSchema(respBody.Firewall), resp, nil return FirewallFromSchema(respBody.Firewall), resp, nil
} }
// Delete deletes a Firewall. // Delete deletes a Firewall.
func (c *FirewallClient) Delete(ctx context.Context, firewall *Firewall) (*Response, error) { func (c *FirewallClient) Delete(ctx context.Context, firewall *Firewall) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/firewalls/%d", firewall.ID), nil) const opPath = "/firewalls/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, firewall.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// FirewallSetRulesOpts specifies options for setting rules of a Firewall. // FirewallSetRulesOpts specifies options for setting rules of a Firewall.
@ -298,75 +265,59 @@ type FirewallSetRulesOpts struct {
// SetRules sets the rules of a Firewall. // SetRules sets the rules of a Firewall.
func (c *FirewallClient) SetRules(ctx context.Context, firewall *Firewall, opts FirewallSetRulesOpts) ([]*Action, *Response, error) { func (c *FirewallClient) SetRules(ctx context.Context, firewall *Firewall, opts FirewallSetRulesOpts) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/set_rules"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
reqBody := firewallSetRulesOptsToSchema(opts) reqBody := firewallSetRulesOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody) respBody, resp, err := postRequest[schema.FirewallActionSetRulesResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/set_rules", firewall.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionSetRulesResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionsFromSchema(respBody.Actions), resp, nil return ActionsFromSchema(respBody.Actions), resp, nil
} }
func (c *FirewallClient) ApplyResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) { func (c *FirewallClient) ApplyResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/apply_to_resources"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
applyTo := make([]schema.FirewallResource, len(resources)) applyTo := make([]schema.FirewallResource, len(resources))
for i, r := range resources { for i, r := range resources {
applyTo[i] = firewallResourceToSchema(r) applyTo[i] = firewallResourceToSchema(r)
} }
reqBody := schema.FirewallActionApplyToResourcesRequest{ApplyTo: applyTo} reqBody := schema.FirewallActionApplyToResourcesRequest{ApplyTo: applyTo}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/apply_to_resources", firewall.ID) respBody, resp, err := postRequest[schema.FirewallActionApplyToResourcesResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionApplyToResourcesResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionsFromSchema(respBody.Actions), resp, nil return ActionsFromSchema(respBody.Actions), resp, nil
} }
func (c *FirewallClient) RemoveResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) { func (c *FirewallClient) RemoveResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/remove_from_resources"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
removeFrom := make([]schema.FirewallResource, len(resources)) removeFrom := make([]schema.FirewallResource, len(resources))
for i, r := range resources { for i, r := range resources {
removeFrom[i] = firewallResourceToSchema(r) removeFrom[i] = firewallResourceToSchema(r)
} }
reqBody := schema.FirewallActionRemoveFromResourcesRequest{RemoveFrom: removeFrom} reqBody := schema.FirewallActionRemoveFromResourcesRequest{RemoveFrom: removeFrom}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/remove_from_resources", firewall.ID) respBody, resp, err := postRequest[schema.FirewallActionRemoveFromResourcesResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionRemoveFromResourcesResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionsFromSchema(respBody.Actions), resp, nil return ActionsFromSchema(respBody.Actions), resp, nil
} }

View File

@ -1,16 +1,13 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -54,26 +51,21 @@ const (
// changeDNSPtr changes or resets the reverse DNS pointer for an IP address. // changeDNSPtr changes or resets the reverse DNS pointer for an IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value. // Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (f *FloatingIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) { func (f *FloatingIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, f.ID)
reqBody := schema.FloatingIPActionChangeDNSPtrRequest{ reqBody := schema.FloatingIPActionChangeDNSPtrRequest{
IP: ip.String(), IP: ip.String(),
DNSPtr: ptr, DNSPtr: ptr,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/change_dns_ptr", f.ID) respBody, resp, err := postRequest[schema.FloatingIPActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPActionChangeDNSPtrResponse{}
resp, err := client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -97,41 +89,33 @@ type FloatingIPClient struct {
// GetByID retrieves a Floating IP by its ID. If the Floating IP does not exist, // GetByID retrieves a Floating IP by its ID. If the Floating IP does not exist,
// nil is returned. // nil is returned.
func (c *FloatingIPClient) GetByID(ctx context.Context, id int64) (*FloatingIP, *Response, error) { func (c *FloatingIPClient) GetByID(ctx context.Context, id int64) (*FloatingIP, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/floating_ips/%d", id), nil) const opPath = "/floating_ips/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.FloatingIPGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.FloatingIPGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, resp, err return nil, resp, err
} }
return FloatingIPFromSchema(body.FloatingIP), resp, nil
return FloatingIPFromSchema(respBody.FloatingIP), resp, nil
} }
// GetByName retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned. // GetByName retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned.
func (c *FloatingIPClient) GetByName(ctx context.Context, name string) (*FloatingIP, *Response, error) { func (c *FloatingIPClient) GetByName(ctx context.Context, name string) (*FloatingIP, *Response, error) {
if name == "" { return firstByName(name, func() ([]*FloatingIP, *Response, error) {
return nil, nil, nil return c.List(ctx, FloatingIPListOpts{Name: name})
} })
floatingIPs, response, err := c.List(ctx, FloatingIPListOpts{Name: name})
if len(floatingIPs) == 0 {
return nil, response, err
}
return floatingIPs[0], response, err
} }
// Get retrieves a Floating IP by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Floating IP by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned. // retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned.
func (c *FloatingIPClient) Get(ctx context.Context, idOrName string) (*FloatingIP, *Response, error) { func (c *FloatingIPClient) Get(ctx context.Context, idOrName string) (*FloatingIP, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// FloatingIPListOpts specifies options for listing Floating IPs. // FloatingIPListOpts specifies options for listing Floating IPs.
@ -157,22 +141,17 @@ func (l FloatingIPListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *FloatingIPClient) List(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, *Response, error) { func (c *FloatingIPClient) List(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, *Response, error) {
path := "/floating_ips?" + opts.values().Encode() const opPath = "/floating_ips?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.FloatingIPListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.FloatingIPListResponse return allFromSchemaFunc(respBody.FloatingIPs, FloatingIPFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
floatingIPs := make([]*FloatingIP, 0, len(body.FloatingIPs))
for _, s := range body.FloatingIPs {
floatingIPs = append(floatingIPs, FloatingIPFromSchema(s))
}
return floatingIPs, resp, nil
} }
// All returns all Floating IPs. // All returns all Floating IPs.
@ -182,22 +161,10 @@ func (c *FloatingIPClient) All(ctx context.Context) ([]*FloatingIP, error) {
// AllWithOpts returns all Floating IPs for the given options. // AllWithOpts returns all Floating IPs for the given options.
func (c *FloatingIPClient) AllWithOpts(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, error) { func (c *FloatingIPClient) AllWithOpts(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, error) {
allFloatingIPs := []*FloatingIP{} return iterPages(func(page int) ([]*FloatingIP, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
floatingIPs, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allFloatingIPs = append(allFloatingIPs, floatingIPs...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allFloatingIPs, nil
} }
// FloatingIPCreateOpts specifies options for creating a Floating IP. // FloatingIPCreateOpts specifies options for creating a Floating IP.
@ -216,10 +183,10 @@ func (o FloatingIPCreateOpts) Validate() error {
case FloatingIPTypeIPv4, FloatingIPTypeIPv6: case FloatingIPTypeIPv4, FloatingIPTypeIPv6:
break break
default: default:
return errors.New("missing or invalid type") return invalidFieldValue(o, "Type", o.Type)
} }
if o.HomeLocation == nil && o.Server == nil { if o.HomeLocation == nil && o.Server == nil {
return errors.New("one of home location or server is required") return missingOneOfFields(o, "HomeLocation", "Server")
} }
return nil return nil
} }
@ -232,8 +199,15 @@ type FloatingIPCreateResult struct {
// Create creates a Floating IP. // Create creates a Floating IP.
func (c *FloatingIPClient) Create(ctx context.Context, opts FloatingIPCreateOpts) (FloatingIPCreateResult, *Response, error) { func (c *FloatingIPClient) Create(ctx context.Context, opts FloatingIPCreateOpts) (FloatingIPCreateResult, *Response, error) {
result := FloatingIPCreateResult{}
const opPath = "/floating_ips"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return FloatingIPCreateResult{}, nil, err return result, nil, err
} }
reqBody := schema.FloatingIPCreateRequest{ reqBody := schema.FloatingIPCreateRequest{
@ -250,38 +224,28 @@ func (c *FloatingIPClient) Create(ctx context.Context, opts FloatingIPCreateOpts
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.FloatingIPCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return FloatingIPCreateResult{}, nil, err return result, resp, err
} }
req, err := c.client.NewRequest(ctx, "POST", "/floating_ips", bytes.NewReader(reqBodyData)) result.FloatingIP = FloatingIPFromSchema(respBody.FloatingIP)
if err != nil {
return FloatingIPCreateResult{}, nil, err
}
var respBody schema.FloatingIPCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return FloatingIPCreateResult{}, resp, err
}
var action *Action
if respBody.Action != nil { if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action) result.Action = ActionFromSchema(*respBody.Action)
} }
return FloatingIPCreateResult{
FloatingIP: FloatingIPFromSchema(respBody.FloatingIP), return result, resp, nil
Action: action,
}, resp, nil
} }
// Delete deletes a Floating IP. // Delete deletes a Floating IP.
func (c *FloatingIPClient) Delete(ctx context.Context, floatingIP *FloatingIP) (*Response, error) { func (c *FloatingIPClient) Delete(ctx context.Context, floatingIP *FloatingIP) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/floating_ips/%d", floatingIP.ID), nil) const opPath = "/floating_ips/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, floatingIP.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// FloatingIPUpdateOpts specifies options for updating a Floating IP. // FloatingIPUpdateOpts specifies options for updating a Floating IP.
@ -293,6 +257,11 @@ type FloatingIPUpdateOpts struct {
// Update updates a Floating IP. // Update updates a Floating IP.
func (c *FloatingIPClient) Update(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPUpdateOpts) (*FloatingIP, *Response, error) { func (c *FloatingIPClient) Update(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPUpdateOpts) (*FloatingIP, *Response, error) {
const opPath = "/floating_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPUpdateRequest{ reqBody := schema.FloatingIPUpdateRequest{
Description: opts.Description, Description: opts.Description,
Name: opts.Name, Name: opts.Name,
@ -300,68 +269,48 @@ func (c *FloatingIPClient) Update(ctx context.Context, floatingIP *FloatingIP, o
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d", floatingIP.ID) respBody, resp, err := putRequest[schema.FloatingIPUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return FloatingIPFromSchema(respBody.FloatingIP), resp, nil return FloatingIPFromSchema(respBody.FloatingIP), resp, nil
} }
// Assign assigns a Floating IP to a server. // Assign assigns a Floating IP to a server.
func (c *FloatingIPClient) Assign(ctx context.Context, floatingIP *FloatingIP, server *Server) (*Action, *Response, error) { func (c *FloatingIPClient) Assign(ctx context.Context, floatingIP *FloatingIP, server *Server) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/assign"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPActionAssignRequest{ reqBody := schema.FloatingIPActionAssignRequest{
Server: server.ID, Server: server.ID,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/assign", floatingIP.ID) respBody, resp, err := postRequest[schema.FloatingIPActionAssignResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FloatingIPActionAssignResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// Unassign unassigns a Floating IP from the currently assigned server. // Unassign unassigns a Floating IP from the currently assigned server.
func (c *FloatingIPClient) Unassign(ctx context.Context, floatingIP *FloatingIP) (*Action, *Response, error) { func (c *FloatingIPClient) Unassign(ctx context.Context, floatingIP *FloatingIP) (*Action, *Response, error) {
var reqBody schema.FloatingIPActionUnassignRequest const opPath = "/floating_ips/%d/actions/unassign"
reqBodyData, err := json.Marshal(reqBody) ctx = ctxutil.SetOpPath(ctx, opPath)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/unassign", floatingIP.ID) reqPath := fmt.Sprintf(opPath, floatingIP.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FloatingIPActionUnassignResponse reqBody := schema.FloatingIPActionUnassignRequest{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FloatingIPActionUnassignResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -382,24 +331,19 @@ type FloatingIPChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a Floating IP. // ChangeProtection changes the resource protection level of a Floating IP.
func (c *FloatingIPClient) ChangeProtection(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPChangeProtectionOpts) (*Action, *Response, error) { func (c *FloatingIPClient) ChangeProtection(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPActionChangeProtectionRequest{ reqBody := schema.FloatingIPActionChangeProtectionRequest{
Delete: opts.Delete, Delete: opts.Delete,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/change_protection", floatingIP.ID) respBody, resp, err := postRequest[schema.FloatingIPActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }

View File

@ -1,5 +1,34 @@
// Package hcloud is a library for the Hetzner Cloud API. /*
Package hcloud is a library for the Hetzner Cloud API.
The Hetzner Cloud API reference is available at https://docs.hetzner.cloud.
Make sure to follow our API changelog available at https://docs.hetzner.cloud/changelog
(or the RRS feed available at https://docs.hetzner.cloud/changelog/feed.rss) to be
notified about additions, deprecations and removals.
# Retry mechanism
The [Client.Do] method will retry failed requests that match certain criteria. The
default retry interval is defined by an exponential backoff algorithm truncated to 60s
with jitter. The default maximal number of retries is 5.
The following rules defines when a request can be retried:
When the [http.Client] returned a network timeout error.
When the API returned an HTTP error, with the status code:
- [http.StatusBadGateway]
- [http.StatusGatewayTimeout]
When the API returned an application error, with the code:
- [ErrorCodeConflict]
- [ErrorCodeRateLimitExceeded]
Changes to the retry policy might occur between releases, and will not be considered
breaking changes.
*/
package hcloud package hcloud
// Version is the library's version following Semantic Versioning. // Version is the library's version following Semantic Versioning.
const Version = "2.8.0" // x-release-please-version const Version = "2.21.1" // x-releaser-pleaser-version

View File

@ -1,14 +1,13 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -83,34 +82,29 @@ type ImageClient struct {
// GetByID retrieves an image by its ID. If the image does not exist, nil is returned. // GetByID retrieves an image by its ID. If the image does not exist, nil is returned.
func (c *ImageClient) GetByID(ctx context.Context, id int64) (*Image, *Response, error) { func (c *ImageClient) GetByID(ctx context.Context, id int64) (*Image, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/images/%d", id), nil) const opPath = "/images/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.ImageGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.ImageGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return ImageFromSchema(body.Image), resp, nil
return ImageFromSchema(respBody.Image), resp, nil
} }
// GetByName retrieves an image by its name. If the image does not exist, nil is returned. // GetByName retrieves an image by its name. If the image does not exist, nil is returned.
// //
// Deprecated: Use [ImageClient.GetByNameAndArchitecture] instead. // Deprecated: Use [ImageClient.GetByNameAndArchitecture] instead.
func (c *ImageClient) GetByName(ctx context.Context, name string) (*Image, *Response, error) { func (c *ImageClient) GetByName(ctx context.Context, name string) (*Image, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Image, *Response, error) {
return nil, nil, nil return c.List(ctx, ImageListOpts{Name: name})
} })
images, response, err := c.List(ctx, ImageListOpts{Name: name})
if len(images) == 0 {
return nil, response, err
}
return images[0], response, err
} }
// GetByNameAndArchitecture retrieves an image by its name and architecture. If the image does not exist, // GetByNameAndArchitecture retrieves an image by its name and architecture. If the image does not exist,
@ -118,14 +112,9 @@ func (c *ImageClient) GetByName(ctx context.Context, name string) (*Image, *Resp
// In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should // In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should
// check for this in your calling method. // check for this in your calling method.
func (c *ImageClient) GetByNameAndArchitecture(ctx context.Context, name string, architecture Architecture) (*Image, *Response, error) { func (c *ImageClient) GetByNameAndArchitecture(ctx context.Context, name string, architecture Architecture) (*Image, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Image, *Response, error) {
return nil, nil, nil return c.List(ctx, ImageListOpts{Name: name, Architecture: []Architecture{architecture}, IncludeDeprecated: true})
} })
images, response, err := c.List(ctx, ImageListOpts{Name: name, Architecture: []Architecture{architecture}, IncludeDeprecated: true})
if len(images) == 0 {
return nil, response, err
}
return images[0], response, err
} }
// Get retrieves an image by its ID if the input can be parsed as an integer, otherwise it // Get retrieves an image by its ID if the input can be parsed as an integer, otherwise it
@ -133,10 +122,7 @@ func (c *ImageClient) GetByNameAndArchitecture(ctx context.Context, name string,
// //
// Deprecated: Use [ImageClient.GetForArchitecture] instead. // Deprecated: Use [ImageClient.GetForArchitecture] instead.
func (c *ImageClient) Get(ctx context.Context, idOrName string) (*Image, *Response, error) { func (c *ImageClient) Get(ctx context.Context, idOrName string) (*Image, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// GetForArchitecture retrieves an image by its ID if the input can be parsed as an integer, otherwise it // GetForArchitecture retrieves an image by its ID if the input can be parsed as an integer, otherwise it
@ -145,10 +131,13 @@ func (c *ImageClient) Get(ctx context.Context, idOrName string) (*Image, *Respon
// In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should // In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should
// check for this in your calling method. // check for this in your calling method.
func (c *ImageClient) GetForArchitecture(ctx context.Context, idOrName string, architecture Architecture) (*Image, *Response, error) { func (c *ImageClient) GetForArchitecture(ctx context.Context, idOrName string, architecture Architecture) (*Image, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx,
return c.GetByID(ctx, id) c.GetByID,
} func(ctx context.Context, name string) (*Image, *Response, error) {
return c.GetByNameAndArchitecture(ctx, idOrName, architecture) return c.GetByNameAndArchitecture(ctx, name, architecture)
},
idOrName,
)
} }
// ImageListOpts specifies options for listing images. // ImageListOpts specifies options for listing images.
@ -194,22 +183,17 @@ func (l ImageListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *ImageClient) List(ctx context.Context, opts ImageListOpts) ([]*Image, *Response, error) { func (c *ImageClient) List(ctx context.Context, opts ImageListOpts) ([]*Image, *Response, error) {
path := "/images?" + opts.values().Encode() const opPath = "/images?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ImageListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.ImageListResponse return allFromSchemaFunc(respBody.Images, ImageFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
images := make([]*Image, 0, len(body.Images))
for _, i := range body.Images {
images = append(images, ImageFromSchema(i))
}
return images, resp, nil
} }
// All returns all images. // All returns all images.
@ -219,31 +203,20 @@ func (c *ImageClient) All(ctx context.Context) ([]*Image, error) {
// AllWithOpts returns all images for the given options. // AllWithOpts returns all images for the given options.
func (c *ImageClient) AllWithOpts(ctx context.Context, opts ImageListOpts) ([]*Image, error) { func (c *ImageClient) AllWithOpts(ctx context.Context, opts ImageListOpts) ([]*Image, error) {
allImages := []*Image{} return iterPages(func(page int) ([]*Image, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
images, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allImages = append(allImages, images...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allImages, nil
} }
// Delete deletes an image. // Delete deletes an image.
func (c *ImageClient) Delete(ctx context.Context, image *Image) (*Response, error) { func (c *ImageClient) Delete(ctx context.Context, image *Image) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/images/%d", image.ID), nil) const opPath = "/images/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, image.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// ImageUpdateOpts specifies options for updating an image. // ImageUpdateOpts specifies options for updating an image.
@ -255,6 +228,11 @@ type ImageUpdateOpts struct {
// Update updates an image. // Update updates an image.
func (c *ImageClient) Update(ctx context.Context, image *Image, opts ImageUpdateOpts) (*Image, *Response, error) { func (c *ImageClient) Update(ctx context.Context, image *Image, opts ImageUpdateOpts) (*Image, *Response, error) {
const opPath = "/images/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, image.ID)
reqBody := schema.ImageUpdateRequest{ reqBody := schema.ImageUpdateRequest{
Description: opts.Description, Description: opts.Description,
} }
@ -264,22 +242,12 @@ func (c *ImageClient) Update(ctx context.Context, image *Image, opts ImageUpdate
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/images/%d", image.ID) respBody, resp, err := putRequest[schema.ImageUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.ImageUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ImageFromSchema(respBody.Image), resp, nil return ImageFromSchema(respBody.Image), resp, nil
} }
@ -290,24 +258,19 @@ type ImageChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of an image. // ChangeProtection changes the resource protection level of an image.
func (c *ImageClient) ChangeProtection(ctx context.Context, image *Image, opts ImageChangeProtectionOpts) (*Action, *Response, error) { func (c *ImageClient) ChangeProtection(ctx context.Context, image *Image, opts ImageChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/images/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, image.ID)
reqBody := schema.ImageActionChangeProtectionRequest{ reqBody := schema.ImageActionChangeProtectionRequest{
Delete: opts.Delete, Delete: opts.Delete,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/images/%d/actions/change_protection", image.ID) respBody, resp, err := postRequest[schema.ImageActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.ImageActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }

View File

@ -1,6 +1,7 @@
package instrumentation package instrumentation
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
@ -9,6 +10,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
) )
type Instrumenter struct { type Instrumenter struct {
@ -22,7 +25,12 @@ func New(subsystemIdentifier string, instrumentationRegistry prometheus.Register
} }
// InstrumentedRoundTripper returns an instrumented round tripper. // InstrumentedRoundTripper returns an instrumented round tripper.
func (i *Instrumenter) InstrumentedRoundTripper() http.RoundTripper { func (i *Instrumenter) InstrumentedRoundTripper(transport http.RoundTripper) http.RoundTripper {
// By default, http client would use DefaultTransport on nil, but we internally are relying on it being configured
if transport == nil {
transport = http.DefaultTransport
}
inFlightRequestsGauge := registerOrReuse( inFlightRequestsGauge := registerOrReuse(
i.instrumentationRegistry, i.instrumentationRegistry,
prometheus.NewGauge(prometheus.GaugeOpts{ prometheus.NewGauge(prometheus.GaugeOpts{
@ -57,7 +65,7 @@ func (i *Instrumenter) InstrumentedRoundTripper() http.RoundTripper {
return promhttp.InstrumentRoundTripperInFlight(inFlightRequestsGauge, return promhttp.InstrumentRoundTripperInFlight(inFlightRequestsGauge,
promhttp.InstrumentRoundTripperDuration(requestLatencyHistogram, promhttp.InstrumentRoundTripperDuration(requestLatencyHistogram,
i.instrumentRoundTripperEndpoint(requestsPerEndpointCounter, i.instrumentRoundTripperEndpoint(requestsPerEndpointCounter,
http.DefaultTransport, transport,
), ),
), ),
) )
@ -73,8 +81,17 @@ func (i *Instrumenter) instrumentRoundTripperEndpoint(counter *prometheus.Counte
return func(r *http.Request) (*http.Response, error) { return func(r *http.Request) (*http.Response, error) {
resp, err := next.RoundTrip(r) resp, err := next.RoundTrip(r)
if err == nil { if err == nil {
statusCode := strconv.Itoa(resp.StatusCode) apiEndpoint := ctxutil.OpPath(r.Context())
counter.WithLabelValues(statusCode, strings.ToLower(resp.Request.Method), preparePathForLabel(resp.Request.URL.Path)).Inc() // If the request does not set the operation path, we must construct it. Happens e.g. for
// user crafted requests.
if apiEndpoint == "" {
apiEndpoint = preparePathForLabel(resp.Request.URL.Path)
}
counter.WithLabelValues(
strconv.Itoa(resp.StatusCode),
strings.ToLower(resp.Request.Method),
apiEndpoint,
).Inc()
} }
return resp, err return resp, err
@ -87,9 +104,10 @@ func (i *Instrumenter) instrumentRoundTripperEndpoint(counter *prometheus.Counte
func registerOrReuse[C prometheus.Collector](registry prometheus.Registerer, collector C) C { func registerOrReuse[C prometheus.Collector](registry prometheus.Registerer, collector C) C {
err := registry.Register(collector) err := registry.Register(collector)
if err != nil { if err != nil {
var arErr prometheus.AlreadyRegisteredError
// If we get a AlreadyRegisteredError we can return the existing collector // If we get a AlreadyRegisteredError we can return the existing collector
if are, ok := err.(prometheus.AlreadyRegisteredError); ok { if errors.As(err, &arErr) {
if existingCollector, ok := are.ExistingCollector.(C); ok { if existingCollector, ok := arErr.ExistingCollector.(C); ok {
collector = existingCollector collector = existingCollector
} else { } else {
panic("received incompatible existing collector") panic("received incompatible existing collector")
@ -102,16 +120,16 @@ func registerOrReuse[C prometheus.Collector](registry prometheus.Registerer, col
return collector return collector
} }
var pathLabelRegexp = regexp.MustCompile("[^a-z/_]+")
func preparePathForLabel(path string) string { func preparePathForLabel(path string) string {
path = strings.ToLower(path) path = strings.ToLower(path)
// replace the /v1/ that indicated the API version
path, _ = strings.CutPrefix(path, "/v1")
// replace all numbers and chars that are not a-z, / or _ // replace all numbers and chars that are not a-z, / or _
reg := regexp.MustCompile("[^a-z/_]+") path = pathLabelRegexp.ReplaceAllString(path, "-")
path = reg.ReplaceAllString(path, "")
// replace all artifacts of number replacement (//) return path
path = strings.ReplaceAll(path, "//", "/")
// replace the /v/ that indicated the API version
return strings.Replace(path, "/v/", "/", 1)
} }

View File

@ -4,9 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -40,40 +40,32 @@ type ISOClient struct {
// GetByID retrieves an ISO by its ID. // GetByID retrieves an ISO by its ID.
func (c *ISOClient) GetByID(ctx context.Context, id int64) (*ISO, *Response, error) { func (c *ISOClient) GetByID(ctx context.Context, id int64) (*ISO, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/isos/%d", id), nil) const opPath = "/isos/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.ISOGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.ISOGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, resp, err return nil, resp, err
} }
return ISOFromSchema(body.ISO), resp, nil
return ISOFromSchema(respBody.ISO), resp, nil
} }
// GetByName retrieves an ISO by its name. // GetByName retrieves an ISO by its name.
func (c *ISOClient) GetByName(ctx context.Context, name string) (*ISO, *Response, error) { func (c *ISOClient) GetByName(ctx context.Context, name string) (*ISO, *Response, error) {
if name == "" { return firstByName(name, func() ([]*ISO, *Response, error) {
return nil, nil, nil return c.List(ctx, ISOListOpts{Name: name})
} })
isos, response, err := c.List(ctx, ISOListOpts{Name: name})
if len(isos) == 0 {
return nil, response, err
}
return isos[0], response, err
} }
// Get retrieves an ISO by its ID if the input can be parsed as an integer, otherwise it retrieves an ISO by its name. // Get retrieves an ISO by its ID if the input can be parsed as an integer, otherwise it retrieves an ISO by its name.
func (c *ISOClient) Get(ctx context.Context, idOrName string) (*ISO, *Response, error) { func (c *ISOClient) Get(ctx context.Context, idOrName string) (*ISO, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// ISOListOpts specifies options for listing isos. // ISOListOpts specifies options for listing isos.
@ -115,22 +107,17 @@ func (l ISOListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *ISOClient) List(ctx context.Context, opts ISOListOpts) ([]*ISO, *Response, error) { func (c *ISOClient) List(ctx context.Context, opts ISOListOpts) ([]*ISO, *Response, error) {
path := "/isos?" + opts.values().Encode() const opPath = "/isos?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ISOListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.ISOListResponse return allFromSchemaFunc(respBody.ISOs, ISOFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
isos := make([]*ISO, 0, len(body.ISOs))
for _, i := range body.ISOs {
isos = append(isos, ISOFromSchema(i))
}
return isos, resp, nil
} }
// All returns all ISOs. // All returns all ISOs.
@ -140,20 +127,8 @@ func (c *ISOClient) All(ctx context.Context) ([]*ISO, error) {
// AllWithOpts returns all ISOs for the given options. // AllWithOpts returns all ISOs for the given options.
func (c *ISOClient) AllWithOpts(ctx context.Context, opts ISOListOpts) ([]*ISO, error) { func (c *ISOClient) AllWithOpts(ctx context.Context, opts ISOListOpts) ([]*ISO, error) {
allISOs := []*ISO{} return iterPages(func(page int) ([]*ISO, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
isos, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allISOs = append(allISOs, isos...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allISOs, nil
} }

View File

@ -1,16 +1,14 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strconv" "strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -200,26 +198,21 @@ type LoadBalancerProtection struct {
// changeDNSPtr changes or resets the reverse DNS pointer for an IP address. // changeDNSPtr changes or resets the reverse DNS pointer for an IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value. // Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (lb *LoadBalancer) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) { func (lb *LoadBalancer) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, lb.ID)
reqBody := schema.LoadBalancerActionChangeDNSPtrRequest{ reqBody := schema.LoadBalancerActionChangeDNSPtrRequest{
IP: ip.String(), IP: ip.String(),
DNSPtr: ptr, DNSPtr: ptr,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_dns_ptr", lb.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeDNSPtrResponse{}
resp, err := client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -243,41 +236,33 @@ type LoadBalancerClient struct {
// GetByID retrieves a Load Balancer by its ID. If the Load Balancer does not exist, nil is returned. // GetByID retrieves a Load Balancer by its ID. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) GetByID(ctx context.Context, id int64) (*LoadBalancer, *Response, error) { func (c *LoadBalancerClient) GetByID(ctx context.Context, id int64) (*LoadBalancer, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/load_balancers/%d", id), nil) const opPath = "/load_balancers/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.LoadBalancerGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.LoadBalancerGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return LoadBalancerFromSchema(body.LoadBalancer), resp, nil
return LoadBalancerFromSchema(respBody.LoadBalancer), resp, nil
} }
// GetByName retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned. // GetByName retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) GetByName(ctx context.Context, name string) (*LoadBalancer, *Response, error) { func (c *LoadBalancerClient) GetByName(ctx context.Context, name string) (*LoadBalancer, *Response, error) {
if name == "" { return firstByName(name, func() ([]*LoadBalancer, *Response, error) {
return nil, nil, nil return c.List(ctx, LoadBalancerListOpts{Name: name})
} })
LoadBalancer, response, err := c.List(ctx, LoadBalancerListOpts{Name: name})
if len(LoadBalancer) == 0 {
return nil, response, err
}
return LoadBalancer[0], response, err
} }
// Get retrieves a Load Balancer by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Load Balancer by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned. // retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) Get(ctx context.Context, idOrName string) (*LoadBalancer, *Response, error) { func (c *LoadBalancerClient) Get(ctx context.Context, idOrName string) (*LoadBalancer, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// LoadBalancerListOpts specifies options for listing Load Balancers. // LoadBalancerListOpts specifies options for listing Load Balancers.
@ -303,22 +288,17 @@ func (l LoadBalancerListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *LoadBalancerClient) List(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, *Response, error) { func (c *LoadBalancerClient) List(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, *Response, error) {
path := "/load_balancers?" + opts.values().Encode() const opPath = "/load_balancers?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.LoadBalancerListResponse return allFromSchemaFunc(respBody.LoadBalancers, LoadBalancerFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
LoadBalancers := make([]*LoadBalancer, 0, len(body.LoadBalancers))
for _, s := range body.LoadBalancers {
LoadBalancers = append(LoadBalancers, LoadBalancerFromSchema(s))
}
return LoadBalancers, resp, nil
} }
// All returns all Load Balancers. // All returns all Load Balancers.
@ -328,22 +308,10 @@ func (c *LoadBalancerClient) All(ctx context.Context) ([]*LoadBalancer, error) {
// AllWithOpts returns all Load Balancers for the given options. // AllWithOpts returns all Load Balancers for the given options.
func (c *LoadBalancerClient) AllWithOpts(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, error) { func (c *LoadBalancerClient) AllWithOpts(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, error) {
allLoadBalancers := []*LoadBalancer{} return iterPages(func(page int) ([]*LoadBalancer, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
LoadBalancers, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allLoadBalancers = append(allLoadBalancers, LoadBalancers...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allLoadBalancers, nil
} }
// LoadBalancerUpdateOpts specifies options for updating a Load Balancer. // LoadBalancerUpdateOpts specifies options for updating a Load Balancer.
@ -354,6 +322,11 @@ type LoadBalancerUpdateOpts struct {
// Update updates a Load Balancer. // Update updates a Load Balancer.
func (c *LoadBalancerClient) Update(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerUpdateOpts) (*LoadBalancer, *Response, error) { func (c *LoadBalancerClient) Update(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerUpdateOpts) (*LoadBalancer, *Response, error) {
const opPath = "/load_balancers/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerUpdateRequest{} reqBody := schema.LoadBalancerUpdateRequest{}
if opts.Name != "" { if opts.Name != "" {
reqBody.Name = &opts.Name reqBody.Name = &opts.Name
@ -361,22 +334,12 @@ func (c *LoadBalancerClient) Update(ctx context.Context, loadBalancer *LoadBalan
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d", loadBalancer.ID) respBody, resp, err := putRequest[schema.LoadBalancerUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return LoadBalancerFromSchema(respBody.LoadBalancer), resp, nil return LoadBalancerFromSchema(respBody.LoadBalancer), resp, nil
} }
@ -472,73 +435,61 @@ type LoadBalancerCreateResult struct {
// Create creates a new Load Balancer. // Create creates a new Load Balancer.
func (c *LoadBalancerClient) Create(ctx context.Context, opts LoadBalancerCreateOpts) (LoadBalancerCreateResult, *Response, error) { func (c *LoadBalancerClient) Create(ctx context.Context, opts LoadBalancerCreateOpts) (LoadBalancerCreateResult, *Response, error) {
const opPath = "/load_balancers"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := LoadBalancerCreateResult{}
reqPath := opPath
reqBody := loadBalancerCreateOptsToSchema(opts) reqBody := loadBalancerCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.LoadBalancerCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return LoadBalancerCreateResult{}, nil, err return result, resp, err
}
req, err := c.client.NewRequest(ctx, "POST", "/load_balancers", bytes.NewReader(reqBodyData))
if err != nil {
return LoadBalancerCreateResult{}, nil, err
} }
respBody := schema.LoadBalancerCreateResponse{} result.LoadBalancer = LoadBalancerFromSchema(respBody.LoadBalancer)
resp, err := c.client.Do(req, &respBody) result.Action = ActionFromSchema(respBody.Action)
if err != nil {
return LoadBalancerCreateResult{}, resp, err return result, resp, nil
}
return LoadBalancerCreateResult{
LoadBalancer: LoadBalancerFromSchema(respBody.LoadBalancer),
Action: ActionFromSchema(respBody.Action),
}, resp, nil
} }
// Delete deletes a Load Balancer. // Delete deletes a Load Balancer.
func (c *LoadBalancerClient) Delete(ctx context.Context, loadBalancer *LoadBalancer) (*Response, error) { func (c *LoadBalancerClient) Delete(ctx context.Context, loadBalancer *LoadBalancer) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/load_balancers/%d", loadBalancer.ID), nil) const opPath = "/load_balancers/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
func (c *LoadBalancerClient) addTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionAddTargetRequest) (*Action, *Response, error) { func (c *LoadBalancerClient) addTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionAddTargetRequest) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(reqBody) const opPath = "/load_balancers/%d/actions/add_target"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/add_target", loadBalancer.ID) reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionAddTargetResponse respBody, resp, err := postRequest[schema.LoadBalancerActionAddTargetResponse](ctx, c.client, reqPath, reqBody)
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
func (c *LoadBalancerClient) removeTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionRemoveTargetRequest) (*Action, *Response, error) { func (c *LoadBalancerClient) removeTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionRemoveTargetRequest) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(reqBody) const opPath = "/load_balancers/%d/actions/remove_target"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/remove_target", loadBalancer.ID) reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionRemoveTargetResponse respBody, resp, err := postRequest[schema.LoadBalancerActionRemoveTargetResponse](ctx, c.client, reqPath, reqBody)
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -671,23 +622,18 @@ type LoadBalancerAddServiceOptsHealthCheckHTTP struct {
// AddService adds a service to a Load Balancer. // AddService adds a service to a Load Balancer.
func (c *LoadBalancerClient) AddService(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAddServiceOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) AddService(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAddServiceOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/add_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := loadBalancerAddServiceOptsToSchema(opts) reqBody := loadBalancerAddServiceOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/add_service", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionAddServiceResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionAddServiceResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -732,48 +678,38 @@ type LoadBalancerUpdateServiceOptsHealthCheckHTTP struct {
// UpdateService updates a Load Balancer service. // UpdateService updates a Load Balancer service.
func (c *LoadBalancerClient) UpdateService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int, opts LoadBalancerUpdateServiceOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) UpdateService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int, opts LoadBalancerUpdateServiceOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/update_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := loadBalancerUpdateServiceOptsToSchema(opts) reqBody := loadBalancerUpdateServiceOptsToSchema(opts)
reqBody.ListenPort = listenPort reqBody.ListenPort = listenPort
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/update_service", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionUpdateServiceResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionUpdateServiceResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// DeleteService deletes a Load Balancer service. // DeleteService deletes a Load Balancer service.
func (c *LoadBalancerClient) DeleteService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int) (*Action, *Response, error) { func (c *LoadBalancerClient) DeleteService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/delete_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerDeleteServiceRequest{ reqBody := schema.LoadBalancerDeleteServiceRequest{
ListenPort: listenPort, ListenPort: listenPort,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/delete_service", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerDeleteServiceResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerDeleteServiceResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -784,26 +720,21 @@ type LoadBalancerChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a Load Balancer. // ChangeProtection changes the resource protection level of a Load Balancer.
func (c *LoadBalancerClient) ChangeProtection(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeProtectionOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) ChangeProtection(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeProtectionRequest{ reqBody := schema.LoadBalancerActionChangeProtectionRequest{
Delete: opts.Delete, Delete: opts.Delete,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_protection", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// LoadBalancerChangeAlgorithmOpts specifies options for changing the algorithm of a Load Balancer. // LoadBalancerChangeAlgorithmOpts specifies options for changing the algorithm of a Load Balancer.
@ -813,26 +744,21 @@ type LoadBalancerChangeAlgorithmOpts struct {
// ChangeAlgorithm changes the algorithm of a Load Balancer. // ChangeAlgorithm changes the algorithm of a Load Balancer.
func (c *LoadBalancerClient) ChangeAlgorithm(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeAlgorithmOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) ChangeAlgorithm(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeAlgorithmOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_algorithm"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeAlgorithmRequest{ reqBody := schema.LoadBalancerActionChangeAlgorithmRequest{
Type: string(opts.Type), Type: string(opts.Type),
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_algorithm", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionChangeAlgorithmResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeAlgorithmResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// LoadBalancerAttachToNetworkOpts specifies options for attaching a Load Balancer to a network. // LoadBalancerAttachToNetworkOpts specifies options for attaching a Load Balancer to a network.
@ -843,29 +769,24 @@ type LoadBalancerAttachToNetworkOpts struct {
// AttachToNetwork attaches a Load Balancer to a network. // AttachToNetwork attaches a Load Balancer to a network.
func (c *LoadBalancerClient) AttachToNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAttachToNetworkOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) AttachToNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAttachToNetworkOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/attach_to_network"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionAttachToNetworkRequest{ reqBody := schema.LoadBalancerActionAttachToNetworkRequest{
Network: opts.Network.ID, Network: opts.Network.ID,
} }
if opts.IP != nil { if opts.IP != nil {
reqBody.IP = Ptr(opts.IP.String()) reqBody.IP = Ptr(opts.IP.String())
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/attach_to_network", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionAttachToNetworkResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionAttachToNetworkResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// LoadBalancerDetachFromNetworkOpts specifies options for detaching a Load Balancer from a network. // LoadBalancerDetachFromNetworkOpts specifies options for detaching a Load Balancer from a network.
@ -875,56 +796,51 @@ type LoadBalancerDetachFromNetworkOpts struct {
// DetachFromNetwork detaches a Load Balancer from a network. // DetachFromNetwork detaches a Load Balancer from a network.
func (c *LoadBalancerClient) DetachFromNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerDetachFromNetworkOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) DetachFromNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerDetachFromNetworkOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/detach_from_network"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionDetachFromNetworkRequest{ reqBody := schema.LoadBalancerActionDetachFromNetworkRequest{
Network: opts.Network.ID, Network: opts.Network.ID,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/detach_from_network", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionDetachFromNetworkResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionDetachFromNetworkResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// EnablePublicInterface enables the Load Balancer's public network interface. // EnablePublicInterface enables the Load Balancer's public network interface.
func (c *LoadBalancerClient) EnablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) { func (c *LoadBalancerClient) EnablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) {
path := fmt.Sprintf("/load_balancers/%d/actions/enable_public_interface", loadBalancer.ID) const opPath = "/load_balancers/%d/actions/enable_public_interface"
req, err := c.client.NewRequest(ctx, "POST", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
if err != nil {
return nil, nil, err reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
}
respBody := schema.LoadBalancerActionEnablePublicInterfaceResponse{} respBody, resp, err := postRequest[schema.LoadBalancerActionEnablePublicInterfaceResponse](ctx, c.client, reqPath, nil)
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// DisablePublicInterface disables the Load Balancer's public network interface. // DisablePublicInterface disables the Load Balancer's public network interface.
func (c *LoadBalancerClient) DisablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) { func (c *LoadBalancerClient) DisablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) {
path := fmt.Sprintf("/load_balancers/%d/actions/disable_public_interface", loadBalancer.ID) const opPath = "/load_balancers/%d/actions/disable_public_interface"
req, err := c.client.NewRequest(ctx, "POST", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
if err != nil {
return nil, nil, err reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
}
respBody := schema.LoadBalancerActionDisablePublicInterfaceResponse{} respBody, resp, err := postRequest[schema.LoadBalancerActionDisablePublicInterfaceResponse](ctx, c.client, reqPath, nil)
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// LoadBalancerChangeTypeOpts specifies options for changing a Load Balancer's type. // LoadBalancerChangeTypeOpts specifies options for changing a Load Balancer's type.
@ -934,28 +850,21 @@ type LoadBalancerChangeTypeOpts struct {
// ChangeType changes a Load Balancer's type. // ChangeType changes a Load Balancer's type.
func (c *LoadBalancerClient) ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error) { func (c *LoadBalancerClient) ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_type"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeTypeRequest{} reqBody := schema.LoadBalancerActionChangeTypeRequest{}
if opts.LoadBalancerType.ID != 0 { if opts.LoadBalancerType.ID != 0 || opts.LoadBalancerType.Name != "" {
reqBody.LoadBalancerType = opts.LoadBalancerType.ID reqBody.LoadBalancerType = schema.IDOrName{ID: opts.LoadBalancerType.ID, Name: opts.LoadBalancerType.Name}
} else {
reqBody.LoadBalancerType = opts.LoadBalancerType.Name
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
} }
path := fmt.Sprintf("/load_balancers/%d/actions/change_type", loadBalancer.ID) respBody, resp, err := postRequest[schema.LoadBalancerActionChangeTypeResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeTypeResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -980,32 +889,34 @@ type LoadBalancerGetMetricsOpts struct {
Step int Step int
} }
func (o *LoadBalancerGetMetricsOpts) addQueryParams(req *http.Request) error { func (o LoadBalancerGetMetricsOpts) Validate() error {
query := req.URL.Query()
if len(o.Types) == 0 { if len(o.Types) == 0 {
return fmt.Errorf("no metric types specified") return missingField(o, "Types")
} }
if o.Start.IsZero() {
return missingField(o, "Start")
}
if o.End.IsZero() {
return missingField(o, "End")
}
return nil
}
func (o LoadBalancerGetMetricsOpts) values() url.Values {
query := url.Values{}
for _, typ := range o.Types { for _, typ := range o.Types {
query.Add("type", string(typ)) query.Add("type", string(typ))
} }
if o.Start.IsZero() {
return fmt.Errorf("no start time specified")
}
query.Add("start", o.Start.Format(time.RFC3339)) query.Add("start", o.Start.Format(time.RFC3339))
if o.End.IsZero() {
return fmt.Errorf("no end time specified")
}
query.Add("end", o.End.Format(time.RFC3339)) query.Add("end", o.End.Format(time.RFC3339))
if o.Step > 0 { if o.Step > 0 {
query.Add("step", strconv.Itoa(o.Step)) query.Add("step", strconv.Itoa(o.Step))
} }
req.URL.RawQuery = query.Encode()
return nil return query
} }
// LoadBalancerMetrics contains the metrics requested for a Load Balancer. // LoadBalancerMetrics contains the metrics requested for a Load Balancer.
@ -1024,31 +935,32 @@ type LoadBalancerMetricsValue struct {
// GetMetrics obtains metrics for a Load Balancer. // GetMetrics obtains metrics for a Load Balancer.
func (c *LoadBalancerClient) GetMetrics( func (c *LoadBalancerClient) GetMetrics(
ctx context.Context, lb *LoadBalancer, opts LoadBalancerGetMetricsOpts, ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerGetMetricsOpts,
) (*LoadBalancerMetrics, *Response, error) { ) (*LoadBalancerMetrics, *Response, error) {
var respBody schema.LoadBalancerGetMetricsResponse const opPath = "/load_balancers/%d/metrics?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
if lb == nil { if loadBalancer == nil {
return nil, nil, fmt.Errorf("illegal argument: load balancer is nil") return nil, nil, missingArgument("loadBalancer", loadBalancer)
} }
path := fmt.Sprintf("/load_balancers/%d/metrics", lb.ID) if err := opts.Validate(); err != nil {
req, err := c.client.NewRequest(ctx, "GET", path, nil) return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, loadBalancer.ID, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerGetMetricsResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("new request: %v", err) return nil, resp, err
} }
if err := opts.addQueryParams(req); err != nil {
return nil, nil, fmt.Errorf("add query params: %v", err) metrics, err := loadBalancerMetricsFromSchema(&respBody)
}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("get metrics: %v", err) return nil, nil, fmt.Errorf("convert response body: %w", err)
} }
ms, err := loadBalancerMetricsFromSchema(&respBody)
if err != nil { return metrics, resp, nil
return nil, nil, fmt.Errorf("convert response body: %v", err)
}
return ms, resp, nil
} }
// ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer. // ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer.

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -29,32 +30,27 @@ type LoadBalancerTypeClient struct {
// GetByID retrieves a Load Balancer type by its ID. If the Load Balancer type does not exist, nil is returned. // GetByID retrieves a Load Balancer type by its ID. If the Load Balancer type does not exist, nil is returned.
func (c *LoadBalancerTypeClient) GetByID(ctx context.Context, id int64) (*LoadBalancerType, *Response, error) { func (c *LoadBalancerTypeClient) GetByID(ctx context.Context, id int64) (*LoadBalancerType, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/load_balancer_types/%d", id), nil) const opPath = "/load_balancer_types/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.LoadBalancerTypeGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.LoadBalancerTypeGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return LoadBalancerTypeFromSchema(body.LoadBalancerType), resp, nil
return LoadBalancerTypeFromSchema(respBody.LoadBalancerType), resp, nil
} }
// GetByName retrieves a Load Balancer type by its name. If the Load Balancer type does not exist, nil is returned. // GetByName retrieves a Load Balancer type by its name. If the Load Balancer type does not exist, nil is returned.
func (c *LoadBalancerTypeClient) GetByName(ctx context.Context, name string) (*LoadBalancerType, *Response, error) { func (c *LoadBalancerTypeClient) GetByName(ctx context.Context, name string) (*LoadBalancerType, *Response, error) {
if name == "" { return firstByName(name, func() ([]*LoadBalancerType, *Response, error) {
return nil, nil, nil return c.List(ctx, LoadBalancerTypeListOpts{Name: name})
} })
LoadBalancerTypes, response, err := c.List(ctx, LoadBalancerTypeListOpts{Name: name})
if len(LoadBalancerTypes) == 0 {
return nil, response, err
}
return LoadBalancerTypes[0], response, err
} }
// Get retrieves a Load Balancer type by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Load Balancer type by its ID if the input can be parsed as an integer, otherwise it
@ -89,22 +85,17 @@ func (l LoadBalancerTypeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *LoadBalancerTypeClient) List(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, *Response, error) { func (c *LoadBalancerTypeClient) List(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, *Response, error) {
path := "/load_balancer_types?" + opts.values().Encode() const opPath = "/load_balancer_types?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerTypeListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.LoadBalancerTypeListResponse return allFromSchemaFunc(respBody.LoadBalancerTypes, LoadBalancerTypeFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
LoadBalancerTypes := make([]*LoadBalancerType, 0, len(body.LoadBalancerTypes))
for _, s := range body.LoadBalancerTypes {
LoadBalancerTypes = append(LoadBalancerTypes, LoadBalancerTypeFromSchema(s))
}
return LoadBalancerTypes, resp, nil
} }
// All returns all Load Balancer types. // All returns all Load Balancer types.
@ -114,20 +105,8 @@ func (c *LoadBalancerTypeClient) All(ctx context.Context) ([]*LoadBalancerType,
// AllWithOpts returns all Load Balancer types for the given options. // AllWithOpts returns all Load Balancer types for the given options.
func (c *LoadBalancerTypeClient) AllWithOpts(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, error) { func (c *LoadBalancerTypeClient) AllWithOpts(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, error) {
allLoadBalancerTypes := []*LoadBalancerType{} return iterPages(func(page int) ([]*LoadBalancerType, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
LoadBalancerTypes, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allLoadBalancerTypes = append(allLoadBalancerTypes, LoadBalancerTypes...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allLoadBalancerTypes, nil
} }

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -28,32 +29,27 @@ type LocationClient struct {
// GetByID retrieves a location by its ID. If the location does not exist, nil is returned. // GetByID retrieves a location by its ID. If the location does not exist, nil is returned.
func (c *LocationClient) GetByID(ctx context.Context, id int64) (*Location, *Response, error) { func (c *LocationClient) GetByID(ctx context.Context, id int64) (*Location, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/locations/%d", id), nil) const opPath = "/locations/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.LocationGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.LocationGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, resp, err return nil, resp, err
} }
return LocationFromSchema(body.Location), resp, nil
return LocationFromSchema(respBody.Location), resp, nil
} }
// GetByName retrieves an location by its name. If the location does not exist, nil is returned. // GetByName retrieves an location by its name. If the location does not exist, nil is returned.
func (c *LocationClient) GetByName(ctx context.Context, name string) (*Location, *Response, error) { func (c *LocationClient) GetByName(ctx context.Context, name string) (*Location, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Location, *Response, error) {
return nil, nil, nil return c.List(ctx, LocationListOpts{Name: name})
} })
locations, response, err := c.List(ctx, LocationListOpts{Name: name})
if len(locations) == 0 {
return nil, response, err
}
return locations[0], response, err
} }
// Get retrieves a location by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a location by its ID if the input can be parsed as an integer, otherwise it
@ -88,22 +84,17 @@ func (l LocationListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *LocationClient) List(ctx context.Context, opts LocationListOpts) ([]*Location, *Response, error) { func (c *LocationClient) List(ctx context.Context, opts LocationListOpts) ([]*Location, *Response, error) {
path := "/locations?" + opts.values().Encode() const opPath = "/locations?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LocationListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.LocationListResponse return allFromSchemaFunc(respBody.Locations, LocationFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
locations := make([]*Location, 0, len(body.Locations))
for _, i := range body.Locations {
locations = append(locations, LocationFromSchema(i))
}
return locations, resp, nil
} }
// All returns all locations. // All returns all locations.
@ -113,20 +104,8 @@ func (c *LocationClient) All(ctx context.Context) ([]*Location, error) {
// AllWithOpts returns all locations for the given options. // AllWithOpts returns all locations for the given options.
func (c *LocationClient) AllWithOpts(ctx context.Context, opts LocationListOpts) ([]*Location, error) { func (c *LocationClient) AllWithOpts(ctx context.Context, opts LocationListOpts) ([]*Location, error) {
allLocations := []*Location{} return iterPages(func(page int) ([]*Location, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
locations, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allLocations = append(allLocations, locations...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allLocations, nil
} }

View File

@ -1,6 +1,8 @@
package metadata package metadata
import ( import (
"bytes"
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -11,6 +13,7 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation"
) )
@ -72,24 +75,33 @@ func NewClient(options ...ClientOption) *Client {
if client.instrumentationRegistry != nil { if client.instrumentationRegistry != nil {
i := instrumentation.New("metadata", client.instrumentationRegistry) i := instrumentation.New("metadata", client.instrumentationRegistry)
client.httpClient.Transport = i.InstrumentedRoundTripper() client.httpClient.Transport = i.InstrumentedRoundTripper(client.httpClient.Transport)
} }
return client return client
} }
// get executes an HTTP request against the API. // get executes an HTTP request against the API.
func (c *Client) get(path string) (string, error) { func (c *Client) get(path string) (string, error) {
url := c.endpoint + path ctx := ctxutil.SetOpPath(context.Background(), path)
resp, err := c.httpClient.Get(url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.endpoint+path, http.NoBody)
if err != nil { if err != nil {
return "", err return "", err
} }
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close() defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", err return "", err
} }
body := string(bodyBytes)
body := string(bytes.TrimSpace(bodyBytes))
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return body, fmt.Errorf("response status was %d", resp.StatusCode) return body, fmt.Errorf("response status was %d", resp.StatusCode)
} }

View File

@ -1,16 +1,13 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -22,6 +19,7 @@ const (
NetworkZoneEUCentral NetworkZone = "eu-central" NetworkZoneEUCentral NetworkZone = "eu-central"
NetworkZoneUSEast NetworkZone = "us-east" NetworkZoneUSEast NetworkZone = "us-east"
NetworkZoneUSWest NetworkZone = "us-west" NetworkZoneUSWest NetworkZone = "us-west"
NetworkZoneAPSouthEast NetworkZone = "ap-southeast"
) )
// NetworkSubnetType specifies a type of a subnet. // NetworkSubnetType specifies a type of a subnet.
@ -29,8 +27,15 @@ type NetworkSubnetType string
// List of available network subnet types. // List of available network subnet types.
const ( const (
// Used to connect cloud servers and load balancers.
NetworkSubnetTypeCloud NetworkSubnetType = "cloud" NetworkSubnetTypeCloud NetworkSubnetType = "cloud"
// Used to connect cloud servers and load balancers.
//
// Deprecated: Use [NetworkSubnetTypeCloud] instead.
NetworkSubnetTypeServer NetworkSubnetType = "server" NetworkSubnetTypeServer NetworkSubnetType = "server"
// Used to connect cloud servers and load balancers with dedicated servers.
//
// See https://docs.hetzner.com/cloud/networks/connect-dedi-vswitch/
NetworkSubnetTypeVSwitch NetworkSubnetType = "vswitch" NetworkSubnetTypeVSwitch NetworkSubnetType = "vswitch"
) )
@ -43,6 +48,7 @@ type Network struct {
Subnets []NetworkSubnet Subnets []NetworkSubnet
Routes []NetworkRoute Routes []NetworkRoute
Servers []*Server Servers []*Server
LoadBalancers []*LoadBalancer
Protection NetworkProtection Protection NetworkProtection
Labels map[string]string Labels map[string]string
@ -78,41 +84,33 @@ type NetworkClient struct {
// GetByID retrieves a network by its ID. If the network does not exist, nil is returned. // GetByID retrieves a network by its ID. If the network does not exist, nil is returned.
func (c *NetworkClient) GetByID(ctx context.Context, id int64) (*Network, *Response, error) { func (c *NetworkClient) GetByID(ctx context.Context, id int64) (*Network, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/networks/%d", id), nil) const opPath = "/networks/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.NetworkGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.NetworkGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return NetworkFromSchema(body.Network), resp, nil
return NetworkFromSchema(respBody.Network), resp, nil
} }
// GetByName retrieves a network by its name. If the network does not exist, nil is returned. // GetByName retrieves a network by its name. If the network does not exist, nil is returned.
func (c *NetworkClient) GetByName(ctx context.Context, name string) (*Network, *Response, error) { func (c *NetworkClient) GetByName(ctx context.Context, name string) (*Network, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Network, *Response, error) {
return nil, nil, nil return c.List(ctx, NetworkListOpts{Name: name})
} })
Networks, response, err := c.List(ctx, NetworkListOpts{Name: name})
if len(Networks) == 0 {
return nil, response, err
}
return Networks[0], response, err
} }
// Get retrieves a network by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a network by its ID if the input can be parsed as an integer, otherwise it
// retrieves a network by its name. If the network does not exist, nil is returned. // retrieves a network by its name. If the network does not exist, nil is returned.
func (c *NetworkClient) Get(ctx context.Context, idOrName string) (*Network, *Response, error) { func (c *NetworkClient) Get(ctx context.Context, idOrName string) (*Network, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// NetworkListOpts specifies options for listing networks. // NetworkListOpts specifies options for listing networks.
@ -138,22 +136,17 @@ func (l NetworkListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *NetworkClient) List(ctx context.Context, opts NetworkListOpts) ([]*Network, *Response, error) { func (c *NetworkClient) List(ctx context.Context, opts NetworkListOpts) ([]*Network, *Response, error) {
path := "/networks?" + opts.values().Encode() const opPath = "/networks?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.NetworkListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.NetworkListResponse return allFromSchemaFunc(respBody.Networks, NetworkFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
Networks := make([]*Network, 0, len(body.Networks))
for _, s := range body.Networks {
Networks = append(Networks, NetworkFromSchema(s))
}
return Networks, resp, nil
} }
// All returns all networks. // All returns all networks.
@ -163,31 +156,20 @@ func (c *NetworkClient) All(ctx context.Context) ([]*Network, error) {
// AllWithOpts returns all networks for the given options. // AllWithOpts returns all networks for the given options.
func (c *NetworkClient) AllWithOpts(ctx context.Context, opts NetworkListOpts) ([]*Network, error) { func (c *NetworkClient) AllWithOpts(ctx context.Context, opts NetworkListOpts) ([]*Network, error) {
allNetworks := []*Network{} return iterPages(func(page int) ([]*Network, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
Networks, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allNetworks = append(allNetworks, Networks...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allNetworks, nil
} }
// Delete deletes a network. // Delete deletes a network.
func (c *NetworkClient) Delete(ctx context.Context, network *Network) (*Response, error) { func (c *NetworkClient) Delete(ctx context.Context, network *Network) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/networks/%d", network.ID), nil) const opPath = "/networks/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, network.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// NetworkUpdateOpts specifies options for updating a network. // NetworkUpdateOpts specifies options for updating a network.
@ -201,6 +183,11 @@ type NetworkUpdateOpts struct {
// Update updates a network. // Update updates a network.
func (c *NetworkClient) Update(ctx context.Context, network *Network, opts NetworkUpdateOpts) (*Network, *Response, error) { func (c *NetworkClient) Update(ctx context.Context, network *Network, opts NetworkUpdateOpts) (*Network, *Response, error) {
const opPath = "/networks/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkUpdateRequest{ reqBody := schema.NetworkUpdateRequest{
Name: opts.Name, Name: opts.Name,
} }
@ -211,22 +198,11 @@ func (c *NetworkClient) Update(ctx context.Context, network *Network, opts Netwo
reqBody.ExposeRoutesToVSwitch = opts.ExposeRoutesToVSwitch reqBody.ExposeRoutesToVSwitch = opts.ExposeRoutesToVSwitch
} }
reqBodyData, err := json.Marshal(reqBody) respBody, resp, err := putRequest[schema.NetworkUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d", network.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return NetworkFromSchema(respBody.Network), resp, nil return NetworkFromSchema(respBody.Network), resp, nil
} }
@ -245,16 +221,21 @@ type NetworkCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o NetworkCreateOpts) Validate() error { func (o NetworkCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
if o.IPRange == nil || o.IPRange.String() == "" { if o.IPRange == nil || o.IPRange.String() == "" {
return errors.New("missing IP range") return missingField(o, "IPRange")
} }
return nil return nil
} }
// Create creates a new network. // Create creates a new network.
func (c *NetworkClient) Create(ctx context.Context, opts NetworkCreateOpts) (*Network, *Response, error) { func (c *NetworkClient) Create(ctx context.Context, opts NetworkCreateOpts) (*Network, *Response, error) {
const opPath = "/networks"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -283,20 +264,12 @@ func (c *NetworkClient) Create(ctx context.Context, opts NetworkCreateOpts) (*Ne
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/networks", bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkCreateResponse{} respBody, resp, err := postRequest[schema.NetworkCreateResponse](ctx, c.client, reqPath, reqBody)
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return NetworkFromSchema(respBody.Network), resp, nil return NetworkFromSchema(respBody.Network), resp, nil
} }
@ -307,25 +280,20 @@ type NetworkChangeIPRangeOpts struct {
// ChangeIPRange changes the IP range of a network. // ChangeIPRange changes the IP range of a network.
func (c *NetworkClient) ChangeIPRange(ctx context.Context, network *Network, opts NetworkChangeIPRangeOpts) (*Action, *Response, error) { func (c *NetworkClient) ChangeIPRange(ctx context.Context, network *Network, opts NetworkChangeIPRangeOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/change_ip_range"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionChangeIPRangeRequest{ reqBody := schema.NetworkActionChangeIPRangeRequest{
IPRange: opts.IPRange.String(), IPRange: opts.IPRange.String(),
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/change_ip_range", network.ID) respBody, resp, err := postRequest[schema.NetworkActionChangeIPRangeResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionChangeIPRangeResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -336,6 +304,11 @@ type NetworkAddSubnetOpts struct {
// AddSubnet adds a subnet to a network. // AddSubnet adds a subnet to a network.
func (c *NetworkClient) AddSubnet(ctx context.Context, network *Network, opts NetworkAddSubnetOpts) (*Action, *Response, error) { func (c *NetworkClient) AddSubnet(ctx context.Context, network *Network, opts NetworkAddSubnetOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/add_subnet"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionAddSubnetRequest{ reqBody := schema.NetworkActionAddSubnetRequest{
Type: string(opts.Subnet.Type), Type: string(opts.Subnet.Type),
NetworkZone: string(opts.Subnet.NetworkZone), NetworkZone: string(opts.Subnet.NetworkZone),
@ -346,22 +319,12 @@ func (c *NetworkClient) AddSubnet(ctx context.Context, network *Network, opts Ne
if opts.Subnet.VSwitchID != 0 { if opts.Subnet.VSwitchID != 0 {
reqBody.VSwitchID = opts.Subnet.VSwitchID reqBody.VSwitchID = opts.Subnet.VSwitchID
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/add_subnet", network.ID) respBody, resp, err := postRequest[schema.NetworkActionAddSubnetResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionAddSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -372,25 +335,20 @@ type NetworkDeleteSubnetOpts struct {
// DeleteSubnet deletes a subnet from a network. // DeleteSubnet deletes a subnet from a network.
func (c *NetworkClient) DeleteSubnet(ctx context.Context, network *Network, opts NetworkDeleteSubnetOpts) (*Action, *Response, error) { func (c *NetworkClient) DeleteSubnet(ctx context.Context, network *Network, opts NetworkDeleteSubnetOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/delete_subnet"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionDeleteSubnetRequest{ reqBody := schema.NetworkActionDeleteSubnetRequest{
IPRange: opts.Subnet.IPRange.String(), IPRange: opts.Subnet.IPRange.String(),
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/delete_subnet", network.ID) respBody, resp, err := postRequest[schema.NetworkActionDeleteSubnetResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionDeleteSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -401,26 +359,21 @@ type NetworkAddRouteOpts struct {
// AddRoute adds a route to a network. // AddRoute adds a route to a network.
func (c *NetworkClient) AddRoute(ctx context.Context, network *Network, opts NetworkAddRouteOpts) (*Action, *Response, error) { func (c *NetworkClient) AddRoute(ctx context.Context, network *Network, opts NetworkAddRouteOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/add_route"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionAddRouteRequest{ reqBody := schema.NetworkActionAddRouteRequest{
Destination: opts.Route.Destination.String(), Destination: opts.Route.Destination.String(),
Gateway: opts.Route.Gateway.String(), Gateway: opts.Route.Gateway.String(),
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/add_route", network.ID) respBody, resp, err := postRequest[schema.NetworkActionAddRouteResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionAddSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -431,26 +384,21 @@ type NetworkDeleteRouteOpts struct {
// DeleteRoute deletes a route from a network. // DeleteRoute deletes a route from a network.
func (c *NetworkClient) DeleteRoute(ctx context.Context, network *Network, opts NetworkDeleteRouteOpts) (*Action, *Response, error) { func (c *NetworkClient) DeleteRoute(ctx context.Context, network *Network, opts NetworkDeleteRouteOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/delete_route"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionDeleteRouteRequest{ reqBody := schema.NetworkActionDeleteRouteRequest{
Destination: opts.Route.Destination.String(), Destination: opts.Route.Destination.String(),
Gateway: opts.Route.Gateway.String(), Gateway: opts.Route.Gateway.String(),
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/delete_route", network.ID) respBody, resp, err := postRequest[schema.NetworkActionDeleteRouteResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionDeleteSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -461,24 +409,19 @@ type NetworkChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a network. // ChangeProtection changes the resource protection level of a network.
func (c *NetworkClient) ChangeProtection(ctx context.Context, network *Network, opts NetworkChangeProtectionOpts) (*Action, *Response, error) { func (c *NetworkClient) ChangeProtection(ctx context.Context, network *Network, opts NetworkChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionChangeProtectionRequest{ reqBody := schema.NetworkActionChangeProtectionRequest{
Delete: opts.Delete, Delete: opts.Delete,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/change_protection", network.ID) respBody, resp, err := postRequest[schema.NetworkActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }

View File

@ -1,15 +1,12 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -38,41 +35,33 @@ type PlacementGroupClient struct {
// GetByID retrieves a PlacementGroup by its ID. If the PlacementGroup does not exist, nil is returned. // GetByID retrieves a PlacementGroup by its ID. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) GetByID(ctx context.Context, id int64) (*PlacementGroup, *Response, error) { func (c *PlacementGroupClient) GetByID(ctx context.Context, id int64) (*PlacementGroup, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/placement_groups/%d", id), nil) const opPath = "/placement_groups/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.PlacementGroupGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.PlacementGroupGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return PlacementGroupFromSchema(body.PlacementGroup), resp, nil
return PlacementGroupFromSchema(respBody.PlacementGroup), resp, nil
} }
// GetByName retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned. // GetByName retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) GetByName(ctx context.Context, name string) (*PlacementGroup, *Response, error) { func (c *PlacementGroupClient) GetByName(ctx context.Context, name string) (*PlacementGroup, *Response, error) {
if name == "" { return firstByName(name, func() ([]*PlacementGroup, *Response, error) {
return nil, nil, nil return c.List(ctx, PlacementGroupListOpts{Name: name})
} })
placementGroups, response, err := c.List(ctx, PlacementGroupListOpts{Name: name})
if len(placementGroups) == 0 {
return nil, response, err
}
return placementGroups[0], response, err
} }
// Get retrieves a PlacementGroup by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a PlacementGroup by its ID if the input can be parsed as an integer, otherwise it
// retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned. // retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) Get(ctx context.Context, idOrName string) (*PlacementGroup, *Response, error) { func (c *PlacementGroupClient) Get(ctx context.Context, idOrName string) (*PlacementGroup, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// PlacementGroupListOpts specifies options for listing PlacementGroup. // PlacementGroupListOpts specifies options for listing PlacementGroup.
@ -102,22 +91,17 @@ func (l PlacementGroupListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *PlacementGroupClient) List(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, *Response, error) { func (c *PlacementGroupClient) List(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, *Response, error) {
path := "/placement_groups?" + opts.values().Encode() const opPath = "/placement_groups?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.PlacementGroupListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.PlacementGroupListResponse return allFromSchemaFunc(respBody.PlacementGroups, PlacementGroupFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
placementGroups := make([]*PlacementGroup, 0, len(body.PlacementGroups))
for _, g := range body.PlacementGroups {
placementGroups = append(placementGroups, PlacementGroupFromSchema(g))
}
return placementGroups, resp, nil
} }
// All returns all PlacementGroups. // All returns all PlacementGroups.
@ -133,22 +117,10 @@ func (c *PlacementGroupClient) All(ctx context.Context) ([]*PlacementGroup, erro
// AllWithOpts returns all PlacementGroups for the given options. // AllWithOpts returns all PlacementGroups for the given options.
func (c *PlacementGroupClient) AllWithOpts(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, error) { func (c *PlacementGroupClient) AllWithOpts(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, error) {
allPlacementGroups := []*PlacementGroup{} return iterPages(func(page int) ([]*PlacementGroup, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
placementGroups, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allPlacementGroups = append(allPlacementGroups, placementGroups...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allPlacementGroups, nil
} }
// PlacementGroupCreateOpts specifies options for creating a new PlacementGroup. // PlacementGroupCreateOpts specifies options for creating a new PlacementGroup.
@ -161,7 +133,7 @@ type PlacementGroupCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o PlacementGroupCreateOpts) Validate() error { func (o PlacementGroupCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
return nil return nil
} }
@ -174,27 +146,25 @@ type PlacementGroupCreateResult struct {
// Create creates a new PlacementGroup. // Create creates a new PlacementGroup.
func (c *PlacementGroupClient) Create(ctx context.Context, opts PlacementGroupCreateOpts) (PlacementGroupCreateResult, *Response, error) { func (c *PlacementGroupClient) Create(ctx context.Context, opts PlacementGroupCreateOpts) (PlacementGroupCreateResult, *Response, error) {
const opPath = "/placement_groups"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := PlacementGroupCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return PlacementGroupCreateResult{}, nil, err return result, nil, err
}
reqBody := placementGroupCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return PlacementGroupCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/placement_groups", bytes.NewReader(reqBodyData))
if err != nil {
return PlacementGroupCreateResult{}, nil, err
} }
respBody := schema.PlacementGroupCreateResponse{} reqBody := placementGroupCreateOptsToSchema(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PlacementGroupCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return PlacementGroupCreateResult{}, nil, err return result, resp, err
}
result := PlacementGroupCreateResult{
PlacementGroup: PlacementGroupFromSchema(respBody.PlacementGroup),
} }
result.PlacementGroup = PlacementGroupFromSchema(respBody.PlacementGroup)
if respBody.Action != nil { if respBody.Action != nil {
result.Action = ActionFromSchema(*respBody.Action) result.Action = ActionFromSchema(*respBody.Action)
} }
@ -210,6 +180,11 @@ type PlacementGroupUpdateOpts struct {
// Update updates a PlacementGroup. // Update updates a PlacementGroup.
func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *PlacementGroup, opts PlacementGroupUpdateOpts) (*PlacementGroup, *Response, error) { func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *PlacementGroup, opts PlacementGroupUpdateOpts) (*PlacementGroup, *Response, error) {
const opPath = "/placement_groups/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, placementGroup.ID)
reqBody := schema.PlacementGroupUpdateRequest{} reqBody := schema.PlacementGroupUpdateRequest{}
if opts.Name != "" { if opts.Name != "" {
reqBody.Name = &opts.Name reqBody.Name = &opts.Name
@ -217,19 +192,8 @@ func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *Place
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/placement_groups/%d", placementGroup.ID) respBody, resp, err := putRequest[schema.PlacementGroupUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.PlacementGroupUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
@ -239,9 +203,10 @@ func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *Place
// Delete deletes a PlacementGroup. // Delete deletes a PlacementGroup.
func (c *PlacementGroupClient) Delete(ctx context.Context, placementGroup *PlacementGroup) (*Response, error) { func (c *PlacementGroupClient) Delete(ctx context.Context, placementGroup *PlacementGroup) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/placement_groups/%d", placementGroup.ID), nil) const opPath = "/placement_groups/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, placementGroup.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }

View File

@ -3,15 +3,19 @@ package hcloud
import ( import (
"context" "context"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
// Pricing specifies pricing information for various resources. // Pricing specifies pricing information for various resources.
type Pricing struct { type Pricing struct {
Image ImagePricing Image ImagePricing
// Deprecated: [Pricing.FloatingIP] is deprecated, use [Pricing.FloatingIPs] instead.
FloatingIP FloatingIPPricing FloatingIP FloatingIPPricing
FloatingIPs []FloatingIPTypePricing FloatingIPs []FloatingIPTypePricing
PrimaryIPs []PrimaryIPPricing PrimaryIPs []PrimaryIPPricing
// Deprecated: [Pricing.Traffic] is deprecated and will report 0 after 2024-08-05.
// Use traffic pricing from [Pricing.ServerTypes] or [Pricing.LoadBalancerTypes] instead.
Traffic TrafficPricing Traffic TrafficPricing
ServerBackup ServerBackupPricing ServerBackup ServerBackupPricing
ServerTypes []ServerTypePricing ServerTypes []ServerTypePricing
@ -102,6 +106,10 @@ type ServerTypeLocationPricing struct {
Location *Location Location *Location
Hourly Price Hourly Price
Monthly Price Monthly Price
// IncludedTraffic is the free traffic per month in bytes
IncludedTraffic uint64
PerTBTraffic Price
} }
// LoadBalancerTypePricing provides pricing information for a Load Balancer type. // LoadBalancerTypePricing provides pricing information for a Load Balancer type.
@ -116,6 +124,10 @@ type LoadBalancerTypeLocationPricing struct {
Location *Location Location *Location
Hourly Price Hourly Price
Monthly Price Monthly Price
// IncludedTraffic is the free traffic per month in bytes
IncludedTraffic uint64
PerTBTraffic Price
} }
// PricingClient is a client for the pricing API. // PricingClient is a client for the pricing API.
@ -125,15 +137,15 @@ type PricingClient struct {
// Get retrieves pricing information. // Get retrieves pricing information.
func (c *PricingClient) Get(ctx context.Context) (Pricing, *Response, error) { func (c *PricingClient) Get(ctx context.Context) (Pricing, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", "/pricing", nil) const opPath = "/pricing"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
respBody, resp, err := getRequest[schema.PricingGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return Pricing{}, nil, err return Pricing{}, resp, err
} }
var body schema.PricingGetResponse return PricingFromSchema(respBody.Pricing), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return Pricing{}, nil, err
}
return PricingFromSchema(body.Pricing), resp, nil
} }

View File

@ -1,15 +1,13 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -46,26 +44,21 @@ type PrimaryIPDNSPTR struct {
// changeDNSPtr changes or resets the reverse DNS pointer for a IP address. // changeDNSPtr changes or resets the reverse DNS pointer for a IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value. // Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (p *PrimaryIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) { func (p *PrimaryIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/primary_ips/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, p.ID)
reqBody := schema.PrimaryIPActionChangeDNSPtrRequest{ reqBody := schema.PrimaryIPActionChangeDNSPtrRequest{
IP: ip.String(), IP: ip.String(),
DNSPtr: ptr, DNSPtr: ptr,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d/actions/change_dns_ptr", p.ID) respBody, resp, err := postRequest[schema.PrimaryIPActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPChangeDNSPtrResult
resp, err := client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -92,13 +85,13 @@ const (
// PrimaryIPCreateOpts defines the request to // PrimaryIPCreateOpts defines the request to
// create a Primary IP. // create a Primary IP.
type PrimaryIPCreateOpts struct { type PrimaryIPCreateOpts struct {
AssigneeID *int64 `json:"assignee_id,omitempty"` AssigneeID *int64
AssigneeType string `json:"assignee_type"` AssigneeType string
AutoDelete *bool `json:"auto_delete,omitempty"` AutoDelete *bool
Datacenter string `json:"datacenter,omitempty"` Datacenter string
Labels map[string]string `json:"labels,omitempty"` Labels map[string]string
Name string `json:"name"` Name string
Type PrimaryIPType `json:"type"` Type PrimaryIPType
} }
// PrimaryIPCreateResult defines the response // PrimaryIPCreateResult defines the response
@ -111,51 +104,42 @@ type PrimaryIPCreateResult struct {
// PrimaryIPUpdateOpts defines the request to // PrimaryIPUpdateOpts defines the request to
// update a Primary IP. // update a Primary IP.
type PrimaryIPUpdateOpts struct { type PrimaryIPUpdateOpts struct {
AutoDelete *bool `json:"auto_delete,omitempty"` AutoDelete *bool
Labels *map[string]string `json:"labels,omitempty"` Labels *map[string]string
Name string `json:"name,omitempty"` Name string
} }
// PrimaryIPAssignOpts defines the request to // PrimaryIPAssignOpts defines the request to
// assign a Primary IP to an assignee (usually a server). // assign a Primary IP to an assignee (usually a server).
type PrimaryIPAssignOpts struct { type PrimaryIPAssignOpts struct {
ID int64 ID int64
AssigneeID int64 `json:"assignee_id"` AssigneeID int64
AssigneeType string `json:"assignee_type"` AssigneeType string
} }
// PrimaryIPAssignResult defines the response // Deprecated: Please use [schema.PrimaryIPActionAssignResponse] instead.
// when assigning a Primary IP to a assignee. type PrimaryIPAssignResult = schema.PrimaryIPActionAssignResponse
type PrimaryIPAssignResult struct {
Action schema.Action `json:"action"`
}
// PrimaryIPChangeDNSPtrOpts defines the request to // PrimaryIPChangeDNSPtrOpts defines the request to
// change a DNS PTR entry from a Primary IP. // change a DNS PTR entry from a Primary IP.
type PrimaryIPChangeDNSPtrOpts struct { type PrimaryIPChangeDNSPtrOpts struct {
ID int64 ID int64
DNSPtr string `json:"dns_ptr"` DNSPtr string
IP string `json:"ip"` IP string
} }
// PrimaryIPChangeDNSPtrResult defines the response // Deprecated: Please use [schema.PrimaryIPChangeDNSPtrResponse] instead.
// when assigning a Primary IP to a assignee. type PrimaryIPChangeDNSPtrResult = schema.PrimaryIPActionChangeDNSPtrResponse
type PrimaryIPChangeDNSPtrResult struct {
Action schema.Action `json:"action"`
}
// PrimaryIPChangeProtectionOpts defines the request to // PrimaryIPChangeProtectionOpts defines the request to
// change protection configuration of a Primary IP. // change protection configuration of a Primary IP.
type PrimaryIPChangeProtectionOpts struct { type PrimaryIPChangeProtectionOpts struct {
ID int64 ID int64
Delete bool `json:"delete"` Delete bool
} }
// PrimaryIPChangeProtectionResult defines the response // Deprecated: Please use [schema.PrimaryIPActionChangeProtectionResponse] instead.
// when changing a protection of a PrimaryIP. type PrimaryIPChangeProtectionResult = schema.PrimaryIPActionChangeProtectionResponse
type PrimaryIPChangeProtectionResult struct {
Action schema.Action `json:"action"`
}
// PrimaryIPClient is a client for the Primary IP API. // PrimaryIPClient is a client for the Primary IP API.
type PrimaryIPClient struct { type PrimaryIPClient struct {
@ -165,20 +149,20 @@ type PrimaryIPClient struct {
// GetByID retrieves a Primary IP by its ID. If the Primary IP does not exist, nil is returned. // GetByID retrieves a Primary IP by its ID. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) GetByID(ctx context.Context, id int64) (*PrimaryIP, *Response, error) { func (c *PrimaryIPClient) GetByID(ctx context.Context, id int64) (*PrimaryIP, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/primary_ips/%d", id), nil) const opPath = "/primary_ips/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.PrimaryIPGetResult reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.PrimaryIPGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return PrimaryIPFromSchema(body.PrimaryIP), resp, nil
return PrimaryIPFromSchema(respBody.PrimaryIP), resp, nil
} }
// GetByIP retrieves a Primary IP by its IP Address. If the Primary IP does not exist, nil is returned. // GetByIP retrieves a Primary IP by its IP Address. If the Primary IP does not exist, nil is returned.
@ -186,32 +170,22 @@ func (c *PrimaryIPClient) GetByIP(ctx context.Context, ip string) (*PrimaryIP, *
if ip == "" { if ip == "" {
return nil, nil, nil return nil, nil, nil
} }
primaryIPs, response, err := c.List(ctx, PrimaryIPListOpts{IP: ip}) return firstBy(func() ([]*PrimaryIP, *Response, error) {
if len(primaryIPs) == 0 { return c.List(ctx, PrimaryIPListOpts{IP: ip})
return nil, response, err })
}
return primaryIPs[0], response, err
} }
// GetByName retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned. // GetByName retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) GetByName(ctx context.Context, name string) (*PrimaryIP, *Response, error) { func (c *PrimaryIPClient) GetByName(ctx context.Context, name string) (*PrimaryIP, *Response, error) {
if name == "" { return firstByName(name, func() ([]*PrimaryIP, *Response, error) {
return nil, nil, nil return c.List(ctx, PrimaryIPListOpts{Name: name})
} })
primaryIPs, response, err := c.List(ctx, PrimaryIPListOpts{Name: name})
if len(primaryIPs) == 0 {
return nil, response, err
}
return primaryIPs[0], response, err
} }
// Get retrieves a Primary IP by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a Primary IP by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned. // retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) Get(ctx context.Context, idOrName string) (*PrimaryIP, *Response, error) { func (c *PrimaryIPClient) Get(ctx context.Context, idOrName string) (*PrimaryIP, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// PrimaryIPListOpts specifies options for listing Primary IPs. // PrimaryIPListOpts specifies options for listing Primary IPs.
@ -241,22 +215,17 @@ func (l PrimaryIPListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *PrimaryIPClient) List(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, *Response, error) { func (c *PrimaryIPClient) List(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, *Response, error) {
path := "/primary_ips?" + opts.values().Encode() const opPath = "/primary_ips?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.PrimaryIPListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.PrimaryIPListResult return allFromSchemaFunc(respBody.PrimaryIPs, PrimaryIPFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
primaryIPs := make([]*PrimaryIP, 0, len(body.PrimaryIPs))
for _, s := range body.PrimaryIPs {
primaryIPs = append(primaryIPs, PrimaryIPFromSchema(s))
}
return primaryIPs, resp, nil
} }
// All returns all Primary IPs. // All returns all Primary IPs.
@ -266,157 +235,125 @@ func (c *PrimaryIPClient) All(ctx context.Context) ([]*PrimaryIP, error) {
// AllWithOpts returns all Primary IPs for the given options. // AllWithOpts returns all Primary IPs for the given options.
func (c *PrimaryIPClient) AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error) { func (c *PrimaryIPClient) AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error) {
allPrimaryIPs := []*PrimaryIP{} return iterPages(func(page int) ([]*PrimaryIP, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
primaryIPs, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allPrimaryIPs = append(allPrimaryIPs, primaryIPs...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allPrimaryIPs, nil
} }
// Create creates a Primary IP. // Create creates a Primary IP.
func (c *PrimaryIPClient) Create(ctx context.Context, reqBody PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error) { func (c *PrimaryIPClient) Create(ctx context.Context, opts PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error) {
reqBodyData, err := json.Marshal(reqBody) const opPath = "/primary_ips"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := &PrimaryIPCreateResult{}
reqPath := opPath
reqBody := SchemaFromPrimaryIPCreateOpts(opts)
respBody, resp, err := postRequest[schema.PrimaryIPCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return &PrimaryIPCreateResult{}, nil, err return result, resp, err
} }
req, err := c.client.NewRequest(ctx, "POST", "/primary_ips", bytes.NewReader(reqBodyData)) result.PrimaryIP = PrimaryIPFromSchema(respBody.PrimaryIP)
if err != nil {
return &PrimaryIPCreateResult{}, nil, err
}
var respBody schema.PrimaryIPCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return &PrimaryIPCreateResult{}, resp, err
}
var action *Action
if respBody.Action != nil { if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action) result.Action = ActionFromSchema(*respBody.Action)
} }
primaryIP := PrimaryIPFromSchema(respBody.PrimaryIP)
return &PrimaryIPCreateResult{ return result, resp, nil
PrimaryIP: primaryIP,
Action: action,
}, resp, nil
} }
// Delete deletes a Primary IP. // Delete deletes a Primary IP.
func (c *PrimaryIPClient) Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error) { func (c *PrimaryIPClient) Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/primary_ips/%d", primaryIP.ID), nil) const opPath = "/primary_ips/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, primaryIP.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// Update updates a Primary IP. // Update updates a Primary IP.
func (c *PrimaryIPClient) Update(ctx context.Context, primaryIP *PrimaryIP, reqBody PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error) { func (c *PrimaryIPClient) Update(ctx context.Context, primaryIP *PrimaryIP, opts PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error) {
reqBodyData, err := json.Marshal(reqBody) const opPath = "/primary_ips/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d", primaryIP.ID) reqPath := fmt.Sprintf(opPath, primaryIP.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.PrimaryIPUpdateResult reqBody := SchemaFromPrimaryIPUpdateOpts(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.PrimaryIPUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return PrimaryIPFromSchema(respBody.PrimaryIP), resp, nil return PrimaryIPFromSchema(respBody.PrimaryIP), resp, nil
} }
// Assign a Primary IP to a resource. // Assign a Primary IP to a resource.
func (c *PrimaryIPClient) Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error) { func (c *PrimaryIPClient) Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts) const opPath = "/primary_ips/%d/actions/assign"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d/actions/assign", opts.ID) reqPath := fmt.Sprintf(opPath, opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPAssignResult reqBody := SchemaFromPrimaryIPAssignOpts(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PrimaryIPActionAssignResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// Unassign a Primary IP from a resource. // Unassign a Primary IP from a resource.
func (c *PrimaryIPClient) Unassign(ctx context.Context, id int64) (*Action, *Response, error) { func (c *PrimaryIPClient) Unassign(ctx context.Context, id int64) (*Action, *Response, error) {
path := fmt.Sprintf("/primary_ips/%d/actions/unassign", id) const opPath = "/primary_ips/%d/actions/unassign"
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader([]byte{})) ctx = ctxutil.SetOpPath(ctx, opPath)
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPAssignResult reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PrimaryIPActionUnassignResponse](ctx, c.client, reqPath, nil)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// ChangeDNSPtr Change the reverse DNS from a Primary IP. // ChangeDNSPtr Change the reverse DNS from a Primary IP.
func (c *PrimaryIPClient) ChangeDNSPtr(ctx context.Context, opts PrimaryIPChangeDNSPtrOpts) (*Action, *Response, error) { func (c *PrimaryIPClient) ChangeDNSPtr(ctx context.Context, opts PrimaryIPChangeDNSPtrOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts) const opPath = "/primary_ips/%d/actions/change_dns_ptr"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d/actions/change_dns_ptr", opts.ID) reqPath := fmt.Sprintf(opPath, opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPChangeDNSPtrResult reqBody := SchemaFromPrimaryIPChangeDNSPtrOpts(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PrimaryIPActionChangeDNSPtrResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
// ChangeProtection Changes the protection configuration of a Primary IP. // ChangeProtection Changes the protection configuration of a Primary IP.
func (c *PrimaryIPClient) ChangeProtection(ctx context.Context, opts PrimaryIPChangeProtectionOpts) (*Action, *Response, error) { func (c *PrimaryIPClient) ChangeProtection(ctx context.Context, opts PrimaryIPChangeProtectionOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts) const opPath = "/primary_ips/%d/actions/change_protection"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d/actions/change_protection", opts.ID) reqPath := fmt.Sprintf(opPath, opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPChangeProtectionResult reqBody := SchemaFromPrimaryIPChangeProtectionOpts(opts)
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PrimaryIPActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }

View File

@ -49,6 +49,26 @@ func SchemaFromPrimaryIP(p *PrimaryIP) schema.PrimaryIP {
return c.SchemaFromPrimaryIP(p) return c.SchemaFromPrimaryIP(p)
} }
func SchemaFromPrimaryIPCreateOpts(o PrimaryIPCreateOpts) schema.PrimaryIPCreateRequest {
return c.SchemaFromPrimaryIPCreateOpts(o)
}
func SchemaFromPrimaryIPUpdateOpts(o PrimaryIPUpdateOpts) schema.PrimaryIPUpdateRequest {
return c.SchemaFromPrimaryIPUpdateOpts(o)
}
func SchemaFromPrimaryIPChangeDNSPtrOpts(o PrimaryIPChangeDNSPtrOpts) schema.PrimaryIPActionChangeDNSPtrRequest {
return c.SchemaFromPrimaryIPChangeDNSPtrOpts(o)
}
func SchemaFromPrimaryIPChangeProtectionOpts(o PrimaryIPChangeProtectionOpts) schema.PrimaryIPActionChangeProtectionRequest {
return c.SchemaFromPrimaryIPChangeProtectionOpts(o)
}
func SchemaFromPrimaryIPAssignOpts(o PrimaryIPAssignOpts) schema.PrimaryIPActionAssignRequest {
return c.SchemaFromPrimaryIPAssignOpts(o)
}
// ISOFromSchema converts a schema.ISO to an ISO. // ISOFromSchema converts a schema.ISO to an ISO.
func ISOFromSchema(s schema.ISO) *ISO { func ISOFromSchema(s schema.ISO) *ISO {
return c.ISOFromSchema(s) return c.ISOFromSchema(s)

View File

@ -0,0 +1,4 @@
// The schema package holds API schemas for the `hcloud-go` library.
// Breaking changes may occur without notice. Do not use in production!
package schema

View File

@ -7,7 +7,7 @@ type Error struct {
Code string `json:"code"` Code string `json:"code"`
Message string `json:"message"` Message string `json:"message"`
DetailsRaw json.RawMessage `json:"details"` DetailsRaw json.RawMessage `json:"details"`
Details interface{} Details any `json:"-"`
} }
// UnmarshalJSON overrides default json unmarshalling. // UnmarshalJSON overrides default json unmarshalling.
@ -17,13 +17,20 @@ func (e *Error) UnmarshalJSON(data []byte) (err error) {
if err = json.Unmarshal(data, alias); err != nil { if err = json.Unmarshal(data, alias); err != nil {
return return
} }
if e.Code == "invalid_input" { if e.Code == "invalid_input" && len(e.DetailsRaw) > 0 {
details := ErrorDetailsInvalidInput{} details := ErrorDetailsInvalidInput{}
if err = json.Unmarshal(e.DetailsRaw, &details); err != nil { if err = json.Unmarshal(e.DetailsRaw, &details); err != nil {
return return
} }
alias.Details = details alias.Details = details
} }
if e.Code == "deprecated_api_endpoint" && len(e.DetailsRaw) > 0 {
details := ErrorDetailsDeprecatedAPIEndpoint{}
if err = json.Unmarshal(e.DetailsRaw, &details); err != nil {
return
}
alias.Details = details
}
return return
} }
@ -40,3 +47,9 @@ type ErrorDetailsInvalidInput struct {
Messages []string `json:"messages"` Messages []string `json:"messages"`
} `json:"fields"` } `json:"fields"`
} }
// ErrorDetailsDeprecatedAPIEndpoint defines the schema of the Details field
// of an error with code 'deprecated_api_endpoint'.
type ErrorDetailsDeprecatedAPIEndpoint struct {
Announcement string `json:"announcement"`
}

View File

@ -0,0 +1,68 @@
package schema
import (
"bytes"
"encoding/json"
"reflect"
"strconv"
)
// IDOrName can be used in API requests where either a resource id or name can be
// specified.
type IDOrName struct {
ID int64
Name string
}
var _ json.Unmarshaler = (*IDOrName)(nil)
var _ json.Marshaler = (*IDOrName)(nil)
func (o IDOrName) MarshalJSON() ([]byte, error) {
if o.ID != 0 {
return json.Marshal(o.ID)
}
if o.Name != "" {
return json.Marshal(o.Name)
}
// We want to preserve the behavior of an empty interface{} to prevent breaking
// changes (marshaled to null when empty).
return json.Marshal(nil)
}
func (o *IDOrName) UnmarshalJSON(data []byte) error {
d := json.NewDecoder(bytes.NewBuffer(data))
// This ensures we won't lose precision on large IDs, see json.Number below
d.UseNumber()
var v any
if err := d.Decode(&v); err != nil {
return err
}
switch typed := v.(type) {
case string:
id, err := strconv.ParseInt(typed, 10, 64)
if err == nil {
o.ID = id
} else if typed != "" {
o.Name = typed
}
case json.Number:
id, err := typed.Int64()
if err != nil {
return &json.UnmarshalTypeError{
Value: string(data),
Type: reflect.TypeOf(*o),
}
}
o.ID = id
default:
return &json.UnmarshalTypeError{
Value: string(data),
Type: reflect.TypeOf(*o),
}
}
return nil
}

View File

@ -251,7 +251,7 @@ type LoadBalancerDeleteServiceResponse struct {
type LoadBalancerCreateRequest struct { type LoadBalancerCreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
LoadBalancerType interface{} `json:"load_balancer_type"` // int or string LoadBalancerType IDOrName `json:"load_balancer_type"`
Algorithm *LoadBalancerCreateRequestAlgorithm `json:"algorithm,omitempty"` Algorithm *LoadBalancerCreateRequestAlgorithm `json:"algorithm,omitempty"`
Location *string `json:"location,omitempty"` Location *string `json:"location,omitempty"`
NetworkZone *string `json:"network_zone,omitempty"` NetworkZone *string `json:"network_zone,omitempty"`
@ -380,7 +380,7 @@ type LoadBalancerActionDisablePublicInterfaceResponse struct {
} }
type LoadBalancerActionChangeTypeRequest struct { type LoadBalancerActionChangeTypeRequest struct {
LoadBalancerType interface{} `json:"load_balancer_type"` // int or string LoadBalancerType IDOrName `json:"load_balancer_type"`
} }
type LoadBalancerActionChangeTypeResponse struct { type LoadBalancerActionChangeTypeResponse struct {

View File

@ -11,6 +11,7 @@ type Network struct {
Subnets []NetworkSubnet `json:"subnets"` Subnets []NetworkSubnet `json:"subnets"`
Routes []NetworkRoute `json:"routes"` Routes []NetworkRoute `json:"routes"`
Servers []int64 `json:"servers"` Servers []int64 `json:"servers"`
LoadBalancers []int64 `json:"load_balancers"`
Protection NetworkProtection `json:"protection"` Protection NetworkProtection `json:"protection"`
Labels map[string]string `json:"labels"` Labels map[string]string `json:"labels"`
ExposeRoutesToVSwitch bool `json:"expose_routes_to_vswitch"` ExposeRoutesToVSwitch bool `json:"expose_routes_to_vswitch"`

View File

@ -5,9 +5,12 @@ type Pricing struct {
Currency string `json:"currency"` Currency string `json:"currency"`
VATRate string `json:"vat_rate"` VATRate string `json:"vat_rate"`
Image PricingImage `json:"image"` Image PricingImage `json:"image"`
// Deprecated: [Pricing.FloatingIP] is deprecated, use [Pricing.FloatingIPs] instead.
FloatingIP PricingFloatingIP `json:"floating_ip"` FloatingIP PricingFloatingIP `json:"floating_ip"`
FloatingIPs []PricingFloatingIPType `json:"floating_ips"` FloatingIPs []PricingFloatingIPType `json:"floating_ips"`
PrimaryIPs []PricingPrimaryIP `json:"primary_ips"` PrimaryIPs []PricingPrimaryIP `json:"primary_ips"`
// Deprecated: [Pricing.Traffic] is deprecated and will report 0 after 2024-08-05.
// Use traffic pricing from [Pricing.ServerTypes] or [Pricing.LoadBalancerTypes] instead.
Traffic PricingTraffic `json:"traffic"` Traffic PricingTraffic `json:"traffic"`
ServerBackup PricingServerBackup `json:"server_backup"` ServerBackup PricingServerBackup `json:"server_backup"`
ServerTypes []PricingServerType `json:"server_types"` ServerTypes []PricingServerType `json:"server_types"`
@ -72,6 +75,9 @@ type PricingServerTypePrice struct {
Location string `json:"location"` Location string `json:"location"`
PriceHourly Price `json:"price_hourly"` PriceHourly Price `json:"price_hourly"`
PriceMonthly Price `json:"price_monthly"` PriceMonthly Price `json:"price_monthly"`
IncludedTraffic uint64 `json:"included_traffic"`
PricePerTBTraffic Price `json:"price_per_tb_traffic"`
} }
// PricingLoadBalancerType defines the schema of pricing information for a Load Balancer type. // PricingLoadBalancerType defines the schema of pricing information for a Load Balancer type.
@ -87,6 +93,9 @@ type PricingLoadBalancerTypePrice struct {
Location string `json:"location"` Location string `json:"location"`
PriceHourly Price `json:"price_hourly"` PriceHourly Price `json:"price_hourly"`
PriceMonthly Price `json:"price_monthly"` PriceMonthly Price `json:"price_monthly"`
IncludedTraffic uint64 `json:"included_traffic"`
PricePerTBTraffic Price `json:"price_per_tb_traffic"`
} }
// PricingGetResponse defines the schema of the response when retrieving pricing information. // PricingGetResponse defines the schema of the response when retrieving pricing information.

View File

@ -31,6 +31,18 @@ type PrimaryIPDNSPTR struct {
IP string `json:"ip"` IP string `json:"ip"`
} }
// PrimaryIPCreateOpts defines the request to
// create a Primary IP.
type PrimaryIPCreateRequest struct {
Name string `json:"name"`
Type string `json:"type"`
AssigneeType string `json:"assignee_type"`
AssigneeID *int64 `json:"assignee_id,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
AutoDelete *bool `json:"auto_delete,omitempty"`
Datacenter string `json:"datacenter,omitempty"`
}
// PrimaryIPCreateResponse defines the schema of the response // PrimaryIPCreateResponse defines the schema of the response
// when creating a Primary IP. // when creating a Primary IP.
type PrimaryIPCreateResponse struct { type PrimaryIPCreateResponse struct {
@ -38,19 +50,27 @@ type PrimaryIPCreateResponse struct {
Action *Action `json:"action"` Action *Action `json:"action"`
} }
// PrimaryIPGetResult defines the response when retrieving a single Primary IP. // PrimaryIPGetResponse defines the response when retrieving a single Primary IP.
type PrimaryIPGetResult struct { type PrimaryIPGetResponse struct {
PrimaryIP PrimaryIP `json:"primary_ip"` PrimaryIP PrimaryIP `json:"primary_ip"`
} }
// PrimaryIPListResult defines the response when listing Primary IPs. // PrimaryIPListResponse defines the response when listing Primary IPs.
type PrimaryIPListResult struct { type PrimaryIPListResponse struct {
PrimaryIPs []PrimaryIP `json:"primary_ips"` PrimaryIPs []PrimaryIP `json:"primary_ips"`
} }
// PrimaryIPUpdateResult defines the response // PrimaryIPUpdateOpts defines the request to
// update a Primary IP.
type PrimaryIPUpdateRequest struct {
Name string `json:"name,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
AutoDelete *bool `json:"auto_delete,omitempty"`
}
// PrimaryIPUpdateResponse defines the response
// when updating a Primary IP. // when updating a Primary IP.
type PrimaryIPUpdateResult struct { type PrimaryIPUpdateResponse struct {
PrimaryIP PrimaryIP `json:"primary_ip"` PrimaryIP PrimaryIP `json:"primary_ip"`
} }
@ -60,3 +80,39 @@ type PrimaryIPActionChangeDNSPtrRequest struct {
IP string `json:"ip"` IP string `json:"ip"`
DNSPtr *string `json:"dns_ptr"` DNSPtr *string `json:"dns_ptr"`
} }
// PrimaryIPActionChangeDNSPtrResponse defines the response when setting a reverse DNS
// pointer for a IP address.
type PrimaryIPActionChangeDNSPtrResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionAssignRequest defines the request to
// assign a Primary IP to an assignee (usually a server).
type PrimaryIPActionAssignRequest struct {
AssigneeID int64 `json:"assignee_id"`
AssigneeType string `json:"assignee_type"`
}
// PrimaryIPActionAssignResponse defines the response when assigning a Primary IP to a
// assignee.
type PrimaryIPActionAssignResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionUnassignResponse defines the response to unassign a Primary IP.
type PrimaryIPActionUnassignResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionChangeProtectionRequest defines the request to
// change protection configuration of a Primary IP.
type PrimaryIPActionChangeProtectionRequest struct {
Delete bool `json:"delete"`
}
// PrimaryIPActionChangeProtectionResponse defines the response when changing the
// protection of a Primary IP.
type PrimaryIPActionChangeProtectionResponse struct {
Action Action `json:"action"`
}

View File

@ -99,8 +99,8 @@ type ServerListResponse struct {
// create a server. // create a server.
type ServerCreateRequest struct { type ServerCreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
ServerType interface{} `json:"server_type"` // int or string ServerType IDOrName `json:"server_type"`
Image interface{} `json:"image"` // int or string Image IDOrName `json:"image"`
SSHKeys []int64 `json:"ssh_keys,omitempty"` SSHKeys []int64 `json:"ssh_keys,omitempty"`
Location string `json:"location,omitempty"` Location string `json:"location,omitempty"`
Datacenter string `json:"datacenter,omitempty"` Datacenter string `json:"datacenter,omitempty"`
@ -257,7 +257,7 @@ type ServerActionDisableRescueResponse struct {
// ServerActionRebuildRequest defines the schema for the request to // ServerActionRebuildRequest defines the schema for the request to
// rebuild a server. // rebuild a server.
type ServerActionRebuildRequest struct { type ServerActionRebuildRequest struct {
Image interface{} `json:"image"` // int or string Image IDOrName `json:"image"`
} }
// ServerActionRebuildResponse defines the schema of the response when // ServerActionRebuildResponse defines the schema of the response when
@ -270,7 +270,7 @@ type ServerActionRebuildResponse struct {
// ServerActionAttachISORequest defines the schema for the request to // ServerActionAttachISORequest defines the schema for the request to
// attach an ISO to a server. // attach an ISO to a server.
type ServerActionAttachISORequest struct { type ServerActionAttachISORequest struct {
ISO interface{} `json:"iso"` // int or string ISO IDOrName `json:"iso"`
} }
// ServerActionAttachISOResponse defines the schema of the response when // ServerActionAttachISOResponse defines the schema of the response when
@ -289,12 +289,6 @@ type ServerActionDetachISOResponse struct {
Action Action `json:"action"` Action Action `json:"action"`
} }
// ServerActionEnableBackupRequest defines the schema for the request to
// enable backup for a server.
type ServerActionEnableBackupRequest struct {
BackupWindow *string `json:"backup_window,omitempty"`
}
// ServerActionEnableBackupResponse defines the schema of the response when // ServerActionEnableBackupResponse defines the schema of the response when
// creating a enable_backup server action. // creating a enable_backup server action.
type ServerActionEnableBackupResponse struct { type ServerActionEnableBackupResponse struct {
@ -314,7 +308,7 @@ type ServerActionDisableBackupResponse struct {
// ServerActionChangeTypeRequest defines the schema for the request to // ServerActionChangeTypeRequest defines the schema for the request to
// change a server's type. // change a server's type.
type ServerActionChangeTypeRequest struct { type ServerActionChangeTypeRequest struct {
ServerType interface{} `json:"server_type"` // int or string ServerType IDOrName `json:"server_type"`
UpgradeDisk bool `json:"upgrade_disk"` UpgradeDisk bool `json:"upgrade_disk"`
} }

View File

@ -11,6 +11,9 @@ type ServerType struct {
StorageType string `json:"storage_type"` StorageType string `json:"storage_type"`
CPUType string `json:"cpu_type"` CPUType string `json:"cpu_type"`
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
// Deprecated: [ServerType.IncludedTraffic] is deprecated and will always report 0 after 2024-08-05.
// Use [ServerType.Prices] instead to get the included traffic for each location.
IncludedTraffic int64 `json:"included_traffic"` IncludedTraffic int64 `json:"included_traffic"`
Prices []PricingServerTypePrice `json:"prices"` Prices []PricingServerTypePrice `json:"prices"`
Deprecated bool `json:"deprecated"` Deprecated bool `json:"deprecated"`

View File

@ -23,7 +23,7 @@ type VolumeCreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
Size int `json:"size"` Size int `json:"size"`
Server *int64 `json:"server,omitempty"` Server *int64 `json:"server,omitempty"`
Location interface{} `json:"location,omitempty"` // int, string, or nil Location *IDOrName `json:"location,omitempty"`
Labels *map[string]string `json:"labels,omitempty"` Labels *map[string]string `json:"labels,omitempty"`
Automount *bool `json:"automount,omitempty"` Automount *bool `json:"automount,omitempty"`
Format *string `json:"format,omitempty"` Format *string `json:"format,omitempty"`

View File

@ -69,7 +69,6 @@ You can find a documentation of goverter here: https://goverter.jmattheis.de/
// goverter:extend durationFromIntSeconds // goverter:extend durationFromIntSeconds
// goverter:extend intSecondsFromDuration // goverter:extend intSecondsFromDuration
// goverter:extend serverFromImageCreatedFromSchema // goverter:extend serverFromImageCreatedFromSchema
// goverter:extend anyFromLoadBalancerType
// goverter:extend serverMetricsTimeSeriesFromSchema // goverter:extend serverMetricsTimeSeriesFromSchema
// goverter:extend loadBalancerMetricsTimeSeriesFromSchema // goverter:extend loadBalancerMetricsTimeSeriesFromSchema
// goverter:extend stringPtrFromLoadBalancerServiceProtocol // goverter:extend stringPtrFromLoadBalancerServiceProtocol
@ -108,6 +107,12 @@ type converter interface {
// goverter:map AssigneeID | mapZeroInt64ToNil // goverter:map AssigneeID | mapZeroInt64ToNil
SchemaFromPrimaryIP(*PrimaryIP) schema.PrimaryIP SchemaFromPrimaryIP(*PrimaryIP) schema.PrimaryIP
SchemaFromPrimaryIPCreateOpts(PrimaryIPCreateOpts) schema.PrimaryIPCreateRequest
SchemaFromPrimaryIPUpdateOpts(PrimaryIPUpdateOpts) schema.PrimaryIPUpdateRequest
SchemaFromPrimaryIPChangeDNSPtrOpts(PrimaryIPChangeDNSPtrOpts) schema.PrimaryIPActionChangeDNSPtrRequest
SchemaFromPrimaryIPChangeProtectionOpts(PrimaryIPChangeProtectionOpts) schema.PrimaryIPActionChangeProtectionRequest
SchemaFromPrimaryIPAssignOpts(PrimaryIPAssignOpts) schema.PrimaryIPActionAssignRequest
ISOFromSchema(schema.ISO) *ISO ISOFromSchema(schema.ISO) *ISO
// We cannot use goverter settings when mapping a struct to a struct pointer // We cannot use goverter settings when mapping a struct to a struct pointer
@ -207,10 +212,12 @@ type converter interface {
// goverter:map PriceHourly Hourly // goverter:map PriceHourly Hourly
// goverter:map PriceMonthly Monthly // goverter:map PriceMonthly Monthly
// goverter:map PricePerTBTraffic PerTBTraffic
LoadBalancerTypeLocationPricingFromSchema(schema.PricingLoadBalancerTypePrice) LoadBalancerTypeLocationPricing LoadBalancerTypeLocationPricingFromSchema(schema.PricingLoadBalancerTypePrice) LoadBalancerTypeLocationPricing
// goverter:map Hourly PriceHourly // goverter:map Hourly PriceHourly
// goverter:map Monthly PriceMonthly // goverter:map Monthly PriceMonthly
// goverter:map PerTBTraffic PricePerTBTraffic
SchemaFromLoadBalancerTypeLocationPricing(LoadBalancerTypeLocationPricing) schema.PricingLoadBalancerTypePrice SchemaFromLoadBalancerTypeLocationPricing(LoadBalancerTypeLocationPricing) schema.PricingLoadBalancerTypePrice
LoadBalancerServiceFromSchema(schema.LoadBalancerService) LoadBalancerService LoadBalancerServiceFromSchema(schema.LoadBalancerService) LoadBalancerService
@ -263,6 +270,7 @@ type converter interface {
// goverter:map PriceHourly Hourly // goverter:map PriceHourly Hourly
// goverter:map PriceMonthly Monthly // goverter:map PriceMonthly Monthly
// goverter:map PricePerTBTraffic PerTBTraffic
serverTypePricingFromSchema(schema.PricingServerTypePrice) ServerTypeLocationPricing serverTypePricingFromSchema(schema.PricingServerTypePrice) ServerTypeLocationPricing
// goverter:map Image.PerGBMonth.Currency Currency // goverter:map Image.PerGBMonth.Currency Currency
@ -306,6 +314,7 @@ type converter interface {
// goverter:map Monthly PriceMonthly // goverter:map Monthly PriceMonthly
// goverter:map Hourly PriceHourly // goverter:map Hourly PriceHourly
// goverter:map PerTBTraffic PricePerTBTraffic
schemaFromServerTypeLocationPricing(ServerTypeLocationPricing) schema.PricingServerTypePrice schemaFromServerTypeLocationPricing(ServerTypeLocationPricing) schema.PricingServerTypePrice
FirewallFromSchema(schema.Firewall) *Firewall FirewallFromSchema(schema.Firewall) *Firewall
@ -606,37 +615,48 @@ func intSecondsFromDuration(d time.Duration) int {
} }
func errorDetailsFromSchema(d interface{}) interface{} { func errorDetailsFromSchema(d interface{}) interface{} {
if d, ok := d.(schema.ErrorDetailsInvalidInput); ok { switch typed := d.(type) {
case schema.ErrorDetailsInvalidInput:
details := ErrorDetailsInvalidInput{ details := ErrorDetailsInvalidInput{
Fields: make([]ErrorDetailsInvalidInputField, len(d.Fields)), Fields: make([]ErrorDetailsInvalidInputField, len(typed.Fields)),
} }
for i, field := range d.Fields { for i, field := range typed.Fields {
details.Fields[i] = ErrorDetailsInvalidInputField{ details.Fields[i] = ErrorDetailsInvalidInputField{
Name: field.Name, Name: field.Name,
Messages: field.Messages, Messages: field.Messages,
} }
} }
return details return details
case schema.ErrorDetailsDeprecatedAPIEndpoint:
return ErrorDetailsDeprecatedAPIEndpoint{
Announcement: typed.Announcement,
}
} }
return nil return nil
} }
func schemaFromErrorDetails(d interface{}) interface{} { func schemaFromErrorDetails(d interface{}) interface{} {
if d, ok := d.(ErrorDetailsInvalidInput); ok { switch typed := d.(type) {
case ErrorDetailsInvalidInput:
details := schema.ErrorDetailsInvalidInput{ details := schema.ErrorDetailsInvalidInput{
Fields: make([]struct { Fields: make([]struct {
Name string `json:"name"` Name string `json:"name"`
Messages []string `json:"messages"` Messages []string `json:"messages"`
}, len(d.Fields)), }, len(typed.Fields)),
} }
for i, field := range d.Fields { for i, field := range typed.Fields {
details.Fields[i] = struct { details.Fields[i] = struct {
Name string `json:"name"` Name string `json:"name"`
Messages []string `json:"messages"` Messages []string `json:"messages"`
}{Name: field.Name, Messages: field.Messages} }{Name: field.Name, Messages: field.Messages}
} }
return details return details
case ErrorDetailsDeprecatedAPIEndpoint:
return schema.ErrorDetailsDeprecatedAPIEndpoint{Announcement: typed.Announcement}
} }
return nil return nil
} }
@ -654,8 +674,8 @@ func imagePricingFromSchema(s schema.Pricing) ImagePricing {
func floatingIPPricingFromSchema(s schema.Pricing) FloatingIPPricing { func floatingIPPricingFromSchema(s schema.Pricing) FloatingIPPricing {
return FloatingIPPricing{ return FloatingIPPricing{
Monthly: Price{ Monthly: Price{
Net: s.FloatingIP.PriceMonthly.Net, Net: s.FloatingIP.PriceMonthly.Net, // nolint:staticcheck // Field is deprecated, but removal is not planned
Gross: s.FloatingIP.PriceMonthly.Gross, Gross: s.FloatingIP.PriceMonthly.Gross, // nolint:staticcheck // Field is deprecated, but removal is not planned
Currency: s.Currency, Currency: s.Currency,
VATRate: s.VATRate, VATRate: s.VATRate,
}, },
@ -707,8 +727,8 @@ func primaryIPPricingFromSchema(s schema.Pricing) []PrimaryIPPricing {
func trafficPricingFromSchema(s schema.Pricing) TrafficPricing { func trafficPricingFromSchema(s schema.Pricing) TrafficPricing {
return TrafficPricing{ return TrafficPricing{
PerTB: Price{ PerTB: Price{
Net: s.Traffic.PricePerTB.Net, Net: s.Traffic.PricePerTB.Net, // nolint:staticcheck // Field is deprecated, but we still need to map it as long as it is available
Gross: s.Traffic.PricePerTB.Gross, Gross: s.Traffic.PricePerTB.Gross, // nolint:staticcheck // Field is deprecated, but we still need to map it as long as it is available
Currency: s.Currency, Currency: s.Currency,
VATRate: s.VATRate, VATRate: s.VATRate,
}, },
@ -734,6 +754,13 @@ func serverTypePricingFromSchema(s schema.Pricing) []ServerTypePricing {
Net: price.PriceMonthly.Net, Net: price.PriceMonthly.Net,
Gross: price.PriceMonthly.Gross, Gross: price.PriceMonthly.Gross,
}, },
IncludedTraffic: price.IncludedTraffic,
PerTBTraffic: Price{
Currency: s.Currency,
VATRate: s.VATRate,
Net: price.PricePerTBTraffic.Net,
Gross: price.PricePerTBTraffic.Gross,
},
} }
} }
p[i] = ServerTypePricing{ p[i] = ServerTypePricing{
@ -766,6 +793,13 @@ func loadBalancerTypePricingFromSchema(s schema.Pricing) []LoadBalancerTypePrici
Net: price.PriceMonthly.Net, Net: price.PriceMonthly.Net,
Gross: price.PriceMonthly.Gross, Gross: price.PriceMonthly.Gross,
}, },
IncludedTraffic: price.IncludedTraffic,
PerTBTraffic: Price{
Currency: s.Currency,
VATRate: s.VATRate,
Net: price.PricePerTBTraffic.Net,
Gross: price.PricePerTBTraffic.Gross,
},
} }
} }
p[i] = LoadBalancerTypePricing{ p[i] = LoadBalancerTypePricing{
@ -790,16 +824,6 @@ func volumePricingFromSchema(s schema.Pricing) VolumePricing {
} }
} }
func anyFromLoadBalancerType(t *LoadBalancerType) interface{} {
if t == nil {
return nil
}
if t.ID != 0 {
return t.ID
}
return t.Name
}
func serverMetricsTimeSeriesFromSchema(s schema.ServerTimeSeriesVals) ([]ServerMetricsValue, error) { func serverMetricsTimeSeriesFromSchema(s schema.ServerTimeSeriesVals) ([]ServerMetricsValue, error) {
vals := make([]ServerMetricsValue, len(s.Values)) vals := make([]ServerMetricsValue, len(s.Values))
@ -922,7 +946,10 @@ func rawSchemaFromErrorDetails(v interface{}) json.RawMessage {
if v == nil { if v == nil {
return nil return nil
} }
msg, _ := json.Marshal(d) msg, err := json.Marshal(d)
if err != nil {
return nil
}
return msg return msg
} }

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -20,7 +21,9 @@ type ServerType struct {
StorageType StorageType StorageType StorageType
CPUType CPUType CPUType CPUType
Architecture Architecture Architecture Architecture
// IncludedTraffic is the free traffic per month in bytes
// Deprecated: [ServerType.IncludedTraffic] is deprecated and will always report 0 after 2024-08-05.
// Use [ServerType.Pricings] instead to get the included traffic for each location.
IncludedTraffic int64 IncludedTraffic int64
Pricings []ServerTypeLocationPricing Pricings []ServerTypeLocationPricing
DeprecatableResource DeprecatableResource
@ -55,32 +58,27 @@ type ServerTypeClient struct {
// GetByID retrieves a server type by its ID. If the server type does not exist, nil is returned. // GetByID retrieves a server type by its ID. If the server type does not exist, nil is returned.
func (c *ServerTypeClient) GetByID(ctx context.Context, id int64) (*ServerType, *Response, error) { func (c *ServerTypeClient) GetByID(ctx context.Context, id int64) (*ServerType, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/server_types/%d", id), nil) const opPath = "/server_types/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.ServerTypeGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.ServerTypeGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return ServerTypeFromSchema(body.ServerType), resp, nil
return ServerTypeFromSchema(respBody.ServerType), resp, nil
} }
// GetByName retrieves a server type by its name. If the server type does not exist, nil is returned. // GetByName retrieves a server type by its name. If the server type does not exist, nil is returned.
func (c *ServerTypeClient) GetByName(ctx context.Context, name string) (*ServerType, *Response, error) { func (c *ServerTypeClient) GetByName(ctx context.Context, name string) (*ServerType, *Response, error) {
if name == "" { return firstByName(name, func() ([]*ServerType, *Response, error) {
return nil, nil, nil return c.List(ctx, ServerTypeListOpts{Name: name})
} })
serverTypes, response, err := c.List(ctx, ServerTypeListOpts{Name: name})
if len(serverTypes) == 0 {
return nil, response, err
}
return serverTypes[0], response, err
} }
// Get retrieves a server type by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a server type by its ID if the input can be parsed as an integer, otherwise it
@ -115,22 +113,17 @@ func (l ServerTypeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *ServerTypeClient) List(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, *Response, error) { func (c *ServerTypeClient) List(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, *Response, error) {
path := "/server_types?" + opts.values().Encode() const opPath = "/server_types?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ServerTypeListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.ServerTypeListResponse return allFromSchemaFunc(respBody.ServerTypes, ServerTypeFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
serverTypes := make([]*ServerType, 0, len(body.ServerTypes))
for _, s := range body.ServerTypes {
serverTypes = append(serverTypes, ServerTypeFromSchema(s))
}
return serverTypes, resp, nil
} }
// All returns all server types. // All returns all server types.
@ -140,20 +133,8 @@ func (c *ServerTypeClient) All(ctx context.Context) ([]*ServerType, error) {
// AllWithOpts returns all server types for the given options. // AllWithOpts returns all server types for the given options.
func (c *ServerTypeClient) AllWithOpts(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, error) { func (c *ServerTypeClient) AllWithOpts(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, error) {
allServerTypes := []*ServerType{} return iterPages(func(page int) ([]*ServerType, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
serverTypes, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allServerTypes = append(allServerTypes, serverTypes...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allServerTypes, nil
} }

View File

@ -1,15 +1,12 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -30,50 +27,40 @@ type SSHKeyClient struct {
// GetByID retrieves a SSH key by its ID. If the SSH key does not exist, nil is returned. // GetByID retrieves a SSH key by its ID. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByID(ctx context.Context, id int64) (*SSHKey, *Response, error) { func (c *SSHKeyClient) GetByID(ctx context.Context, id int64) (*SSHKey, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/ssh_keys/%d", id), nil) const opPath = "/ssh_keys/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.SSHKeyGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.SSHKeyGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return SSHKeyFromSchema(body.SSHKey), resp, nil
return SSHKeyFromSchema(respBody.SSHKey), resp, nil
} }
// GetByName retrieves a SSH key by its name. If the SSH key does not exist, nil is returned. // GetByName retrieves a SSH key by its name. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByName(ctx context.Context, name string) (*SSHKey, *Response, error) { func (c *SSHKeyClient) GetByName(ctx context.Context, name string) (*SSHKey, *Response, error) {
if name == "" { return firstByName(name, func() ([]*SSHKey, *Response, error) {
return nil, nil, nil return c.List(ctx, SSHKeyListOpts{Name: name})
} })
sshKeys, response, err := c.List(ctx, SSHKeyListOpts{Name: name})
if len(sshKeys) == 0 {
return nil, response, err
}
return sshKeys[0], response, err
} }
// GetByFingerprint retreives a SSH key by its fingerprint. If the SSH key does not exist, nil is returned. // GetByFingerprint retreives a SSH key by its fingerprint. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByFingerprint(ctx context.Context, fingerprint string) (*SSHKey, *Response, error) { func (c *SSHKeyClient) GetByFingerprint(ctx context.Context, fingerprint string) (*SSHKey, *Response, error) {
sshKeys, response, err := c.List(ctx, SSHKeyListOpts{Fingerprint: fingerprint}) return firstBy(func() ([]*SSHKey, *Response, error) {
if len(sshKeys) == 0 { return c.List(ctx, SSHKeyListOpts{Fingerprint: fingerprint})
return nil, response, err })
}
return sshKeys[0], response, err
} }
// Get retrieves a SSH key by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a SSH key by its ID if the input can be parsed as an integer, otherwise it
// retrieves a SSH key by its name. If the SSH key does not exist, nil is returned. // retrieves a SSH key by its name. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) Get(ctx context.Context, idOrName string) (*SSHKey, *Response, error) { func (c *SSHKeyClient) Get(ctx context.Context, idOrName string) (*SSHKey, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// SSHKeyListOpts specifies options for listing SSH keys. // SSHKeyListOpts specifies options for listing SSH keys.
@ -103,22 +90,17 @@ func (l SSHKeyListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *SSHKeyClient) List(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, *Response, error) { func (c *SSHKeyClient) List(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, *Response, error) {
path := "/ssh_keys?" + opts.values().Encode() const opPath = "/ssh_keys?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.SSHKeyListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.SSHKeyListResponse return allFromSchemaFunc(respBody.SSHKeys, SSHKeyFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
sshKeys := make([]*SSHKey, 0, len(body.SSHKeys))
for _, s := range body.SSHKeys {
sshKeys = append(sshKeys, SSHKeyFromSchema(s))
}
return sshKeys, resp, nil
} }
// All returns all SSH keys. // All returns all SSH keys.
@ -128,22 +110,10 @@ func (c *SSHKeyClient) All(ctx context.Context) ([]*SSHKey, error) {
// AllWithOpts returns all SSH keys with the given options. // AllWithOpts returns all SSH keys with the given options.
func (c *SSHKeyClient) AllWithOpts(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, error) { func (c *SSHKeyClient) AllWithOpts(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, error) {
allSSHKeys := []*SSHKey{} return iterPages(func(page int) ([]*SSHKey, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
sshKeys, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allSSHKeys = append(allSSHKeys, sshKeys...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allSSHKeys, nil
} }
// SSHKeyCreateOpts specifies parameters for creating a SSH key. // SSHKeyCreateOpts specifies parameters for creating a SSH key.
@ -156,16 +126,21 @@ type SSHKeyCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o SSHKeyCreateOpts) Validate() error { func (o SSHKeyCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
if o.PublicKey == "" { if o.PublicKey == "" {
return errors.New("missing public key") return missingField(o, "PublicKey")
} }
return nil return nil
} }
// Create creates a new SSH key with the given options. // Create creates a new SSH key with the given options.
func (c *SSHKeyClient) Create(ctx context.Context, opts SSHKeyCreateOpts) (*SSHKey, *Response, error) { func (c *SSHKeyClient) Create(ctx context.Context, opts SSHKeyCreateOpts) (*SSHKey, *Response, error) {
const opPath = "/ssh_keys"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -176,31 +151,23 @@ func (c *SSHKeyClient) Create(ctx context.Context, opts SSHKeyCreateOpts) (*SSHK
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/ssh_keys", bytes.NewReader(reqBodyData)) respBody, resp, err := postRequest[schema.SSHKeyCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, nil, err
}
var respBody schema.SSHKeyCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return SSHKeyFromSchema(respBody.SSHKey), resp, nil return SSHKeyFromSchema(respBody.SSHKey), resp, nil
} }
// Delete deletes a SSH key. // Delete deletes a SSH key.
func (c *SSHKeyClient) Delete(ctx context.Context, sshKey *SSHKey) (*Response, error) { func (c *SSHKeyClient) Delete(ctx context.Context, sshKey *SSHKey) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/ssh_keys/%d", sshKey.ID), nil) const opPath = "/ssh_keys/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, sshKey.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// SSHKeyUpdateOpts specifies options for updating a SSH key. // SSHKeyUpdateOpts specifies options for updating a SSH key.
@ -211,27 +178,22 @@ type SSHKeyUpdateOpts struct {
// Update updates a SSH key. // Update updates a SSH key.
func (c *SSHKeyClient) Update(ctx context.Context, sshKey *SSHKey, opts SSHKeyUpdateOpts) (*SSHKey, *Response, error) { func (c *SSHKeyClient) Update(ctx context.Context, sshKey *SSHKey, opts SSHKeyUpdateOpts) (*SSHKey, *Response, error) {
const opPath = "/ssh_keys/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, sshKey.ID)
reqBody := schema.SSHKeyUpdateRequest{ reqBody := schema.SSHKeyUpdateRequest{
Name: opts.Name, Name: opts.Name,
} }
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/ssh_keys/%d", sshKey.ID) respBody, resp, err := putRequest[schema.SSHKeyUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.SSHKeyUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return SSHKeyFromSchema(respBody.SSHKey), resp, nil return SSHKeyFromSchema(respBody.SSHKey), resp, nil
} }

View File

@ -1,15 +1,12 @@
package hcloud package hcloud
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"time" "time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
) )
@ -57,41 +54,33 @@ const (
// GetByID retrieves a volume by its ID. If the volume does not exist, nil is returned. // GetByID retrieves a volume by its ID. If the volume does not exist, nil is returned.
func (c *VolumeClient) GetByID(ctx context.Context, id int64) (*Volume, *Response, error) { func (c *VolumeClient) GetByID(ctx context.Context, id int64) (*Volume, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/volumes/%d", id), nil) const opPath = "/volumes/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, nil, err
}
var body schema.VolumeGetResponse reqPath := fmt.Sprintf(opPath, id)
resp, err := c.client.Do(req, &body)
respBody, resp, err := getRequest[schema.VolumeGetResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
if IsError(err, ErrorCodeNotFound) { if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil return nil, resp, nil
} }
return nil, nil, err return nil, resp, err
} }
return VolumeFromSchema(body.Volume), resp, nil
return VolumeFromSchema(respBody.Volume), resp, nil
} }
// GetByName retrieves a volume by its name. If the volume does not exist, nil is returned. // GetByName retrieves a volume by its name. If the volume does not exist, nil is returned.
func (c *VolumeClient) GetByName(ctx context.Context, name string) (*Volume, *Response, error) { func (c *VolumeClient) GetByName(ctx context.Context, name string) (*Volume, *Response, error) {
if name == "" { return firstByName(name, func() ([]*Volume, *Response, error) {
return nil, nil, nil return c.List(ctx, VolumeListOpts{Name: name})
} })
volumes, response, err := c.List(ctx, VolumeListOpts{Name: name})
if len(volumes) == 0 {
return nil, response, err
}
return volumes[0], response, err
} }
// Get retrieves a volume by its ID if the input can be parsed as an integer, otherwise it // Get retrieves a volume by its ID if the input can be parsed as an integer, otherwise it
// retrieves a volume by its name. If the volume does not exist, nil is returned. // retrieves a volume by its name. If the volume does not exist, nil is returned.
func (c *VolumeClient) Get(ctx context.Context, idOrName string) (*Volume, *Response, error) { func (c *VolumeClient) Get(ctx context.Context, idOrName string) (*Volume, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil { return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
} }
// VolumeListOpts specifies options for listing volumes. // VolumeListOpts specifies options for listing volumes.
@ -121,22 +110,17 @@ func (l VolumeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account // Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
func (c *VolumeClient) List(ctx context.Context, opts VolumeListOpts) ([]*Volume, *Response, error) { func (c *VolumeClient) List(ctx context.Context, opts VolumeListOpts) ([]*Volume, *Response, error) {
path := "/volumes?" + opts.values().Encode() const opPath = "/volumes?%s"
req, err := c.client.NewRequest(ctx, "GET", path, nil) ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.VolumeListResponse](ctx, c.client, reqPath)
if err != nil { if err != nil {
return nil, nil, err return nil, resp, err
} }
var body schema.VolumeListResponse return allFromSchemaFunc(respBody.Volumes, VolumeFromSchema), resp, nil
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
volumes := make([]*Volume, 0, len(body.Volumes))
for _, s := range body.Volumes {
volumes = append(volumes, VolumeFromSchema(s))
}
return volumes, resp, nil
} }
// All returns all volumes. // All returns all volumes.
@ -146,22 +130,10 @@ func (c *VolumeClient) All(ctx context.Context) ([]*Volume, error) {
// AllWithOpts returns all volumes with the given options. // AllWithOpts returns all volumes with the given options.
func (c *VolumeClient) AllWithOpts(ctx context.Context, opts VolumeListOpts) ([]*Volume, error) { func (c *VolumeClient) AllWithOpts(ctx context.Context, opts VolumeListOpts) ([]*Volume, error) {
allVolumes := []*Volume{} return iterPages(func(page int) ([]*Volume, *Response, error) {
err := c.client.all(func(page int) (*Response, error) {
opts.Page = page opts.Page = page
volumes, resp, err := c.List(ctx, opts) return c.List(ctx, opts)
if err != nil {
return resp, err
}
allVolumes = append(allVolumes, volumes...)
return resp, nil
}) })
if err != nil {
return nil, err
}
return allVolumes, nil
} }
// VolumeCreateOpts specifies parameters for creating a volume. // VolumeCreateOpts specifies parameters for creating a volume.
@ -178,19 +150,19 @@ type VolumeCreateOpts struct {
// Validate checks if options are valid. // Validate checks if options are valid.
func (o VolumeCreateOpts) Validate() error { func (o VolumeCreateOpts) Validate() error {
if o.Name == "" { if o.Name == "" {
return errors.New("missing name") return missingField(o, "Name")
} }
if o.Size <= 0 { if o.Size <= 0 {
return errors.New("size must be greater than 0") return invalidFieldValue(o, "Size", o.Size)
} }
if o.Server == nil && o.Location == nil { if o.Server == nil && o.Location == nil {
return errors.New("one of server or location must be provided") return missingOneOfFields(o, "Server", "Location")
} }
if o.Server != nil && o.Location != nil { if o.Server != nil && o.Location != nil {
return errors.New("only one of server or location must be provided") return mutuallyExclusiveFields(o, "Server", "Location")
} }
if o.Server == nil && (o.Automount != nil && *o.Automount) { if o.Server == nil && (o.Automount != nil && *o.Automount) {
return errors.New("server must be provided when automount is true") return missingRequiredTogetherFields(o, "Automount", "Server")
} }
return nil return nil
} }
@ -204,8 +176,15 @@ type VolumeCreateResult struct {
// Create creates a new volume with the given options. // Create creates a new volume with the given options.
func (c *VolumeClient) Create(ctx context.Context, opts VolumeCreateOpts) (VolumeCreateResult, *Response, error) { func (c *VolumeClient) Create(ctx context.Context, opts VolumeCreateOpts) (VolumeCreateResult, *Response, error) {
const opPath = "/volumes"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := VolumeCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return VolumeCreateResult{}, nil, err return result, nil, err
} }
reqBody := schema.VolumeCreateRequest{ reqBody := schema.VolumeCreateRequest{
Name: opts.Name, Name: opts.Name,
@ -220,48 +199,33 @@ func (c *VolumeClient) Create(ctx context.Context, opts VolumeCreateOpts) (Volum
reqBody.Server = Ptr(opts.Server.ID) reqBody.Server = Ptr(opts.Server.ID)
} }
if opts.Location != nil { if opts.Location != nil {
if opts.Location.ID != 0 { if opts.Location.ID != 0 || opts.Location.Name != "" {
reqBody.Location = opts.Location.ID reqBody.Location = &schema.IDOrName{ID: opts.Location.ID, Name: opts.Location.Name}
} else {
reqBody.Location = opts.Location.Name
} }
} }
reqBodyData, err := json.Marshal(reqBody) respBody, resp, err := postRequest[schema.VolumeCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil { if err != nil {
return VolumeCreateResult{}, nil, err return result, resp, err
} }
req, err := c.client.NewRequest(ctx, "POST", "/volumes", bytes.NewReader(reqBodyData)) result.Volume = VolumeFromSchema(respBody.Volume)
if err != nil {
return VolumeCreateResult{}, nil, err
}
var respBody schema.VolumeCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return VolumeCreateResult{}, resp, err
}
var action *Action
if respBody.Action != nil { if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action) result.Action = ActionFromSchema(*respBody.Action)
} }
result.NextActions = ActionsFromSchema(respBody.NextActions)
return VolumeCreateResult{ return result, resp, nil
Volume: VolumeFromSchema(respBody.Volume),
Action: action,
NextActions: ActionsFromSchema(respBody.NextActions),
}, resp, nil
} }
// Delete deletes a volume. // Delete deletes a volume.
func (c *VolumeClient) Delete(ctx context.Context, volume *Volume) (*Response, error) { func (c *VolumeClient) Delete(ctx context.Context, volume *Volume) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/volumes/%d", volume.ID), nil) const opPath = "/volumes/%d"
if err != nil { ctx = ctxutil.SetOpPath(ctx, opPath)
return nil, err
} reqPath := fmt.Sprintf(opPath, volume.ID)
return c.client.Do(req, nil)
return deleteRequestNoResult(ctx, c.client, reqPath)
} }
// VolumeUpdateOpts specifies options for updating a volume. // VolumeUpdateOpts specifies options for updating a volume.
@ -272,28 +236,23 @@ type VolumeUpdateOpts struct {
// Update updates a volume. // Update updates a volume.
func (c *VolumeClient) Update(ctx context.Context, volume *Volume, opts VolumeUpdateOpts) (*Volume, *Response, error) { func (c *VolumeClient) Update(ctx context.Context, volume *Volume, opts VolumeUpdateOpts) (*Volume, *Response, error) {
const opPath = "/volumes/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeUpdateRequest{ reqBody := schema.VolumeUpdateRequest{
Name: opts.Name, Name: opts.Name,
} }
if opts.Labels != nil { if opts.Labels != nil {
reqBody.Labels = &opts.Labels reqBody.Labels = &opts.Labels
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d", volume.ID) respBody, resp, err := putRequest[schema.VolumeUpdateResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return VolumeFromSchema(respBody.Volume), resp, nil return VolumeFromSchema(respBody.Volume), resp, nil
} }
@ -305,27 +264,21 @@ type VolumeAttachOpts struct {
// AttachWithOpts attaches a volume to a server. // AttachWithOpts attaches a volume to a server.
func (c *VolumeClient) AttachWithOpts(ctx context.Context, volume *Volume, opts VolumeAttachOpts) (*Action, *Response, error) { func (c *VolumeClient) AttachWithOpts(ctx context.Context, volume *Volume, opts VolumeAttachOpts) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/attach"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionAttachVolumeRequest{ reqBody := schema.VolumeActionAttachVolumeRequest{
Server: opts.Server.ID, Server: opts.Server.ID,
Automount: opts.Automount, Automount: opts.Automount,
} }
reqBodyData, err := json.Marshal(reqBody) respBody, resp, err := postRequest[schema.VolumeActionAttachVolumeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/attach", volume.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.VolumeActionAttachVolumeResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -336,23 +289,18 @@ func (c *VolumeClient) Attach(ctx context.Context, volume *Volume, server *Serve
// Detach detaches a volume from a server. // Detach detaches a volume from a server.
func (c *VolumeClient) Detach(ctx context.Context, volume *Volume) (*Action, *Response, error) { func (c *VolumeClient) Detach(ctx context.Context, volume *Volume) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/detach"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
var reqBody schema.VolumeActionDetachVolumeRequest var reqBody schema.VolumeActionDetachVolumeRequest
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/detach", volume.ID) respBody, resp, err := postRequest[schema.VolumeActionDetachVolumeResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.VolumeActionDetachVolumeResponse
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, nil return ActionFromSchema(respBody.Action), resp, nil
} }
@ -363,48 +311,38 @@ type VolumeChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a volume. // ChangeProtection changes the resource protection level of a volume.
func (c *VolumeClient) ChangeProtection(ctx context.Context, volume *Volume, opts VolumeChangeProtectionOpts) (*Action, *Response, error) { func (c *VolumeClient) ChangeProtection(ctx context.Context, volume *Volume, opts VolumeChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionChangeProtectionRequest{ reqBody := schema.VolumeActionChangeProtectionRequest{
Delete: opts.Delete, Delete: opts.Delete,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/change_protection", volume.ID) respBody, resp, err := postRequest[schema.VolumeActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
} }
// Resize changes the size of a volume. // Resize changes the size of a volume.
func (c *VolumeClient) Resize(ctx context.Context, volume *Volume, size int) (*Action, *Response, error) { func (c *VolumeClient) Resize(ctx context.Context, volume *Volume, size int) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/resize"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionResizeVolumeRequest{ reqBody := schema.VolumeActionResizeVolumeRequest{
Size: size, Size: size,
} }
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/resize", volume.ID) respBody, resp, err := postRequest[schema.VolumeActionResizeVolumeResponse](ctx, c.client, reqPath, reqBody)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeActionResizeVolumeResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return ActionFromSchema(respBody.Action), resp, err return ActionFromSchema(respBody.Action), resp, err
} }

View File

@ -16,8 +16,12 @@ type IActionClient interface {
// when their value corresponds to their zero value or when they are empty. // when their value corresponds to their zero value or when they are empty.
List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error) List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error)
// All returns all actions. // All returns all actions.
//
// Deprecated: It is required to pass in a list of IDs since 30 January 2025. Please use [ActionClient.AllWithOpts] instead.
All(ctx context.Context) ([]*Action, error) All(ctx context.Context) ([]*Action, error)
// AllWithOpts returns all actions for the given options. // AllWithOpts returns all actions for the given options.
//
// It is required to set [ActionListOpts.ID]. Any other fields set in the opts are ignored.
AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error) AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error)
// WatchOverallProgress watches several actions' progress until they complete // WatchOverallProgress watches several actions' progress until they complete
// with success or error. This watching happens in a goroutine and updates are // with success or error. This watching happens in a goroutine and updates are
@ -35,7 +39,7 @@ type IActionClient interface {
// timeout, use the [context.Context]. Once the method has stopped watching, // timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed. // both returned channels are closed.
// //
// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait // WatchOverallProgress uses the [WithPollOpts] of the [Client] to wait
// until sending the next request. // until sending the next request.
// //
// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead. // Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead.
@ -56,19 +60,19 @@ type IActionClient interface {
// timeout, use the [context.Context]. Once the method has stopped watching, // timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed. // both returned channels are closed.
// //
// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until // WatchProgress uses the [WithPollOpts] of the [Client] to wait until
// sending the next request. // sending the next request.
// //
// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead. // Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead.
WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error) WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error)
// WaitForFunc waits until all actions are completed by polling the API at the interval // WaitForFunc waits until all actions are completed by polling the API at the interval
// defined by [WithPollBackoffFunc]. An action is considered as complete when its status is // defined by [WithPollOpts]. An action is considered as complete when its status is
// either [ActionStatusSuccess] or [ActionStatusError]. // either [ActionStatusSuccess] or [ActionStatusError].
// //
// The handleUpdate callback is called every time an action is updated. // The handleUpdate callback is called every time an action is updated.
WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error
// WaitFor waits until all actions succeed by polling the API at the interval defined by // WaitFor waits until all actions succeed by polling the API at the interval defined by
// [WithPollBackoffFunc]. An action is considered as succeeded when its status is either // [WithPollOpts]. An action is considered as succeeded when its status is either
// [ActionStatusSuccess]. // [ActionStatusSuccess].
// //
// If a single action fails, the function will stop waiting and the error set in the // If a single action fails, the function will stop waiting and the error set in the

View File

@ -64,7 +64,7 @@ type ILoadBalancerClient interface {
// ChangeType changes a Load Balancer's type. // ChangeType changes a Load Balancer's type.
ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error) ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error)
// GetMetrics obtains metrics for a Load Balancer. // GetMetrics obtains metrics for a Load Balancer.
GetMetrics(ctx context.Context, lb *LoadBalancer, opts LoadBalancerGetMetricsOpts) (*LoadBalancerMetrics, *Response, error) GetMetrics(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerGetMetricsOpts) (*LoadBalancerMetrics, *Response, error)
// ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer. // ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer.
// Pass a nil ptr to reset the reverse DNS pointer to its default value. // Pass a nil ptr to reset the reverse DNS pointer to its default value.
ChangeDNSPtr(ctx context.Context, lb *LoadBalancer, ip string, ptr *string) (*Action, *Response, error) ChangeDNSPtr(ctx context.Context, lb *LoadBalancer, ip string, ptr *string) (*Action, *Response, error)

View File

@ -27,11 +27,11 @@ type IPrimaryIPClient interface {
// AllWithOpts returns all Primary IPs for the given options. // AllWithOpts returns all Primary IPs for the given options.
AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error) AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error)
// Create creates a Primary IP. // Create creates a Primary IP.
Create(ctx context.Context, reqBody PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error) Create(ctx context.Context, opts PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error)
// Delete deletes a Primary IP. // Delete deletes a Primary IP.
Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error) Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error)
// Update updates a Primary IP. // Update updates a Primary IP.
Update(ctx context.Context, primaryIP *PrimaryIP, reqBody PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error) Update(ctx context.Context, primaryIP *PrimaryIP, opts PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error)
// Assign a Primary IP to a resource. // Assign a Primary IP to a resource.
Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error) Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error)
// Unassign a Primary IP from a resource. // Unassign a Primary IP from a resource.

File diff suppressed because it is too large Load Diff

View File

@ -62,9 +62,8 @@ type IServerClient interface {
AttachISO(ctx context.Context, server *Server, iso *ISO) (*Action, *Response, error) AttachISO(ctx context.Context, server *Server, iso *ISO) (*Action, *Response, error)
// DetachISO detaches the currently attached ISO from a server. // DetachISO detaches the currently attached ISO from a server.
DetachISO(ctx context.Context, server *Server) (*Action, *Response, error) DetachISO(ctx context.Context, server *Server) (*Action, *Response, error)
// EnableBackup enables backup for a server. Pass in an empty backup window to let the // EnableBackup enables backup for a server.
// API pick a window for you. See the API documentation at docs.hetzner.cloud for a list // The window parameter is deprecated and will be ignored.
// of valid backup windows.
EnableBackup(ctx context.Context, server *Server, window string) (*Action, *Response, error) EnableBackup(ctx context.Context, server *Server, window string) (*Action, *Response, error)
// DisableBackup disables backup for a server. // DisableBackup disables backup for a server.
DisableBackup(ctx context.Context, server *Server) (*Action, *Response, error) DisableBackup(ctx context.Context, server *Server) (*Action, *Response, error)

View File

@ -95,7 +95,9 @@ func newManager() (*hetznerManager, error) {
hcloud.WithToken(token), hcloud.WithToken(token),
hcloud.WithHTTPClient(httpClient), hcloud.WithHTTPClient(httpClient),
hcloud.WithApplication("cluster-autoscaler", version.ClusterAutoscalerVersion), hcloud.WithApplication("cluster-autoscaler", version.ClusterAutoscalerVersion),
hcloud.WithPollBackoffFunc(hcloud.ExponentialBackoff(2, 500*time.Millisecond)), hcloud.WithPollOpts(hcloud.PollOpts{
BackoffFunc: hcloud.ExponentialBackoff(2, 500*time.Millisecond),
}),
hcloud.WithDebugWriter(&debugWriter{}), hcloud.WithDebugWriter(&debugWriter{}),
} }
@ -252,7 +254,7 @@ func (m *hetznerManager) deleteByNode(node *apiv1.Node) error {
} }
func (m *hetznerManager) deleteServer(server *hcloud.Server) error { func (m *hetznerManager) deleteServer(server *hcloud.Server) error {
_, err := m.client.Server.Delete(m.apiCallContext, server) _, _, err := m.client.Server.DeleteWithResult(m.apiCallContext, server)
return err return err
} }

View File

@ -34,7 +34,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/utils/units" "k8s.io/autoscaler/cluster-autoscaler/utils/units"
"k8s.io/client-go/rest" "k8s.io/client-go/rest"
klog "k8s.io/klog/v2" "k8s.io/klog/v2"
kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config" kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config"
scheduler_config "k8s.io/kubernetes/pkg/scheduler/apis/config" scheduler_config "k8s.io/kubernetes/pkg/scheduler/apis/config"
) )
@ -269,12 +269,12 @@ func createAutoscalingOptions() config.AutoscalingOptions {
klog.Fatalf("Failed to get scheduler config: %v", err) klog.Fatalf("Failed to get scheduler config: %v", err)
} }
if isFlagPassed("drain-priority-config") && isFlagPassed("max-graceful-termination-sec") { if pflag.CommandLine.Changed("drain-priority-config") && pflag.CommandLine.Changed("max-graceful-termination-sec") {
klog.Fatalf("Invalid configuration, could not use --drain-priority-config together with --max-graceful-termination-sec") klog.Fatalf("Invalid configuration, could not use --drain-priority-config together with --max-graceful-termination-sec")
} }
var drainPriorityConfigMap []kubelet_config.ShutdownGracePeriodByPodPriority var drainPriorityConfigMap []kubelet_config.ShutdownGracePeriodByPodPriority
if isFlagPassed("drain-priority-config") { if pflag.CommandLine.Changed("drain-priority-config") {
drainPriorityConfigMap = parseShutdownGracePeriodsAndPriorities(*drainPriorityConfig) drainPriorityConfigMap = parseShutdownGracePeriodsAndPriorities(*drainPriorityConfig)
if len(drainPriorityConfigMap) == 0 { if len(drainPriorityConfigMap) == 0 {
klog.Fatalf("Invalid configuration, parsing --drain-priority-config") klog.Fatalf("Invalid configuration, parsing --drain-priority-config")
@ -409,16 +409,6 @@ func createAutoscalingOptions() config.AutoscalingOptions {
} }
} }
func isFlagPassed(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}
func minMaxFlagString(min, max int64) string { func minMaxFlagString(min, max int64) string {
return fmt.Sprintf("%v:%v", min, max) return fmt.Sprintf("%v:%v", min, max)
} }

View File

@ -17,11 +17,15 @@ limitations under the License.
package flags package flags
import ( import (
"flag"
"testing" "testing"
"k8s.io/autoscaler/cluster-autoscaler/config" "k8s.io/autoscaler/cluster-autoscaler/config"
kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config" kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -146,3 +150,47 @@ func TestParseShutdownGracePeriodsAndPriorities(t *testing.T) {
}) })
} }
} }
func TestCreateAutoscalingOptions(t *testing.T) {
for _, tc := range []struct {
testName string
flags []string
wantOptionsAsserter func(t *testing.T, gotOptions config.AutoscalingOptions)
}{
{
testName: "DrainPriorityConfig defaults to an empty list when the flag isn't passed",
flags: []string{},
wantOptionsAsserter: func(t *testing.T, gotOptions config.AutoscalingOptions) {
if diff := cmp.Diff([]kubelet_config.ShutdownGracePeriodByPodPriority{}, gotOptions.DrainPriorityConfig, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("createAutoscalingOptions(): unexpected DrainPriorityConfig field (-want +got): %s", diff)
}
},
},
{
testName: "DrainPriorityConfig is parsed correctly when the flag passed",
flags: []string{"--drain-priority-config", "5000:60,3000:50,0:40"},
wantOptionsAsserter: func(t *testing.T, gotOptions config.AutoscalingOptions) {
wantConfig := []kubelet_config.ShutdownGracePeriodByPodPriority{
{Priority: 5000, ShutdownGracePeriodSeconds: 60},
{Priority: 3000, ShutdownGracePeriodSeconds: 50},
{Priority: 0, ShutdownGracePeriodSeconds: 40},
}
if diff := cmp.Diff(wantConfig, gotOptions.DrainPriorityConfig); diff != "" {
t.Errorf("createAutoscalingOptions(): unexpected DrainPriorityConfig field (-want +got): %s", diff)
}
},
},
} {
t.Run(tc.testName, func(t *testing.T) {
pflag.CommandLine = pflag.NewFlagSet("test", pflag.ExitOnError)
pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
err := pflag.CommandLine.Parse(tc.flags)
if err != nil {
t.Errorf("pflag.CommandLine.Parse() got unexpected error: %v", err)
}
gotOptions := createAutoscalingOptions()
tc.wantOptionsAsserter(t, gotOptions)
})
}
}

View File

@ -6,9 +6,9 @@ require (
cloud.google.com/go/compute/metadata v0.5.0 cloud.google.com/go/compute/metadata v0.5.0
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible github.com/Azure/azure-sdk-for-go v68.0.0+incompatible
github.com/Azure/azure-sdk-for-go-extensions v0.1.6 github.com/Azure/azure-sdk-for-go-extensions v0.1.6
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2
github.com/Azure/go-autorest/autorest v0.11.29 github.com/Azure/go-autorest/autorest v0.11.29
github.com/Azure/go-autorest/autorest/adal v0.9.24 github.com/Azure/go-autorest/autorest/adal v0.9.24
github.com/Azure/go-autorest/autorest/azure/auth v0.5.13 github.com/Azure/go-autorest/autorest/azure/auth v0.5.13
@ -32,6 +32,7 @@ require (
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/vburenin/ifacemaker v1.2.1 github.com/vburenin/ifacemaker v1.2.1
go.uber.org/mock v0.4.0 go.uber.org/mock v0.4.0
golang.org/x/crypto v0.35.0
golang.org/x/net v0.33.0 golang.org/x/net v0.33.0
golang.org/x/oauth2 v0.27.0 golang.org/x/oauth2 v0.27.0
golang.org/x/sys v0.30.0 golang.org/x/sys v0.30.0
@ -62,11 +63,12 @@ require (
require ( require (
cel.dev/expr v0.19.1 // indirect cel.dev/expr v0.19.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5 v5.6.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5 v5.6.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 // indirect
@ -186,7 +188,6 @@ require (
go.opentelemetry.io/proto/otlp v1.4.0 // indirect go.opentelemetry.io/proto/otlp v1.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.35.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/mod v0.21.0 // indirect golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.11.0 // indirect golang.org/x/sync v0.11.0 // indirect

View File

@ -4,17 +4,16 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
github.com/Azure/azure-sdk-for-go v46.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go-extensions v0.1.6 h1:EXGvDcj54u98XfaI/Cy65Ds6vNsIJeGKYf0eNLB1y4Q= github.com/Azure/azure-sdk-for-go-extensions v0.1.6 h1:EXGvDcj54u98XfaI/Cy65Ds6vNsIJeGKYf0eNLB1y4Q=
github.com/Azure/azure-sdk-for-go-extensions v0.1.6/go.mod h1:27StPiXJp6Xzkq2AQL7gPK7VC0hgmCnUKlco1dO1jaM= github.com/Azure/azure-sdk-for-go-extensions v0.1.6/go.mod h1:27StPiXJp6Xzkq2AQL7gPK7VC0hgmCnUKlco1dO1jaM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 h1:FDif4R1+UUR+00q6wquyX90K7A8dN+R5E8GEadoP7sU= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2/go.mod h1:aiYBYui4BJ/BJCAIKs92XiPyQfTaBWqvHujDwKb6CBU= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 h1:xnO4sFyG8UH2fElBkcqLTOZsAajvKfnSlgBBW8dXYjw= github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 h1:xnO4sFyG8UH2fElBkcqLTOZsAajvKfnSlgBBW8dXYjw=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0/go.mod h1:XD3DIOOVgBCO03OleB1fHjgktVRFxlT++KwKgIOewdM= github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0/go.mod h1:XD3DIOOVgBCO03OleB1fHjgktVRFxlT++KwKgIOewdM=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw= github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw=
@ -25,10 +24,14 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armconta
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0/go.mod h1:E7ltexgRDmeJ0fJWv0D/HLwY2xbDdN+uv+X2uZtOx3w= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0/go.mod h1:E7ltexgRDmeJ0fJWv0D/HLwY2xbDdN+uv+X2uZtOx3w=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0 h1:1u/K2BFv0MwkG6he8RYuUcbbeK22rkoZbg4lKa/msZU= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0 h1:1u/K2BFv0MwkG6he8RYuUcbbeK22rkoZbg4lKa/msZU=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0/go.mod h1:U5gpsREQZE6SLk1t/cFfc1eMhYAlYpEzvaYXuDfefy8= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0/go.mod h1:U5gpsREQZE6SLk1t/cFfc1eMhYAlYpEzvaYXuDfefy8=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1 h1:iqhrjj9w9/AQZsHjaOVyloamkeAFRbWI0iHNy6INMYk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 h1:0nGmzwBv5ougvzfGPCO2ljFRHvun57KpNrVCMrlk0ns=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2 h1:re+BEe/OafvSyRy2vM+Fyu+EcUK34O2o/Fa6WO3ITZM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2/go.mod h1:5zx285T5OLk+iQbfOuexhhO7J6dfzkqVkFgS/+s7XaA=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 h1:HlZMUZW8S4P9oob1nCHxCCKrytxyLc+24nUJGssoEto= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 h1:HlZMUZW8S4P9oob1nCHxCCKrytxyLc+24nUJGssoEto=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0/go.mod h1:StGsLbuJh06Bd8IBfnAlIFV3fLb+gkczONWf15hpX2E= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0/go.mod h1:StGsLbuJh06Bd8IBfnAlIFV3fLb+gkczONWf15hpX2E=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o=
@ -45,12 +48,9 @@ github.com/Azure/go-armbalancer v0.0.2 h1:NVnxsTWHI5/fEzL6k6TjxPUfcB/3Si3+HFOZXO
github.com/Azure/go-armbalancer v0.0.2/go.mod h1:yTg7MA/8YnfKQc9o97tzAJ7fbdVkod1xGsIvKmhYPRE= github.com/Azure/go-armbalancer v0.0.2/go.mod h1:yTg7MA/8YnfKQc9o97tzAJ7fbdVkod1xGsIvKmhYPRE=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24=
github.com/Azure/go-autorest/autorest v0.11.4/go.mod h1:JFgpikqFJ/MleTTxwepExTKnFUKKszPS8UavbQYUMuw=
github.com/Azure/go-autorest/autorest v0.11.28/go.mod h1:MrkzG3Y3AH668QyF9KRk5neJnGgmhQ6krbhR8Q5eMvA= github.com/Azure/go-autorest/autorest v0.11.28/go.mod h1:MrkzG3Y3AH668QyF9KRk5neJnGgmhQ6krbhR8Q5eMvA=
github.com/Azure/go-autorest/autorest v0.11.29 h1:I4+HL/JDvErx2LjyzaVxllw2lRDB5/BT2Bm4g20iqYw= github.com/Azure/go-autorest/autorest v0.11.29 h1:I4+HL/JDvErx2LjyzaVxllw2lRDB5/BT2Bm4g20iqYw=
github.com/Azure/go-autorest/autorest v0.11.29/go.mod h1:ZtEzC4Jy2JDrZLxvWs8LrBWEBycl1hbT1eknI8MtfAs= github.com/Azure/go-autorest/autorest v0.11.29/go.mod h1:ZtEzC4Jy2JDrZLxvWs8LrBWEBycl1hbT1eknI8MtfAs=
github.com/Azure/go-autorest/autorest/adal v0.9.0/go.mod h1:/c022QCutn2P7uY+/oQWWNcK9YU+MH96NgK+jErpbcg=
github.com/Azure/go-autorest/autorest/adal v0.9.2/go.mod h1:/3SMAM86bP6wC9Ev35peQDUeqFZBMH07vvUOmg4z/fE=
github.com/Azure/go-autorest/autorest/adal v0.9.18/go.mod h1:XVVeme+LZwABT8K5Lc3hA4nAe8LDBVle26gTrguhhPQ= github.com/Azure/go-autorest/autorest/adal v0.9.18/go.mod h1:XVVeme+LZwABT8K5Lc3hA4nAe8LDBVle26gTrguhhPQ=
github.com/Azure/go-autorest/autorest/adal v0.9.22/go.mod h1:XuAbAEUv2Tta//+voMI038TrJBqjKam0me7qR+L8Cmk= github.com/Azure/go-autorest/autorest/adal v0.9.22/go.mod h1:XuAbAEUv2Tta//+voMI038TrJBqjKam0me7qR+L8Cmk=
github.com/Azure/go-autorest/autorest/adal v0.9.24 h1:BHZfgGsGwdkHDyZdtQRQk1WeUdW0m2WPAwuHZwUi5i4= github.com/Azure/go-autorest/autorest/adal v0.9.24 h1:BHZfgGsGwdkHDyZdtQRQk1WeUdW0m2WPAwuHZwUi5i4=
@ -61,22 +61,17 @@ github.com/Azure/go-autorest/autorest/azure/cli v0.4.6 h1:w77/uPk80ZET2F+AfQExZy
github.com/Azure/go-autorest/autorest/azure/cli v0.4.6/go.mod h1:piCfgPho7BiIDdEQ1+g4VmKyD5y+p/XtSNqE6Hc4QD0= github.com/Azure/go-autorest/autorest/azure/cli v0.4.6/go.mod h1:piCfgPho7BiIDdEQ1+g4VmKyD5y+p/XtSNqE6Hc4QD0=
github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw=
github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74=
github.com/Azure/go-autorest/autorest/mocks v0.4.0/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9AGVwYHG9/fkDYhtAfw= github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9AGVwYHG9/fkDYhtAfw=
github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU= github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU=
github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk=
github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE=
github.com/Azure/go-autorest/autorest/validation v0.3.0/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E=
github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac= github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac=
github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E=
github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg=
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/Azure/skewer v0.0.14 h1:0mzUJhspECkajYyynYsOCp//E2PSnYXrgP45bcskqfQ=
github.com/Azure/skewer v0.0.14/go.mod h1:6WTecuPyfGtuvS8Mh4JYWuHhO4kcWycGfsUBB+XTFG4=
github.com/Azure/skewer v0.0.19 h1:+qA1z8isKmlNkhAwZErNS2wD2jaemSk9NszYKr8dddU= github.com/Azure/skewer v0.0.19 h1:+qA1z8isKmlNkhAwZErNS2wD2jaemSk9NszYKr8dddU=
github.com/Azure/skewer v0.0.19/go.mod h1:LVH7jmduRKmPj8YcIz7V4f53xJEntjweL4aoLyChkwk= github.com/Azure/skewer v0.0.19/go.mod h1:LVH7jmduRKmPj8YcIz7V4f53xJEntjweL4aoLyChkwk=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
@ -139,7 +134,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/digitalocean/godo v1.27.0 h1:78iE9oVvTnAEqhMip2UHFvL01b8LJcydbNUpr0cAmN4= github.com/digitalocean/godo v1.27.0 h1:78iE9oVvTnAEqhMip2UHFvL01b8LJcydbNUpr0cAmN4=
github.com/digitalocean/godo v1.27.0/go.mod h1:iJnN9rVu6K5LioLxLimlq0uRI+y/eAQjROUmeU/r0hY= github.com/digitalocean/godo v1.27.0/go.mod h1:iJnN9rVu6K5LioLxLimlq0uRI+y/eAQjROUmeU/r0hY=
github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U= github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U=
@ -232,7 +226,6 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@ -442,7 +435,6 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=

View File

@ -504,9 +504,7 @@ func UpdateDurationFromStart(label FunctionLabel, start time.Time) {
// UpdateDuration records the duration of the step identified by the label // UpdateDuration records the duration of the step identified by the label
func UpdateDuration(label FunctionLabel, duration time.Duration) { func UpdateDuration(label FunctionLabel, duration time.Duration) {
// TODO(maciekpytel): remove second condition if we manage to get if duration > LogLongDurationThreshold {
// asynchronous node drain
if duration > LogLongDurationThreshold && label != ScaleDown {
klog.V(4).Infof("Function %s took %v to complete", label, duration) klog.V(4).Infof("Function %s took %v to complete", label, duration)
} }
functionDuration.WithLabelValues(string(label)).Observe(duration.Seconds()) functionDuration.WithLabelValues(string(label)).Observe(duration.Seconds())

View File

@ -24,6 +24,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/simulator/clustersnapshot" "k8s.io/autoscaler/cluster-autoscaler/simulator/clustersnapshot"
drautils "k8s.io/autoscaler/cluster-autoscaler/simulator/dynamicresources/utils" drautils "k8s.io/autoscaler/cluster-autoscaler/simulator/dynamicresources/utils"
"k8s.io/autoscaler/cluster-autoscaler/simulator/framework" "k8s.io/autoscaler/cluster-autoscaler/simulator/framework"
"k8s.io/dynamic-resource-allocation/resourceclaim"
schedulerframework "k8s.io/kubernetes/pkg/scheduler/framework" schedulerframework "k8s.io/kubernetes/pkg/scheduler/framework"
) )
@ -245,7 +246,7 @@ func (s *PredicateSnapshot) modifyResourceClaimsForNewPod(podInfo *framework.Pod
// so we don't add them. The claims should already be allocated in the provided PodInfo. // so we don't add them. The claims should already be allocated in the provided PodInfo.
var podOwnedClaims []*resourceapi.ResourceClaim var podOwnedClaims []*resourceapi.ResourceClaim
for _, claim := range podInfo.NeededResourceClaims { for _, claim := range podInfo.NeededResourceClaims {
if ownerName, _ := drautils.ClaimOwningPod(claim); ownerName != "" { if err := resourceclaim.IsForPod(podInfo.Pod, claim); err == nil {
podOwnedClaims = append(podOwnedClaims, claim) podOwnedClaims = append(podOwnedClaims, claim)
} }
} }

View File

@ -178,7 +178,7 @@ func (s Snapshot) RemovePodOwnedClaims(pod *apiv1.Pod) {
// The claim isn't tracked in the snapshot for some reason. Nothing to remove/modify, so continue to the next claim. // The claim isn't tracked in the snapshot for some reason. Nothing to remove/modify, so continue to the next claim.
continue continue
} }
if ownerName, ownerUid := drautils.ClaimOwningPod(claim); ownerName == pod.Name && ownerUid == pod.UID { if err := resourceclaim.IsForPod(pod, claim); err == nil {
delete(s.resourceClaimsById, claimId) delete(s.resourceClaimsById, claimId)
} else { } else {
drautils.ClearPodReservationInPlace(claim, pod) drautils.ClearPodReservationInPlace(claim, pod)
@ -214,9 +214,7 @@ func (s Snapshot) UnreservePodClaims(pod *apiv1.Pod) error {
return err return err
} }
for _, claim := range claims { for _, claim := range claims {
ownerPodName, ownerPodUid := drautils.ClaimOwningPod(claim) podOwnedClaim := resourceclaim.IsForPod(pod, claim) == nil
podOwnedClaim := ownerPodName == pod.Name && ownerPodUid == ownerPodUid
drautils.ClearPodReservationInPlace(claim, pod) drautils.ClearPodReservationInPlace(claim, pod)
if podOwnedClaim || !drautils.ClaimInUse(claim) { if podOwnedClaim || !drautils.ClaimInUse(claim) {
drautils.DeallocateClaimInPlace(claim) drautils.DeallocateClaimInPlace(claim)

View File

@ -22,23 +22,10 @@ import (
apiv1 "k8s.io/api/core/v1" apiv1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1beta1" resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/component-helpers/scheduling/corev1" "k8s.io/component-helpers/scheduling/corev1"
resourceclaim "k8s.io/dynamic-resource-allocation/resourceclaim" resourceclaim "k8s.io/dynamic-resource-allocation/resourceclaim"
"k8s.io/utils/ptr"
) )
// ClaimOwningPod returns the name and UID of the Pod owner of the provided claim. If the claim isn't
// owned by a Pod, empty strings are returned.
func ClaimOwningPod(claim *resourceapi.ResourceClaim) (string, types.UID) {
for _, owner := range claim.OwnerReferences {
if ptr.Deref(owner.Controller, false) && owner.APIVersion == "v1" && owner.Kind == "Pod" {
return owner.Name, owner.UID
}
}
return "", ""
}
// ClaimAllocated returns whether the provided claim is allocated. // ClaimAllocated returns whether the provided claim is allocated.
func ClaimAllocated(claim *resourceapi.ResourceClaim) bool { func ClaimAllocated(claim *resourceapi.ResourceClaim) bool {
return claim.Status.Allocation != nil return claim.Status.Allocation != nil

View File

@ -25,84 +25,8 @@ import (
apiv1 "k8s.io/api/core/v1" apiv1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1beta1" resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
) )
func TestClaimOwningPod(t *testing.T) {
truePtr := true
for _, tc := range []struct {
testName string
claim *resourceapi.ResourceClaim
wantName string
wantUid types.UID
}{
{
testName: "claim with no owners",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with non-Pod owners",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController", Controller: &truePtr},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
},
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with a Pod non-controller owner",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController"},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
{Name: "owner3", UID: "owner3uid", APIVersion: "v1", Kind: "Pod"},
},
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with a Pod controller owner",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController"},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
{Name: "owner3", UID: "owner3uid", APIVersion: "v1", Kind: "Pod", Controller: &truePtr},
},
},
},
wantName: "owner3",
wantUid: "owner3uid",
},
} {
t.Run(tc.testName, func(t *testing.T) {
name, uid := ClaimOwningPod(tc.claim)
if tc.wantName != name {
t.Errorf("ClaimOwningPod(): unexpected output name: want %s, got %s", tc.wantName, name)
}
if tc.wantUid != uid {
t.Errorf("ClaimOwningPod(): unexpected output UID: want %v, got %v", tc.wantUid, uid)
}
})
}
}
func TestClaimAllocated(t *testing.T) { func TestClaimAllocated(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
testName string testName string

View File

@ -23,6 +23,7 @@ import (
resourceapi "k8s.io/api/resource/v1beta1" resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/uuid" "k8s.io/apimachinery/pkg/util/uuid"
"k8s.io/dynamic-resource-allocation/resourceclaim"
"k8s.io/utils/set" "k8s.io/utils/set"
) )
@ -61,7 +62,7 @@ func SanitizedNodeResourceSlices(nodeLocalSlices []*resourceapi.ResourceSlice, n
func SanitizedPodResourceClaims(newOwner, oldOwner *v1.Pod, claims []*resourceapi.ResourceClaim, nameSuffix, newNodeName, oldNodeName string, oldNodePoolNames set.Set[string]) ([]*resourceapi.ResourceClaim, error) { func SanitizedPodResourceClaims(newOwner, oldOwner *v1.Pod, claims []*resourceapi.ResourceClaim, nameSuffix, newNodeName, oldNodeName string, oldNodePoolNames set.Set[string]) ([]*resourceapi.ResourceClaim, error) {
var result []*resourceapi.ResourceClaim var result []*resourceapi.ResourceClaim
for _, claim := range claims { for _, claim := range claims {
if ownerName, ownerUid := ClaimOwningPod(claim); ownerName != oldOwner.Name || ownerUid != oldOwner.UID { if err := resourceclaim.IsForPod(oldOwner, claim); err != nil {
// Only claims owned by the pod are bound to its lifecycle. The lifecycle of other claims is independent, and they're most likely shared // Only claims owned by the pod are bound to its lifecycle. The lifecycle of other claims is independent, and they're most likely shared
// by multiple pods. They shouldn't be sanitized or duplicated - just add unchanged to the result. // by multiple pods. They shouldn't be sanitized or duplicated - just add unchanged to the result.
result = append(result, claim) result = append(result, claim)

View File

@ -49,9 +49,12 @@ func SanitizedTemplateNodeInfoFromNodeGroup(nodeGroup nodeGroupTemplateNodeInfoG
if err != nil { if err != nil {
return nil, errors.ToAutoscalerError(errors.CloudProviderError, err).AddPrefix("failed to obtain template NodeInfo from node group %q: ", nodeGroup.Id()) return nil, errors.ToAutoscalerError(errors.CloudProviderError, err).AddPrefix("failed to obtain template NodeInfo from node group %q: ", nodeGroup.Id())
} }
labels.UpdateDeprecatedLabels(baseNodeInfo.Node().ObjectMeta.Labels) sanitizedNodeInfo, aErr := SanitizedTemplateNodeInfoFromNodeInfo(baseNodeInfo, nodeGroup.Id(), daemonsets, true, taintConfig)
if aErr != nil {
return SanitizedTemplateNodeInfoFromNodeInfo(baseNodeInfo, nodeGroup.Id(), daemonsets, true, taintConfig) return nil, aErr
}
labels.UpdateDeprecatedLabels(sanitizedNodeInfo.Node().Labels)
return sanitizedNodeInfo, nil
} }
// SanitizedTemplateNodeInfoFromNodeInfo returns a template NodeInfo object based on a real example NodeInfo from the cluster. The template is sanitized, and only // SanitizedTemplateNodeInfoFromNodeInfo returns a template NodeInfo object based on a real example NodeInfo from the cluster. The template is sanitized, and only

View File

@ -40,6 +40,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/utils/labels" "k8s.io/autoscaler/cluster-autoscaler/utils/labels"
"k8s.io/autoscaler/cluster-autoscaler/utils/taints" "k8s.io/autoscaler/cluster-autoscaler/utils/taints"
. "k8s.io/autoscaler/cluster-autoscaler/utils/test" . "k8s.io/autoscaler/cluster-autoscaler/utils/test"
"k8s.io/dynamic-resource-allocation/resourceclaim"
) )
var ( var (
@ -93,6 +94,11 @@ func TestSanitizedTemplateNodeInfoFromNodeGroup(t *testing.T) {
exampleNode.Spec.Taints = []apiv1.Taint{ exampleNode.Spec.Taints = []apiv1.Taint{
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule}, {Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
} }
exampleNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
for _, tc := range []struct { for _, tc := range []struct {
testName string testName string
@ -155,7 +161,7 @@ func TestSanitizedTemplateNodeInfoFromNodeGroup(t *testing.T) {
// Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because // Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because
// TemplateNodeInfoFromNodeGroupTemplate randomizes the suffix. // TemplateNodeInfoFromNodeGroupTemplate randomizes the suffix.
// Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed). // Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed).
if err := verifyNodeInfoSanitization(tc.nodeGroup.templateNodeInfoResult, templateNodeInfo, tc.wantPods, "template-node-for-"+tc.nodeGroup.id, "", nil); err != nil { if err := verifyNodeInfoSanitization(tc.nodeGroup.templateNodeInfoResult, templateNodeInfo, tc.wantPods, "template-node-for-"+tc.nodeGroup.id, "", true, nil); err != nil {
t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err) t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
} }
}) })
@ -167,6 +173,11 @@ func TestSanitizedTemplateNodeInfoFromNodeInfo(t *testing.T) {
exampleNode.Spec.Taints = []apiv1.Taint{ exampleNode.Spec.Taints = []apiv1.Taint{
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule}, {Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
} }
exampleNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
testCases := []struct { testCases := []struct {
name string name string
@ -317,7 +328,7 @@ func TestSanitizedTemplateNodeInfoFromNodeInfo(t *testing.T) {
// Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because // Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because
// TemplateNodeInfoFromExampleNodeInfo randomizes the suffix. // TemplateNodeInfoFromExampleNodeInfo randomizes the suffix.
// Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed). // Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed).
if err := verifyNodeInfoSanitization(exampleNodeInfo, templateNodeInfo, tc.wantPods, "template-node-for-"+nodeGroupId, "", nil); err != nil { if err := verifyNodeInfoSanitization(exampleNodeInfo, templateNodeInfo, tc.wantPods, "template-node-for-"+nodeGroupId, "", false, nil); err != nil {
t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err) t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
} }
}) })
@ -332,6 +343,12 @@ func TestSanitizedNodeInfo(t *testing.T) {
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule}, {Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
{Key: "a", Value: "b", Effect: apiv1.TaintEffectNoSchedule}, {Key: "a", Value: "b", Effect: apiv1.TaintEffectNoSchedule},
} }
templateNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
pods := []*framework.PodInfo{ pods := []*framework.PodInfo{
{Pod: BuildTestPod("p1", 80, 0, WithNodeName(nodeName))}, {Pod: BuildTestPod("p1", 80, 0, WithNodeName(nodeName))},
{Pod: BuildTestPod("p2", 80, 0, WithNodeName(nodeName))}, {Pod: BuildTestPod("p2", 80, 0, WithNodeName(nodeName))},
@ -346,7 +363,7 @@ func TestSanitizedNodeInfo(t *testing.T) {
// Verify that the taints are not sanitized (they should be sanitized in the template already). // Verify that the taints are not sanitized (they should be sanitized in the template already).
// Verify that the NodeInfo is sanitized using the template Node name as base. // Verify that the NodeInfo is sanitized using the template Node name as base.
initialTaints := templateNodeInfo.Node().Spec.Taints initialTaints := templateNodeInfo.Node().Spec.Taints
if err := verifyNodeInfoSanitization(templateNodeInfo, freshNodeInfo, nil, templateNodeInfo.Node().Name, suffix, initialTaints); err != nil { if err := verifyNodeInfoSanitization(templateNodeInfo, freshNodeInfo, nil, templateNodeInfo.Node().Name, suffix, false, initialTaints); err != nil {
t.Fatalf("FreshNodeInfoFromTemplateNodeInfo(): NodeInfo wasn't properly sanitized: %v", err) t.Fatalf("FreshNodeInfoFromTemplateNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
} }
} }
@ -360,6 +377,8 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
apiv1.LabelHostname: oldNodeName, apiv1.LabelHostname: oldNodeName,
"a": "b", "a": "b",
"x": "y", "x": "y",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
} }
taintsNode := basicNode.DeepCopy() taintsNode := basicNode.DeepCopy()
@ -491,7 +510,7 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("sanitizeNodeInfo(): want nil error, got %v", err) t.Fatalf("sanitizeNodeInfo(): want nil error, got %v", err)
} }
if err := verifyNodeInfoSanitization(tc.nodeInfo, nodeInfo, nil, newNameBase, suffix, tc.wantTaints); err != nil { if err := verifyNodeInfoSanitization(tc.nodeInfo, nodeInfo, nil, newNameBase, suffix, false, tc.wantTaints); err != nil {
t.Fatalf("sanitizeNodeInfo(): NodeInfo wasn't properly sanitized: %v", err) t.Fatalf("sanitizeNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
} }
}) })
@ -506,7 +525,7 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
// //
// If expectedPods is nil, the set of pods is expected not to change between initialNodeInfo and sanitizedNodeInfo. If the sanitization is // If expectedPods is nil, the set of pods is expected not to change between initialNodeInfo and sanitizedNodeInfo. If the sanitization is
// expected to change the set of pods, the expected set should be passed to expectedPods. // expected to change the set of pods, the expected set should be passed to expectedPods.
func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.NodeInfo, expectedPods []*apiv1.Pod, nameBase, nameSuffix string, wantTaints []apiv1.Taint) error { func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.NodeInfo, expectedPods []*apiv1.Pod, nameBase, nameSuffix string, wantDeprecatedLabels bool, wantTaints []apiv1.Taint) error {
if nameSuffix == "" { if nameSuffix == "" {
// Determine the suffix from the provided sanitized NodeInfo - it should be the last part of a dash-separated name. // Determine the suffix from the provided sanitized NodeInfo - it should be the last part of a dash-separated name.
nameParts := strings.Split(sanitizedNodeInfo.Node().Name, "-") nameParts := strings.Split(sanitizedNodeInfo.Node().Name, "-")
@ -526,7 +545,7 @@ func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.No
// Verification below assumes the same set of pods between initialNodeInfo and sanitizedNodeInfo. // Verification below assumes the same set of pods between initialNodeInfo and sanitizedNodeInfo.
wantNodeName := fmt.Sprintf("%s-%s", nameBase, nameSuffix) wantNodeName := fmt.Sprintf("%s-%s", nameBase, nameSuffix)
if err := verifySanitizedNode(initialNodeInfo.Node(), sanitizedNodeInfo.Node(), wantNodeName, wantTaints); err != nil { if err := verifySanitizedNode(initialNodeInfo.Node(), sanitizedNodeInfo.Node(), wantNodeName, wantDeprecatedLabels, wantTaints); err != nil {
return err return err
} }
if err := verifySanitizedNodeResourceSlices(initialNodeInfo.LocalResourceSlices, sanitizedNodeInfo.LocalResourceSlices, nameSuffix); err != nil { if err := verifySanitizedNodeResourceSlices(initialNodeInfo.LocalResourceSlices, sanitizedNodeInfo.LocalResourceSlices, nameSuffix); err != nil {
@ -539,7 +558,7 @@ func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.No
return nil return nil
} }
func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName string, wantTaints []apiv1.Taint) error { func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName string, wantDeprecatedLabels bool, wantTaints []apiv1.Taint) error {
if gotName := sanitizedNode.Name; gotName != wantNodeName { if gotName := sanitizedNode.Name; gotName != wantNodeName {
return fmt.Errorf("want sanitized Node name %q, got %q", wantNodeName, gotName) return fmt.Errorf("want sanitized Node name %q, got %q", wantNodeName, gotName)
} }
@ -552,6 +571,9 @@ func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName st
wantLabels[k] = v wantLabels[k] = v
} }
wantLabels[apiv1.LabelHostname] = wantNodeName wantLabels[apiv1.LabelHostname] = wantNodeName
if wantDeprecatedLabels {
labels.UpdateDeprecatedLabels(wantLabels)
}
if diff := cmp.Diff(wantLabels, sanitizedNode.Labels); diff != "" { if diff := cmp.Diff(wantLabels, sanitizedNode.Labels); diff != "" {
return fmt.Errorf("sanitized Node labels unexpected, diff (-want +got): %s", diff) return fmt.Errorf("sanitized Node labels unexpected, diff (-want +got): %s", diff)
} }
@ -601,7 +623,7 @@ func verifySanitizedPods(initialPods, sanitizedPods []*framework.PodInfo, wantNo
return fmt.Errorf("sanitized Pod unexpected diff (-want +got): %s", diff) return fmt.Errorf("sanitized Pod unexpected diff (-want +got): %s", diff)
} }
if err := verifySanitizedPodResourceClaims(initialPod.NeededResourceClaims, sanitizedPod.NeededResourceClaims, nameSuffix); err != nil { if err := verifySanitizedPodResourceClaims(initialPod, sanitizedPod, nameSuffix); err != nil {
return err return err
} }
} }
@ -633,7 +655,11 @@ func verifySanitizedNodeResourceSlices(initialSlices, sanitizedSlices []*resourc
return nil return nil
} }
func verifySanitizedPodResourceClaims(initialClaims, sanitizedClaims []*resourceapi.ResourceClaim, nameSuffix string) error { func verifySanitizedPodResourceClaims(initialPod, sanitizedPod *framework.PodInfo, nameSuffix string) error {
initialClaims := initialPod.NeededResourceClaims
sanitizedClaims := sanitizedPod.NeededResourceClaims
owningPod := initialPod.Pod
if len(initialClaims) != len(sanitizedClaims) { if len(initialClaims) != len(sanitizedClaims) {
return fmt.Errorf("want %d NeededResourceClaims in sanitized NodeInfo, got %d", len(initialClaims), len(sanitizedClaims)) return fmt.Errorf("want %d NeededResourceClaims in sanitized NodeInfo, got %d", len(initialClaims), len(sanitizedClaims))
} }
@ -642,7 +668,9 @@ func verifySanitizedPodResourceClaims(initialClaims, sanitizedClaims []*resource
initialClaim := initialClaims[i] initialClaim := initialClaims[i]
// Pod-owned claims should be sanitized, other claims shouldn't. // Pod-owned claims should be sanitized, other claims shouldn't.
if owningPod, _ := drautils.ClaimOwningPod(initialClaim); owningPod != "" { err := resourceclaim.IsForPod(owningPod, initialClaim)
isPodOwned := err == nil
if isPodOwned {
// Pod-owned claim, verify that it was sanitized. // Pod-owned claim, verify that it was sanitized.
if sanitizedClaim.Name == initialClaim.Name || !strings.HasSuffix(sanitizedClaim.Name, nameSuffix) { if sanitizedClaim.Name == initialClaim.Name || !strings.HasSuffix(sanitizedClaim.Name, nameSuffix) {
return fmt.Errorf("sanitized ResourceClaim name unexpected: want (different than %q, ending in %q), got %q", initialClaim.Name, nameSuffix, sanitizedClaim.Name) return fmt.Errorf("sanitized ResourceClaim name unexpected: want (different than %q, ending in %q), got %q", initialClaim.Name, nameSuffix, sanitizedClaim.Name)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
FROM --platform=$BUILDPLATFORM golang:1.24.3 AS builder FROM --platform=$BUILDPLATFORM golang:1.24.4 AS builder
WORKDIR /workspace WORKDIR /workspace

View File

@ -276,119 +276,120 @@ func TestChangedCAReloader(t *testing.T) {
assert.NotEqual(t, oldCAEncodedString, newCAEncodedString, "expected CA to change") assert.NotEqual(t, oldCAEncodedString, newCAEncodedString, "expected CA to change")
} }
func TestUnchangedCAReloader(t *testing.T) { // TODO(omerap12): Temporary workaround for flakiness (#7831)
tempDir := t.TempDir() // func TestUnchangedCAReloader(t *testing.T) {
caCert := &x509.Certificate{ // tempDir := t.TempDir()
SerialNumber: big.NewInt(0), // caCert := &x509.Certificate{
Subject: pkix.Name{ // SerialNumber: big.NewInt(0),
Organization: []string{"ca"}, // Subject: pkix.Name{
}, // Organization: []string{"ca"},
NotBefore: time.Now(), // },
NotAfter: time.Now().AddDate(2, 0, 0), // NotBefore: time.Now(),
IsCA: true, // NotAfter: time.Now().AddDate(2, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, // IsCA: true,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, // ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true, // KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
} // BasicConstraintsValid: true,
caKey, err := rsa.GenerateKey(rand.Reader, 4096) // }
if err != nil { // caKey, err := rsa.GenerateKey(rand.Reader, 4096)
t.Error(err) // if err != nil {
} // t.Error(err)
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey) // }
if err != nil { // caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey)
t.Error(err) // if err != nil {
} // t.Error(err)
caPath := path.Join(tempDir, "ca.crt") // }
caFile, err := os.Create(caPath) // caPath := path.Join(tempDir, "ca.crt")
if err != nil { // caFile, err := os.Create(caPath)
t.Error(err) // if err != nil {
} // t.Error(err)
err = pem.Encode(caFile, &pem.Block{ // }
Type: "CERTIFICATE", // err = pem.Encode(caFile, &pem.Block{
Bytes: caBytes, // Type: "CERTIFICATE",
}) // Bytes: caBytes,
if err != nil { // })
t.Error(err) // if err != nil {
} // t.Error(err)
// }
testClientSet := fake.NewSimpleClientset() // testClientSet := fake.NewSimpleClientset()
selfRegistration( // selfRegistration(
testClientSet, // testClientSet,
readFile(caPath), // readFile(caPath),
0*time.Second, // 0*time.Second,
"default", // "default",
"vpa-service", // "vpa-service",
"http://example.com/", // "http://example.com/",
true, // true,
int32(32), // int32(32),
"", // "",
[]string{}, // []string{},
false, // false,
"key1:value1,key2:value2", // "key1:value1,key2:value2",
) // )
webhookConfigInterface := testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations() // webhookConfigInterface := testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations()
oldWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{}) // oldWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
if err != nil { // if err != nil {
t.Error(err) // t.Error(err)
} // }
assert.Len(t, oldWebhookConfig.Webhooks, 1, "expected one webhook configuration") // assert.Len(t, oldWebhookConfig.Webhooks, 1, "expected one webhook configuration")
webhook := oldWebhookConfig.Webhooks[0] // webhook := oldWebhookConfig.Webhooks[0]
oldWebhookCABundle := webhook.ClientConfig.CABundle // oldWebhookCABundle := webhook.ClientConfig.CABundle
var reloadWebhookCACalled, patchCalled atomic.Bool // var reloadWebhookCACalled, patchCalled atomic.Bool
reloadWebhookCACalled.Store(false) // reloadWebhookCACalled.Store(false)
patchCalled.Store(false) // patchCalled.Store(false)
testClientSet.PrependReactor("get", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) { // testClientSet.PrependReactor("get", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
reloadWebhookCACalled.Store(true) // reloadWebhookCACalled.Store(true)
return false, nil, nil // return false, nil, nil
}) // })
testClientSet.PrependReactor("patch", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) { // testClientSet.PrependReactor("patch", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
patchCalled.Store(true) // patchCalled.Store(true)
return false, nil, nil // return false, nil, nil
}) // })
reloader := certReloader{ // reloader := certReloader{
clientCaPath: caPath, // clientCaPath: caPath,
mutatingWebhookClient: testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations(), // mutatingWebhookClient: testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations(),
} // }
stop := make(chan struct{}) // stop := make(chan struct{})
defer close(stop) // defer close(stop)
if err := reloader.start(stop); err != nil { // if err := reloader.start(stop); err != nil {
t.Error(err) // t.Error(err)
} // }
originalCaFile, err := os.ReadFile(caPath) // originalCaFile, err := os.ReadFile(caPath)
if err != nil { // if err != nil {
t.Error(err) // t.Error(err)
} // }
err = os.WriteFile(caPath, originalCaFile, 0666) // err = os.WriteFile(caPath, originalCaFile, 0666)
if err != nil { // if err != nil {
t.Error(err) // t.Error(err)
} // }
oldCAEncodedString := base64.StdEncoding.EncodeToString(oldWebhookCABundle) // oldCAEncodedString := base64.StdEncoding.EncodeToString(oldWebhookCABundle)
for tries := 0; tries < 10; tries++ { // for tries := 0; tries < 10; tries++ {
if reloadWebhookCACalled.Load() { // if reloadWebhookCACalled.Load() {
break // break
} // }
time.Sleep(1 * time.Second) // time.Sleep(1 * time.Second)
} // }
if !reloadWebhookCACalled.Load() { // if !reloadWebhookCACalled.Load() {
t.Error("expected reloadWebhookCA to be called") // t.Error("expected reloadWebhookCA to be called")
} // }
assert.False(t, patchCalled.Load(), "expected patch to not be called") // assert.False(t, patchCalled.Load(), "expected patch to not be called")
newWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{}) // newWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
assert.Nil(t, err, "expected no error") // assert.Nil(t, err, "expected no error")
assert.NotNil(t, newWebhookConfig, "expected webhook configuration") // assert.NotNil(t, newWebhookConfig, "expected webhook configuration")
assert.Len(t, newWebhookConfig.Webhooks, 1, "expected one webhook configuration") // assert.Len(t, newWebhookConfig.Webhooks, 1, "expected one webhook configuration")
newWebhookCABundle := newWebhookConfig.Webhooks[0].ClientConfig.CABundle // newWebhookCABundle := newWebhookConfig.Webhooks[0].ClientConfig.CABundle
newCAEncodedString := base64.StdEncoding.EncodeToString(newWebhookCABundle) // newCAEncodedString := base64.StdEncoding.EncodeToString(newWebhookCABundle)
assert.Equal(t, oldCAEncodedString, newCAEncodedString, "expected CA to not change") // assert.Equal(t, oldCAEncodedString, newCAEncodedString, "expected CA to not change")
} // }

Some files were not shown because too many files have changed in this diff Show More