Merge branch 'master' into ignore-hyperpod

This commit is contained in:
Yiqing Wang 2025-06-10 13:50:22 -07:00
commit 20810f8d93
103 changed files with 4761 additions and 3488 deletions

View File

@ -19,6 +19,13 @@ package signers
import (
"encoding/json"
"fmt"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
"github.com/jmespath/go-jmespath"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/auth/credentials"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/errors"
@ -26,16 +33,12 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/responses"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/utils"
"k8s.io/klog/v2"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
)
const (
defaultOIDCDurationSeconds = 3600
oidcTokenFilePath = "ALIBABA_CLOUD_OIDC_TOKEN_FILE"
oldOidcTokenFilePath = "ALICLOUD_OIDC_TOKEN_FILE_PATH"
)
// OIDCSigner is kind of signer
@ -149,7 +152,7 @@ func (signer *OIDCSigner) getOIDCToken(OIDCTokenFilePath string) string {
tokenPath := OIDCTokenFilePath
_, err := os.Stat(tokenPath)
if os.IsNotExist(err) {
tokenPath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
tokenPath = utils.FirstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath))
if tokenPath == "" {
klog.Error("oidc token file path is missing")
return ""

View File

@ -22,11 +22,12 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/google/uuid"
"net/url"
"reflect"
"strconv"
"time"
"github.com/google/uuid"
)
/* if you use go 1.10 or higher, you can hack this util by these to avoid "TimeZone.zip not found" on Windows */
@ -127,3 +128,15 @@ func InitStructWithDefaultTag(bean interface{}) {
}
}
}
// FirstNotEmpty returns the first non-empty string from the input list.
// If all strings are empty or no arguments are provided, it returns an empty string.
func FirstNotEmpty(strs ...string) string {
for _, str := range strs {
if str != "" {
return str
}
}
return ""
}

View File

@ -0,0 +1,45 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFirstNotEmpty(t *testing.T) {
// Test case where the first non-empty string is at the beginning
result := FirstNotEmpty("hello", "world", "test")
assert.Equal(t, "hello", result)
// Test case where the first non-empty string is in the middle
result = FirstNotEmpty("", "foo", "bar")
assert.Equal(t, "foo", result)
// Test case where the first non-empty string is at the end
result = FirstNotEmpty("", "", "baz")
assert.Equal(t, "baz", result)
// Test case where all strings are empty
result = FirstNotEmpty("", "", "")
assert.Equal(t, "", result)
// Test case with no arguments
result = FirstNotEmpty()
assert.Equal(t, "", result)
}

View File

@ -19,6 +19,7 @@ package alicloud
import (
"os"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/alibaba-cloud-sdk-go/sdk/utils"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/alicloud/metadata"
"k8s.io/klog/v2"
)
@ -63,19 +64,19 @@ func (cc *cloudConfig) isValid() bool {
}
if cc.OIDCProviderARN == "" {
cc.OIDCProviderARN = firstNotEmpty(os.Getenv(oidcProviderARN), os.Getenv(oldOidcProviderARN))
cc.OIDCProviderARN = utils.FirstNotEmpty(os.Getenv(oidcProviderARN), os.Getenv(oldOidcProviderARN))
}
if cc.OIDCTokenFilePath == "" {
cc.OIDCTokenFilePath = firstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath))
cc.OIDCTokenFilePath = utils.FirstNotEmpty(os.Getenv(oidcTokenFilePath), os.Getenv(oldOidcTokenFilePath))
}
if cc.RoleARN == "" {
cc.RoleARN = firstNotEmpty(os.Getenv(roleARN), os.Getenv(oldRoleARN))
cc.RoleARN = utils.FirstNotEmpty(os.Getenv(roleARN), os.Getenv(oldRoleARN))
}
if cc.RoleSessionName == "" {
cc.RoleSessionName = firstNotEmpty(os.Getenv(roleSessionName), os.Getenv(oldRoleSessionName))
cc.RoleSessionName = utils.FirstNotEmpty(os.Getenv(roleSessionName), os.Getenv(oldRoleSessionName))
}
if cc.RegionId != "" && cc.AccessKeyID != "" && cc.AccessKeySecret != "" {
@ -133,15 +134,3 @@ func (cc *cloudConfig) getRegion() string {
}
return r
}
// firstNotEmpty returns the first non-empty string from the input list.
// If all strings are empty or no arguments are provided, it returns an empty string.
func firstNotEmpty(strs ...string) string {
for _, str := range strs {
if str != "" {
return str
}
}
return ""
}

View File

@ -55,25 +55,3 @@ func TestOldRRSACloudConfigIsValid(t *testing.T) {
assert.True(t, cfg.isValid())
assert.True(t, cfg.RRSAEnabled)
}
func TestFirstNotEmpty(t *testing.T) {
// Test case where the first non-empty string is at the beginning
result := firstNotEmpty("hello", "world", "test")
assert.Equal(t, "hello", result)
// Test case where the first non-empty string is in the middle
result = firstNotEmpty("", "foo", "bar")
assert.Equal(t, "foo", result)
// Test case where the first non-empty string is at the end
result = firstNotEmpty("", "", "baz")
assert.Equal(t, "baz", result)
// Test case where all strings are empty
result = firstNotEmpty("", "", "")
assert.Equal(t, "", result)
// Test case with no arguments
result = firstNotEmpty()
assert.Equal(t, "", result)
}

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"]
resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault
serviceAccountName: cluster-autoscaler
containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler
resources:
limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"]
resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault
serviceAccountName: cluster-autoscaler
containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler
resources:
limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"]
resources: ["jobs"]
@ -146,7 +146,7 @@ spec:
type: RuntimeDefault
serviceAccountName: cluster-autoscaler
containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler
resources:
limits:

View File

@ -51,7 +51,7 @@ rules:
resources: ["statefulsets", "replicasets", "daemonsets"]
verbs: ["watch", "list", "get"]
- apiGroups: ["storage.k8s.io"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities"]
resources: ["storageclasses", "csinodes", "csidrivers", "csistoragecapacities", "volumeattachments"]
verbs: ["watch", "list", "get"]
- apiGroups: ["batch", "extensions"]
resources: ["jobs"]
@ -153,7 +153,7 @@ spec:
nodeSelector:
kubernetes.io/role: control-plane
containers:
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.26.2
- image: registry.k8s.io/autoscaling/cluster-autoscaler:v1.32.1
name: cluster-autoscaler
resources:
limits:

View File

@ -20,7 +20,6 @@ import (
"context"
"fmt"
"net/http"
"strings"
"testing"
"time"
@ -422,8 +421,7 @@ func TestDeleteInstances(t *testing.T) {
},
}, nil)
err = as.DeleteInstances(instances)
expectedErrStr := "The specified account is disabled."
assert.True(t, strings.Contains(err.Error(), expectedErrStr))
assert.Error(t, err)
}
func TestAgentPoolDeleteNodes(t *testing.T) {
@ -478,8 +476,7 @@ func TestAgentPoolDeleteNodes(t *testing.T) {
ObjectMeta: v1.ObjectMeta{Name: "node"},
},
})
expectedErrStr := "The specified account is disabled."
assert.True(t, strings.Contains(err.Error(), expectedErrStr))
assert.Error(t, err)
as.minSize = 3
err = as.DeleteNodes([]*apiv1.Node{})

View File

@ -25,6 +25,7 @@ import (
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
"github.com/Azure/skewer"
@ -67,13 +68,18 @@ type azureCache struct {
// Cache content.
// resourceGroup specifies the name of the resource group that this cache tracks
resourceGroup string
// resourceGroup specifies the name of the node resource group that this cache tracks
resourceGroup string
clusterResourceGroup string
clusterName string
// enableVMsAgentPool specifies whether VMs agent pool type is supported.
enableVMsAgentPool bool
// vmType can be one of vmTypeVMSS (default), vmTypeStandard
vmType string
vmsPoolSet map[string]struct{} // track the nodepools that're vms pool
vmsPoolMap map[string]armcontainerservice.AgentPool // track the nodepools that're vms pool
// scaleSets keeps the set of all known scalesets in the resource group, populated/refreshed via VMSS.List() call.
// It is only used/populated if vmType is vmTypeVMSS (default).
@ -106,8 +112,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az
azClient: client,
refreshInterval: cacheTTL,
resourceGroup: config.ResourceGroup,
clusterResourceGroup: config.ClusterResourceGroup,
clusterName: config.ClusterName,
enableVMsAgentPool: config.EnableVMsAgentPool,
vmType: config.VMType,
vmsPoolSet: make(map[string]struct{}),
vmsPoolMap: make(map[string]armcontainerservice.AgentPool),
scaleSets: make(map[string]compute.VirtualMachineScaleSet),
virtualMachines: make(map[string][]compute.VirtualMachine),
registeredNodeGroups: make([]cloudprovider.NodeGroup, 0),
@ -130,11 +139,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az
return cache, nil
}
func (m *azureCache) getVMsPoolSet() map[string]struct{} {
func (m *azureCache) getVMsPoolMap() map[string]armcontainerservice.AgentPool {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.vmsPoolSet
return m.vmsPoolMap
}
func (m *azureCache) getVirtualMachines() map[string][]compute.VirtualMachine {
@ -232,13 +241,20 @@ func (m *azureCache) fetchAzureResources() error {
return err
}
m.scaleSets = vmssResult
vmResult, vmsPoolSet, err := m.fetchVirtualMachines()
vmResult, err := m.fetchVirtualMachines()
if err != nil {
return err
}
// we fetch both sets of resources since CAS may operate on mixed nodepools
m.virtualMachines = vmResult
m.vmsPoolSet = vmsPoolSet
// fetch VMs pools if enabled
if m.enableVMsAgentPool {
vmsPoolMap, err := m.fetchVMsPools()
if err != nil {
return err
}
m.vmsPoolMap = vmsPoolMap
}
return nil
}
@ -251,19 +267,17 @@ const (
)
// fetchVirtualMachines returns the updated list of virtual machines in the config resource group using the Azure API.
func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, map[string]struct{}, error) {
func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, error) {
ctx, cancel := getContextWithCancel()
defer cancel()
result, err := m.azClient.virtualMachinesClient.List(ctx, m.resourceGroup)
if err != nil {
klog.Errorf("VirtualMachinesClient.List in resource group %q failed: %v", m.resourceGroup, err)
return nil, nil, err.Error()
return nil, err.Error()
}
instances := make(map[string][]compute.VirtualMachine)
// track the nodepools that're vms pools
vmsPoolSet := make(map[string]struct{})
for _, instance := range result {
if instance.Tags == nil {
continue
@ -280,20 +294,43 @@ func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine
}
instances[to.String(vmPoolName)] = append(instances[to.String(vmPoolName)], instance)
}
return instances, nil
}
// if the nodepool is already in the map, skip it
if _, ok := vmsPoolSet[to.String(vmPoolName)]; ok {
continue
// fetchVMsPools returns a name to agentpool map of all the VMs pools in the cluster
func (m *azureCache) fetchVMsPools() (map[string]armcontainerservice.AgentPool, error) {
ctx, cancel := getContextWithTimeout(vmsContextTimeout)
defer cancel()
// defensive check, should never happen when enableVMsAgentPool toggle is on
if m.azClient.agentPoolClient == nil {
return nil, errors.New("agentPoolClient is nil")
}
vmsPoolMap := make(map[string]armcontainerservice.AgentPool)
pager := m.azClient.agentPoolClient.NewListPager(m.clusterResourceGroup, m.clusterName, nil)
var aps []*armcontainerservice.AgentPool
for pager.More() {
resp, err := pager.NextPage(ctx)
if err != nil {
klog.Errorf("agentPoolClient.pager.NextPage in cluster %s resource group %s failed: %v",
m.clusterName, m.clusterResourceGroup, err)
return nil, err
}
aps = append(aps, resp.Value...)
}
// nodes from vms pool will have tag "aks-managed-agentpool-type" set to "VirtualMachines"
if agentpoolType := tags[agentpoolTypeTag]; agentpoolType != nil {
if strings.EqualFold(to.String(agentpoolType), vmsPoolType) {
vmsPoolSet[to.String(vmPoolName)] = struct{}{}
}
for _, ap := range aps {
if ap != nil && ap.Name != nil && ap.Properties != nil && ap.Properties.Type != nil &&
*ap.Properties.Type == armcontainerservice.AgentPoolTypeVirtualMachines {
// we only care about VMs pools, skip other types
klog.V(6).Infof("Found VMs pool %q", *ap.Name)
vmsPoolMap[*ap.Name] = *ap
}
}
return instances, vmsPoolSet, nil
return vmsPoolMap, nil
}
// fetchScaleSets returns the updated list of scale sets in the config resource group using the Azure API.
@ -422,7 +459,7 @@ func (m *azureCache) HasInstance(providerID string) (bool, error) {
// FindForInstance returns node group of the given Instance
func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudprovider.NodeGroup, error) {
vmsPoolSet := m.getVMsPoolSet()
vmsPoolMap := m.getVMsPoolMap()
m.mutex.Lock()
defer m.mutex.Unlock()
@ -441,7 +478,7 @@ func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudpr
}
// cluster with vmss pool only
if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolSet) == 0 {
if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolMap) == 0 {
if m.areAllScaleSetsUniform() {
// Omit virtual machines not managed by vmss only in case of uniform scale set.
if ok := virtualMachineRE.Match([]byte(inst.Name)); ok {

View File

@ -22,9 +22,42 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
providerazureconsts "sigs.k8s.io/cloud-provider-azure/pkg/consts"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/go-autorest/autorest/to"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
func TestFetchVMsPools(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
provider := newTestProvider(t)
ac := provider.azureManager.azureCache
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ac.azClient.agentPoolClient = mockAgentpoolclient
vmsPool := getTestVMsAgentPool(false)
vmssPoolType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("vmsspool1"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssPoolType,
},
}
invalidPool := armcontainerservice.AgentPool{}
fakeAPListPager := getFakeAgentpoolListPager(&vmsPool, &vmssPool, &invalidPool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
vmsPoolMap, err := ac.fetchVMsPools()
assert.NoError(t, err)
assert.Equal(t, 1, len(vmsPoolMap))
_, ok := vmsPoolMap[to.String(vmsPool.Name)]
assert.True(t, ok)
}
func TestRegister(t *testing.T) {
provider := newTestProvider(t)
ss := newTestScaleSet(provider.azureManager, "ss")

View File

@ -19,6 +19,8 @@ package azure
import (
"context"
"fmt"
"os"
"time"
_ "go.uber.org/mock/mockgen/model" // for go:generate
@ -29,7 +31,7 @@ import (
azurecore_policy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
@ -47,7 +49,12 @@ import (
providerazureconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config"
)
//go:generate sh -c "mockgen k8s.io/autoscaler/cluster-autoscaler/cloudprovider/azure AgentPoolsClient >./agentpool_client.go"
//go:generate sh -c "mockgen -source=azure_client.go -destination azure_mock_agentpool_client.go -package azure -exclude_interfaces DeploymentsClient"
const (
vmsContextTimeout = 5 * time.Minute
vmsAsyncContextTimeout = 30 * time.Minute
)
// AgentPoolsClient interface defines the methods needed for scaling vms pool.
// it is implemented by track2 sdk armcontainerservice.AgentPoolsClient
@ -68,52 +75,89 @@ type AgentPoolsClient interface {
machines armcontainerservice.AgentPoolDeleteMachinesParameter,
options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (
*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error)
NewListPager(
resourceGroupName, resourceName string,
options *armcontainerservice.AgentPoolsClientListOptions,
) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse]
}
func getAgentpoolClientCredentials(cfg *Config) (azcore.TokenCredential, error) {
var cred azcore.TokenCredential
var err error
if cfg.AuthMethod == authMethodCLI {
cred, err = azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
TenantID: cfg.TenantID})
if err != nil {
klog.Errorf("NewAzureCLICredential failed: %v", err)
return nil, err
if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal {
// Use MSI
if cfg.UseManagedIdentityExtension {
// Use System Assigned MSI
if cfg.UserAssignedIdentityID == "" {
klog.V(4).Info("Agentpool client: using System Assigned MSI to retrieve access token")
return azidentity.NewManagedIdentityCredential(nil)
}
// Use User Assigned MSI
klog.V(4).Info("Agentpool client: using User Assigned MSI to retrieve access token")
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(cfg.UserAssignedIdentityID),
})
}
} else if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal {
cred, err = azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil)
if err != nil {
klog.Errorf("NewClientSecretCredential failed: %v", err)
return nil, err
}
} else {
return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod)
}
return cred, nil
}
func getAgentpoolClientRetryOptions(cfg *Config) azurecore_policy.RetryOptions {
if cfg.AuthMethod == authMethodCLI {
return azurecore_policy.RetryOptions{
MaxRetries: -1, // no retry when using CLI auth for UT
// Use Service Principal with ClientID and ClientSecret
if cfg.AADClientID != "" && cfg.AADClientSecret != "" {
klog.V(2).Infoln("Agentpool client: using client_id+client_secret to retrieve access token")
return azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil)
}
// Use Service Principal with ClientCert and AADClientCertPassword
if cfg.AADClientID != "" && cfg.AADClientCertPath != "" {
klog.V(2).Infoln("Agentpool client: using client_cert+client_private_key to retrieve access token")
certData, err := os.ReadFile(cfg.AADClientCertPath)
if err != nil {
return nil, fmt.Errorf("reading the client certificate from file %s failed with error: %w", cfg.AADClientCertPath, err)
}
certs, privateKey, err := azidentity.ParseCertificates(certData, []byte(cfg.AADClientCertPassword))
if err != nil {
return nil, fmt.Errorf("parsing service principal certificate data failed with error: %w", err)
}
return azidentity.NewClientCertificateCredential(cfg.TenantID, cfg.AADClientID, certs, privateKey, &azidentity.ClientCertificateCredentialOptions{
SendCertificateChain: true,
})
}
}
return azextensions.DefaultRetryOpts()
if cfg.UseFederatedWorkloadIdentityExtension {
klog.V(4).Info("Agentpool client: using workload identity for access token")
return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
TokenFilePath: cfg.AADFederatedTokenFile,
})
}
return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod)
}
func newAgentpoolClient(cfg *Config) (AgentPoolsClient, error) {
retryOptions := getAgentpoolClientRetryOptions(cfg)
retryOptions := azextensions.DefaultRetryOpts()
cred, err := getAgentpoolClientCredentials(cfg)
if err != nil {
klog.Errorf("failed to get agent pool client credentials: %v", err)
return nil, err
}
env := azure.PublicCloud // default to public cloud
if cfg.Cloud != "" {
var err error
env, err = azure.EnvironmentFromName(cfg.Cloud)
if err != nil {
klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err)
return nil, err
}
}
if cfg.ARMBaseURLForAPClient != "" {
klog.V(10).Infof("Using ARMBaseURLForAPClient to create agent pool client")
return newAgentpoolClientWithConfig(cfg.SubscriptionID, nil, cfg.ARMBaseURLForAPClient, "UNKNOWN", retryOptions)
return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, cfg.ARMBaseURLForAPClient, env.TokenAudience, retryOptions, true /*insecureAllowCredentialWithHTTP*/)
}
return newAgentpoolClientWithPublicEndpoint(cfg, retryOptions)
return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions, false /*insecureAllowCredentialWithHTTP*/)
}
func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCredential,
cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) {
cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions, insecureAllowCredentialWithHTTP bool) (AgentPoolsClient, error) {
agentPoolsClient, err := armcontainerservice.NewAgentPoolsClient(subscriptionID, cred,
&policy.ClientOptions{
ClientOptions: azurecore_policy.ClientOptions{
@ -125,9 +169,10 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden
},
},
},
Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()),
Transport: azextensions.DefaultHTTPClient(),
Retry: retryOptions,
InsecureAllowCredentialWithHTTP: insecureAllowCredentialWithHTTP,
Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()),
Transport: azextensions.DefaultHTTPClient(),
Retry: retryOptions,
},
})
@ -139,26 +184,6 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden
return agentPoolsClient, nil
}
func newAgentpoolClientWithPublicEndpoint(cfg *Config, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) {
cred, err := getAgentpoolClientCredentials(cfg)
if err != nil {
klog.Errorf("failed to get agent pool client credentials: %v", err)
return nil, err
}
// default to public cloud
env := azure.PublicCloud
if cfg.Cloud != "" {
env, err = azure.EnvironmentFromName(cfg.Cloud)
if err != nil {
klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err)
return nil, err
}
}
return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions)
}
type azClient struct {
virtualMachineScaleSetsClient vmssclient.Interface
virtualMachineScaleSetVMsClient vmssvmclient.Interface
@ -232,9 +257,11 @@ func newAzClient(cfg *Config, env *azure.Environment) (*azClient, error) {
agentPoolClient, err := newAgentpoolClient(cfg)
if err != nil {
// we don't want to fail the whole process so we don't break any existing functionality
// since this may not be fatal - it is only used by vms pool which is still under development.
klog.Warningf("newAgentpoolClient failed with error: %s", err)
klog.Errorf("newAgentpoolClient failed with error: %s", err)
if cfg.EnableVMsAgentPool {
// only return error if VMs agent pool is supported which is controlled by toggle
return nil, err
}
}
return &azClient{

View File

@ -20,6 +20,7 @@ import (
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
"github.com/Azure/go-autorest/autorest/to"
@ -132,7 +133,7 @@ func TestNodeGroups(t *testing.T) {
)
assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"),
newTestVMsPool(provider.azureManager),
)
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2)
@ -146,9 +147,14 @@ func TestHasInstance(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient
provider.azureManager.azureCache.clusterName = "test-cluster"
provider.azureManager.azureCache.clusterResourceGroup = "test-rg"
provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types
// Simulate node groups and instances
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform)
@ -158,6 +164,20 @@ func TestHasInstance(t *testing.T) {
mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil).AnyTimes()
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes()
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("test-asg"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssType,
},
}
vmsPool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool)
mockAgentpoolclient.EXPECT().NewListPager(
provider.azureManager.azureCache.clusterResourceGroup,
provider.azureManager.azureCache.clusterName, nil).
Return(fakeAPListPager).AnyTimes()
// Register node groups
assert.Equal(t, len(provider.NodeGroups()), 0)
@ -168,9 +188,9 @@ func TestHasInstance(t *testing.T) {
assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"),
newTestVMsPool(provider.azureManager),
)
provider.azureManager.explicitlyConfigured["test-vms-pool"] = true
provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2)
@ -264,9 +284,14 @@ func TestMixedNodeGroups(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
provider.azureManager.azureCache.clusterName = "test-cluster"
provider.azureManager.azureCache.clusterResourceGroup = "test-rg"
provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types
provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform)
expectedVMsPoolVMs := newTestVMsPoolVMList(3)
@ -276,6 +301,19 @@ func TestMixedNodeGroups(t *testing.T) {
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes()
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
vmssPool := armcontainerservice.AgentPool{
Name: to.StringPtr("test-asg"),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmssType,
},
}
vmsPool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool)
mockAgentpoolclient.EXPECT().NewListPager(provider.azureManager.azureCache.clusterResourceGroup, provider.azureManager.azureCache.clusterName, nil).
Return(fakeAPListPager).AnyTimes()
assert.Equal(t, len(provider.NodeGroups()), 0)
registered := provider.azureManager.RegisterNodeGroup(
newTestScaleSet(provider.azureManager, "test-asg"),
@ -284,9 +322,9 @@ func TestMixedNodeGroups(t *testing.T) {
assert.True(t, registered)
registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"),
newTestVMsPool(provider.azureManager),
)
provider.azureManager.explicitlyConfigured["test-vms-pool"] = true
provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2)
@ -307,7 +345,7 @@ func TestMixedNodeGroups(t *testing.T) {
group, err = provider.NodeGroupForNode(vmsPoolNode)
assert.NoError(t, err)
assert.NotNil(t, group, "Group should not be nil")
assert.Equal(t, group.Id(), "test-vms-pool")
assert.Equal(t, group.Id(), vmsNodeGroupName)
assert.Equal(t, group.MinSize(), 3)
assert.Equal(t, group.MaxSize(), 10)
}

View File

@ -86,6 +86,9 @@ type Config struct {
// EnableForceDelete defines whether to enable force deletion on the APIs
EnableForceDelete bool `json:"enableForceDelete,omitempty" yaml:"enableForceDelete,omitempty"`
// EnableVMsAgentPool defines whether to support VMs agentpool type in addition to VMSS type
EnableVMsAgentPool bool `json:"enableVMsAgentPool,omitempty" yaml:"enableVMsAgentPool,omitempty"`
// (DEPRECATED, DO NOT USE) EnableDynamicInstanceList defines whether to enable dynamic instance workflow for instance information check
EnableDynamicInstanceList bool `json:"enableDynamicInstanceList,omitempty" yaml:"enableDynamicInstanceList,omitempty"`
@ -122,6 +125,7 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) {
// Static defaults
cfg.EnableDynamicInstanceList = false
cfg.EnableVmssFlexNodes = false
cfg.EnableVMsAgentPool = false
cfg.CloudProviderBackoffRetries = providerazureconsts.BackoffRetriesDefault
cfg.CloudProviderBackoffExponent = providerazureconsts.BackoffExponentDefault
cfg.CloudProviderBackoffDuration = providerazureconsts.BackoffDurationDefault
@ -257,6 +261,9 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) {
if _, err = assignBoolFromEnvIfExists(&cfg.StrictCacheUpdates, "AZURE_STRICT_CACHE_UPDATES"); err != nil {
return nil, err
}
if _, err = assignBoolFromEnvIfExists(&cfg.EnableVMsAgentPool, "AZURE_ENABLE_VMS_AGENT_POOLS"); err != nil {
return nil, err
}
if _, err = assignBoolFromEnvIfExists(&cfg.EnableDynamicInstanceList, "AZURE_ENABLE_DYNAMIC_INSTANCE_LIST"); err != nil {
return nil, err
}

View File

@ -22,80 +22,79 @@ import (
"regexp"
"strings"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"k8s.io/klog/v2"
)
// GetVMSSTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information.
// GetInstanceTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information.
// It is declared as a variable for testing purpose.
var GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) {
var vmssType *InstanceType
var GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
var instanceType *InstanceType
for k := range InstanceTypes {
if strings.EqualFold(k, *template.Sku.Name) {
vmssType = InstanceTypes[k]
if strings.EqualFold(k, template.SkuName) {
instanceType = InstanceTypes[k]
break
}
}
promoRe := regexp.MustCompile(`(?i)_promo`)
if promoRe.MatchString(*template.Sku.Name) {
if vmssType == nil {
if promoRe.MatchString(template.SkuName) {
if instanceType == nil {
// We didn't find an exact match but this is a promo type, check for matching standard
klog.V(4).Infof("No exact match found for %s, checking standard types", *template.Sku.Name)
skuName := promoRe.ReplaceAllString(*template.Sku.Name, "")
klog.V(4).Infof("No exact match found for %s, checking standard types", template.SkuName)
skuName := promoRe.ReplaceAllString(template.SkuName, "")
for k := range InstanceTypes {
if strings.EqualFold(k, skuName) {
vmssType = InstanceTypes[k]
instanceType = InstanceTypes[k]
break
}
}
}
}
if vmssType == nil {
return vmssType, fmt.Errorf("instance type %q not supported", *template.Sku.Name)
if instanceType == nil {
return instanceType, fmt.Errorf("instance type %q not supported", template.SkuName)
}
return vmssType, nil
return instanceType, nil
}
// GetVMSSTypeDynamically fetched vmss instance information using sku api calls.
// GetInstanceTypeDynamically fetched vmss instance information using sku api calls.
// It is declared as a variable for testing purpose.
var GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) {
var GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
ctx := context.Background()
var vmssType InstanceType
var instanceType InstanceType
sku, err := azCache.GetSKU(ctx, *template.Sku.Name, *template.Location)
sku, err := azCache.GetSKU(ctx, template.SkuName, template.Location)
if err != nil {
// We didn't find an exact match but this is a promo type, check for matching standard
promoRe := regexp.MustCompile(`(?i)_promo`)
skuName := promoRe.ReplaceAllString(*template.Sku.Name, "")
if skuName != *template.Sku.Name {
klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", *template.Sku.Name, skuName, err)
sku, err = azCache.GetSKU(ctx, skuName, *template.Location)
skuName := promoRe.ReplaceAllString(template.SkuName, "")
if skuName != template.SkuName {
klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", template.SkuName, skuName, err)
sku, err = azCache.GetSKU(ctx, skuName, template.Location)
}
if err != nil {
return vmssType, fmt.Errorf("instance type %q not supported. Error %v", *template.Sku.Name, err)
return instanceType, fmt.Errorf("instance type %q not supported. Error %v", template.SkuName, err)
}
}
vmssType.VCPU, err = sku.VCPU()
instanceType.VCPU, err = sku.VCPU()
if err != nil {
klog.V(1).Infof("Failed to parse vcpu from sku %q %v", *template.Sku.Name, err)
return vmssType, err
klog.V(1).Infof("Failed to parse vcpu from sku %q %v", template.SkuName, err)
return instanceType, err
}
gpu, err := getGpuFromSku(sku)
if err != nil {
klog.V(1).Infof("Failed to parse gpu from sku %q %v", *template.Sku.Name, err)
return vmssType, err
klog.V(1).Infof("Failed to parse gpu from sku %q %v", template.SkuName, err)
return instanceType, err
}
vmssType.GPU = gpu
instanceType.GPU = gpu
memoryGb, err := sku.Memory()
if err != nil {
klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", *template.Sku.Name, err)
return vmssType, err
klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", template.SkuName, err)
return instanceType, err
}
vmssType.MemoryMb = int64(memoryGb) * 1024
instanceType.MemoryMb = int64(memoryGb) * 1024
return vmssType, nil
return instanceType, nil
}

View File

@ -168,6 +168,23 @@ func (m *AzureManager) fetchExplicitNodeGroups(specs []string) error {
return nil
}
// parseSKUAndVMsAgentpoolNameFromSpecName parses the spec name for a mixed-SKU VMs pool.
// The spec name should be in the format <agentpoolname>/<sku>, e.g., "mypool1/Standard_D2s_v3", if the agent pool is a VMs pool.
// This method returns a boolean indicating if the agent pool is a VMs pool, along with the agent pool name and SKU.
func (m *AzureManager) parseSKUAndVMsAgentpoolNameFromSpecName(name string) (bool, string, string) {
parts := strings.Split(name, "/")
if len(parts) == 2 {
agentPoolName := parts[0]
sku := parts[1]
vmsPoolMap := m.azureCache.getVMsPoolMap()
if _, ok := vmsPoolMap[agentPoolName]; ok {
return true, agentPoolName, sku
}
}
return false, "", ""
}
func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGroup, error) {
scaleToZeroSupported := scaleToZeroSupportedStandard
if strings.EqualFold(m.config.VMType, providerazureconsts.VMTypeVMSS) {
@ -177,9 +194,13 @@ func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGr
if err != nil {
return nil, fmt.Errorf("failed to parse node group spec: %v", err)
}
vmsPoolSet := m.azureCache.getVMsPoolSet()
if _, ok := vmsPoolSet[s.Name]; ok {
return NewVMsPool(s, m), nil
// Starting from release 1.30, a cluster may have both VMSS and VMs pools.
// Therefore, we cannot solely rely on the VMType to determine the node group type.
// Instead, we need to check the cache to determine if the agent pool is a VMs pool.
isVMsPool, agentPoolName, sku := m.parseSKUAndVMsAgentpoolNameFromSpecName(s.Name)
if isVMsPool {
return NewVMPool(s, m, agentPoolName, sku)
}
switch m.config.VMType {

View File

@ -297,6 +297,7 @@ func TestCreateAzureManagerValidConfig(t *testing.T) {
VmssVmsCacheJitter: 120,
MaxDeploymentsCount: 8,
EnableFastDeleteOnFailedProvisioning: true,
EnableVMsAgentPool: false,
}
assert.NoError(t, err)
@ -618,9 +619,14 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMSSClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachineScaleSet{}, nil).AnyTimes()
mockVMClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachine{}, nil).AnyTimes()
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
vmspool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&vmspool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager).AnyTimes()
mockAzClient := &azClient{
virtualMachinesClient: mockVMClient,
virtualMachineScaleSetsClient: mockVMSSClient,
agentPoolClient: mockAgentpoolclient,
}
expectedConfig := &Config{
@ -702,6 +708,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
VmssVmsCacheJitter: 90,
MaxDeploymentsCount: 8,
EnableFastDeleteOnFailedProvisioning: true,
EnableVMsAgentPool: true,
}
t.Setenv("ARM_CLOUD", "AzurePublicCloud")
@ -735,6 +742,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) {
t.Setenv("ARM_CLUSTER_RESOURCE_GROUP", "myrg")
t.Setenv("ARM_BASE_URL_FOR_AP_CLIENT", "nodeprovisioner-svc.nodeprovisioner.svc.cluster.local")
t.Setenv("AZURE_ENABLE_FAST_DELETE_ON_FAILED_PROVISIONING", "true")
t.Setenv("AZURE_ENABLE_VMS_AGENT_POOLS", "true")
t.Run("environment variables correctly set", func(t *testing.T) {
manager, err := createAzureManagerInternal(nil, cloudprovider.NodeGroupDiscoveryOptions{}, mockAzClient)

View File

@ -21,7 +21,7 @@ import (
reflect "reflect"
runtime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4"
armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
gomock "go.uber.org/mock/gomock"
)
@ -49,46 +49,60 @@ func (m *MockAgentPoolsClient) EXPECT() *MockAgentPoolsClientMockRecorder {
}
// BeginCreateOrUpdate mocks base method.
func (m *MockAgentPoolsClient) BeginCreateOrUpdate(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPool, arg5 *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) {
func (m *MockAgentPoolsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, parameters armcontainerservice.AgentPool, options *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginCreateOrUpdate", arg0, arg1, arg2, arg3, arg4, arg5)
ret := m.ctrl.Call(m, "BeginCreateOrUpdate", ctx, resourceGroupName, resourceName, agentPoolName, parameters, options)
ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse])
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginCreateOrUpdate indicates an expected call of BeginCreateOrUpdate.
func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call {
func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(ctx, resourceGroupName, resourceName, agentPoolName, parameters, options any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), arg0, arg1, arg2, arg3, arg4, arg5)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), ctx, resourceGroupName, resourceName, agentPoolName, parameters, options)
}
// BeginDeleteMachines mocks base method.
func (m *MockAgentPoolsClient) BeginDeleteMachines(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPoolDeleteMachinesParameter, arg5 *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) {
func (m *MockAgentPoolsClient) BeginDeleteMachines(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, machines armcontainerservice.AgentPoolDeleteMachinesParameter, options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginDeleteMachines", arg0, arg1, arg2, arg3, arg4, arg5)
ret := m.ctrl.Call(m, "BeginDeleteMachines", ctx, resourceGroupName, resourceName, agentPoolName, machines, options)
ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse])
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginDeleteMachines indicates an expected call of BeginDeleteMachines.
func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call {
func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(ctx, resourceGroupName, resourceName, agentPoolName, machines, options any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), arg0, arg1, arg2, arg3, arg4, arg5)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), ctx, resourceGroupName, resourceName, agentPoolName, machines, options)
}
// Get mocks base method.
func (m *MockAgentPoolsClient) Get(arg0 context.Context, arg1, arg2, arg3 string, arg4 *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) {
func (m *MockAgentPoolsClient) Get(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, options *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3, arg4)
ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, resourceName, agentPoolName, options)
ret0, _ := ret[0].(armcontainerservice.AgentPoolsClientGetResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockAgentPoolsClientMockRecorder) Get(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
func (mr *MockAgentPoolsClientMockRecorder) Get(ctx, resourceGroupName, resourceName, agentPoolName, options any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), arg0, arg1, arg2, arg3, arg4)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), ctx, resourceGroupName, resourceName, agentPoolName, options)
}
// NewListPager mocks base method.
func (m *MockAgentPoolsClient) NewListPager(resourceGroupName, resourceName string, options *armcontainerservice.AgentPoolsClientListOptions) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, resourceName, options)
ret0, _ := ret[0].(*runtime.Pager[armcontainerservice.AgentPoolsClientListResponse])
return ret0
}
// NewListPager indicates an expected call of NewListPager.
func (mr *MockAgentPoolsClientMockRecorder) NewListPager(resourceGroupName, resourceName, options any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockAgentPoolsClient)(nil).NewListPager), resourceGroupName, resourceName, options)
}

View File

