diff --git a/cmd/cloud-controller-manager/main.go b/cmd/cloud-controller-manager/main.go index 5ce3bdd51e..c3f302be74 100644 --- a/cmd/cloud-controller-manager/main.go +++ b/cmd/cloud-controller-manager/main.go @@ -27,8 +27,6 @@ import ( "github.com/spf13/pflag" "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" - "k8s.io/cloud-provider-gcp/providers/gce" - _ "k8s.io/cloud-provider-gcp/providers/gce" "k8s.io/cloud-provider/app" "k8s.io/cloud-provider/app/config" "k8s.io/cloud-provider/names" @@ -39,6 +37,9 @@ import ( _ "k8s.io/component-base/metrics/prometheus/version" // for version metric registration "k8s.io/klog/v2" kcmnames "k8s.io/kubernetes/cmd/kube-controller-manager/names" + + "k8s.io/cloud-provider-gcp/providers/gce" + _ "k8s.io/cloud-provider-gcp/providers/gce" ) const ( @@ -47,30 +48,36 @@ const ( gkeServiceAlias = "gke-service" ) -// enableMultiProject is bound to a command-line flag. When true, it enables the -// projectFromNodeProviderID option of the GCE cloud provider, instructing it to -// use the project specified in the Node's providerID for GCE API calls. -// -// This flag should only be enabled when the Node's providerID can be fully -// trusted. -// -// Flag binding occurs in main() -var enableMultiProject bool - -// enableDiscretePortForwarding is bound to a command-line flag. It enables -// the same option of the GCE cloud provider to forward individual ports -// instead of port ranges in Forwarding Rules for external load balancers. -var enableDiscretePortForwarding bool - -// enableRBSDefaultForGCEL4NetLB is bound to a command-line flag. It enables -// the option to default L4 NetLB to RBS, only controlling NetLB services with -// LoadBalancerClass -var enableRBSDefaultForL4NetLB bool - -// enableL4LBAnnotations is bound to a command-line flag. It enables -// the controller to write annotations related to the provisioned resources -// for L4 Load Balancers services -var enableL4LBAnnotations bool +var ( + // enableMultiProject is bound to a command-line flag. When true, it enables the + // projectFromNodeProviderID option of the GCE cloud provider, instructing it to + // use the project specified in the Node's providerID for GCE API calls. + // + // This flag should only be enabled when the Node's providerID can be fully + // trusted. + // + // Flag binding occurs in main() + enableMultiProject bool + + // enableRBSDefaultForGCEL4NetLB is bound to a command-line flag. It enables + // the option to default L4 NetLB to RBS, only controlling NetLB services with + // LoadBalancerClass + enableRBSDefaultForL4NetLB bool + + // enableL4LBAnnotations is bound to a command-line flag. It enables + // the controller to write annotations related to the provisioned resources + // for L4 Load Balancers services + enableL4LBAnnotations bool + + // enableL4DenyFirewall creates and manages an additional deny firewall rule + // at priority 1000 and moves the node and healthcheck firewall rule to priority 999. + enableL4DenyFirewall bool + + // enableL4DenyFirewallRollbackCleanup enable cleanup codepath of the deny firewalls for rollback. + // The reason for it not being enabled by default is the additional GCE API calls that are made + // for checking if the deny firewalls exist/deletion which will eat up the quota unnecessarily. + enableL4DenyFirewallRollbackCleanup bool +) func main() { rand.Seed(time.Now().UnixNano()) @@ -88,9 +95,10 @@ func main() { cloudProviderFS := fss.FlagSet("GCE Cloud Provider") cloudProviderFS.BoolVar(&enableMultiProject, "enable-multi-project", false, "Enables project selection from Node providerID for GCE API calls. CAUTION: Only enable if Node providerID is configured by a trusted source.") - cloudProviderFS.BoolVar(&enableDiscretePortForwarding, "enable-discrete-port-forwarding", false, "Enables forwarding of individual ports instead of port ranges for GCE external load balancers.") cloudProviderFS.BoolVar(&enableRBSDefaultForL4NetLB, "enable-rbs-default-l4-netlb", false, "Enables RBS defaulting for GCE L4 NetLB") cloudProviderFS.BoolVar(&enableL4LBAnnotations, "enable-l4-lb-annotations", false, "Enables Annotations for GCE L4 LB Services") + cloudProviderFS.BoolVar(&enableL4DenyFirewall, "enable-l4-deny-firewall", false, "Enable creation and updates of Deny VPC Firewall Rules for L4 external load balancers. Requires --enable-pinhole and --enable-l4-deny-firewall-rollback-cleanup to be true.") + cloudProviderFS.BoolVar(&enableL4DenyFirewallRollbackCleanup, "enable-l4-deny-firewall-rollback-cleanup", false, "Enable cleanup codepath of the deny firewalls for rollback. The reason for it not being enabled by default is the additional GCE API calls that are made for checking if the deny firewalls exist/deletion which will eat up the quota unnecessarily.") // add new controllers and initializers nodeIpamController := nodeIPAMController{} @@ -158,16 +166,6 @@ func cloudInitializer(config *config.CompletedConfig) cloudprovider.Interface { gceCloud.SetProjectFromNodeProviderID(true) } - if enableDiscretePortForwarding { - gceCloud, ok := (cloud).(*gce.Cloud) - if !ok { - // Fail-fast: If enableDiscretePortForwarding is set, the cloud - // provider MUST be GCE. - klog.Fatalf("enable-discrete-port-forwarding requires GCE cloud provider, but got %T", cloud) - } - gceCloud.SetEnableDiscretePortForwarding(true) - } - if enableRBSDefaultForL4NetLB { gceCloud, ok := (cloud).(*gce.Cloud) if !ok { @@ -188,5 +186,16 @@ func cloudInitializer(config *config.CompletedConfig) cloudprovider.Interface { gceCloud.SetEnableL4LBAnnotations(true) } + if enableL4DenyFirewall || enableL4DenyFirewallRollbackCleanup { + gceCloud, ok := (cloud).(*gce.Cloud) + if !ok { + klog.Fatalf("enable-l4-deny-firewall and enable-l4-deny-firewall-rollback-cleanup require GCE cloud provider, but got %T", cloud) + } + if enableL4DenyFirewall && !enableL4DenyFirewallRollbackCleanup { + klog.Fatal("enable-l4-deny-firewall requires enable-l4-deny-firewall-rollback-cleanup to be true") + } + gceCloud.SetEnableL4DenyFirewallRule(enableL4DenyFirewall, enableL4DenyFirewallRollbackCleanup) + } + return cloud } diff --git a/providers/gce/BUILD b/providers/gce/BUILD index d5ed33001e..aa8abb2d84 100644 --- a/providers/gce/BUILD +++ b/providers/gce/BUILD @@ -101,9 +101,11 @@ go_test( "gce_annotations_test.go", "gce_disks_test.go", "gce_instances_test.go", + "gce_loadbalancer_external_deny_test.go", "gce_loadbalancer_external_test.go", "gce_loadbalancer_internal_test.go", "gce_loadbalancer_metrics_test.go", + "gce_loadbalancer_naming_test.go", "gce_loadbalancer_test.go", "gce_loadbalancer_utils_test.go", "gce_test.go", @@ -116,6 +118,7 @@ go_test( "//vendor/github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta", "//vendor/github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/mock", "//vendor/github.com/google/go-cmp/cmp", + "//vendor/github.com/google/go-cmp/cmp/cmpopts", "//vendor/github.com/stretchr/testify/assert", "//vendor/github.com/stretchr/testify/require", "//vendor/golang.org/x/oauth2/google", @@ -132,6 +135,7 @@ go_test( "//vendor/k8s.io/client-go/tools/record", "//vendor/k8s.io/cloud-provider", "//vendor/k8s.io/cloud-provider/service/helpers", + "//vendor/k8s.io/component-base/metrics/testutil", "//vendor/k8s.io/utils/net", ], ) diff --git a/providers/gce/gce.go b/providers/gce/gce.go index a24771dfd6..0b7ae1d791 100644 --- a/providers/gce/gce.go +++ b/providers/gce/gce.go @@ -85,13 +85,15 @@ const ( gceComputeAPIEndpointBeta = "https://www.googleapis.com/compute/beta/" ) -var _ cloudprovider.Interface = (*Cloud)(nil) -var _ cloudprovider.Instances = (*Cloud)(nil) -var _ cloudprovider.LoadBalancer = (*Cloud)(nil) -var _ cloudprovider.Routes = (*Cloud)(nil) -var _ cloudprovider.Zones = (*Cloud)(nil) -var _ cloudprovider.PVLabeler = (*Cloud)(nil) -var _ cloudprovider.Clusters = (*Cloud)(nil) +var ( + _ cloudprovider.Interface = (*Cloud)(nil) + _ cloudprovider.Instances = (*Cloud)(nil) + _ cloudprovider.LoadBalancer = (*Cloud)(nil) + _ cloudprovider.Routes = (*Cloud)(nil) + _ cloudprovider.Zones = (*Cloud)(nil) + _ cloudprovider.PVLabeler = (*Cloud)(nil) + _ cloudprovider.Clusters = (*Cloud)(nil) +) type StackType string @@ -203,15 +205,18 @@ type Cloud struct { // Enable this ony when the Node's .spec.providerID can be fully trusted. projectFromNodeProviderID bool - // enableDiscretePortForwarding enables forwarding of individual ports - // instead of port ranges in Forwarding Rules for external load balancers. - enableDiscretePortForwarding bool - // enableRBSDefaultForL4NetLB disable Service controller from picking up services by default enableRBSDefaultForL4NetLB bool // enableL4LBAnnotations enable annotations related to provisioned resources in GCE enableL4LBAnnotations bool + + // enableL4DenyFirewallRule creates an additional deny firewall rule at priority 1000 + // and moves the allow rule to priority 999 to improve security posture. + enableL4DenyFirewallRule bool + + // enableL4DenyFirewallRollbackCleanup + enableL4DenyFirewallRollbackCleanup bool } // ConfigGlobal is the in memory representation of the gce.conf config data @@ -864,11 +869,6 @@ func (g *Cloud) SetProjectFromNodeProviderID(enabled bool) { g.projectFromNodeProviderID = enabled } -// SetEnableDiscretePortForwarding configures enableDiscretePortForwarding option. -func (g *Cloud) SetEnableDiscretePortForwarding(enabled bool) { - g.enableDiscretePortForwarding = enabled -} - func (g *Cloud) SetEnableRBSDefaultForL4NetLB(enabled bool) { g.enableRBSDefaultForL4NetLB = enabled } @@ -877,6 +877,11 @@ func (g *Cloud) SetEnableL4LBAnnotations(enabled bool) { g.enableL4LBAnnotations = enabled } +func (g *Cloud) SetEnableL4DenyFirewallRule(firewallEnabled, rollbackEnabled bool) { + g.enableL4DenyFirewallRule = firewallEnabled + g.enableL4DenyFirewallRollbackCleanup = rollbackEnabled +} + // getProjectsBasePath returns the compute API endpoint with the `projects/` element. // The suffix must be added when generating compute resource urls. func getProjectsBasePath(basePath string) string { @@ -970,7 +975,7 @@ func getZonesForRegion(svc *compute.Service, projectID, region string) ([]string // listCall = listCall.Filter("region eq " + region) var zones []string - var accumulator = func(response *compute.ZoneList) error { + accumulator := func(response *compute.ZoneList) error { for _, zone := range response.Items { regionName := lastComponent(zone.Region) if regionName == region { diff --git a/providers/gce/gce_loadbalancer_external.go b/providers/gce/gce_loadbalancer_external.go index 0164c58183..e9e785ee7b 100644 --- a/providers/gce/gce_loadbalancer_external.go +++ b/providers/gce/gce_loadbalancer_external.go @@ -21,6 +21,7 @@ package gce import ( "context" + "errors" "fmt" "net/http" "reflect" @@ -41,10 +42,11 @@ import ( ) const ( - errStrLbNoHosts = "cannot EnsureLoadBalancer() with no hosts" - maxNodeNamesToLog = 50 - // maxForwardedPorts is the maximum number of ports that can be specified in an Forwarding Rule - maxForwardedPorts = 5 + errStrLbNoHosts = "cannot EnsureLoadBalancer() with no hosts" + maxNodeNamesToLog = 50 + firewallPriorityDefault = 1000 + firewallPriorityDeny = firewallPriorityDefault + firewallPriorityAllow = firewallPriorityDefault - 1 ) // ensureExternalLoadBalancer is the external implementation of LoadBalancer.EnsureLoadBalancer. @@ -65,6 +67,18 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string, return nil, cloudprovider.ImplementedElsewhere } + nm := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name} + metricsState := L4NetLBServiceState{ + Status: StatusError, + DenyFirewall: DenyFirewallStatusNone, + } + if !g.enableL4DenyFirewallRule && g.enableL4DenyFirewallRollbackCleanup { + metricsState.DenyFirewall = DenyFirewallStatusDisabled + } + defer func() { + g.metricsCollector.SetL4NetLBService(nm.String(), metricsState) + }() + if hasLoadBalancerClass(apiService, LegacyRegionalExternalLoadBalancerClass) { if apiService.Annotations[ServiceAnnotationLoadBalancerType] == string(LBTypeInternal) { g.eventRecorder.Event(apiService, v1.EventTypeWarning, "ConflictingConfiguration", fmt.Sprintf("loadBalancerClass conflicts with %s: %s annotation. External LoadBalancer Service provisioned.", ServiceAnnotationLoadBalancerType, string(LBTypeInternal))) @@ -195,35 +209,16 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string, // Deal with the firewall next. The reason we do this here rather than last // is because the forwarding rule is used as the indicator that the load // balancer is fully created - it's what getLoadBalancer checks for. - // Check if user specified the allow source range - sourceRanges, err := servicehelpers.GetLoadBalancerSourceRanges(apiService) - if err != nil { - return nil, err - } - - firewallExists, firewallNeedsUpdate, err := g.firewallNeedsUpdate(loadBalancerName, serviceName.String(), ipAddressToUse, ports, sourceRanges) - if err != nil { - return nil, err - } - - if firewallNeedsUpdate { - desc := makeFirewallDescription(serviceName.String(), ipAddressToUse) - // Unlike forwarding rules and target pools, firewalls can be updated - // without needing to be deleted and recreated. - if firewallExists { - klog.Infof("ensureExternalLoadBalancer(%s): Updating firewall.", lbRefStr) - if err := g.updateFirewall(apiService, MakeFirewallName(loadBalancerName), desc, ipAddressToUse, sourceRanges, ports, hosts); err != nil { - return nil, err - } - klog.Infof("ensureExternalLoadBalancer(%s): Updated firewall.", lbRefStr) - } else { - klog.Infof("ensureExternalLoadBalancer(%s): Creating firewall.", lbRefStr) - if err := g.createFirewall(apiService, MakeFirewallName(loadBalancerName), desc, ipAddressToUse, sourceRanges, ports, hosts); err != nil { - return nil, err - } - klog.Infof("ensureExternalLoadBalancer(%s): Created firewall.", lbRefStr) + if !g.enableL4DenyFirewallRule && g.enableL4DenyFirewallRollbackCleanup { + // clean up the resource, if the flag got disabled + fwDenyName := MakeFirewallDenyName(loadBalancerName) + if err := g.ensureFirewallDeleted(fwDenyName); err != nil { + return nil, fmt.Errorf("failed to clean up deny firewall for load balancer (%s): %v", lbRefStr, err) } } + if err := g.ensureAllowNodeFirewall(apiService, loadBalancerName, ipAddressToUse, lbRefStr, hosts); err != nil { + return nil, fmt.Errorf("failed to ensure node firewall for load balancer (%s): %v", lbRefStr, err) + } tpExists, tpNeedsRecreation, err := g.targetPoolNeedsRecreation(loadBalancerName, g.region, apiService.Spec.SessionAffinity) if err != nil { @@ -288,7 +283,7 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string, if tpNeedsRecreation || fwdRuleNeedsUpdate { klog.Infof("ensureExternalLoadBalancer(%s): Creating forwarding rule, IP %s (tier: %s).", lbRefStr, ipAddressToUse, netTier) - if err := createForwardingRule(g, loadBalancerName, serviceName.String(), g.region, ipAddressToUse, g.targetPoolURL(loadBalancerName), ports, netTier, g.enableDiscretePortForwarding); err != nil { + if err := createForwardingRule(g, loadBalancerName, serviceName.String(), g.region, ipAddressToUse, g.targetPoolURL(loadBalancerName), ports, netTier); err != nil { return nil, fmt.Errorf("failed to create forwarding rule for load balancer (%s): %v", lbRefStr, err) } // End critical section. It is safe to release the static IP (which @@ -299,9 +294,21 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string, klog.Infof("ensureExternalLoadBalancer(%s): Created forwarding rule, IP %s.", lbRefStr, ipAddressToUse) } + // We can create deny firewall rule only after making sure that the allow firewalls for nodes and healthchecks are created/updated to 999 priority + if g.enableL4DenyFirewallRule { + if err := g.ensureDenyNodeFirewall(apiService, loadBalancerName, ipAddressToUse, lbRefStr, hosts); err != nil { + return nil, fmt.Errorf("failed to ensure deny firewall rule for load balancer(%s): %v", lbRefStr, err) + } + } + status := &v1.LoadBalancerStatus{} status.Ingress = []v1.LoadBalancerIngress{{IP: ipAddressToUse}} + metricsState.Status = StatusSuccess + if g.enableL4DenyFirewallRule { + metricsState.DenyFirewall = DenyFirewallStatusIPv4 + } + syncResult.status = status return syncResult, nil } @@ -374,6 +381,21 @@ func (g *Cloud) ensureExternalLoadBalancerDeleted(clusterName, clusterID string, } return err }, + func() error { + if !g.enableL4DenyFirewallRollbackCleanup { + klog.Infof("ensureExternalLoadBalancerDeleted(%s): Skipping deleting deny firewall rule, as it hasn't been enabled.", lbRefStr) + return nil + } + klog.Infof("ensureExternalLoadBalancerDeleted(%s): Deleting deny firewall rule.", lbRefStr) + fwName := MakeFirewallDenyName(loadBalancerName) + err := ignoreNotFound(g.DeleteFirewall(fwName)) + if isForbidden(err) && g.OnXPN() { + klog.V(4).Infof("ensureExternalLoadBalancerDeleted(%s): Do not have permission to delete deny firewall rule %v (on XPN). Raising event.", lbRefStr, fwName) + g.raiseFirewallChangeNeededEvent(service, FirewallToGCloudDeleteCmd(fwName, g.NetworkProjectID())) + return nil + } + return err + }, // Even though we don't hold on to static IPs for load balancers, it's // possible that EnsureLoadBalancer left one around in a failed // creation/update attempt, so make sure we clean it up here just in case. @@ -404,6 +426,7 @@ func (g *Cloud) ensureExternalLoadBalancerDeleted(clusterName, clusterID string, klog.Errorf("Failed to remove finalizer '%s' from service %s - %v", NetLBFinalizerV1, service.Name, err) return err } + g.metricsCollector.DeleteL4NetLBService(serviceName.String()) return nil } @@ -567,7 +590,18 @@ func (g *Cloud) ensureTargetPoolAndHealthCheck(tpExists, tpNeedsRecreation bool, if hc, err := g.ensureHTTPHealthCheck(hcToCreate.Name, hcToCreate.RequestPath, int32(hcToCreate.Port)); err != nil || hc == nil { return fmt.Errorf("failed to ensure health check for %v port %d path %v: %v", loadBalancerName, hcToCreate.Port, hcToCreate.RequestPath, err) } + // Check whether it is nodes health check, which has different name from the load-balancer. + isNodesHealthCheck := hcToCreate.Name != serviceName.Name + if isNodesHealthCheck { + // Lock to prevent necessary nodes health check / firewall gets deleted. + g.sharedResourceLock.Lock() + defer g.sharedResourceLock.Unlock() + } + if err := g.ensureHTTPHealthCheckFirewall(svc, serviceName.String(), ipAddressToUse, g.region, clusterID, hosts, hcToCreate.Name, int32(hcToCreate.Port), isNodesHealthCheck); err != nil { + return fmt.Errorf("failed to ensure health check firewall %v for %v: %w", hcToCreate.Name, loadBalancerName, err) + } } + } else { // Panic worthy. klog.Errorf("ensureTargetPoolAndHealthCheck(%s): target pool not exists and doesn't need to be created.", lbRefStr) @@ -809,14 +843,9 @@ func (g *Cloud) forwardingRuleNeedsUpdate(name, region string, loadBalancerIP st // We never want to end up recreating resources because g api flaked. return true, false, "", err } - newPorts := []string{} - if frPorts := getPorts(ports); len(frPorts) <= maxForwardedPorts && g.enableDiscretePortForwarding { - newPorts = frPorts - newPortRange = "" - } - frEqualPorts := equalPorts(fwd.Ports, newPorts, fwd.PortRange, newPortRange, g.enableDiscretePortForwarding) - if !frEqualPorts { - klog.Infof("Forwarding rule port range / ports are not equal, old (port range: %v, ports: %v), new (port range: %v, ports: %v)", fwd.PortRange, fwd.Ports, newPortRange, newPorts) + + if newPortRange != fwd.PortRange { + klog.Infof("LoadBalancer port range for forwarding rule %v was expected to be %v, but was actually %v", fwd.Name, fwd.PortRange, newPortRange) return true, true, fwd.IPAddress, nil } @@ -894,25 +923,16 @@ func getProtocol(svcPorts []v1.ServicePort) (v1.Protocol, error) { return protocol, nil } -func getPorts(svcPorts []v1.ServicePort) []string { - ports := []string{} - for _, p := range svcPorts { - ports = append(ports, strconv.Itoa(int(p.Port))) - } - - return ports -} - func minMaxPort[T v1.ServicePort | string](svcPorts []T) (int32, int32) { minPort := int32(65536) maxPort := int32(0) for _, svcPort := range svcPorts { port := func(value any) int32 { - switch value.(type) { + switch value := value.(type) { case v1.ServicePort: - return value.(v1.ServicePort).Port + return value.Port case string: - i, _ := strconv.ParseInt(value.(string), 10, 32) + i, _ := strconv.ParseInt(value, 10, 32) return int32(i) default: return 0 @@ -937,23 +957,6 @@ func loadBalancerPortRange[T v1.ServicePort | string](svcPorts []T) (string, err return fmt.Sprintf("%d-%d", minPort, maxPort), nil } -// equalPorts compares two port ranges or slices of ports. Before comparison, -// slices of ports are converted into a port range from smallest to largest -// port. This is done so we don't unnecessarily recreate forwarding rules -// when upgrading from port ranges to distinct ports, because recreating -// forwarding rules is traffic impacting. -func equalPorts(existingPorts, newPorts []string, existingPortRange, newPortRange string, enableDiscretePortForwarding bool) bool { - if len(existingPorts) != 0 || !enableDiscretePortForwarding { - return equalStringSets(existingPorts, newPorts) && existingPortRange == newPortRange - } - // Existing forwarding rule contains a port range. To keep it that way, - // compare new list of ports as if it was a port range, too. - if len(newPorts) != 0 { - newPortRange, _ = loadBalancerPortRange(newPorts) - } - return existingPortRange == newPortRange -} - // translate from what K8s supports to what the cloud provider supports for session affinity. func translateAffinityType(affinityType v1.ServiceAffinity) string { switch affinityType { @@ -967,7 +970,7 @@ func translateAffinityType(affinityType v1.ServiceAffinity) string { } } -func (g *Cloud) firewallNeedsUpdate(name, serviceName, ipAddress string, ports []v1.ServicePort, sourceRanges utilnet.IPNetSet) (exists bool, needsUpdate bool, err error) { +func (g *Cloud) firewallNeedsUpdate(name, serviceName, ipAddress string, ports []v1.ServicePort, sourceRanges utilnet.IPNetSet, priority int64) (exists bool, needsUpdate bool, err error) { fw, err := g.GetFirewall(MakeFirewallName(name)) if err != nil { if isHTTPErrorCode(err, http.StatusNotFound) { @@ -1009,6 +1012,10 @@ func (g *Cloud) firewallNeedsUpdate(name, serviceName, ipAddress string, ports [ return true, true, nil } + if fw.Priority != priority { + return true, true, nil + } + return true, false, nil } @@ -1020,6 +1027,10 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd } sourceRanges := l4LbSrcRngsFlag.ipn ports := []v1.ServicePort{{Protocol: "tcp", Port: hcPort}} + allowPriority := firewallPriorityDefault + if g.enableL4DenyFirewallRule { + allowPriority = firewallPriorityAllow + } fwName := MakeHealthCheckFirewallName(clusterID, hcName, isNodesHealthCheck) fw, err := g.GetFirewall(fwName) @@ -1028,7 +1039,7 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd return fmt.Errorf("error getting firewall for health checks: %v", err) } klog.Infof("Creating firewall %v for health checks.", fwName) - if err := g.createFirewall(svc, fwName, desc, ipAddress, sourceRanges, ports, hosts); err != nil { + if err := g.createFirewall(svc, fwName, desc, ipAddress, sourceRanges, ports, hosts, allowPriority); err != nil { return err } klog.Infof("Created firewall %v for health checks.", fwName) @@ -1039,9 +1050,10 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd len(fw.Allowed) != 1 || fw.Allowed[0].IPProtocol != string(ports[0].Protocol) || !equalStringSets(fw.Allowed[0].Ports, []string{strconv.Itoa(int(ports[0].Port))}) || - !equalStringSets(fw.SourceRanges, sourceRanges.StringSlice()) { + !equalStringSets(fw.SourceRanges, sourceRanges.StringSlice()) || + fw.Priority != int64(allowPriority) { klog.Warningf("Firewall %v exists but parameters have drifted - updating...", fwName) - if err := g.updateFirewall(svc, fwName, desc, ipAddress, sourceRanges, ports, hosts); err != nil { + if err := g.updateFirewall(svc, fwName, desc, ipAddress, sourceRanges, ports, hosts, allowPriority); err != nil { klog.Warningf("Failed to reconcile firewall %v parameters.", fwName) return err } @@ -1050,8 +1062,7 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd return nil } -func createForwardingRule(s CloudForwardingRuleService, name, serviceName, region, ipAddress, target string, ports []v1.ServicePort, netTier cloud.NetworkTier, enableDiscretePortForwarding bool) error { - frPorts := getPorts(ports) +func createForwardingRule(s CloudForwardingRuleService, name, serviceName, region, ipAddress, target string, ports []v1.ServicePort, netTier cloud.NetworkTier) error { protocol, err := getProtocol(ports) if err != nil { return err @@ -1072,11 +1083,6 @@ func createForwardingRule(s CloudForwardingRuleService, name, serviceName, regio NetworkTier: netTier.ToGCEValue(), } - if len(frPorts) <= maxForwardedPorts && enableDiscretePortForwarding { - rule.Ports = frPorts - rule.PortRange = "" - } - err = s.CreateRegionForwardingRule(rule, region) if err != nil && !isHTTPErrorCode(err, http.StatusConflict) { @@ -1086,8 +1092,8 @@ func createForwardingRule(s CloudForwardingRuleService, name, serviceName, regio return nil } -func (g *Cloud) createFirewall(svc *v1.Service, name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error { - firewall, err := g.firewallObject(name, desc, destinationIP, sourceRanges, ports, hosts) +func (g *Cloud) createFirewall(svc *v1.Service, name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance, priority int) error { + firewall, err := g.firewallObject(name, desc, destinationIP, sourceRanges, ports, hosts, priority) if err != nil { return err } @@ -1104,8 +1110,8 @@ func (g *Cloud) createFirewall(svc *v1.Service, name, desc, destinationIP string return nil } -func (g *Cloud) updateFirewall(svc *v1.Service, name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error { - firewall, err := g.firewallObject(name, desc, destinationIP, sourceRanges, ports, hosts) +func (g *Cloud) updateFirewall(svc *v1.Service, name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance, priority int) error { + firewall, err := g.firewallObject(name, desc, destinationIP, sourceRanges, ports, hosts, priority) if err != nil { return err } @@ -1123,7 +1129,7 @@ func (g *Cloud) updateFirewall(svc *v1.Service, name, desc, destinationIP string return nil } -func (g *Cloud) firewallObject(name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) (*compute.Firewall, error) { +func (g *Cloud) firewallObject(name, desc, destinationIP string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance, priority int) (*compute.Firewall, error) { // destinationIP can be empty string "" and this means that it is not set. // GCE considers empty destinationRanges as "all" for ingress firewall-rules. // Concatenate service ports into port ranges. This help to workaround the gce firewall limitation where only @@ -1157,6 +1163,7 @@ func (g *Cloud) firewallObject(name, desc, destinationIP string, sourceRanges ut Ports: portRanges, }, }, + Priority: int64(priority), } if destinationIP != "" { firewall.DestinationRanges = []string{destinationIP} @@ -1164,6 +1171,243 @@ func (g *Cloud) firewallObject(name, desc, destinationIP string, sourceRanges ut return firewall, nil } +func (g *Cloud) ensureAllowNodeFirewall(apiService *v1.Service, loadBalancerName, ipAddressToUse, lbRefStr string, hosts []*gceInstance) error { + fwAllowName := MakeFirewallName(loadBalancerName) + + // Check if user specified the allow source range + sourceRanges, err := servicehelpers.GetLoadBalancerSourceRanges(apiService) + if err != nil { + return err + } + ports := apiService.Spec.Ports + + serviceName := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name}.String() + desc := makeFirewallDescription(serviceName, ipAddressToUse) + + allowPriority := firewallPriorityDefault + if g.enableL4DenyFirewallRule { + allowPriority = firewallPriorityAllow + } + + firewallExists, firewallNeedsUpdate, err := g.firewallNeedsUpdate(loadBalancerName, serviceName, ipAddressToUse, ports, sourceRanges, int64(allowPriority)) + if err != nil { + return err + } + + if firewallNeedsUpdate { + // Unlike forwarding rules and target pools, firewalls can be updated + // without needing to be deleted and recreated. + if firewallExists { + klog.Infof("ensureExternalLoadBalancer(%s): Updating firewall.", lbRefStr) + if err := g.updateFirewall(apiService, fwAllowName, desc, ipAddressToUse, sourceRanges, ports, hosts, allowPriority); err != nil { + return err + } + klog.Infof("ensureExternalLoadBalancer(%s): Updated firewall.", lbRefStr) + } else { + klog.Infof("ensureExternalLoadBalancer(%s): Creating firewall.", lbRefStr) + if err := g.createFirewall(apiService, fwAllowName, desc, ipAddressToUse, sourceRanges, ports, hosts, allowPriority); err != nil { + return err + } + klog.Infof("ensureExternalLoadBalancer(%s): Created firewall.", lbRefStr) + } + } + return nil +} + +func (g *Cloud) ensureDenyNodeFirewall(apiService *v1.Service, loadBalancerName, ipAddressToUse, lbRefStr string, hosts []*gceInstance) error { + // If the node tags to be used for this cluster have been predefined in the + // provider config, just use them. Otherwise, invoke computeHostTags method to get the tags. + hostTags := g.nodeTags + if len(hostTags) == 0 { + var err error + if hostTags, err = g.computeHostTags(hosts); err != nil { + return fmt.Errorf("no node tags supplied and also failed to parse the given lists of hosts for tags. Abort ensuring firewall rule") + } + } + name := MakeFirewallDenyName(loadBalancerName) + serviceName := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name}.String() + desc := makeFirewallDescription(serviceName, ipAddressToUse) + + want := &compute.Firewall{ + Name: name, + Description: desc, + Network: g.networkURL, + SourceRanges: []string{"0.0.0.0/0"}, + DestinationRanges: []string{ipAddressToUse}, + TargetTags: hostTags, + Denied: []*compute.FirewallDenied{{IPProtocol: "all"}}, + Priority: firewallPriorityDeny, + } + + got, err := g.GetFirewall(name) + if ignoreNotFound(err) != nil { + return err + } + + if create := isNotFound(err); create { + klog.Infof("ensureDenyNodeFirewall(%s): Creating firewall %q.", lbRefStr, name) + if err := g.CreateFirewall(want); err != nil { + if isForbidden(err) && g.OnXPN() { + klog.V(4).Infof("ensureDenyNodeFirewall(%q): Do not have permission to create firewall rule (on XPN) %q. Skipping creation.", name, err) + return nil + } + return err + } + klog.Infof("ensureDenyNodeFirewall(%s): Created firewall %q.", lbRefStr, name) + return nil + } + + // otherwise exists + equal, err := firewallsEqual(got, want) + if err != nil { + return err + } + if equal { + // no need to update + klog.Infof("ensureDenyNodeFirewall(%s): Firewall %q already exists and is up to date.", lbRefStr, name) + return nil + } + + // needs update + klog.Infof("ensureDenyNodeFirewall(%s): Updating firewall %q.", lbRefStr, name) + if err := g.PatchFirewall(want); err != nil { + if isForbidden(err) && g.OnXPN() { + klog.V(4).Infof("ensureDenyNodeFirewall(%q): Do not have permission to update firewall rule (on XPN) %q. Skipping update.", name, err) + return nil + } + return err + } + + return nil +} + +func (g *Cloud) ensureFirewallDeleted(fwName string) error { + // We do an additional call to check if the resource is there + // If it isn't there we don't call delete which will leave the + // 404 in the project Audit Logs. + _, err := g.GetFirewall(fwName) + if isNotFound(err) || (isForbidden(err) && g.OnXPN()) { + klog.V(4).Infof("ensureFirewallDeleted(%q): Firewall does not exist or do not have permission to delete (on XPN) %q. Skipping deletion.", fwName, err) + return nil + } + if err != nil { + return err + } + + // If there is a firewall we delete it. + err = ignoreNotFound(g.DeleteFirewall(fwName)) + if isForbidden(err) && g.OnXPN() { + klog.V(4).Infof("ensureFirewallDeleted(%q): Do not have permission to delete firewall rule (on XPN) %q. Skipping deletion.", fwName, err) + return nil + } + return err +} + +func firewallsEqual(a, b *compute.Firewall) (bool, error) { + if equal, err := ipRangesEqual(a.SourceRanges, b.SourceRanges); !equal || err != nil { + return equal, err + } + if equal, err := firewallEffectsEqual(a.Allowed, b.Allowed); !equal || err != nil { + return equal, err + } + if equal, err := firewallEffectsEqual(a.Denied, b.Denied); !equal || err != nil { + return equal, err + } + + return a.Priority == b.Priority && reflect.DeepEqual(a.DestinationRanges, b.DestinationRanges) && + a.Description == b.Description, nil +} + +func firewallEffectsEqual[T compute.FirewallAllowed | compute.FirewallDenied](a, b []*T) (bool, error) { + mapA, err := portsPerProtocol(a) + if err != nil { + return false, err + } + + mapB, err := portsPerProtocol(b) + if err != nil { + return false, err + } + + return reflect.DeepEqual(mapA, mapB), nil +} + +type protocol string + +func portsPerProtocol[T compute.FirewallAllowed | compute.FirewallDenied](a []*T) (map[protocol]sets.Set[int], error) { + mapped := make(map[protocol]sets.Set[int]) + for _, item := range a { + var protoStr string + var ports []string + + switch v := any(item).(type) { + case *compute.FirewallAllowed: + protoStr = v.IPProtocol + ports = v.Ports + case *compute.FirewallDenied: + protoStr = v.IPProtocol + ports = v.Ports + } + proto := protocol(strings.ToUpper(protoStr)) + + if _, ok := mapped[proto]; !ok { + mapped[proto] = sets.New[int]() + } + for _, port := range ports { + start, end, err := parsePort(port) + if err != nil { + return nil, err + } + for i := start; i <= end; i++ { + mapped[proto].Insert(i) + } + } + } + return mapped, nil +} + +// parsePort parses firewall definition of ports to a int inclusive range [start, end] +// +// For example: +// - "80" will return 80, 80, nil +// - "443-8080" will return 443, 8080, nil +// +// It returns an error when int parsing has failed. +func parsePort(portStr string) (int, int, error) { + parts := strings.Split(portStr, "-") + + if len(parts) == 2 { + start, errStart := strconv.Atoi(parts[0]) + end, errEnd := strconv.Atoi(parts[1]) + if errStart != nil || errEnd != nil { + return 0, -1, errors.Join(errStart, errEnd) + } + return start, end, nil + } + + if len(parts) == 1 { + port, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, -1, err + } + return port, port, nil + } + + return 0, -1, fmt.Errorf("unexpected port format %q, expects single integer or a range delimited by `-`", portStr) +} + +func ipRangesEqual(a, b []string) (bool, error) { + as, err := utilnet.ParseIPNets(a...) + if err != nil { + return false, err + } + bs, err := utilnet.ParseIPNets(b...) + if err != nil { + return false, err + } + return as.Equal(bs), nil +} + func ensureStaticIP(s CloudAddressService, name, serviceName, region, existingIP string, netTier cloud.NetworkTier) (ipAddress string, existing bool, err error) { // If the address doesn't exist, this will create it. // If the existingIP exists but is ephemeral, this will promote it to static. diff --git a/providers/gce/gce_loadbalancer_external_deny_test.go b/providers/gce/gce_loadbalancer_external_deny_test.go new file mode 100644 index 0000000000..647e40b602 --- /dev/null +++ b/providers/gce/gce_loadbalancer_external_deny_test.go @@ -0,0 +1,613 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gce + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "testing" + + "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud" + "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/meta" + "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/mock" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/api/compute/v1" + "google.golang.org/api/googleapi" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/record" +) + +const ( + fakeDenyFirewallName = "k8s-fw-a-deny" + fakeNodeFirewallName = "k8s-fw-a" + fakeHealthCheckFirewallName = "k8s-test-cluster-id-node-http-hc" +) + +// firewallTracker is used to check if there are multiple firewalls +// on the same IP or IP range that have conflicting priority, which +// would result in blocking all traffic. +type firewallTracker struct { + // firewalls contains all firewalls for IP specified in the key + // for IPv6 ranges we store just the prefix + firewalls map[ipPrefix]map[resourceName]*compute.Firewall + + mu sync.Mutex +} + +type ( + ipPrefix string + resourceName string + firewallMap map[ipPrefix]map[resourceName]*compute.Firewall +) + +// patch will return an error if there is a situation that modifying fw +// would cause blocking traffic - allow overruled by deny firewall. +func (f *firewallTracker) patch(fw *compute.Firewall) error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.firewalls == nil { + f.firewalls = make(firewallMap) + } + + if len(fw.DestinationRanges) < 1 { // Does not concern us - likely for healthcheck + return nil + } + + if len(fw.DestinationRanges) > 1 { + return fmt.Errorf("unexpected count of destination ranges, expected at most 1: %v", fw.DestinationRanges) + } + + cidrOrIP := fw.DestinationRanges[0] + key := ipPrefix(strings.TrimSuffix(cidrOrIP, "/96")) + + if f.firewalls[key] == nil { + f.firewalls[key] = make(map[resourceName]*compute.Firewall) + } + rName := resourceName(fw.Name) + f.firewalls[key][rName] = fw + + for _, other := range f.firewalls[key] { + if fw.Name == other.Name { + continue + } + if areBlocked(fw, other) { + return fmt.Errorf( + "two firewalls block each other on %q: %s (priority %d) and %s (priority %d)", + key, fw.Name, fw.Priority, other.Name, other.Priority, + ) + } + } + + return nil +} + +func (f *firewallTracker) delete(name string) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.firewalls == nil { + return + } + + // this could be done a tad quicker with an additional map + // but this should be fast enough for the test + for _, fw := range f.firewalls { + delete(fw, resourceName(name)) + } +} + +func (f *firewallTracker) hookTo(mockGCE *cloud.MockGCE) { + mockGCE.MockFirewalls.InsertHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) (bool, error) { + if err := f.patch(obj); err != nil { + return true, err + } + return false, nil + } + mockGCE.MockFirewalls.UpdateHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) error { + if err := f.patch(obj); err != nil { + return err + } + return mock.UpdateFirewallHook(ctx, key, obj, m, options...) + } + mockGCE.MockFirewalls.PatchHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) error { + if err := f.patch(obj); err != nil { + return err + } + return mock.UpdateFirewallHook(ctx, key, obj, m, options...) + } + mockGCE.MockFirewalls.DeleteHook = func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls, options ...cloud.Option) (bool, error) { + f.delete(key.Name) + return false, nil + } +} + +// areBlocked only works if fw1 and fw2 are using the same +// destination range, direction, etc +func areBlocked(fw1, fw2 *compute.Firewall) bool { + if fw1 == nil || fw2 == nil { + return false + } + + if len(fw2.Denied) > 0 { + fw1, fw2 = fw2, fw1 + } + + // Both are deny or allow - won't block themselves + if len(fw1.Denied) == 0 || len(fw2.Allowed) == 0 { + return false + } + + // deny takes precedence over allow if they have the same priority + return fw1.Priority <= fw2.Priority +} + +func TestDenyFirewall(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Setup + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + if err != nil { + t.Fatalf("fakeGCECloud error: %v", err) + } + + // Hook firewall tracker which will throw errors on mockGCE calls when firewalls are blocking each other + tracker := &firewallTracker{} + mockGCE := gce.Compute().(*cloud.MockGCE) + tracker.hookTo(mockGCE) + + // Create service and nodes + svc := fakeLoadbalancerService("") + nodeName := "test-node-1" + nodes, err := createAndInsertNodes(gce, []string{nodeName}, vals.ZoneName) + if err != nil { + t.Fatalf("createAndInsertNodes error: %v", err) + } + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(ctx, svc, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + // Ensure with deny enabled + gce.enableL4DenyFirewallRule = true + gce.enableL4DenyFirewallRollbackCleanup = true + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=false) error: %v", err) + } + + // Verify health check firewall exists at 999 + fw, err := gce.GetFirewall(fakeHealthCheckFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeHealthCheckFirewallName, err) + } + if fw.Priority != 999 { + t.Errorf("Allow firewall priority = %d, want 999", fw.Priority) + } + + wantDestinationRange := []string{"1.2.3.0"} + + // Verify node firewall exists at 999 + fw, err = gce.GetFirewall(fakeNodeFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeNodeFirewallName, err) + } + if fw.Priority != 999 { + t.Errorf("Allow firewall priority = %d, want 999", fw.Priority) + } + if diff := cmp.Diff(wantDestinationRange, fw.DestinationRanges); diff != "" { + t.Errorf("allow destination range got != want (-want, +got)\n%s", diff) + } + + // Verify deny firewall + fw, err = gce.GetFirewall(fakeDenyFirewallName) + if err != nil { + t.Errorf("GetFirewall(%q) error: %v", fakeDenyFirewallName, err) + } + want := &compute.Firewall{ + Name: fakeDenyFirewallName, + Denied: []*compute.FirewallDenied{{IPProtocol: "all"}}, + Description: `{"kubernetes.io/service-name":"/fakesvc", "kubernetes.io/service-ip":"1.2.3.0"}`, + DestinationRanges: wantDestinationRange, + SourceRanges: []string{"0.0.0.0/0"}, + TargetTags: []string{nodeName}, + Priority: 1000, + } + fwCmpOpt := cmpopts.IgnoreFields(compute.Firewall{}, "SelfLink") + if diff := cmp.Diff(want, fw, fwCmpOpt); diff != "" { + t.Errorf("deny firewalls got != want (-want, +got)\n%s", diff) + } +} + +func TestDenyRollforwardDoesNotBlockTraffic(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Setup + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + if err != nil { + t.Fatalf("fakeGCECloud error: %v", err) + } + + // Hook firewall tracker which will throw errors on mockGCE calls when firewalls are blocking each other + tracker := &firewallTracker{} + mockGCE := gce.Compute().(*cloud.MockGCE) + tracker.hookTo(mockGCE) + + // Create service and nodes + svc := fakeLoadbalancerService("") + nodeName := "test-node-1" + nodes, err := createAndInsertNodes(gce, []string{nodeName}, vals.ZoneName) + if err != nil { + t.Fatalf("createAndInsertNodes error: %v", err) + } + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(ctx, svc, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + // 1. Ensure with Deny Disabled + gce.enableL4DenyFirewallRule = false + gce.enableL4DenyFirewallRollbackCleanup = true + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=false) error: %v", err) + } + + // Verify Allow exists at 1000 + fw, err := gce.GetFirewall(fakeNodeFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeNodeFirewallName, err) + } + if fw.Priority != 1000 { + t.Errorf("Allow firewall priority = %d, want 1000", fw.Priority) + } + + // Verify Health Check firewall exists at 1000 + fw, err = gce.GetFirewall(fakeHealthCheckFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeHealthCheckFirewallName, err) + } + if fw.Priority != 1000 { + t.Errorf("Allow firewall priority = %d, want 1000", fw.Priority) + } + + // Verify Deny does not exist + _, err = gce.GetFirewall(fakeDenyFirewallName) + if !isNotFound(err) { + t.Errorf("Deny firewall %q should not exist, err: %v", fakeDenyFirewallName, err) + } + + // 2. Ensure with Deny Enabled (Rollforward) + gce.enableL4DenyFirewallRule = true + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=true) error: %v", err) + } + + // Verify Allow exists at 999 + fw, err = gce.GetFirewall(fakeNodeFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeNodeFirewallName, err) + } + if fw.Priority != 999 { + t.Errorf("Allow firewall priority = %d, want 999", fw.Priority) + } + + // Verify Healthcheck Firewall exists at 999 + fwHealthcheck, err := gce.GetFirewall(fakeHealthCheckFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeHealthCheckFirewallName, err) + } + if fwHealthcheck.Priority != 999 { + t.Errorf("Healthcheck firewall priority = %d, want 999", fwHealthcheck.Priority) + } + + // Verify Deny exists at 1000 + fwDeny, err := gce.GetFirewall(fakeDenyFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeDenyFirewallName, err) + } + if fwDeny.Priority != 1000 { + t.Errorf("Deny firewall priority = %d, want 1000", fwDeny.Priority) + } + + // 3. Delete service + err = gce.ensureExternalLoadBalancerDeleted(vals.ClusterName, vals.ClusterID, svc) + if err != nil { + t.Fatal(err) + } + + // Verify firewalls are cleaned up + for _, fwName := range []string{fakeNodeFirewallName, fakeHealthCheckFirewallName, fakeDenyFirewallName} { + got, err := gce.GetFirewall(fwName) + if got != nil { + t.Errorf("firewall %v wasn't deleted after delete service", fwName) + } + if !isNotFound(err) { + t.Errorf("got unexpected err %v when checking for deleted %v firewall", err, fwName) + } + } +} + +func TestDenyRollback(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Setup + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + if err != nil { + t.Fatalf("fakeGCECloud error: %v", err) + } + gce.eventRecorder = record.NewFakeRecorder(1024) + + // Hook firewall tracker which will throw errors on mockGCE calls when firewalls are blocking each other + tracker := &firewallTracker{} + mockGCE := gce.Compute().(*cloud.MockGCE) + tracker.hookTo(mockGCE) + + // Create service and nodes + svc := fakeLoadbalancerService("") + nodeName := "test-node-1" + nodes, err := createAndInsertNodes(gce, []string{nodeName}, vals.ZoneName) + if err != nil { + t.Fatalf("createAndInsertNodes error: %v", err) + } + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(ctx, svc, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + // 1. Ensure with Deny Enabled + gce.enableL4DenyFirewallRule = true + gce.enableL4DenyFirewallRollbackCleanup = true + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=true) error: %v", err) + } + + // 2. Ensure with Deny Disabled (Rollback) + gce.enableL4DenyFirewallRule = false + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=false) error: %v", err) + } + + // Verify Allow exists at 1000 + fw, err := gce.GetFirewall(fakeNodeFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeNodeFirewallName, err) + } + if fw.Priority != 1000 { + t.Errorf("Allow firewall priority = %d, want 1000", fw.Priority) + } + + // Verify Health Check firewall exists at 1000 + fw, err = gce.GetFirewall(fakeHealthCheckFirewallName) + if err != nil { + t.Fatalf("GetFirewall(%q) error: %v", fakeNodeFirewallName, err) + } + if fw.Priority != 1000 { + t.Errorf("Health check firewall priority = %d, want 1000", fw.Priority) + } + + // Verify Deny does not exist + _, err = gce.GetFirewall(fakeDenyFirewallName) + if !isNotFound(err) { + t.Errorf("Deny firewall %q should not exist, err: %v", fakeDenyFirewallName, err) + } +} + +func TestDenyIsNotCreatedWhenPriorityUpdateFails(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + firewallNameToFail string + }{ + { + name: "node_firewall", + firewallNameToFail: fakeNodeFirewallName, + }, + { + name: "healthcheck_firewall", + firewallNameToFail: fakeHealthCheckFirewallName, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() + // Setup + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + if err != nil { + t.Fatalf("fakeGCECloud error: %v", err) + } + + svc := fakeLoadbalancerService("") + nodeName := "test-node-1" + nodes, err := createAndInsertNodes(gce, []string{nodeName}, vals.ZoneName) + if err != nil { + t.Fatalf("createAndInsertNodes error: %v", err) + } + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(ctx, svc, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + // 1. Ensure with Deny Disabled just to provision firewalls with 1000 priority + gce.enableL4DenyFirewallRule = false + gce.enableL4DenyFirewallRollbackCleanup = false + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + if err != nil { + t.Fatalf("ensureExternalLoadBalancer(deny=false) error: %v", err) + } + + // 2. Inject error on Allow update + mockGCE := gce.Compute().(*cloud.MockGCE) + injectedError := errors.New("injected error on allow patch") + + mockGCE.MockFirewalls.PatchHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) error { + if key.Name == tc.firewallNameToFail { + return injectedError + } + return mock.UpdateFirewallHook(ctx, key, obj, m, options...) + } + mockGCE.MockFirewalls.UpdateHook = mockGCE.MockFirewalls.PatchHook + + // 3. Ensure with Deny Enabled to force priority decrease + gce.enableL4DenyFirewallRule = true + gce.enableL4DenyFirewallRollbackCleanup = true + + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + + // Assert error returned + if err == nil || !strings.Contains(err.Error(), injectedError.Error()) { + t.Errorf("got unexpected err %q, wanted %q", err, injectedError) + } + + // Assert Deny rule NOT created + _, err = gce.GetFirewall(fakeDenyFirewallName) + if !isNotFound(err) { + t.Errorf("Deny firewall %q should not exist after failure to update %q rule", fakeDenyFirewallName, tc.firewallNameToFail) + } + }) + } +} + +// TestContinueOnXPN403s verifies that we don't error out on XPN (shared VPC) clusters that don't have permissions to create firewalls +func TestContinueOnXPN403s(t *testing.T) { + testCases := []struct { + name string + denyFirewallEnabled bool + denyFirewallCleanupEnabled bool + }{ + { + name: "disabled", + }, + { + name: "rolled_back", + denyFirewallEnabled: true, + }, + { + name: "enabled", + denyFirewallEnabled: true, + denyFirewallCleanupEnabled: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Arrange + vals := DefaultTestClusterValues() + vals.OnXPN = true + gce, err := fakeGCECloud(vals) + if err != nil { + t.Fatalf("fakeGCECloud error: %v", err) + } + + gce.enableL4DenyFirewallRule = tc.denyFirewallEnabled + gce.enableL4DenyFirewallRollbackCleanup = tc.denyFirewallCleanupEnabled + + svc := fakeLoadbalancerService("") + nodeName := "test-node-1" + nodes, err := createAndInsertNodes(gce, []string{nodeName}, vals.ZoneName) + if err != nil { + t.Fatalf("createAndInsertNodes error: %v", err) + } + ctx := t.Context() + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(ctx, svc, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + xpnErr := func(call, name string) *googleapi.Error { + return &googleapi.Error{ + Code: http.StatusForbidden, + Message: fmt.Sprintf("Required 'compute.firewalls.%s' permission for 'projects/something/global/firewalls/%s'.", call, name), + } + } + + mockGCE := gce.Compute().(*cloud.MockGCE) + mockGCE.MockFirewalls.InsertHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) (bool, error) { + return true, xpnErr("insert", key.Name) + } + mockGCE.MockFirewalls.GetHook = func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls, options ...cloud.Option) (bool, *compute.Firewall, error) { + return false, nil, nil // For some reason get doesn't return 403 + } + mockGCE.MockFirewalls.DeleteHook = func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls, options ...cloud.Option) (bool, error) { + return true, xpnErr("delete", key.Name) + } + mockGCE.MockFirewalls.PatchHook = func(ctx context.Context, key *meta.Key, obj *compute.Firewall, m *cloud.MockFirewalls, options ...cloud.Option) error { + return xpnErr("patch", key.Name) + } + mockGCE.MockFirewalls.UpdateHook = mockGCE.MockFirewalls.PatchHook + + // Act: create load balancer + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + + // Assert: we don't expect any errors to be returned + if err != nil { + t.Fatal(err) + } + + // Assert: no firewalls were created + fwNames := []string{"k8s-fw-a", "k8s-fw-a-deny"} + for _, name := range fwNames { + fw, _ := gce.GetFirewall(name) + if fw != nil { + t.Errorf("something is wrong with the test logic, the firewall %v should not have been created", name) + } + } + + // Assert: forwarding rule exists either way + forwardingRuleName := "a" + _, err = gce.GetRegionForwardingRule(forwardingRuleName, gce.Region()) + if err != nil { + t.Errorf("something is wrong with the test logic, the forwarding rule %v should have been created, but got %v", forwardingRuleName, err) + } + + // Act: delete the service + err = gce.ensureExternalLoadBalancerDeleted(vals.ClusterName, vals.ClusterID, svc) + if err != nil { + t.Fatal(err) + } + + // Assert: forwarding rules were cleaned up + fr, err := gce.GetRegionForwardingRule(forwardingRuleName, gce.Region()) + if !isNotFound(err) || fr != nil { + t.Errorf("something is wrong with the test logic, the forwarding rule %v should have been cleaned up, but got err: %v and fw: %v", forwardingRuleName, err, fr) + } + }) + } +} diff --git a/providers/gce/gce_loadbalancer_external_test.go b/providers/gce/gce_loadbalancer_external_test.go index a299a674bb..4d7d72012c 100644 --- a/providers/gce/gce_loadbalancer_external_test.go +++ b/providers/gce/gce_loadbalancer_external_test.go @@ -183,7 +183,8 @@ func TestMinMaxPortRange(t *testing.T) { svcPorts: []v1.ServicePort{ {Port: 1}, {Port: 10}, - {Port: 100}}, + {Port: 100}, + }, expectedRange: "1-100", expectError: false, }, @@ -193,20 +194,23 @@ func TestMinMaxPortRange(t *testing.T) { {Port: 1}, {Port: 50}, {Port: 100}, - {Port: 90}}, + {Port: 90}, + }, expectedRange: "1-100", expectError: false, }, { svcPorts: []v1.ServicePort{ - {Port: 10}}, + {Port: 10}, + }, expectedRange: "10-10", expectError: false, }, { svcPorts: []v1.ServicePort{ {Port: 100}, - {Port: 10}}, + {Port: 10}, + }, expectedRange: "10-100", expectError: false, }, @@ -214,7 +218,8 @@ func TestMinMaxPortRange(t *testing.T) { svcPorts: []v1.ServicePort{ {Port: 100}, {Port: 50}, - {Port: 10}}, + {Port: 10}, + }, expectedRange: "10-100", expectError: false, }, @@ -285,7 +290,7 @@ func TestCreateForwardingRuleWithTier(t *testing.T) { lbName := tc.expectedRule.Name ipAddr := tc.expectedRule.IPAddress - err = createForwardingRule(s, lbName, serviceName, s.region, ipAddr, target, ports, tc.netTier, false) + err = createForwardingRule(s, lbName, serviceName, s.region, ipAddr, target, ports, tc.netTier) assert.NoError(t, err) Rule, err := s.GetRegionForwardingRule(lbName, s.region) @@ -323,90 +328,49 @@ func TestCreateForwardingRulePorts(t *testing.T) { wideRangePortsTCP := basePortsTCP[:] for _, tc := range []struct { - desc string - frName string - ports []v1.ServicePort - discretePortForwarding bool - expectedPorts []string - expectedPortRange string + desc string + frName string + ports []v1.ServicePort + expectedPorts []string + expectedPortRange string }{ { - desc: "Single Port, discretePorts enabled", - frName: "fwd-rule1", - ports: onePortUDP, - discretePortForwarding: true, - expectedPorts: []string{"80"}, - expectedPortRange: "", + desc: "Single Port (PortRange)", + frName: "fwd-rule5", + ports: onePortUDP, + expectedPorts: []string{}, + expectedPortRange: "80-80", }, { - desc: "Individual Ports, discretePorts enabled", - frName: "fwd-rule2", - ports: fivePortsTCP, - discretePortForwarding: true, - expectedPorts: []string{"80", "81", "82", "83", "84"}, - expectedPortRange: "", + desc: "5 Ports PortRange", + frName: "fwd-rule6", + ports: fivePortsTCP, + expectedPorts: []string{}, + expectedPortRange: "80-84", }, { - desc: "PortRange, discretePorts enabled", - frName: "fwd-rule3", - ports: sixPortsTCP, - discretePortForwarding: true, - expectedPorts: []string{}, - expectedPortRange: "80-85", + desc: "6 ports PortRange", + frName: "fwd-rule7", + ports: sixPortsTCP, + expectedPorts: []string{}, + expectedPortRange: "80-85", }, { - desc: "Wide PortRange, discretePorts enabled", - frName: "fwd-rule4", - ports: wideRangePortsTCP, - discretePortForwarding: true, - expectedPorts: []string{}, - expectedPortRange: "80-8080", - }, - { - desc: "Single Port (PortRange)", - frName: "fwd-rule5", - ports: onePortUDP, - discretePortForwarding: false, - expectedPorts: []string{}, - expectedPortRange: "80-80", - }, - { - desc: "5 Ports PortRange", - frName: "fwd-rule6", - ports: fivePortsTCP, - discretePortForwarding: false, - expectedPorts: []string{}, - expectedPortRange: "80-84", - }, - { - desc: "6 ports PortRange", - frName: "fwd-rule7", - ports: sixPortsTCP, - discretePortForwarding: false, - expectedPorts: []string{}, - expectedPortRange: "80-85", - }, - { - desc: "Wide PortRange", - frName: "fwd-rule8", - ports: wideRangePortsTCP, - discretePortForwarding: false, - expectedPorts: []string{}, - expectedPortRange: "80-8080", + desc: "Wide PortRange", + frName: "fwd-rule8", + ports: wideRangePortsTCP, + expectedPorts: []string{}, + expectedPortRange: "80-8080", }, } { t.Run(tc.desc, func(t *testing.T) { gce, err := fakeGCECloud(vals) require.NoError(t, err) - if tc.discretePortForwarding { - gce.SetEnableDiscretePortForwarding(true) - } - frName := tc.frName ports := tc.ports - err = createForwardingRule(gce, frName, serviceName, gce.region, ipAddr, target, ports, cloud.NetworkTierStandard, tc.discretePortForwarding) + err = createForwardingRule(gce, frName, serviceName, gce.region, ipAddr, target, ports, cloud.NetworkTierStandard) assert.NoError(t, err) fwdRule, err := gce.GetRegionForwardingRule(frName, gce.region) @@ -546,7 +510,6 @@ func TestShouldNotRecreateLBWhenNetworkTiersMismatch(t *testing.T) { mutateSvc: func(service *v1.Service) { svc.Annotations[NetworkTierAnnotationKey] = string(NetworkTierAnnotationStandard) svc.Spec.LoadBalancerIP = staticIP - }, expectNetTier: NetworkTierAnnotationStandard.ToGCEValue(), }, @@ -754,7 +717,6 @@ func TestLoadBalancerWrongTierResourceDeletion(t *testing.T) { gce.targetPoolURL(lbName), svc.Spec.Ports, cloud.NetworkTierStandard, - false, ) require.NoError(t, err) @@ -1229,10 +1191,6 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { vals := DefaultTestClusterValues() serviceName := "foo-svc" - onePortTCP8080 := []v1.ServicePort{ - {Name: "tcp1", Protocol: v1.ProtocolTCP, Port: int32(8080)}, - } - onePortUDP := []v1.ServicePort{ {Name: "udp1", Protocol: v1.ProtocolUDP, Port: int32(80)}, } @@ -1253,14 +1211,13 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { sevenPortsTCP := basePortsTCP[:] for _, tc := range []struct { - desc string - oldFwdRule *compute.ForwardingRule - oldPorts []v1.ServicePort - newlbIP string - newPorts []v1.ServicePort - discretePortForwarding bool - needsUpdate bool - expectError bool + desc string + oldFwdRule *compute.ForwardingRule + oldPorts []v1.ServicePort + newlbIP string + newPorts []v1.ServicePort + needsUpdate bool + expectError bool }{ { desc: "different ip address on update", @@ -1268,12 +1225,11 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { Name: "fwd-rule1", IPAddress: "1.1.1.1", }, - oldPorts: onePortTCP, - newlbIP: "2.2.2.2", - newPorts: onePortTCP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, + oldPorts: onePortTCP, + newlbIP: "2.2.2.2", + newPorts: onePortTCP, + needsUpdate: true, + expectError: false, }, { desc: "different protocol", @@ -1281,15 +1237,14 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { Name: "fwd-rule2", IPAddress: "1.1.1.1", }, - oldPorts: onePortTCP, - newlbIP: "1.1.1.1", - newPorts: onePortUDP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, + oldPorts: onePortTCP, + newlbIP: "1.1.1.1", + newPorts: onePortUDP, + needsUpdate: true, + expectError: false, }, { - desc: "same ports (PortRange)", + desc: "same ports", oldFwdRule: &compute.ForwardingRule{ Name: "fwd-rule3", IPAddress: "1.1.1.1", @@ -1298,25 +1253,9 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: onePortTCP, newlbIP: "1.1.1.1", // "80-80" - newPorts: onePortTCP, - discretePortForwarding: false, - needsUpdate: false, - expectError: false, - }, - { - desc: "same ports, discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule4", - IPAddress: "1.1.1.1", - }, - // ["8080"] - oldPorts: onePortTCP8080, - newlbIP: "1.1.1.1", - // ["8080"] - newPorts: onePortTCP8080, - discretePortForwarding: true, - needsUpdate: false, - expectError: false, + newPorts: onePortTCP, + needsUpdate: false, + expectError: false, }, { desc: "same Port Range", @@ -1328,13 +1267,12 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: sixPortsTCP, newlbIP: "1.1.1.1", // "80-85" - newPorts: sixPortsTCP, - discretePortForwarding: false, - needsUpdate: false, - expectError: false, + newPorts: sixPortsTCP, + needsUpdate: false, + expectError: false, }, { - desc: "same Port Range, discretePorts enabled", + desc: "same Port Range", oldFwdRule: &compute.ForwardingRule{ Name: "fwd-rule6", IPAddress: "1.1.1.1", @@ -1343,10 +1281,9 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: sevenPortsTCP, newlbIP: "1.1.1.1", // "80-86" - newPorts: sevenPortsTCP, - discretePortForwarding: true, - needsUpdate: false, - expectError: false, + newPorts: sevenPortsTCP, + needsUpdate: false, + expectError: false, }, { desc: "port range mismatch", @@ -1358,28 +1295,12 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: sixPortsTCP, newlbIP: "1.1.1.1", // "80-86" - newPorts: sevenPortsTCP, - discretePortForwarding: false, - needsUpdate: true, - expectError: false, - }, - { - desc: "port range mismatch, discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule8", - IPAddress: "1.1.1.1", - }, - // "80-85" - oldPorts: sixPortsTCP, - newlbIP: "1.1.1.1", - // "80-86" - newPorts: sevenPortsTCP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, + newPorts: sevenPortsTCP, + needsUpdate: true, + expectError: false, }, { - desc: "ports mismatch (PortRange)", + desc: "single port to multiple mismatch", oldFwdRule: &compute.ForwardingRule{ Name: "fwd-rule9", IPAddress: "1.1.1.1", @@ -1388,95 +1309,26 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: onePortTCP, newlbIP: "1.1.1.1", // "80-84" - newPorts: fivePortsTCP, - discretePortForwarding: false, - needsUpdate: true, - expectError: false, - }, - { - desc: "ports mismatch, discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule10", - IPAddress: "1.1.1.1", - }, - // ["80", "81", "82", "83", "84"] - oldPorts: fivePortsTCP, - newlbIP: "1.1.1.1", - // ["80"] - newPorts: onePortTCP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, - }, - { - desc: "PortRange to ports (PortRange)", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule11", - IPAddress: "1.1.1.1", - }, - // "80-85" - oldPorts: sixPortsTCP, - newlbIP: "1.1.1.1", - // "80-84" five ports are still considered PortRange since discretePorts is disabled - newPorts: fivePortsTCP, - discretePortForwarding: false, - needsUpdate: true, - expectError: false, - }, - { - desc: "PortRange to ports discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule12", - IPAddress: "1.1.1.1", - }, - // "80-85" - oldPorts: sixPortsTCP, - newlbIP: "1.1.1.1", - // ["80", "81", "82", "83", "84"] - newPorts: fivePortsTCP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, - }, - { - desc: "PortRange to ports within existing port range discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule13", - IPAddress: "1.1.1.1", - }, - // "80-85" - oldPorts: sixPortsTCP, - newlbIP: "1.1.1.1", - // ["80", "85"] - newPorts: []v1.ServicePort{ - {Name: "tcp1", Protocol: v1.ProtocolTCP, Port: int32(80)}, - {Name: "tcp2", Protocol: v1.ProtocolTCP, Port: int32(85)}, - }, - discretePortForwarding: true, - // we don't want to unnecessarily recreate forwarding rules - // when upgrading from port ranges to distinct ports, because recreating - // forwarding rules is traffic impacting. - needsUpdate: false, + newPorts: fivePortsTCP, + needsUpdate: true, expectError: false, }, { - desc: "PortRange to ports, discretePorts enabled, port outside of PortRange", + desc: "port range shrinks", oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule14", + Name: "fwd-rule11", IPAddress: "1.1.1.1", }, // "80-85" oldPorts: sixPortsTCP, newlbIP: "1.1.1.1", - // ["8080"] - newPorts: onePortTCP8080, - discretePortForwarding: true, - // Since port is outside of portrange we expect to recreate forwarding rule + // "80-84" + newPorts: fivePortsTCP, needsUpdate: true, expectError: false, }, { - desc: "ports (PortRange) to PortRange", + desc: "ports grow", oldFwdRule: &compute.ForwardingRule{ Name: "fwd-rule15", IPAddress: "1.1.1.1", @@ -1485,56 +1337,22 @@ func TestCreateForwardingRuleNeedsUpdate(t *testing.T) { oldPorts: fivePortsTCP, newlbIP: "1.1.1.1", // "80-85" - newPorts: sixPortsTCP, - discretePortForwarding: false, - needsUpdate: true, - expectError: false, - }, - { - desc: "ports to PortRange, discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule16", - IPAddress: "1.1.1.1", - }, - // ["80", "81", "82", "83", "84"] - oldPorts: fivePortsTCP, - newlbIP: "1.1.1.1", - // "80-85" - newPorts: sixPortsTCP, - discretePortForwarding: true, - needsUpdate: true, - expectError: false, - }, - { - desc: "update to empty ports, discretePorts enabled", - oldFwdRule: &compute.ForwardingRule{ - Name: "fwd-rule17", - IPAddress: "1.1.1.1", - }, - // ["80", "81", "82", "83", "84"] - oldPorts: fivePortsTCP, - newlbIP: "1.1.1.1", - newPorts: []v1.ServicePort{}, - discretePortForwarding: true, - needsUpdate: false, - expectError: true, + newPorts: sixPortsTCP, + needsUpdate: true, + expectError: false, }, } { t.Run(tc.desc, func(t *testing.T) { gce, err := fakeGCECloud(vals) require.NoError(t, err) - if tc.discretePortForwarding { - gce.SetEnableDiscretePortForwarding(true) - } - frName := tc.oldFwdRule.Name ipAddr := tc.oldFwdRule.IPAddress ports := tc.oldPorts newlbIP := tc.newlbIP newPorts := tc.newPorts - err = createForwardingRule(gce, frName, serviceName, gce.region, ipAddr, target, ports, cloud.NetworkTierStandard, tc.discretePortForwarding) + err = createForwardingRule(gce, frName, serviceName, gce.region, ipAddr, target, ports, cloud.NetworkTierStandard) assert.NoError(t, err) exists, needsUpdate, _, err := gce.forwardingRuleNeedsUpdate(frName, vals.Region, newlbIP, newPorts) @@ -1894,6 +1712,7 @@ func TestFirewallNeedsUpdate(t *testing.T) { tc.ipAddr, tc.ports, tc.ipnet, + int64(firewallPriorityDefault), ) assert.Equal(t, tc.exists, exists, "'exists' didn't return as expected "+desc) assert.Equal(t, tc.needsUpdate, needsUpdate, "'needsUpdate' didn't return as expected "+desc) @@ -1911,7 +1730,6 @@ func TestFirewallNeedsUpdate(t *testing.T) { require.NoError(t, err) require.Equal(t, fw.Allowed[0].IPProtocol, "tcp") require.Equal(t, fw.SourceRanges[0], trueSourceRange) - }) } } @@ -2037,7 +1855,9 @@ func TestCreateAndUpdateFirewallSucceedsOnXPN(t *testing.T) { "A sad little firewall", ipnet, svc.Spec.Ports, - hosts) + hosts, + firewallPriorityDefault, + ) require.NoError(t, err) msg := fmt.Sprintf("%s %s %s", v1.EventTypeNormal, eventReasonManualChange, eventMsgFirewallChange) @@ -2050,7 +1870,9 @@ func TestCreateAndUpdateFirewallSucceedsOnXPN(t *testing.T) { "10.0.0.1", ipnet, svc.Spec.Ports, - hosts) + hosts, + firewallPriorityDefault, + ) require.NoError(t, err) msg = fmt.Sprintf("%s %s %s", v1.EventTypeNormal, eventReasonManualChange, eventMsgFirewallChange) @@ -2242,7 +2064,6 @@ func TestExternalLoadBalancerEnsureHttpHealthCheck(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - gce, err := fakeGCECloud(DefaultTestClusterValues()) require.NoError(t, err) c := gce.c.(*cloud.MockGCE) @@ -2273,7 +2094,6 @@ func TestExternalLoadBalancerEnsureHttpHealthCheck(t *testing.T) { } }) } - } func TestMergeHttpHealthChecks(t *testing.T) { @@ -2380,6 +2200,7 @@ func TestFirewallObject(t *testing.T) { Ports: []string{"80"}, }, }, + Priority: 1000, } for _, tc := range []struct { @@ -2466,7 +2287,7 @@ func TestFirewallObject(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - ret, err := gce.firewallObject(fwName, fwDesc, tc.destinationIP, tc.sourceRanges, tc.svcPorts, nil) + ret, err := gce.firewallObject(fwName, fwDesc, tc.destinationIP, tc.sourceRanges, tc.svcPorts, nil, firewallPriorityDefault) require.NoError(t, err) expectedFirewall := tc.expectedFirewall(baseFw) retSrcRanges := sets.NewString(ret.SourceRanges...) @@ -2623,3 +2444,268 @@ func TestEnsureExternalLoadBalancerClass(t *testing.T) { } } } + +func TestFirewallsEqual(t *testing.T) { + t.Parallel() + testCases := []struct { + desc string + a *compute.Firewall + b *compute.Firewall + want bool + }{ + { + desc: "same allow", + a: &compute.Firewall{ + Priority: 1000, + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + b: &compute.Firewall{ + Priority: 1000, + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + want: true, + }, + { + desc: "same allow with different order", + a: &compute.Firewall{ + Priority: 1000, + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + b: &compute.Firewall{ + Priority: 1000, + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"11", "10"}}, + }, + SourceRanges: []string{"2.3.4.5/24", "1.2.3.0/24"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + want: true, + }, + { + desc: "same deny", + a: &compute.Firewall{ + Priority: 999, + Denied: []*compute.FirewallDenied{{IPProtocol: "all"}}, + SourceRanges: []string{"0.0.0.0/0"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + b: &compute.Firewall{ + Priority: 999, + Denied: []*compute.FirewallDenied{{IPProtocol: "all"}}, + SourceRanges: []string{"0.0.0.0/0"}, + DestinationRanges: []string{"12.34.56.78"}, + Description: "abcdef", + }, + want: true, + }, + { + desc: "different_priority", + a: &compute.Firewall{ + Priority: 1000, + }, + b: &compute.Firewall{ + Priority: 999, + }, + want: false, + }, + { + desc: "same_source_ranges", + a: &compute.Firewall{ + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + }, + b: &compute.Firewall{ + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + }, + want: true, + }, + { + desc: "different_source_ranges", + a: &compute.Firewall{ + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/24"}, + }, + b: &compute.Firewall{ + SourceRanges: []string{"1.2.3.0/24", "2.3.4.5/32"}, + }, + want: false, + }, + { + desc: "same_destination_ranges", + a: &compute.Firewall{ + DestinationRanges: []string{"12.34.56.78"}, + }, + b: &compute.Firewall{ + DestinationRanges: []string{"12.34.56.78"}, + }, + want: true, + }, + { + desc: "different_destination_ranges", + a: &compute.Firewall{ + DestinationRanges: []string{"12.34.56.78"}, + }, + b: &compute.Firewall{ + DestinationRanges: []string{"1.2.3.4"}, + }, + want: false, + }, + { + desc: "different_description", + a: &compute.Firewall{ + Description: "cat", + }, + b: &compute.Firewall{ + Description: "dog", + }, + want: false, + }, + { + desc: "same_description", + a: &compute.Firewall{ + Description: "cat", + }, + b: &compute.Firewall{ + Description: "cat", + }, + want: true, + }, + { + desc: "different_protocol", + a: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + }, + b: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "udp", Ports: []string{"10", "11"}}, + }, + }, + want: false, + }, + { + desc: "different_ports", + a: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + }, + b: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "12"}}, + }, + }, + want: false, + }, + { + desc: "different_port_count", + a: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10", "11"}}, + }, + }, + b: &compute.Firewall{ + Allowed: []*compute.FirewallAllowed{ + {IPProtocol: "tcp", Ports: []string{"10"}}, + }, + }, + want: false, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + t.Parallel() + + got, err := firewallsEqual(tC.a, tC.b) + if err != nil { + t.Fatalf("got unexpected err when comparing firewalls %v", err) + } + if got != tC.want { + t.Fatalf("got %v, want %v", got, tC.want) + } + }) + } +} + +func TestEnsureExternalLoadBalancerMetrics(t *testing.T) { + // t.Parallel() // Disable parallel to avoid race with global metrics registry + + vals := DefaultTestClusterValues() + gce, err := fakeGCECloud(vals) + require.NoError(t, err) + + lm, ok := gce.metricsCollector.(*LoadBalancerMetrics) + require.True(t, ok) + + svc := fakeLoadbalancerService("") + svc, err = gce.client.CoreV1().Services(svc.Namespace).Create(context.TODO(), svc, metav1.CreateOptions{}) + require.NoError(t, err) + + nodes, err := createAndInsertNodes(gce, []string{"test-node-1"}, vals.ZoneName) + require.NoError(t, err) + + // Case 1: Success + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + assert.NoError(t, err) + + // We expect 1 success, and deny firewall None (default) + lm.exportNetLBMetrics() + verifyL4NetLBMetric(t, 1, StatusSuccess, DenyFirewallStatusNone) + + // Case 2: Enable deny firewall cleanup + gce.enableL4DenyFirewallRollbackCleanup = true + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + assert.NoError(t, err) + + // We expect 1 success, and deny firewall Disabled + lm.exportNetLBMetrics() + verifyL4NetLBMetric(t, 1, StatusSuccess, DenyFirewallStatusDisabled) + + // Case 3: Enable deny firewall + gce.enableL4DenyFirewallRule = true + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + assert.NoError(t, err) + + // We expect 1 success, and deny firewall IPv4 + lm.exportNetLBMetrics() + verifyL4NetLBMetric(t, 1, StatusSuccess, DenyFirewallStatusIPv4) + + // Case 4: Error on fetch + mockGCE := gce.Compute().(*cloud.MockGCE) + mockGCE.MockFirewalls.GetHook = func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls, options ...cloud.Option) (bool, *compute.Firewall, error) { + return true, nil, fmt.Errorf("error on fetch") + } + _, err = gce.ensureExternalLoadBalancer(vals.ClusterName, vals.ClusterID, svc, nil, nodes) + assert.Error(t, err) + + // We expect 1 error, and deny firewall IPv4 + lm.exportNetLBMetrics() + verifyL4NetLBMetric(t, 1, StatusError, DenyFirewallStatusNone) + + // Clear mock + mockGCE.MockFirewalls.GetHook = nil + + // Case 5: Delete + err = gce.ensureExternalLoadBalancerDeleted(vals.ClusterName, vals.ClusterID, svc) + assert.NoError(t, err) + + // Now verify success count is 0 (since we deleted the success service) + lm.exportNetLBMetrics() + verifyL4NetLBMetric(t, 0, StatusError, DenyFirewallStatusNone) +} diff --git a/providers/gce/gce_loadbalancer_metrics.go b/providers/gce/gce_loadbalancer_metrics.go index 3cdbd9c0c9..0941a8d4d5 100644 --- a/providers/gce/gce_loadbalancer_metrics.go +++ b/providers/gce/gce_loadbalancer_metrics.go @@ -43,12 +43,21 @@ var ( }, []string{label}, ) + l4NetLBCount = metrics.NewGaugeVec( + &metrics.GaugeOpts{ + Name: "number_of_l4_netlbs", + Help: "Metric containing the number of NetLBs that can be filtered by feature labels and status", + }, + []string{"status", "deny_firewall"}, + ) ) // init registers L4 internal loadbalancer usage metrics. func init() { klog.V(3).Infof("Registering Service Controller loadbalancer usage metrics %v", l4ILBCount) legacyregistry.MustRegister(l4ILBCount) + klog.V(3).Infof("Registering Service Controller loadbalancer usage metrics %v", l4NetLBCount) + legacyregistry.MustRegister(l4NetLBCount) } // LoadBalancerMetrics is a cache that contains loadbalancer service resource @@ -56,6 +65,7 @@ func init() { type LoadBalancerMetrics struct { // l4ILBServiceMap is a map of service key and L4 ILB service state. l4ILBServiceMap map[string]L4ILBServiceState + l4NetLBMap map[string]L4NetLBServiceState sync.Mutex } @@ -97,6 +107,10 @@ type loadbalancerMetricsCollector interface { SetL4ILBService(svcKey string, state L4ILBServiceState) // DeleteL4ILBService removes the given L4 ILB service key. DeleteL4ILBService(svcKey string) + // SetL4NetLBService adds/updates L4 NetLB service state for given service key. + SetL4NetLBService(svcKey string, state L4NetLBServiceState) + // DeleteL4NetLBService removes the given L4 NetLB service key. + DeleteL4NetLBService(svcKey string) } // newLoadBalancerMetrics initializes LoadBalancerMetrics and starts a goroutine @@ -104,6 +118,7 @@ type loadbalancerMetricsCollector interface { func newLoadBalancerMetrics() loadbalancerMetricsCollector { return &LoadBalancerMetrics{ l4ILBServiceMap: make(map[string]L4ILBServiceState), + l4NetLBMap: make(map[string]L4NetLBServiceState), } } @@ -140,6 +155,11 @@ func (lm *LoadBalancerMetrics) DeleteL4ILBService(svcKey string) { // export computes and exports loadbalancer usage metrics. func (lm *LoadBalancerMetrics) export() { + lm.exportILBMetrics() + lm.exportNetLBMetrics() +} + +func (lm *LoadBalancerMetrics) exportILBMetrics() { ilbCount := lm.computeL4ILBMetrics() klog.V(5).Infof("Exporting L4 ILB usage metrics: %#v", ilbCount) for feature, count := range ilbCount { @@ -180,3 +200,60 @@ func (lm *LoadBalancerMetrics) computeL4ILBMetrics() map[feature]int { klog.V(4).Info("L4 ILB usage metrics computed.") return counts } + +// L4ServiceStatus denotes the status of the service +type L4ServiceStatus string + +// L4ServiceStatus denotes the status of the service +const ( + StatusSuccess = L4ServiceStatus("Success") + StatusUserError = L4ServiceStatus("UserError") + StatusError = L4ServiceStatus("Error") + StatusPersistentError = L4ServiceStatus("PersistentError") +) + +// DenyFirewallStatus represents IP stack used when the deny firewalls are provisioned. +type DenyFirewallStatus string + +// DenyFirewallStatus represents IP stack used when the deny firewalls are provisioned. +const ( + DenyFirewallStatusUnknown = DenyFirewallStatus("UNKNOWN") // Shouldn't happen, but if it does something is wrong. + DenyFirewallStatusNone = DenyFirewallStatus("") // Case when no firewalls have been provisioned yet or when the feature has not been enabled explicitly + DenyFirewallStatusDisabled = DenyFirewallStatus("DISABLED") // Case to mark when the feature has been enabled then explicitly disabled - for example when the feature is rolled back + DenyFirewallStatusIPv4 = DenyFirewallStatus("IPv4") +) + +type L4NetLBServiceState struct { + Status L4ServiceStatus + DenyFirewall DenyFirewallStatus +} + +// SetL4NetLBService patches information about L4 NetLB +func (lm *LoadBalancerMetrics) SetL4NetLBService(svcKey string, state L4NetLBServiceState) { + lm.Lock() + defer lm.Unlock() + + lm.l4NetLBMap[svcKey] = state +} + +// DeleteL4NetLBService removes the given L4 NetLB service key. +func (lm *LoadBalancerMetrics) DeleteL4NetLBService(svcKey string) { + lm.Lock() + defer lm.Unlock() + + delete(lm.l4NetLBMap, svcKey) +} + +// exportNetLBMetrics computes and exports loadbalancer usage metrics. +func (lm *LoadBalancerMetrics) exportNetLBMetrics() { + lm.Lock() + defer lm.Unlock() + + klog.Info("Exporting L4 NetLB usage metrics for services", "serviceCount", len(lm.l4NetLBMap)) + + l4NetLBCount.Reset() + for _, svcState := range lm.l4NetLBMap { + l4NetLBCount.WithLabelValues(string(svcState.Status), string(svcState.DenyFirewall)).Inc() + } + klog.Info("L4 NetLB usage metrics exported") +} diff --git a/providers/gce/gce_loadbalancer_metrics_test.go b/providers/gce/gce_loadbalancer_metrics_test.go index 50f9a4e5a4..1c0d3bf715 100644 --- a/providers/gce/gce_loadbalancer_metrics_test.go +++ b/providers/gce/gce_loadbalancer_metrics_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "k8s.io/component-base/metrics/testutil" ) func TestComputeL4ILBMetrics(t *testing.T) { @@ -168,3 +169,63 @@ func newL4ILBServiceState(globalAccess, customSubnet, inSuccess bool) L4ILBServi InSuccess: inSuccess, } } + +func TestL4NetLBMetrics(t *testing.T) { + metrics := newLoadBalancerMetrics() + // Cast to *LoadBalancerMetrics to access methods + lbMetrics, ok := metrics.(*LoadBalancerMetrics) + if !ok { + t.Fatalf("Failed to cast loadbalancerMetricsCollector to *LoadBalancerMetrics") + } + + lbMetrics.SetL4NetLBService("svc-success-ipv4", L4NetLBServiceState{ + Status: StatusSuccess, + DenyFirewall: DenyFirewallStatusIPv4, + }) + lbMetrics.SetL4NetLBService("svc-success-ipv4-2", L4NetLBServiceState{ + Status: StatusSuccess, + DenyFirewall: DenyFirewallStatusIPv4, + }) + lbMetrics.SetL4NetLBService("svc-success-disabled", L4NetLBServiceState{ + Status: StatusSuccess, + DenyFirewall: DenyFirewallStatusDisabled, + }) + lbMetrics.SetL4NetLBService("svc-error-none", L4NetLBServiceState{ + Status: StatusError, + DenyFirewall: DenyFirewallStatusNone, + }) + lbMetrics.SetL4NetLBService("svc-user-error-none", L4NetLBServiceState{ + Status: StatusUserError, + DenyFirewall: DenyFirewallStatusNone, + }) + lbMetrics.SetL4NetLBService("svc-persistent-error-none", L4NetLBServiceState{ + Status: StatusPersistentError, + DenyFirewall: DenyFirewallStatusNone, + }) + + // Add keys to be checked for deletion + lbMetrics.SetL4NetLBService("svc-to-delete", L4NetLBServiceState{ + Status: StatusSuccess, + DenyFirewall: DenyFirewallStatusNone, + }) + lbMetrics.DeleteL4NetLBService("svc-to-delete") + + lbMetrics.exportNetLBMetrics() + + verifyL4NetLBMetric(t, 2, StatusSuccess, DenyFirewallStatusIPv4) + verifyL4NetLBMetric(t, 1, StatusSuccess, DenyFirewallStatusDisabled) + verifyL4NetLBMetric(t, 1, StatusError, DenyFirewallStatusNone) + verifyL4NetLBMetric(t, 1, StatusUserError, DenyFirewallStatusNone) + verifyL4NetLBMetric(t, 1, StatusPersistentError, DenyFirewallStatusNone) +} + +func verifyL4NetLBMetric(t *testing.T, expectedCount int, status L4ServiceStatus, denyFirewall DenyFirewallStatus) { + t.Helper() + val, err := testutil.GetGaugeMetricValue(l4NetLBCount.WithLabelValues(string(status), string(denyFirewall))) + if err != nil { + t.Errorf("Failed to get metric value: %v", err) + } + if int(val) != expectedCount { + t.Errorf("Expected count %d but got %d for status %s, denyFirewall %s", expectedCount, int(val), status, denyFirewall) + } +} diff --git a/providers/gce/gce_loadbalancer_naming.go b/providers/gce/gce_loadbalancer_naming.go index 01c8765e94..10ab5e5632 100644 --- a/providers/gce/gce_loadbalancer_naming.go +++ b/providers/gce/gce_loadbalancer_naming.go @@ -26,7 +26,7 @@ import ( "strings" "github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" ) @@ -123,6 +123,13 @@ func MakeFirewallName(name string) string { return fmt.Sprintf("k8s-fw-%s", name) } +// MakeFirewallDenyName returns the name of the deny firewall rule +// used by the GCE L4 NetLBs for blocking all traffic that is not +// defined by the firewall rule. +func MakeFirewallDenyName(name string) string { + return fmt.Sprintf("k8s-fw-%s-deny", name) +} + func makeFirewallDescription(serviceName, ipAddress string) string { return fmt.Sprintf(`{"kubernetes.io/service-name":"%s", "kubernetes.io/service-ip":"%s"}`, serviceName, ipAddress) diff --git a/providers/gce/gce_loadbalancer_naming_test.go b/providers/gce/gce_loadbalancer_naming_test.go new file mode 100644 index 0000000000..7becd45020 --- /dev/null +++ b/providers/gce/gce_loadbalancer_naming_test.go @@ -0,0 +1,88 @@ +//go:build !providerless +// +build !providerless + +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package gce_test + +import ( + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + cloudprovider "k8s.io/cloud-provider" + + "k8s.io/cloud-provider-gcp/providers/gce" +) + +func TestLoadBalancerNames(t *testing.T) { + t.Parallel() + type names struct { + FirewallName string + DenyFirewallName string + } + + testCases := []struct { + desc string + svc *v1.Service + want names + }{ + { + desc: "short_uid", + svc: &v1.Service{ObjectMeta: metav1.ObjectMeta{UID: "shortuidwith19chars"}}, + want: names{ + FirewallName: "k8s-fw-ashortuidwith19chars", + DenyFirewallName: "k8s-fw-ashortuidwith19chars-deny", + }, + }, + { + desc: "long_uid", + svc: &v1.Service{ObjectMeta: metav1.ObjectMeta{UID: "nextremelylonguidwithmorethan32charsthatwillbecutbecauseofaws32charlimitforloadbalancernames"}}, + want: names{ + FirewallName: "k8s-fw-anextremelylonguidwithmorethan32", + DenyFirewallName: "k8s-fw-anextremelylonguidwithmorethan32-deny", + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + t.Parallel() + + lbName := cloudprovider.DefaultLoadBalancerName(tC.svc) + + got := names{ + FirewallName: gce.MakeFirewallName(lbName), + DenyFirewallName: gce.MakeFirewallDenyName(lbName), + } + if diff := cmp.Diff(tC.want, got); diff != "" { + t.Errorf("got != want, (-want, +got):/n%s", diff) + } + + // https://docs.cloud.google.com/compute/docs/naming-resources#resource-name-format + const gcpResourceNameLengthUpperLimit = 63 + v := reflect.ValueOf(got) + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if len(f.String()) > gcpResourceNameLengthUpperLimit || len(f.String()) < 1 { + t.Errorf("unacceptable length of resource name %q in field %q", f.String(), v.Type().Field(i).Name) + } + } + }) + } +} diff --git a/providers/go.mod b/providers/go.mod index 7a26b047bc..a153af8c23 100644 --- a/providers/go.mod +++ b/providers/go.mod @@ -47,6 +47,7 @@ require ( github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/BUILD b/vendor/github.com/google/go-cmp/cmp/cmpopts/BUILD new file mode 100644 index 0000000000..ec43905420 --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/BUILD @@ -0,0 +1,19 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "cmpopts", + srcs = [ + "equate.go", + "ignore.go", + "sort.go", + "struct_filter.go", + "xform.go", + ], + importmap = "k8s.io/cloud-provider-gcp/vendor/github.com/google/go-cmp/cmp/cmpopts", + importpath = "github.com/google/go-cmp/cmp/cmpopts", + visibility = ["//visibility:public"], + deps = [ + "//vendor/github.com/google/go-cmp/cmp", + "//vendor/github.com/google/go-cmp/cmp/internal/function", + ], +) diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/equate.go b/vendor/github.com/google/go-cmp/cmp/cmpopts/equate.go new file mode 100644 index 0000000000..3d8d0cd3ae --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/equate.go @@ -0,0 +1,185 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cmpopts provides common options for the cmp package. +package cmpopts + +import ( + "errors" + "fmt" + "math" + "reflect" + "time" + + "github.com/google/go-cmp/cmp" +) + +func equateAlways(_, _ interface{}) bool { return true } + +// EquateEmpty returns a [cmp.Comparer] option that determines all maps and slices +// with a length of zero to be equal, regardless of whether they are nil. +// +// EquateEmpty can be used in conjunction with [SortSlices] and [SortMaps]. +func EquateEmpty() cmp.Option { + return cmp.FilterValues(isEmpty, cmp.Comparer(equateAlways)) +} + +func isEmpty(x, y interface{}) bool { + vx, vy := reflect.ValueOf(x), reflect.ValueOf(y) + return (x != nil && y != nil && vx.Type() == vy.Type()) && + (vx.Kind() == reflect.Slice || vx.Kind() == reflect.Map) && + (vx.Len() == 0 && vy.Len() == 0) +} + +// EquateApprox returns a [cmp.Comparer] option that determines float32 or float64 +// values to be equal if they are within a relative fraction or absolute margin. +// This option is not used when either x or y is NaN or infinite. +// +// The fraction determines that the difference of two values must be within the +// smaller fraction of the two values, while the margin determines that the two +// values must be within some absolute margin. +// To express only a fraction or only a margin, use 0 for the other parameter. +// The fraction and margin must be non-negative. +// +// The mathematical expression used is equivalent to: +// +// |x-y| ≤ max(fraction*min(|x|, |y|), margin) +// +// EquateApprox can be used in conjunction with [EquateNaNs]. +func EquateApprox(fraction, margin float64) cmp.Option { + if margin < 0 || fraction < 0 || math.IsNaN(margin) || math.IsNaN(fraction) { + panic("margin or fraction must be a non-negative number") + } + a := approximator{fraction, margin} + return cmp.Options{ + cmp.FilterValues(areRealF64s, cmp.Comparer(a.compareF64)), + cmp.FilterValues(areRealF32s, cmp.Comparer(a.compareF32)), + } +} + +type approximator struct{ frac, marg float64 } + +func areRealF64s(x, y float64) bool { + return !math.IsNaN(x) && !math.IsNaN(y) && !math.IsInf(x, 0) && !math.IsInf(y, 0) +} +func areRealF32s(x, y float32) bool { + return areRealF64s(float64(x), float64(y)) +} +func (a approximator) compareF64(x, y float64) bool { + relMarg := a.frac * math.Min(math.Abs(x), math.Abs(y)) + return math.Abs(x-y) <= math.Max(a.marg, relMarg) +} +func (a approximator) compareF32(x, y float32) bool { + return a.compareF64(float64(x), float64(y)) +} + +// EquateNaNs returns a [cmp.Comparer] option that determines float32 and float64 +// NaN values to be equal. +// +// EquateNaNs can be used in conjunction with [EquateApprox]. +func EquateNaNs() cmp.Option { + return cmp.Options{ + cmp.FilterValues(areNaNsF64s, cmp.Comparer(equateAlways)), + cmp.FilterValues(areNaNsF32s, cmp.Comparer(equateAlways)), + } +} + +func areNaNsF64s(x, y float64) bool { + return math.IsNaN(x) && math.IsNaN(y) +} +func areNaNsF32s(x, y float32) bool { + return areNaNsF64s(float64(x), float64(y)) +} + +// EquateApproxTime returns a [cmp.Comparer] option that determines two non-zero +// [time.Time] values to be equal if they are within some margin of one another. +// If both times have a monotonic clock reading, then the monotonic time +// difference will be used. The margin must be non-negative. +func EquateApproxTime(margin time.Duration) cmp.Option { + if margin < 0 { + panic("margin must be a non-negative number") + } + a := timeApproximator{margin} + return cmp.FilterValues(areNonZeroTimes, cmp.Comparer(a.compare)) +} + +func areNonZeroTimes(x, y time.Time) bool { + return !x.IsZero() && !y.IsZero() +} + +type timeApproximator struct { + margin time.Duration +} + +func (a timeApproximator) compare(x, y time.Time) bool { + // Avoid subtracting times to avoid overflow when the + // difference is larger than the largest representable duration. + if x.After(y) { + // Ensure x is always before y + x, y = y, x + } + // We're within the margin if x+margin >= y. + // Note: time.Time doesn't have AfterOrEqual method hence the negation. + return !x.Add(a.margin).Before(y) +} + +// AnyError is an error that matches any non-nil error. +var AnyError anyError + +type anyError struct{} + +func (anyError) Error() string { return "any error" } +func (anyError) Is(err error) bool { return err != nil } + +// EquateErrors returns a [cmp.Comparer] option that determines errors to be equal +// if [errors.Is] reports them to match. The [AnyError] error can be used to +// match any non-nil error. +func EquateErrors() cmp.Option { + return cmp.FilterValues(areConcreteErrors, cmp.Comparer(compareErrors)) +} + +// areConcreteErrors reports whether x and y are types that implement error. +// The input types are deliberately of the interface{} type rather than the +// error type so that we can handle situations where the current type is an +// interface{}, but the underlying concrete types both happen to implement +// the error interface. +func areConcreteErrors(x, y interface{}) bool { + _, ok1 := x.(error) + _, ok2 := y.(error) + return ok1 && ok2 +} + +func compareErrors(x, y interface{}) bool { + xe := x.(error) + ye := y.(error) + return errors.Is(xe, ye) || errors.Is(ye, xe) +} + +// EquateComparable returns a [cmp.Option] that determines equality +// of comparable types by directly comparing them using the == operator in Go. +// The types to compare are specified by passing a value of that type. +// This option should only be used on types that are documented as being +// safe for direct == comparison. For example, [net/netip.Addr] is documented +// as being semantically safe to use with ==, while [time.Time] is documented +// to discourage the use of == on time values. +func EquateComparable(typs ...interface{}) cmp.Option { + types := make(typesFilter) + for _, typ := range typs { + switch t := reflect.TypeOf(typ); { + case !t.Comparable(): + panic(fmt.Sprintf("%T is not a comparable Go type", typ)) + case types[t]: + panic(fmt.Sprintf("%T is already specified", typ)) + default: + types[t] = true + } + } + return cmp.FilterPath(types.filter, cmp.Comparer(equateAny)) +} + +type typesFilter map[reflect.Type]bool + +func (tf typesFilter) filter(p cmp.Path) bool { return tf[p.Last().Type()] } + +func equateAny(x, y interface{}) bool { return x == y } diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/ignore.go b/vendor/github.com/google/go-cmp/cmp/cmpopts/ignore.go new file mode 100644 index 0000000000..fb84d11d70 --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/ignore.go @@ -0,0 +1,206 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cmpopts + +import ( + "fmt" + "reflect" + "unicode" + "unicode/utf8" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/internal/function" +) + +// IgnoreFields returns an [cmp.Option] that ignores fields of the +// given names on a single struct type. It respects the names of exported fields +// that are forwarded due to struct embedding. +// The struct type is specified by passing in a value of that type. +// +// The name may be a dot-delimited string (e.g., "Foo.Bar") to ignore a +// specific sub-field that is embedded or nested within the parent struct. +func IgnoreFields(typ interface{}, names ...string) cmp.Option { + sf := newStructFilter(typ, names...) + return cmp.FilterPath(sf.filter, cmp.Ignore()) +} + +// IgnoreTypes returns an [cmp.Option] that ignores all values assignable to +// certain types, which are specified by passing in a value of each type. +func IgnoreTypes(typs ...interface{}) cmp.Option { + tf := newTypeFilter(typs...) + return cmp.FilterPath(tf.filter, cmp.Ignore()) +} + +type typeFilter []reflect.Type + +func newTypeFilter(typs ...interface{}) (tf typeFilter) { + for _, typ := range typs { + t := reflect.TypeOf(typ) + if t == nil { + // This occurs if someone tries to pass in sync.Locker(nil) + panic("cannot determine type; consider using IgnoreInterfaces") + } + tf = append(tf, t) + } + return tf +} +func (tf typeFilter) filter(p cmp.Path) bool { + if len(p) < 1 { + return false + } + t := p.Last().Type() + for _, ti := range tf { + if t.AssignableTo(ti) { + return true + } + } + return false +} + +// IgnoreInterfaces returns an [cmp.Option] that ignores all values or references of +// values assignable to certain interface types. These interfaces are specified +// by passing in an anonymous struct with the interface types embedded in it. +// For example, to ignore [sync.Locker], pass in struct{sync.Locker}{}. +func IgnoreInterfaces(ifaces interface{}) cmp.Option { + tf := newIfaceFilter(ifaces) + return cmp.FilterPath(tf.filter, cmp.Ignore()) +} + +type ifaceFilter []reflect.Type + +func newIfaceFilter(ifaces interface{}) (tf ifaceFilter) { + t := reflect.TypeOf(ifaces) + if ifaces == nil || t.Name() != "" || t.Kind() != reflect.Struct { + panic("input must be an anonymous struct") + } + for i := 0; i < t.NumField(); i++ { + fi := t.Field(i) + switch { + case !fi.Anonymous: + panic("struct cannot have named fields") + case fi.Type.Kind() != reflect.Interface: + panic("embedded field must be an interface type") + case fi.Type.NumMethod() == 0: + // This matches everything; why would you ever want this? + panic("cannot ignore empty interface") + default: + tf = append(tf, fi.Type) + } + } + return tf +} +func (tf ifaceFilter) filter(p cmp.Path) bool { + if len(p) < 1 { + return false + } + t := p.Last().Type() + for _, ti := range tf { + if t.AssignableTo(ti) { + return true + } + if t.Kind() != reflect.Ptr && reflect.PtrTo(t).AssignableTo(ti) { + return true + } + } + return false +} + +// IgnoreUnexported returns an [cmp.Option] that only ignores the immediate unexported +// fields of a struct, including anonymous fields of unexported types. +// In particular, unexported fields within the struct's exported fields +// of struct types, including anonymous fields, will not be ignored unless the +// type of the field itself is also passed to IgnoreUnexported. +// +// Avoid ignoring unexported fields of a type which you do not control (i.e. a +// type from another repository), as changes to the implementation of such types +// may change how the comparison behaves. Prefer a custom [cmp.Comparer] instead. +func IgnoreUnexported(typs ...interface{}) cmp.Option { + ux := newUnexportedFilter(typs...) + return cmp.FilterPath(ux.filter, cmp.Ignore()) +} + +type unexportedFilter struct{ m map[reflect.Type]bool } + +func newUnexportedFilter(typs ...interface{}) unexportedFilter { + ux := unexportedFilter{m: make(map[reflect.Type]bool)} + for _, typ := range typs { + t := reflect.TypeOf(typ) + if t == nil || t.Kind() != reflect.Struct { + panic(fmt.Sprintf("%T must be a non-pointer struct", typ)) + } + ux.m[t] = true + } + return ux +} +func (xf unexportedFilter) filter(p cmp.Path) bool { + sf, ok := p.Index(-1).(cmp.StructField) + if !ok { + return false + } + return xf.m[p.Index(-2).Type()] && !isExported(sf.Name()) +} + +// isExported reports whether the identifier is exported. +func isExported(id string) bool { + r, _ := utf8.DecodeRuneInString(id) + return unicode.IsUpper(r) +} + +// IgnoreSliceElements returns an [cmp.Option] that ignores elements of []V. +// The discard function must be of the form "func(T) bool" which is used to +// ignore slice elements of type V, where V is assignable to T. +// Elements are ignored if the function reports true. +func IgnoreSliceElements(discardFunc interface{}) cmp.Option { + vf := reflect.ValueOf(discardFunc) + if !function.IsType(vf.Type(), function.ValuePredicate) || vf.IsNil() { + panic(fmt.Sprintf("invalid discard function: %T", discardFunc)) + } + return cmp.FilterPath(func(p cmp.Path) bool { + si, ok := p.Index(-1).(cmp.SliceIndex) + if !ok { + return false + } + if !si.Type().AssignableTo(vf.Type().In(0)) { + return false + } + vx, vy := si.Values() + if vx.IsValid() && vf.Call([]reflect.Value{vx})[0].Bool() { + return true + } + if vy.IsValid() && vf.Call([]reflect.Value{vy})[0].Bool() { + return true + } + return false + }, cmp.Ignore()) +} + +// IgnoreMapEntries returns an [cmp.Option] that ignores entries of map[K]V. +// The discard function must be of the form "func(T, R) bool" which is used to +// ignore map entries of type K and V, where K and V are assignable to T and R. +// Entries are ignored if the function reports true. +func IgnoreMapEntries(discardFunc interface{}) cmp.Option { + vf := reflect.ValueOf(discardFunc) + if !function.IsType(vf.Type(), function.KeyValuePredicate) || vf.IsNil() { + panic(fmt.Sprintf("invalid discard function: %T", discardFunc)) + } + return cmp.FilterPath(func(p cmp.Path) bool { + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + if !mi.Key().Type().AssignableTo(vf.Type().In(0)) || !mi.Type().AssignableTo(vf.Type().In(1)) { + return false + } + k := mi.Key() + vx, vy := mi.Values() + if vx.IsValid() && vf.Call([]reflect.Value{k, vx})[0].Bool() { + return true + } + if vy.IsValid() && vf.Call([]reflect.Value{k, vy})[0].Bool() { + return true + } + return false + }, cmp.Ignore()) +} diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/sort.go b/vendor/github.com/google/go-cmp/cmp/cmpopts/sort.go new file mode 100644 index 0000000000..720f3cdf57 --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/sort.go @@ -0,0 +1,171 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cmpopts + +import ( + "fmt" + "reflect" + "sort" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/internal/function" +) + +// SortSlices returns a [cmp.Transformer] option that sorts all []V. +// The lessOrCompareFunc function must be either +// a less function of the form "func(T, T) bool" or +// a compare function of the format "func(T, T) int" +// which is used to sort any slice with element type V that is assignable to T. +// +// A less function must be: +// - Deterministic: less(x, y) == less(x, y) +// - Irreflexive: !less(x, x) +// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) +// +// A compare function must be: +// - Deterministic: compare(x, y) == compare(x, y) +// - Irreflexive: compare(x, x) == 0 +// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) +// +// The function does not have to be "total". That is, if x != y, but +// less or compare report inequality, their relative order is maintained. +// +// SortSlices can be used in conjunction with [EquateEmpty]. +func SortSlices(lessOrCompareFunc interface{}) cmp.Option { + vf := reflect.ValueOf(lessOrCompareFunc) + if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() { + panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc)) + } + ss := sliceSorter{vf.Type().In(0), vf} + return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort)) +} + +type sliceSorter struct { + in reflect.Type // T + fnc reflect.Value // func(T, T) bool +} + +func (ss sliceSorter) filter(x, y interface{}) bool { + vx, vy := reflect.ValueOf(x), reflect.ValueOf(y) + if !(x != nil && y != nil && vx.Type() == vy.Type()) || + !(vx.Kind() == reflect.Slice && vx.Type().Elem().AssignableTo(ss.in)) || + (vx.Len() <= 1 && vy.Len() <= 1) { + return false + } + // Check whether the slices are already sorted to avoid an infinite + // recursion cycle applying the same transform to itself. + ok1 := sort.SliceIsSorted(x, func(i, j int) bool { return ss.less(vx, i, j) }) + ok2 := sort.SliceIsSorted(y, func(i, j int) bool { return ss.less(vy, i, j) }) + return !ok1 || !ok2 +} +func (ss sliceSorter) sort(x interface{}) interface{} { + src := reflect.ValueOf(x) + dst := reflect.MakeSlice(src.Type(), src.Len(), src.Len()) + for i := 0; i < src.Len(); i++ { + dst.Index(i).Set(src.Index(i)) + } + sort.SliceStable(dst.Interface(), func(i, j int) bool { return ss.less(dst, i, j) }) + ss.checkSort(dst) + return dst.Interface() +} +func (ss sliceSorter) checkSort(v reflect.Value) { + start := -1 // Start of a sequence of equal elements. + for i := 1; i < v.Len(); i++ { + if ss.less(v, i-1, i) { + // Check that first and last elements in v[start:i] are equal. + if start >= 0 && (ss.less(v, start, i-1) || ss.less(v, i-1, start)) { + panic(fmt.Sprintf("incomparable values detected: want equal elements: %v", v.Slice(start, i))) + } + start = -1 + } else if start == -1 { + start = i + } + } +} +func (ss sliceSorter) less(v reflect.Value, i, j int) bool { + vx, vy := v.Index(i), v.Index(j) + vo := ss.fnc.Call([]reflect.Value{vx, vy})[0] + if vo.Kind() == reflect.Bool { + return vo.Bool() + } else { + return vo.Int() < 0 + } +} + +// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be +// a sorted []struct{K, V}. The lessOrCompareFunc function must be either +// a less function of the form "func(T, T) bool" or +// a compare function of the format "func(T, T) int" +// which is used to sort any map with key K that is assignable to T. +// +// Flattening the map into a slice has the property that [cmp.Equal] is able to +// use [cmp.Comparer] options on K or the K.Equal method if it exists. +// +// A less function must be: +// - Deterministic: less(x, y) == less(x, y) +// - Irreflexive: !less(x, x) +// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z) +// - Total: if x != y, then either less(x, y) or less(y, x) +// +// A compare function must be: +// - Deterministic: compare(x, y) == compare(x, y) +// - Irreflexive: compare(x, x) == 0 +// - Transitive: if compare(x, y) < 0 and compare(y, z) < 0, then compare(x, z) < 0 +// - Total: if x != y, then compare(x, y) != 0 +// +// SortMaps can be used in conjunction with [EquateEmpty]. +func SortMaps(lessOrCompareFunc interface{}) cmp.Option { + vf := reflect.ValueOf(lessOrCompareFunc) + if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() { + panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc)) + } + ms := mapSorter{vf.Type().In(0), vf} + return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort)) +} + +type mapSorter struct { + in reflect.Type // T + fnc reflect.Value // func(T, T) bool +} + +func (ms mapSorter) filter(x, y interface{}) bool { + vx, vy := reflect.ValueOf(x), reflect.ValueOf(y) + return (x != nil && y != nil && vx.Type() == vy.Type()) && + (vx.Kind() == reflect.Map && vx.Type().Key().AssignableTo(ms.in)) && + (vx.Len() != 0 || vy.Len() != 0) +} +func (ms mapSorter) sort(x interface{}) interface{} { + src := reflect.ValueOf(x) + outType := reflect.StructOf([]reflect.StructField{ + {Name: "K", Type: src.Type().Key()}, + {Name: "V", Type: src.Type().Elem()}, + }) + dst := reflect.MakeSlice(reflect.SliceOf(outType), src.Len(), src.Len()) + for i, k := range src.MapKeys() { + v := reflect.New(outType).Elem() + v.Field(0).Set(k) + v.Field(1).Set(src.MapIndex(k)) + dst.Index(i).Set(v) + } + sort.Slice(dst.Interface(), func(i, j int) bool { return ms.less(dst, i, j) }) + ms.checkSort(dst) + return dst.Interface() +} +func (ms mapSorter) checkSort(v reflect.Value) { + for i := 1; i < v.Len(); i++ { + if !ms.less(v, i-1, i) { + panic(fmt.Sprintf("partial order detected: want %v < %v", v.Index(i-1), v.Index(i))) + } + } +} +func (ms mapSorter) less(v reflect.Value, i, j int) bool { + vx, vy := v.Index(i).Field(0), v.Index(j).Field(0) + vo := ms.fnc.Call([]reflect.Value{vx, vy})[0] + if vo.Kind() == reflect.Bool { + return vo.Bool() + } else { + return vo.Int() < 0 + } +} diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/struct_filter.go b/vendor/github.com/google/go-cmp/cmp/cmpopts/struct_filter.go new file mode 100644 index 0000000000..ca11a40249 --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/struct_filter.go @@ -0,0 +1,189 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cmpopts + +import ( + "fmt" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" +) + +// filterField returns a new Option where opt is only evaluated on paths that +// include a specific exported field on a single struct type. +// The struct type is specified by passing in a value of that type. +// +// The name may be a dot-delimited string (e.g., "Foo.Bar") to select a +// specific sub-field that is embedded or nested within the parent struct. +func filterField(typ interface{}, name string, opt cmp.Option) cmp.Option { + // TODO: This is currently unexported over concerns of how helper filters + // can be composed together easily. + // TODO: Add tests for FilterField. + + sf := newStructFilter(typ, name) + return cmp.FilterPath(sf.filter, opt) +} + +type structFilter struct { + t reflect.Type // The root struct type to match on + ft fieldTree // Tree of fields to match on +} + +func newStructFilter(typ interface{}, names ...string) structFilter { + // TODO: Perhaps allow * as a special identifier to allow ignoring any + // number of path steps until the next field match? + // This could be useful when a concrete struct gets transformed into + // an anonymous struct where it is not possible to specify that by type, + // but the transformer happens to provide guarantees about the names of + // the transformed fields. + + t := reflect.TypeOf(typ) + if t == nil || t.Kind() != reflect.Struct { + panic(fmt.Sprintf("%T must be a non-pointer struct", typ)) + } + var ft fieldTree + for _, name := range names { + cname, err := canonicalName(t, name) + if err != nil { + panic(fmt.Sprintf("%s: %v", strings.Join(cname, "."), err)) + } + ft.insert(cname) + } + return structFilter{t, ft} +} + +func (sf structFilter) filter(p cmp.Path) bool { + for i, ps := range p { + if ps.Type().AssignableTo(sf.t) && sf.ft.matchPrefix(p[i+1:]) { + return true + } + } + return false +} + +// fieldTree represents a set of dot-separated identifiers. +// +// For example, inserting the following selectors: +// +// Foo +// Foo.Bar.Baz +// Foo.Buzz +// Nuka.Cola.Quantum +// +// Results in a tree of the form: +// +// {sub: { +// "Foo": {ok: true, sub: { +// "Bar": {sub: { +// "Baz": {ok: true}, +// }}, +// "Buzz": {ok: true}, +// }}, +// "Nuka": {sub: { +// "Cola": {sub: { +// "Quantum": {ok: true}, +// }}, +// }}, +// }} +type fieldTree struct { + ok bool // Whether this is a specified node + sub map[string]fieldTree // The sub-tree of fields under this node +} + +// insert inserts a sequence of field accesses into the tree. +func (ft *fieldTree) insert(cname []string) { + if ft.sub == nil { + ft.sub = make(map[string]fieldTree) + } + if len(cname) == 0 { + ft.ok = true + return + } + sub := ft.sub[cname[0]] + sub.insert(cname[1:]) + ft.sub[cname[0]] = sub +} + +// matchPrefix reports whether any selector in the fieldTree matches +// the start of path p. +func (ft fieldTree) matchPrefix(p cmp.Path) bool { + for _, ps := range p { + switch ps := ps.(type) { + case cmp.StructField: + ft = ft.sub[ps.Name()] + if ft.ok { + return true + } + if len(ft.sub) == 0 { + return false + } + case cmp.Indirect: + default: + return false + } + } + return false +} + +// canonicalName returns a list of identifiers where any struct field access +// through an embedded field is expanded to include the names of the embedded +// types themselves. +// +// For example, suppose field "Foo" is not directly in the parent struct, +// but actually from an embedded struct of type "Bar". Then, the canonical name +// of "Foo" is actually "Bar.Foo". +// +// Suppose field "Foo" is not directly in the parent struct, but actually +// a field in two different embedded structs of types "Bar" and "Baz". +// Then the selector "Foo" causes a panic since it is ambiguous which one it +// refers to. The user must specify either "Bar.Foo" or "Baz.Foo". +func canonicalName(t reflect.Type, sel string) ([]string, error) { + var name string + sel = strings.TrimPrefix(sel, ".") + if sel == "" { + return nil, fmt.Errorf("name must not be empty") + } + if i := strings.IndexByte(sel, '.'); i < 0 { + name, sel = sel, "" + } else { + name, sel = sel[:i], sel[i:] + } + + // Type must be a struct or pointer to struct. + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("%v must be a struct", t) + } + + // Find the canonical name for this current field name. + // If the field exists in an embedded struct, then it will be expanded. + sf, _ := t.FieldByName(name) + if !isExported(name) { + // Avoid using reflect.Type.FieldByName for unexported fields due to + // buggy behavior with regard to embeddeding and unexported fields. + // See https://golang.org/issue/4876 for details. + sf = reflect.StructField{} + for i := 0; i < t.NumField() && sf.Name == ""; i++ { + if t.Field(i).Name == name { + sf = t.Field(i) + } + } + } + if sf.Name == "" { + return []string{name}, fmt.Errorf("does not exist") + } + var ss []string + for i := range sf.Index { + ss = append(ss, t.FieldByIndex(sf.Index[:i+1]).Name) + } + if sel == "" { + return ss, nil + } + ssPost, err := canonicalName(sf.Type, sel) + return append(ss, ssPost...), err +} diff --git a/vendor/github.com/google/go-cmp/cmp/cmpopts/xform.go b/vendor/github.com/google/go-cmp/cmp/cmpopts/xform.go new file mode 100644 index 0000000000..25b4bd05bd --- /dev/null +++ b/vendor/github.com/google/go-cmp/cmp/cmpopts/xform.go @@ -0,0 +1,36 @@ +// Copyright 2018, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cmpopts + +import ( + "github.com/google/go-cmp/cmp" +) + +type xformFilter struct{ xform cmp.Option } + +func (xf xformFilter) filter(p cmp.Path) bool { + for _, ps := range p { + if t, ok := ps.(cmp.Transform); ok && t.Option() == xf.xform { + return false + } + } + return true +} + +// AcyclicTransformer returns a [cmp.Transformer] with a filter applied that ensures +// that the transformer cannot be recursively applied upon its own output. +// +// An example use case is a transformer that splits a string by lines: +// +// AcyclicTransformer("SplitLines", func(s string) []string{ +// return strings.Split(s, "\n") +// }) +// +// Had this been an unfiltered [cmp.Transformer] instead, this would result in an +// infinite cycle converting a string to []string to [][]string and so on. +func AcyclicTransformer(name string, xformFunc interface{}) cmp.Option { + xf := xformFilter{cmp.Transformer(name, xformFunc)} + return cmp.FilterPath(xf.filter, xf.xform) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 01bb5b86d0..0162b96789 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -175,6 +175,7 @@ github.com/google/gnostic-models/openapiv3 # github.com/google/go-cmp v0.7.0 ## explicit; go 1.21 github.com/google/go-cmp/cmp +github.com/google/go-cmp/cmp/cmpopts github.com/google/go-cmp/cmp/internal/diff github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function