From b7b551a211e130084b3eb69437b1660edd2363f3 Mon Sep 17 00:00:00 2001 From: "haijian.yang" Date: Thu, 28 May 2026 11:24:43 +0800 Subject: [PATCH] Use secret to create tower client --- api/v1beta1/types.go | 16 + api/v1beta1/zz_generated.deepcopy.go | 25 +- ...tructure.cluster.x-k8s.io_elfclusters.yaml | 16 +- ...tructure.cluster.x-k8s.io_elfmachines.yaml | 2 +- ....cluster.x-k8s.io_elfmachinetemplates.yaml | 2 +- controllers/elfcluster_controller.go | 14 +- controllers/elfcluster_controller_test.go | 13 +- controllers/elfmachine_controller.go | 2 +- .../elfmachine_controller_cloudinit.go | 30 +- .../elfmachine_controller_cloudinit_test.go | 4 +- controllers/elfmachine_controller_gpu_test.go | 2 +- .../elfmachine_controller_resources_test.go | 2 +- controllers/elfmachine_controller_test.go | 8 +- controllers/suite_test.go | 2 +- pkg/cloudtower/tower.go | 277 ++++++++++++++ pkg/cloudtower/tower_test.go | 346 ++++++++++++++++++ pkg/config/vm.go | 2 +- pkg/service/vm.go | 98 ++--- pkg/session/tower.go | 175 --------- pkg/session/tower_test.go | 41 --- templates/cluster-template.yaml | 8 +- test/e2e/tower_test.go | 15 +- 22 files changed, 777 insertions(+), 323 deletions(-) create mode 100644 pkg/cloudtower/tower.go create mode 100644 pkg/cloudtower/tower_test.go delete mode 100644 pkg/session/tower.go delete mode 100644 pkg/session/tower_test.go diff --git a/api/v1beta1/types.go b/api/v1beta1/types.go index 67772d6d..9bcafa67 100644 --- a/api/v1beta1/types.go +++ b/api/v1beta1/types.go @@ -76,6 +76,22 @@ func (z ElfClusterZoneType) ToLower() string { } type Tower struct { + TowerClientConfig `json:",inline"` + + // SecretRef is the reference to the secret containing the tower information. + SecretRef *corev1.SecretReference `json:"secretRef,omitempty"` +} + +func (t *Tower) String() string { + if t.SecretRef != nil { + return fmt.Sprintf("%s/%s", t.SecretRef.Namespace, t.SecretRef.Name) + } + + return t.TowerClientConfig.Server +} + +// TowerClientConfig is the connection information for the tower server. +type TowerClientConfig struct { // Server is address of the tower server. Server string `json:"server,omitempty"` diff --git a/api/v1beta1/zz_generated.deepcopy.go b/api/v1beta1/zz_generated.deepcopy.go index 20affe1e..6cf70250 100644 --- a/api/v1beta1/zz_generated.deepcopy.go +++ b/api/v1beta1/zz_generated.deepcopy.go @@ -63,7 +63,7 @@ func (in *ElfCluster) DeepCopyInto(out *ElfCluster) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - out.Spec = in.Spec + in.Spec.DeepCopyInto(&out.Spec) in.Status.DeepCopyInto(&out.Status) } @@ -120,7 +120,7 @@ func (in *ElfClusterList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ElfClusterSpec) DeepCopyInto(out *ElfClusterSpec) { *out = *in - out.Tower = in.Tower + in.Tower.DeepCopyInto(&out.Tower) out.ControlPlaneEndpoint = in.ControlPlaneEndpoint } @@ -555,6 +555,12 @@ func (in *ResourcesStatus) DeepCopy() *ResourcesStatus { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Tower) DeepCopyInto(out *Tower) { *out = *in + out.TowerClientConfig = in.TowerClientConfig + if in.SecretRef != nil { + in, out := &in.SecretRef, &out.SecretRef + *out = new(v1.SecretReference) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Tower. @@ -567,6 +573,21 @@ func (in *Tower) DeepCopy() *Tower { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *TowerClientConfig) DeepCopyInto(out *TowerClientConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TowerClientConfig. +func (in *TowerClientConfig) DeepCopy() *TowerClientConfig { + if in == nil { + return nil + } + out := new(TowerClientConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *VGPUDeviceSpec) DeepCopyInto(out *VGPUDeviceSpec) { *out = *in diff --git a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfclusters.yaml b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfclusters.yaml index c0bac0b1..6ea7d21e 100644 --- a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfclusters.yaml +++ b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfclusters.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.2 + controller-gen.kubebuilder.io/version: v0.18.0 name: elfclusters.infrastructure.cluster.x-k8s.io spec: group: infrastructure.cluster.x-k8s.io @@ -97,6 +97,20 @@ spec: description: Password is the password used to access the tower server. type: string + secretRef: + description: SecretRef is the reference to the secret containing + the tower information. + properties: + name: + description: name is unique within a namespace to reference + a secret resource. + type: string + namespace: + description: namespace defines the space within which the + secret name must be unique. + type: string + type: object + x-kubernetes-map-type: atomic server: description: Server is address of the tower server. type: string diff --git a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachines.yaml b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachines.yaml index bf8f57b6..b7d82f44 100644 --- a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachines.yaml +++ b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachines.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.2 + controller-gen.kubebuilder.io/version: v0.18.0 name: elfmachines.infrastructure.cluster.x-k8s.io spec: group: infrastructure.cluster.x-k8s.io diff --git a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachinetemplates.yaml b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachinetemplates.yaml index bb4f724a..80cf0e16 100644 --- a/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachinetemplates.yaml +++ b/config/crd/bases/infrastructure.cluster.x-k8s.io_elfmachinetemplates.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.2 + controller-gen.kubebuilder.io/version: v0.18.0 name: elfmachinetemplates.infrastructure.cluster.x-k8s.io spec: group: infrastructure.cluster.x-k8s.io diff --git a/controllers/elfcluster_controller.go b/controllers/elfcluster_controller.go index 766adbaf..1dc9cbfc 100644 --- a/controllers/elfcluster_controller.go +++ b/controllers/elfcluster_controller.go @@ -149,7 +149,7 @@ func (r *ElfClusterReconciler) Reconcile(ctx goctx.Context, req ctrl.Request) (_ // If ElfCluster is being deleting and ForceDeleteCluster flag is set, skip creating the VMService object, // because Tower server may be out of service. So we can force delete ElfCluster. if elfCluster.ObjectMeta.DeletionTimestamp.IsZero() || !elfCluster.HasForceDeleteCluster() { - vmService, err := r.NewVMService(ctx, elfCluster.GetTower(), log) + vmService, err := r.NewVMService(ctx, r.Client, elfCluster.GetTower(), log) if err != nil { conditions.MarkFalse(&elfCluster, infrav1.TowerAvailableCondition, infrav1.TowerUnreachableReason, clusterv1.ConditionSeverityError, err.Error()) @@ -283,25 +283,25 @@ func (r *ElfClusterReconciler) cleanOrphanLabels(ctx goctx.Context, clusterCtx * log := ctrl.LoggerFrom(ctx) // Locking ensures that only one coroutine cleans at the same time - if ok := acquireLockForGCTowerLabels(clusterCtx.ElfCluster.Spec.Tower.Server); ok { - defer releaseLockForForGCTowerLabels(clusterCtx.ElfCluster.Spec.Tower.Server) + if ok := acquireLockForGCTowerLabels(clusterCtx.ElfCluster.Spec.Tower.String()); ok { + defer releaseLockForForGCTowerLabels(clusterCtx.ElfCluster.Spec.Tower.String()) } else { return } - log.V(1).Info(fmt.Sprintf("Cleaning orphan labels in Tower %s created by CAPE", clusterCtx.ElfCluster.Spec.Tower.Server)) + log.V(1).Info(fmt.Sprintf("Cleaning orphan labels in Tower %s created by CAPE", clusterCtx.ElfCluster.Spec.Tower.String())) keys := []string{towerresources.GetVMLabelClusterName(), towerresources.GetVMLabelVIP(), towerresources.GetVMLabelNamespace()} labelIDs, err := clusterCtx.VMService.CleanUnusedLabels(keys) if err != nil { - log.Error(err, "Warning: failed to clean orphan labels in Tower "+clusterCtx.ElfCluster.Spec.Tower.Server) + log.Error(err, "Warning: failed to clean orphan labels in Tower "+clusterCtx.ElfCluster.Spec.Tower.String()) return } - recordGCTimeForTowerLabels(clusterCtx.ElfCluster.Spec.Tower.Server) + recordGCTimeForTowerLabels(clusterCtx.ElfCluster.Spec.Tower.String()) - log.V(1).Info(fmt.Sprintf("Labels of Tower %s are cleaned successfully", clusterCtx.ElfCluster.Spec.Tower.Server), "labelCount", len(labelIDs)) + log.V(1).Info(fmt.Sprintf("Labels of Tower %s are cleaned successfully", clusterCtx.ElfCluster.Spec.Tower.String()), "labelCount", len(labelIDs)) } func (r *ElfClusterReconciler) reconcileNormal(ctx goctx.Context, clusterCtx *context.ClusterContext) (reconcile.Result, error) { //nolint:unparam diff --git a/controllers/elfcluster_controller_test.go b/controllers/elfcluster_controller_test.go index a03c9ada..bd7b1292 100644 --- a/controllers/elfcluster_controller_test.go +++ b/controllers/elfcluster_controller_test.go @@ -35,6 +35,7 @@ import ( clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" capiutil "sigs.k8s.io/cluster-api/util" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/event" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" @@ -71,7 +72,7 @@ var _ = Describe("ElfClusterReconciler", func() { // mock mockCtrl = gomock.NewController(GinkgoT()) mockVMService = mock_services.NewMockVMService(mockCtrl) - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, nil } }) @@ -218,7 +219,7 @@ var _ = Describe("ElfClusterReconciler", func() { }) It("should delete failed when tower is out of service", func() { - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, errors.New("get vm service failed") } ctrlMgrCtx := fake.NewControllerManagerContext(elfCluster, cluster) @@ -234,7 +235,7 @@ var _ = Describe("ElfClusterReconciler", func() { }) It("should force delete when tower is out of service and cluster need to force delete", func() { - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, errors.New("get vm service failed") } elfCluster.Annotations = map[string]string{ @@ -275,16 +276,16 @@ var _ = Describe("ElfClusterReconciler", func() { mockVMService.EXPECT().CleanUnusedLabels(keys).Return(nil, unexpectedError) reconciler := &ElfClusterReconciler{ControllerManagerContext: ctrlMgrCtx, NewVMService: mockNewVMService} reconciler.cleanOrphanLabels(ctx, clusterCtx) - Expect(logBuffer.String()).To(ContainSubstring("Warning: failed to clean orphan labels in Tower " + elfCluster.Spec.Tower.Server)) + Expect(logBuffer.String()).To(ContainSubstring("Warning: failed to clean orphan labels in Tower " + elfCluster.Spec.Tower.String())) logBuffer.Reset() mockVMService.EXPECT().CleanUnusedLabels(keys).Return(nil, nil) reconciler.cleanOrphanLabels(ctx, clusterCtx) - Expect(logBuffer.String()).To(ContainSubstring(fmt.Sprintf("Labels of Tower %s are cleaned successfully", elfCluster.Spec.Tower.Server))) + Expect(logBuffer.String()).To(ContainSubstring(fmt.Sprintf("Labels of Tower %s are cleaned successfully", elfCluster.Spec.Tower.String()))) logBuffer.Reset() reconciler.cleanOrphanLabels(ctx, clusterCtx) - Expect(logBuffer.String()).NotTo(ContainSubstring(fmt.Sprintf("Cleaning orphan labels in Tower %s created by CAPE", elfCluster.Spec.Tower.Server))) + Expect(logBuffer.String()).NotTo(ContainSubstring(fmt.Sprintf("Cleaning orphan labels in Tower %s created by CAPE", elfCluster.Spec.Tower.String()))) }) }) diff --git a/controllers/elfmachine_controller.go b/controllers/elfmachine_controller.go index 7770245a..886828eb 100644 --- a/controllers/elfmachine_controller.go +++ b/controllers/elfmachine_controller.go @@ -198,7 +198,7 @@ func (r *ElfMachineReconciler) Reconcile(ctx goctx.Context, req ctrl.Request) (r // If ElfMachine is being deleting and ElfCLuster ForceDeleteCluster flag is set, skip creating the VMService object, // because Tower server may be out of service. So we can force delete ElfCluster. if elfMachine.ObjectMeta.DeletionTimestamp.IsZero() || !elfCluster.HasForceDeleteCluster() { - vmService, err := r.NewVMService(ctx, elfCluster.GetTower(), log) + vmService, err := r.NewVMService(ctx, r.Client, elfCluster.GetTower(), log) if err != nil { conditions.MarkFalse(&elfMachine, infrav1.TowerAvailableCondition, infrav1.TowerUnreachableReason, clusterv1.ConditionSeverityError, err.Error()) diff --git a/controllers/elfmachine_controller_cloudinit.go b/controllers/elfmachine_controller_cloudinit.go index 09a00c2a..a9d8fa29 100644 --- a/controllers/elfmachine_controller_cloudinit.go +++ b/controllers/elfmachine_controller_cloudinit.go @@ -5,7 +5,6 @@ import ( goctx "context" "fmt" "io" - "strconv" "strings" "github.com/pkg/errors" @@ -17,6 +16,7 @@ const ( kubeadmAPIVersionV1Beta3 = "kubeadm.k8s.io/v1beta3" kubeadmAPIVersionV1Beta4 = "kubeadm.k8s.io/v1beta4" kubeadmProviderIDValue = "elf://{{ ds.meta_data.instance_id }}" + yamlTagString = "!!str" ) type cloudInitMutationContext struct { @@ -107,7 +107,7 @@ func ensureKubeadmConfigInWriteFiles(root *yaml.Node, mutationCtx cloudInitMutat continue } - updated, err := ensureKubeadmConfigContent(content, mutationCtx.hostName) + updated, err := ensureKubeadmConfigContent(content) if err != nil { return false, err } @@ -119,8 +119,8 @@ func ensureKubeadmConfigInWriteFiles(root *yaml.Node, mutationCtx cloudInitMutat return changed, nil } -func ensureKubeadmConfigContent(content *yaml.Node, hostname string) (bool, error) { - if content.Kind != yaml.ScalarNode || (content.Tag != "" && content.Tag != "!!str") { +func ensureKubeadmConfigContent(content *yaml.Node) (bool, error) { + if content.Kind != yaml.ScalarNode || (content.Tag != "" && content.Tag != yamlTagString) { return false, nil } @@ -130,7 +130,7 @@ func ensureKubeadmConfigContent(content *yaml.Node, hostname string) (bool, erro return false, nil } - if !ensureKubeadmNodeRegistrationDocuments(documents, hostname) { + if !ensureKubeadmNodeRegistrationDocuments(documents) { return false, nil } @@ -139,20 +139,20 @@ func ensureKubeadmConfigContent(content *yaml.Node, hostname string) (bool, erro return false, errors.Wrap(err, "failed to marshal kubeadm config after ensuring provider-id") } - content.Tag = "!!str" + content.Tag = yamlTagString content.Value = marshaled return true, nil } -func ensureKubeadmNodeRegistrationDocuments(documents []*yaml.Node, hostname string) bool { +func ensureKubeadmNodeRegistrationDocuments(documents []*yaml.Node) bool { changed := false for _, document := range documents { if !isKubeadmNodeRegistrationDocument(document) { continue } - if ensureKubeadmNodeRegistration(document, hostname) { + if ensureKubeadmNodeRegistration(document) { changed = true } } @@ -169,7 +169,7 @@ func isKubeadmNodeRegistrationDocument(root *yaml.Node) bool { } } -func ensureKubeadmNodeRegistration(root *yaml.Node, hostname string) bool { +func ensureKubeadmNodeRegistration(root *yaml.Node) bool { changed := false nodeRegistration, _ := ensureYAMLMappingValue(root, "nodeRegistration") if ensureKubeletProviderID(root, nodeRegistration) { @@ -368,7 +368,7 @@ func upsertYAMLMapString(parent *yaml.Node, key, value string) bool { return false } - existing.Tag = "!!str" + existing.Tag = yamlTagString existing.Value = value return true @@ -392,14 +392,6 @@ func upsertNamedValueSequenceItem(sequence *yaml.Node, name, value string) bool return true } -func newBoolYAMLNode(value bool) *yaml.Node { - return &yaml.Node{ - Kind: yaml.ScalarNode, - Tag: "!!bool", - Value: strconv.FormatBool(value), - } -} - func newScalarNodes(values []string) []*yaml.Node { nodes := make([]*yaml.Node, 0, len(values)) for _, value := range values { @@ -420,7 +412,7 @@ func newNamedValueMappingNode(name, value string) *yaml.Node { func newStringYAMLNode(value string) *yaml.Node { return &yaml.Node{ Kind: yaml.ScalarNode, - Tag: "!!str", + Tag: yamlTagString, Value: value, } } diff --git a/controllers/elfmachine_controller_cloudinit_test.go b/controllers/elfmachine_controller_cloudinit_test.go index fffa86b7..07aa363c 100644 --- a/controllers/elfmachine_controller_cloudinit_test.go +++ b/controllers/elfmachine_controller_cloudinit_test.go @@ -23,7 +23,7 @@ kind: KubeProxyConfiguration metricsBindAddress: 0.0.0.0:10249 `) - changed, err := ensureKubeadmConfigContent(content, "") + changed, err := ensureKubeadmConfigContent(content) if err != nil { t.Fatalf("ensureKubeadmConfigContent() error = %v", err) } @@ -73,7 +73,7 @@ nodeRegistration: value: "0" `) - changed, err := ensureKubeadmConfigContent(content, "") + changed, err := ensureKubeadmConfigContent(content) if err != nil { t.Fatalf("ensureKubeadmConfigContent() error = %v", err) } diff --git a/controllers/elfmachine_controller_gpu_test.go b/controllers/elfmachine_controller_gpu_test.go index e7e1ee01..aee754c1 100644 --- a/controllers/elfmachine_controller_gpu_test.go +++ b/controllers/elfmachine_controller_gpu_test.go @@ -73,7 +73,7 @@ var _ = Describe("ElfMachineReconciler-GPU", func() { // mock mockCtrl = gomock.NewController(GinkgoT()) mockVMService = mock_services.NewMockVMService(mockCtrl) - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, nil } }) diff --git a/controllers/elfmachine_controller_resources_test.go b/controllers/elfmachine_controller_resources_test.go index 89b67fd4..7d044a00 100644 --- a/controllers/elfmachine_controller_resources_test.go +++ b/controllers/elfmachine_controller_resources_test.go @@ -71,7 +71,7 @@ var _ = Describe("ElfMachineReconciler", func() { // mock mockCtrl = gomock.NewController(GinkgoT()) mockVMService = mock_services.NewMockVMService(mockCtrl) - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, nil } }) diff --git a/controllers/elfmachine_controller_test.go b/controllers/elfmachine_controller_test.go index ebe88ad9..b4667e59 100644 --- a/controllers/elfmachine_controller_test.go +++ b/controllers/elfmachine_controller_test.go @@ -93,7 +93,7 @@ var _ = Describe("ElfMachineReconciler", func() { // mock mockCtrl = gomock.NewController(GinkgoT()) mockVMService = mock_services.NewMockVMService(mockCtrl) - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, nil } @@ -2537,7 +2537,7 @@ var _ = Describe("ElfMachineReconciler", func() { }) It("should delete ElfMachine when tower is out of service and cluster need to force delete", func() { - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, errors.New("get vm service failed") } elfCluster.Annotations = map[string]string{ @@ -2558,7 +2558,7 @@ var _ = Describe("ElfMachineReconciler", func() { }) It("should delete ElfMachine failed when tower is out of service", func() { - mockNewVMService = func(_ goctx.Context, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { + mockNewVMService = func(_ goctx.Context, _ client.Client, _ infrav1.Tower, _ logr.Logger) (service.VMService, error) { return mockVMService, errors.New("get vm service failed") } ctrlMgrCtx := fake.NewControllerManagerContext(elfCluster, cluster, elfMachine, machine, secret, md) @@ -4115,7 +4115,7 @@ var _ = Describe("ElfMachineReconciler", func() { }) }) -func waitStaticIPAllocationSpec(mockNewVMService func(ctx goctx.Context, auth infrav1.Tower, logger logr.Logger) (service.VMService, error), +func waitStaticIPAllocationSpec(mockNewVMService service.NewVMServiceFunc, elfCluster *infrav1.ElfCluster, cluster *clusterv1.Cluster, elfMachine *infrav1.ElfMachine, machine *clusterv1.Machine, secret *corev1.Secret, md *clusterv1.MachineDeployment) { ctrlMgrCtx := fake.NewControllerManagerContext(elfCluster, cluster, elfMachine, machine, secret, md) diff --git a/controllers/suite_test.go b/controllers/suite_test.go index 827d3c0e..69a9df11 100644 --- a/controllers/suite_test.go +++ b/controllers/suite_test.go @@ -142,7 +142,7 @@ func setup() { }() // Setting ConnectionCreationRetryInterval to 2 seconds, otherwise client creation is // only retried every 30s. If we get unlucky tests are then failing with timeout. - clusterCache.(interface{ SetConnectionCreationRetryInterval(time.Duration) }). + clusterCache.(interface{ SetConnectionCreationRetryInterval(interval time.Duration) }). SetConnectionCreationRetryInterval(2 * time.Second) if err := AddClusterControllerToManager(ctx, testEnv.GetControllerManagerContext(), testEnv.Manager, controllerOpts); err != nil { diff --git a/pkg/cloudtower/tower.go b/pkg/cloudtower/tower.go new file mode 100644 index 00000000..f6f5991d --- /dev/null +++ b/pkg/cloudtower/tower.go @@ -0,0 +1,277 @@ +/* +Copyright 2022. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloudtower + +import ( + "bytes" + goctx "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/go-logr/logr" + httptransport "github.com/go-openapi/runtime/client" + "github.com/go-openapi/strfmt" + "github.com/pkg/errors" + towerclient "github.com/smartxworks/cloudtower-go-sdk/v2/client" + "github.com/smartxworks/cloudtower-go-sdk/v2/client/user" + "github.com/smartxworks/cloudtower-go-sdk/v2/models" + "golang.org/x/sync/singleflight" + corev1 "k8s.io/api/core/v1" + apitypes "k8s.io/apimachinery/pkg/types" + utilyaml "k8s.io/apimachinery/pkg/util/yaml" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + infrav1 "github.com/smartxworks/cluster-api-provider-elf/api/v1beta1" + annotationsutil "github.com/smartxworks/cluster-api-provider-elf/pkg/util/annotations" +) + +// constants. +const ( + // CloudTowerServerVersionAnnotation is the annotation identifying the version of cloud tower server configuration. + CloudTowerServerVersionAnnotation = "cape.infrastructure.cluster.x-k8s.io/cloud-server-version" + + CloudTowerServerVersion1_0_0 = "v1.0.0" +) + +var lastGCTime = time.Now() +var gcMinInterval = 10 * time.Minute +var cacheIdleTime = 15 * time.Minute + +// global cache map against cache keys. +// It stores both Tower clients and parsed Tower client configs from immutable Secrets. +var cacheMap sync.Map + +var towerSecretConfigGroup singleflight.Group + +type cacheItem struct { + LastUsedTime time.Time + TowerClient *towerclient.Cloudtower + TowerConfig *infrav1.TowerClientConfig +} + +// NewTowerClient gets a cached client or creates a new one if one does not +// already exist. +func NewTowerClient(ctx goctx.Context, k8sClient client.Client, tower infrav1.Tower) (*towerclient.Cloudtower, error) { + clientConfig, err := GetTowerClientConfig(ctx, k8sClient, tower) + if err != nil { + return nil, err + } + + logger := ctrl.LoggerFrom(ctx).WithName("client").WithValues("server", clientConfig.Server, "username", clientConfig.Username, "source", clientConfig.AuthMode) + + defer func() { + if lastGCTime.Add(gcMinInterval).Before(time.Now()) { + cleanupCache(logger) + } + }() + + clientKey := getTowerClientCacheKey(clientConfig) + if item, ok := loadCacheItem(clientKey); ok && item.TowerClient != nil { + logger.V(3).Info("found active cached tower client") + + return item.TowerClient, nil + } + + client, err := createTowerClient(httptransport.TLSClientOptions{ + InsecureSkipVerify: clientConfig.SkipTLSVerify, + }, towerclient.ClientConfig{ + Host: clientConfig.Server, + BasePath: "/v2/api", + Schemes: []string{"https"}, + }, towerclient.UserConfig{ + Name: clientConfig.Username, + Password: clientConfig.Password, + Source: models.UserSource(clientConfig.AuthMode), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create tower client") + } + + // Cache the client. + cacheMap.Store(clientKey, &cacheItem{LastUsedTime: time.Now(), TowerClient: client}) + logger.V(3).Info("cached tower client") + + return client, nil +} + +func createTowerClient(tlsOpts httptransport.TLSClientOptions, clientConfig towerclient.ClientConfig, userConfig towerclient.UserConfig) (*towerclient.Cloudtower, error) { + transport := httptransport.New(clientConfig.Host, clientConfig.BasePath, clientConfig.Schemes) + roundTripper, err := httptransport.TLSTransport(tlsOpts) + if err != nil { + return nil, err + } + + // For Arcfra vendor, we need to bypass the whitelist for AOC(CloudTower) + rtWithHeader := NewWithHeaderRoundTripper(roundTripper) + rtWithHeader.Set("x-bypass-whitelist", "true") //nolint:canonicalheader + transport.Transport = rtWithHeader + + client := towerclient.New(transport, strfmt.Default) + params := user.NewLoginParams() + params.WithTimeout(10 * time.Second) + params.RequestBody = &models.LoginInput{ + Username: &userConfig.Name, + Password: &userConfig.Password, + Source: userConfig.Source.Pointer(), + } + resp, err := client.User.Login(params) + if err != nil { + return nil, err + } + transport.DefaultAuthentication = httptransport.APIKeyAuth("Authorization", "header", *resp.Payload.Data.Token) + return client, nil +} + +func cleanupCache(logger logr.Logger) { + cacheMap.Range(func(key interface{}, value interface{}) bool { + item := value.(*cacheItem) + if item.LastUsedTime.Add(cacheIdleTime).Before(time.Now()) { + cacheMap.Delete(key) + logger.V(3).Info(fmt.Sprintf("delete inactive tower cache %s from cacheMap", key)) + } + + return true + }) + + lastGCTime = time.Now() +} + +// ClearClientCache removes all cached Tower clients and client configs. +func ClearClientCache() { + cacheMap = sync.Map{} + lastGCTime = time.Now() +} + +func GetTowerClientConfig(ctx goctx.Context, k8sClient client.Client, tower infrav1.Tower) (*infrav1.TowerClientConfig, error) { + if tower.SecretRef == nil { + return &tower.TowerClientConfig, nil + } + + secretKey := apitypes.NamespacedName{Namespace: tower.SecretRef.Namespace, Name: tower.SecretRef.Name} + cacheKey := getTowerSecretCacheKey(secretKey) + if item, ok := loadCacheItem(cacheKey); ok && item.TowerConfig != nil { + return item.TowerConfig, nil + } + + value, err, _ := towerSecretConfigGroup.Do(cacheKey, func() (interface{}, error) { + if item, ok := loadCacheItem(cacheKey); ok && item.TowerConfig != nil { + return item.TowerConfig, nil + } + + var secret corev1.Secret + if err := k8sClient.Get(ctx, secretKey, &secret); err != nil { + return nil, errors.Wrapf(err, "failed to get tower secret %s", secretKey.String()) + } + + config, err := ParseTowerClientConfigFromSecret(&secret) + if err != nil { + return nil, err + } + + // Cache the config if the server version is annotated. + if annotationsutil.HasAnnotation(&secret, CloudTowerServerVersionAnnotation) { + cacheMap.Store(cacheKey, &cacheItem{LastUsedTime: time.Now(), TowerConfig: config}) + } + + return config, nil + }) + if err != nil { + return nil, err + } + + return value.(*infrav1.TowerClientConfig), nil +} + +func loadCacheItem(cacheKey string) (*cacheItem, bool) { + value, ok := cacheMap.Load(cacheKey) + if !ok { + return nil, false + } + + item := value.(*cacheItem) + item.LastUsedTime = time.Now() + + return item, true +} + +func ParseTowerClientConfigFromSecret(secret *corev1.Secret) (*infrav1.TowerClientConfig, error) { + data, ok := secret.Data["cloudtower.yaml"] + if !ok { + return nil, errors.Errorf("tower secret %s missing cloudtower.yaml", client.ObjectKeyFromObject(secret)) + } + + decoder := utilyaml.NewYAMLOrJSONDecoder(bytes.NewReader(data), 1024) + var config infrav1.TowerClientConfig + if err := decoder.Decode(&config); err != nil { + return nil, errors.Wrapf(err, "failed to decode cloudtower.yaml in tower secret %s", client.ObjectKeyFromObject(secret)) + } + + return &config, nil +} + +func getTowerClientCacheKey(tower *infrav1.TowerClientConfig) string { + return "tower-client:" + getClientKey(tower) +} + +func getTowerSecretCacheKey(secretKey apitypes.NamespacedName) string { + return "tower-secret:" + secretKey.String() +} + +func getClientKey(tower *infrav1.TowerClientConfig) string { + encryptedTower := *tower + sum256 := sha256.Sum256([]byte(tower.Password)) + encryptedTower.Password = hex.EncodeToString(sum256[:]) + key, err := json.Marshal(encryptedTower) + if err != nil { + return fmt.Sprintf("%v", encryptedTower) + } + + return string(key) +} + +type withHeaderRoundTripper struct { + http.Header + + rt http.RoundTripper +} + +func NewWithHeaderRoundTripper(rt http.RoundTripper) withHeaderRoundTripper { + if rt == nil { + rt = http.DefaultTransport + } + + return withHeaderRoundTripper{Header: make(http.Header), rt: rt} +} + +func (h withHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if len(h.Header) == 0 { + return h.rt.RoundTrip(req) + } + + req = req.Clone(req.Context()) + for k, v := range h.Header { + req.Header[k] = v + } + + return h.rt.RoundTrip(req) +} diff --git a/pkg/cloudtower/tower_test.go b/pkg/cloudtower/tower_test.go new file mode 100644 index 00000000..7e27a893 --- /dev/null +++ b/pkg/cloudtower/tower_test.go @@ -0,0 +1,346 @@ +package cloudtower + +import ( + goctx "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/onsi/gomega" + towerclient "github.com/smartxworks/cloudtower-go-sdk/v2/client" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apitypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + infrav1 "github.com/smartxworks/cluster-api-provider-elf/api/v1beta1" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func resetTowerCache(t *testing.T) { + t.Helper() + t.Cleanup(func() { + ClearClientCache() + }) +} + +func TestClearClientCache(t *testing.T) { + t.Run("should clear client cache", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + clientKey := getTowerClientCacheKey(&infrav1.TowerClientConfig{}) + cachedClient := &towerclient.Cloudtower{} + cacheMap.Store(clientKey, &cacheItem{TowerClient: cachedClient}) + + ClearClientCache() + + isEmpty := true + cacheMap.Range(func(key, value any) bool { + isEmpty = true + return false + }) + g.Expect(isEmpty).To(gomega.BeTrue()) + }) +} + +func TestNewTowerClient(t *testing.T) { + t.Run("should get cached session and clear inactive session", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + + lastGCTime = time.Now().Add(-gcMinInterval - time.Second) + tower := infrav1.Tower{TowerClientConfig: infrav1.TowerClientConfig{Server: "127.0.0.1", Username: "tower", Password: "tower"}} + inactiveTowerConfig := tower.TowerClientConfig + inactiveTowerConfig.Username = "inactive" + invalidTower := tower.DeepCopy() + invalidTower.Username = "invalid" + + clientKey := getTowerClientCacheKey(&tower.TowerClientConfig) + cachedClient := &towerclient.Cloudtower{} + cacheMap.Store(clientKey, &cacheItem{TowerClient: cachedClient}) + inactiveClientKey := getTowerClientCacheKey(&inactiveTowerConfig) + cacheMap.Store(inactiveClientKey, &cacheItem{TowerClient: &towerclient.Cloudtower{}, LastUsedTime: time.Now().Add(-cacheIdleTime - time.Second)}) + inactiveSecretKey := getTowerSecretCacheKey(apitypes.NamespacedName{Namespace: "default", Name: "inactive-secret"}) + cacheMap.Store(inactiveSecretKey, &cacheItem{TowerConfig: &infrav1.TowerClientConfig{}, LastUsedTime: time.Now().Add(-cacheIdleTime - time.Second)}) + + client, err := NewTowerClient(goctx.Background(), nil, tower) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(client).To(gomega.Equal(cachedClient)) + + _, ok := cacheMap.Load(clientKey) + g.Expect(ok).To(gomega.BeTrue()) + _, ok = cacheMap.Load(inactiveClientKey) + g.Expect(ok).To(gomega.BeFalse()) + _, ok = cacheMap.Load(inactiveSecretKey) + g.Expect(ok).To(gomega.BeFalse()) + + client, err = NewTowerClient(goctx.Background(), nil, *invalidTower) + g.Expect(client).To(gomega.BeNil()) + g.Expect(err).To(gomega.HaveOccurred()) + }) +} + +func TestGetTowerClientConfig(t *testing.T) { + t.Run("returns inline config without requiring secret lookup", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + + tower := infrav1.Tower{TowerClientConfig: infrav1.TowerClientConfig{ + Server: "127.0.0.1", + Username: "tower", + Password: "tower-password", + AuthMode: "LDAP", + SkipTLSVerify: true, + }} + + config, err := GetTowerClientConfig(goctx.Background(), nil, tower) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(config).To(gomega.Equal(&tower.TowerClientConfig)) + }) + + t.Run("reads secret config and reuses cached immutable secret", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + ctx := goctx.Background() + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "cloudtower-server", + Namespace: "sks-system", + Annotations: map[string]string{ + CloudTowerServerVersionAnnotation: CloudTowerServerVersion1_0_0, + }, + }, + Data: map[string][]byte{ + "cloudtower.yaml": []byte("authMode: LOCAL\npassword: K5yt3hcjtUE4Teqe\nserver: 10.255.0.4\nskipTLSVerify: true\nusername: system-service\n"), + }, + } + k8sClient := fake.NewClientBuilder().WithObjects(secret).Build() + + tower := infrav1.Tower{ + TowerClientConfig: infrav1.TowerClientConfig{Server: "127.0.0.1", Username: "ignored", Password: "ignored"}, + SecretRef: &corev1.SecretReference{Name: "cloudtower-server", Namespace: "sks-system"}, + } + config, err := GetTowerClientConfig(ctx, k8sClient, tower) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(config).To(gomega.Equal(&infrav1.TowerClientConfig{ + Server: "10.255.0.4", + Username: "system-service", + Password: "K5yt3hcjtUE4Teqe", + AuthMode: "LOCAL", + SkipTLSVerify: true, + })) + + cacheKey := getTowerSecretCacheKey(apitypes.NamespacedName{Namespace: "sks-system", Name: "cloudtower-server"}) + cached, ok := loadCacheItem(cacheKey) + g.Expect(ok).To(gomega.BeTrue()) + g.Expect(cached.TowerConfig).To(gomega.Equal(config)) + + config, err = GetTowerClientConfig(ctx, fake.NewClientBuilder().Build(), tower) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(config.Server).To(gomega.Equal("10.255.0.4")) + }) + + t.Run("does not cache secret config without server version annotation", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + ctx := goctx.Background() + secret := newTowerSecret("cloudtower-server", "sks-system", "10.255.0.4", "system-service", "K5yt3hcjtUE4Teqe") + k8sClient := fake.NewClientBuilder().WithObjects(secret).Build() + + tower := infrav1.Tower{SecretRef: &corev1.SecretReference{Name: "cloudtower-server", Namespace: "sks-system"}} + config, err := GetTowerClientConfig(ctx, k8sClient, tower) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(config.Server).To(gomega.Equal("10.255.0.4")) + + cacheKey := getTowerSecretCacheKey(apitypes.NamespacedName{Namespace: "sks-system", Name: "cloudtower-server"}) + _, ok := cacheMap.Load(cacheKey) + g.Expect(ok).To(gomega.BeFalse()) + + _, err = GetTowerClientConfig(ctx, fake.NewClientBuilder().Build(), tower) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("failed to get tower secret sks-system/cloudtower-server")) + }) + + t.Run("uses namespace and name to isolate cached secret configs", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + ctx := goctx.Background() + secretA := newTowerSecretWithVersion("cloudtower-server", "namespace-a", "10.255.0.4", "user-a", "password-a") + secretB := newTowerSecretWithVersion("cloudtower-server", "namespace-b", "10.255.0.5", "user-b", "password-b") + k8sClient := fake.NewClientBuilder().WithObjects(secretA, secretB).Build() + + configA, err := GetTowerClientConfig(ctx, k8sClient, infrav1.Tower{SecretRef: &corev1.SecretReference{Name: "cloudtower-server", Namespace: "namespace-a"}}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + configB, err := GetTowerClientConfig(ctx, k8sClient, infrav1.Tower{SecretRef: &corev1.SecretReference{Name: "cloudtower-server", Namespace: "namespace-b"}}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(configA.Server).To(gomega.Equal("10.255.0.4")) + g.Expect(configA.Username).To(gomega.Equal("user-a")) + g.Expect(configB.Server).To(gomega.Equal("10.255.0.5")) + g.Expect(configB.Username).To(gomega.Equal("user-b")) + }) + + t.Run("returns contextual errors for missing secret", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + resetTowerCache(t) + + _, err := GetTowerClientConfig(goctx.Background(), fake.NewClientBuilder().Build(), infrav1.Tower{ + SecretRef: &corev1.SecretReference{Name: "missing", Namespace: "sks-system"}, + }) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("failed to get tower secret sks-system/missing")) + }) +} + +func TestParseTowerClientConfigFromSecret(t *testing.T) { + t.Run("parses yaml and json cloudtower config", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + yamlConfig, err := ParseTowerClientConfigFromSecret(&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "yaml-secret", Namespace: "sks-system"}, + Data: map[string][]byte{ + "cloudtower.yaml": []byte("authMode: LOCAL\npassword: yaml-password\nserver: 10.255.0.4\nskipTLSVerify: true\nusername: yaml-user\n"), + }, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(yamlConfig).To(gomega.Equal(&infrav1.TowerClientConfig{Server: "10.255.0.4", Username: "yaml-user", Password: "yaml-password", AuthMode: "LOCAL", SkipTLSVerify: true})) + + jsonConfig, err := ParseTowerClientConfigFromSecret(&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "json-secret", Namespace: "sks-system"}, + Data: map[string][]byte{ + "cloudtower.yaml": []byte(`{"authMode":"LDAP","password":"json-password","server":"10.255.0.5","skipTLSVerify":false,"username":"json-user"}`), + }, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(jsonConfig).To(gomega.Equal(&infrav1.TowerClientConfig{Server: "10.255.0.5", Username: "json-user", Password: "json-password", AuthMode: "LDAP"})) + }) + + t.Run("rejects missing cloudtower yaml key", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + _, err := ParseTowerClientConfigFromSecret(&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "invalid-secret", Namespace: "sks-system"}, + Data: map[string][]byte{"other.yaml": []byte("server: 10.255.0.4\n")}, + }) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("tower secret sks-system/invalid-secret missing cloudtower.yaml")) + }) + + t.Run("rejects malformed cloudtower yaml", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + _, err := ParseTowerClientConfigFromSecret(&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "malformed-secret", Namespace: "sks-system"}, + Data: map[string][]byte{"cloudtower.yaml": []byte("server: [unterminated\n")}, + }) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("failed to decode cloudtower.yaml in tower secret sks-system/malformed-secret")) + }) +} + +func TestTowerCacheKeys(t *testing.T) { + t.Run("client cache key hashes password without exposing plaintext", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + config := &infrav1.TowerClientConfig{Server: "10.255.0.4", Username: "system-service", Password: "super-secret", AuthMode: "LOCAL", SkipTLSVerify: true} + key := getTowerClientCacheKey(config) + + g.Expect(key).To(gomega.HavePrefix("tower-client:")) + g.Expect(key).ToNot(gomega.ContainSubstring("super-secret")) + g.Expect(key).To(gomega.ContainSubstring("system-service")) + g.Expect(config.Password).To(gomega.Equal("super-secret")) + }) + + t.Run("client cache key changes when password changes", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + base := &infrav1.TowerClientConfig{Server: "10.255.0.4", Username: "system-service", Password: "password-a", AuthMode: "LOCAL"} + changedPassword := *base + changedPassword.Password = "password-b" + + g.Expect(getTowerClientCacheKey(base)).ToNot(gomega.Equal(getTowerClientCacheKey(&changedPassword))) + }) + + t.Run("secret cache key includes namespace and name", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + g.Expect(getTowerSecretCacheKey(apitypes.NamespacedName{Namespace: "sks-system", Name: "cloudtower-server"})).To(gomega.Equal("tower-secret:sks-system/cloudtower-server")) + }) +} + +func TestWithHeaderRoundTripper(t *testing.T) { + t.Run("uses default transport when nil delegate is provided", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + rt := NewWithHeaderRoundTripper(nil) + + g.Expect(rt.rt).ToNot(gomega.BeNil()) + }) + + t.Run("does not clone or mutate headers when no configured headers exist", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + called := false + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + req.Header.Set("Existing", "value") + + rt := NewWithHeaderRoundTripper(roundTripFunc(func(got *http.Request) (*http.Response, error) { + called = true + g.Expect(got).To(gomega.BeIdenticalTo(req)) + g.Expect(got.Header.Get("Existing")).To(gomega.Equal("value")) + return &http.Response{StatusCode: http.StatusNoContent, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil + })) + + resp, err := rt.RoundTrip(req) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + g.Expect(resp.StatusCode).To(gomega.Equal(http.StatusNoContent)) + g.Expect(called).To(gomega.BeTrue()) + }) + + t.Run("adds configured headers to a cloned request", func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + req.Header.Set("Existing", "value") + + rt := NewWithHeaderRoundTripper(roundTripFunc(func(got *http.Request) (*http.Response, error) { + g.Expect(got).ToNot(gomega.BeIdenticalTo(req)) + g.Expect(got.Header.Get("Existing")).To(gomega.Equal("value")) + g.Expect(got.Header.Values("X-Test")).To(gomega.Equal([]string{"a", "b"})) + return &http.Response{StatusCode: http.StatusAccepted, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil + })) + rt.Set("X-Test", "a") + rt.Add("X-Test", "b") + + resp, err := rt.RoundTrip(req) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + g.Expect(resp.StatusCode).To(gomega.Equal(http.StatusAccepted)) + g.Expect(req.Header.Values("X-Test")).To(gomega.BeEmpty()) + }) +} + +func newTowerSecret(name, namespace, server, username, password string) *corev1.Secret { + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}, + Data: map[string][]byte{ + "cloudtower.yaml": []byte("authMode: LOCAL\npassword: " + password + "\nserver: " + server + "\nskipTLSVerify: true\nusername: " + username + "\n"), + }, + } +} + +func newTowerSecretWithVersion(name, namespace, server, username, password string) *corev1.Secret { + secret := newTowerSecret(name, namespace, server, username, password) + secret.Annotations = map[string]string{ + CloudTowerServerVersionAnnotation: CloudTowerServerVersion1_0_0, + } + return secret +} diff --git a/pkg/config/vm.go b/pkg/config/vm.go index eb8c42ee..8a9f0cec 100644 --- a/pkg/config/vm.go +++ b/pkg/config/vm.go @@ -18,7 +18,7 @@ package config const ( // VMDescription is the default description in a VM. - VMDescription = "Automatically created Kubernetes node by server %s." + VMDescription = "Automatically created Kubernetes node." ) // MaxConcurrentVMCreations is the maximum number of concurrent virtual machine creations. diff --git a/pkg/service/vm.go b/pkg/service/vm.go index 7c3e27d6..07515b95 100644 --- a/pkg/service/vm.go +++ b/pkg/service/vm.go @@ -23,6 +23,7 @@ import ( "github.com/go-logr/logr" "github.com/pkg/errors" + towerclient "github.com/smartxworks/cloudtower-go-sdk/v2/client" clientcluster "github.com/smartxworks/cloudtower-go-sdk/v2/client/cluster" clientvmtemplate "github.com/smartxworks/cloudtower-go-sdk/v2/client/content_library_vm_template" clientgpu "github.com/smartxworks/cloudtower-go-sdk/v2/client/gpu_device" @@ -38,10 +39,11 @@ import ( clientzone "github.com/smartxworks/cloudtower-go-sdk/v2/client/zone" "github.com/smartxworks/cloudtower-go-sdk/v2/models" "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" infrav1 "github.com/smartxworks/cluster-api-provider-elf/api/v1beta1" + "github.com/smartxworks/cluster-api-provider-elf/pkg/cloudtower" "github.com/smartxworks/cluster-api-provider-elf/pkg/config" - "github.com/smartxworks/cluster-api-provider-elf/pkg/session" annotationsutil "github.com/smartxworks/cluster-api-provider-elf/pkg/util/annotations" ) @@ -102,20 +104,20 @@ type VMService interface { GetVMGPUAllocationInfo(id string) (*models.VMGpuInfo, error) } -type NewVMServiceFunc func(ctx goctx.Context, auth infrav1.Tower, logger logr.Logger) (VMService, error) +type NewVMServiceFunc func(ctx goctx.Context, k8sClient client.Client, tower infrav1.Tower, logger logr.Logger) (VMService, error) -func NewVMService(ctx goctx.Context, auth infrav1.Tower, logger logr.Logger) (VMService, error) { - authSession, err := session.GetOrCreate(ctx, auth) +func NewVMService(ctx goctx.Context, k8sClient client.Client, tower infrav1.Tower, logger logr.Logger) (VMService, error) { + towerClient, err := cloudtower.NewTowerClient(ctx, k8sClient, tower) if err != nil { return nil, err } - return &TowerVMService{authSession, logger}, nil + return &TowerVMService{towerClient, logger}, nil } type TowerVMService struct { - Session *session.TowerSession `json:"session"` - Logger logr.Logger `json:"logger"` + Client *towerclient.Cloudtower `json:"towerClient"` + Logger logr.Logger `json:"logger"` } func (svr *TowerVMService) UpdateVM(vm *models.VM, elfMachine *infrav1.ElfMachine) (*models.WithTaskVM, error) { @@ -135,7 +137,7 @@ func (svr *TowerVMService) UpdateVM(vm *models.VM, elfMachine *infrav1.ElfMachin Where: &models.VMWhereInput{ID: TowerString(*vm.ID)}, } - updateVMResp, err := svr.Session.VM.UpdateVM(updateVMParams) + updateVMResp, err := svr.Client.VM.UpdateVM(updateVMParams) if err != nil { return nil, err } @@ -150,7 +152,7 @@ func (svr *TowerVMService) GetVMDisks(vmDiskIDs []string) ([]*models.VMDisk, err OrderBy: models.NewVMDiskOrderByInput(models.VMDiskOrderByInputBootASC), } - getVMDisksResp, err := svr.Session.VMDisk.GetVMDisks(getVMDisksParams) + getVMDisksResp, err := svr.Client.VMDisk.GetVMDisks(getVMDisksParams) if err != nil { return nil, err } @@ -164,7 +166,7 @@ func (svr *TowerVMService) GetVMVolume(volumeID string) (*models.VMVolume, error Where: &models.VMVolumeWhereInput{ID: TowerString(volumeID)}, } - getVMVolumesResp, err := svr.Session.VMVolume.GetVMVolumes(getVMVolumesParams) + getVMVolumesResp, err := svr.Client.VMVolume.GetVMVolumes(getVMVolumesParams) if err != nil { return nil, err } @@ -185,7 +187,7 @@ func (svr *TowerVMService) ResizeVMVolume(vmVolumeID string, size int64) (*model Where: &models.VMVolumeWhereInput{ID: TowerString(vmVolumeID)}, } - updateVMVolumeResp, err := svr.Session.VMVolume.UpdateVMVolume(updateVMVolumeParams) + updateVMVolumeResp, err := svr.Client.VMVolume.UpdateVMVolume(updateVMVolumeParams) if err != nil { return nil, err } @@ -213,7 +215,7 @@ func (svr *TowerVMService) Clone( createVMFromContentLibraryTemplateParams := clientvm.NewCreateVMFromContentLibraryTemplateParams() createVMFromContentLibraryTemplateParams.RequestBody = []*models.VMCreateVMFromContentLibraryTemplateParams{createVMFromTemplateParams} - createVMFromTemplateResp, err := svr.Session.VM.CreateVMFromContentLibraryTemplate(createVMFromContentLibraryTemplateParams) + createVMFromTemplateResp, err := svr.Client.VM.CreateVMFromContentLibraryTemplate(createVMFromContentLibraryTemplateParams) if err != nil { return nil, err } @@ -393,7 +395,7 @@ func (svr *TowerVMService) createVMFromTemplateParams( ClusterID: cluster.ID, HostID: TowerString(hostID), Name: TowerString(elfMachine.Name), - Description: TowerString(fmt.Sprintf(config.VMDescription, elfCluster.Spec.Tower.Server)), + Description: TowerString(config.VMDescription), Owner: owner, Vcpu: vCPU, CPUCores: cpuSocketCores, @@ -423,7 +425,7 @@ func (svr *TowerVMService) Migrate(vmID, hostID string) (*models.WithTaskVM, err }, } - migrateVMResp, err := svr.Session.VM.MigrateVM(migrateVMParams) + migrateVMResp, err := svr.Client.VM.MigrateVM(migrateVMParams) if err != nil { return nil, err } @@ -440,7 +442,7 @@ func (svr *TowerVMService) Delete(id string) (*models.Task, error) { }, } - deleteVMResp, err := svr.Session.VM.DeleteVM(deleteVMParams) + deleteVMResp, err := svr.Client.VM.DeleteVM(deleteVMParams) if err != nil { return nil, err } @@ -461,7 +463,7 @@ func (svr *TowerVMService) PowerOff(id string) (*models.Task, error) { }, } - poweroffVMResp, err := svr.Session.VM.PoweroffVM(poweroffVMParams) + poweroffVMResp, err := svr.Client.VM.PoweroffVM(poweroffVMParams) if err != nil { return nil, err } @@ -489,7 +491,7 @@ func (svr *TowerVMService) PowerOn(id string, hostID string) (*models.Task, erro startVMParams.RequestBody.Data = &models.VMStartParamsData{HostID: TowerString(hostID)} } - startVMResp, err := svr.Session.VM.StartVM(startVMParams) + startVMResp, err := svr.Client.VM.StartVM(startVMParams) if err != nil { return nil, err } @@ -510,7 +512,7 @@ func (svr *TowerVMService) ShutDown(id string) (*models.Task, error) { }, } - shutDownVMResp, err := svr.Session.VM.ShutDownVM(shutDownVMParams) + shutDownVMResp, err := svr.Client.VM.ShutDownVM(shutDownVMParams) if err != nil { return nil, err } @@ -531,7 +533,7 @@ func (svr *TowerVMService) RemoveGPUDevices(id string, gpus []*models.VMGpuOpera }, } - temoveVMGPUDeviceResp, err := svr.Session.VM.RemoveVMGpuDevice(removeVMGpuDeviceParams) + temoveVMGPUDeviceResp, err := svr.Client.VM.RemoveVMGpuDevice(removeVMGpuDeviceParams) if err != nil { return nil, err } @@ -560,7 +562,7 @@ func (svr *TowerVMService) AddGPUDevices(id string, gpuDeviceInfos []*GPUDeviceI }, } - addVMGpuDeviceResp, err := svr.Session.VM.AddVMGpuDevice(addVMGpuDeviceParams) + addVMGpuDeviceResp, err := svr.Client.VM.AddVMGpuDevice(addVMGpuDeviceParams) if err != nil { return nil, err } @@ -581,7 +583,7 @@ func (svr *TowerVMService) Get(id string) (*models.VM, error) { }, } - getVmsResp, err := svr.Session.VM.GetVms(getVmsParams) + getVmsResp, err := svr.Client.VM.GetVms(getVmsParams) if err != nil { return nil, err } @@ -602,7 +604,7 @@ func (svr *TowerVMService) GetByName(name string) (*models.VM, error) { }, } - getVmsResp, err := svr.Session.VM.GetVms(getVmsParams) + getVmsResp, err := svr.Client.VM.GetVms(getVmsParams) if err != nil { return nil, err } @@ -627,7 +629,7 @@ func (svr *TowerVMService) FindByIDs(ids []string) ([]*models.VM, error) { }, } - getVmsResp, err := svr.Session.VM.GetVms(getVmsParams) + getVmsResp, err := svr.Client.VM.GetVms(getVmsParams) if err != nil { return nil, err } @@ -648,7 +650,7 @@ func (svr *TowerVMService) FindVMsByName(name string) ([]*models.VM, error) { }, } - getVmsResp, err := svr.Session.VM.GetVms(getVmsParams) + getVmsResp, err := svr.Client.VM.GetVms(getVmsParams) if err != nil { return nil, err } @@ -668,7 +670,7 @@ func (svr *TowerVMService) GetVMNics(vmID string) ([]*models.VMNic, error) { OrderBy: models.NewVMNicOrderByInput(models.VMNicOrderByInputOrderASC), } - getVMNicsResp, err := svr.Session.VMNic.GetVMNics(getVMNicsParams) + getVMNicsResp, err := svr.Client.VMNic.GetVMNics(getVMNicsParams) if err != nil { return nil, err } @@ -687,7 +689,7 @@ func (svr *TowerVMService) AddVMNics(vmID string, nics []*models.VMNicParams) (* }, } - addVMNicResp, err := svr.Session.VM.AddVMNic(addVMNicParams) + addVMNicResp, err := svr.Client.VM.AddVMNic(addVMNicParams) if err != nil { return nil, err } @@ -708,7 +710,7 @@ func (svr *TowerVMService) GetCluster(id string) (*models.Cluster, error) { }, } - getClustersResp, err := svr.Session.Cluster.GetClusters(getClustersParams) + getClustersResp, err := svr.Client.Cluster.GetClusters(getClustersParams) if err != nil { return nil, err } @@ -730,7 +732,7 @@ func (svr *TowerVMService) GetClusterZones(clusterID string) ([]*models.Zone, er }, } - getZonesResp, err := svr.Session.Zone.GetZones(getZonesParams) + getZonesResp, err := svr.Client.Zone.GetZones(getZonesParams) if err != nil { return nil, err } @@ -746,7 +748,7 @@ func (svr *TowerVMService) GetHost(id string) (*models.Host, error) { }, } - getHostsResp, err := svr.Session.Host.GetHosts(getHostsParams) + getHostsResp, err := svr.Client.Host.GetHosts(getHostsParams) if err != nil { return nil, err } @@ -768,7 +770,7 @@ func (svr *TowerVMService) GetHostsByCluster(clusterID string) (Hosts, error) { }, } - getHostsResp, err := svr.Session.Host.GetHosts(getHostsParams) + getHostsResp, err := svr.Client.Host.GetHosts(getHostsParams) if err != nil { return nil, err } @@ -789,7 +791,7 @@ func (svr *TowerVMService) GetVlan(id string) (*models.Vlan, error) { }, } - getVlansResp, err := svr.Session.Vlan.GetVlans(getVlansParams) + getVlansResp, err := svr.Client.Vlan.GetVlans(getVlansParams) if err != nil { return nil, err } @@ -820,7 +822,7 @@ func (svr *TowerVMService) GetVMTemplate(template string) (*models.ContentLibrar }, } - getVMTemplatesResp, err := svr.Session.ContentLibraryVMTemplate.GetContentLibraryVMTemplates(getVMTemplatesParams) + getVMTemplatesResp, err := svr.Client.ContentLibraryVMTemplate.GetContentLibraryVMTemplates(getVMTemplatesParams) if err != nil { return nil, err } @@ -849,7 +851,7 @@ func (svr *TowerVMService) GetTask(id string) (*models.Task, error) { }, } - getTasksResp, err := svr.Session.Task.GetTasks(getTasksParams) + getTasksResp, err := svr.Client.Task.GetTasks(getTasksParams) if err != nil { return nil, err } @@ -894,7 +896,7 @@ func (svr *TowerVMService) UpsertLabel(key, value string) (*models.Label, error) Value: TowerString(value), }, } - getLabelResp, err := svr.Session.Label.GetLabels(getLabelParams) + getLabelResp, err := svr.Client.Label.GetLabels(getLabelParams) if err != nil { return nil, err } @@ -906,7 +908,7 @@ func (svr *TowerVMService) UpsertLabel(key, value string) (*models.Label, error) createLabelParams.RequestBody = []*models.LabelCreationParams{ {Key: &key, Value: &value}, } - createLabelResp, err := svr.Session.Label.CreateLabel(createLabelParams) + createLabelResp, err := svr.Client.Label.CreateLabel(createLabelParams) if err != nil { return nil, err } @@ -936,7 +938,7 @@ func (svr *TowerVMService) DeleteLabel(key, value string, strict bool) (string, ) } - deleteLabelResp, err := svr.Session.Label.DeleteLabel(deleteLabelParams) + deleteLabelResp, err := svr.Client.Label.DeleteLabel(deleteLabelParams) if err != nil { return "", err } @@ -960,7 +962,7 @@ func (svr *TowerVMService) CleanUnusedLabels(keys []string) ([]string, error) { }, } - deleteLabelResp, err := svr.Session.Label.DeleteLabel(deleteLabelParams) + deleteLabelResp, err := svr.Client.Label.DeleteLabel(deleteLabelParams) if err != nil { return nil, err } @@ -986,7 +988,7 @@ func (svr *TowerVMService) AddLabelsToVM(vmID string, labelIds []string) (*model }, }, } - addLabelsResp, err := svr.Session.Label.AddLabelsToResources(addLabelsParams) + addLabelsResp, err := svr.Client.Label.AddLabelsToResources(addLabelsParams) if err != nil { return nil, err } @@ -1009,7 +1011,7 @@ func (svr *TowerVMService) CreateVMPlacementGroup(name, clusterID string, vmPoli VMVMPolicyEnabled: TowerBool(true), VMVMPolicy: &vmPolicy, }} - createVMPlacementGroupResp, err := svr.Session.VMPlacementGroup.CreateVMPlacementGroup(createVMPlacementGroupParams) + createVMPlacementGroupResp, err := svr.Client.VMPlacementGroup.CreateVMPlacementGroup(createVMPlacementGroupParams) if err != nil { return nil, err } @@ -1026,7 +1028,7 @@ func (svr *TowerVMService) GetVMPlacementGroup(name string) (*models.VMPlacement }, } - getVMPlacementGroupsResp, err := svr.Session.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) + getVMPlacementGroupsResp, err := svr.Client.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) if err != nil { return nil, err } @@ -1057,7 +1059,7 @@ func (svr *TowerVMService) AddVMsToPlacementGroup(placementGroup *models.VMPlace }, } - updateVMPlacementGroupResp, err := svr.Session.VMPlacementGroup.UpdateVMPlacementGroup(updateVMPlacementGroupParams) + updateVMPlacementGroupResp, err := svr.Client.VMPlacementGroup.UpdateVMPlacementGroup(updateVMPlacementGroupParams) if err != nil { return nil, err } @@ -1078,7 +1080,7 @@ func (svr *TowerVMService) DeleteVMPlacementGroupByID(ctx goctx.Context, id stri }, } - getVMPlacementGroupsResp, err := svr.Session.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) + getVMPlacementGroupsResp, err := svr.Client.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) if err != nil { return false, err } @@ -1096,7 +1098,7 @@ func (svr *TowerVMService) DeleteVMPlacementGroupByID(ctx goctx.Context, id stri }, } - if _, err := svr.Session.VMPlacementGroup.DeleteVMPlacementGroup(deleteVMPlacementGroupParams); err != nil { + if _, err := svr.Client.VMPlacementGroup.DeleteVMPlacementGroup(deleteVMPlacementGroupParams); err != nil { return false, err } @@ -1120,7 +1122,7 @@ func (svr *TowerVMService) DeleteVMPlacementGroupsByNamePrefix(ctx goctx.Context }, } - getVMPlacementGroupsResp, err := svr.Session.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) + getVMPlacementGroupsResp, err := svr.Client.VMPlacementGroup.GetVMPlacementGroups(getVMPlacementGroupsParams) if err != nil { return nil, err } else if len(getVMPlacementGroupsResp.Payload) == 0 { @@ -1135,7 +1137,7 @@ func (svr *TowerVMService) DeleteVMPlacementGroupsByNamePrefix(ctx goctx.Context }, } - deleteVMPlacementGroupResp, err := svr.Session.VMPlacementGroup.DeleteVMPlacementGroup(deleteVMPlacementGroupParams) + deleteVMPlacementGroupResp, err := svr.Client.VMPlacementGroup.DeleteVMPlacementGroup(deleteVMPlacementGroupParams) if err != nil { return nil, err } @@ -1180,7 +1182,7 @@ func (svr *TowerVMService) GetGPUDevicesAllocationInfoByIDs(gpuIDs []string) (GP }, } - getDetailVMInfoByGpuDevicesResp, err := svr.Session.GpuDevice.GetDetailVMInfoByGpuDevices(getDetailVMInfoByGpuDevicesParams) + getDetailVMInfoByGpuDevicesResp, err := svr.Client.GpuDevice.GetDetailVMInfoByGpuDevices(getDetailVMInfoByGpuDevicesParams) if err != nil { return nil, err } @@ -1209,7 +1211,7 @@ func (svr *TowerVMService) GetGPUDevicesAllocationInfoByHostIDs(hostIDs []string getDetailVMInfoByGpuDevicesParams.RequestBody.Where.AvailableVgpusNumGt = TowerInt32(0) } - getDetailVMInfoByGpuDevicesResp, err := svr.Session.GpuDevice.GetDetailVMInfoByGpuDevices(getDetailVMInfoByGpuDevicesParams) + getDetailVMInfoByGpuDevicesResp, err := svr.Client.GpuDevice.GetDetailVMInfoByGpuDevices(getDetailVMInfoByGpuDevicesParams) if err != nil { return nil, err } @@ -1226,7 +1228,7 @@ func (svr *TowerVMService) GetVMGPUAllocationInfo(id string) (*models.VMGpuInfo, }, } - getVMGpuDeviceInfoResp, err := svr.Session.VM.GetVMGpuDeviceInfo(getVMGpuDeviceInfoParams) + getVMGpuDeviceInfoResp, err := svr.Client.VM.GetVMGpuDeviceInfo(getVMGpuDeviceInfoParams) if err != nil { return nil, err } diff --git a/pkg/session/tower.go b/pkg/session/tower.go deleted file mode 100644 index 824939f9..00000000 --- a/pkg/session/tower.go +++ /dev/null @@ -1,175 +0,0 @@ -/* -Copyright 2022. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package session - -import ( - goctx "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "net/http" - "sync" - "time" - - "github.com/go-logr/logr" - httptransport "github.com/go-openapi/runtime/client" - "github.com/go-openapi/strfmt" - "github.com/pkg/errors" - towerclient "github.com/smartxworks/cloudtower-go-sdk/v2/client" - "github.com/smartxworks/cloudtower-go-sdk/v2/client/user" - "github.com/smartxworks/cloudtower-go-sdk/v2/models" - ctrl "sigs.k8s.io/controller-runtime" - - infrav1 "github.com/smartxworks/cluster-api-provider-elf/api/v1beta1" -) - -var lastGCTime = time.Now() -var gcMinInterval = 10 * time.Minute -var sessionIdleTime = 10 * time.Minute - -// global Session map against sessionKeys -// in map[sessionKey]cacheItem. -var sessionCache sync.Map - -type cacheItem struct { - LastUsedTime time.Time - Session *TowerSession -} - -type TowerSession struct { - *towerclient.Cloudtower -} - -// GetOrCreate gets a cached session or creates a new one if one does not -// already exist. -func GetOrCreate(ctx goctx.Context, tower infrav1.Tower) (*TowerSession, error) { - logger := ctrl.LoggerFrom(ctx).WithName("session").WithValues("server", tower.Server, "username", tower.Username, "source", tower.AuthMode) - - defer func() { - if lastGCTime.Add(gcMinInterval).Before(time.Now()) { - cleanupSessionCache(logger) - lastGCTime = time.Now() - } - }() - - sessionKey := getSessionKey(&tower) - if value, ok := sessionCache.Load(sessionKey); ok { - item := value.(*cacheItem) - item.LastUsedTime = time.Now() - logger.V(3).Info("found active cached tower client session") - - return item.Session, nil - } - - client, err := createTowerClient(httptransport.TLSClientOptions{ - InsecureSkipVerify: tower.SkipTLSVerify, - }, towerclient.ClientConfig{ - Host: tower.Server, - BasePath: "/v2/api", - Schemes: []string{"https"}, - }, towerclient.UserConfig{ - Name: tower.Username, - Password: tower.Password, - Source: models.UserSource(tower.AuthMode), - }) - if err != nil { - return nil, errors.Wrap(err, "failed to create tower client") - } - - session := &TowerSession{client} - - // Cache the session. - sessionCache.Store(sessionKey, &cacheItem{LastUsedTime: time.Now(), Session: session}) - logger.V(3).Info("cached tower client session") - - return session, nil -} - -func createTowerClient(tlsOpts httptransport.TLSClientOptions, clientConfig towerclient.ClientConfig, userConfig towerclient.UserConfig) (*towerclient.Cloudtower, error) { - transport := httptransport.New(clientConfig.Host, clientConfig.BasePath, clientConfig.Schemes) - roundTripper, err := httptransport.TLSTransport(tlsOpts) - if err != nil { - return nil, err - } - - // For Arcfra vendor, we need to bypass the whitelist for AOC(CloudTower) - rtWithHeader := NewWithHeaderRoundTripper(roundTripper) - rtWithHeader.Set("x-bypass-whitelist", "true") //nolint:canonicalheader - transport.Transport = rtWithHeader - - client := towerclient.New(transport, strfmt.Default) - params := user.NewLoginParams() - params.WithTimeout(10 * time.Second) - params.RequestBody = &models.LoginInput{ - Username: &userConfig.Name, - Password: &userConfig.Password, - Source: userConfig.Source.Pointer(), - } - resp, err := client.User.Login(params) - if err != nil { - return nil, err - } - transport.DefaultAuthentication = httptransport.APIKeyAuth("Authorization", "header", *resp.Payload.Data.Token) - return client, nil -} - -func cleanupSessionCache(logger logr.Logger) { - sessionCache.Range(func(key interface{}, value interface{}) bool { - item := value.(*cacheItem) - if item.LastUsedTime.Add(sessionIdleTime).Before(time.Now()) { - sessionCache.Delete(key) - logger.V(3).Info(fmt.Sprintf("delete inactive tower client session %s from sessionCache", key)) - } - - return true - }) -} - -func getSessionKey(tower *infrav1.Tower) string { - encryptedTower := tower.DeepCopy() - sum256 := sha256.Sum256([]byte(tower.Password)) - encryptedTower.Password = hex.EncodeToString(sum256[:]) - - return fmt.Sprintf("%v", encryptedTower) -} - -type withHeaderRoundTripper struct { - http.Header - - rt http.RoundTripper -} - -func NewWithHeaderRoundTripper(rt http.RoundTripper) withHeaderRoundTripper { - if rt == nil { - rt = http.DefaultTransport - } - - return withHeaderRoundTripper{Header: make(http.Header), rt: rt} -} - -func (h withHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(h.Header) == 0 { - return h.rt.RoundTrip(req) - } - - req = req.Clone(req.Context()) - for k, v := range h.Header { - req.Header[k] = v - } - - return h.rt.RoundTrip(req) -} diff --git a/pkg/session/tower_test.go b/pkg/session/tower_test.go deleted file mode 100644 index a08c25d2..00000000 --- a/pkg/session/tower_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package session - -import ( - goctx "context" - "testing" - "time" - - "github.com/onsi/gomega" - - infrav1 "github.com/smartxworks/cluster-api-provider-elf/api/v1beta1" -) - -func TestGetOrCreate(t *testing.T) { - g := gomega.NewGomegaWithT(t) - - t.Run("should get cached session and clear inactive session", func(t *testing.T) { - lastGCTime = lastGCTime.Add(-gcMinInterval) - tower := infrav1.Tower{Server: "127.0.0.1", Username: "tower", Password: "tower"} - inactiveTower := tower.DeepCopy() - inactiveTower.Username = "inactive" - invalidTower := tower.DeepCopy() - invalidTower.Username = "invalid" - - sessionKey := getSessionKey(&tower) - cachedSession := &TowerSession{} - sessionCache.Store(sessionKey, &cacheItem{Session: cachedSession}) - inactiveSessionKey := getSessionKey(inactiveTower) - sessionCache.Store(inactiveSessionKey, &cacheItem{Session: &TowerSession{}, LastUsedTime: time.Now().Add(-sessionIdleTime)}) - - session, err := GetOrCreate(goctx.Background(), tower) - g.Expect(err).ToNot(gomega.HaveOccurred()) - g.Expect(session).To(gomega.Equal(cachedSession)) - - _, ok := sessionCache.Load(inactiveSessionKey) - g.Expect(ok).To(gomega.BeFalse()) - - session, err = GetOrCreate(goctx.Background(), *invalidTower) - g.Expect(session).To(gomega.BeNil()) - g.Expect(err).To(gomega.HaveOccurred()) - }) -} diff --git a/templates/cluster-template.yaml b/templates/cluster-template.yaml index fa150da1..5e0693a4 100644 --- a/templates/cluster-template.yaml +++ b/templates/cluster-template.yaml @@ -27,11 +27,9 @@ metadata: spec: cluster: "${ELF_CLUSTER}" tower: - server: "${TOWER_SERVER}" - username: "${TOWER_USERNAME}" - password: "${TOWER_PASSWORD}" - authMode: ${TOWER_AUTH_MODE:=LOCAL} - skipTLSVerify: ${TOWER_SKIP_TLS_VERIFY:=false} + secretRef: + name: "${TOWER_SECRET_NAME}" + namespace: "${TOWER_SECRET_NAMESPACE}" controlPlaneEndpoint: host: "${CONTROL_PLANE_ENDPOINT_IP}" port: 6443 diff --git a/test/e2e/tower_test.go b/test/e2e/tower_test.go index 412e712e..8410dee5 100644 --- a/test/e2e/tower_test.go +++ b/test/e2e/tower_test.go @@ -47,12 +47,15 @@ func init() { func initTowerSession() { var err error - vmService, err = service.NewVMService(goctx.Background(), infrav1.Tower{ - Server: towerServer, - Username: towerUsername, - Password: towerPassword, - AuthMode: towerAuthMode, - SkipTLSVerify: towerSkipTLSVerify}, ctrllog.Log) + vmService, err = service.NewVMService(goctx.Background(), nil, infrav1.Tower{ + TowerClientConfig: infrav1.TowerClientConfig{ + Server: towerServer, + Username: towerUsername, + Password: towerPassword, + AuthMode: towerAuthMode, + SkipTLSVerify: towerSkipTLSVerify, + }, + }, ctrllog.Log) Expect(err).ShouldNot(HaveOccurred()) template, err := vmService.GetVMTemplate(vmTemplate)