@ -651,15 +651,18 @@ func (scaleSet *ScaleSet) Debug() string {
// TemplateNodeInfo returns a node template for this scale set.
func (scaleSet *ScaleSet) TemplateNodeInfo() (*framework.NodeInfo, error) {
template, err := scaleSet.getVMSSFromCache()
vmss, err := scaleSet.getVMSSFromCache()
if err != nil {
return nil, err
}
inputLabels := map[string]string{}
inputTaints := ""
node, err := buildNodeFromTemplate(scaleSet.Name, inputLabels, inputTaints, template, scaleSet.manager, scaleSet.enableDynamicInstanceList)
template, err := buildNodeTemplateFromVMSS(vmss, inputLabels, inputTaints)
if err != nil {
return nil, err
}
node, err := buildNodeFromTemplate(scaleSet.Name, template, scaleSet.manager, scaleSet.enableDynamicInstanceList)
if err != nil {
return nil, err
}

View File

@ -1232,12 +1232,12 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
// Properly testing dynamic SKU list through skewer is not possible,
// because there are no Resource API mocks included yet.
// Instead, the rest of the (consumer side) tests here
// override GetVMSSTypeDynamically and GetVMSSTypeStatically functions.
// override GetInstanceTypeDynamically and GetInstanceTypeStatically functions.
t.Run("Checking dynamic workflow", func(t *testing.T) {
asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) {
GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
vmssType := InstanceType{}
vmssType.VCPU = 1
vmssType.GPU = 2
@ -1255,10 +1255,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Checking static workflow if dynamic fails", func(t *testing.T) {
asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) {
GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
return InstanceType{}, fmt.Errorf("dynamic error exists")
}
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) {
GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
vmssType := InstanceType{}
vmssType.VCPU = 1
vmssType.GPU = 2
@ -1276,10 +1276,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Fails to find vmss instance information using static and dynamic workflow, instance not supported", func(t *testing.T) {
asg.enableDynamicInstanceList = true
GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) {
GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) {
return InstanceType{}, fmt.Errorf("dynamic error exists")
}
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) {
GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
return &InstanceType{}, fmt.Errorf("static error exists")
}
nodeInfo, err := asg.TemplateNodeInfo()
@ -1292,7 +1292,7 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) {
t.Run("Checking static-only workflow", func(t *testing.T) {
asg.enableDynamicInstanceList = false
GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) {
GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) {
vmssType := InstanceType{}
vmssType.VCPU = 1
vmssType.GPU = 2

View File

@ -24,7 +24,9 @@ import (
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
apiv1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -84,8 +86,132 @@ const (
clusterLabelKey = AKSLabelKeyPrefixValue + "cluster"
)
func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, inputTaints string,
template compute.VirtualMachineScaleSet, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) {
// VMPoolNodeTemplate holds properties for node from VMPool
type VMPoolNodeTemplate struct {
AgentPoolName string
Taints []apiv1.Taint
Labels map[string]*string
OSDiskType *armcontainerservice.OSDiskType
}
// VMSSNodeTemplate holds properties for node from VMSS
type VMSSNodeTemplate struct {
InputLabels map[string]string
InputTaints string
Tags map[string]*string
OSDisk *compute.VirtualMachineScaleSetOSDisk
}
// NodeTemplate represents a template for an Azure node
type NodeTemplate struct {
SkuName string
InstanceOS string
Location string
Zones []string
VMPoolNodeTemplate *VMPoolNodeTemplate
VMSSNodeTemplate *VMSSNodeTemplate
}
func buildNodeTemplateFromVMSS(vmss compute.VirtualMachineScaleSet, inputLabels map[string]string, inputTaints string) (NodeTemplate, error) {
instanceOS := cloudprovider.DefaultOS
if vmss.VirtualMachineProfile != nil &&
vmss.VirtualMachineProfile.OsProfile != nil &&
vmss.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil {
instanceOS = "windows"
}
var osDisk *compute.VirtualMachineScaleSetOSDisk
if vmss.VirtualMachineProfile != nil &&
vmss.VirtualMachineProfile.StorageProfile != nil &&
vmss.VirtualMachineProfile.StorageProfile.OsDisk != nil {
osDisk = vmss.VirtualMachineProfile.StorageProfile.OsDisk
}
if vmss.Sku == nil || vmss.Sku.Name == nil {
return NodeTemplate{}, fmt.Errorf("VMSS %s has no SKU", to.String(vmss.Name))
}
if vmss.Location == nil {
return NodeTemplate{}, fmt.Errorf("VMSS %s has no location", to.String(vmss.Name))
}
zones := []string{}
if vmss.Zones != nil {
zones = *vmss.Zones
}
return NodeTemplate{
SkuName: *vmss.Sku.Name,
Location: *vmss.Location,
Zones: zones,
InstanceOS: instanceOS,
VMSSNodeTemplate: &VMSSNodeTemplate{
InputLabels: inputLabels,
InputTaints: inputTaints,
OSDisk: osDisk,
Tags: vmss.Tags,
},
}, nil
}
func buildNodeTemplateFromVMPool(vmsPool armcontainerservice.AgentPool, location string, skuName string, labelsFromSpec map[string]string, taintsFromSpec string) (NodeTemplate, error) {
if vmsPool.Properties == nil {
return NodeTemplate{}, fmt.Errorf("vmsPool %s has nil properties", to.String(vmsPool.Name))
}
// labels from the agentpool
labels := vmsPool.Properties.NodeLabels
// labels from spec
for k, v := range labelsFromSpec {
if labels == nil {
labels = make(map[string]*string)
}
labels[k] = to.StringPtr(v)
}
// taints from the agentpool
taintsList := []string{}
for _, taint := range vmsPool.Properties.NodeTaints {
if to.String(taint) != "" {
taintsList = append(taintsList, to.String(taint))
}
}
// taints from spec
if taintsFromSpec != "" {
taintsList = append(taintsList, taintsFromSpec)
}
taintsStr := strings.Join(taintsList, ",")
taints := extractTaintsFromSpecString(taintsStr)
var zones []string
if vmsPool.Properties.AvailabilityZones != nil {
for _, zone := range vmsPool.Properties.AvailabilityZones {
if zone != nil {
zones = append(zones, *zone)
}
}
}
var instanceOS string
if vmsPool.Properties.OSType != nil {
instanceOS = strings.ToLower(string(*vmsPool.Properties.OSType))
}
return NodeTemplate{
SkuName: skuName,
Zones: zones,
InstanceOS: instanceOS,
Location: location,
VMPoolNodeTemplate: &VMPoolNodeTemplate{
AgentPoolName: to.String(vmsPool.Name),
OSDiskType: vmsPool.Properties.OSDiskType,
Taints: taints,
Labels: labels,
},
}, nil
}
func buildNodeFromTemplate(nodeGroupName string, template NodeTemplate, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) {
node := apiv1.Node{}
nodeName := fmt.Sprintf("%s-asg-%d", nodeGroupName, rand.Int63())
@ -104,28 +230,28 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
// Fetching SKU information from SKU API if enableDynamicInstanceList is true.
var dynamicErr error
if enableDynamicInstanceList {
var vmssTypeDynamic InstanceType
klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", *template.Sku.Name)
vmssTypeDynamic, dynamicErr = GetVMSSTypeDynamically(template, manager.azureCache)
var instanceTypeDynamic InstanceType
klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", template.SkuName)
instanceTypeDynamic, dynamicErr = GetInstanceTypeDynamically(template, manager.azureCache)
if dynamicErr == nil {
vcpu = vmssTypeDynamic.VCPU
gpuCount = vmssTypeDynamic.GPU
memoryMb = vmssTypeDynamic.MemoryMb
vcpu = instanceTypeDynamic.VCPU
gpuCount = instanceTypeDynamic.GPU
memoryMb = instanceTypeDynamic.MemoryMb
} else {
klog.Errorf("Dynamically fetching of instance information from SKU api failed with error: %v", dynamicErr)
}
}
if !enableDynamicInstanceList || dynamicErr != nil {
klog.V(1).Infof("Falling back to static SKU list for SKU: %s", *template.Sku.Name)
klog.V(1).Infof("Falling back to static SKU list for SKU: %s", template.SkuName)
// fall-back on static list of vmss if dynamic workflow fails.
vmssTypeStatic, staticErr := GetVMSSTypeStatically(template)
instanceTypeStatic, staticErr := GetInstanceTypeStatically(template)
if staticErr == nil {
vcpu = vmssTypeStatic.VCPU
gpuCount = vmssTypeStatic.GPU
memoryMb = vmssTypeStatic.MemoryMb
vcpu = instanceTypeStatic.VCPU
gpuCount = instanceTypeStatic.GPU
memoryMb = instanceTypeStatic.MemoryMb
} else {
// return error if neither of the workflows results with vmss data.
klog.V(1).Infof("Instance type %q not supported, err: %v", *template.Sku.Name, staticErr)
klog.V(1).Infof("Instance type %q not supported, err: %v", template.SkuName, staticErr)
return nil, staticErr
}
}
@ -134,7 +260,7 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
node.Status.Capacity[apiv1.ResourceCPU] = *resource.NewQuantity(vcpu, resource.DecimalSI)
// isNPSeries returns if a SKU is an NP-series SKU
// SKU API reports GPUs for NP-series but it's actually FPGAs
if isNPSeries(*template.Sku.Name) {
if isNPSeries(template.SkuName) {
node.Status.Capacity[xilinxFpgaResourceName] = *resource.NewQuantity(gpuCount, resource.DecimalSI)
} else {
node.Status.Capacity[gpu.ResourceNvidiaGPU] = *resource.NewQuantity(gpuCount, resource.DecimalSI)
@ -145,9 +271,37 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
// TODO: set real allocatable.
node.Status.Allocatable = node.Status.Capacity
if template.VMSSNodeTemplate != nil {
node = processVMSSTemplate(template, nodeName, node)
} else if template.VMPoolNodeTemplate != nil {
node = processVMPoolTemplate(template, nodeName, node)
} else {
return nil, fmt.Errorf("invalid node template: missing both VMSS and VMPool templates")
}
klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels)
klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints)
node.Status.Conditions = cloudprovider.BuildReadyConditions()
return &node, nil
}
func processVMPoolTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node {
labels := buildGenericLabels(template, nodeName)
labels[agentPoolNodeLabelKey] = template.VMPoolNodeTemplate.AgentPoolName
if template.VMPoolNodeTemplate.Labels != nil {
for k, v := range template.VMPoolNodeTemplate.Labels {
labels[k] = to.String(v)
}
}
node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels)
node.Spec.Taints = template.VMPoolNodeTemplate.Taints
return node
}
func processVMSSTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node {
// NodeLabels
if template.Tags != nil {
for k, v := range template.Tags {
if template.VMSSNodeTemplate.Tags != nil {
for k, v := range template.VMSSNodeTemplate.Tags {
if v != nil {
node.Labels[k] = *v
} else {
@ -164,10 +318,10 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
labels := make(map[string]string)
// Prefer the explicit labels in spec coming from RP over the VMSS template
if len(inputLabels) > 0 {
labels = inputLabels
if len(template.VMSSNodeTemplate.InputLabels) > 0 {
labels = template.VMSSNodeTemplate.InputLabels
} else {
labels = extractLabelsFromScaleSet(template.Tags)
labels = extractLabelsFromTags(template.VMSSNodeTemplate.Tags)
}
// Add the agentpool label, its value should come from the VMSS poolName tag
@ -182,87 +336,74 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string,
labels[agentPoolNodeLabelKey] = node.Labels[poolNameTag]
}
// Add the storage profile and storage tier labels
if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.StorageProfile != nil && template.VirtualMachineProfile.StorageProfile.OsDisk != nil {
// Add the storage profile and storage tier labels for vmss node
if template.VMSSNodeTemplate.OSDisk != nil {
// ephemeral
if template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings != nil && template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings.Option == compute.Local {
if template.VMSSNodeTemplate.OSDisk.DiffDiskSettings != nil && template.VMSSNodeTemplate.OSDisk.DiffDiskSettings.Option == compute.Local {
labels[legacyStorageProfileNodeLabelKey] = "ephemeral"
labels[storageProfileNodeLabelKey] = "ephemeral"
} else {
labels[legacyStorageProfileNodeLabelKey] = "managed"
labels[storageProfileNodeLabelKey] = "managed"
}
if template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk != nil {
labels[legacyStorageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType)
labels[storageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType)
if template.VMSSNodeTemplate.OSDisk.ManagedDisk != nil {
labels[legacyStorageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType)
labels[storageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType)
}
// Add ephemeral-storage value
if template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB != nil {
node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI)
klog.V(4).Infof("OS Disk Size from template is: %d", *template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB)
if template.VMSSNodeTemplate.OSDisk.DiskSizeGB != nil {
node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VMSSNodeTemplate.OSDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI)
klog.V(4).Infof("OS Disk Size from template is: %d", *template.VMSSNodeTemplate.OSDisk.DiskSizeGB)
klog.V(4).Infof("Setting ephemeral storage to: %v", node.Status.Capacity[apiv1.ResourceEphemeralStorage])
}
}
// If we are on GPU-enabled SKUs, append the accelerator
// label so that CA makes better decision when scaling from zero for GPU pools
if isNvidiaEnabledSKU(*template.Sku.Name) {
if isNvidiaEnabledSKU(template.SkuName) {
labels[GPULabel] = "nvidia"
labels[legacyGPULabel] = "nvidia"
}
// Extract allocatables from tags
resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.Tags)
resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.VMSSNodeTemplate.Tags)
for resourceName, val := range resourcesFromTags {
node.Status.Capacity[apiv1.ResourceName(resourceName)] = *val
}
node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels)
klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels)
var taints []apiv1.Taint
// Prefer the explicit taints in spec over the VMSS template
if inputTaints != "" {
taints = extractTaintsFromSpecString(inputTaints)
// Prefer the explicit taints in spec over the tags from vmss or vm
if template.VMSSNodeTemplate.InputTaints != "" {
taints = extractTaintsFromSpecString(template.VMSSNodeTemplate.InputTaints)
} else {
taints = extractTaintsFromScaleSet(template.Tags)
taints = extractTaintsFromTags(template.VMSSNodeTemplate.Tags)
}
// Taints from the Scale Set's Tags
node.Spec.Taints = taints
klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints)
node.Status.Conditions = cloudprovider.BuildReadyConditions()
return &node, nil
return node
}
func buildInstanceOS(template compute.VirtualMachineScaleSet) string {
instanceOS := cloudprovider.DefaultOS
if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.OsProfile != nil && template.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil {
instanceOS = "windows"
}
return instanceOS
}
func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string) map[string]string {
func buildGenericLabels(template NodeTemplate, nodeName string) map[string]string {
result := make(map[string]string)
result[kubeletapis.LabelArch] = cloudprovider.DefaultArch
result[apiv1.LabelArchStable] = cloudprovider.DefaultArch
result[kubeletapis.LabelOS] = buildInstanceOS(template)
result[apiv1.LabelOSStable] = buildInstanceOS(template)
result[kubeletapis.LabelOS] = template.InstanceOS
result[apiv1.LabelOSStable] = template.InstanceOS
result[apiv1.LabelInstanceType] = *template.Sku.Name
result[apiv1.LabelInstanceTypeStable] = *template.Sku.Name
result[apiv1.LabelZoneRegion] = strings.ToLower(*template.Location)
result[apiv1.LabelTopologyRegion] = strings.ToLower(*template.Location)
result[apiv1.LabelInstanceType] = template.SkuName
result[apiv1.LabelInstanceTypeStable] = template.SkuName
result[apiv1.LabelZoneRegion] = strings.ToLower(template.Location)
result[apiv1.LabelTopologyRegion] = strings.ToLower(template.Location)
if template.Zones != nil && len(*template.Zones) > 0 {
failureDomains := make([]string, len(*template.Zones))
for k, v := range *template.Zones {
failureDomains[k] = strings.ToLower(*template.Location) + "-" + v
if len(template.Zones) > 0 {
failureDomains := make([]string, len(template.Zones))
for k, v := range template.Zones {
failureDomains[k] = strings.ToLower(template.Location) + "-" + v
}
//Picks random zones for Multi-zone nodepool when scaling from zero.
//This random zone will not be the same as the zone of the VMSS that is being created, the purpose of creating
@ -283,7 +424,7 @@ func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string
return result
}
func extractLabelsFromScaleSet(tags map[string]*string) map[string]string {
func extractLabelsFromTags(tags map[string]*string) map[string]string {
result := make(map[string]string)
for tagName, tagValue := range tags {
@ -300,7 +441,7 @@ func extractLabelsFromScaleSet(tags map[string]*string) map[string]string {
return result
}
func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint {
func extractTaintsFromTags(tags map[string]*string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0)
for tagName, tagValue := range tags {
@ -327,35 +468,61 @@ func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint {
return taints
}
// extractTaintsFromSpecString is for nodepool taints
// Example of a valid taints string, is the same argument to kubelet's `--register-with-taints`
// "dedicated=foo:NoSchedule,group=bar:NoExecute,app=fizz:PreferNoSchedule"
func extractTaintsFromSpecString(taintsString string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0)
dedupMap := make(map[string]interface{})
// First split the taints at the separator
splits := strings.Split(taintsString, ",")
for _, split := range splits {
taintSplit := strings.Split(split, "=")
if len(taintSplit) != 2 {
if dedupMap[split] != nil {
continue
}
taintKey := taintSplit[0]
taintValue := taintSplit[1]
r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)")
if !r.MatchString(taintValue) {
continue
dedupMap[split] = struct{}{}
valid, taint := constructTaintFromString(split)
if valid {
taints = append(taints, taint)
}
}
return taints
}
values := strings.SplitN(taintValue, ":", 2)
taints = append(taints, apiv1.Taint{
Key: taintKey,
Value: values[0],
Effect: apiv1.TaintEffect(values[1]),
})
// buildNodeTaintsForVMPool is for VMPool taints, it looks for the taints in the format
// []string{zone=dmz:NoSchedule, usage=monitoring:NoSchedule}
func buildNodeTaintsForVMPool(taintStrs []string) []apiv1.Taint {
taints := make([]apiv1.Taint, 0)
for _, taintStr := range taintStrs {
valid, taint := constructTaintFromString(taintStr)
if valid {
taints = append(taints, taint)
}
}
return taints
}
// constructTaintFromString constructs a taint from a string in the format <key>=<value>:<effect>
// if the input string is not in the correct format, it returns false and an empty taint
func constructTaintFromString(taintString string) (bool, apiv1.Taint) {
taintSplit := strings.Split(taintString, "=")
if len(taintSplit) != 2 {
return false, apiv1.Taint{}
}
taintKey := taintSplit[0]
taintValue := taintSplit[1]
r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)")
if !r.MatchString(taintValue) {
return false, apiv1.Taint{}
}
return taints
values := strings.SplitN(taintValue, ":", 2)
return true, apiv1.Taint{
Key: taintKey,
Value: values[0],
Effect: apiv1.TaintEffect(values[1]),
}
}
func extractAutoscalingOptionsFromScaleSetTags(tags map[string]*string) map[string]string {

View File

@ -21,6 +21,7 @@ import (
"strings"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/to"
@ -30,7 +31,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
)
func TestExtractLabelsFromScaleSet(t *testing.T) {
func TestExtractLabelsFromTags(t *testing.T) {
expectedNodeLabelKey := "zip"
expectedNodeLabelValue := "zap"
extraNodeLabelValue := "buzz"
@ -52,14 +53,14 @@ func TestExtractLabelsFromScaleSet(t *testing.T) {
fmt.Sprintf("%s%s", nodeLabelTagName, escapedUnderscoreNodeLabelKey): &escapedUnderscoreNodeLabelValue,
}
labels := extractLabelsFromScaleSet(tags)
labels := extractLabelsFromTags(tags)
assert.Len(t, labels, 3)
assert.Equal(t, expectedNodeLabelValue, labels[expectedNodeLabelKey])
assert.Equal(t, escapedSlashNodeLabelValue, labels[expectedSlashEscapedNodeLabelKey])
assert.Equal(t, escapedUnderscoreNodeLabelValue, labels[expectedUnderscoreEscapedNodeLabelKey])
}
func TestExtractTaintsFromScaleSet(t *testing.T) {
func TestExtractTaintsFromTags(t *testing.T) {
noScheduleTaintValue := "foo:NoSchedule"
noExecuteTaintValue := "bar:NoExecute"
preferNoScheduleTaintValue := "fizz:PreferNoSchedule"
@ -100,7 +101,7 @@ func TestExtractTaintsFromScaleSet(t *testing.T) {
},
}
taints := extractTaintsFromScaleSet(tags)
taints := extractTaintsFromTags(tags)
assert.Len(t, taints, 4)
assert.Equal(t, makeTaintSet(expectedTaints), makeTaintSet(taints))
}
@ -137,6 +138,11 @@ func TestExtractTaintsFromSpecString(t *testing.T) {
Value: "fizz",
Effect: apiv1.TaintEffectPreferNoSchedule,
},
{
Key: "dedicated", // duplicate key, should be ignored
Value: "foo",
Effect: apiv1.TaintEffectNoSchedule,
},
}
taints := extractTaintsFromSpecString(strings.Join(taintsString, ","))
@ -176,8 +182,9 @@ func TestTopologyFromScaleSet(t *testing.T) {
Location: to.StringPtr("westus"),
}
expectedZoneValues := []string{"westus-1", "westus-2", "westus-3"}
labels := buildGenericLabels(testVmss, testNodeName)
template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "")
assert.NoError(t, err)
labels := buildGenericLabels(template, testNodeName)
failureDomain, ok := labels[apiv1.LabelZoneFailureDomain]
assert.True(t, ok)
topologyZone, ok := labels[apiv1.LabelTopologyZone]
@ -205,7 +212,9 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) {
expectedFailureDomain := "0"
expectedTopologyZone := "0"
expectedAzureDiskTopology := ""
labels := buildGenericLabels(testVmss, testNodeName)
template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "")
assert.NoError(t, err)
labels := buildGenericLabels(template, testNodeName)
failureDomain, ok := labels[apiv1.LabelZoneFailureDomain]
assert.True(t, ok)
@ -219,6 +228,61 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, expectedAzureDiskTopology, azureDiskTopology)
}
func TestBuildNodeTemplateFromVMPool(t *testing.T) {
agentPoolName := "testpool"
location := "eastus"
skuName := "Standard_DS2_v2"
labelKey := "foo"
labelVal := "bar"
taintStr := "dedicated=foo:NoSchedule,boo=fizz:PreferNoSchedule,group=bar:NoExecute"
osType := armcontainerservice.OSTypeLinux
osDiskType := armcontainerservice.OSDiskTypeEphemeral
zone1 := "1"
zone2 := "2"
vmpool := armcontainerservice.AgentPool{
Name: to.StringPtr(agentPoolName),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
NodeLabels: map[string]*string{
"existing": to.StringPtr("label"),
"department": to.StringPtr("engineering"),
},
NodeTaints: []*string{to.StringPtr("group=bar:NoExecute")},
OSType: &osType,
OSDiskType: &osDiskType,
AvailabilityZones: []*string{&zone1, &zone2},
},
}
labelsFromSpec := map[string]string{labelKey: labelVal}
taintsFromSpec := taintStr
template, err := buildNodeTemplateFromVMPool(vmpool, location, skuName, labelsFromSpec, taintsFromSpec)
assert.NoError(t, err)
assert.Equal(t, skuName, template.SkuName)
assert.Equal(t, location, template.Location)
assert.ElementsMatch(t, []string{zone1, zone2}, template.Zones)
assert.Equal(t, "linux", template.InstanceOS)
assert.NotNil(t, template.VMPoolNodeTemplate)
assert.Equal(t, agentPoolName, template.VMPoolNodeTemplate.AgentPoolName)
assert.Equal(t, &osDiskType, template.VMPoolNodeTemplate.OSDiskType)
// Labels: should include both from NodeLabels and labelsFromSpec
assert.Contains(t, template.VMPoolNodeTemplate.Labels, "existing")
assert.Equal(t, "label", *template.VMPoolNodeTemplate.Labels["existing"])
assert.Contains(t, template.VMPoolNodeTemplate.Labels, "department")
assert.Equal(t, "engineering", *template.VMPoolNodeTemplate.Labels["department"])
assert.Contains(t, template.VMPoolNodeTemplate.Labels, labelKey)
assert.Equal(t, labelVal, *template.VMPoolNodeTemplate.Labels[labelKey])
// Taints: should include both from NodeTaints and taintsFromSpec
taintSet := makeTaintSet(template.VMPoolNodeTemplate.Taints)
expectedTaints := []apiv1.Taint{
{Key: "group", Value: "bar", Effect: apiv1.TaintEffectNoExecute},
{Key: "dedicated", Value: "foo", Effect: apiv1.TaintEffectNoSchedule},
{Key: "boo", Value: "fizz", Effect: apiv1.TaintEffectPreferNoSchedule},
}
assert.Equal(t, makeTaintSet(expectedTaints), taintSet)
}
func makeTaintSet(taints []apiv1.Taint) map[apiv1.Taint]bool {
set := make(map[apiv1.Taint]bool)

View File

@ -18,142 +18,426 @@ package azure
import (
"fmt"
"net/http"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
apiv1 "k8s.io/api/core/v1"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"k8s.io/autoscaler/cluster-autoscaler/config"
"k8s.io/autoscaler/cluster-autoscaler/config/dynamic"
"k8s.io/autoscaler/cluster-autoscaler/simulator/framework"
klog "k8s.io/klog/v2"
)
// VMsPool is single instance VM pool
// this is a placeholder for now, no real implementation
type VMsPool struct {
// VMPool represents a group of standalone virtual machines (VMs) with a single SKU.
// It is part of a mixed-SKU agent pool (an agent pool with type `VirtualMachines`).
// Terminology:
// - Agent pool: A node pool in an AKS cluster.
// - VMs pool: An agent pool of type `VirtualMachines`, which can contain mixed SKUs.
// - VMPool: A subset of VMs within a VMs pool that share the same SKU.
type VMPool struct {
azureRef
manager *AzureManager
resourceGroup string
agentPoolName string // the virtual machines agentpool that this VMPool belongs to
sku string // sku of the VM in the pool
minSize int
maxSize int
curSize int64
// sizeMutex sync.Mutex
// lastSizeRefresh time.Time
}
// NewVMsPool creates a new VMsPool
func NewVMsPool(spec *dynamic.NodeGroupSpec, am *AzureManager) *VMsPool {
nodepool := &VMsPool{
azureRef: azureRef{
Name: spec.Name,
},
manager: am,
resourceGroup: am.config.ResourceGroup,
curSize: -1,
minSize: spec.MinSize,
maxSize: spec.MaxSize,
// NewVMPool creates a new VMPool - a pool of standalone VMs of a single size.
func NewVMPool(spec *dynamic.NodeGroupSpec, am *AzureManager, agentPoolName string, sku string) (*VMPool, error) {
if am.azClient.agentPoolClient == nil {
return nil, fmt.Errorf("agentPoolClient is nil")
}
return nodepool
nodepool := &VMPool{
azureRef: azureRef{
Name: spec.Name, // in format "<agentPoolName>/<sku>"
},
manager: am,
sku: sku,
agentPoolName: agentPoolName,
minSize: spec.MinSize,
maxSize: spec.MaxSize,
}
return nodepool, nil
}
// MinSize returns the minimum size the cluster is allowed to scaled down
// MinSize returns the minimum size the vmPool is allowed to scaled down
// to as provided by the node spec in --node parameter.
func (agentPool *VMsPool) MinSize() int {
return agentPool.minSize
func (vmPool *VMPool) MinSize() int {
return vmPool.minSize
}
// Exist is always true since we are initialized with an existing agentpool
func (agentPool *VMsPool) Exist() bool {
// Exist is always true since we are initialized with an existing vmPool
func (vmPool *VMPool) Exist() bool {
return true
}
// Create creates the node group on the cloud provider side.
func (agentPool *VMsPool) Create() (cloudprovider.NodeGroup, error) {
func (vmPool *VMPool) Create() (cloudprovider.NodeGroup, error) {
return nil, cloudprovider.ErrAlreadyExist
}
// Delete deletes the node group on the cloud provider side.
func (agentPool *VMsPool) Delete() error {
func (vmPool *VMPool) Delete() error {
return cloudprovider.ErrNotImplemented
}
// ForceDeleteNodes deletes nodes from the group regardless of constraints.
func (vmPool *VMPool) ForceDeleteNodes(nodes []*apiv1.Node) error {
return cloudprovider.ErrNotImplemented
}
// Autoprovisioned is always false since we are initialized with an existing agentpool
func (agentPool *VMsPool) Autoprovisioned() bool {
func (vmPool *VMPool) Autoprovisioned() bool {
return false
}
// GetOptions returns NodeGroupAutoscalingOptions that should be used for this particular
// NodeGroup. Returning a nil will result in using default options.
func (agentPool *VMsPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) {
// TODO(wenxuan): Implement this method
return nil, cloudprovider.ErrNotImplemented
func (vmPool *VMPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) {
// TODO(wenxuan): implement this method when vmPool can fully support GPU nodepool
return nil, nil
}
// MaxSize returns the maximum size scale limit provided by --node
// parameter to the autoscaler main
func (agentPool *VMsPool) MaxSize() int {
return agentPool.maxSize
func (vmPool *VMPool) MaxSize() int {
return vmPool.maxSize
}
// TargetSize returns the current TARGET size of the node group. It is possible that the
// number is different from the number of nodes registered in Kubernetes.
func (agentPool *VMsPool) TargetSize() (int, error) {
// TODO(wenxuan): Implement this method
return -1, cloudprovider.ErrNotImplemented
// TargetSize returns the current target size of the node group. This value represents
// the desired number of nodes in the VMPool, which may differ from the actual number
// of nodes currently present.
func (vmPool *VMPool) TargetSize() (int, error) {
// VMs in the "Deleting" state are not counted towards the target size.
size, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false})
return int(size), err
}
// IncreaseSize increase the size through a PUT AP call. It calculates the expected size
// based on a delta provided as parameter
func (agentPool *VMsPool) IncreaseSize(delta int) error {
// TODO(wenxuan): Implement this method
return cloudprovider.ErrNotImplemented
// IncreaseSize increases the size of the VMPool by sending a PUT request to update the agent pool.
// This method waits until the asynchronous PUT operation completes or the client-side timeout is reached.
func (vmPool *VMPool) IncreaseSize(delta int) error {
if delta <= 0 {
return fmt.Errorf("size increase must be positive, current delta: %d", delta)
}
// Skip VMs in the failed state so that a PUT AP will be triggered to fix the failed VMs.
currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: true})
if err != nil {
return err
}
if int(currentSize)+delta > vmPool.MaxSize() {
return fmt.Errorf("size-increasing request of %d is bigger than max size %d", int(currentSize)+delta, vmPool.MaxSize())
}
updateCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout)
defer cancel()
versionedAP, err := vmPool.getAgentpoolFromCache()
if err != nil {
klog.Errorf("Failed to get vmPool %s, error: %s", vmPool.agentPoolName, err)
return err
}
count := currentSize + int32(delta)
requestBody := armcontainerservice.AgentPool{}
// self-hosted CAS will be using Manual scale profile
if len(versionedAP.Properties.VirtualMachinesProfile.Scale.Manual) > 0 {
requestBody = buildRequestBodyForScaleUp(versionedAP, count, vmPool.sku)
} else { // AKS-managed CAS will use custom header for setting the target count
header := make(http.Header)
header.Set("Target-Count", fmt.Sprintf("%d", count))
updateCtx = policy.WithHTTPHeader(updateCtx, header)
}
defer vmPool.manager.invalidateCache()
poller, err := vmPool.manager.azClient.agentPoolClient.BeginCreateOrUpdate(
updateCtx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName,
requestBody, nil)
if err != nil {
klog.Errorf("Failed to scale up agentpool %s in cluster %s for vmPool %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, err)
return err
}
if _, err := poller.PollUntilDone(updateCtx, nil /*default polling interval is 30s*/); err != nil {
klog.Errorf("agentPoolClient.BeginCreateOrUpdate for aks cluster %s agentpool %s for scaling up vmPool %s failed with error %s",
vmPool.manager.config.ClusterName, vmPool.agentPoolName, vmPool.Name, err)
return err
}
klog.Infof("Successfully scaled up agentpool %s in cluster %s for vmPool %s to size %d",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, count)
return nil
}
// DeleteNodes extracts the providerIDs from the node spec and
// delete or deallocate the nodes from the agent pool based on the scale down policy.
func (agentPool *VMsPool) DeleteNodes(nodes []*apiv1.Node) error {
// TODO(wenxuan): Implement this method
return cloudprovider.ErrNotImplemented
// buildRequestBodyForScaleUp builds the request body for scale up for self-hosted CAS
func buildRequestBodyForScaleUp(agentpool armcontainerservice.AgentPool, count int32, vmSku string) armcontainerservice.AgentPool {
requestBody := armcontainerservice.AgentPool{
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: agentpool.Properties.Type,
},
}
// the request body must have the same mode as the original agentpool
// otherwise the PUT request will fail
if agentpool.Properties.Mode != nil &&
*agentpool.Properties.Mode == armcontainerservice.AgentPoolModeSystem {
systemMode := armcontainerservice.AgentPoolModeSystem
requestBody.Properties.Mode = &systemMode
}
// set the count of the matching manual scale profile to the new target value
for _, manualProfile := range agentpool.Properties.VirtualMachinesProfile.Scale.Manual {
if manualProfile != nil && len(manualProfile.Sizes) == 1 &&
strings.EqualFold(to.String(manualProfile.Sizes[0]), vmSku) {
klog.V(5).Infof("Found matching manual profile for VM SKU: %s, updating count to: %d", vmSku, count)
manualProfile.Count = to.Int32Ptr(count)
requestBody.Properties.VirtualMachinesProfile = agentpool.Properties.VirtualMachinesProfile
break
}
}
return requestBody
}
// ForceDeleteNodes deletes nodes from the group regardless of constraints.
func (agentPool *VMsPool) ForceDeleteNodes(nodes []*apiv1.Node) error {
return cloudprovider.ErrNotImplemented
// DeleteNodes removes the specified nodes from the VMPool by extracting their providerIDs
// and performing the appropriate delete or deallocate operation based on the agent pool's
// scale-down policy. This method waits for the asynchronous delete operation to complete,
// with a client-side timeout.
func (vmPool *VMPool) DeleteNodes(nodes []*apiv1.Node) error {
// Ensure we don't scale below the minimum size by excluding VMs in the "Deleting" state.
currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false})
if err != nil {
return fmt.Errorf("unable to retrieve current size: %w", err)
}
if int(currentSize) <= vmPool.MinSize() {
return fmt.Errorf("cannot delete nodes as minimum size of %d has been reached", vmPool.MinSize())
}
providerIDs, err := vmPool.getProviderIDsForNodes(nodes)
if err != nil {
return fmt.Errorf("failed to retrieve provider IDs for nodes: %w", err)
}
if len(providerIDs) == 0 {
return nil
}
klog.V(3).Infof("Deleting nodes from vmPool %s: %v", vmPool.Name, providerIDs)
machineNames := make([]*string, len(providerIDs))
for i, providerID := range providerIDs {
// extract the machine name from the providerID by splitting the providerID by '/' and get the last element
// The providerID look like this:
// "azure:///subscriptions/0000000-0000-0000-0000-00000000000/resourceGroups/mc_myrg_mycluster_eastus/providers/Microsoft.Compute/virtualMachines/aks-mypool-12345678-vms0"
machineName, err := resourceName(providerID)
if err != nil {
return err
}
machineNames[i] = &machineName
}
requestBody := armcontainerservice.AgentPoolDeleteMachinesParameter{
MachineNames: machineNames,
}
deleteCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout)
defer cancel()
defer vmPool.manager.invalidateCache()
poller, err := vmPool.manager.azClient.agentPoolClient.BeginDeleteMachines(
deleteCtx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName,
requestBody, nil)
if err != nil {
klog.Errorf("Failed to delete nodes from agentpool %s in cluster %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, err)
return err
}
if _, err := poller.PollUntilDone(deleteCtx, nil); err != nil {
klog.Errorf("agentPoolClient.BeginDeleteMachines for aks cluster %s for scaling down vmPool %s failed with error %s",
vmPool.manager.config.ClusterName, vmPool.agentPoolName, err)
return err
}
klog.Infof("Successfully deleted %d nodes from vmPool %s", len(providerIDs), vmPool.Name)
return nil
}
func (vmPool *VMPool) getProviderIDsForNodes(nodes []*apiv1.Node) ([]string, error) {
var providerIDs []string
for _, node := range nodes {
belongs, err := vmPool.Belongs(node)
if err != nil {
return nil, fmt.Errorf("failed to check if node %s belongs to vmPool %s: %w", node.Name, vmPool.Name, err)
}
if !belongs {
return nil, fmt.Errorf("node %s does not belong to vmPool %s", node.Name, vmPool.Name)
}
providerIDs = append(providerIDs, node.Spec.ProviderID)
}
return providerIDs, nil
}
// Belongs returns true if the given k8s node belongs to this vms nodepool.
func (vmPool *VMPool) Belongs(node *apiv1.Node) (bool, error) {
klog.V(6).Infof("Check if node belongs to this vmPool:%s, node:%v\n", vmPool, node)
ref := &azureRef{
Name: node.Spec.ProviderID,
}
nodeGroup, err := vmPool.manager.GetNodeGroupForInstance(ref)
if err != nil {
return false, err
}
if nodeGroup == nil {
return false, fmt.Errorf("%s doesn't belong to a known node group", node.Name)
}
if !strings.EqualFold(nodeGroup.Id(), vmPool.Id()) {
return false, nil
}
return true, nil
}
// DecreaseTargetSize decreases the target size of the node group.
func (agentPool *VMsPool) DecreaseTargetSize(delta int) error {
// TODO(wenxuan): Implement this method
return cloudprovider.ErrNotImplemented
func (vmPool *VMPool) DecreaseTargetSize(delta int) error {
// The TargetSize of a VMPool is automatically adjusted after node deletions.
// This method is invoked in scenarios such as (see details in clusterstate.go):
// - len(readiness.Registered) > acceptableRange.CurrentTarget
// - len(readiness.Registered) < acceptableRange.CurrentTarget - unregisteredNodes
// For VMPool, this method should not be called because:
// CurrentTarget = len(readiness.Registered) + unregisteredNodes - len(nodesInDeletingState)
// Here, nodesInDeletingState is a subset of unregisteredNodes,
// ensuring len(readiness.Registered) is always within the acceptable range.
// here we just invalidate the cache to avoid any potential bugs
vmPool.manager.invalidateCache()
klog.Warningf("DecreaseTargetSize called for VMPool %s, but it should not be used, invalidating cache", vmPool.Name)
return nil
}
// Id returns the name of the agentPool
func (agentPool *VMsPool) Id() string {
return agentPool.azureRef.Name
// Id returns the name of the agentPool, it is in the format of <agentpoolname>/<sku>
// e.g. mypool1/Standard_D2s_v3
func (vmPool *VMPool) Id() string {
return vmPool.azureRef.Name
}
// Debug returns a string with basic details of the agentPool
func (agentPool *VMsPool) Debug() string {
return fmt.Sprintf("%s (%d:%d)", agentPool.Id(), agentPool.MinSize(), agentPool.MaxSize())
func (vmPool *VMPool) Debug() string {
return fmt.Sprintf("%s (%d:%d)", vmPool.Id(), vmPool.MinSize(), vmPool.MaxSize())
}
func (agentPool *VMsPool) getVMsFromCache() ([]compute.VirtualMachine, error) {
// vmsPoolMap is a map of agent pool name to the list of virtual machines
vmsPoolMap := agentPool.manager.azureCache.getVirtualMachines()
if _, ok := vmsPoolMap[agentPool.Name]; !ok {
return []compute.VirtualMachine{}, fmt.Errorf("vms pool %s not found in the cache", agentPool.Name)
func isSpotAgentPool(ap armcontainerservice.AgentPool) bool {
if ap.Properties != nil && ap.Properties.ScaleSetPriority != nil {
return strings.EqualFold(string(*ap.Properties.ScaleSetPriority), "Spot")
}
return false
}
// skipOption is used to determine whether to skip VMs in certain states when calculating the current size of the vmPool.
type skipOption struct {
// skipDeleting indicates whether to skip VMs in the "Deleting" state.
skipDeleting bool
// skipFailed indicates whether to skip VMs in the "Failed" state.
skipFailed bool
}
// getCurSize determines the current count of VMs in the vmPool, including unregistered ones.
// The source of truth depends on the pool type (spot or non-spot).
func (vmPool *VMPool) getCurSize(op skipOption) (int32, error) {
agentPool, err := vmPool.getAgentpoolFromCache()
if err != nil {
klog.Errorf("Failed to retrieve agent pool %s from cache: %v", vmPool.agentPoolName, err)
return -1, err
}
return vmsPoolMap[agentPool.Name], nil
// spot pool size is retrieved directly from Azure instead of the cache
if isSpotAgentPool(agentPool) {
return vmPool.getSpotPoolSize()
}
// non-spot pool size is retrieved from the cache
vms, err := vmPool.getVMsFromCache(op)
if err != nil {
klog.Errorf("Failed to get VMs from cache for agentpool %s with error: %v", vmPool.agentPoolName, err)
return -1, err
}
return int32(len(vms)), nil
}
// getSpotPoolSize retrieves the current size of a spot agent pool directly from Azure.
func (vmPool *VMPool) getSpotPoolSize() (int32, error) {
ap, err := vmPool.getAgentpoolFromAzure()
if err != nil {
klog.Errorf("Failed to get agentpool %s from Azure with error: %v", vmPool.agentPoolName, err)
return -1, err
}
if ap.Properties != nil {
// the VirtualMachineNodesStatus returned by AKS-RP is constructed from the vm list returned from CRP.
// it only contains VMs in the running state.
for _, status := range ap.Properties.VirtualMachineNodesStatus {
if status != nil {
if strings.EqualFold(to.String(status.Size), vmPool.sku) {
return to.Int32(status.Count), nil
}
}
}
}
return -1, fmt.Errorf("failed to get the size of spot agentpool %s", vmPool.agentPoolName)
}
// getVMsFromCache retrieves the list of virtual machines in this VMPool.
// If excludeDeleting is true, it skips VMs in the "Deleting" state.
// https://learn.microsoft.com/en-us/azure/virtual-machines/states-billing#provisioning-states
func (vmPool *VMPool) getVMsFromCache(op skipOption) ([]compute.VirtualMachine, error) {
vmsMap := vmPool.manager.azureCache.getVirtualMachines()
var filteredVMs []compute.VirtualMachine
for _, vm := range vmsMap[vmPool.agentPoolName] {
if vm.VirtualMachineProperties == nil ||
vm.VirtualMachineProperties.HardwareProfile == nil ||
!strings.EqualFold(string(vm.HardwareProfile.VMSize), vmPool.sku) {
continue
}
if op.skipDeleting && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Deleting") {
klog.V(4).Infof("Skipping VM %s in deleting state", to.String(vm.ID))
continue
}
if op.skipFailed && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Failed") {
klog.V(4).Infof("Skipping VM %s in failed state", to.String(vm.ID))
continue
}
filteredVMs = append(filteredVMs, vm)
}
return filteredVMs, nil
}
// Nodes returns the list of nodes in the vms agentPool.
func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) {
vms, err := agentPool.getVMsFromCache()
func (vmPool *VMPool) Nodes() ([]cloudprovider.Instance, error) {
vms, err := vmPool.getVMsFromCache(skipOption{}) // no skip option, get all VMs
if err != nil {
return nil, err
}
@ -163,7 +447,7 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) {
if vm.ID == nil || len(*vm.ID) == 0 {
continue
}
resourceID, err := convertResourceGroupNameToLower("azure://" + *vm.ID)
resourceID, err := convertResourceGroupNameToLower("azure://" + to.String(vm.ID))
if err != nil {
return nil, err
}
@ -173,12 +457,53 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) {
return nodes, nil
}
// TemplateNodeInfo is not implemented.
func (agentPool *VMsPool) TemplateNodeInfo() (*framework.NodeInfo, error) {
return nil, cloudprovider.ErrNotImplemented
// TemplateNodeInfo returns a NodeInfo object that can be used to create a new node in the vmPool.
func (vmPool *VMPool) TemplateNodeInfo() (*framework.NodeInfo, error) {
ap, err := vmPool.getAgentpoolFromCache()
if err != nil {
return nil, err
}
inputLabels := map[string]string{}
inputTaints := ""
template, err := buildNodeTemplateFromVMPool(ap, vmPool.manager.config.Location, vmPool.sku, inputLabels, inputTaints)
if err != nil {
return nil, err
}
node, err := buildNodeFromTemplate(vmPool.agentPoolName, template, vmPool.manager, vmPool.manager.config.EnableDynamicInstanceList)
if err != nil {
return nil, err
}
nodeInfo := framework.NewNodeInfo(node, nil, &framework.PodInfo{Pod: cloudprovider.BuildKubeProxy(vmPool.agentPoolName)})
return nodeInfo, nil
}
func (vmPool *VMPool) getAgentpoolFromCache() (armcontainerservice.AgentPool, error) {
vmsPoolMap := vmPool.manager.azureCache.getVMsPoolMap()
if _, exists := vmsPoolMap[vmPool.agentPoolName]; !exists {
return armcontainerservice.AgentPool{}, fmt.Errorf("VMs agent pool %s not found in cache", vmPool.agentPoolName)
}
return vmsPoolMap[vmPool.agentPoolName], nil
}
// getAgentpoolFromAzure returns the AKS agentpool from Azure
func (vmPool *VMPool) getAgentpoolFromAzure() (armcontainerservice.AgentPool, error) {
ctx, cancel := getContextWithTimeout(vmsContextTimeout)
defer cancel()
resp, err := vmPool.manager.azClient.agentPoolClient.Get(
ctx,
vmPool.manager.config.ClusterResourceGroup,
vmPool.manager.config.ClusterName,
vmPool.agentPoolName, nil)
if err != nil {
return resp.AgentPool, fmt.Errorf("failed to get agentpool %s in cluster %s with error: %v",
vmPool.agentPoolName, vmPool.manager.config.ClusterName, err)
}
return resp.AgentPool, nil
}
// AtomicIncreaseSize is not implemented.
func (agentPool *VMsPool) AtomicIncreaseSize(delta int) error {
func (vmPool *VMPool) AtomicIncreaseSize(delta int) error {
return cloudprovider.ErrNotImplemented
}

View File

@ -17,45 +17,64 @@ limitations under the License.
package azure
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/go-autorest/autorest/to"
"go.uber.org/mock/gomock"
"github.com/stretchr/testify/assert"
apiv1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"k8s.io/autoscaler/cluster-autoscaler/config"
"k8s.io/autoscaler/cluster-autoscaler/config/dynamic"
providerazure "sigs.k8s.io/cloud-provider-azure/pkg/provider"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient"
)
func newTestVMsPool(manager *AzureManager, name string) *VMsPool {
return &VMsPool{
const (
vmSku = "Standard_D2_v2"
vmsAgentPoolName = "test-vms-pool"
vmsNodeGroupName = vmsAgentPoolName + "/" + vmSku
fakeVMsNodeName = "aks-" + vmsAgentPoolName + "-13222729-vms%d"
fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/" + fakeVMsNodeName
)
func newTestVMsPool(manager *AzureManager) *VMPool {
return &VMPool{
azureRef: azureRef{
Name: name,
Name: vmsNodeGroupName,
},
manager: manager,
minSize: 3,
maxSize: 10,
manager: manager,
minSize: 3,
maxSize: 10,
agentPoolName: vmsAgentPoolName,
sku: vmSku,
}
}
const (
fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/%d"
)
func newTestVMsPoolVMList(count int) []compute.VirtualMachine {
var vmList []compute.VirtualMachine
for i := 0; i < count; i++ {
vm := compute.VirtualMachine{
ID: to.StringPtr(fmt.Sprintf(fakeVMsPoolVMID, i)),
VirtualMachineProperties: &compute.VirtualMachineProperties{
VMID: to.StringPtr(fmt.Sprintf("123E4567-E89B-12D3-A456-426655440000-%d", i)),
HardwareProfile: &compute.HardwareProfile{
VMSize: compute.VirtualMachineSizeTypes(vmSku),
},
ProvisioningState: to.StringPtr("Succeeded"),
},
Tags: map[string]*string{
agentpoolTypeTag: to.StringPtr("VirtualMachines"),
agentpoolNameTag: to.StringPtr("test-vms-pool"),
agentpoolNameTag: to.StringPtr(vmsAgentPoolName),
},
}
vmList = append(vmList, vm)
@ -63,41 +82,73 @@ func newTestVMsPoolVMList(count int) []compute.VirtualMachine {
return vmList
}
func newVMsNode(vmID int64) *apiv1.Node {
node := &apiv1.Node{
func newVMsNode(vmIdx int64) *apiv1.Node {
return &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf(fakeVMsNodeName, vmIdx),
},
Spec: apiv1.NodeSpec{
ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmID),
ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmIdx),
},
}
return node
}
func TestNewVMsPool(t *testing.T) {
spec := &dynamic.NodeGroupSpec{
Name: "test-nodepool",
MinSize: 1,
MaxSize: 5,
func getTestVMsAgentPool(isSystemPool bool) armcontainerservice.AgentPool {
mode := armcontainerservice.AgentPoolModeUser
if isSystemPool {
mode = armcontainerservice.AgentPoolModeSystem
}
am := &AzureManager{
config: &Config{
Config: providerazure.Config{
ResourceGroup: "test-resource-group",
vmsPoolType := armcontainerservice.AgentPoolTypeVirtualMachines
return armcontainerservice.AgentPool{
Name: to.StringPtr(vmsAgentPoolName),
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
Type: &vmsPoolType,
Mode: &mode,
VirtualMachinesProfile: &armcontainerservice.VirtualMachinesProfile{
Scale: &armcontainerservice.ScaleProfile{
Manual: []*armcontainerservice.ManualScaleProfile{
{
Count: to.Int32Ptr(3),
Sizes: []*string{to.StringPtr(vmSku)},
},
},
},
},
VirtualMachineNodesStatus: []*armcontainerservice.VirtualMachineNodes{
{
Count: to.Int32Ptr(3),
Size: to.StringPtr(vmSku),
},
},
},
}
}
nodepool := NewVMsPool(spec, am)
func TestNewVMsPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
manager := newTestAzureManager(t)
manager.azClient.agentPoolClient = mockAgentpoolclient
manager.config.ResourceGroup = "MC_rg"
manager.config.ClusterResourceGroup = "rg"
manager.config.ClusterName = "mycluster"
assert.Equal(t, "test-nodepool", nodepool.azureRef.Name)
assert.Equal(t, "test-resource-group", nodepool.resourceGroup)
assert.Equal(t, int64(-1), nodepool.curSize)
assert.Equal(t, 1, nodepool.minSize)
assert.Equal(t, 5, nodepool.maxSize)
assert.Equal(t, am, nodepool.manager)
spec := &dynamic.NodeGroupSpec{
Name: vmsAgentPoolName,
MinSize: 1,
MaxSize: 10,
}
ap, err := NewVMPool(spec, manager, vmsAgentPoolName, vmSku)
assert.NoError(t, err)
assert.Equal(t, vmsAgentPoolName, ap.azureRef.Name)
assert.Equal(t, 1, ap.minSize)
assert.Equal(t, 10, ap.maxSize)
}
func TestMinSize(t *testing.T) {
agentPool := &VMsPool{
agentPool := &VMPool{
minSize: 1,
}
@ -105,12 +156,12 @@ func TestMinSize(t *testing.T) {
}
func TestExist(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
assert.True(t, agentPool.Exist())
}
func TestCreate(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
nodeGroup, err := agentPool.Create()
assert.Nil(t, nodeGroup)
@ -118,65 +169,43 @@ func TestCreate(t *testing.T) {
}
func TestDelete(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
err := agentPool.Delete()
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestAutoprovisioned(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
assert.False(t, agentPool.Autoprovisioned())
}
func TestGetOptions(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
defaults := config.NodeGroupAutoscalingOptions{}
options, err := agentPool.GetOptions(defaults)
assert.Nil(t, options)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
assert.Nil(t, err)
}
func TestMaxSize(t *testing.T) {
agentPool := &VMsPool{
agentPool := &VMPool{
maxSize: 10,
}
assert.Equal(t, 10, agentPool.MaxSize())
}
func TestTargetSize(t *testing.T) {
agentPool := &VMsPool{}
size, err := agentPool.TargetSize()
assert.Equal(t, -1, size)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestIncreaseSize(t *testing.T) {
agentPool := &VMsPool{}
err := agentPool.IncreaseSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestDeleteNodes(t *testing.T) {
agentPool := &VMsPool{}
err := agentPool.DeleteNodes(nil)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
func TestDecreaseTargetSize(t *testing.T) {
agentPool := &VMsPool{}
agentPool := newTestVMsPool(newTestAzureManager(t))
err := agentPool.DecreaseTargetSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
assert.Nil(t, err)
}
func TestId(t *testing.T) {
agentPool := &VMsPool{
agentPool := &VMPool{
azureRef: azureRef{
Name: "test-id",
},
@ -186,7 +215,7 @@ func TestId(t *testing.T) {
}
func TestDebug(t *testing.T) {
agentPool := &VMsPool{
agentPool := &VMPool{
azureRef: azureRef{
Name: "test-debug",
},
@ -198,115 +227,341 @@ func TestDebug(t *testing.T) {
assert.Equal(t, expectedDebugString, agentPool.Debug())
}
func TestTemplateNodeInfo(t *testing.T) {
agentPool := &VMsPool{}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
nodeInfo, err := agentPool.TemplateNodeInfo()
assert.Nil(t, nodeInfo)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
ap := newTestVMsPool(newTestAzureManager(t))
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
nodeInfo, err := ap.TemplateNodeInfo()
assert.NotNil(t, nodeInfo)
assert.Nil(t, err)
}
func TestAtomicIncreaseSize(t *testing.T) {
agentPool := &VMsPool{}
agentPool := &VMPool{}
err := agentPool.AtomicIncreaseSize(1)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}
// Test cases for getVMsFromCache()
// Test case 1 - when the vms pool is not found in the cache
// Test case 2 - when the vms pool is found in the cache but has no VMs
// Test case 3 - when the vms pool is found in the cache and has VMs
// Test case 4 - when the vms pool is found in the cache and has VMs with no name
func TestGetVMsFromCache(t *testing.T) {
// Test case 1
manager := &AzureManager{
azureCache: &azureCache{
virtualMachines: make(map[string][]compute.VirtualMachine),
vmsPoolMap: make(map[string]armcontainerservice.AgentPool),
},
}
agentPool := &VMsPool{
manager: manager,
azureRef: azureRef{
Name: "test-vms-pool",
},
agentPool := &VMPool{
manager: manager,
agentPoolName: vmsAgentPoolName,
sku: vmSku,
}
_, err := agentPool.getVMsFromCache()
assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache")
// Test case 1 - when the vms pool is not found in the cache
vms, err := agentPool.getVMsFromCache(skipOption{})
assert.Nil(t, err)
assert.Len(t, vms, 0)
// Test case 2
manager.azureCache.virtualMachines["test-vms-pool"] = []compute.VirtualMachine{}
_, err = agentPool.getVMsFromCache()
// Test case 2 - when the vms pool is found in the cache but has no VMs
manager.azureCache.virtualMachines[vmsAgentPoolName] = []compute.VirtualMachine{}
vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err)
assert.Len(t, vms, 0)
// Test case 3
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
vms, err := agentPool.getVMsFromCache()
// Test case 3 - when the vms pool is found in the cache and has VMs
manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3)
vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err)
assert.Len(t, vms, 3)
// Test case 4
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
agentPool.azureRef.Name = ""
_, err = agentPool.getVMsFromCache()
assert.EqualError(t, err, "vms pool not found in the cache")
// Test case 4 - should skip failed VMs
vmList := newTestVMsPoolVMList(3)
vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Failed")
manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true})
assert.NoError(t, err)
assert.Len(t, vms, 2)
// Test case 5 - should skip deleting VMs
vmList = newTestVMsPoolVMList(3)
vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting")
manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
vms, err = agentPool.getVMsFromCache(skipOption{skipDeleting: true})
assert.NoError(t, err)
assert.Len(t, vms, 2)
// Test case 6 - should not skip deleting VMs
vmList = newTestVMsPoolVMList(3)
vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting")
manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList
vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true})
assert.NoError(t, err)
assert.Len(t, vms, 3)
// Test case 7 - when the vms pool is found in the cache and has VMs with no name
manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3)
agentPool.agentPoolName = ""
vms, err = agentPool.getVMsFromCache(skipOption{})
assert.NoError(t, err)
assert.Len(t, vms, 0)
}
func TestGetVMsFromCacheForVMsPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(2)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ac.enableVMsAgentPool = true
ap.manager.azureCache = ac
vms, err := ap.getVMsFromCache(skipOption{})
assert.Equal(t, 2, len(vms))
assert.NoError(t, err)
}
// Test cases for Nodes()
// Test case 1 - when there are no VMs in the pool
// Test case 2 - when there are VMs in the pool
// Test case 3 - when there are VMs in the pool with no ID
// Test case 4 - when there is an error converting resource group name
// Test case 5 - when there is an error getting VMs from cache
func TestNodes(t *testing.T) {
// Test case 1
manager := &AzureManager{
azureCache: &azureCache{
virtualMachines: make(map[string][]compute.VirtualMachine),
},
}
agentPool := &VMsPool{
manager: manager,
azureRef: azureRef{
Name: "test-vms-pool",
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
nodes, err := agentPool.Nodes()
assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache")
assert.Empty(t, nodes)
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(2)
// Test case 2
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
nodes, err = agentPool.Nodes()
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
assert.Len(t, nodes, 3)
ap.manager.azureCache = ac
// Test case 3
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = nil
nodes, err = agentPool.Nodes()
vms, err := ap.Nodes()
assert.Equal(t, 2, len(vms))
assert.NoError(t, err)
assert.Len(t, nodes, 2)
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
emptyString := ""
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &emptyString
nodes, err = agentPool.Nodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
// Test case 4
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3)
bogusID := "foo"
manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &bogusID
nodes, err = agentPool.Nodes()
assert.Empty(t, nodes)
assert.Error(t, err)
// Test case 5
manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(1)
agentPool.azureRef.Name = ""
nodes, err = agentPool.Nodes()
assert.Empty(t, nodes)
assert.Error(t, err)
}
func TestGetCurSizeForVMsPool(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(3)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
curSize, err := ap.getCurSize(skipOption{})
assert.NoError(t, err)
assert.Equal(t, int32(3), curSize)
}
func TestVMsPoolIncreaseSize(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
manager := newTestAzureManager(t)
ap := newTestVMsPool(manager)
expectedVMs := newTestVMsPoolVMList(3)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
agentpool := getTestVMsAgentPool(false)
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).
Return(fakeAPListPager)
ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config)
assert.NoError(t, err)
ap.manager.azureCache = ac
// failure case 1
err1 := ap.IncreaseSize(-1)
expectedErr := fmt.Errorf("size increase must be positive, current delta: -1")
assert.Equal(t, expectedErr, err1)
// failure case 2
err2 := ap.IncreaseSize(8)
expectedErr = fmt.Errorf("size-increasing request of 11 is bigger than max size 10")
assert.Equal(t, expectedErr, err2)
// success case 3
resp := &http.Response{
Header: map[string][]string{
"Fake-Poller-Status": {"Done"},
},
}
fakePoller, pollerErr := runtime.NewPoller(resp, runtime.Pipeline{},
&runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{
Handler: &fakehandler[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{},
})
assert.NoError(t, pollerErr)
mockAgentpoolclient.EXPECT().BeginCreateOrUpdate(
gomock.Any(), manager.config.ClusterResourceGroup,
manager.config.ClusterName,
vmsAgentPoolName,
gomock.Any(), gomock.Any()).Return(fakePoller, nil)
err3 := ap.IncreaseSize(1)
assert.NoError(t, err3)
}
func TestDeleteVMsPoolNodes_Failed(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
node := newVMsNode(0)
expectedVMs := newTestVMsPoolVMList(3)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
agentpool := getTestVMsAgentPool(false)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager)
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.azureCache.enableVMsAgentPool = true
registered := ap.manager.RegisterNodeGroup(ap)
assert.True(t, registered)
ap.manager.explicitlyConfigured[vmsNodeGroupName] = true
ap.manager.forceRefresh()
// failure case
deleteErr := ap.DeleteNodes([]*apiv1.Node{node})
assert.Error(t, deleteErr)
assert.Contains(t, deleteErr.Error(), "cannot delete nodes as minimum size of 3 has been reached")
}
func TestDeleteVMsPoolNodes_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ap := newTestVMsPool(newTestAzureManager(t))
expectedVMs := newTestVMsPoolVMList(5)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
ap.manager.azClient.virtualMachinesClient = mockVMClient
ap.manager.config.EnableVMsAgentPool = true
mockAgentpoolclient := NewMockAgentPoolsClient(ctrl)
agentpool := getTestVMsAgentPool(false)
ap.manager.azClient.agentPoolClient = mockAgentpoolclient
fakeAPListPager := getFakeAgentpoolListPager(&agentpool)
mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager)
mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil)
ap.manager.azureCache.enableVMsAgentPool = true
registered := ap.manager.RegisterNodeGroup(ap)
assert.True(t, registered)
ap.manager.explicitlyConfigured[vmsNodeGroupName] = true
ap.manager.forceRefresh()
// success case
resp := &http.Response{
Header: map[string][]string{
"Fake-Poller-Status": {"Done"},
},
}
fakePoller, err := runtime.NewPoller(resp, runtime.Pipeline{},
&runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{
Handler: &fakehandler[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{},
})
assert.NoError(t, err)
mockAgentpoolclient.EXPECT().BeginDeleteMachines(
gomock.Any(), ap.manager.config.ClusterResourceGroup,
ap.manager.config.ClusterName,
vmsAgentPoolName,
gomock.Any(), gomock.Any()).Return(fakePoller, nil)
node := newVMsNode(0)
derr := ap.DeleteNodes([]*apiv1.Node{node})
assert.NoError(t, derr)
}
type fakehandler[T any] struct{}
func (f *fakehandler[T]) Done() bool {
return true
}
func (f *fakehandler[T]) Poll(ctx context.Context) (*http.Response, error) {
return nil, nil
}
func (f *fakehandler[T]) Result(ctx context.Context, out *T) error {
return nil
}
func getFakeAgentpoolListPager(agentpool ...*armcontainerservice.AgentPool) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] {
fakeFetcher := func(ctx context.Context, response *armcontainerservice.AgentPoolsClientListResponse) (armcontainerservice.AgentPoolsClientListResponse, error) {
return armcontainerservice.AgentPoolsClientListResponse{
AgentPoolListResult: armcontainerservice.AgentPoolListResult{
Value: agentpool,
},
}, nil
}
return runtime.NewPager(runtime.PagingHandler[armcontainerservice.AgentPoolsClientListResponse]{
More: func(response armcontainerservice.AgentPoolsClientListResponse) bool {
return false
},
Fetcher: fakeFetcher,
})
}

View File

@ -214,6 +214,11 @@ autoscaler about the sizing of the nodes in the node group. At the minimum,
you must specify the CPU and memory annotations, these annotations should
match the expected capacity of the nodes created from the infrastructure.
> Note: The scale from zero annotations will override any capacity information
> supplied by the Cluster API provider in the infrastructure machine templates.
> If both the annotations and the provider supplied capacity information are
> present, the annotations will take precedence.
For example, if my MachineDeployment will create nodes that have "16000m" CPU,
"128G" memory, "100Gi" ephemeral disk storage, 2 NVidia GPUs, and can support
200 max pods, the following annotations will instruct the autoscaler how to
@ -240,11 +245,12 @@ metadata:
capacity.cluster-autoscaler.kubernetes.io/gpu-count: "2"
```
*Note* the `maxPods` annotation will default to `110` if it is not supplied.
This value is inspired by the Kubernetes best practices
[Considerations for large clusters](https://kubernetes.io/docs/setup/best-practices/cluster-large/).
> Note: the `maxPods` annotation will default to `110` if it is not supplied.
> This value is inspired by the Kubernetes best practices
> [Considerations for large clusters](https://kubernetes.io/docs/setup/best-practices/cluster-large/).
*Note* User should select the annotation for GPU either `gpu-type` or `dra-driver` depends on whether using Device Plugin or Dynamic Resource Allocation(DRA). `gpu-count` is a common parameter in both.
> Note: User should select the annotation for GPU either `gpu-type` or `dra-driver` depends on whether using
> Device Plugin or Dynamic Resource Allocation(DRA). `gpu-count` is a common parameter in both.
#### RBAC changes for scaling from zero
@ -289,6 +295,12 @@ metadata:
capacity.cluster-autoscaler.kubernetes.io/taints: "key1=value1:NoSchedule,key2=value2:NoExecute"
```
> Note: The labels supplied through the capacity annotation will be combined
> with the labels to be propagated from the scalable Cluster API resource.
> The annotation does not override the labels in the scalable resource.
> Please see the [Cluster API Book chapter on Metadata propagation](https://cluster-api.sigs.k8s.io/reference/api/metadata-propagation)
> for more information.
#### Per-NodeGroup autoscaling options
Custom autoscaling options per node group (MachineDeployment/MachinePool/MachineSet) can be specified as annoations with a common prefix:

View File

@ -6,6 +6,7 @@ import (
"net/url"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -54,9 +55,21 @@ const (
type ActionError struct {
Code string
Message string
action *Action
}
// Action returns the [Action] that triggered the error if available.
func (e ActionError) Action() *Action {
return e.action
}
func (e ActionError) Error() string {
action := e.Action()
if action != nil {
// For easier debugging, the error string contains the Action ID.
return fmt.Sprintf("%s (%s, %d)", e.Message, e.Code, action.ID)
}
return fmt.Sprintf("%s (%s)", e.Message, e.Code)
}
@ -65,6 +78,7 @@ func (a *Action) Error() error {
return ActionError{
Code: a.ErrorCode,
Message: a.ErrorMessage,
action: a,
}
}
return nil
@ -111,11 +125,15 @@ func (c *ActionClient) List(ctx context.Context, opts ActionListOpts) ([]*Action
}
// All returns all actions.
//
// Deprecated: It is required to pass in a list of IDs since 30 January 2025. Please use [ActionClient.AllWithOpts] instead.
func (c *ActionClient) All(ctx context.Context) ([]*Action, error) {
return c.action.All(ctx, ActionListOpts{ListOpts: ListOpts{PerPage: 50}})
}
// AllWithOpts returns all actions for the given options.
//
// It is required to set [ActionListOpts.ID]. Any other fields set in the opts are ignored.
func (c *ActionClient) AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error) {
return c.action.All(ctx, opts)
}
@ -136,20 +154,19 @@ func (c *ResourceActionClient) getBaseURL() string {
// GetByID retrieves an action by its ID. If the action does not exist, nil is returned.
func (c *ResourceActionClient) GetByID(ctx context.Context, id int64) (*Action, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("%s/actions/%d", c.getBaseURL(), id), nil)
if err != nil {
return nil, nil, err
}
opPath := c.getBaseURL() + "/actions/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.ActionGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.ActionGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return ActionFromSchema(body.Action), resp, nil
return ActionFromSchema(respBody.Action), resp, nil
}
// List returns a list of actions for a specific page.
@ -157,44 +174,23 @@ func (c *ResourceActionClient) GetByID(ctx context.Context, id int64) (*Action,
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *ResourceActionClient) List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error) {
req, err := c.client.NewRequest(
ctx,
"GET",
fmt.Sprintf("%s/actions?%s", c.getBaseURL(), opts.values().Encode()),
nil,
)
opPath := c.getBaseURL() + "/actions?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ActionListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.ActionListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
actions := make([]*Action, 0, len(body.Actions))
for _, i := range body.Actions {
actions = append(actions, ActionFromSchema(i))
}
return actions, resp, nil
return allFromSchemaFunc(respBody.Actions, ActionFromSchema), resp, nil
}
// All returns all actions for the given options.
func (c *ResourceActionClient) All(ctx context.Context, opts ActionListOpts) ([]*Action, error) {
allActions := []*Action{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Action, *Response, error) {
opts.Page = page
actions, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allActions = append(allActions, actions...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allActions, nil
}

View File

@ -16,11 +16,14 @@ type ActionWaiter interface {
var _ ActionWaiter = (*ActionClient)(nil)
// WaitForFunc waits until all actions are completed by polling the API at the interval
// defined by [WithPollBackoffFunc]. An action is considered as complete when its status is
// defined by [WithPollOpts]. An action is considered as complete when its status is
// either [ActionStatusSuccess] or [ActionStatusError].
//
// The handleUpdate callback is called every time an action is updated.
func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error {
// Filter out nil actions
actions = slices.DeleteFunc(actions, func(a *Action) bool { return a == nil })
running := make(map[int64]struct{}, len(actions))
for _, action := range actions {
if action.Status == ActionStatusRunning {
@ -48,18 +51,19 @@ func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update
retries++
}
opts := ActionListOpts{
Sort: []string{"status", "id"},
ID: make([]int64, 0, len(running)),
}
for actionID := range running {
opts.ID = append(opts.ID, actionID)
}
slices.Sort(opts.ID)
updates := make([]*Action, 0, len(running))
for runningIDsChunk := range slices.Chunk(slices.Sorted(maps.Keys(running)), 25) {
opts := ActionListOpts{
Sort: []string{"status", "id"},
ID: runningIDsChunk,
}
updates, err := c.AllWithOpts(ctx, opts)
if err != nil {
return err
updatesChunk, err := c.AllWithOpts(ctx, opts)
if err != nil {
return err
}
updates = append(updates, updatesChunk...)
}
if len(updates) != len(running) {
@ -95,7 +99,7 @@ func (c *ActionClient) WaitForFunc(ctx context.Context, handleUpdate func(update
}
// WaitFor waits until all actions succeed by polling the API at the interval defined by
// [WithPollBackoffFunc]. An action is considered as succeeded when its status is either
// [WithPollOpts]. An action is considered as succeeded when its status is either
// [ActionStatusSuccess].
//
// If a single action fails, the function will stop waiting and the error set in the

View File

@ -21,7 +21,7 @@ import (
// timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed.
//
// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait
// WatchOverallProgress uses the [WithPollOpts] of the [Client] to wait
// until sending the next request.
//
// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead.
@ -86,7 +86,7 @@ func (c *ActionClient) WatchOverallProgress(ctx context.Context, actions []*Acti
// timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed.
//
// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until
// WatchProgress uses the [WithPollOpts] of the [Client] to wait until
// sending the next request.
//
// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead.

View File

@ -1,15 +1,12 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -98,41 +95,32 @@ type CertificateClient struct {
// GetByID retrieves a Certificate by its ID. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) GetByID(ctx context.Context, id int64) (*Certificate, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/certificates/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/certificates/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.CertificateGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.CertificateGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return CertificateFromSchema(body.Certificate), resp, nil
return CertificateFromSchema(respBody.Certificate), resp, nil
}
// GetByName retrieves a Certificate by its name. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) GetByName(ctx context.Context, name string) (*Certificate, *Response, error) {
if name == "" {
return nil, nil, nil
}
Certificate, response, err := c.List(ctx, CertificateListOpts{Name: name})
if len(Certificate) == 0 {
return nil, response, err
}
return Certificate[0], response, err
return firstByName(name, func() ([]*Certificate, *Response, error) {
return c.List(ctx, CertificateListOpts{Name: name})
})
}
// Get retrieves a Certificate by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Certificate by its name. If the Certificate does not exist, nil is returned.
func (c *CertificateClient) Get(ctx context.Context, idOrName string) (*Certificate, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// CertificateListOpts specifies options for listing Certificates.
@ -158,22 +146,17 @@ func (l CertificateListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *CertificateClient) List(ctx context.Context, opts CertificateListOpts) ([]*Certificate, *Response, error) {
path := "/certificates?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/certificates?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.CertificateListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.CertificateListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
Certificates := make([]*Certificate, 0, len(body.Certificates))
for _, s := range body.Certificates {
Certificates = append(Certificates, CertificateFromSchema(s))
}
return Certificates, resp, nil
return allFromSchemaFunc(respBody.Certificates, CertificateFromSchema), resp, nil
}
// All returns all Certificates.
@ -183,22 +166,10 @@ func (c *CertificateClient) All(ctx context.Context) ([]*Certificate, error) {
// AllWithOpts returns all Certificates for the given options.
func (c *CertificateClient) AllWithOpts(ctx context.Context, opts CertificateListOpts) ([]*Certificate, error) {
allCertificates := []*Certificate{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Certificate, *Response, error) {
opts.Page = page
Certificates, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allCertificates = append(allCertificates, Certificates...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allCertificates, nil
}
// CertificateCreateOpts specifies options for creating a new Certificate.
@ -214,7 +185,7 @@ type CertificateCreateOpts struct {
// Validate checks if options are valid.
func (o CertificateCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
switch o.Type {
case "", CertificateTypeUploaded:
@ -222,23 +193,23 @@ func (o CertificateCreateOpts) Validate() error {
case CertificateTypeManaged:
return o.validateManaged()
default:
return fmt.Errorf("invalid type: %s", o.Type)
return invalidFieldValue(o, "Type", o.Type)
}
}
func (o CertificateCreateOpts) validateManaged() error {
if len(o.DomainNames) == 0 {
return errors.New("no domain names")
return missingField(o, "DomainNames")
}
return nil
}
func (o CertificateCreateOpts) validateUploaded() error {
if o.Certificate == "" {
return errors.New("missing certificate")
return missingField(o, "Certificate")
}
if o.PrivateKey == "" {
return errors.New("missing private key")
return missingField(o, "PrivateKey")
}
return nil
}
@ -249,7 +220,7 @@ func (o CertificateCreateOpts) validateUploaded() error {
// CreateCertificate to create such certificates.
func (c *CertificateClient) Create(ctx context.Context, opts CertificateCreateOpts) (*Certificate, *Response, error) {
if !(opts.Type == "" || opts.Type == CertificateTypeUploaded) {
return nil, nil, fmt.Errorf("invalid certificate type: %s", opts.Type)
return nil, nil, invalidFieldValue(opts, "Type", opts.Type)
}
result, resp, err := c.CreateCertificate(ctx, opts)
if err != nil {
@ -262,16 +233,20 @@ func (c *CertificateClient) Create(ctx context.Context, opts CertificateCreateOp
func (c *CertificateClient) CreateCertificate(
ctx context.Context, opts CertificateCreateOpts,
) (CertificateCreateResult, *Response, error) {
var (
action *Action
reqBody schema.CertificateCreateRequest
)
const opPath = "/certificates"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
result := CertificateCreateResult{}
if err := opts.Validate(); err != nil {
return CertificateCreateResult{}, nil, err
return result, nil, err
}
reqBody.Name = opts.Name
reqBody := schema.CertificateCreateRequest{
Name: opts.Name,
}
switch opts.Type {
case "", CertificateTypeUploaded:
@ -282,32 +257,24 @@ func (c *CertificateClient) CreateCertificate(
reqBody.Type = string(CertificateTypeManaged)
reqBody.DomainNames = opts.DomainNames
default:
return CertificateCreateResult{}, nil, fmt.Errorf("invalid certificate type: %v", opts.Type)
return result, nil, invalidFieldValue(opts, "Type", opts.Type)
}
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.CertificateCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return CertificateCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/certificates", bytes.NewReader(reqBodyData))
if err != nil {
return CertificateCreateResult{}, nil, err
return result, resp, err
}
respBody := schema.CertificateCreateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil {
return CertificateCreateResult{}, resp, err
}
cert := CertificateFromSchema(respBody.Certificate)
result.Certificate = CertificateFromSchema(respBody.Certificate)
if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action)
result.Action = ActionFromSchema(*respBody.Action)
}
return CertificateCreateResult{Certificate: cert, Action: action}, resp, nil
return result, resp, nil
}
// CertificateUpdateOpts specifies options for updating a Certificate.
@ -318,6 +285,11 @@ type CertificateUpdateOpts struct {
// Update updates a Certificate.
func (c *CertificateClient) Update(ctx context.Context, certificate *Certificate, opts CertificateUpdateOpts) (*Certificate, *Response, error) {
const opPath = "/certificates/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, certificate.ID)
reqBody := schema.CertificateUpdateRequest{}
if opts.Name != "" {
reqBody.Name = &opts.Name
@ -325,46 +297,36 @@ func (c *CertificateClient) Update(ctx context.Context, certificate *Certificate
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/certificates/%d", certificate.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.CertificateUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.CertificateUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return CertificateFromSchema(respBody.Certificate), resp, nil
}
// Delete deletes a certificate.
func (c *CertificateClient) Delete(ctx context.Context, certificate *Certificate) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/certificates/%d", certificate.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/certificates/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, certificate.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// RetryIssuance retries the issuance of a failed managed certificate.
func (c *CertificateClient) RetryIssuance(ctx context.Context, certificate *Certificate) (*Action, *Response, error) {
var respBody schema.CertificateIssuanceRetryResponse
const opPath = "/certificates/%d/actions/retry"
ctx = ctxutil.SetOpPath(ctx, opPath)
req, err := c.client.NewRequest(ctx, "POST", fmt.Sprintf("/certificates/%d/actions/retry", certificate.ID), nil)
reqPath := fmt.Sprintf(opPath, certificate.ID)
respBody, resp, err := postRequest[schema.CertificateIssuanceRetryResponse](ctx, c.client, reqPath, nil)
if err != nil {
return nil, nil, err
return nil, resp, err
}
resp, err := c.client.Do(req, &respBody)
if err != nil {
return nil, nil, err
}
action := ActionFromSchema(respBody.Action)
return action, resp, nil
return ActionFromSchema(respBody.Action), resp, nil
}

View File

@ -3,13 +3,12 @@ package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
@ -19,7 +18,6 @@ import (
"golang.org/x/net/http/httpguts"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
// Endpoint is the base URL of the API.
@ -43,13 +41,43 @@ func ConstantBackoff(d time.Duration) BackoffFunc {
}
// ExponentialBackoff returns a BackoffFunc which implements an exponential
// backoff.
// It uses the formula:
// backoff, truncated to 60 seconds.
// See [ExponentialBackoffWithOpts] for more details.
func ExponentialBackoff(multiplier float64, base time.Duration) BackoffFunc {
return ExponentialBackoffWithOpts(ExponentialBackoffOpts{
Base: base,
Multiplier: multiplier,
Cap: time.Minute,
})
}
// ExponentialBackoffOpts defines the options used by [ExponentialBackoffWithOpts].
type ExponentialBackoffOpts struct {
Base time.Duration
Multiplier float64
Cap time.Duration
Jitter bool
}
// ExponentialBackoffWithOpts returns a BackoffFunc which implements an exponential
// backoff, truncated to a maximum, and an optional full jitter.
//
// b^retries * d
func ExponentialBackoff(b float64, d time.Duration) BackoffFunc {
// See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
func ExponentialBackoffWithOpts(opts ExponentialBackoffOpts) BackoffFunc {
baseSeconds := opts.Base.Seconds()
capSeconds := opts.Cap.Seconds()
return func(retries int) time.Duration {
return time.Duration(math.Pow(b, float64(retries))) * d
// Exponential backoff
backoff := baseSeconds * math.Pow(opts.Multiplier, float64(retries))
// Cap backoff
backoff = math.Min(capSeconds, backoff)
// Add jitter
if opts.Jitter {
backoff = ((backoff - baseSeconds) * rand.Float64()) + baseSeconds // #nosec G404
}
return time.Duration(backoff * float64(time.Second))
}
}
@ -58,7 +86,8 @@ type Client struct {
endpoint string
token string
tokenValid bool
backoffFunc BackoffFunc
retryBackoffFunc BackoffFunc
retryMaxRetries int
pollBackoffFunc BackoffFunc
httpClient *http.Client
applicationName string
@ -66,6 +95,7 @@ type Client struct {
userAgent string
debugWriter io.Writer
instrumentationRegistry prometheus.Registerer
handler handler
Action ActionClient
Certificate CertificateClient
@ -110,30 +140,73 @@ func WithToken(token string) ClientOption {
// polling from the API.
//
// Deprecated: Setting the poll interval is deprecated, you can now configure
// [WithPollBackoffFunc] with a [ConstantBackoff] to get the same results. To
// [WithPollOpts] with a [ConstantBackoff] to get the same results. To
// migrate your code, replace your usage like this:
//
// // before
// hcloud.WithPollInterval(2 * time.Second)
// // now
// hcloud.WithPollBackoffFunc(hcloud.ConstantBackoff(2 * time.Second))
// hcloud.WithPollOpts(hcloud.PollOpts{
// BackoffFunc: hcloud.ConstantBackoff(2 * time.Second),
// })
func WithPollInterval(pollInterval time.Duration) ClientOption {
return WithPollBackoffFunc(ConstantBackoff(pollInterval))
return WithPollOpts(PollOpts{
BackoffFunc: ConstantBackoff(pollInterval),
})
}
// WithPollBackoffFunc configures a Client to use the specified backoff
// function when polling from the API.
//
// Deprecated: WithPollBackoffFunc is deprecated, use [WithPollOpts] instead.
func WithPollBackoffFunc(f BackoffFunc) ClientOption {
return WithPollOpts(PollOpts{
BackoffFunc: f,
})
}
// PollOpts defines the options used by [WithPollOpts].
type PollOpts struct {
BackoffFunc BackoffFunc
}
// WithPollOpts configures a Client to use the specified options when polling from the API.
//
// If [PollOpts.BackoffFunc] is nil, the existing backoff function will be preserved.
func WithPollOpts(opts PollOpts) ClientOption {
return func(client *Client) {
client.pollBackoffFunc = f
if opts.BackoffFunc != nil {
client.pollBackoffFunc = opts.BackoffFunc
}
}
}
// WithBackoffFunc configures a Client to use the specified backoff function.
// The backoff function is used for retrying HTTP requests.
//
// Deprecated: WithBackoffFunc is deprecated, use [WithRetryOpts] instead.
func WithBackoffFunc(f BackoffFunc) ClientOption {
return func(client *Client) {
client.backoffFunc = f
client.retryBackoffFunc = f
}
}
// RetryOpts defines the options used by [WithRetryOpts].
type RetryOpts struct {
BackoffFunc BackoffFunc
MaxRetries int
}
// WithRetryOpts configures a Client to use the specified options when retrying API
// requests.
//
// If [RetryOpts.BackoffFunc] is nil, the existing backoff function will be preserved.
func WithRetryOpts(opts RetryOpts) ClientOption {
return func(client *Client) {
if opts.BackoffFunc != nil {
client.retryBackoffFunc = opts.BackoffFunc
}
client.retryMaxRetries = opts.MaxRetries
}
}
@ -172,10 +245,18 @@ func WithInstrumentation(registry prometheus.Registerer) ClientOption {
// NewClient creates a new client.
func NewClient(options ...ClientOption) *Client {
client := &Client{
endpoint: Endpoint,
tokenValid: true,
httpClient: &http.Client{},
backoffFunc: ExponentialBackoff(2, 500*time.Millisecond),
endpoint: Endpoint,
tokenValid: true,
httpClient: &http.Client{},
retryBackoffFunc: ExponentialBackoffWithOpts(ExponentialBackoffOpts{
Base: time.Second,
Multiplier: 2,
Cap: time.Minute,
Jitter: true,
}),
retryMaxRetries: 5,
pollBackoffFunc: ConstantBackoff(500 * time.Millisecond),
}
@ -186,9 +267,11 @@ func NewClient(options ...ClientOption) *Client {
client.buildUserAgent()
if client.instrumentationRegistry != nil {
i := instrumentation.New("api", client.instrumentationRegistry)
client.httpClient.Transport = i.InstrumentedRoundTripper()
client.httpClient.Transport = i.InstrumentedRoundTripper(client.httpClient.Transport)
}
client.handler = assembleHandlerChain(client)
client.Action = ActionClient{action: &ResourceActionClient{client: client}}
client.Datacenter = DatacenterClient{client: client}
client.FloatingIP = FloatingIPClient{client: client, Action: &ResourceActionClient{client: client, resource: "floating_ips"}}
@ -238,97 +321,8 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
// Do performs an HTTP request against the API.
// v can be nil, an io.Writer to write the response body to or a pointer to
// a struct to json.Unmarshal the response to.
func (c *Client) Do(r *http.Request, v interface{}) (*Response, error) {
var retries int
var body []byte
var err error
if r.ContentLength > 0 {
body, err = io.ReadAll(r.Body)
if err != nil {
r.Body.Close()
return nil, err
}
r.Body.Close()
}
for {
if r.ContentLength > 0 {
r.Body = io.NopCloser(bytes.NewReader(body))
}
if c.debugWriter != nil {
dumpReq, err := dumpRequest(r)
if err != nil {
return nil, err
}
fmt.Fprintf(c.debugWriter, "--- Request:\n%s\n\n", dumpReq)
}
resp, err := c.httpClient.Do(r)
if err != nil {
return nil, err
}
response := &Response{Response: resp}
body, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return response, err
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(body))
if c.debugWriter != nil {
dumpResp, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, err
}
fmt.Fprintf(c.debugWriter, "--- Response:\n%s\n\n", dumpResp)
}
if err = response.readMeta(body); err != nil {
return response, fmt.Errorf("hcloud: error reading response meta data: %s", err)
}
if response.StatusCode >= 400 && response.StatusCode <= 599 {
err = errorFromResponse(response, body)
if err == nil {
err = fmt.Errorf("hcloud: server responded with status code %d", resp.StatusCode)
} else if IsError(err, ErrorCodeConflict) {
c.backoff(retries)
retries++
continue
}
return response, err
}
if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, bytes.NewReader(body))
} else {
err = json.Unmarshal(body, v)
}
}
return response, err
}
}
func (c *Client) backoff(retries int) {
time.Sleep(c.backoffFunc(retries))
}
func (c *Client) all(f func(int) (*Response, error)) error {
var (
page = 1
)
for {
resp, err := f(page)
if err != nil {
return err
}
if resp.Meta.Pagination == nil || resp.Meta.Pagination.NextPage == 0 {
return nil
}
page = resp.Meta.Pagination.NextPage
}
func (c *Client) Do(req *http.Request, v any) (*Response, error) {
return c.handler.Do(req, v)
}
func (c *Client) buildUserAgent() {
@ -342,43 +336,6 @@ func (c *Client) buildUserAgent() {
}
}
func dumpRequest(r *http.Request) ([]byte, error) {
// Duplicate the request, so we can redact the auth header
rDuplicate := r.Clone(context.Background())
rDuplicate.Header.Set("Authorization", "REDACTED")
// To get the request body we need to read it before the request was actually sent.
// See https://github.com/golang/go/issues/29792
dumpReq, err := httputil.DumpRequestOut(rDuplicate, true)
if err != nil {
return nil, err
}
// Set original request body to the duplicate created by DumpRequestOut. The request body is not duplicated
// by .Clone() and instead just referenced, so it would be completely read otherwise.
r.Body = rDuplicate.Body
return dumpReq, nil
}
func errorFromResponse(resp *Response, body []byte) error {
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
return nil
}
var respBody schema.ErrorResponse
if err := json.Unmarshal(body, &respBody); err != nil {
return nil
}
if respBody.Error.Code == "" && respBody.Error.Message == "" {
return nil
}
hcErr := ErrorFromSchema(respBody.Error)
hcErr.response = resp
return hcErr
}
const (
headerCorrelationID = "X-Correlation-Id"
)
@ -387,35 +344,34 @@ const (
type Response struct {
*http.Response
Meta Meta
// body holds a copy of the http.Response body that must be used within the handler
// chain. The http.Response.Body is reserved for external users.
body []byte
}
func (r *Response) readMeta(body []byte) error {
if h := r.Header.Get("RateLimit-Limit"); h != "" {
r.Meta.Ratelimit.Limit, _ = strconv.Atoi(h)
}
if h := r.Header.Get("RateLimit-Remaining"); h != "" {
r.Meta.Ratelimit.Remaining, _ = strconv.Atoi(h)
}
if h := r.Header.Get("RateLimit-Reset"); h != "" {
if ts, err := strconv.ParseInt(h, 10, 64); err == nil {
r.Meta.Ratelimit.Reset = time.Unix(ts, 0)
}
// populateBody copies the original [http.Response] body into the internal [Response] body
// property, and restore the original [http.Response] body as if it was untouched.
func (r *Response) populateBody() error {
// Read full response body and save it for later use
body, err := io.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
r.body = body
if strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") {
var s schema.MetaResponse
if err := json.Unmarshal(body, &s); err != nil {
return err
}
if s.Meta.Pagination != nil {
p := PaginationFromSchema(*s.Meta.Pagination)
r.Meta.Pagination = &p
}
}
// Restore the body as if it was untouched, as it might be read by external users
r.Body = io.NopCloser(bytes.NewReader(body))
return nil
}
// hasJSONBody returns whether the response has a JSON body.
func (r *Response) hasJSONBody() bool {
return len(r.body) > 0 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/json")
}
// internalCorrelationID returns the unique ID of the request as set by the API. This ID can help with support requests,
// as it allows the people working on identify this request in particular.
func (r *Response) internalCorrelationID() string {

View File

@ -0,0 +1,101 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"io"
)
func getRequest[Schema any](ctx context.Context, client *Client, url string) (Schema, *Response, error) {
var respBody Schema
req, err := client.NewRequest(ctx, "GET", url, nil)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func postRequest[Schema any](ctx context.Context, client *Client, url string, reqBody any) (Schema, *Response, error) {
var respBody Schema
var reqBodyReader io.Reader
if reqBody != nil {
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
return respBody, nil, err
}
reqBodyReader = bytes.NewReader(reqBodyBytes)
}
req, err := client.NewRequest(ctx, "POST", url, reqBodyReader)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func putRequest[Schema any](ctx context.Context, client *Client, url string, reqBody any) (Schema, *Response, error) {
var respBody Schema
var reqBodyReader io.Reader
if reqBody != nil {
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
return respBody, nil, err
}
reqBodyReader = bytes.NewReader(reqBodyBytes)
}
req, err := client.NewRequest(ctx, "PUT", url, reqBodyReader)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func deleteRequest[Schema any](ctx context.Context, client *Client, url string) (Schema, *Response, error) {
var respBody Schema
req, err := client.NewRequest(ctx, "DELETE", url, nil)
if err != nil {
return respBody, nil, err
}
resp, err := client.Do(req, &respBody)
if err != nil {
return respBody, resp, err
}
return respBody, resp, nil
}
func deleteRequestNoResult(ctx context.Context, client *Client, url string) (*Response, error) {
req, err := client.NewRequest(ctx, "DELETE", url, nil)
if err != nil {
return nil, err
}
return client.Do(req, nil)
}

View File

@ -0,0 +1,56 @@
package hcloud
import (
"context"
"net/http"
)
// handler is an interface representing a client request transaction. The handler are
// meant to be chained, similarly to the [http.RoundTripper] interface.
//
// The handler chain is placed between the [Client] API operations and the
// [http.Client].
type handler interface {
Do(req *http.Request, v any) (resp *Response, err error)
}
// assembleHandlerChain assembles the chain of handlers used to make API requests.
//
// The order of the handlers is important.
func assembleHandlerChain(client *Client) handler {
// Start down the chain: sending the http request
h := newHTTPHandler(client.httpClient)
// Insert debug writer if enabled
if client.debugWriter != nil {
h = wrapDebugHandler(h, client.debugWriter)
}
// Read rate limit headers
h = wrapRateLimitHandler(h)
// Build error from response
h = wrapErrorHandler(h)
// Retry request if condition are met
h = wrapRetryHandler(h, client.retryBackoffFunc, client.retryMaxRetries)
// Finally parse the response body into the provided schema
h = wrapParseHandler(h)
return h
}
// cloneRequest clones both the request and the request body.
func cloneRequest(req *http.Request, ctx context.Context) (cloned *http.Request, err error) { //revive:disable:context-as-argument
cloned = req.Clone(ctx)
if req.ContentLength > 0 {
cloned.Body, err = req.GetBody()
if err != nil {
return nil, err
}
}
return cloned, nil
}

View File

@ -0,0 +1,50 @@
package hcloud
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
)
func wrapDebugHandler(wrapped handler, output io.Writer) handler {
return &debugHandler{wrapped, output}
}
type debugHandler struct {
handler handler
output io.Writer
}
func (h *debugHandler) Do(req *http.Request, v any) (resp *Response, err error) {
// Clone the request, so we can redact the auth header, read the body
// and use a new context.
cloned, err := cloneRequest(req, context.Background())
if err != nil {
return nil, err
}
cloned.Header.Set("Authorization", "REDACTED")
dumpReq, err := httputil.DumpRequestOut(cloned, true)
if err != nil {
return nil, err
}
fmt.Fprintf(h.output, "--- Request:\n%s\n\n", dumpReq)
resp, err = h.handler.Do(req, v)
if err != nil {
return resp, err
}
dumpResp, err := httputil.DumpResponse(resp.Response, true)
if err != nil {
return nil, err
}
fmt.Fprintf(h.output, "--- Response:\n%s\n\n", dumpResp)
return resp, err
}

View File

@ -0,0 +1,53 @@
package hcloud
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
var ErrStatusCode = errors.New("server responded with status code")
func wrapErrorHandler(wrapped handler) handler {
return &errorHandler{wrapped}
}
type errorHandler struct {
handler handler
}
func (h *errorHandler) Do(req *http.Request, v any) (resp *Response, err error) {
resp, err = h.handler.Do(req, v)
if err != nil {
return resp, err
}
if resp.StatusCode >= 400 && resp.StatusCode <= 599 {
err = errorFromBody(resp)
if err == nil {
err = fmt.Errorf("hcloud: %w %d", ErrStatusCode, resp.StatusCode)
}
}
return resp, err
}
func errorFromBody(resp *Response) error {
if !resp.hasJSONBody() {
return nil
}
var s schema.ErrorResponse
if err := json.Unmarshal(resp.body, &s); err != nil {
return nil // nolint: nilerr
}
if s.Error.Code == "" && s.Error.Message == "" {
return nil
}
hcErr := ErrorFromSchema(s.Error)
hcErr.response = resp
return hcErr
}

View File

@ -0,0 +1,28 @@
package hcloud
import (
"net/http"
)
func newHTTPHandler(httpClient *http.Client) handler {
return &httpHandler{httpClient}
}
type httpHandler struct {
httpClient *http.Client
}
func (h *httpHandler) Do(req *http.Request, _ interface{}) (*Response, error) {
httpResponse, err := h.httpClient.Do(req) //nolint: bodyclose
resp := &Response{Response: httpResponse}
if err != nil {
return resp, err
}
err = resp.populateBody()
if err != nil {
return resp, err
}
return resp, err
}

View File

@ -0,0 +1,50 @@
package hcloud
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
func wrapParseHandler(wrapped handler) handler {
return &parseHandler{wrapped}
}
type parseHandler struct {
handler handler
}
func (h *parseHandler) Do(req *http.Request, v any) (resp *Response, err error) {
// respBody is not needed down the handler chain
resp, err = h.handler.Do(req, nil)
if err != nil {
return resp, err
}
if resp.hasJSONBody() {
// Parse the response meta
var s schema.MetaResponse
if err := json.Unmarshal(resp.body, &s); err != nil {
return resp, fmt.Errorf("hcloud: error reading response meta data: %w", err)
}
if s.Meta.Pagination != nil {
p := PaginationFromSchema(*s.Meta.Pagination)
resp.Meta.Pagination = &p
}
}
// Parse the response schema
if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, bytes.NewReader(resp.body))
} else {
err = json.Unmarshal(resp.body, v)
}
}
return resp, err
}

View File

@ -0,0 +1,36 @@
package hcloud
import (
"net/http"
"strconv"
"time"
)
func wrapRateLimitHandler(wrapped handler) handler {
return &rateLimitHandler{wrapped}
}
type rateLimitHandler struct {
handler handler
}
func (h *rateLimitHandler) Do(req *http.Request, v any) (resp *Response, err error) {
resp, err = h.handler.Do(req, v)
// Ensure the embedded [*http.Response] is not nil, e.g. on canceled context
if resp != nil && resp.Response != nil && resp.Response.Header != nil {
if h := resp.Header.Get("RateLimit-Limit"); h != "" {
resp.Meta.Ratelimit.Limit, _ = strconv.Atoi(h)
}
if h := resp.Header.Get("RateLimit-Remaining"); h != "" {
resp.Meta.Ratelimit.Remaining, _ = strconv.Atoi(h)
}
if h := resp.Header.Get("RateLimit-Reset"); h != "" {
if ts, err := strconv.ParseInt(h, 10, 64); err == nil {
resp.Meta.Ratelimit.Reset = time.Unix(ts, 0)
}
}
}
return resp, err
}

View File

@ -0,0 +1,84 @@
package hcloud
import (
"errors"
"net"
"net/http"
"time"
)
func wrapRetryHandler(wrapped handler, backoffFunc BackoffFunc, maxRetries int) handler {
return &retryHandler{wrapped, backoffFunc, maxRetries}
}
type retryHandler struct {
handler handler
backoffFunc BackoffFunc
maxRetries int
}
func (h *retryHandler) Do(req *http.Request, v any) (resp *Response, err error) {
retries := 0
ctx := req.Context()
for {
// Clone the request using the original context
cloned, err := cloneRequest(req, ctx)
if err != nil {
return nil, err
}
resp, err = h.handler.Do(cloned, v)
if err != nil {
// Beware the diversity of the errors:
// - request preparation
// - network connectivity
// - http status code (see [errorHandler])
if ctx.Err() != nil {
// early return if the context was canceled or timed out
return resp, err
}
if retries < h.maxRetries && retryPolicy(resp, err) {
select {
case <-ctx.Done():
return resp, err
case <-time.After(h.backoffFunc(retries)):
retries++
continue
}
}
}
return resp, err
}
}
func retryPolicy(resp *Response, err error) bool {
if err != nil {
var apiErr Error
var netErr net.Error
switch {
case errors.As(err, &apiErr):
switch apiErr.Code { //nolint:exhaustive
case ErrorCodeConflict:
return true
case ErrorCodeRateLimitExceeded:
return true
}
case errors.Is(err, ErrStatusCode):
switch resp.Response.StatusCode {
// 5xx errors
case http.StatusBadGateway, http.StatusGatewayTimeout:
return true
}
case errors.As(err, &netErr):
if netErr.Timeout() {
return true
}
}
}
return false
}

View File

@ -0,0 +1,85 @@
package hcloud
import (
"context"
"strconv"
)
// allFromSchemaFunc transform each item in the list using the FromSchema function, and
// returns the result.
func allFromSchemaFunc[T, V any](all []T, fn func(T) V) []V {
result := make([]V, len(all))
for i, t := range all {
result[i] = fn(t)
}
return result
}
// iterPages fetches each pages using the list function, and returns the result.
func iterPages[T any](listFn func(int) ([]*T, *Response, error)) ([]*T, error) {
page := 1
result := []*T{}
for {
pageResult, resp, err := listFn(page)
if err != nil {
return nil, err
}
result = append(result, pageResult...)
if resp.Meta.Pagination == nil || resp.Meta.Pagination.NextPage == 0 {
return result, nil
}
page = resp.Meta.Pagination.NextPage
}
}
// firstBy fetches a list of items using the list function, and returns the first item
// of the list if present otherwise nil.
func firstBy[T any](listFn func() ([]*T, *Response, error)) (*T, *Response, error) {
items, resp, err := listFn()
if len(items) == 0 {
return nil, resp, err
}
return items[0], resp, err
}
// firstByName is a wrapper around [firstBy], that checks if the provided name is not
// empty.
func firstByName[T any](name string, listFn func() ([]*T, *Response, error)) (*T, *Response, error) {
if name == "" {
return nil, nil, nil
}
return firstBy(listFn)
}
// getByIDOrName fetches the resource by ID when the identifier is an integer, otherwise
// by Name. To support resources that have a integer as Name, an additional attempt is
// made to fetch the resource by Name using the ID.
//
// Since API managed resources (locations, server types, ...) do not have integers as
// names, this function is only meaningful for user managed resources (ssh keys,
// servers).
func getByIDOrName[T any](
ctx context.Context,
getByIDFn func(ctx context.Context, id int64) (*T, *Response, error),
getByNameFn func(ctx context.Context, name string) (*T, *Response, error),
idOrName string,
) (*T, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
result, resp, err := getByIDFn(ctx, id)
if err != nil {
return result, resp, err
}
if result != nil {
return result, resp, err
}
// Fallback to get by Name if the resource was not found
}
return getByNameFn(ctx, idOrName)
}

View File

@ -6,6 +6,7 @@ import (
"net/url"
"strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -32,32 +33,27 @@ type DatacenterClient struct {
// GetByID retrieves a datacenter by its ID. If the datacenter does not exist, nil is returned.
func (c *DatacenterClient) GetByID(ctx context.Context, id int64) (*Datacenter, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/datacenters/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/datacenters/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.DatacenterGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.DatacenterGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, resp, err
}
return DatacenterFromSchema(body.Datacenter), resp, nil
return DatacenterFromSchema(respBody.Datacenter), resp, nil
}
// GetByName retrieves a datacenter by its name. If the datacenter does not exist, nil is returned.
func (c *DatacenterClient) GetByName(ctx context.Context, name string) (*Datacenter, *Response, error) {
if name == "" {
return nil, nil, nil
}
datacenters, response, err := c.List(ctx, DatacenterListOpts{Name: name})
if len(datacenters) == 0 {
return nil, response, err
}
return datacenters[0], response, err
return firstByName(name, func() ([]*Datacenter, *Response, error) {
return c.List(ctx, DatacenterListOpts{Name: name})
})
}
// Get retrieves a datacenter by its ID if the input can be parsed as an integer, otherwise it
@ -92,22 +88,17 @@ func (l DatacenterListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *DatacenterClient) List(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, *Response, error) {
path := "/datacenters?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/datacenters?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.DatacenterListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.DatacenterListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
datacenters := make([]*Datacenter, 0, len(body.Datacenters))
for _, i := range body.Datacenters {
datacenters = append(datacenters, DatacenterFromSchema(i))
}
return datacenters, resp, nil
return allFromSchemaFunc(respBody.Datacenters, DatacenterFromSchema), resp, nil
}
// All returns all datacenters.
@ -117,20 +108,8 @@ func (c *DatacenterClient) All(ctx context.Context) ([]*Datacenter, error) {
// AllWithOpts returns all datacenters for the given options.
func (c *DatacenterClient) AllWithOpts(ctx context.Context, opts DatacenterListOpts) ([]*Datacenter, error) {
allDatacenters := []*Datacenter{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Datacenter, *Response, error) {
opts.Page = page
datacenters, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allDatacenters = append(allDatacenters, datacenters...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allDatacenters, nil
}

View File

@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"net"
"slices"
"strings"
)
// ErrorCode represents an error code returned from the API.
@ -29,6 +31,7 @@ const (
ErrorCodeRobotUnavailable ErrorCode = "robot_unavailable" // Robot was not available. The caller may retry the operation after a short delay
ErrorCodeResourceLocked ErrorCode = "resource_locked" // The resource is locked. The caller should contact support
ErrorUnsupportedError ErrorCode = "unsupported_error" // The given resource does not support this
ErrorDeprecatedAPIEndpoint ErrorCode = "deprecated_api_endpoint" // The request can not be answered because the API functionality was removed
// Server related error codes.
@ -126,11 +129,16 @@ type ErrorDetailsInvalidInputField struct {
Messages []string
}
// IsError returns whether err is an API error with the given error code.
func IsError(err error, code ErrorCode) bool {
// ErrorDetailsDeprecatedAPIEndpoint contains the details of a 'deprecated_api_endpoint' error.
type ErrorDetailsDeprecatedAPIEndpoint struct {
Announcement string
}
// IsError returns whether err is an API error with one of the given error codes.
func IsError(err error, code ...ErrorCode) bool {
var apiErr Error
ok := errors.As(err, &apiErr)
return ok && apiErr.Code == code
return ok && slices.Index(code, apiErr.Code) > -1
}
type InvalidIPError struct {
@ -148,3 +156,40 @@ type DNSNotFoundError struct {
func (e DNSNotFoundError) Error() string {
return fmt.Sprintf("dns for ip %s not found", e.IP.String())
}
// ArgumentError is a type of error returned when validating arguments.
type ArgumentError string
func (e ArgumentError) Error() string { return string(e) }
func newArgumentErrorf(format string, args ...any) ArgumentError {
return ArgumentError(fmt.Sprintf(format, args...))
}
func missingArgument(name string, obj any) error {
return newArgumentErrorf("missing argument '%s' [%T]", name, obj)
}
func invalidArgument(name string, obj any) error {
return newArgumentErrorf("invalid value '%v' for argument '%s' [%T]", obj, name, obj)
}
func missingField(obj any, field string) error {
return newArgumentErrorf("missing field [%s] in [%T]", field, obj)
}
func invalidFieldValue(obj any, field string, value any) error {
return newArgumentErrorf("invalid value '%v' for field [%s] in [%T]", value, field, obj)
}
func missingOneOfFields(obj any, fields ...string) error {
return newArgumentErrorf("missing one of fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}
func mutuallyExclusiveFields(obj any, fields ...string) error {
return newArgumentErrorf("found mutually exclusive fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}
func missingRequiredTogetherFields(obj any, fields ...string) error {
return newArgumentErrorf("missing required together fields [%s] in [%T]", strings.Join(fields, ", "), obj)
}

View File

@ -0,0 +1,11 @@
package actionutil
import "k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud"
// AppendNext return the action and the next actions in a new slice.
func AppendNext(action *hcloud.Action, nextActions []*hcloud.Action) []*hcloud.Action {
all := make([]*hcloud.Action, 0, 1+len(nextActions))
all = append(all, action)
all = append(all, nextActions...)
return all
}

View File

@ -0,0 +1,30 @@
package ctxutil
import (
"context"
"strings"
)
// key is an unexported type to prevents collisions with keys defined in other packages.
type key struct{}
// opPathKey is the key for operation path in Contexts.
var opPathKey = key{}
// SetOpPath processes the operation path and save it in the context before returning it.
func SetOpPath(ctx context.Context, path string) context.Context {
path, _, _ = strings.Cut(path, "?")
path = strings.ReplaceAll(path, "%d", "-")
path = strings.ReplaceAll(path, "%s", "-")
return context.WithValue(ctx, opPathKey, path)
}
// OpPath returns the operation path from the context.
func OpPath(ctx context.Context) string {
result, ok := ctx.Value(opPathKey).(string)
if !ok {
return ""
}
return result
}

View File

@ -0,0 +1,4 @@
// Package exp is a namespace that holds experimental features for the `hcloud-go` library.
//
// Breaking changes may occur without notice. Do not use in production!
package exp

View File

@ -0,0 +1,40 @@
package envutil
import (
"fmt"
"os"
"strings"
)
// LookupEnvWithFile retrieves the value of the environment variable named by the key (e.g.
// HCLOUD_TOKEN). If the previous environment variable is not set, it retrieves the
// content of the file located by a second environment variable named by the key +
// '_FILE' (.e.g HCLOUD_TOKEN_FILE).
//
// For both cases, the returned value may be empty.
//
// The value from the environment takes precedence over the value from the file.
func LookupEnvWithFile(key string) (string, error) {
// Check if the value is set in the environment (e.g. HCLOUD_TOKEN)
value, ok := os.LookupEnv(key)
if ok {
return value, nil
}
key += "_FILE"
// Check if the value is set via a file (e.g. HCLOUD_TOKEN_FILE)
valueFile, ok := os.LookupEnv(key)
if !ok {
// Validation of the value happens outside of this function
return "", nil
}
// Read the content of the file
valueBytes, err := os.ReadFile(valueFile)
if err != nil {
return "", fmt.Errorf("failed to read %s: %w", key, err)
}
return strings.TrimSpace(string(valueBytes)), nil
}

View File

@ -0,0 +1,19 @@
package randutil
import (
"crypto/rand"
"encoding/hex"
"fmt"
)
// GenerateID returns a hex encoded random string with a len of 8 chars similar to
// "2873fce7".
func GenerateID() string {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
// Should never happen as of go1.24: https://github.com/golang/go/issues/66821
panic(fmt.Errorf("failed to generate random string: %w", err))
}
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,86 @@
package sshutil
import (
"crypto"
"crypto/ed25519"
"encoding/pem"
"fmt"
"golang.org/x/crypto/ssh"
)
// GenerateKeyPair generates a new ed25519 ssh key pair, and returns the private key and
// the public key respectively.
func GenerateKeyPair() ([]byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
return nil, nil, fmt.Errorf("could not generate key pair: %w", err)
}
privBytes, err := encodePrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("could not encode private key: %w", err)
}
pubBytes, err := encodePublicKey(pub)
if err != nil {
return nil, nil, fmt.Errorf("could not encode public key: %w", err)
}
return privBytes, pubBytes, nil
}
func encodePrivateKey(priv crypto.PrivateKey) ([]byte, error) {
privPem, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
return nil, err
}
return pem.EncodeToMemory(privPem), nil
}
func encodePublicKey(pub crypto.PublicKey) ([]byte, error) {
sshPub, err := ssh.NewPublicKey(pub)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(sshPub), nil
}
type privateKeyWithPublicKey interface {
crypto.PrivateKey
Public() crypto.PublicKey
}
// GeneratePublicKey generate a public key from the provided private key.
func GeneratePublicKey(privBytes []byte) ([]byte, error) {
priv, err := ssh.ParseRawPrivateKey(privBytes)
if err != nil {
return nil, fmt.Errorf("could not decode private key: %w", err)
}
key, ok := priv.(privateKeyWithPublicKey)
if !ok {
return nil, fmt.Errorf("private key doesn't export Public() crypto.PublicKey")
}
pubBytes, err := encodePublicKey(key.Public())
if err != nil {
return nil, fmt.Errorf("could not encode public key: %w", err)
}
return pubBytes, nil
}
// GetPublicKeyFingerprint generate the finger print for the provided public key.
func GetPublicKeyFingerprint(pubBytes []byte) (string, error) {
pub, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes)
if err != nil {
return "", fmt.Errorf("could not decode public key: %w", err)
}
fingerprint := ssh.FingerprintLegacyMD5(pub)
return fingerprint, nil
}

View File

@ -0,0 +1,24 @@
package labelutil
import (
"fmt"
"sort"
"strings"
)
// Selector combines the label set into a [label selector](https://docs.hetzner.cloud/#label-selector) that only selects
// resources have all specified labels set.
//
// The selector string can be used to filter resources when listing, for example with [hcloud.ServerClient.AllWithOpts()].
func Selector(labels map[string]string) string {
selectors := make([]string, 0, len(labels))
for k, v := range labels {
selectors = append(selectors, fmt.Sprintf("%s=%s", k, v))
}
// Reproducible result for tests
sort.Strings(selectors)
return strings.Join(selectors, ",")
}

View File

@ -0,0 +1,123 @@
package mockutil
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Request describes a http request that a [httptest.Server] should receive, and the
// corresponding response to return.
//
// Additional checks on the request (e.g. request body) may be added using the
// [Request.Want] function.
//
// The response body is populated from either a JSON struct, or a JSON string.
type Request struct {
Method string
Path string
Want func(t *testing.T, r *http.Request)
Status int
JSON any
JSONRaw string
}
// Handler is using a [Server] to mock http requests provided by the user.
func Handler(t *testing.T, requests []Request) http.HandlerFunc {
t.Helper()
server := NewServer(t, requests)
t.Cleanup(server.close)
return server.handler
}
// NewServer returns a new mock server that closes itself at the end of the test.
func NewServer(t *testing.T, requests []Request) *Server {
t.Helper()
o := &Server{t: t}
o.Server = httptest.NewServer(http.HandlerFunc(o.handler))
t.Cleanup(o.close)
o.Expect(requests)
return o
}
// Server embeds a [httptest.Server] that answers HTTP calls with a list of expected [Request].
//
// Request matching is based on the request count, and the user provided request will be
// iterated over.
//
// A Server must be created using the [NewServer] function.
type Server struct {
*httptest.Server
t *testing.T
requests []Request
index int
}
// Expect adds requests to the list of requests expected by the [Server].
func (m *Server) Expect(requests []Request) {
m.requests = append(m.requests, requests...)
}
func (m *Server) close() {
m.t.Helper()
m.Server.Close()
assert.EqualValues(m.t, len(m.requests), m.index, "expected more calls")
}
func (m *Server) handler(w http.ResponseWriter, r *http.Request) {
if testing.Verbose() {
m.t.Logf("call %d: %s %s\n", m.index, r.Method, r.RequestURI)
}
if m.index >= len(m.requests) {
m.t.Fatalf("received unknown call %d", m.index)
}
expected := m.requests[m.index]
expectedCall := expected.Method
foundCall := r.Method
if expected.Path != "" {
expectedCall += " " + expected.Path
foundCall += " " + r.RequestURI
}
require.Equal(m.t, expectedCall, foundCall) // nolint: testifylint
if expected.Want != nil {
expected.Want(m.t, r)
}
switch {
case expected.JSON != nil:
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(expected.Status)
if err := json.NewEncoder(w).Encode(expected.JSON); err != nil {
m.t.Fatal(err)
}
case expected.JSONRaw != "":
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(expected.Status)
_, err := w.Write([]byte(expected.JSONRaw))
if err != nil {
m.t.Fatal(err)
}
default:
w.WriteHeader(expected.Status)
}
m.index++
}

View File

@ -1,16 +1,13 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -96,41 +93,33 @@ type FirewallClient struct {
// GetByID retrieves a Firewall by its ID. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) GetByID(ctx context.Context, id int64) (*Firewall, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/firewalls/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/firewalls/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.FirewallGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.FirewallGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return FirewallFromSchema(body.Firewall), resp, nil
return FirewallFromSchema(respBody.Firewall), resp, nil
}
// GetByName retrieves a Firewall by its name. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) GetByName(ctx context.Context, name string) (*Firewall, *Response, error) {
if name == "" {
return nil, nil, nil
}
firewalls, response, err := c.List(ctx, FirewallListOpts{Name: name})
if len(firewalls) == 0 {
return nil, response, err
}
return firewalls[0], response, err
return firstByName(name, func() ([]*Firewall, *Response, error) {
return c.List(ctx, FirewallListOpts{Name: name})
})
}
// Get retrieves a Firewall by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Firewall by its name. If the Firewall does not exist, nil is returned.
func (c *FirewallClient) Get(ctx context.Context, idOrName string) (*Firewall, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// FirewallListOpts specifies options for listing Firewalls.
@ -156,22 +145,17 @@ func (l FirewallListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *FirewallClient) List(ctx context.Context, opts FirewallListOpts) ([]*Firewall, *Response, error) {
path := "/firewalls?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/firewalls?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.FirewallListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.FirewallListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
firewalls := make([]*Firewall, 0, len(body.Firewalls))
for _, s := range body.Firewalls {
firewalls = append(firewalls, FirewallFromSchema(s))
}
return firewalls, resp, nil
return allFromSchemaFunc(respBody.Firewalls, FirewallFromSchema), resp, nil
}
// All returns all Firewalls.
@ -181,22 +165,10 @@ func (c *FirewallClient) All(ctx context.Context) ([]*Firewall, error) {
// AllWithOpts returns all Firewalls for the given options.
func (c *FirewallClient) AllWithOpts(ctx context.Context, opts FirewallListOpts) ([]*Firewall, error) {
allFirewalls := []*Firewall{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Firewall, *Response, error) {
opts.Page = page
firewalls, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allFirewalls = append(allFirewalls, firewalls...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allFirewalls, nil
}
// FirewallCreateOpts specifies options for creating a new Firewall.
@ -210,7 +182,7 @@ type FirewallCreateOpts struct {
// Validate checks if options are valid.
func (o FirewallCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
return nil
}
@ -223,28 +195,27 @@ type FirewallCreateResult struct {
// Create creates a new Firewall.
func (c *FirewallClient) Create(ctx context.Context, opts FirewallCreateOpts) (FirewallCreateResult, *Response, error) {
const opPath = "/firewalls"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := FirewallCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil {
return FirewallCreateResult{}, nil, err
}
reqBody := firewallCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return FirewallCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/firewalls", bytes.NewReader(reqBodyData))
if err != nil {
return FirewallCreateResult{}, nil, err
return result, nil, err
}
respBody := schema.FirewallCreateResponse{}
resp, err := c.client.Do(req, &respBody)
reqBody := firewallCreateOptsToSchema(opts)
respBody, resp, err := postRequest[schema.FirewallCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return FirewallCreateResult{}, resp, err
}
result := FirewallCreateResult{
Firewall: FirewallFromSchema(respBody.Firewall),
Actions: ActionsFromSchema(respBody.Actions),
return result, resp, err
}
result.Firewall = FirewallFromSchema(respBody.Firewall)
result.Actions = ActionsFromSchema(respBody.Actions)
return result, resp, nil
}
@ -256,6 +227,11 @@ type FirewallUpdateOpts struct {
// Update updates a Firewall.
func (c *FirewallClient) Update(ctx context.Context, firewall *Firewall, opts FirewallUpdateOpts) (*Firewall, *Response, error) {
const opPath = "/firewalls/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
reqBody := schema.FirewallUpdateRequest{}
if opts.Name != "" {
reqBody.Name = &opts.Name
@ -263,32 +239,23 @@ func (c *FirewallClient) Update(ctx context.Context, firewall *Firewall, opts Fi
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d", firewall.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FirewallUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.FirewallUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return FirewallFromSchema(respBody.Firewall), resp, nil
}
// Delete deletes a Firewall.
func (c *FirewallClient) Delete(ctx context.Context, firewall *Firewall) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/firewalls/%d", firewall.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/firewalls/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// FirewallSetRulesOpts specifies options for setting rules of a Firewall.
@ -298,75 +265,59 @@ type FirewallSetRulesOpts struct {
// SetRules sets the rules of a Firewall.
func (c *FirewallClient) SetRules(ctx context.Context, firewall *Firewall, opts FirewallSetRulesOpts) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/set_rules"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
reqBody := firewallSetRulesOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/set_rules", firewall.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionSetRulesResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FirewallActionSetRulesResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionsFromSchema(respBody.Actions), resp, nil
}
func (c *FirewallClient) ApplyResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/apply_to_resources"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
applyTo := make([]schema.FirewallResource, len(resources))
for i, r := range resources {
applyTo[i] = firewallResourceToSchema(r)
}
reqBody := schema.FirewallActionApplyToResourcesRequest{ApplyTo: applyTo}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/apply_to_resources", firewall.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionApplyToResourcesResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FirewallActionApplyToResourcesResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionsFromSchema(respBody.Actions), resp, nil
}
func (c *FirewallClient) RemoveResources(ctx context.Context, firewall *Firewall, resources []FirewallResource) ([]*Action, *Response, error) {
const opPath = "/firewalls/%d/actions/remove_from_resources"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, firewall.ID)
removeFrom := make([]schema.FirewallResource, len(resources))
for i, r := range resources {
removeFrom[i] = firewallResourceToSchema(r)
}
reqBody := schema.FirewallActionRemoveFromResourcesRequest{RemoveFrom: removeFrom}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/firewalls/%d/actions/remove_from_resources", firewall.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FirewallActionRemoveFromResourcesResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FirewallActionRemoveFromResourcesResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionsFromSchema(respBody.Actions), resp, nil
}

View File

@ -1,16 +1,13 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -54,26 +51,21 @@ const (
// changeDNSPtr changes or resets the reverse DNS pointer for an IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (f *FloatingIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, f.ID)
reqBody := schema.FloatingIPActionChangeDNSPtrRequest{
IP: ip.String(),
DNSPtr: ptr,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/change_dns_ptr", f.ID)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPActionChangeDNSPtrResponse{}
resp, err := client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FloatingIPActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -97,41 +89,33 @@ type FloatingIPClient struct {
// GetByID retrieves a Floating IP by its ID. If the Floating IP does not exist,
// nil is returned.
func (c *FloatingIPClient) GetByID(ctx context.Context, id int64) (*FloatingIP, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/floating_ips/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/floating_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.FloatingIPGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.FloatingIPGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, resp, err
}
return FloatingIPFromSchema(body.FloatingIP), resp, nil
return FloatingIPFromSchema(respBody.FloatingIP), resp, nil
}
// GetByName retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned.
func (c *FloatingIPClient) GetByName(ctx context.Context, name string) (*FloatingIP, *Response, error) {
if name == "" {
return nil, nil, nil
}
floatingIPs, response, err := c.List(ctx, FloatingIPListOpts{Name: name})
if len(floatingIPs) == 0 {
return nil, response, err
}
return floatingIPs[0], response, err
return firstByName(name, func() ([]*FloatingIP, *Response, error) {
return c.List(ctx, FloatingIPListOpts{Name: name})
})
}
// Get retrieves a Floating IP by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Floating IP by its name. If the Floating IP does not exist, nil is returned.
func (c *FloatingIPClient) Get(ctx context.Context, idOrName string) (*FloatingIP, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// FloatingIPListOpts specifies options for listing Floating IPs.
@ -157,22 +141,17 @@ func (l FloatingIPListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *FloatingIPClient) List(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, *Response, error) {
path := "/floating_ips?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/floating_ips?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.FloatingIPListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.FloatingIPListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
floatingIPs := make([]*FloatingIP, 0, len(body.FloatingIPs))
for _, s := range body.FloatingIPs {
floatingIPs = append(floatingIPs, FloatingIPFromSchema(s))
}
return floatingIPs, resp, nil
return allFromSchemaFunc(respBody.FloatingIPs, FloatingIPFromSchema), resp, nil
}
// All returns all Floating IPs.
@ -182,22 +161,10 @@ func (c *FloatingIPClient) All(ctx context.Context) ([]*FloatingIP, error) {
// AllWithOpts returns all Floating IPs for the given options.
func (c *FloatingIPClient) AllWithOpts(ctx context.Context, opts FloatingIPListOpts) ([]*FloatingIP, error) {
allFloatingIPs := []*FloatingIP{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*FloatingIP, *Response, error) {
opts.Page = page
floatingIPs, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allFloatingIPs = append(allFloatingIPs, floatingIPs...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allFloatingIPs, nil
}
// FloatingIPCreateOpts specifies options for creating a Floating IP.
@ -216,10 +183,10 @@ func (o FloatingIPCreateOpts) Validate() error {
case FloatingIPTypeIPv4, FloatingIPTypeIPv6:
break
default:
return errors.New("missing or invalid type")
return invalidFieldValue(o, "Type", o.Type)
}
if o.HomeLocation == nil && o.Server == nil {
return errors.New("one of home location or server is required")
return missingOneOfFields(o, "HomeLocation", "Server")
}
return nil
}
@ -232,8 +199,15 @@ type FloatingIPCreateResult struct {
// Create creates a Floating IP.
func (c *FloatingIPClient) Create(ctx context.Context, opts FloatingIPCreateOpts) (FloatingIPCreateResult, *Response, error) {
result := FloatingIPCreateResult{}
const opPath = "/floating_ips"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil {
return FloatingIPCreateResult{}, nil, err
return result, nil, err
}
reqBody := schema.FloatingIPCreateRequest{
@ -250,38 +224,28 @@ func (c *FloatingIPClient) Create(ctx context.Context, opts FloatingIPCreateOpts
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.FloatingIPCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return FloatingIPCreateResult{}, nil, err
return result, resp, err
}
req, err := c.client.NewRequest(ctx, "POST", "/floating_ips", bytes.NewReader(reqBodyData))
if err != nil {
return FloatingIPCreateResult{}, nil, err
}
var respBody schema.FloatingIPCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return FloatingIPCreateResult{}, resp, err
}
var action *Action
result.FloatingIP = FloatingIPFromSchema(respBody.FloatingIP)
if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action)
result.Action = ActionFromSchema(*respBody.Action)
}
return FloatingIPCreateResult{
FloatingIP: FloatingIPFromSchema(respBody.FloatingIP),
Action: action,
}, resp, nil
return result, resp, nil
}
// Delete deletes a Floating IP.
func (c *FloatingIPClient) Delete(ctx context.Context, floatingIP *FloatingIP) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/floating_ips/%d", floatingIP.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/floating_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// FloatingIPUpdateOpts specifies options for updating a Floating IP.
@ -293,6 +257,11 @@ type FloatingIPUpdateOpts struct {
// Update updates a Floating IP.
func (c *FloatingIPClient) Update(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPUpdateOpts) (*FloatingIP, *Response, error) {
const opPath = "/floating_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPUpdateRequest{
Description: opts.Description,
Name: opts.Name,
@ -300,68 +269,48 @@ func (c *FloatingIPClient) Update(ctx context.Context, floatingIP *FloatingIP, o
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d", floatingIP.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.FloatingIPUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return FloatingIPFromSchema(respBody.FloatingIP), resp, nil
}
// Assign assigns a Floating IP to a server.
func (c *FloatingIPClient) Assign(ctx context.Context, floatingIP *FloatingIP, server *Server) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/assign"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPActionAssignRequest{
Server: server.ID,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/assign", floatingIP.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.FloatingIPActionAssignResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FloatingIPActionAssignResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
// Unassign unassigns a Floating IP from the currently assigned server.
func (c *FloatingIPClient) Unassign(ctx context.Context, floatingIP *FloatingIP) (*Action, *Response, error) {
var reqBody schema.FloatingIPActionUnassignRequest
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
const opPath = "/floating_ips/%d/actions/unassign"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/floating_ips/%d/actions/unassign", floatingIP.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
var respBody schema.FloatingIPActionUnassignResponse
resp, err := c.client.Do(req, &respBody)
reqBody := schema.FloatingIPActionUnassignRequest{}
respBody, resp, err := postRequest[schema.FloatingIPActionUnassignResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -382,24 +331,19 @@ type FloatingIPChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a Floating IP.
func (c *FloatingIPClient) ChangeProtection(ctx context.Context, floatingIP *FloatingIP, opts FloatingIPChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/floating_ips/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, floatingIP.ID)
reqBody := schema.FloatingIPActionChangeProtectionRequest{
Delete: opts.Delete,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/floating_ips/%d/actions/change_protection", floatingIP.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.FloatingIPActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.FloatingIPActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}

View File

@ -1,5 +1,34 @@
// Package hcloud is a library for the Hetzner Cloud API.
/*
Package hcloud is a library for the Hetzner Cloud API.
The Hetzner Cloud API reference is available at https://docs.hetzner.cloud.
Make sure to follow our API changelog available at https://docs.hetzner.cloud/changelog
(or the RRS feed available at https://docs.hetzner.cloud/changelog/feed.rss) to be
notified about additions, deprecations and removals.
# Retry mechanism
The [Client.Do] method will retry failed requests that match certain criteria. The
default retry interval is defined by an exponential backoff algorithm truncated to 60s
with jitter. The default maximal number of retries is 5.
The following rules defines when a request can be retried:
When the [http.Client] returned a network timeout error.
When the API returned an HTTP error, with the status code:
- [http.StatusBadGateway]
- [http.StatusGatewayTimeout]
When the API returned an application error, with the code:
- [ErrorCodeConflict]
- [ErrorCodeRateLimitExceeded]
Changes to the retry policy might occur between releases, and will not be considered
breaking changes.
*/
package hcloud
// Version is the library's version following Semantic Versioning.
const Version = "2.8.0" // x-release-please-version
const Version = "2.21.1" // x-releaser-pleaser-version

View File

@ -1,14 +1,13 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -83,34 +82,29 @@ type ImageClient struct {
// GetByID retrieves an image by its ID. If the image does not exist, nil is returned.
func (c *ImageClient) GetByID(ctx context.Context, id int64) (*Image, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/images/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/images/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.ImageGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.ImageGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return ImageFromSchema(body.Image), resp, nil
return ImageFromSchema(respBody.Image), resp, nil
}
// GetByName retrieves an image by its name. If the image does not exist, nil is returned.
//
// Deprecated: Use [ImageClient.GetByNameAndArchitecture] instead.
func (c *ImageClient) GetByName(ctx context.Context, name string) (*Image, *Response, error) {
if name == "" {
return nil, nil, nil
}
images, response, err := c.List(ctx, ImageListOpts{Name: name})
if len(images) == 0 {
return nil, response, err
}
return images[0], response, err
return firstByName(name, func() ([]*Image, *Response, error) {
return c.List(ctx, ImageListOpts{Name: name})
})
}
// GetByNameAndArchitecture retrieves an image by its name and architecture. If the image does not exist,
@ -118,14 +112,9 @@ func (c *ImageClient) GetByName(ctx context.Context, name string) (*Image, *Resp
// In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should
// check for this in your calling method.
func (c *ImageClient) GetByNameAndArchitecture(ctx context.Context, name string, architecture Architecture) (*Image, *Response, error) {
if name == "" {
return nil, nil, nil
}
images, response, err := c.List(ctx, ImageListOpts{Name: name, Architecture: []Architecture{architecture}, IncludeDeprecated: true})
if len(images) == 0 {
return nil, response, err
}
return images[0], response, err
return firstByName(name, func() ([]*Image, *Response, error) {
return c.List(ctx, ImageListOpts{Name: name, Architecture: []Architecture{architecture}, IncludeDeprecated: true})
})
}
// Get retrieves an image by its ID if the input can be parsed as an integer, otherwise it
@ -133,10 +122,7 @@ func (c *ImageClient) GetByNameAndArchitecture(ctx context.Context, name string,
//
// Deprecated: Use [ImageClient.GetForArchitecture] instead.
func (c *ImageClient) Get(ctx context.Context, idOrName string) (*Image, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// GetForArchitecture retrieves an image by its ID if the input can be parsed as an integer, otherwise it
@ -145,10 +131,13 @@ func (c *ImageClient) Get(ctx context.Context, idOrName string) (*Image, *Respon
// In contrast to [ImageClient.Get], this method also returns deprecated images. Depending on your needs you should
// check for this in your calling method.
func (c *ImageClient) GetForArchitecture(ctx context.Context, idOrName string, architecture Architecture) (*Image, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByNameAndArchitecture(ctx, idOrName, architecture)
return getByIDOrName(ctx,
c.GetByID,
func(ctx context.Context, name string) (*Image, *Response, error) {
return c.GetByNameAndArchitecture(ctx, name, architecture)
},
idOrName,
)
}
// ImageListOpts specifies options for listing images.
@ -194,22 +183,17 @@ func (l ImageListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *ImageClient) List(ctx context.Context, opts ImageListOpts) ([]*Image, *Response, error) {
path := "/images?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/images?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ImageListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.ImageListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
images := make([]*Image, 0, len(body.Images))
for _, i := range body.Images {
images = append(images, ImageFromSchema(i))
}
return images, resp, nil
return allFromSchemaFunc(respBody.Images, ImageFromSchema), resp, nil
}
// All returns all images.
@ -219,31 +203,20 @@ func (c *ImageClient) All(ctx context.Context) ([]*Image, error) {
// AllWithOpts returns all images for the given options.
func (c *ImageClient) AllWithOpts(ctx context.Context, opts ImageListOpts) ([]*Image, error) {
allImages := []*Image{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Image, *Response, error) {
opts.Page = page
images, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allImages = append(allImages, images...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allImages, nil
}
// Delete deletes an image.
func (c *ImageClient) Delete(ctx context.Context, image *Image) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/images/%d", image.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/images/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, image.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// ImageUpdateOpts specifies options for updating an image.
@ -255,6 +228,11 @@ type ImageUpdateOpts struct {
// Update updates an image.
func (c *ImageClient) Update(ctx context.Context, image *Image, opts ImageUpdateOpts) (*Image, *Response, error) {
const opPath = "/images/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, image.ID)
reqBody := schema.ImageUpdateRequest{
Description: opts.Description,
}
@ -264,22 +242,12 @@ func (c *ImageClient) Update(ctx context.Context, image *Image, opts ImageUpdate
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/images/%d", image.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.ImageUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.ImageUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ImageFromSchema(respBody.Image), resp, nil
}
@ -290,24 +258,19 @@ type ImageChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of an image.
func (c *ImageClient) ChangeProtection(ctx context.Context, image *Image, opts ImageChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/images/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, image.ID)
reqBody := schema.ImageActionChangeProtectionRequest{
Delete: opts.Delete,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/images/%d/actions/change_protection", image.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.ImageActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.ImageActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}

View File

@ -1,6 +1,7 @@
package instrumentation
import (
"errors"
"fmt"
"net/http"
"regexp"
@ -9,6 +10,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
)
type Instrumenter struct {
@ -22,7 +25,12 @@ func New(subsystemIdentifier string, instrumentationRegistry prometheus.Register
}
// InstrumentedRoundTripper returns an instrumented round tripper.
func (i *Instrumenter) InstrumentedRoundTripper() http.RoundTripper {
func (i *Instrumenter) InstrumentedRoundTripper(transport http.RoundTripper) http.RoundTripper {
// By default, http client would use DefaultTransport on nil, but we internally are relying on it being configured
if transport == nil {
transport = http.DefaultTransport
}
inFlightRequestsGauge := registerOrReuse(
i.instrumentationRegistry,
prometheus.NewGauge(prometheus.GaugeOpts{
@ -57,7 +65,7 @@ func (i *Instrumenter) InstrumentedRoundTripper() http.RoundTripper {
return promhttp.InstrumentRoundTripperInFlight(inFlightRequestsGauge,
promhttp.InstrumentRoundTripperDuration(requestLatencyHistogram,
i.instrumentRoundTripperEndpoint(requestsPerEndpointCounter,
http.DefaultTransport,
transport,
),
),
)
@ -73,8 +81,17 @@ func (i *Instrumenter) instrumentRoundTripperEndpoint(counter *prometheus.Counte
return func(r *http.Request) (*http.Response, error) {
resp, err := next.RoundTrip(r)
if err == nil {
statusCode := strconv.Itoa(resp.StatusCode)
counter.WithLabelValues(statusCode, strings.ToLower(resp.Request.Method), preparePathForLabel(resp.Request.URL.Path)).Inc()
apiEndpoint := ctxutil.OpPath(r.Context())
// If the request does not set the operation path, we must construct it. Happens e.g. for
// user crafted requests.
if apiEndpoint == "" {
apiEndpoint = preparePathForLabel(resp.Request.URL.Path)
}
counter.WithLabelValues(
strconv.Itoa(resp.StatusCode),
strings.ToLower(resp.Request.Method),
apiEndpoint,
).Inc()
}
return resp, err
@ -87,9 +104,10 @@ func (i *Instrumenter) instrumentRoundTripperEndpoint(counter *prometheus.Counte
func registerOrReuse[C prometheus.Collector](registry prometheus.Registerer, collector C) C {
err := registry.Register(collector)
if err != nil {
var arErr prometheus.AlreadyRegisteredError
// If we get a AlreadyRegisteredError we can return the existing collector
if are, ok := err.(prometheus.AlreadyRegisteredError); ok {
if existingCollector, ok := are.ExistingCollector.(C); ok {
if errors.As(err, &arErr) {
if existingCollector, ok := arErr.ExistingCollector.(C); ok {
collector = existingCollector
} else {
panic("received incompatible existing collector")
@ -102,16 +120,16 @@ func registerOrReuse[C prometheus.Collector](registry prometheus.Registerer, col
return collector
}
var pathLabelRegexp = regexp.MustCompile("[^a-z/_]+")
func preparePathForLabel(path string) string {
path = strings.ToLower(path)
// replace the /v1/ that indicated the API version
path, _ = strings.CutPrefix(path, "/v1")
// replace all numbers and chars that are not a-z, / or _
reg := regexp.MustCompile("[^a-z/_]+")
path = reg.ReplaceAllString(path, "")
path = pathLabelRegexp.ReplaceAllString(path, "-")
// replace all artifacts of number replacement (//)
path = strings.ReplaceAll(path, "//", "/")
// replace the /v/ that indicated the API version
return strings.Replace(path, "/v/", "/", 1)
return path
}

View File

@ -4,9 +4,9 @@ import (
"context"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -40,40 +40,32 @@ type ISOClient struct {
// GetByID retrieves an ISO by its ID.
func (c *ISOClient) GetByID(ctx context.Context, id int64) (*ISO, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/isos/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/isos/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.ISOGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.ISOGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, resp, err
}
return ISOFromSchema(body.ISO), resp, nil
return ISOFromSchema(respBody.ISO), resp, nil
}
// GetByName retrieves an ISO by its name.
func (c *ISOClient) GetByName(ctx context.Context, name string) (*ISO, *Response, error) {
if name == "" {
return nil, nil, nil
}
isos, response, err := c.List(ctx, ISOListOpts{Name: name})
if len(isos) == 0 {
return nil, response, err
}
return isos[0], response, err
return firstByName(name, func() ([]*ISO, *Response, error) {
return c.List(ctx, ISOListOpts{Name: name})
})
}
// Get retrieves an ISO by its ID if the input can be parsed as an integer, otherwise it retrieves an ISO by its name.
func (c *ISOClient) Get(ctx context.Context, idOrName string) (*ISO, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// ISOListOpts specifies options for listing isos.
@ -115,22 +107,17 @@ func (l ISOListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *ISOClient) List(ctx context.Context, opts ISOListOpts) ([]*ISO, *Response, error) {
path := "/isos?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/isos?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ISOListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.ISOListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
isos := make([]*ISO, 0, len(body.ISOs))
for _, i := range body.ISOs {
isos = append(isos, ISOFromSchema(i))
}
return isos, resp, nil
return allFromSchemaFunc(respBody.ISOs, ISOFromSchema), resp, nil
}
// All returns all ISOs.
@ -140,20 +127,8 @@ func (c *ISOClient) All(ctx context.Context) ([]*ISO, error) {
// AllWithOpts returns all ISOs for the given options.
func (c *ISOClient) AllWithOpts(ctx context.Context, opts ISOListOpts) ([]*ISO, error) {
allISOs := []*ISO{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*ISO, *Response, error) {
opts.Page = page
isos, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allISOs = append(allISOs, isos...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allISOs, nil
}

View File

@ -1,16 +1,14 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -200,26 +198,21 @@ type LoadBalancerProtection struct {
// changeDNSPtr changes or resets the reverse DNS pointer for an IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (lb *LoadBalancer) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, lb.ID)
reqBody := schema.LoadBalancerActionChangeDNSPtrRequest{
IP: ip.String(),
DNSPtr: ptr,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_dns_ptr", lb.ID)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeDNSPtrResponse{}
resp, err := client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -243,41 +236,33 @@ type LoadBalancerClient struct {
// GetByID retrieves a Load Balancer by its ID. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) GetByID(ctx context.Context, id int64) (*LoadBalancer, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/load_balancers/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/load_balancers/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.LoadBalancerGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.LoadBalancerGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return LoadBalancerFromSchema(body.LoadBalancer), resp, nil
return LoadBalancerFromSchema(respBody.LoadBalancer), resp, nil
}
// GetByName retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) GetByName(ctx context.Context, name string) (*LoadBalancer, *Response, error) {
if name == "" {
return nil, nil, nil
}
LoadBalancer, response, err := c.List(ctx, LoadBalancerListOpts{Name: name})
if len(LoadBalancer) == 0 {
return nil, response, err
}
return LoadBalancer[0], response, err
return firstByName(name, func() ([]*LoadBalancer, *Response, error) {
return c.List(ctx, LoadBalancerListOpts{Name: name})
})
}
// Get retrieves a Load Balancer by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Load Balancer by its name. If the Load Balancer does not exist, nil is returned.
func (c *LoadBalancerClient) Get(ctx context.Context, idOrName string) (*LoadBalancer, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// LoadBalancerListOpts specifies options for listing Load Balancers.
@ -303,22 +288,17 @@ func (l LoadBalancerListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *LoadBalancerClient) List(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, *Response, error) {
path := "/load_balancers?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/load_balancers?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.LoadBalancerListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
LoadBalancers := make([]*LoadBalancer, 0, len(body.LoadBalancers))
for _, s := range body.LoadBalancers {
LoadBalancers = append(LoadBalancers, LoadBalancerFromSchema(s))
}
return LoadBalancers, resp, nil
return allFromSchemaFunc(respBody.LoadBalancers, LoadBalancerFromSchema), resp, nil
}
// All returns all Load Balancers.
@ -328,22 +308,10 @@ func (c *LoadBalancerClient) All(ctx context.Context) ([]*LoadBalancer, error) {
// AllWithOpts returns all Load Balancers for the given options.
func (c *LoadBalancerClient) AllWithOpts(ctx context.Context, opts LoadBalancerListOpts) ([]*LoadBalancer, error) {
allLoadBalancers := []*LoadBalancer{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*LoadBalancer, *Response, error) {
opts.Page = page
LoadBalancers, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allLoadBalancers = append(allLoadBalancers, LoadBalancers...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allLoadBalancers, nil
}
// LoadBalancerUpdateOpts specifies options for updating a Load Balancer.
@ -354,6 +322,11 @@ type LoadBalancerUpdateOpts struct {
// Update updates a Load Balancer.
func (c *LoadBalancerClient) Update(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerUpdateOpts) (*LoadBalancer, *Response, error) {
const opPath = "/load_balancers/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerUpdateRequest{}
if opts.Name != "" {
reqBody.Name = &opts.Name
@ -361,22 +334,12 @@ func (c *LoadBalancerClient) Update(ctx context.Context, loadBalancer *LoadBalan
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.LoadBalancerUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return LoadBalancerFromSchema(respBody.LoadBalancer), resp, nil
}
@ -472,73 +435,61 @@ type LoadBalancerCreateResult struct {
// Create creates a new Load Balancer.
func (c *LoadBalancerClient) Create(ctx context.Context, opts LoadBalancerCreateOpts) (LoadBalancerCreateResult, *Response, error) {
const opPath = "/load_balancers"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := LoadBalancerCreateResult{}
reqPath := opPath
reqBody := loadBalancerCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.LoadBalancerCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return LoadBalancerCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/load_balancers", bytes.NewReader(reqBodyData))
if err != nil {
return LoadBalancerCreateResult{}, nil, err
return result, resp, err
}
respBody := schema.LoadBalancerCreateResponse{}
resp, err := c.client.Do(req, &respBody)
if err != nil {
return LoadBalancerCreateResult{}, resp, err
}
return LoadBalancerCreateResult{
LoadBalancer: LoadBalancerFromSchema(respBody.LoadBalancer),
Action: ActionFromSchema(respBody.Action),
}, resp, nil
result.LoadBalancer = LoadBalancerFromSchema(respBody.LoadBalancer)
result.Action = ActionFromSchema(respBody.Action)
return result, resp, nil
}
// Delete deletes a Load Balancer.
func (c *LoadBalancerClient) Delete(ctx context.Context, loadBalancer *LoadBalancer) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/load_balancers/%d", loadBalancer.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/load_balancers/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
func (c *LoadBalancerClient) addTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionAddTargetRequest) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
const opPath = "/load_balancers/%d/actions/add_target"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/load_balancers/%d/actions/add_target", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
var respBody schema.LoadBalancerActionAddTargetResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionAddTargetResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
func (c *LoadBalancerClient) removeTarget(ctx context.Context, loadBalancer *LoadBalancer, reqBody schema.LoadBalancerActionRemoveTargetRequest) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
const opPath = "/load_balancers/%d/actions/remove_target"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/load_balancers/%d/actions/remove_target", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
var respBody schema.LoadBalancerActionRemoveTargetResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionRemoveTargetResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -671,23 +622,18 @@ type LoadBalancerAddServiceOptsHealthCheckHTTP struct {
// AddService adds a service to a Load Balancer.
func (c *LoadBalancerClient) AddService(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAddServiceOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/add_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := loadBalancerAddServiceOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/add_service", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionAddServiceResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionAddServiceResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -732,48 +678,38 @@ type LoadBalancerUpdateServiceOptsHealthCheckHTTP struct {
// UpdateService updates a Load Balancer service.
func (c *LoadBalancerClient) UpdateService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int, opts LoadBalancerUpdateServiceOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/update_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := loadBalancerUpdateServiceOptsToSchema(opts)
reqBody.ListenPort = listenPort
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/update_service", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerActionUpdateServiceResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionUpdateServiceResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
// DeleteService deletes a Load Balancer service.
func (c *LoadBalancerClient) DeleteService(ctx context.Context, loadBalancer *LoadBalancer, listenPort int) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/delete_service"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerDeleteServiceRequest{
ListenPort: listenPort,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/delete_service", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.LoadBalancerDeleteServiceResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerDeleteServiceResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -784,26 +720,21 @@ type LoadBalancerChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a Load Balancer.
func (c *LoadBalancerClient) ChangeProtection(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeProtectionRequest{
Delete: opts.Delete,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_protection", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// LoadBalancerChangeAlgorithmOpts specifies options for changing the algorithm of a Load Balancer.
@ -813,26 +744,21 @@ type LoadBalancerChangeAlgorithmOpts struct {
// ChangeAlgorithm changes the algorithm of a Load Balancer.
func (c *LoadBalancerClient) ChangeAlgorithm(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeAlgorithmOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_algorithm"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeAlgorithmRequest{
Type: string(opts.Type),
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_algorithm", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeAlgorithmResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionChangeAlgorithmResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// LoadBalancerAttachToNetworkOpts specifies options for attaching a Load Balancer to a network.
@ -843,29 +769,24 @@ type LoadBalancerAttachToNetworkOpts struct {
// AttachToNetwork attaches a Load Balancer to a network.
func (c *LoadBalancerClient) AttachToNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerAttachToNetworkOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/attach_to_network"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionAttachToNetworkRequest{
Network: opts.Network.ID,
}
if opts.IP != nil {
reqBody.IP = Ptr(opts.IP.String())
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/attach_to_network", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionAttachToNetworkResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionAttachToNetworkResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// LoadBalancerDetachFromNetworkOpts specifies options for detaching a Load Balancer from a network.
@ -875,56 +796,51 @@ type LoadBalancerDetachFromNetworkOpts struct {
// DetachFromNetwork detaches a Load Balancer from a network.
func (c *LoadBalancerClient) DetachFromNetwork(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerDetachFromNetworkOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/detach_from_network"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionDetachFromNetworkRequest{
Network: opts.Network.ID,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/load_balancers/%d/actions/detach_from_network", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionDetachFromNetworkResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionDetachFromNetworkResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// EnablePublicInterface enables the Load Balancer's public network interface.
func (c *LoadBalancerClient) EnablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) {
path := fmt.Sprintf("/load_balancers/%d/actions/enable_public_interface", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, nil)
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionEnablePublicInterfaceResponse{}
resp, err := c.client.Do(req, &respBody)
const opPath = "/load_balancers/%d/actions/enable_public_interface"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
respBody, resp, err := postRequest[schema.LoadBalancerActionEnablePublicInterfaceResponse](ctx, c.client, reqPath, nil)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// DisablePublicInterface disables the Load Balancer's public network interface.
func (c *LoadBalancerClient) DisablePublicInterface(ctx context.Context, loadBalancer *LoadBalancer) (*Action, *Response, error) {
path := fmt.Sprintf("/load_balancers/%d/actions/disable_public_interface", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, nil)
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionDisablePublicInterfaceResponse{}
resp, err := c.client.Do(req, &respBody)
const opPath = "/load_balancers/%d/actions/disable_public_interface"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
respBody, resp, err := postRequest[schema.LoadBalancerActionDisablePublicInterfaceResponse](ctx, c.client, reqPath, nil)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// LoadBalancerChangeTypeOpts specifies options for changing a Load Balancer's type.
@ -934,28 +850,21 @@ type LoadBalancerChangeTypeOpts struct {
// ChangeType changes a Load Balancer's type.
func (c *LoadBalancerClient) ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error) {
const opPath = "/load_balancers/%d/actions/change_type"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, loadBalancer.ID)
reqBody := schema.LoadBalancerActionChangeTypeRequest{}
if opts.LoadBalancerType.ID != 0 {
reqBody.LoadBalancerType = opts.LoadBalancerType.ID
} else {
reqBody.LoadBalancerType = opts.LoadBalancerType.Name
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
if opts.LoadBalancerType.ID != 0 || opts.LoadBalancerType.Name != "" {
reqBody.LoadBalancerType = schema.IDOrName{ID: opts.LoadBalancerType.ID, Name: opts.LoadBalancerType.Name}
}
path := fmt.Sprintf("/load_balancers/%d/actions/change_type", loadBalancer.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.LoadBalancerActionChangeTypeResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.LoadBalancerActionChangeTypeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -980,32 +889,34 @@ type LoadBalancerGetMetricsOpts struct {
Step int
}
func (o *LoadBalancerGetMetricsOpts) addQueryParams(req *http.Request) error {
query := req.URL.Query()
func (o LoadBalancerGetMetricsOpts) Validate() error {
if len(o.Types) == 0 {
return fmt.Errorf("no metric types specified")
return missingField(o, "Types")
}
if o.Start.IsZero() {
return missingField(o, "Start")
}
if o.End.IsZero() {
return missingField(o, "End")
}
return nil
}
func (o LoadBalancerGetMetricsOpts) values() url.Values {
query := url.Values{}
for _, typ := range o.Types {
query.Add("type", string(typ))
}
if o.Start.IsZero() {
return fmt.Errorf("no start time specified")
}
query.Add("start", o.Start.Format(time.RFC3339))
if o.End.IsZero() {
return fmt.Errorf("no end time specified")
}
query.Add("end", o.End.Format(time.RFC3339))
if o.Step > 0 {
query.Add("step", strconv.Itoa(o.Step))
}
req.URL.RawQuery = query.Encode()
return nil
return query
}
// LoadBalancerMetrics contains the metrics requested for a Load Balancer.
@ -1024,31 +935,32 @@ type LoadBalancerMetricsValue struct {
// GetMetrics obtains metrics for a Load Balancer.
func (c *LoadBalancerClient) GetMetrics(
ctx context.Context, lb *LoadBalancer, opts LoadBalancerGetMetricsOpts,
ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerGetMetricsOpts,
) (*LoadBalancerMetrics, *Response, error) {
var respBody schema.LoadBalancerGetMetricsResponse
const opPath = "/load_balancers/%d/metrics?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
if lb == nil {
return nil, nil, fmt.Errorf("illegal argument: load balancer is nil")
if loadBalancer == nil {
return nil, nil, missingArgument("loadBalancer", loadBalancer)
}
path := fmt.Sprintf("/load_balancers/%d/metrics", lb.ID)
req, err := c.client.NewRequest(ctx, "GET", path, nil)
if err := opts.Validate(); err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, loadBalancer.ID, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerGetMetricsResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, fmt.Errorf("new request: %v", err)
return nil, resp, err
}
if err := opts.addQueryParams(req); err != nil {
return nil, nil, fmt.Errorf("add query params: %v", err)
}
resp, err := c.client.Do(req, &respBody)
metrics, err := loadBalancerMetricsFromSchema(&respBody)
if err != nil {
return nil, nil, fmt.Errorf("get metrics: %v", err)
return nil, nil, fmt.Errorf("convert response body: %w", err)
}
ms, err := loadBalancerMetricsFromSchema(&respBody)
if err != nil {
return nil, nil, fmt.Errorf("convert response body: %v", err)
}
return ms, resp, nil
return metrics, resp, nil
}
// ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer.

View File

@ -6,6 +6,7 @@ import (
"net/url"
"strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -29,32 +30,27 @@ type LoadBalancerTypeClient struct {
// GetByID retrieves a Load Balancer type by its ID. If the Load Balancer type does not exist, nil is returned.
func (c *LoadBalancerTypeClient) GetByID(ctx context.Context, id int64) (*LoadBalancerType, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/load_balancer_types/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/load_balancer_types/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.LoadBalancerTypeGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.LoadBalancerTypeGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return LoadBalancerTypeFromSchema(body.LoadBalancerType), resp, nil
return LoadBalancerTypeFromSchema(respBody.LoadBalancerType), resp, nil
}
// GetByName retrieves a Load Balancer type by its name. If the Load Balancer type does not exist, nil is returned.
func (c *LoadBalancerTypeClient) GetByName(ctx context.Context, name string) (*LoadBalancerType, *Response, error) {
if name == "" {
return nil, nil, nil
}
LoadBalancerTypes, response, err := c.List(ctx, LoadBalancerTypeListOpts{Name: name})
if len(LoadBalancerTypes) == 0 {
return nil, response, err
}
return LoadBalancerTypes[0], response, err
return firstByName(name, func() ([]*LoadBalancerType, *Response, error) {
return c.List(ctx, LoadBalancerTypeListOpts{Name: name})
})
}
// Get retrieves a Load Balancer type by its ID if the input can be parsed as an integer, otherwise it
@ -89,22 +85,17 @@ func (l LoadBalancerTypeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *LoadBalancerTypeClient) List(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, *Response, error) {
path := "/load_balancer_types?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/load_balancer_types?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LoadBalancerTypeListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.LoadBalancerTypeListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
LoadBalancerTypes := make([]*LoadBalancerType, 0, len(body.LoadBalancerTypes))
for _, s := range body.LoadBalancerTypes {
LoadBalancerTypes = append(LoadBalancerTypes, LoadBalancerTypeFromSchema(s))
}
return LoadBalancerTypes, resp, nil
return allFromSchemaFunc(respBody.LoadBalancerTypes, LoadBalancerTypeFromSchema), resp, nil
}
// All returns all Load Balancer types.
@ -114,20 +105,8 @@ func (c *LoadBalancerTypeClient) All(ctx context.Context) ([]*LoadBalancerType,
// AllWithOpts returns all Load Balancer types for the given options.
func (c *LoadBalancerTypeClient) AllWithOpts(ctx context.Context, opts LoadBalancerTypeListOpts) ([]*LoadBalancerType, error) {
allLoadBalancerTypes := []*LoadBalancerType{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*LoadBalancerType, *Response, error) {
opts.Page = page
LoadBalancerTypes, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allLoadBalancerTypes = append(allLoadBalancerTypes, LoadBalancerTypes...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allLoadBalancerTypes, nil
}

View File

@ -6,6 +6,7 @@ import (
"net/url"
"strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -28,32 +29,27 @@ type LocationClient struct {
// GetByID retrieves a location by its ID. If the location does not exist, nil is returned.
func (c *LocationClient) GetByID(ctx context.Context, id int64) (*Location, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/locations/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/locations/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.LocationGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.LocationGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, resp, err
}
return LocationFromSchema(body.Location), resp, nil
return LocationFromSchema(respBody.Location), resp, nil
}
// GetByName retrieves an location by its name. If the location does not exist, nil is returned.
func (c *LocationClient) GetByName(ctx context.Context, name string) (*Location, *Response, error) {
if name == "" {
return nil, nil, nil
}
locations, response, err := c.List(ctx, LocationListOpts{Name: name})
if len(locations) == 0 {
return nil, response, err
}
return locations[0], response, err
return firstByName(name, func() ([]*Location, *Response, error) {
return c.List(ctx, LocationListOpts{Name: name})
})
}
// Get retrieves a location by its ID if the input can be parsed as an integer, otherwise it
@ -88,22 +84,17 @@ func (l LocationListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *LocationClient) List(ctx context.Context, opts LocationListOpts) ([]*Location, *Response, error) {
path := "/locations?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/locations?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.LocationListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.LocationListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
locations := make([]*Location, 0, len(body.Locations))
for _, i := range body.Locations {
locations = append(locations, LocationFromSchema(i))
}
return locations, resp, nil
return allFromSchemaFunc(respBody.Locations, LocationFromSchema), resp, nil
}
// All returns all locations.
@ -113,20 +104,8 @@ func (c *LocationClient) All(ctx context.Context) ([]*Location, error) {
// AllWithOpts returns all locations for the given options.
func (c *LocationClient) AllWithOpts(ctx context.Context, opts LocationListOpts) ([]*Location, error) {
allLocations := []*Location{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Location, *Response, error) {
opts.Page = page
locations, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allLocations = append(allLocations, locations...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allLocations, nil
}

View File

@ -1,6 +1,8 @@
package metadata
import (
"bytes"
"context"
"fmt"
"io"
"net"
@ -11,6 +13,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/internal/instrumentation"
)
@ -72,24 +75,33 @@ func NewClient(options ...ClientOption) *Client {
if client.instrumentationRegistry != nil {
i := instrumentation.New("metadata", client.instrumentationRegistry)
client.httpClient.Transport = i.InstrumentedRoundTripper()
client.httpClient.Transport = i.InstrumentedRoundTripper(client.httpClient.Transport)
}
return client
}
// get executes an HTTP request against the API.
func (c *Client) get(path string) (string, error) {
url := c.endpoint + path
resp, err := c.httpClient.Get(url)
ctx := ctxutil.SetOpPath(context.Background(), path)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.endpoint+path, http.NoBody)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
body := string(bodyBytes)
body := string(bytes.TrimSpace(bodyBytes))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return body, fmt.Errorf("response status was %d", resp.StatusCode)
}

View File

@ -1,16 +1,13 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -19,9 +16,10 @@ type NetworkZone string
// List of available Network Zones.
const (
NetworkZoneEUCentral NetworkZone = "eu-central"
NetworkZoneUSEast NetworkZone = "us-east"
NetworkZoneUSWest NetworkZone = "us-west"
NetworkZoneEUCentral NetworkZone = "eu-central"
NetworkZoneUSEast NetworkZone = "us-east"
NetworkZoneUSWest NetworkZone = "us-west"
NetworkZoneAPSouthEast NetworkZone = "ap-southeast"
)
// NetworkSubnetType specifies a type of a subnet.
@ -29,22 +27,30 @@ type NetworkSubnetType string
// List of available network subnet types.
const (
NetworkSubnetTypeCloud NetworkSubnetType = "cloud"
NetworkSubnetTypeServer NetworkSubnetType = "server"
// Used to connect cloud servers and load balancers.
NetworkSubnetTypeCloud NetworkSubnetType = "cloud"
// Used to connect cloud servers and load balancers.
//
// Deprecated: Use [NetworkSubnetTypeCloud] instead.
NetworkSubnetTypeServer NetworkSubnetType = "server"
// Used to connect cloud servers and load balancers with dedicated servers.
//
// See https://docs.hetzner.com/cloud/networks/connect-dedi-vswitch/
NetworkSubnetTypeVSwitch NetworkSubnetType = "vswitch"
)
// Network represents a network in the Hetzner Cloud.
type Network struct {
ID int64
Name string
Created time.Time
IPRange *net.IPNet
Subnets []NetworkSubnet
Routes []NetworkRoute
Servers []*Server
Protection NetworkProtection
Labels map[string]string
ID int64
Name string
Created time.Time
IPRange *net.IPNet
Subnets []NetworkSubnet
Routes []NetworkRoute
Servers []*Server
LoadBalancers []*LoadBalancer
Protection NetworkProtection
Labels map[string]string
// ExposeRoutesToVSwitch indicates if the routes from this network should be exposed to the vSwitch connection.
ExposeRoutesToVSwitch bool
@ -78,41 +84,33 @@ type NetworkClient struct {
// GetByID retrieves a network by its ID. If the network does not exist, nil is returned.
func (c *NetworkClient) GetByID(ctx context.Context, id int64) (*Network, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/networks/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/networks/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.NetworkGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.NetworkGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return NetworkFromSchema(body.Network), resp, nil
return NetworkFromSchema(respBody.Network), resp, nil
}
// GetByName retrieves a network by its name. If the network does not exist, nil is returned.
func (c *NetworkClient) GetByName(ctx context.Context, name string) (*Network, *Response, error) {
if name == "" {
return nil, nil, nil
}
Networks, response, err := c.List(ctx, NetworkListOpts{Name: name})
if len(Networks) == 0 {
return nil, response, err
}
return Networks[0], response, err
return firstByName(name, func() ([]*Network, *Response, error) {
return c.List(ctx, NetworkListOpts{Name: name})
})
}
// Get retrieves a network by its ID if the input can be parsed as an integer, otherwise it
// retrieves a network by its name. If the network does not exist, nil is returned.
func (c *NetworkClient) Get(ctx context.Context, idOrName string) (*Network, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// NetworkListOpts specifies options for listing networks.
@ -138,22 +136,17 @@ func (l NetworkListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *NetworkClient) List(ctx context.Context, opts NetworkListOpts) ([]*Network, *Response, error) {
path := "/networks?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/networks?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.NetworkListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.NetworkListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
Networks := make([]*Network, 0, len(body.Networks))
for _, s := range body.Networks {
Networks = append(Networks, NetworkFromSchema(s))
}
return Networks, resp, nil
return allFromSchemaFunc(respBody.Networks, NetworkFromSchema), resp, nil
}
// All returns all networks.
@ -163,31 +156,20 @@ func (c *NetworkClient) All(ctx context.Context) ([]*Network, error) {
// AllWithOpts returns all networks for the given options.
func (c *NetworkClient) AllWithOpts(ctx context.Context, opts NetworkListOpts) ([]*Network, error) {
allNetworks := []*Network{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Network, *Response, error) {
opts.Page = page
Networks, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allNetworks = append(allNetworks, Networks...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allNetworks, nil
}
// Delete deletes a network.
func (c *NetworkClient) Delete(ctx context.Context, network *Network) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/networks/%d", network.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/networks/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// NetworkUpdateOpts specifies options for updating a network.
@ -201,6 +183,11 @@ type NetworkUpdateOpts struct {
// Update updates a network.
func (c *NetworkClient) Update(ctx context.Context, network *Network, opts NetworkUpdateOpts) (*Network, *Response, error) {
const opPath = "/networks/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkUpdateRequest{
Name: opts.Name,
}
@ -211,22 +198,11 @@ func (c *NetworkClient) Update(ctx context.Context, network *Network, opts Netwo
reqBody.ExposeRoutesToVSwitch = opts.ExposeRoutesToVSwitch
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d", network.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.NetworkUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return NetworkFromSchema(respBody.Network), resp, nil
}
@ -245,16 +221,21 @@ type NetworkCreateOpts struct {
// Validate checks if options are valid.
func (o NetworkCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
if o.IPRange == nil || o.IPRange.String() == "" {
return errors.New("missing IP range")
return missingField(o, "IPRange")
}
return nil
}
// Create creates a new network.
func (c *NetworkClient) Create(ctx context.Context, opts NetworkCreateOpts) (*Network, *Response, error) {
const opPath = "/networks"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil {
return nil, nil, err
}
@ -283,20 +264,12 @@ func (c *NetworkClient) Create(ctx context.Context, opts NetworkCreateOpts) (*Ne
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/networks", bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkCreateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return NetworkFromSchema(respBody.Network), resp, nil
}
@ -307,25 +280,20 @@ type NetworkChangeIPRangeOpts struct {
// ChangeIPRange changes the IP range of a network.
func (c *NetworkClient) ChangeIPRange(ctx context.Context, network *Network, opts NetworkChangeIPRangeOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/change_ip_range"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionChangeIPRangeRequest{
IPRange: opts.IPRange.String(),
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/change_ip_range", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionChangeIPRangeResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionChangeIPRangeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -336,6 +304,11 @@ type NetworkAddSubnetOpts struct {
// AddSubnet adds a subnet to a network.
func (c *NetworkClient) AddSubnet(ctx context.Context, network *Network, opts NetworkAddSubnetOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/add_subnet"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionAddSubnetRequest{
Type: string(opts.Subnet.Type),
NetworkZone: string(opts.Subnet.NetworkZone),
@ -346,22 +319,12 @@ func (c *NetworkClient) AddSubnet(ctx context.Context, network *Network, opts Ne
if opts.Subnet.VSwitchID != 0 {
reqBody.VSwitchID = opts.Subnet.VSwitchID
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/add_subnet", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionAddSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionAddSubnetResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -372,25 +335,20 @@ type NetworkDeleteSubnetOpts struct {
// DeleteSubnet deletes a subnet from a network.
func (c *NetworkClient) DeleteSubnet(ctx context.Context, network *Network, opts NetworkDeleteSubnetOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/delete_subnet"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionDeleteSubnetRequest{
IPRange: opts.Subnet.IPRange.String(),
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/delete_subnet", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionDeleteSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionDeleteSubnetResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -401,26 +359,21 @@ type NetworkAddRouteOpts struct {
// AddRoute adds a route to a network.
func (c *NetworkClient) AddRoute(ctx context.Context, network *Network, opts NetworkAddRouteOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/add_route"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionAddRouteRequest{
Destination: opts.Route.Destination.String(),
Gateway: opts.Route.Gateway.String(),
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/add_route", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionAddSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionAddRouteResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -431,26 +384,21 @@ type NetworkDeleteRouteOpts struct {
// DeleteRoute deletes a route from a network.
func (c *NetworkClient) DeleteRoute(ctx context.Context, network *Network, opts NetworkDeleteRouteOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/delete_route"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionDeleteRouteRequest{
Destination: opts.Route.Destination.String(),
Gateway: opts.Route.Gateway.String(),
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/delete_route", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionDeleteSubnetResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionDeleteRouteResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -461,24 +409,19 @@ type NetworkChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a network.
func (c *NetworkClient) ChangeProtection(ctx context.Context, network *Network, opts NetworkChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/networks/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, network.ID)
reqBody := schema.NetworkActionChangeProtectionRequest{
Delete: opts.Delete,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/networks/%d/actions/change_protection", network.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.NetworkActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.NetworkActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}

View File

@ -1,15 +1,12 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -38,41 +35,33 @@ type PlacementGroupClient struct {
// GetByID retrieves a PlacementGroup by its ID. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) GetByID(ctx context.Context, id int64) (*PlacementGroup, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/placement_groups/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/placement_groups/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.PlacementGroupGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.PlacementGroupGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return PlacementGroupFromSchema(body.PlacementGroup), resp, nil
return PlacementGroupFromSchema(respBody.PlacementGroup), resp, nil
}
// GetByName retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) GetByName(ctx context.Context, name string) (*PlacementGroup, *Response, error) {
if name == "" {
return nil, nil, nil
}
placementGroups, response, err := c.List(ctx, PlacementGroupListOpts{Name: name})
if len(placementGroups) == 0 {
return nil, response, err
}
return placementGroups[0], response, err
return firstByName(name, func() ([]*PlacementGroup, *Response, error) {
return c.List(ctx, PlacementGroupListOpts{Name: name})
})
}
// Get retrieves a PlacementGroup by its ID if the input can be parsed as an integer, otherwise it
// retrieves a PlacementGroup by its name. If the PlacementGroup does not exist, nil is returned.
func (c *PlacementGroupClient) Get(ctx context.Context, idOrName string) (*PlacementGroup, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// PlacementGroupListOpts specifies options for listing PlacementGroup.
@ -102,22 +91,17 @@ func (l PlacementGroupListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *PlacementGroupClient) List(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, *Response, error) {
path := "/placement_groups?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/placement_groups?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.PlacementGroupListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.PlacementGroupListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
placementGroups := make([]*PlacementGroup, 0, len(body.PlacementGroups))
for _, g := range body.PlacementGroups {
placementGroups = append(placementGroups, PlacementGroupFromSchema(g))
}
return placementGroups, resp, nil
return allFromSchemaFunc(respBody.PlacementGroups, PlacementGroupFromSchema), resp, nil
}
// All returns all PlacementGroups.
@ -133,22 +117,10 @@ func (c *PlacementGroupClient) All(ctx context.Context) ([]*PlacementGroup, erro
// AllWithOpts returns all PlacementGroups for the given options.
func (c *PlacementGroupClient) AllWithOpts(ctx context.Context, opts PlacementGroupListOpts) ([]*PlacementGroup, error) {
allPlacementGroups := []*PlacementGroup{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*PlacementGroup, *Response, error) {
opts.Page = page
placementGroups, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allPlacementGroups = append(allPlacementGroups, placementGroups...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allPlacementGroups, nil
}
// PlacementGroupCreateOpts specifies options for creating a new PlacementGroup.
@ -161,7 +133,7 @@ type PlacementGroupCreateOpts struct {
// Validate checks if options are valid.
func (o PlacementGroupCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
return nil
}
@ -174,27 +146,25 @@ type PlacementGroupCreateResult struct {
// Create creates a new PlacementGroup.
func (c *PlacementGroupClient) Create(ctx context.Context, opts PlacementGroupCreateOpts) (PlacementGroupCreateResult, *Response, error) {
const opPath = "/placement_groups"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := PlacementGroupCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil {
return PlacementGroupCreateResult{}, nil, err
}
reqBody := placementGroupCreateOptsToSchema(opts)
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return PlacementGroupCreateResult{}, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/placement_groups", bytes.NewReader(reqBodyData))
if err != nil {
return PlacementGroupCreateResult{}, nil, err
return result, nil, err
}
respBody := schema.PlacementGroupCreateResponse{}
resp, err := c.client.Do(req, &respBody)
reqBody := placementGroupCreateOptsToSchema(opts)
respBody, resp, err := postRequest[schema.PlacementGroupCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return PlacementGroupCreateResult{}, nil, err
}
result := PlacementGroupCreateResult{
PlacementGroup: PlacementGroupFromSchema(respBody.PlacementGroup),
return result, resp, err
}
result.PlacementGroup = PlacementGroupFromSchema(respBody.PlacementGroup)
if respBody.Action != nil {
result.Action = ActionFromSchema(*respBody.Action)
}
@ -210,6 +180,11 @@ type PlacementGroupUpdateOpts struct {
// Update updates a PlacementGroup.
func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *PlacementGroup, opts PlacementGroupUpdateOpts) (*PlacementGroup, *Response, error) {
const opPath = "/placement_groups/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, placementGroup.ID)
reqBody := schema.PlacementGroupUpdateRequest{}
if opts.Name != "" {
reqBody.Name = &opts.Name
@ -217,19 +192,8 @@ func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *Place
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/placement_groups/%d", placementGroup.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.PlacementGroupUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.PlacementGroupUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
@ -239,9 +203,10 @@ func (c *PlacementGroupClient) Update(ctx context.Context, placementGroup *Place
// Delete deletes a PlacementGroup.
func (c *PlacementGroupClient) Delete(ctx context.Context, placementGroup *PlacementGroup) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/placement_groups/%d", placementGroup.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/placement_groups/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, placementGroup.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}

View File

@ -3,15 +3,19 @@ package hcloud
import (
"context"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
// Pricing specifies pricing information for various resources.
type Pricing struct {
Image ImagePricing
FloatingIP FloatingIPPricing
FloatingIPs []FloatingIPTypePricing
PrimaryIPs []PrimaryIPPricing
Image ImagePricing
// Deprecated: [Pricing.FloatingIP] is deprecated, use [Pricing.FloatingIPs] instead.
FloatingIP FloatingIPPricing
FloatingIPs []FloatingIPTypePricing
PrimaryIPs []PrimaryIPPricing
// Deprecated: [Pricing.Traffic] is deprecated and will report 0 after 2024-08-05.
// Use traffic pricing from [Pricing.ServerTypes] or [Pricing.LoadBalancerTypes] instead.
Traffic TrafficPricing
ServerBackup ServerBackupPricing
ServerTypes []ServerTypePricing
@ -102,6 +106,10 @@ type ServerTypeLocationPricing struct {
Location *Location
Hourly Price
Monthly Price
// IncludedTraffic is the free traffic per month in bytes
IncludedTraffic uint64
PerTBTraffic Price
}
// LoadBalancerTypePricing provides pricing information for a Load Balancer type.
@ -116,6 +124,10 @@ type LoadBalancerTypeLocationPricing struct {
Location *Location
Hourly Price
Monthly Price
// IncludedTraffic is the free traffic per month in bytes
IncludedTraffic uint64
PerTBTraffic Price
}
// PricingClient is a client for the pricing API.
@ -125,15 +137,15 @@ type PricingClient struct {
// Get retrieves pricing information.
func (c *PricingClient) Get(ctx context.Context) (Pricing, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", "/pricing", nil)
const opPath = "/pricing"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
respBody, resp, err := getRequest[schema.PricingGetResponse](ctx, c.client, reqPath)
if err != nil {
return Pricing{}, nil, err
return Pricing{}, resp, err
}
var body schema.PricingGetResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return Pricing{}, nil, err
}
return PricingFromSchema(body.Pricing), resp, nil
return PricingFromSchema(respBody.Pricing), resp, nil
}

View File

@ -1,15 +1,13 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -46,26 +44,21 @@ type PrimaryIPDNSPTR struct {
// changeDNSPtr changes or resets the reverse DNS pointer for a IP address.
// Pass a nil ptr to reset the reverse DNS pointer to its default value.
func (p *PrimaryIP) changeDNSPtr(ctx context.Context, client *Client, ip net.IP, ptr *string) (*Action, *Response, error) {
const opPath = "/primary_ips/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, p.ID)
reqBody := schema.PrimaryIPActionChangeDNSPtrRequest{
IP: ip.String(),
DNSPtr: ptr,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/primary_ips/%d/actions/change_dns_ptr", p.ID)
req, err := client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody PrimaryIPChangeDNSPtrResult
resp, err := client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.PrimaryIPActionChangeDNSPtrResponse](ctx, client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -92,13 +85,13 @@ const (
// PrimaryIPCreateOpts defines the request to
// create a Primary IP.
type PrimaryIPCreateOpts struct {
AssigneeID *int64 `json:"assignee_id,omitempty"`
AssigneeType string `json:"assignee_type"`
AutoDelete *bool `json:"auto_delete,omitempty"`
Datacenter string `json:"datacenter,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Name string `json:"name"`
Type PrimaryIPType `json:"type"`
AssigneeID *int64
AssigneeType string
AutoDelete *bool
Datacenter string
Labels map[string]string
Name string
Type PrimaryIPType
}
// PrimaryIPCreateResult defines the response
@ -111,51 +104,42 @@ type PrimaryIPCreateResult struct {
// PrimaryIPUpdateOpts defines the request to
// update a Primary IP.
type PrimaryIPUpdateOpts struct {
AutoDelete *bool `json:"auto_delete,omitempty"`
Labels *map[string]string `json:"labels,omitempty"`
Name string `json:"name,omitempty"`
AutoDelete *bool
Labels *map[string]string
Name string
}
// PrimaryIPAssignOpts defines the request to
// assign a Primary IP to an assignee (usually a server).
type PrimaryIPAssignOpts struct {
ID int64
AssigneeID int64 `json:"assignee_id"`
AssigneeType string `json:"assignee_type"`
AssigneeID int64
AssigneeType string
}
// PrimaryIPAssignResult defines the response
// when assigning a Primary IP to a assignee.
type PrimaryIPAssignResult struct {
Action schema.Action `json:"action"`
}
// Deprecated: Please use [schema.PrimaryIPActionAssignResponse] instead.
type PrimaryIPAssignResult = schema.PrimaryIPActionAssignResponse
// PrimaryIPChangeDNSPtrOpts defines the request to
// change a DNS PTR entry from a Primary IP.
type PrimaryIPChangeDNSPtrOpts struct {
ID int64
DNSPtr string `json:"dns_ptr"`
IP string `json:"ip"`
DNSPtr string
IP string
}
// PrimaryIPChangeDNSPtrResult defines the response
// when assigning a Primary IP to a assignee.
type PrimaryIPChangeDNSPtrResult struct {
Action schema.Action `json:"action"`
}
// Deprecated: Please use [schema.PrimaryIPChangeDNSPtrResponse] instead.
type PrimaryIPChangeDNSPtrResult = schema.PrimaryIPActionChangeDNSPtrResponse
// PrimaryIPChangeProtectionOpts defines the request to
// change protection configuration of a Primary IP.
type PrimaryIPChangeProtectionOpts struct {
ID int64
Delete bool `json:"delete"`
Delete bool
}
// PrimaryIPChangeProtectionResult defines the response
// when changing a protection of a PrimaryIP.
type PrimaryIPChangeProtectionResult struct {
Action schema.Action `json:"action"`
}
// Deprecated: Please use [schema.PrimaryIPActionChangeProtectionResponse] instead.
type PrimaryIPChangeProtectionResult = schema.PrimaryIPActionChangeProtectionResponse
// PrimaryIPClient is a client for the Primary IP API.
type PrimaryIPClient struct {
@ -165,20 +149,20 @@ type PrimaryIPClient struct {
// GetByID retrieves a Primary IP by its ID. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) GetByID(ctx context.Context, id int64) (*PrimaryIP, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/primary_ips/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/primary_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.PrimaryIPGetResult
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.PrimaryIPGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return PrimaryIPFromSchema(body.PrimaryIP), resp, nil
return PrimaryIPFromSchema(respBody.PrimaryIP), resp, nil
}
// GetByIP retrieves a Primary IP by its IP Address. If the Primary IP does not exist, nil is returned.
@ -186,32 +170,22 @@ func (c *PrimaryIPClient) GetByIP(ctx context.Context, ip string) (*PrimaryIP, *
if ip == "" {
return nil, nil, nil
}
primaryIPs, response, err := c.List(ctx, PrimaryIPListOpts{IP: ip})
if len(primaryIPs) == 0 {
return nil, response, err
}
return primaryIPs[0], response, err
return firstBy(func() ([]*PrimaryIP, *Response, error) {
return c.List(ctx, PrimaryIPListOpts{IP: ip})
})
}
// GetByName retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) GetByName(ctx context.Context, name string) (*PrimaryIP, *Response, error) {
if name == "" {
return nil, nil, nil
}
primaryIPs, response, err := c.List(ctx, PrimaryIPListOpts{Name: name})
if len(primaryIPs) == 0 {
return nil, response, err
}
return primaryIPs[0], response, err
return firstByName(name, func() ([]*PrimaryIP, *Response, error) {
return c.List(ctx, PrimaryIPListOpts{Name: name})
})
}
// Get retrieves a Primary IP by its ID if the input can be parsed as an integer, otherwise it
// retrieves a Primary IP by its name. If the Primary IP does not exist, nil is returned.
func (c *PrimaryIPClient) Get(ctx context.Context, idOrName string) (*PrimaryIP, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// PrimaryIPListOpts specifies options for listing Primary IPs.
@ -241,22 +215,17 @@ func (l PrimaryIPListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *PrimaryIPClient) List(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, *Response, error) {
path := "/primary_ips?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/primary_ips?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.PrimaryIPListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.PrimaryIPListResult
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
primaryIPs := make([]*PrimaryIP, 0, len(body.PrimaryIPs))
for _, s := range body.PrimaryIPs {
primaryIPs = append(primaryIPs, PrimaryIPFromSchema(s))
}
return primaryIPs, resp, nil
return allFromSchemaFunc(respBody.PrimaryIPs, PrimaryIPFromSchema), resp, nil
}
// All returns all Primary IPs.
@ -266,157 +235,125 @@ func (c *PrimaryIPClient) All(ctx context.Context) ([]*PrimaryIP, error) {
// AllWithOpts returns all Primary IPs for the given options.
func (c *PrimaryIPClient) AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error) {
allPrimaryIPs := []*PrimaryIP{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*PrimaryIP, *Response, error) {
opts.Page = page
primaryIPs, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allPrimaryIPs = append(allPrimaryIPs, primaryIPs...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allPrimaryIPs, nil
}
// Create creates a Primary IP.
func (c *PrimaryIPClient) Create(ctx context.Context, reqBody PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error) {
reqBodyData, err := json.Marshal(reqBody)
func (c *PrimaryIPClient) Create(ctx context.Context, opts PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error) {
const opPath = "/primary_ips"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := &PrimaryIPCreateResult{}
reqPath := opPath
reqBody := SchemaFromPrimaryIPCreateOpts(opts)
respBody, resp, err := postRequest[schema.PrimaryIPCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return &PrimaryIPCreateResult{}, nil, err
return result, resp, err
}
req, err := c.client.NewRequest(ctx, "POST", "/primary_ips", bytes.NewReader(reqBodyData))
if err != nil {
return &PrimaryIPCreateResult{}, nil, err
}
var respBody schema.PrimaryIPCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return &PrimaryIPCreateResult{}, resp, err
}
var action *Action
result.PrimaryIP = PrimaryIPFromSchema(respBody.PrimaryIP)
if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action)
result.Action = ActionFromSchema(*respBody.Action)
}
primaryIP := PrimaryIPFromSchema(respBody.PrimaryIP)
return &PrimaryIPCreateResult{
PrimaryIP: primaryIP,
Action: action,
}, resp, nil
return result, resp, nil
}
// Delete deletes a Primary IP.
func (c *PrimaryIPClient) Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/primary_ips/%d", primaryIP.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/primary_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, primaryIP.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// Update updates a Primary IP.
func (c *PrimaryIPClient) Update(ctx context.Context, primaryIP *PrimaryIP, reqBody PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error) {
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
func (c *PrimaryIPClient) Update(ctx context.Context, primaryIP *PrimaryIP, opts PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error) {
const opPath = "/primary_ips/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/primary_ips/%d", primaryIP.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, primaryIP.ID)
var respBody schema.PrimaryIPUpdateResult
resp, err := c.client.Do(req, &respBody)
reqBody := SchemaFromPrimaryIPUpdateOpts(opts)
respBody, resp, err := putRequest[schema.PrimaryIPUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return PrimaryIPFromSchema(respBody.PrimaryIP), resp, nil
}
// Assign a Primary IP to a resource.
func (c *PrimaryIPClient) Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts)
if err != nil {
return nil, nil, err
}
const opPath = "/primary_ips/%d/actions/assign"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/primary_ips/%d/actions/assign", opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, opts.ID)
var respBody PrimaryIPAssignResult
resp, err := c.client.Do(req, &respBody)
reqBody := SchemaFromPrimaryIPAssignOpts(opts)
respBody, resp, err := postRequest[schema.PrimaryIPActionAssignResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
// Unassign a Primary IP from a resource.
func (c *PrimaryIPClient) Unassign(ctx context.Context, id int64) (*Action, *Response, error) {
path := fmt.Sprintf("/primary_ips/%d/actions/unassign", id)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader([]byte{}))
if err != nil {
return nil, nil, err
}
const opPath = "/primary_ips/%d/actions/unassign"
ctx = ctxutil.SetOpPath(ctx, opPath)
var respBody PrimaryIPAssignResult
resp, err := c.client.Do(req, &respBody)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := postRequest[schema.PrimaryIPActionUnassignResponse](ctx, c.client, reqPath, nil)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
// ChangeDNSPtr Change the reverse DNS from a Primary IP.
func (c *PrimaryIPClient) ChangeDNSPtr(ctx context.Context, opts PrimaryIPChangeDNSPtrOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts)
if err != nil {
return nil, nil, err
}
const opPath = "/primary_ips/%d/actions/change_dns_ptr"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/primary_ips/%d/actions/change_dns_ptr", opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, opts.ID)
var respBody PrimaryIPChangeDNSPtrResult
resp, err := c.client.Do(req, &respBody)
reqBody := SchemaFromPrimaryIPChangeDNSPtrOpts(opts)
respBody, resp, err := postRequest[schema.PrimaryIPActionChangeDNSPtrResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
// ChangeProtection Changes the protection configuration of a Primary IP.
func (c *PrimaryIPClient) ChangeProtection(ctx context.Context, opts PrimaryIPChangeProtectionOpts) (*Action, *Response, error) {
reqBodyData, err := json.Marshal(opts)
if err != nil {
return nil, nil, err
}
const opPath = "/primary_ips/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
path := fmt.Sprintf("/primary_ips/%d/actions/change_protection", opts.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
reqPath := fmt.Sprintf(opPath, opts.ID)
var respBody PrimaryIPChangeProtectionResult
resp, err := c.client.Do(req, &respBody)
reqBody := SchemaFromPrimaryIPChangeProtectionOpts(opts)
respBody, resp, err := postRequest[schema.PrimaryIPActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}

View File

@ -49,6 +49,26 @@ func SchemaFromPrimaryIP(p *PrimaryIP) schema.PrimaryIP {
return c.SchemaFromPrimaryIP(p)
}
func SchemaFromPrimaryIPCreateOpts(o PrimaryIPCreateOpts) schema.PrimaryIPCreateRequest {
return c.SchemaFromPrimaryIPCreateOpts(o)
}
func SchemaFromPrimaryIPUpdateOpts(o PrimaryIPUpdateOpts) schema.PrimaryIPUpdateRequest {
return c.SchemaFromPrimaryIPUpdateOpts(o)
}
func SchemaFromPrimaryIPChangeDNSPtrOpts(o PrimaryIPChangeDNSPtrOpts) schema.PrimaryIPActionChangeDNSPtrRequest {
return c.SchemaFromPrimaryIPChangeDNSPtrOpts(o)
}
func SchemaFromPrimaryIPChangeProtectionOpts(o PrimaryIPChangeProtectionOpts) schema.PrimaryIPActionChangeProtectionRequest {
return c.SchemaFromPrimaryIPChangeProtectionOpts(o)
}
func SchemaFromPrimaryIPAssignOpts(o PrimaryIPAssignOpts) schema.PrimaryIPActionAssignRequest {
return c.SchemaFromPrimaryIPAssignOpts(o)
}
// ISOFromSchema converts a schema.ISO to an ISO.
func ISOFromSchema(s schema.ISO) *ISO {
return c.ISOFromSchema(s)

View File

@ -0,0 +1,4 @@
// The schema package holds API schemas for the `hcloud-go` library.
// Breaking changes may occur without notice. Do not use in production!
package schema

View File

@ -7,7 +7,7 @@ type Error struct {
Code string `json:"code"`
Message string `json:"message"`
DetailsRaw json.RawMessage `json:"details"`
Details interface{}
Details any `json:"-"`
}
// UnmarshalJSON overrides default json unmarshalling.
@ -17,13 +17,20 @@ func (e *Error) UnmarshalJSON(data []byte) (err error) {
if err = json.Unmarshal(data, alias); err != nil {
return
}
if e.Code == "invalid_input" {
if e.Code == "invalid_input" && len(e.DetailsRaw) > 0 {
details := ErrorDetailsInvalidInput{}
if err = json.Unmarshal(e.DetailsRaw, &details); err != nil {
return
}
alias.Details = details
}
if e.Code == "deprecated_api_endpoint" && len(e.DetailsRaw) > 0 {
details := ErrorDetailsDeprecatedAPIEndpoint{}
if err = json.Unmarshal(e.DetailsRaw, &details); err != nil {
return
}
alias.Details = details
}
return
}
@ -40,3 +47,9 @@ type ErrorDetailsInvalidInput struct {
Messages []string `json:"messages"`
} `json:"fields"`
}
// ErrorDetailsDeprecatedAPIEndpoint defines the schema of the Details field
// of an error with code 'deprecated_api_endpoint'.
type ErrorDetailsDeprecatedAPIEndpoint struct {
Announcement string `json:"announcement"`
}

View File

@ -0,0 +1,68 @@
package schema
import (
"bytes"
"encoding/json"
"reflect"
"strconv"
)
// IDOrName can be used in API requests where either a resource id or name can be
// specified.
type IDOrName struct {
ID int64
Name string
}
var _ json.Unmarshaler = (*IDOrName)(nil)
var _ json.Marshaler = (*IDOrName)(nil)
func (o IDOrName) MarshalJSON() ([]byte, error) {
if o.ID != 0 {
return json.Marshal(o.ID)
}
if o.Name != "" {
return json.Marshal(o.Name)
}
// We want to preserve the behavior of an empty interface{} to prevent breaking
// changes (marshaled to null when empty).
return json.Marshal(nil)
}
func (o *IDOrName) UnmarshalJSON(data []byte) error {
d := json.NewDecoder(bytes.NewBuffer(data))
// This ensures we won't lose precision on large IDs, see json.Number below
d.UseNumber()
var v any
if err := d.Decode(&v); err != nil {
return err
}
switch typed := v.(type) {
case string:
id, err := strconv.ParseInt(typed, 10, 64)
if err == nil {
o.ID = id
} else if typed != "" {
o.Name = typed
}
case json.Number:
id, err := typed.Int64()
if err != nil {
return &json.UnmarshalTypeError{
Value: string(data),
Type: reflect.TypeOf(*o),
}
}
o.ID = id
default:
return &json.UnmarshalTypeError{
Value: string(data),
Type: reflect.TypeOf(*o),
}
}
return nil
}

View File

@ -251,7 +251,7 @@ type LoadBalancerDeleteServiceResponse struct {
type LoadBalancerCreateRequest struct {
Name string `json:"name"`
LoadBalancerType interface{} `json:"load_balancer_type"` // int or string
LoadBalancerType IDOrName `json:"load_balancer_type"`
Algorithm *LoadBalancerCreateRequestAlgorithm `json:"algorithm,omitempty"`
Location *string `json:"location,omitempty"`
NetworkZone *string `json:"network_zone,omitempty"`
@ -380,7 +380,7 @@ type LoadBalancerActionDisablePublicInterfaceResponse struct {
}
type LoadBalancerActionChangeTypeRequest struct {
LoadBalancerType interface{} `json:"load_balancer_type"` // int or string
LoadBalancerType IDOrName `json:"load_balancer_type"`
}
type LoadBalancerActionChangeTypeResponse struct {

View File

@ -11,6 +11,7 @@ type Network struct {
Subnets []NetworkSubnet `json:"subnets"`
Routes []NetworkRoute `json:"routes"`
Servers []int64 `json:"servers"`
LoadBalancers []int64 `json:"load_balancers"`
Protection NetworkProtection `json:"protection"`
Labels map[string]string `json:"labels"`
ExposeRoutesToVSwitch bool `json:"expose_routes_to_vswitch"`

View File

@ -2,12 +2,15 @@ package schema
// Pricing defines the schema for pricing information.
type Pricing struct {
Currency string `json:"currency"`
VATRate string `json:"vat_rate"`
Image PricingImage `json:"image"`
FloatingIP PricingFloatingIP `json:"floating_ip"`
FloatingIPs []PricingFloatingIPType `json:"floating_ips"`
PrimaryIPs []PricingPrimaryIP `json:"primary_ips"`
Currency string `json:"currency"`
VATRate string `json:"vat_rate"`
Image PricingImage `json:"image"`
// Deprecated: [Pricing.FloatingIP] is deprecated, use [Pricing.FloatingIPs] instead.
FloatingIP PricingFloatingIP `json:"floating_ip"`
FloatingIPs []PricingFloatingIPType `json:"floating_ips"`
PrimaryIPs []PricingPrimaryIP `json:"primary_ips"`
// Deprecated: [Pricing.Traffic] is deprecated and will report 0 after 2024-08-05.
// Use traffic pricing from [Pricing.ServerTypes] or [Pricing.LoadBalancerTypes] instead.
Traffic PricingTraffic `json:"traffic"`
ServerBackup PricingServerBackup `json:"server_backup"`
ServerTypes []PricingServerType `json:"server_types"`
@ -72,6 +75,9 @@ type PricingServerTypePrice struct {
Location string `json:"location"`
PriceHourly Price `json:"price_hourly"`
PriceMonthly Price `json:"price_monthly"`
IncludedTraffic uint64 `json:"included_traffic"`
PricePerTBTraffic Price `json:"price_per_tb_traffic"`
}
// PricingLoadBalancerType defines the schema of pricing information for a Load Balancer type.
@ -87,6 +93,9 @@ type PricingLoadBalancerTypePrice struct {
Location string `json:"location"`
PriceHourly Price `json:"price_hourly"`
PriceMonthly Price `json:"price_monthly"`
IncludedTraffic uint64 `json:"included_traffic"`
PricePerTBTraffic Price `json:"price_per_tb_traffic"`
}
// PricingGetResponse defines the schema of the response when retrieving pricing information.

View File

@ -31,6 +31,18 @@ type PrimaryIPDNSPTR struct {
IP string `json:"ip"`
}
// PrimaryIPCreateOpts defines the request to
// create a Primary IP.
type PrimaryIPCreateRequest struct {
Name string `json:"name"`
Type string `json:"type"`
AssigneeType string `json:"assignee_type"`
AssigneeID *int64 `json:"assignee_id,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
AutoDelete *bool `json:"auto_delete,omitempty"`
Datacenter string `json:"datacenter,omitempty"`
}
// PrimaryIPCreateResponse defines the schema of the response
// when creating a Primary IP.
type PrimaryIPCreateResponse struct {
@ -38,19 +50,27 @@ type PrimaryIPCreateResponse struct {
Action *Action `json:"action"`
}
// PrimaryIPGetResult defines the response when retrieving a single Primary IP.
type PrimaryIPGetResult struct {
// PrimaryIPGetResponse defines the response when retrieving a single Primary IP.
type PrimaryIPGetResponse struct {
PrimaryIP PrimaryIP `json:"primary_ip"`
}
// PrimaryIPListResult defines the response when listing Primary IPs.
type PrimaryIPListResult struct {
// PrimaryIPListResponse defines the response when listing Primary IPs.
type PrimaryIPListResponse struct {
PrimaryIPs []PrimaryIP `json:"primary_ips"`
}
// PrimaryIPUpdateResult defines the response
// PrimaryIPUpdateOpts defines the request to
// update a Primary IP.
type PrimaryIPUpdateRequest struct {
Name string `json:"name,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
AutoDelete *bool `json:"auto_delete,omitempty"`
}
// PrimaryIPUpdateResponse defines the response
// when updating a Primary IP.
type PrimaryIPUpdateResult struct {
type PrimaryIPUpdateResponse struct {
PrimaryIP PrimaryIP `json:"primary_ip"`
}
@ -60,3 +80,39 @@ type PrimaryIPActionChangeDNSPtrRequest struct {
IP string `json:"ip"`
DNSPtr *string `json:"dns_ptr"`
}
// PrimaryIPActionChangeDNSPtrResponse defines the response when setting a reverse DNS
// pointer for a IP address.
type PrimaryIPActionChangeDNSPtrResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionAssignRequest defines the request to
// assign a Primary IP to an assignee (usually a server).
type PrimaryIPActionAssignRequest struct {
AssigneeID int64 `json:"assignee_id"`
AssigneeType string `json:"assignee_type"`
}
// PrimaryIPActionAssignResponse defines the response when assigning a Primary IP to a
// assignee.
type PrimaryIPActionAssignResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionUnassignResponse defines the response to unassign a Primary IP.
type PrimaryIPActionUnassignResponse struct {
Action Action `json:"action"`
}
// PrimaryIPActionChangeProtectionRequest defines the request to
// change protection configuration of a Primary IP.
type PrimaryIPActionChangeProtectionRequest struct {
Delete bool `json:"delete"`
}
// PrimaryIPActionChangeProtectionResponse defines the response when changing the
// protection of a Primary IP.
type PrimaryIPActionChangeProtectionResponse struct {
Action Action `json:"action"`
}

View File

@ -99,8 +99,8 @@ type ServerListResponse struct {
// create a server.
type ServerCreateRequest struct {
Name string `json:"name"`
ServerType interface{} `json:"server_type"` // int or string
Image interface{} `json:"image"` // int or string
ServerType IDOrName `json:"server_type"`
Image IDOrName `json:"image"`
SSHKeys []int64 `json:"ssh_keys,omitempty"`
Location string `json:"location,omitempty"`
Datacenter string `json:"datacenter,omitempty"`
@ -257,7 +257,7 @@ type ServerActionDisableRescueResponse struct {
// ServerActionRebuildRequest defines the schema for the request to
// rebuild a server.
type ServerActionRebuildRequest struct {
Image interface{} `json:"image"` // int or string
Image IDOrName `json:"image"`
}
// ServerActionRebuildResponse defines the schema of the response when
@ -270,7 +270,7 @@ type ServerActionRebuildResponse struct {
// ServerActionAttachISORequest defines the schema for the request to
// attach an ISO to a server.
type ServerActionAttachISORequest struct {
ISO interface{} `json:"iso"` // int or string
ISO IDOrName `json:"iso"`
}
// ServerActionAttachISOResponse defines the schema of the response when
@ -289,12 +289,6 @@ type ServerActionDetachISOResponse struct {
Action Action `json:"action"`
}
// ServerActionEnableBackupRequest defines the schema for the request to
// enable backup for a server.
type ServerActionEnableBackupRequest struct {
BackupWindow *string `json:"backup_window,omitempty"`
}
// ServerActionEnableBackupResponse defines the schema of the response when
// creating a enable_backup server action.
type ServerActionEnableBackupResponse struct {
@ -314,8 +308,8 @@ type ServerActionDisableBackupResponse struct {
// ServerActionChangeTypeRequest defines the schema for the request to
// change a server's type.
type ServerActionChangeTypeRequest struct {
ServerType interface{} `json:"server_type"` // int or string
UpgradeDisk bool `json:"upgrade_disk"`
ServerType IDOrName `json:"server_type"`
UpgradeDisk bool `json:"upgrade_disk"`
}
// ServerActionChangeTypeResponse defines the schema of the response when

View File

@ -2,15 +2,18 @@ package schema
// ServerType defines the schema of a server type.
type ServerType struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Cores int `json:"cores"`
Memory float32 `json:"memory"`
Disk int `json:"disk"`
StorageType string `json:"storage_type"`
CPUType string `json:"cpu_type"`
Architecture string `json:"architecture"`
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Cores int `json:"cores"`
Memory float32 `json:"memory"`
Disk int `json:"disk"`
StorageType string `json:"storage_type"`
CPUType string `json:"cpu_type"`
Architecture string `json:"architecture"`
// Deprecated: [ServerType.IncludedTraffic] is deprecated and will always report 0 after 2024-08-05.
// Use [ServerType.Prices] instead to get the included traffic for each location.
IncludedTraffic int64 `json:"included_traffic"`
Prices []PricingServerTypePrice `json:"prices"`
Deprecated bool `json:"deprecated"`

View File

@ -23,7 +23,7 @@ type VolumeCreateRequest struct {
Name string `json:"name"`
Size int `json:"size"`
Server *int64 `json:"server,omitempty"`
Location interface{} `json:"location,omitempty"` // int, string, or nil
Location *IDOrName `json:"location,omitempty"`
Labels *map[string]string `json:"labels,omitempty"`
Automount *bool `json:"automount,omitempty"`
Format *string `json:"format,omitempty"`

View File

@ -69,7 +69,6 @@ You can find a documentation of goverter here: https://goverter.jmattheis.de/
// goverter:extend durationFromIntSeconds
// goverter:extend intSecondsFromDuration
// goverter:extend serverFromImageCreatedFromSchema
// goverter:extend anyFromLoadBalancerType
// goverter:extend serverMetricsTimeSeriesFromSchema
// goverter:extend loadBalancerMetricsTimeSeriesFromSchema
// goverter:extend stringPtrFromLoadBalancerServiceProtocol
@ -108,6 +107,12 @@ type converter interface {
// goverter:map AssigneeID | mapZeroInt64ToNil
SchemaFromPrimaryIP(*PrimaryIP) schema.PrimaryIP
SchemaFromPrimaryIPCreateOpts(PrimaryIPCreateOpts) schema.PrimaryIPCreateRequest
SchemaFromPrimaryIPUpdateOpts(PrimaryIPUpdateOpts) schema.PrimaryIPUpdateRequest
SchemaFromPrimaryIPChangeDNSPtrOpts(PrimaryIPChangeDNSPtrOpts) schema.PrimaryIPActionChangeDNSPtrRequest
SchemaFromPrimaryIPChangeProtectionOpts(PrimaryIPChangeProtectionOpts) schema.PrimaryIPActionChangeProtectionRequest
SchemaFromPrimaryIPAssignOpts(PrimaryIPAssignOpts) schema.PrimaryIPActionAssignRequest
ISOFromSchema(schema.ISO) *ISO
// We cannot use goverter settings when mapping a struct to a struct pointer
@ -207,10 +212,12 @@ type converter interface {
// goverter:map PriceHourly Hourly
// goverter:map PriceMonthly Monthly
// goverter:map PricePerTBTraffic PerTBTraffic
LoadBalancerTypeLocationPricingFromSchema(schema.PricingLoadBalancerTypePrice) LoadBalancerTypeLocationPricing
// goverter:map Hourly PriceHourly
// goverter:map Monthly PriceMonthly
// goverter:map PerTBTraffic PricePerTBTraffic
SchemaFromLoadBalancerTypeLocationPricing(LoadBalancerTypeLocationPricing) schema.PricingLoadBalancerTypePrice
LoadBalancerServiceFromSchema(schema.LoadBalancerService) LoadBalancerService
@ -263,6 +270,7 @@ type converter interface {
// goverter:map PriceHourly Hourly
// goverter:map PriceMonthly Monthly
// goverter:map PricePerTBTraffic PerTBTraffic
serverTypePricingFromSchema(schema.PricingServerTypePrice) ServerTypeLocationPricing
// goverter:map Image.PerGBMonth.Currency Currency
@ -306,6 +314,7 @@ type converter interface {
// goverter:map Monthly PriceMonthly
// goverter:map Hourly PriceHourly
// goverter:map PerTBTraffic PricePerTBTraffic
schemaFromServerTypeLocationPricing(ServerTypeLocationPricing) schema.PricingServerTypePrice
FirewallFromSchema(schema.Firewall) *Firewall
@ -606,37 +615,48 @@ func intSecondsFromDuration(d time.Duration) int {
}
func errorDetailsFromSchema(d interface{}) interface{} {
if d, ok := d.(schema.ErrorDetailsInvalidInput); ok {
switch typed := d.(type) {
case schema.ErrorDetailsInvalidInput:
details := ErrorDetailsInvalidInput{
Fields: make([]ErrorDetailsInvalidInputField, len(d.Fields)),
Fields: make([]ErrorDetailsInvalidInputField, len(typed.Fields)),
}
for i, field := range d.Fields {
for i, field := range typed.Fields {
details.Fields[i] = ErrorDetailsInvalidInputField{
Name: field.Name,
Messages: field.Messages,
}
}
return details
case schema.ErrorDetailsDeprecatedAPIEndpoint:
return ErrorDetailsDeprecatedAPIEndpoint{
Announcement: typed.Announcement,
}
}
return nil
}
func schemaFromErrorDetails(d interface{}) interface{} {
if d, ok := d.(ErrorDetailsInvalidInput); ok {
switch typed := d.(type) {
case ErrorDetailsInvalidInput:
details := schema.ErrorDetailsInvalidInput{
Fields: make([]struct {
Name string `json:"name"`
Messages []string `json:"messages"`
}, len(d.Fields)),
}, len(typed.Fields)),
}
for i, field := range d.Fields {
for i, field := range typed.Fields {
details.Fields[i] = struct {
Name string `json:"name"`
Messages []string `json:"messages"`
}{Name: field.Name, Messages: field.Messages}
}
return details
case ErrorDetailsDeprecatedAPIEndpoint:
return schema.ErrorDetailsDeprecatedAPIEndpoint{Announcement: typed.Announcement}
}
return nil
}
@ -654,8 +674,8 @@ func imagePricingFromSchema(s schema.Pricing) ImagePricing {
func floatingIPPricingFromSchema(s schema.Pricing) FloatingIPPricing {
return FloatingIPPricing{
Monthly: Price{
Net: s.FloatingIP.PriceMonthly.Net,
Gross: s.FloatingIP.PriceMonthly.Gross,
Net: s.FloatingIP.PriceMonthly.Net, // nolint:staticcheck // Field is deprecated, but removal is not planned
Gross: s.FloatingIP.PriceMonthly.Gross, // nolint:staticcheck // Field is deprecated, but removal is not planned
Currency: s.Currency,
VATRate: s.VATRate,
},
@ -707,8 +727,8 @@ func primaryIPPricingFromSchema(s schema.Pricing) []PrimaryIPPricing {
func trafficPricingFromSchema(s schema.Pricing) TrafficPricing {
return TrafficPricing{
PerTB: Price{
Net: s.Traffic.PricePerTB.Net,
Gross: s.Traffic.PricePerTB.Gross,
Net: s.Traffic.PricePerTB.Net, // nolint:staticcheck // Field is deprecated, but we still need to map it as long as it is available
Gross: s.Traffic.PricePerTB.Gross, // nolint:staticcheck // Field is deprecated, but we still need to map it as long as it is available
Currency: s.Currency,
VATRate: s.VATRate,
},
@ -734,6 +754,13 @@ func serverTypePricingFromSchema(s schema.Pricing) []ServerTypePricing {
Net: price.PriceMonthly.Net,
Gross: price.PriceMonthly.Gross,
},
IncludedTraffic: price.IncludedTraffic,
PerTBTraffic: Price{
Currency: s.Currency,
VATRate: s.VATRate,
Net: price.PricePerTBTraffic.Net,
Gross: price.PricePerTBTraffic.Gross,
},
}
}
p[i] = ServerTypePricing{
@ -766,6 +793,13 @@ func loadBalancerTypePricingFromSchema(s schema.Pricing) []LoadBalancerTypePrici
Net: price.PriceMonthly.Net,
Gross: price.PriceMonthly.Gross,
},
IncludedTraffic: price.IncludedTraffic,
PerTBTraffic: Price{
Currency: s.Currency,
VATRate: s.VATRate,
Net: price.PricePerTBTraffic.Net,
Gross: price.PricePerTBTraffic.Gross,
},
}
}
p[i] = LoadBalancerTypePricing{
@ -790,16 +824,6 @@ func volumePricingFromSchema(s schema.Pricing) VolumePricing {
}
}
func anyFromLoadBalancerType(t *LoadBalancerType) interface{} {
if t == nil {
return nil
}
if t.ID != 0 {
return t.ID
}
return t.Name
}
func serverMetricsTimeSeriesFromSchema(s schema.ServerTimeSeriesVals) ([]ServerMetricsValue, error) {
vals := make([]ServerMetricsValue, len(s.Values))
@ -922,7 +946,10 @@ func rawSchemaFromErrorDetails(v interface{}) json.RawMessage {
if v == nil {
return nil
}
msg, _ := json.Marshal(d)
msg, err := json.Marshal(d)
if err != nil {
return nil
}
return msg
}

View File

@ -6,6 +6,7 @@ import (
"net/url"
"strconv"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -20,7 +21,9 @@ type ServerType struct {
StorageType StorageType
CPUType CPUType
Architecture Architecture
// IncludedTraffic is the free traffic per month in bytes
// Deprecated: [ServerType.IncludedTraffic] is deprecated and will always report 0 after 2024-08-05.
// Use [ServerType.Pricings] instead to get the included traffic for each location.
IncludedTraffic int64
Pricings []ServerTypeLocationPricing
DeprecatableResource
@ -55,32 +58,27 @@ type ServerTypeClient struct {
// GetByID retrieves a server type by its ID. If the server type does not exist, nil is returned.
func (c *ServerTypeClient) GetByID(ctx context.Context, id int64) (*ServerType, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/server_types/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/server_types/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.ServerTypeGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.ServerTypeGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return ServerTypeFromSchema(body.ServerType), resp, nil
return ServerTypeFromSchema(respBody.ServerType), resp, nil
}
// GetByName retrieves a server type by its name. If the server type does not exist, nil is returned.
func (c *ServerTypeClient) GetByName(ctx context.Context, name string) (*ServerType, *Response, error) {
if name == "" {
return nil, nil, nil
}
serverTypes, response, err := c.List(ctx, ServerTypeListOpts{Name: name})
if len(serverTypes) == 0 {
return nil, response, err
}
return serverTypes[0], response, err
return firstByName(name, func() ([]*ServerType, *Response, error) {
return c.List(ctx, ServerTypeListOpts{Name: name})
})
}
// Get retrieves a server type by its ID if the input can be parsed as an integer, otherwise it
@ -115,22 +113,17 @@ func (l ServerTypeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *ServerTypeClient) List(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, *Response, error) {
path := "/server_types?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/server_types?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.ServerTypeListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.ServerTypeListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
serverTypes := make([]*ServerType, 0, len(body.ServerTypes))
for _, s := range body.ServerTypes {
serverTypes = append(serverTypes, ServerTypeFromSchema(s))
}
return serverTypes, resp, nil
return allFromSchemaFunc(respBody.ServerTypes, ServerTypeFromSchema), resp, nil
}
// All returns all server types.
@ -140,20 +133,8 @@ func (c *ServerTypeClient) All(ctx context.Context) ([]*ServerType, error) {
// AllWithOpts returns all server types for the given options.
func (c *ServerTypeClient) AllWithOpts(ctx context.Context, opts ServerTypeListOpts) ([]*ServerType, error) {
allServerTypes := []*ServerType{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*ServerType, *Response, error) {
opts.Page = page
serverTypes, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allServerTypes = append(allServerTypes, serverTypes...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allServerTypes, nil
}

View File

@ -1,15 +1,12 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -30,50 +27,40 @@ type SSHKeyClient struct {
// GetByID retrieves a SSH key by its ID. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByID(ctx context.Context, id int64) (*SSHKey, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/ssh_keys/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/ssh_keys/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.SSHKeyGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.SSHKeyGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return SSHKeyFromSchema(body.SSHKey), resp, nil
return SSHKeyFromSchema(respBody.SSHKey), resp, nil
}
// GetByName retrieves a SSH key by its name. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByName(ctx context.Context, name string) (*SSHKey, *Response, error) {
if name == "" {
return nil, nil, nil
}
sshKeys, response, err := c.List(ctx, SSHKeyListOpts{Name: name})
if len(sshKeys) == 0 {
return nil, response, err
}
return sshKeys[0], response, err
return firstByName(name, func() ([]*SSHKey, *Response, error) {
return c.List(ctx, SSHKeyListOpts{Name: name})
})
}
// GetByFingerprint retreives a SSH key by its fingerprint. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) GetByFingerprint(ctx context.Context, fingerprint string) (*SSHKey, *Response, error) {
sshKeys, response, err := c.List(ctx, SSHKeyListOpts{Fingerprint: fingerprint})
if len(sshKeys) == 0 {
return nil, response, err
}
return sshKeys[0], response, err
return firstBy(func() ([]*SSHKey, *Response, error) {
return c.List(ctx, SSHKeyListOpts{Fingerprint: fingerprint})
})
}
// Get retrieves a SSH key by its ID if the input can be parsed as an integer, otherwise it
// retrieves a SSH key by its name. If the SSH key does not exist, nil is returned.
func (c *SSHKeyClient) Get(ctx context.Context, idOrName string) (*SSHKey, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// SSHKeyListOpts specifies options for listing SSH keys.
@ -103,22 +90,17 @@ func (l SSHKeyListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *SSHKeyClient) List(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, *Response, error) {
path := "/ssh_keys?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/ssh_keys?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.SSHKeyListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.SSHKeyListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
sshKeys := make([]*SSHKey, 0, len(body.SSHKeys))
for _, s := range body.SSHKeys {
sshKeys = append(sshKeys, SSHKeyFromSchema(s))
}
return sshKeys, resp, nil
return allFromSchemaFunc(respBody.SSHKeys, SSHKeyFromSchema), resp, nil
}
// All returns all SSH keys.
@ -128,22 +110,10 @@ func (c *SSHKeyClient) All(ctx context.Context) ([]*SSHKey, error) {
// AllWithOpts returns all SSH keys with the given options.
func (c *SSHKeyClient) AllWithOpts(ctx context.Context, opts SSHKeyListOpts) ([]*SSHKey, error) {
allSSHKeys := []*SSHKey{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*SSHKey, *Response, error) {
opts.Page = page
sshKeys, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allSSHKeys = append(allSSHKeys, sshKeys...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allSSHKeys, nil
}
// SSHKeyCreateOpts specifies parameters for creating a SSH key.
@ -156,16 +126,21 @@ type SSHKeyCreateOpts struct {
// Validate checks if options are valid.
func (o SSHKeyCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
if o.PublicKey == "" {
return errors.New("missing public key")
return missingField(o, "PublicKey")
}
return nil
}
// Create creates a new SSH key with the given options.
func (c *SSHKeyClient) Create(ctx context.Context, opts SSHKeyCreateOpts) (*SSHKey, *Response, error) {
const opPath = "/ssh_keys"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := opPath
if err := opts.Validate(); err != nil {
return nil, nil, err
}
@ -176,31 +151,23 @@ func (c *SSHKeyClient) Create(ctx context.Context, opts SSHKeyCreateOpts) (*SSHK
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
req, err := c.client.NewRequest(ctx, "POST", "/ssh_keys", bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.SSHKeyCreateResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.SSHKeyCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return SSHKeyFromSchema(respBody.SSHKey), resp, nil
}
// Delete deletes a SSH key.
func (c *SSHKeyClient) Delete(ctx context.Context, sshKey *SSHKey) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/ssh_keys/%d", sshKey.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/ssh_keys/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, sshKey.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// SSHKeyUpdateOpts specifies options for updating a SSH key.
@ -211,27 +178,22 @@ type SSHKeyUpdateOpts struct {
// Update updates a SSH key.
func (c *SSHKeyClient) Update(ctx context.Context, sshKey *SSHKey, opts SSHKeyUpdateOpts) (*SSHKey, *Response, error) {
const opPath = "/ssh_keys/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, sshKey.ID)
reqBody := schema.SSHKeyUpdateRequest{
Name: opts.Name,
}
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/ssh_keys/%d", sshKey.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.SSHKeyUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.SSHKeyUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return SSHKeyFromSchema(respBody.SSHKey), resp, nil
}

View File

@ -1,15 +1,12 @@
package hcloud
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"time"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/exp/ctxutil"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider/hetzner/hcloud-go/hcloud/schema"
)
@ -57,41 +54,33 @@ const (
// GetByID retrieves a volume by its ID. If the volume does not exist, nil is returned.
func (c *VolumeClient) GetByID(ctx context.Context, id int64) (*Volume, *Response, error) {
req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("/volumes/%d", id), nil)
if err != nil {
return nil, nil, err
}
const opPath = "/volumes/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
var body schema.VolumeGetResponse
resp, err := c.client.Do(req, &body)
reqPath := fmt.Sprintf(opPath, id)
respBody, resp, err := getRequest[schema.VolumeGetResponse](ctx, c.client, reqPath)
if err != nil {
if IsError(err, ErrorCodeNotFound) {
return nil, resp, nil
}
return nil, nil, err
return nil, resp, err
}
return VolumeFromSchema(body.Volume), resp, nil
return VolumeFromSchema(respBody.Volume), resp, nil
}
// GetByName retrieves a volume by its name. If the volume does not exist, nil is returned.
func (c *VolumeClient) GetByName(ctx context.Context, name string) (*Volume, *Response, error) {
if name == "" {
return nil, nil, nil
}
volumes, response, err := c.List(ctx, VolumeListOpts{Name: name})
if len(volumes) == 0 {
return nil, response, err
}
return volumes[0], response, err
return firstByName(name, func() ([]*Volume, *Response, error) {
return c.List(ctx, VolumeListOpts{Name: name})
})
}
// Get retrieves a volume by its ID if the input can be parsed as an integer, otherwise it
// retrieves a volume by its name. If the volume does not exist, nil is returned.
func (c *VolumeClient) Get(ctx context.Context, idOrName string) (*Volume, *Response, error) {
if id, err := strconv.ParseInt(idOrName, 10, 64); err == nil {
return c.GetByID(ctx, id)
}
return c.GetByName(ctx, idOrName)
return getByIDOrName(ctx, c.GetByID, c.GetByName, idOrName)
}
// VolumeListOpts specifies options for listing volumes.
@ -121,22 +110,17 @@ func (l VolumeListOpts) values() url.Values {
// Please note that filters specified in opts are not taken into account
// when their value corresponds to their zero value or when they are empty.
func (c *VolumeClient) List(ctx context.Context, opts VolumeListOpts) ([]*Volume, *Response, error) {
path := "/volumes?" + opts.values().Encode()
req, err := c.client.NewRequest(ctx, "GET", path, nil)
const opPath = "/volumes?%s"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, opts.values().Encode())
respBody, resp, err := getRequest[schema.VolumeListResponse](ctx, c.client, reqPath)
if err != nil {
return nil, nil, err
return nil, resp, err
}
var body schema.VolumeListResponse
resp, err := c.client.Do(req, &body)
if err != nil {
return nil, nil, err
}
volumes := make([]*Volume, 0, len(body.Volumes))
for _, s := range body.Volumes {
volumes = append(volumes, VolumeFromSchema(s))
}
return volumes, resp, nil
return allFromSchemaFunc(respBody.Volumes, VolumeFromSchema), resp, nil
}
// All returns all volumes.
@ -146,22 +130,10 @@ func (c *VolumeClient) All(ctx context.Context) ([]*Volume, error) {
// AllWithOpts returns all volumes with the given options.
func (c *VolumeClient) AllWithOpts(ctx context.Context, opts VolumeListOpts) ([]*Volume, error) {
allVolumes := []*Volume{}
err := c.client.all(func(page int) (*Response, error) {
return iterPages(func(page int) ([]*Volume, *Response, error) {
opts.Page = page
volumes, resp, err := c.List(ctx, opts)
if err != nil {
return resp, err
}
allVolumes = append(allVolumes, volumes...)
return resp, nil
return c.List(ctx, opts)
})
if err != nil {
return nil, err
}
return allVolumes, nil
}
// VolumeCreateOpts specifies parameters for creating a volume.
@ -178,19 +150,19 @@ type VolumeCreateOpts struct {
// Validate checks if options are valid.
func (o VolumeCreateOpts) Validate() error {
if o.Name == "" {
return errors.New("missing name")
return missingField(o, "Name")
}
if o.Size <= 0 {
return errors.New("size must be greater than 0")
return invalidFieldValue(o, "Size", o.Size)
}
if o.Server == nil && o.Location == nil {
return errors.New("one of server or location must be provided")
return missingOneOfFields(o, "Server", "Location")
}
if o.Server != nil && o.Location != nil {
return errors.New("only one of server or location must be provided")
return mutuallyExclusiveFields(o, "Server", "Location")
}
if o.Server == nil && (o.Automount != nil && *o.Automount) {
return errors.New("server must be provided when automount is true")
return missingRequiredTogetherFields(o, "Automount", "Server")
}
return nil
}
@ -204,8 +176,15 @@ type VolumeCreateResult struct {
// Create creates a new volume with the given options.
func (c *VolumeClient) Create(ctx context.Context, opts VolumeCreateOpts) (VolumeCreateResult, *Response, error) {
const opPath = "/volumes"
ctx = ctxutil.SetOpPath(ctx, opPath)
result := VolumeCreateResult{}
reqPath := opPath
if err := opts.Validate(); err != nil {
return VolumeCreateResult{}, nil, err
return result, nil, err
}
reqBody := schema.VolumeCreateRequest{
Name: opts.Name,
@ -220,48 +199,33 @@ func (c *VolumeClient) Create(ctx context.Context, opts VolumeCreateOpts) (Volum
reqBody.Server = Ptr(opts.Server.ID)
}
if opts.Location != nil {
if opts.Location.ID != 0 {
reqBody.Location = opts.Location.ID
} else {
reqBody.Location = opts.Location.Name
if opts.Location.ID != 0 || opts.Location.Name != "" {
reqBody.Location = &schema.IDOrName{ID: opts.Location.ID, Name: opts.Location.Name}
}
}
reqBodyData, err := json.Marshal(reqBody)
respBody, resp, err := postRequest[schema.VolumeCreateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return VolumeCreateResult{}, nil, err
return result, resp, err
}
req, err := c.client.NewRequest(ctx, "POST", "/volumes", bytes.NewReader(reqBodyData))
if err != nil {
return VolumeCreateResult{}, nil, err
}
var respBody schema.VolumeCreateResponse
resp, err := c.client.Do(req, &respBody)
if err != nil {
return VolumeCreateResult{}, resp, err
}
var action *Action
result.Volume = VolumeFromSchema(respBody.Volume)
if respBody.Action != nil {
action = ActionFromSchema(*respBody.Action)
result.Action = ActionFromSchema(*respBody.Action)
}
result.NextActions = ActionsFromSchema(respBody.NextActions)
return VolumeCreateResult{
Volume: VolumeFromSchema(respBody.Volume),
Action: action,
NextActions: ActionsFromSchema(respBody.NextActions),
}, resp, nil
return result, resp, nil
}
// Delete deletes a volume.
func (c *VolumeClient) Delete(ctx context.Context, volume *Volume) (*Response, error) {
req, err := c.client.NewRequest(ctx, "DELETE", fmt.Sprintf("/volumes/%d", volume.ID), nil)
if err != nil {
return nil, err
}
return c.client.Do(req, nil)
const opPath = "/volumes/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
return deleteRequestNoResult(ctx, c.client, reqPath)
}
// VolumeUpdateOpts specifies options for updating a volume.
@ -272,28 +236,23 @@ type VolumeUpdateOpts struct {
// Update updates a volume.
func (c *VolumeClient) Update(ctx context.Context, volume *Volume, opts VolumeUpdateOpts) (*Volume, *Response, error) {
const opPath = "/volumes/%d"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeUpdateRequest{
Name: opts.Name,
}
if opts.Labels != nil {
reqBody.Labels = &opts.Labels
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d", volume.ID)
req, err := c.client.NewRequest(ctx, "PUT", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeUpdateResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := putRequest[schema.VolumeUpdateResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return VolumeFromSchema(respBody.Volume), resp, nil
}
@ -305,27 +264,21 @@ type VolumeAttachOpts struct {
// AttachWithOpts attaches a volume to a server.
func (c *VolumeClient) AttachWithOpts(ctx context.Context, volume *Volume, opts VolumeAttachOpts) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/attach"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionAttachVolumeRequest{
Server: opts.Server.ID,
Automount: opts.Automount,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/attach", volume.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.VolumeActionAttachVolumeResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.VolumeActionAttachVolumeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -336,23 +289,18 @@ func (c *VolumeClient) Attach(ctx context.Context, volume *Volume, server *Serve
// Detach detaches a volume from a server.
func (c *VolumeClient) Detach(ctx context.Context, volume *Volume) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/detach"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
var reqBody schema.VolumeActionDetachVolumeRequest
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/detach", volume.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
var respBody schema.VolumeActionDetachVolumeResponse
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.VolumeActionDetachVolumeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, nil
}
@ -363,48 +311,38 @@ type VolumeChangeProtectionOpts struct {
// ChangeProtection changes the resource protection level of a volume.
func (c *VolumeClient) ChangeProtection(ctx context.Context, volume *Volume, opts VolumeChangeProtectionOpts) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/change_protection"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionChangeProtectionRequest{
Delete: opts.Delete,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/change_protection", volume.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeActionChangeProtectionResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.VolumeActionChangeProtectionResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
return ActionFromSchema(respBody.Action), resp, nil
}
// Resize changes the size of a volume.
func (c *VolumeClient) Resize(ctx context.Context, volume *Volume, size int) (*Action, *Response, error) {
const opPath = "/volumes/%d/actions/resize"
ctx = ctxutil.SetOpPath(ctx, opPath)
reqPath := fmt.Sprintf(opPath, volume.ID)
reqBody := schema.VolumeActionResizeVolumeRequest{
Size: size,
}
reqBodyData, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, err
}
path := fmt.Sprintf("/volumes/%d/actions/resize", volume.ID)
req, err := c.client.NewRequest(ctx, "POST", path, bytes.NewReader(reqBodyData))
if err != nil {
return nil, nil, err
}
respBody := schema.VolumeActionResizeVolumeResponse{}
resp, err := c.client.Do(req, &respBody)
respBody, resp, err := postRequest[schema.VolumeActionResizeVolumeResponse](ctx, c.client, reqPath, reqBody)
if err != nil {
return nil, resp, err
}
return ActionFromSchema(respBody.Action), resp, err
}

View File

@ -16,8 +16,12 @@ type IActionClient interface {
// when their value corresponds to their zero value or when they are empty.
List(ctx context.Context, opts ActionListOpts) ([]*Action, *Response, error)
// All returns all actions.
//
// Deprecated: It is required to pass in a list of IDs since 30 January 2025. Please use [ActionClient.AllWithOpts] instead.
All(ctx context.Context) ([]*Action, error)
// AllWithOpts returns all actions for the given options.
//
// It is required to set [ActionListOpts.ID]. Any other fields set in the opts are ignored.
AllWithOpts(ctx context.Context, opts ActionListOpts) ([]*Action, error)
// WatchOverallProgress watches several actions' progress until they complete
// with success or error. This watching happens in a goroutine and updates are
@ -35,7 +39,7 @@ type IActionClient interface {
// timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed.
//
// WatchOverallProgress uses the [WithPollBackoffFunc] of the [Client] to wait
// WatchOverallProgress uses the [WithPollOpts] of the [Client] to wait
// until sending the next request.
//
// Deprecated: WatchOverallProgress is deprecated, use [WaitForFunc] instead.
@ -56,19 +60,19 @@ type IActionClient interface {
// timeout, use the [context.Context]. Once the method has stopped watching,
// both returned channels are closed.
//
// WatchProgress uses the [WithPollBackoffFunc] of the [Client] to wait until
// WatchProgress uses the [WithPollOpts] of the [Client] to wait until
// sending the next request.
//
// Deprecated: WatchProgress is deprecated, use [WaitForFunc] instead.
WatchProgress(ctx context.Context, action *Action) (<-chan int, <-chan error)
// WaitForFunc waits until all actions are completed by polling the API at the interval
// defined by [WithPollBackoffFunc]. An action is considered as complete when its status is
// defined by [WithPollOpts]. An action is considered as complete when its status is
// either [ActionStatusSuccess] or [ActionStatusError].
//
// The handleUpdate callback is called every time an action is updated.
WaitForFunc(ctx context.Context, handleUpdate func(update *Action) error, actions ...*Action) error
// WaitFor waits until all actions succeed by polling the API at the interval defined by
// [WithPollBackoffFunc]. An action is considered as succeeded when its status is either
// [WithPollOpts]. An action is considered as succeeded when its status is either
// [ActionStatusSuccess].
//
// If a single action fails, the function will stop waiting and the error set in the

View File

@ -64,7 +64,7 @@ type ILoadBalancerClient interface {
// ChangeType changes a Load Balancer's type.
ChangeType(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerChangeTypeOpts) (*Action, *Response, error)
// GetMetrics obtains metrics for a Load Balancer.
GetMetrics(ctx context.Context, lb *LoadBalancer, opts LoadBalancerGetMetricsOpts) (*LoadBalancerMetrics, *Response, error)
GetMetrics(ctx context.Context, loadBalancer *LoadBalancer, opts LoadBalancerGetMetricsOpts) (*LoadBalancerMetrics, *Response, error)
// ChangeDNSPtr changes or resets the reverse DNS pointer for a Load Balancer.
// Pass a nil ptr to reset the reverse DNS pointer to its default value.
ChangeDNSPtr(ctx context.Context, lb *LoadBalancer, ip string, ptr *string) (*Action, *Response, error)

View File

@ -27,11 +27,11 @@ type IPrimaryIPClient interface {
// AllWithOpts returns all Primary IPs for the given options.
AllWithOpts(ctx context.Context, opts PrimaryIPListOpts) ([]*PrimaryIP, error)
// Create creates a Primary IP.
Create(ctx context.Context, reqBody PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error)
Create(ctx context.Context, opts PrimaryIPCreateOpts) (*PrimaryIPCreateResult, *Response, error)
// Delete deletes a Primary IP.
Delete(ctx context.Context, primaryIP *PrimaryIP) (*Response, error)
// Update updates a Primary IP.
Update(ctx context.Context, primaryIP *PrimaryIP, reqBody PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error)
Update(ctx context.Context, primaryIP *PrimaryIP, opts PrimaryIPUpdateOpts) (*PrimaryIP, *Response, error)
// Assign a Primary IP to a resource.
Assign(ctx context.Context, opts PrimaryIPAssignOpts) (*Action, *Response, error)
// Unassign a Primary IP from a resource.

File diff suppressed because it is too large Load Diff

View File

@ -62,9 +62,8 @@ type IServerClient interface {
AttachISO(ctx context.Context, server *Server, iso *ISO) (*Action, *Response, error)
// DetachISO detaches the currently attached ISO from a server.
DetachISO(ctx context.Context, server *Server) (*Action, *Response, error)
// EnableBackup enables backup for a server. Pass in an empty backup window to let the
// API pick a window for you. See the API documentation at docs.hetzner.cloud for a list
// of valid backup windows.
// EnableBackup enables backup for a server.
// The window parameter is deprecated and will be ignored.
EnableBackup(ctx context.Context, server *Server, window string) (*Action, *Response, error)
// DisableBackup disables backup for a server.
DisableBackup(ctx context.Context, server *Server) (*Action, *Response, error)

View File

@ -95,7 +95,9 @@ func newManager() (*hetznerManager, error) {
hcloud.WithToken(token),
hcloud.WithHTTPClient(httpClient),
hcloud.WithApplication("cluster-autoscaler", version.ClusterAutoscalerVersion),
hcloud.WithPollBackoffFunc(hcloud.ExponentialBackoff(2, 500*time.Millisecond)),
hcloud.WithPollOpts(hcloud.PollOpts{
BackoffFunc: hcloud.ExponentialBackoff(2, 500*time.Millisecond),
}),
hcloud.WithDebugWriter(&debugWriter{}),
}
@ -252,7 +254,7 @@ func (m *hetznerManager) deleteByNode(node *apiv1.Node) error {
}
func (m *hetznerManager) deleteServer(server *hcloud.Server) error {
_, err := m.client.Server.Delete(m.apiCallContext, server)
_, _, err := m.client.Server.DeleteWithResult(m.apiCallContext, server)
return err
}

View File

@ -34,7 +34,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/utils/units"
"k8s.io/client-go/rest"
klog "k8s.io/klog/v2"
"k8s.io/klog/v2"
kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config"
scheduler_config "k8s.io/kubernetes/pkg/scheduler/apis/config"
)
@ -269,12 +269,12 @@ func createAutoscalingOptions() config.AutoscalingOptions {
klog.Fatalf("Failed to get scheduler config: %v", err)
}
if isFlagPassed("drain-priority-config") && isFlagPassed("max-graceful-termination-sec") {
if pflag.CommandLine.Changed("drain-priority-config") && pflag.CommandLine.Changed("max-graceful-termination-sec") {
klog.Fatalf("Invalid configuration, could not use --drain-priority-config together with --max-graceful-termination-sec")
}
var drainPriorityConfigMap []kubelet_config.ShutdownGracePeriodByPodPriority
if isFlagPassed("drain-priority-config") {
if pflag.CommandLine.Changed("drain-priority-config") {
drainPriorityConfigMap = parseShutdownGracePeriodsAndPriorities(*drainPriorityConfig)
if len(drainPriorityConfigMap) == 0 {
klog.Fatalf("Invalid configuration, parsing --drain-priority-config")
@ -409,16 +409,6 @@ func createAutoscalingOptions() config.AutoscalingOptions {
}
}
func isFlagPassed(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}
func minMaxFlagString(min, max int64) string {
return fmt.Sprintf("%v:%v", min, max)
}

View File

@ -17,11 +17,15 @@ limitations under the License.
package flags
import (
"flag"
"testing"
"k8s.io/autoscaler/cluster-autoscaler/config"
kubelet_config "k8s.io/kubernetes/pkg/kubelet/apis/config"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
)
@ -146,3 +150,47 @@ func TestParseShutdownGracePeriodsAndPriorities(t *testing.T) {
})
}
}
func TestCreateAutoscalingOptions(t *testing.T) {
for _, tc := range []struct {
testName string
flags []string
wantOptionsAsserter func(t *testing.T, gotOptions config.AutoscalingOptions)
}{
{
testName: "DrainPriorityConfig defaults to an empty list when the flag isn't passed",
flags: []string{},
wantOptionsAsserter: func(t *testing.T, gotOptions config.AutoscalingOptions) {
if diff := cmp.Diff([]kubelet_config.ShutdownGracePeriodByPodPriority{}, gotOptions.DrainPriorityConfig, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("createAutoscalingOptions(): unexpected DrainPriorityConfig field (-want +got): %s", diff)
}
},
},
{
testName: "DrainPriorityConfig is parsed correctly when the flag passed",
flags: []string{"--drain-priority-config", "5000:60,3000:50,0:40"},
wantOptionsAsserter: func(t *testing.T, gotOptions config.AutoscalingOptions) {
wantConfig := []kubelet_config.ShutdownGracePeriodByPodPriority{
{Priority: 5000, ShutdownGracePeriodSeconds: 60},
{Priority: 3000, ShutdownGracePeriodSeconds: 50},
{Priority: 0, ShutdownGracePeriodSeconds: 40},
}
if diff := cmp.Diff(wantConfig, gotOptions.DrainPriorityConfig); diff != "" {
t.Errorf("createAutoscalingOptions(): unexpected DrainPriorityConfig field (-want +got): %s", diff)
}
},
},
} {
t.Run(tc.testName, func(t *testing.T) {
pflag.CommandLine = pflag.NewFlagSet("test", pflag.ExitOnError)
pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
err := pflag.CommandLine.Parse(tc.flags)
if err != nil {
t.Errorf("pflag.CommandLine.Parse() got unexpected error: %v", err)
}
gotOptions := createAutoscalingOptions()
tc.wantOptionsAsserter(t, gotOptions)
})
}
}

View File

@ -6,9 +6,9 @@ require (
cloud.google.com/go/compute/metadata v0.5.0
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible
github.com/Azure/azure-sdk-for-go-extensions v0.1.6
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2
github.com/Azure/go-autorest/autorest v0.11.29
github.com/Azure/go-autorest/autorest/adal v0.9.24
github.com/Azure/go-autorest/autorest/azure/auth v0.5.13
@ -32,6 +32,7 @@ require (
github.com/stretchr/testify v1.10.0
github.com/vburenin/ifacemaker v1.2.1
go.uber.org/mock v0.4.0
golang.org/x/crypto v0.35.0
golang.org/x/net v0.33.0
golang.org/x/oauth2 v0.27.0
golang.org/x/sys v0.30.0
@ -62,11 +63,12 @@ require (
require (
cel.dev/expr v0.19.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5 v5.6.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 // indirect
@ -186,7 +188,6 @@ require (
go.opentelemetry.io/proto/otlp v1.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.35.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sync v0.11.0 // indirect

View File

@ -4,17 +4,16 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
github.com/Azure/azure-sdk-for-go v46.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go-extensions v0.1.6 h1:EXGvDcj54u98XfaI/Cy65Ds6vNsIJeGKYf0eNLB1y4Q=
github.com/Azure/azure-sdk-for-go-extensions v0.1.6/go.mod h1:27StPiXJp6Xzkq2AQL7gPK7VC0hgmCnUKlco1dO1jaM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 h1:FDif4R1+UUR+00q6wquyX90K7A8dN+R5E8GEadoP7sU=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2/go.mod h1:aiYBYui4BJ/BJCAIKs92XiPyQfTaBWqvHujDwKb6CBU=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 h1:xnO4sFyG8UH2fElBkcqLTOZsAajvKfnSlgBBW8dXYjw=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0/go.mod h1:XD3DIOOVgBCO03OleB1fHjgktVRFxlT++KwKgIOewdM=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw=
@ -25,10 +24,14 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armconta
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0/go.mod h1:E7ltexgRDmeJ0fJWv0D/HLwY2xbDdN+uv+X2uZtOx3w=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0 h1:1u/K2BFv0MwkG6he8RYuUcbbeK22rkoZbg4lKa/msZU=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0/go.mod h1:U5gpsREQZE6SLk1t/cFfc1eMhYAlYpEzvaYXuDfefy8=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1 h1:iqhrjj9w9/AQZsHjaOVyloamkeAFRbWI0iHNy6INMYk=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 h1:0nGmzwBv5ougvzfGPCO2ljFRHvun57KpNrVCMrlk0ns=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2 h1:re+BEe/OafvSyRy2vM+Fyu+EcUK34O2o/Fa6WO3ITZM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2/go.mod h1:5zx285T5OLk+iQbfOuexhhO7J6dfzkqVkFgS/+s7XaA=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 h1:HlZMUZW8S4P9oob1nCHxCCKrytxyLc+24nUJGssoEto=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0/go.mod h1:StGsLbuJh06Bd8IBfnAlIFV3fLb+gkczONWf15hpX2E=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o=
@ -45,12 +48,9 @@ github.com/Azure/go-armbalancer v0.0.2 h1:NVnxsTWHI5/fEzL6k6TjxPUfcB/3Si3+HFOZXO
github.com/Azure/go-armbalancer v0.0.2/go.mod h1:yTg7MA/8YnfKQc9o97tzAJ7fbdVkod1xGsIvKmhYPRE=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24=
github.com/Azure/go-autorest/autorest v0.11.4/go.mod h1:JFgpikqFJ/MleTTxwepExTKnFUKKszPS8UavbQYUMuw=
github.com/Azure/go-autorest/autorest v0.11.28/go.mod h1:MrkzG3Y3AH668QyF9KRk5neJnGgmhQ6krbhR8Q5eMvA=
github.com/Azure/go-autorest/autorest v0.11.29 h1:I4+HL/JDvErx2LjyzaVxllw2lRDB5/BT2Bm4g20iqYw=
github.com/Azure/go-autorest/autorest v0.11.29/go.mod h1:ZtEzC4Jy2JDrZLxvWs8LrBWEBycl1hbT1eknI8MtfAs=
github.com/Azure/go-autorest/autorest/adal v0.9.0/go.mod h1:/c022QCutn2P7uY+/oQWWNcK9YU+MH96NgK+jErpbcg=
github.com/Azure/go-autorest/autorest/adal v0.9.2/go.mod h1:/3SMAM86bP6wC9Ev35peQDUeqFZBMH07vvUOmg4z/fE=
github.com/Azure/go-autorest/autorest/adal v0.9.18/go.mod h1:XVVeme+LZwABT8K5Lc3hA4nAe8LDBVle26gTrguhhPQ=
github.com/Azure/go-autorest/autorest/adal v0.9.22/go.mod h1:XuAbAEUv2Tta//+voMI038TrJBqjKam0me7qR+L8Cmk=
github.com/Azure/go-autorest/autorest/adal v0.9.24 h1:BHZfgGsGwdkHDyZdtQRQk1WeUdW0m2WPAwuHZwUi5i4=
@ -61,22 +61,17 @@ github.com/Azure/go-autorest/autorest/azure/cli v0.4.6 h1:w77/uPk80ZET2F+AfQExZy
github.com/Azure/go-autorest/autorest/azure/cli v0.4.6/go.mod h1:piCfgPho7BiIDdEQ1+g4VmKyD5y+p/XtSNqE6Hc4QD0=
github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw=
github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74=
github.com/Azure/go-autorest/autorest/mocks v0.4.0/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9AGVwYHG9/fkDYhtAfw=
github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU=
github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk=
github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE=
github.com/Azure/go-autorest/autorest/validation v0.3.0/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E=
github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac=
github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E=
github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg=
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/Azure/skewer v0.0.14 h1:0mzUJhspECkajYyynYsOCp//E2PSnYXrgP45bcskqfQ=
github.com/Azure/skewer v0.0.14/go.mod h1:6WTecuPyfGtuvS8Mh4JYWuHhO4kcWycGfsUBB+XTFG4=
github.com/Azure/skewer v0.0.19 h1:+qA1z8isKmlNkhAwZErNS2wD2jaemSk9NszYKr8dddU=
github.com/Azure/skewer v0.0.19/go.mod h1:LVH7jmduRKmPj8YcIz7V4f53xJEntjweL4aoLyChkwk=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
@ -139,7 +134,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/digitalocean/godo v1.27.0 h1:78iE9oVvTnAEqhMip2UHFvL01b8LJcydbNUpr0cAmN4=
github.com/digitalocean/godo v1.27.0/go.mod h1:iJnN9rVu6K5LioLxLimlq0uRI+y/eAQjROUmeU/r0hY=
github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U=
@ -232,7 +226,6 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@ -442,7 +435,6 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=

View File

@ -504,9 +504,7 @@ func UpdateDurationFromStart(label FunctionLabel, start time.Time) {
// UpdateDuration records the duration of the step identified by the label
func UpdateDuration(label FunctionLabel, duration time.Duration) {
// TODO(maciekpytel): remove second condition if we manage to get
// asynchronous node drain
if duration > LogLongDurationThreshold && label != ScaleDown {
if duration > LogLongDurationThreshold {
klog.V(4).Infof("Function %s took %v to complete", label, duration)
}
functionDuration.WithLabelValues(string(label)).Observe(duration.Seconds())

View File

@ -24,6 +24,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/simulator/clustersnapshot"
drautils "k8s.io/autoscaler/cluster-autoscaler/simulator/dynamicresources/utils"
"k8s.io/autoscaler/cluster-autoscaler/simulator/framework"
"k8s.io/dynamic-resource-allocation/resourceclaim"
schedulerframework "k8s.io/kubernetes/pkg/scheduler/framework"
)
@ -245,7 +246,7 @@ func (s *PredicateSnapshot) modifyResourceClaimsForNewPod(podInfo *framework.Pod
// so we don't add them. The claims should already be allocated in the provided PodInfo.
var podOwnedClaims []*resourceapi.ResourceClaim
for _, claim := range podInfo.NeededResourceClaims {
if ownerName, _ := drautils.ClaimOwningPod(claim); ownerName != "" {
if err := resourceclaim.IsForPod(podInfo.Pod, claim); err == nil {
podOwnedClaims = append(podOwnedClaims, claim)
}
}

View File

@ -178,7 +178,7 @@ func (s Snapshot) RemovePodOwnedClaims(pod *apiv1.Pod) {
// The claim isn't tracked in the snapshot for some reason. Nothing to remove/modify, so continue to the next claim.
continue
}
if ownerName, ownerUid := drautils.ClaimOwningPod(claim); ownerName == pod.Name && ownerUid == pod.UID {
if err := resourceclaim.IsForPod(pod, claim); err == nil {
delete(s.resourceClaimsById, claimId)
} else {
drautils.ClearPodReservationInPlace(claim, pod)
@ -214,9 +214,7 @@ func (s Snapshot) UnreservePodClaims(pod *apiv1.Pod) error {
return err
}
for _, claim := range claims {
ownerPodName, ownerPodUid := drautils.ClaimOwningPod(claim)
podOwnedClaim := ownerPodName == pod.Name && ownerPodUid == ownerPodUid
podOwnedClaim := resourceclaim.IsForPod(pod, claim) == nil
drautils.ClearPodReservationInPlace(claim, pod)
if podOwnedClaim || !drautils.ClaimInUse(claim) {
drautils.DeallocateClaimInPlace(claim)

View File

@ -22,23 +22,10 @@ import (
apiv1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/component-helpers/scheduling/corev1"
resourceclaim "k8s.io/dynamic-resource-allocation/resourceclaim"
"k8s.io/utils/ptr"
)
// ClaimOwningPod returns the name and UID of the Pod owner of the provided claim. If the claim isn't
// owned by a Pod, empty strings are returned.
func ClaimOwningPod(claim *resourceapi.ResourceClaim) (string, types.UID) {
for _, owner := range claim.OwnerReferences {
if ptr.Deref(owner.Controller, false) && owner.APIVersion == "v1" && owner.Kind == "Pod" {
return owner.Name, owner.UID
}
}
return "", ""
}
// ClaimAllocated returns whether the provided claim is allocated.
func ClaimAllocated(claim *resourceapi.ResourceClaim) bool {
return claim.Status.Allocation != nil

View File

@ -25,84 +25,8 @@ import (
apiv1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
)
func TestClaimOwningPod(t *testing.T) {
truePtr := true
for _, tc := range []struct {
testName string
claim *resourceapi.ResourceClaim
wantName string
wantUid types.UID
}{
{
testName: "claim with no owners",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with non-Pod owners",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController", Controller: &truePtr},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
},
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with a Pod non-controller owner",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController"},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
{Name: "owner3", UID: "owner3uid", APIVersion: "v1", Kind: "Pod"},
},
},
},
wantName: "",
wantUid: "",
},
{
testName: "claim with a Pod controller owner",
claim: &resourceapi.ResourceClaim{
ObjectMeta: metav1.ObjectMeta{
Name: "claim", UID: "claim", Namespace: "default",
OwnerReferences: []metav1.OwnerReference{
{Name: "owner1", UID: "owner1uid", APIVersion: "v1", Kind: "ReplicationController"},
{Name: "owner2", UID: "owner2uid", APIVersion: "v1", Kind: "ConfigMap"},
{Name: "owner3", UID: "owner3uid", APIVersion: "v1", Kind: "Pod", Controller: &truePtr},
},
},
},
wantName: "owner3",
wantUid: "owner3uid",
},
} {
t.Run(tc.testName, func(t *testing.T) {
name, uid := ClaimOwningPod(tc.claim)
if tc.wantName != name {
t.Errorf("ClaimOwningPod(): unexpected output name: want %s, got %s", tc.wantName, name)
}
if tc.wantUid != uid {
t.Errorf("ClaimOwningPod(): unexpected output UID: want %v, got %v", tc.wantUid, uid)
}
})
}
}
func TestClaimAllocated(t *testing.T) {
for _, tc := range []struct {
testName string

View File

@ -23,6 +23,7 @@ import (
resourceapi "k8s.io/api/resource/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/uuid"
"k8s.io/dynamic-resource-allocation/resourceclaim"
"k8s.io/utils/set"
)
@ -61,7 +62,7 @@ func SanitizedNodeResourceSlices(nodeLocalSlices []*resourceapi.ResourceSlice, n
func SanitizedPodResourceClaims(newOwner, oldOwner *v1.Pod, claims []*resourceapi.ResourceClaim, nameSuffix, newNodeName, oldNodeName string, oldNodePoolNames set.Set[string]) ([]*resourceapi.ResourceClaim, error) {
var result []*resourceapi.ResourceClaim
for _, claim := range claims {
if ownerName, ownerUid := ClaimOwningPod(claim); ownerName != oldOwner.Name || ownerUid != oldOwner.UID {
if err := resourceclaim.IsForPod(oldOwner, claim); err != nil {
// Only claims owned by the pod are bound to its lifecycle. The lifecycle of other claims is independent, and they're most likely shared
// by multiple pods. They shouldn't be sanitized or duplicated - just add unchanged to the result.
result = append(result, claim)

View File

@ -49,9 +49,12 @@ func SanitizedTemplateNodeInfoFromNodeGroup(nodeGroup nodeGroupTemplateNodeInfoG
if err != nil {
return nil, errors.ToAutoscalerError(errors.CloudProviderError, err).AddPrefix("failed to obtain template NodeInfo from node group %q: ", nodeGroup.Id())
}
labels.UpdateDeprecatedLabels(baseNodeInfo.Node().ObjectMeta.Labels)
return SanitizedTemplateNodeInfoFromNodeInfo(baseNodeInfo, nodeGroup.Id(), daemonsets, true, taintConfig)
sanitizedNodeInfo, aErr := SanitizedTemplateNodeInfoFromNodeInfo(baseNodeInfo, nodeGroup.Id(), daemonsets, true, taintConfig)
if aErr != nil {
return nil, aErr
}
labels.UpdateDeprecatedLabels(sanitizedNodeInfo.Node().Labels)
return sanitizedNodeInfo, nil
}
// SanitizedTemplateNodeInfoFromNodeInfo returns a template NodeInfo object based on a real example NodeInfo from the cluster. The template is sanitized, and only

View File

@ -40,6 +40,7 @@ import (
"k8s.io/autoscaler/cluster-autoscaler/utils/labels"
"k8s.io/autoscaler/cluster-autoscaler/utils/taints"
. "k8s.io/autoscaler/cluster-autoscaler/utils/test"
"k8s.io/dynamic-resource-allocation/resourceclaim"
)
var (
@ -93,6 +94,11 @@ func TestSanitizedTemplateNodeInfoFromNodeGroup(t *testing.T) {
exampleNode.Spec.Taints = []apiv1.Taint{
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
}
exampleNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
for _, tc := range []struct {
testName string
@ -155,7 +161,7 @@ func TestSanitizedTemplateNodeInfoFromNodeGroup(t *testing.T) {
// Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because
// TemplateNodeInfoFromNodeGroupTemplate randomizes the suffix.
// Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed).
if err := verifyNodeInfoSanitization(tc.nodeGroup.templateNodeInfoResult, templateNodeInfo, tc.wantPods, "template-node-for-"+tc.nodeGroup.id, "", nil); err != nil {
if err := verifyNodeInfoSanitization(tc.nodeGroup.templateNodeInfoResult, templateNodeInfo, tc.wantPods, "template-node-for-"+tc.nodeGroup.id, "", true, nil); err != nil {
t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
}
})
@ -167,6 +173,11 @@ func TestSanitizedTemplateNodeInfoFromNodeInfo(t *testing.T) {
exampleNode.Spec.Taints = []apiv1.Taint{
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
}
exampleNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
testCases := []struct {
name string
@ -317,7 +328,7 @@ func TestSanitizedTemplateNodeInfoFromNodeInfo(t *testing.T) {
// Pass empty string as nameSuffix so that it's auto-determined from the sanitized templateNodeInfo, because
// TemplateNodeInfoFromExampleNodeInfo randomizes the suffix.
// Pass non-empty expectedPods to verify that the set of pods is changed as expected (e.g. DS pods added, non-DS/deleted pods removed).
if err := verifyNodeInfoSanitization(exampleNodeInfo, templateNodeInfo, tc.wantPods, "template-node-for-"+nodeGroupId, "", nil); err != nil {
if err := verifyNodeInfoSanitization(exampleNodeInfo, templateNodeInfo, tc.wantPods, "template-node-for-"+nodeGroupId, "", false, nil); err != nil {
t.Fatalf("TemplateNodeInfoFromExampleNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
}
})
@ -332,6 +343,12 @@ func TestSanitizedNodeInfo(t *testing.T) {
{Key: taints.ToBeDeletedTaint, Value: "2312532423", Effect: apiv1.TaintEffectNoSchedule},
{Key: "a", Value: "b", Effect: apiv1.TaintEffectNoSchedule},
}
templateNode.Labels = map[string]string{
"custom": "label",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
pods := []*framework.PodInfo{
{Pod: BuildTestPod("p1", 80, 0, WithNodeName(nodeName))},
{Pod: BuildTestPod("p2", 80, 0, WithNodeName(nodeName))},
@ -346,7 +363,7 @@ func TestSanitizedNodeInfo(t *testing.T) {
// Verify that the taints are not sanitized (they should be sanitized in the template already).
// Verify that the NodeInfo is sanitized using the template Node name as base.
initialTaints := templateNodeInfo.Node().Spec.Taints
if err := verifyNodeInfoSanitization(templateNodeInfo, freshNodeInfo, nil, templateNodeInfo.Node().Name, suffix, initialTaints); err != nil {
if err := verifyNodeInfoSanitization(templateNodeInfo, freshNodeInfo, nil, templateNodeInfo.Node().Name, suffix, false, initialTaints); err != nil {
t.Fatalf("FreshNodeInfoFromTemplateNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
}
}
@ -357,9 +374,11 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
labelsNode := basicNode.DeepCopy()
labelsNode.Labels = map[string]string{
apiv1.LabelHostname: oldNodeName,
"a": "b",
"x": "y",
apiv1.LabelHostname: oldNodeName,
"a": "b",
"x": "y",
apiv1.LabelInstanceTypeStable: "some-instance",
apiv1.LabelTopologyRegion: "some-region",
}
taintsNode := basicNode.DeepCopy()
@ -491,7 +510,7 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
if err != nil {
t.Fatalf("sanitizeNodeInfo(): want nil error, got %v", err)
}
if err := verifyNodeInfoSanitization(tc.nodeInfo, nodeInfo, nil, newNameBase, suffix, tc.wantTaints); err != nil {
if err := verifyNodeInfoSanitization(tc.nodeInfo, nodeInfo, nil, newNameBase, suffix, false, tc.wantTaints); err != nil {
t.Fatalf("sanitizeNodeInfo(): NodeInfo wasn't properly sanitized: %v", err)
}
})
@ -506,7 +525,7 @@ func TestCreateSanitizedNodeInfo(t *testing.T) {
//
// If expectedPods is nil, the set of pods is expected not to change between initialNodeInfo and sanitizedNodeInfo. If the sanitization is
// expected to change the set of pods, the expected set should be passed to expectedPods.
func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.NodeInfo, expectedPods []*apiv1.Pod, nameBase, nameSuffix string, wantTaints []apiv1.Taint) error {
func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.NodeInfo, expectedPods []*apiv1.Pod, nameBase, nameSuffix string, wantDeprecatedLabels bool, wantTaints []apiv1.Taint) error {
if nameSuffix == "" {
// Determine the suffix from the provided sanitized NodeInfo - it should be the last part of a dash-separated name.
nameParts := strings.Split(sanitizedNodeInfo.Node().Name, "-")
@ -526,7 +545,7 @@ func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.No
// Verification below assumes the same set of pods between initialNodeInfo and sanitizedNodeInfo.
wantNodeName := fmt.Sprintf("%s-%s", nameBase, nameSuffix)
if err := verifySanitizedNode(initialNodeInfo.Node(), sanitizedNodeInfo.Node(), wantNodeName, wantTaints); err != nil {
if err := verifySanitizedNode(initialNodeInfo.Node(), sanitizedNodeInfo.Node(), wantNodeName, wantDeprecatedLabels, wantTaints); err != nil {
return err
}
if err := verifySanitizedNodeResourceSlices(initialNodeInfo.LocalResourceSlices, sanitizedNodeInfo.LocalResourceSlices, nameSuffix); err != nil {
@ -539,7 +558,7 @@ func verifyNodeInfoSanitization(initialNodeInfo, sanitizedNodeInfo *framework.No
return nil
}
func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName string, wantTaints []apiv1.Taint) error {
func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName string, wantDeprecatedLabels bool, wantTaints []apiv1.Taint) error {
if gotName := sanitizedNode.Name; gotName != wantNodeName {
return fmt.Errorf("want sanitized Node name %q, got %q", wantNodeName, gotName)
}
@ -552,6 +571,9 @@ func verifySanitizedNode(initialNode, sanitizedNode *apiv1.Node, wantNodeName st
wantLabels[k] = v
}
wantLabels[apiv1.LabelHostname] = wantNodeName
if wantDeprecatedLabels {
labels.UpdateDeprecatedLabels(wantLabels)
}
if diff := cmp.Diff(wantLabels, sanitizedNode.Labels); diff != "" {
return fmt.Errorf("sanitized Node labels unexpected, diff (-want +got): %s", diff)
}
@ -601,7 +623,7 @@ func verifySanitizedPods(initialPods, sanitizedPods []*framework.PodInfo, wantNo
return fmt.Errorf("sanitized Pod unexpected diff (-want +got): %s", diff)
}
if err := verifySanitizedPodResourceClaims(initialPod.NeededResourceClaims, sanitizedPod.NeededResourceClaims, nameSuffix); err != nil {
if err := verifySanitizedPodResourceClaims(initialPod, sanitizedPod, nameSuffix); err != nil {
return err
}
}
@ -633,7 +655,11 @@ func verifySanitizedNodeResourceSlices(initialSlices, sanitizedSlices []*resourc
return nil
}
func verifySanitizedPodResourceClaims(initialClaims, sanitizedClaims []*resourceapi.ResourceClaim, nameSuffix string) error {
func verifySanitizedPodResourceClaims(initialPod, sanitizedPod *framework.PodInfo, nameSuffix string) error {
initialClaims := initialPod.NeededResourceClaims
sanitizedClaims := sanitizedPod.NeededResourceClaims
owningPod := initialPod.Pod
if len(initialClaims) != len(sanitizedClaims) {
return fmt.Errorf("want %d NeededResourceClaims in sanitized NodeInfo, got %d", len(initialClaims), len(sanitizedClaims))
}
@ -642,7 +668,9 @@ func verifySanitizedPodResourceClaims(initialClaims, sanitizedClaims []*resource
initialClaim := initialClaims[i]
// Pod-owned claims should be sanitized, other claims shouldn't.
if owningPod, _ := drautils.ClaimOwningPod(initialClaim); owningPod != "" {
err := resourceclaim.IsForPod(owningPod, initialClaim)
isPodOwned := err == nil
if isPodOwned {
// Pod-owned claim, verify that it was sanitized.
if sanitizedClaim.Name == initialClaim.Name || !strings.HasSuffix(sanitizedClaim.Name, nameSuffix) {
return fmt.Errorf("sanitized ResourceClaim name unexpected: want (different than %q, ending in %q), got %q", initialClaim.Name, nameSuffix, sanitizedClaim.Name)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM --platform=$BUILDPLATFORM golang:1.24.3 AS builder
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS builder
WORKDIR /workspace

View File

@ -276,119 +276,120 @@ func TestChangedCAReloader(t *testing.T) {
assert.NotEqual(t, oldCAEncodedString, newCAEncodedString, "expected CA to change")
}
func TestUnchangedCAReloader(t *testing.T) {
tempDir := t.TempDir()
caCert := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
Organization: []string{"ca"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(2, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Error(err)
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey)
if err != nil {
t.Error(err)
}
caPath := path.Join(tempDir, "ca.crt")
caFile, err := os.Create(caPath)
if err != nil {
t.Error(err)
}
err = pem.Encode(caFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
if err != nil {
t.Error(err)
}
// TODO(omerap12): Temporary workaround for flakiness (#7831)
// func TestUnchangedCAReloader(t *testing.T) {
// tempDir := t.TempDir()
// caCert := &x509.Certificate{
// SerialNumber: big.NewInt(0),
// Subject: pkix.Name{
// Organization: []string{"ca"},
// },
// NotBefore: time.Now(),
// NotAfter: time.Now().AddDate(2, 0, 0),
// IsCA: true,
// ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
// KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
// BasicConstraintsValid: true,
// }
// caKey, err := rsa.GenerateKey(rand.Reader, 4096)
// if err != nil {
// t.Error(err)
// }
// caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caKey.PublicKey, caKey)
// if err != nil {
// t.Error(err)
// }
// caPath := path.Join(tempDir, "ca.crt")
// caFile, err := os.Create(caPath)
// if err != nil {
// t.Error(err)
// }
// err = pem.Encode(caFile, &pem.Block{
// Type: "CERTIFICATE",
// Bytes: caBytes,
// })
// if err != nil {
// t.Error(err)
// }
testClientSet := fake.NewSimpleClientset()
// testClientSet := fake.NewSimpleClientset()
selfRegistration(
testClientSet,
readFile(caPath),
0*time.Second,
"default",
"vpa-service",
"http://example.com/",
true,
int32(32),
"",
[]string{},
false,
"key1:value1,key2:value2",
)
// selfRegistration(
// testClientSet,
// readFile(caPath),
// 0*time.Second,
// "default",
// "vpa-service",
// "http://example.com/",
// true,
// int32(32),
// "",
// []string{},
// false,
// "key1:value1,key2:value2",
// )
webhookConfigInterface := testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations()
oldWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
if err != nil {
t.Error(err)
}
// webhookConfigInterface := testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations()
// oldWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
// if err != nil {
// t.Error(err)
// }
assert.Len(t, oldWebhookConfig.Webhooks, 1, "expected one webhook configuration")
webhook := oldWebhookConfig.Webhooks[0]
oldWebhookCABundle := webhook.ClientConfig.CABundle
// assert.Len(t, oldWebhookConfig.Webhooks, 1, "expected one webhook configuration")
// webhook := oldWebhookConfig.Webhooks[0]
// oldWebhookCABundle := webhook.ClientConfig.CABundle
var reloadWebhookCACalled, patchCalled atomic.Bool
reloadWebhookCACalled.Store(false)
patchCalled.Store(false)
testClientSet.PrependReactor("get", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
reloadWebhookCACalled.Store(true)
return false, nil, nil
})
testClientSet.PrependReactor("patch", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
patchCalled.Store(true)
return false, nil, nil
})
// var reloadWebhookCACalled, patchCalled atomic.Bool
// reloadWebhookCACalled.Store(false)
// patchCalled.Store(false)
// testClientSet.PrependReactor("get", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
// reloadWebhookCACalled.Store(true)
// return false, nil, nil
// })
// testClientSet.PrependReactor("patch", "mutatingwebhookconfigurations", func(action k8stesting.Action) (bool, runtime.Object, error) {
// patchCalled.Store(true)
// return false, nil, nil
// })
reloader := certReloader{
clientCaPath: caPath,
mutatingWebhookClient: testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations(),
}
stop := make(chan struct{})
defer close(stop)
if err := reloader.start(stop); err != nil {
t.Error(err)
}
// reloader := certReloader{
// clientCaPath: caPath,
// mutatingWebhookClient: testClientSet.AdmissionregistrationV1().MutatingWebhookConfigurations(),
// }
// stop := make(chan struct{})
// defer close(stop)
// if err := reloader.start(stop); err != nil {
// t.Error(err)
// }
originalCaFile, err := os.ReadFile(caPath)
if err != nil {
t.Error(err)
}
err = os.WriteFile(caPath, originalCaFile, 0666)
if err != nil {
t.Error(err)
}
// originalCaFile, err := os.ReadFile(caPath)
// if err != nil {
// t.Error(err)
// }
// err = os.WriteFile(caPath, originalCaFile, 0666)
// if err != nil {
// t.Error(err)
// }
oldCAEncodedString := base64.StdEncoding.EncodeToString(oldWebhookCABundle)
// oldCAEncodedString := base64.StdEncoding.EncodeToString(oldWebhookCABundle)
for tries := 0; tries < 10; tries++ {
if reloadWebhookCACalled.Load() {
break
}
time.Sleep(1 * time.Second)
}
if !reloadWebhookCACalled.Load() {
t.Error("expected reloadWebhookCA to be called")
}
// for tries := 0; tries < 10; tries++ {
// if reloadWebhookCACalled.Load() {
// break
// }
// time.Sleep(1 * time.Second)
// }
// if !reloadWebhookCACalled.Load() {
// t.Error("expected reloadWebhookCA to be called")
// }
assert.False(t, patchCalled.Load(), "expected patch to not be called")
// assert.False(t, patchCalled.Load(), "expected patch to not be called")
newWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
assert.Nil(t, err, "expected no error")
assert.NotNil(t, newWebhookConfig, "expected webhook configuration")
assert.Len(t, newWebhookConfig.Webhooks, 1, "expected one webhook configuration")
// newWebhookConfig, err := webhookConfigInterface.Get(context.TODO(), webhookConfigName, metav1.GetOptions{})
// assert.Nil(t, err, "expected no error")
// assert.NotNil(t, newWebhookConfig, "expected webhook configuration")
// assert.Len(t, newWebhookConfig.Webhooks, 1, "expected one webhook configuration")
newWebhookCABundle := newWebhookConfig.Webhooks[0].ClientConfig.CABundle
newCAEncodedString := base64.StdEncoding.EncodeToString(newWebhookCABundle)
assert.Equal(t, oldCAEncodedString, newCAEncodedString, "expected CA to not change")
}
// newWebhookCABundle := newWebhookConfig.Webhooks[0].ClientConfig.CABundle
// newCAEncodedString := base64.StdEncoding.EncodeToString(newWebhookCABundle)
// assert.Equal(t, oldCAEncodedString, newCAEncodedString, "expected CA to not change")
// }

Some files were not shown because too many files have changed in this diff Show More