diff --git a/api/handler/setting/handler.go b/api/handler/setting/handler.go index 0b4d09b..bd8b077 100644 --- a/api/handler/setting/handler.go +++ b/api/handler/setting/handler.go @@ -113,6 +113,13 @@ func ConfigureDomain(c *fiber.Ctx) error { return c.Status(restErr.StatusCode()).JSON(restErr) } + restErr = tlsCertificateService.EnableHttpsRedirects() + if restErr != nil { + logger.Error("ENABLE_HTTPS_REDIRECTS", restErr) + sqlclient.Rollback(txHandle) + return c.Status(restErr.StatusCode()).JSON(restErr) + } + sqlclient.Commit(txHandle) return c.Status(http.StatusOK).JSON(responder.NewResponse(responder.SuccessMessage{Message: "domain configured successfully!"})) } diff --git a/api/handler/setting/handler_test.go b/api/handler/setting/handler_test.go index f40d739..65ffe7d 100644 --- a/api/handler/setting/handler_test.go +++ b/api/handler/setting/handler_test.go @@ -144,6 +144,7 @@ var ( tlsGetTraefikDeploymentFunc func() (*appsv1.Deployment, restErrors.IRestErr) tlsConfigureLetsEncryptFunc func(domain string, resolverNme string, acmeEmail string) restErrors.IRestErr tlsConfigureCustomCertificateFunc func(secretName string) restErrors.IRestErr + tlsEnableHttpsRedirectsFunc func() restErrors.IRestErr ) func (tls tlsCertificateServiceMock) GetTraefikDeployment() (*appsv1.Deployment, restErrors.IRestErr) { @@ -155,6 +156,9 @@ func (tls tlsCertificateServiceMock) ConfigureLetsEncrypt(domain string, resolve func (tls tlsCertificateServiceMock) ConfigureCustomCertificate(secretName string) restErrors.IRestErr { return tlsConfigureCustomCertificateFunc(secretName) } +func (tls tlsCertificateServiceMock) EnableHttpsRedirects() restErrors.IRestErr { + return tlsEnableHttpsRedirectsFunc() +} var ( UserWithTransactionFunc func(txHandle *gorm.DB) user.IService @@ -304,6 +308,9 @@ func TestConfigureDomain(t *testing.T) { tlsConfigureLetsEncryptFunc = func(domain string, resolverNme string, acmeEmail string) restErrors.IRestErr { return nil } + tlsEnableHttpsRedirectsFunc = func() restErrors.IRestErr { + return nil + } networkIdentifiers = func() (ip string, hostName string, restErr restErrors.IRestErr) { return "1223", "", nil } @@ -350,6 +357,9 @@ func TestConfigureDomain(t *testing.T) { tlsConfigureLetsEncryptFunc = func(domain string, resolverNme string, acmeEmail string) restErrors.IRestErr { return nil } + tlsEnableHttpsRedirectsFunc = func() restErrors.IRestErr { + return nil + } networkIdentifiers = func() (ip string, hostName string, restErr restErrors.IRestErr) { return "", "hostname.amazon.com", nil } diff --git a/k8s/tlscertificate/service.go b/k8s/tlscertificate/service.go index 4306a5b..138bd65 100644 --- a/k8s/tlscertificate/service.go +++ b/k8s/tlscertificate/service.go @@ -23,6 +23,7 @@ type TLSCertificate interface { GetTraefikDeployment() (*appsv1.Deployment, restErrors.IRestErr) ConfigureLetsEncrypt(domain string, resolverNme string, acmeEmail string) restErrors.IRestErr ConfigureCustomCertificate(secretName string) restErrors.IRestErr + EnableHttpsRedirects() restErrors.IRestErr } type tlsCertificate struct{} @@ -167,3 +168,48 @@ func (t *tlsCertificate) ConfigureCustomCertificate(secretName string) restError return nil } + +func (t *tlsCertificate) EnableHttpsRedirects() restErrors.IRestErr { + deploy, restErr := t.GetTraefikDeployment() + if restErr != nil { + return restErr + } + + httpRedirctsArgs := []string{ + "--entrypoints.web.http.redirections.entryPoint.to=:443", + "--entrypoints.web.http.redirections.entryPoint.scheme=https", + "--entrypoints.web.http.redirections.entryPoint.permanent=true", + } + + var newArgs []string + for i, container := range deploy.Spec.Template.Spec.Containers { + if container.Name == config.Environment.TraefikDeploymentName { + for _, arg := range httpRedirctsArgs { + if !contains(container.Args, arg) { + newArgs = append(newArgs, arg) + } + } + deploy.Spec.Template.Spec.Containers[i].Args = append(deploy.Spec.Template.Spec.Containers[i].Args, newArgs...) + break + } + } + + if len(newArgs) > 0 { + err := k8sClient.Update(context.Background(), deploy) + if err != nil { + go logger.Warn(t.ConfigureLetsEncrypt, err) + return restErrors.NewInternalServerError(err.Error()) + } + } + + return nil +} + +func contains(slice []string, item string) bool { + for _, s := range slice { + if strings.TrimSpace(s) == strings.TrimSpace(item) { + return true + } + } + return false +}