azure: Add support for NAT gateway

This commit is contained in:
Ciprian Hacman 2023-08-01 05:47:58 +03:00
parent bfefb0cd97
commit d382b0e44f
16 changed files with 514 additions and 30 deletions

View File

@ -269,12 +269,29 @@ func (b *NetworkModelBuilder) Build(c *fi.CloudupModelBuilderContext) error {
})
c.AddTask(nsgTask)
ngwPipTask := &azuretasks.PublicIPAddress{
Name: fi.PtrTo(b.NameForVirtualNetwork()),
Lifecycle: b.Lifecycle,
ResourceGroup: b.LinkToResourceGroup(),
Tags: map[string]*string{},
}
c.AddTask(ngwPipTask)
ngwTask := &azuretasks.NatGateway{
Name: fi.PtrTo(b.NameForVirtualNetwork()),
Lifecycle: b.Lifecycle,
PublicIPAddresses: []*azuretasks.PublicIPAddress{ngwPipTask},
ResourceGroup: b.LinkToResourceGroup(),
Tags: map[string]*string{},
}
c.AddTask(ngwTask)
for _, subnetSpec := range b.Cluster.Spec.Networking.Subnets {
subnetTask := &azuretasks.Subnet{
Name: fi.PtrTo(subnetSpec.Name),
Lifecycle: b.Lifecycle,
ResourceGroup: b.LinkToResourceGroup(),
VirtualNetwork: b.LinkToVirtualNetwork(),
NatGateway: ngwTask,
NetworkSecurityGroup: nsgTask,
CIDR: fi.PtrTo(subnetSpec.CIDR),
Shared: fi.PtrTo(b.Cluster.SharedVPC()),

View File

@ -42,6 +42,7 @@ const (
typeRoleAssignment = "RoleAssignment"
typeLoadBalancer = "LoadBalancer"
typePublicIPAddress = "PublicIPAddress"
typeNatGateway = "NatGateway"
)
// ListResourcesAzure lists all resources for the cluster by quering Azure.
@ -94,6 +95,7 @@ func (g *resourceGetter) listAll() ([]*resources.Resource, error) {
g.listDisks,
g.listLoadBalancers,
g.listPublicIPAddresses,
g.listNatGateways,
}
var resources []*resources.Resource
@ -216,6 +218,14 @@ func (g *resourceGetter) listSubnets(ctx context.Context, vnetName string) ([]*r
}
func (g *resourceGetter) toSubnetResource(subnet *network.Subnet, vnetName string) *resources.Resource {
var blocks []string
blocks = append(blocks, toKey(typeVirtualNetwork, vnetName))
blocks = append(blocks, toKey(typeResourceGroup, g.resourceGroupName()))
if subnet.NatGateway != nil {
blocks = append(blocks, toKey(typeNatGateway, *subnet.NatGateway.ID))
}
return &resources.Resource{
Obj: subnet,
Type: typeSubnet,
@ -224,10 +234,7 @@ func (g *resourceGetter) toSubnetResource(subnet *network.Subnet, vnetName strin
Deleter: func(_ fi.Cloud, r *resources.Resource) error {
return g.deleteSubnet(vnetName, r)
},
Blocks: []string{
toKey(typeVirtualNetwork, vnetName),
toKey(typeResourceGroup, g.resourceGroupName()),
},
Blocks: blocks,
Shared: g.clusterInfo.AzureNetworkShared,
}
}
@ -637,6 +644,59 @@ func (g *resourceGetter) deletePublicIPAddress(_ fi.Cloud, r *resources.Resource
return g.cloud.PublicIPAddress().Delete(context.TODO(), g.resourceGroupName(), r.Name)
}
func (g *resourceGetter) listNatGateways(ctx context.Context) ([]*resources.Resource, error) {
natGateways, err := g.cloud.NatGateway().List(ctx, g.resourceGroupName())
if err != nil {
return nil, err
}
var rs []*resources.Resource
for i := range natGateways {
p := &natGateways[i]
if !g.isOwnedByCluster(p.Tags) {
continue
}
r, err := g.toNatGatewayResource(p)
if err != nil {
return nil, err
}
rs = append(rs, r)
}
return rs, nil
}
func (g *resourceGetter) toNatGatewayResource(natGateway *network.NatGateway) (*resources.Resource, error) {
var blocks []string
blocks = append(blocks, toKey(typeResourceGroup, g.resourceGroupName()))
pips := set.New[string]()
if natGateway.PublicIPAddresses != nil {
for _, pip := range *natGateway.PublicIPAddresses {
pipID, err := azure.ParsePublicIPAddressID(*pip.ID)
if err != nil {
return nil, fmt.Errorf("parsing public IP address ID: %s", err)
}
pips.Insert(pipID.PublicIPAddressName)
}
}
for pip := range pips {
blocks = append(blocks, toKey(typePublicIPAddress, pip))
}
return &resources.Resource{
Obj: natGateway,
Type: typeNatGateway,
ID: *natGateway.ID,
Name: *natGateway.Name,
Deleter: g.deleteNatGateway,
Blocks: blocks,
}, nil
}
func (g *resourceGetter) deleteNatGateway(_ fi.Cloud, r *resources.Resource) error {
return g.cloud.NatGateway().Delete(context.TODO(), g.resourceGroupName(), r.Name)
}
// isOwnedByCluster returns true if the resource is owned by the cluster.
func (g *resourceGetter) isOwnedByCluster(tags map[string]*string) bool {
for k, v := range tags {

View File

@ -73,7 +73,8 @@ func TestListResourcesAzure(t *testing.T) {
subnets := cloud.SubnetsClient.Subnets
subnets[rgName] = network.Subnet{
Name: to.StringPtr(subnetName),
Name: to.StringPtr(subnetName),
SubnetPropertiesFormat: &network.SubnetPropertiesFormat{},
}
vnets[irrelevantName] = network.VirtualNetwork{
Name: to.StringPtr(irrelevantName),

View File

@ -58,6 +58,7 @@ type AzureCloud interface {
NetworkInterface() NetworkInterfacesClient
LoadBalancer() LoadBalancersClient
PublicIPAddress() PublicIPAddressesClient
NatGateway() NatGatewaysClient
}
type azureCloudImplementation struct {
@ -77,6 +78,7 @@ type azureCloudImplementation struct {
networkInterfacesClient NetworkInterfacesClient
loadBalancersClient LoadBalancersClient
publicIPAddressesClient PublicIPAddressesClient
natGatewaysClient NatGatewaysClient
}
var _ fi.Cloud = &azureCloudImplementation{}
@ -105,6 +107,7 @@ func NewAzureCloud(subscriptionID, location string, tags map[string]string) (Azu
networkInterfacesClient: newNetworkInterfacesClientImpl(subscriptionID, authorizer),
loadBalancersClient: newLoadBalancersClientImpl(subscriptionID, authorizer),
publicIPAddressesClient: newPublicIPAddressesClientImpl(subscriptionID, authorizer),
natGatewaysClient: newNatGatewaysClientImpl(subscriptionID, authorizer),
}, nil
}
@ -327,3 +330,7 @@ func (c *azureCloudImplementation) LoadBalancer() LoadBalancersClient {
func (c *azureCloudImplementation) PublicIPAddress() PublicIPAddressesClient {
return c.publicIPAddressesClient
}
func (c *azureCloudImplementation) NatGateway() NatGatewaysClient {
return c.natGatewaysClient
}

View File

@ -0,0 +1,83 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"fmt"
"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-05-01/network"
"github.com/Azure/go-autorest/autorest"
)
// NatGatewaysClient is a client for managing Nat Gateways.
type NatGatewaysClient interface {
CreateOrUpdate(ctx context.Context, resourceGroupName, natGatewayName string, parameters network.NatGateway) (*network.NatGateway, error)
List(ctx context.Context, resourceGroupName string) ([]network.NatGateway, error)
Delete(ctx context.Context, resourceGroupName, natGatewayName string) error
}
type NatGatewaysClientImpl struct {
c *network.NatGatewaysClient
}
var _ NatGatewaysClient = &NatGatewaysClientImpl{}
func (c *NatGatewaysClientImpl) CreateOrUpdate(ctx context.Context, resourceGroupName, natGatewayName string, parameters network.NatGateway) (*network.NatGateway, error) {
future, err := c.c.CreateOrUpdate(ctx, resourceGroupName, natGatewayName, parameters)
if err != nil {
return nil, fmt.Errorf("creating/updating nat gateway: %w", err)
}
if err := future.WaitForCompletionRef(ctx, c.c.Client); err != nil {
return nil, fmt.Errorf("waiting for nat gateway create/update completion: %w", err)
}
asg, err := future.Result(*c.c)
if err != nil {
return nil, fmt.Errorf("obtaining result for nat gateway create/update: %w", err)
}
return &asg, err
}
func (c *NatGatewaysClientImpl) List(ctx context.Context, resourceGroupName string) ([]network.NatGateway, error) {
var l []network.NatGateway
for iter, err := c.c.ListComplete(ctx, resourceGroupName); iter.NotDone(); err = iter.NextWithContext(ctx) {
if err != nil {
return nil, fmt.Errorf("listing nat gateways: %w", err)
}
l = append(l, iter.Value())
}
return l, nil
}
func (c *NatGatewaysClientImpl) Delete(ctx context.Context, resourceGroupName, natGatewayName string) error {
future, err := c.c.Delete(ctx, resourceGroupName, natGatewayName)
if err != nil {
return fmt.Errorf("deleting nat gateway: %w", err)
}
if err := future.WaitForCompletionRef(ctx, c.c.Client); err != nil {
return fmt.Errorf("waiting for nat gateway deletion completion: %w", err)
}
return nil
}
func newNatGatewaysClientImpl(subscriptionID string, authorizer autorest.Authorizer) *NatGatewaysClientImpl {
c := network.NewNatGatewaysClient(subscriptionID)
c.Authorizer = authorizer
return &NatGatewaysClientImpl{
c: &c,
}
}

View File

@ -26,7 +26,7 @@ import (
// PublicIPAddressesClient is a client for public ip addresses.
type PublicIPAddressesClient interface {
CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) error
CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) (*network.PublicIPAddress, error)
List(ctx context.Context, resourceGroupName string) ([]network.PublicIPAddress, error)
Delete(ctx context.Context, resourceGroupName, publicIPAddressName string) error
}
@ -37,9 +37,19 @@ type publicIPAddressesClientImpl struct {
var _ PublicIPAddressesClient = &publicIPAddressesClientImpl{}
func (c *publicIPAddressesClientImpl) CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) error {
_, err := c.c.CreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters)
return err
func (c *publicIPAddressesClientImpl) CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) (*network.PublicIPAddress, error) {
future, err := c.c.CreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters)
if err != nil {
return nil, fmt.Errorf("creating/updating public ip address: %w", err)
}
if err := future.WaitForCompletionRef(ctx, c.c.Client); err != nil {
return nil, fmt.Errorf("waiting for public ip address create/update completion: %w", err)
}
pip, err := future.Result(*c.c)
if err != nil {
return nil, fmt.Errorf("obtaining result for public ip address create/update: %w", err)
}
return &pip, err
}
func (c *publicIPAddressesClientImpl) List(ctx context.Context, resourceGroupName string) ([]network.PublicIPAddress, error) {

View File

@ -26,7 +26,7 @@ import (
// SubnetsClient is a client for managing Subnets.
type SubnetsClient interface {
CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) error
CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) (*network.Subnet, error)
List(ctx context.Context, resourceGroupName, virtualNetworkName string) ([]network.Subnet, error)
Delete(ctx context.Context, resourceGroupName, vnetName, subnetName string) error
}
@ -37,9 +37,19 @@ type subnetsClientImpl struct {
var _ SubnetsClient = &subnetsClientImpl{}
func (c *subnetsClientImpl) CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) error {
_, err := c.c.CreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, subnetName, parameters)
return err
func (c *subnetsClientImpl) CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) (*network.Subnet, error) {
future, err := c.c.CreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, subnetName, parameters)
if err != nil {
return nil, fmt.Errorf("creating/updating subnet: %w", err)
}
if err := future.WaitForCompletionRef(ctx, c.c.Client); err != nil {
return nil, fmt.Errorf("waiting for subnet create/update completion: %w", err)
}
sn, err := future.Result(*c.c)
if err != nil {
return nil, fmt.Errorf("obtaining result for subnet create/update: %w", err)
}
return &sn, err
}
func (c *subnetsClientImpl) List(ctx context.Context, resourceGroupName, virtualNetworkName string) ([]network.Subnet, error) {

View File

@ -0,0 +1,155 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azuretasks
import (
"context"
"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-05-01/network"
"github.com/Azure/go-autorest/autorest/to"
"k8s.io/klog/v2"
"k8s.io/kops/upup/pkg/fi"
"k8s.io/kops/upup/pkg/fi/cloudup/azure"
)
// NatGateway is an Azure Nat Gateway
// +kops:fitask
type NatGateway struct {
Name *string
Lifecycle fi.Lifecycle
ID *string
PublicIPAddresses []*PublicIPAddress
ResourceGroup *ResourceGroup
Tags map[string]*string
}
var (
_ fi.CloudupTask = &NatGateway{}
_ fi.CompareWithID = &NatGateway{}
_ fi.CloudupTaskNormalize = &NatGateway{}
)
// CompareWithID returns the Name of the Nat Gateway
func (ngw *NatGateway) CompareWithID() *string {
return ngw.ID
}
// Find discovers the Nat Gateway in the cloud provider
func (ngw *NatGateway) Find(c *fi.CloudupContext) (*NatGateway, error) {
cloud := c.T.Cloud.(azure.AzureCloud)
l, err := cloud.NatGateway().List(context.TODO(), *ngw.ResourceGroup.Name)
if err != nil {
return nil, err
}
var found *network.NatGateway
for _, v := range l {
if *v.Name == *ngw.Name {
found = &v
break
}
}
if found == nil {
return nil, nil
}
ngw.ID = found.ID
var pips []*PublicIPAddress
if found.PublicIPAddresses != nil {
for _, pip := range *found.PublicIPAddresses {
pips = append(pips, &PublicIPAddress{ID: pip.ID})
}
}
return &NatGateway{
Name: ngw.Name,
Lifecycle: ngw.Lifecycle,
ResourceGroup: &ResourceGroup{Name: ngw.ResourceGroup.Name},
ID: found.ID,
PublicIPAddresses: pips,
Tags: found.Tags,
}, nil
}
func (ngw *NatGateway) Normalize(c *fi.CloudupContext) error {
c.T.Cloud.(azure.AzureCloud).AddClusterTags(ngw.Tags)
return nil
}
// Run implements fi.Task.Run.
func (ngw *NatGateway) Run(c *fi.CloudupContext) error {
return fi.CloudupDefaultDeltaRunMethod(ngw, c)
}
// CheckChanges returns an error if a change is not allowed.
func (*NatGateway) CheckChanges(a, e, changes *NatGateway) error {
if a == nil {
// Check if required fields are set when a new resource is created.
if e.Name == nil {
return fi.RequiredField("Name")
}
return nil
}
// Check if unchangeable fields won't be changed.
if changes.Name != nil {
return fi.CannotChangeField("Name")
}
return nil
}
// RenderAzure creates or updates a Nat Gateway.
func (*NatGateway) RenderAzure(t *azure.AzureAPITarget, a, e, changes *NatGateway) error {
if a == nil {
klog.Infof("Creating a new Nat Gateway with name: %s", fi.ValueOf(e.Name))
} else {
klog.Infof("Updating a Nat Gateway with name: %s", fi.ValueOf(e.Name))
}
p := network.NatGateway{
Location: to.StringPtr(t.Cloud.Region()),
Name: to.StringPtr(*e.Name),
NatGatewayPropertiesFormat: &network.NatGatewayPropertiesFormat{},
Sku: &network.NatGatewaySku{
Name: network.NatGatewaySkuNameStandard,
},
Tags: e.Tags,
}
if len(e.PublicIPAddresses) > 0 {
var pips []network.SubResource
for _, pip := range e.PublicIPAddresses {
pips = append(pips, network.SubResource{ID: pip.ID})
}
p.PublicIPAddresses = &pips
}
ngw, err := t.Cloud.NatGateway().CreateOrUpdate(
context.TODO(),
*e.ResourceGroup.Name,
*e.Name,
p)
if err != nil {
return err
}
e.ID = ngw.ID
return nil
}

View File

@ -0,0 +1,52 @@
//go:build !ignore_autogenerated
// +build !ignore_autogenerated
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Code generated by fitask. DO NOT EDIT.
package azuretasks
import (
"k8s.io/kops/upup/pkg/fi"
)
// NatGateway
var _ fi.HasLifecycle = &NatGateway{}
// GetLifecycle returns the Lifecycle of the object, implementing fi.HasLifecycle
func (o *NatGateway) GetLifecycle() fi.Lifecycle {
return o.Lifecycle
}
// SetLifecycle sets the Lifecycle of the object, implementing fi.SetLifecycle
func (o *NatGateway) SetLifecycle(lifecycle fi.Lifecycle) {
o.Lifecycle = lifecycle
}
var _ fi.HasName = &NatGateway{}
// GetName returns the Name of the object, implementing fi.HasName
func (o *NatGateway) GetName() *string {
return o.Name
}
// String is the stringer function for the task, producing readable output using fi.TaskAsString
func (o *NatGateway) String() string {
return fi.CloudupTaskAsString(o)
}

View File

@ -29,8 +29,10 @@ import (
// PublicIPAddress is an Azure Cloud Public IP Address
// +kops:fitask
type PublicIPAddress struct {
Name *string
Lifecycle fi.Lifecycle
Name *string
Lifecycle fi.Lifecycle
ID *string
ResourceGroup *ResourceGroup
Tags map[string]*string
@ -44,7 +46,7 @@ var (
// CompareWithID returns the Name of the Public IP Address
func (p *PublicIPAddress) CompareWithID() *string {
return p.Name
return p.ID
}
// Find discovers the Public IP Address in the cloud provider
@ -65,13 +67,15 @@ func (p *PublicIPAddress) Find(c *fi.CloudupContext) (*PublicIPAddress, error) {
return nil, nil
}
p.ID = found.ID
return &PublicIPAddress{
Name: p.Name,
Lifecycle: p.Lifecycle,
ResourceGroup: &ResourceGroup{
Name: p.ResourceGroup.Name,
},
ID: found.ID,
Tags: found.Tags,
}, nil
}
@ -124,9 +128,17 @@ func (*PublicIPAddress) RenderAzure(t *azure.AzureAPITarget, a, e, changes *Publ
Tags: e.Tags,
}
return t.Cloud.PublicIPAddress().CreateOrUpdate(
pip, err := t.Cloud.PublicIPAddress().CreateOrUpdate(
context.TODO(),
*e.ResourceGroup.Name,
*e.Name,
p)
if err != nil {
return err
}
e.ID = pip.ID
return nil
}

View File

@ -94,7 +94,7 @@ func TestPublicIPAddressFind(t *testing.T) {
PublicIPAllocationMethod: network.Dynamic,
},
}
if err := cloud.PublicIPAddress().CreateOrUpdate(context.Background(), *rg.Name, *publicIPAddress.Name, publicIPAddressParameters); err != nil {
if _, err := cloud.PublicIPAddress().CreateOrUpdate(context.Background(), *rg.Name, *publicIPAddress.Name, publicIPAddressParameters); err != nil {
t.Fatalf("failed to create: %s", err)
}
// Find again.

View File

@ -31,8 +31,10 @@ type Subnet struct {
Name *string
Lifecycle fi.Lifecycle
ID *string
ResourceGroup *ResourceGroup
VirtualNetwork *VirtualNetwork
NatGateway *NatGateway
NetworkSecurityGroup *NetworkSecurityGroup
CIDR *string
@ -46,7 +48,7 @@ var (
// CompareWithID returns the Name of the VM Scale Set.
func (s *Subnet) CompareWithID() *string {
return s.Name
return s.ID
}
// Find discovers the Subnet in the cloud provider.
@ -67,6 +69,8 @@ func (s *Subnet) Find(c *fi.CloudupContext) (*Subnet, error) {
return nil, nil
}
s.ID = found.ID
fs := &Subnet{
Name: s.Name,
Lifecycle: s.Lifecycle,
@ -77,8 +81,14 @@ func (s *Subnet) Find(c *fi.CloudupContext) (*Subnet, error) {
VirtualNetwork: &VirtualNetwork{
Name: s.VirtualNetwork.Name,
},
ID: found.ID,
CIDR: found.AddressPrefix,
}
if found.NatGateway != nil {
fs.NatGateway = &NatGateway{
ID: found.NatGateway.ID,
}
}
if found.NetworkSecurityGroup != nil {
fs.NetworkSecurityGroup = &NetworkSecurityGroup{
ID: found.NetworkSecurityGroup.ID,
@ -123,16 +133,28 @@ func (*Subnet) RenderAzure(t *azure.AzureAPITarget, a, e, changes *Subnet) error
AddressPrefix: e.CIDR,
},
}
if e.NatGateway != nil {
subnet.NatGateway = &network.SubResource{
ID: e.NatGateway.ID,
}
}
if e.NetworkSecurityGroup != nil {
subnet.NetworkSecurityGroup = &network.SecurityGroup{
ID: e.NetworkSecurityGroup.ID,
}
}
return t.Cloud.Subnet().CreateOrUpdate(
sn, err := t.Cloud.Subnet().CreateOrUpdate(
context.TODO(),
*e.ResourceGroup.Name,
*e.VirtualNetwork.Name,
*e.Name,
subnet)
if err != nil {
return err
}
e.ID = sn.ID
return nil
}

View File

@ -90,7 +90,7 @@ func TestSubnetFind(t *testing.T) {
AddressPrefix: to.StringPtr(cidr),
},
}
if err := cloud.Subnet().CreateOrUpdate(context.Background(), *rg.Name, *vnet.Name, *subnet.Name, subnetParameters); err != nil {
if _, err := cloud.Subnet().CreateOrUpdate(context.Background(), *rg.Name, *vnet.Name, *subnet.Name, subnetParameters); err != nil {
t.Fatalf("failed to create: %s", err)
}
// Find again.

View File

@ -60,6 +60,7 @@ type MockAzureCloud struct {
NetworkInterfacesClient *MockNetworkInterfacesClient
LoadBalancersClient *MockLoadBalancersClient
PublicIPAddressesClient *MockPublicIPAddressesClient
NatGatewaysClient *MockNatGatewaysClient
}
var _ azure.AzureCloud = &MockAzureCloud{}
@ -107,6 +108,9 @@ func NewMockAzureCloud(location string) *MockAzureCloud {
PublicIPAddressesClient: &MockPublicIPAddressesClient{
PubIPs: map[string]network.PublicIPAddress{},
},
NatGatewaysClient: &MockNatGatewaysClient{
NGWs: map[string]network.NatGateway{},
},
}
}
@ -248,6 +252,11 @@ func (c *MockAzureCloud) PublicIPAddress() azure.PublicIPAddressesClient {
return c.PublicIPAddressesClient
}
// NatGateway returns the nat gateway client.
func (c *MockAzureCloud) NatGateway() azure.NatGatewaysClient {
return c.NatGatewaysClient
}
// MockResourceGroupsClient is a mock implementation of resource group client.
type MockResourceGroupsClient struct {
RGs map[string]resources.Group
@ -327,14 +336,14 @@ type MockSubnetsClient struct {
var _ azure.SubnetsClient = &MockSubnetsClient{}
// CreateOrUpdate creates or updates a subnet.
func (c *MockSubnetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) error {
func (c *MockSubnetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, parameters network.Subnet) (*network.Subnet, error) {
// Ignore resourceGroupName and virtualNetworkName for simplicity.
if _, ok := c.Subnets[subnetName]; ok {
return fmt.Errorf("update not supported")
return nil, fmt.Errorf("update not supported")
}
parameters.Name = &subnetName
c.Subnets[subnetName] = parameters
return nil
return &parameters, nil
}
// List returns a slice of subnets.
@ -613,13 +622,13 @@ type MockPublicIPAddressesClient struct {
var _ azure.PublicIPAddressesClient = &MockPublicIPAddressesClient{}
// CreateOrUpdate creates a new public ip address.
func (c *MockPublicIPAddressesClient) CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) error {
func (c *MockPublicIPAddressesClient) CreateOrUpdate(ctx context.Context, resourceGroupName, publicIPAddressName string, parameters network.PublicIPAddress) (*network.PublicIPAddress, error) {
if _, ok := c.PubIPs[publicIPAddressName]; ok {
return nil
return nil, fmt.Errorf("update not supported")
}
parameters.Name = &publicIPAddressName
c.PubIPs[publicIPAddressName] = parameters
return nil
return &parameters, nil
}
// List returns a slice of public ip address.
@ -732,3 +741,49 @@ func (c *MockApplicationSecurityGroupsClient) Delete(ctx context.Context, resour
delete(c.ASGs, asgName)
return nil
}
// MockNatGatewaysClient is a mock implementation of Nat Gateway client.
type MockNatGatewaysClient struct {
NGWs map[string]network.NatGateway
}
var _ azure.NatGatewaysClient = &MockNatGatewaysClient{}
// CreateOrUpdate creates or updates a Nat Gateway.
func (c *MockNatGatewaysClient) CreateOrUpdate(ctx context.Context, resourceGroupName, ngwName string, parameters network.NatGateway) (*network.NatGateway, error) {
// Ignore resourceGroupName for simplicity.
if _, ok := c.NGWs[ngwName]; ok {
return nil, fmt.Errorf("update not supported")
}
parameters.Name = &ngwName
c.NGWs[ngwName] = parameters
return &parameters, nil
}
// List returns a slice of Nat Gateways.
func (c *MockNatGatewaysClient) List(ctx context.Context, resourceGroupName string) ([]network.NatGateway, error) {
var l []network.NatGateway
for _, ngw := range c.NGWs {
l = append(l, ngw)
}
return l, nil
}
// Get Returns a specified Nat Gateway.
func (c *MockNatGatewaysClient) Get(ctx context.Context, resourceGroupName string, ngwName string) (*network.NatGateway, error) {
ngw, ok := c.NGWs[ngwName]
if !ok {
return nil, nil
}
return &ngw, nil
}
// Delete deletes a specified Nat Gateway.
func (c *MockNatGatewaysClient) Delete(ctx context.Context, resourceGroupName, ngwName string) error {
// Ignore resourceGroupName for simplicity.
if _, ok := c.NGWs[ngwName]; !ok {
return fmt.Errorf("%s does not exist", ngwName)
}
delete(c.NGWs, ngwName)
return nil
}

View File

@ -153,7 +153,7 @@ func (s *VMScaleSet) Find(c *fi.CloudupContext) (*VMScaleSet, error) {
Name: to.StringPtr(subnetID.VirtualNetworkName),
},
Subnet: &Subnet{
Name: to.StringPtr(subnetID.SubnetName),
ID: ipConfig.Subnet.ID,
},
StorageProfile: &VMScaleSetStorageProfile{
VirtualMachineScaleSetStorageProfile: profile.StorageProfile,

View File

@ -280,7 +280,7 @@ func TestVMScaleSetFind(t *testing.T) {
if a, e := *actual.VirtualNetwork.Name, subnetID.VirtualNetworkName; a != e {
t.Errorf("unexpected Resource Group name: expected %s, but got %s", e, a)
}
if a, e := *actual.Subnet.Name, subnetID.SubnetName; a != e {
if a, e := *actual.Subnet.ID, subnetID.String(); a != e {
t.Errorf("unexpected Resource Group name: expected %s, but got %s", e, a)
}
if a, e := *actual.LoadBalancer.Name, loadBalancerID.LoadBalancerName; a != e {