From 808dbbd723aaf31587fb01ce819e31f718430774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Scha=CC=88fer?= <101886095+PeterSchafer@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:25:55 +0100 Subject: [PATCH] chore: Introduce re-usable teardown function --- cliv2/cmd/cliv2/instrumentation.go | 78 +++++++++++ cliv2/cmd/cliv2/main.go | 178 ++++++++++--------------- cliv2/cmd/cliv2/main_test.go | 11 +- cliv2/internal/cliv2/cliv2.go | 19 ++- cliv2/internal/cliv2/cliv2_test.go | 18 +-- cliv2/pkg/basic_workflows/legacycli.go | 4 +- 6 files changed, 170 insertions(+), 138 deletions(-) diff --git a/cliv2/cmd/cliv2/instrumentation.go b/cliv2/cmd/cliv2/instrumentation.go index 09e7e81bf5..546f03cfbe 100644 --- a/cliv2/cmd/cliv2/instrumentation.go +++ b/cliv2/cmd/cliv2/instrumentation.go @@ -4,14 +4,19 @@ package main import _ "github.com/snyk/go-application-framework/pkg/networking/fips_enable" import ( + "context" + "encoding/json" "os/exec" + "strconv" "strings" "time" + "github.com/rs/zerolog" "github.com/snyk/go-application-framework/pkg/analytics" "github.com/snyk/go-application-framework/pkg/configuration" "github.com/snyk/go-application-framework/pkg/instrumentation" + "github.com/snyk/cli/cliv2/internal/constants" cli_utils "github.com/snyk/cli/cliv2/internal/utils" localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows" @@ -74,3 +79,76 @@ func updateInstrumentationDataBeforeSending(cliAnalytics analytics.Analytics, st cliAnalytics.GetInstrumentation().SetStatus(analytics.Failure) } } + +func sendAnalytics(ctx context.Context, a analytics.Analytics, debugLogger *zerolog.Logger) { + debugLogger.Print("Sending Analytics") + + a.SetApiUrl(globalConfiguration.GetString(configuration.API_URL)) + + request, err := a.GetRequest() + if err != nil { + debugLogger.Err(err).Msg("Failed to create Analytics request") + return + } + + // Use context to respect teardown timeout + request = request.WithContext(ctx) + + client := globalEngine.GetNetworkAccess().GetHttpClient() + res, err := client.Do(request) + if err != nil { + debugLogger.Err(err).Msg("Failed to send Analytics") + return + } + defer func() { + _ = res.Body.Close() + }() + + successfullySend := 200 <= res.StatusCode && res.StatusCode < 300 + if successfullySend { + debugLogger.Print("Analytics successfully send") + } else { + debugLogger.Print("Failed to send Analytics:", res.Status) + } +} + +func sendInstrumentation(ctx context.Context, eng workflow.Engine, instrumentor analytics.InstrumentationCollector, logger *zerolog.Logger) { + // Avoid duplicate data to be sent for IDE integrations that use the CLI + if !shallSendInstrumentation(eng.GetConfiguration(), instrumentor) { + logger.Print("This CLI call is not instrumented!") + return + } + + // add temporary static nodejs binary flag, remove once linuxstatic is official + staticNodeJsBinaryBool, parseErr := strconv.ParseBool(constants.StaticNodeJsBinary) + if parseErr != nil { + logger.Print("Failed to parse staticNodeJsBinary:", parseErr) + } else { + // the legacycli:: prefix is added to maintain compatibility with our monitoring dashboard + instrumentor.AddExtension("legacycli::static-nodejs-binary", staticNodeJsBinaryBool) + } + + logger.Print("Sending Instrumentation") + data, err := analytics.GetV2InstrumentationObject(instrumentor, analytics.WithLogger(logger)) + if err != nil { + logger.Err(err).Msg("Failed to derive data object") + } + + v2InstrumentationData := utils.ValueOf(json.Marshal(data)) + localConfiguration := globalConfiguration.Clone() + // the report analytics workflow needs --experimental to run + // we pass the flag here so that we report at every interaction + localConfiguration.Set(configuration.FLAG_EXPERIMENTAL, true) + localConfiguration.Set("inputData", string(v2InstrumentationData)) + _, err = eng.Invoke( + localworkflows.WORKFLOWID_REPORT_ANALYTICS, + workflow.WithConfig(localConfiguration), + workflow.WithContext(ctx), + ) + + if err != nil { + logger.Err(err).Msg("Failed to send Instrumentation") + } else { + logger.Print("Instrumentation successfully sent") + } +} diff --git a/cliv2/cmd/cliv2/main.go b/cliv2/cmd/cliv2/main.go index cb2fb369c0..593b653895 100644 --- a/cliv2/cmd/cliv2/main.go +++ b/cliv2/cmd/cliv2/main.go @@ -11,7 +11,6 @@ import ( "io" "os" "os/exec" - "strconv" "strings" "sync" "time" @@ -75,6 +74,7 @@ import ( var internalOS string var globalEngine workflow.Engine var globalConfiguration configuration.Configuration +var globalContext context.Context var helpProvided bool var noopLogger zerolog.Logger = zerolog.New(io.Discard) @@ -88,6 +88,7 @@ const ( debug_level_flag string = "log-level" integrationNameFlag string = "integration-name" maxNetworkRequestAttempts string = "max-attempts" + teardownTimeout = 5 * time.Second ) type JsonErrorStruct struct { @@ -194,98 +195,33 @@ func runMainWorkflow(config configuration.Configuration, cmd *cobra.Command, arg globalLogger.Print("Running ", name) globalEngine.GetAnalytics().SetCommand(name) - err = runWorkflowAndProcessData(globalEngine, globalLogger, name) + err = runWorkflowAndProcessData(globalContext, globalEngine, globalLogger, name) return err } -func runWorkflowAndProcessData(engine workflow.Engine, logger *zerolog.Logger, name string) error { +func runWorkflowAndProcessData(ctx context.Context, engine workflow.Engine, logger *zerolog.Logger, name string) error { ic := engine.GetAnalytics().GetInstrumentation() - output, err := engine.Invoke(workflow.NewWorkflowIdentifier(name), workflow.WithInstrumentationCollector(ic)) + output, err := engine.Invoke(workflow.NewWorkflowIdentifier(name), workflow.WithContext(ctx), workflow.WithInstrumentationCollector(ic)) if err != nil { logger.Print("Failed to execute the command! ", err) return err } - outputFiltered, err := engine.Invoke(localworkflows.WORKFLOWID_FILTER_FINDINGS, workflow.WithInput(output), workflow.WithInstrumentationCollector(ic)) + outputFiltered, err := engine.Invoke(localworkflows.WORKFLOWID_FILTER_FINDINGS, workflow.WithContext(ctx), workflow.WithInput(output), workflow.WithInstrumentationCollector(ic)) if err != nil { logger.Err(err).Msg(err.Error()) return err } - _, err = engine.Invoke(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW, workflow.WithInput(outputFiltered), workflow.WithInstrumentationCollector(ic)) + _, err = engine.Invoke(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW, workflow.WithContext(ctx), workflow.WithInput(outputFiltered), workflow.WithInstrumentationCollector(ic)) if err == nil { err = getErrorFromWorkFlowData(engine, outputFiltered) } return err } -func sendAnalytics(analytics analytics.Analytics, debugLogger *zerolog.Logger) { - debugLogger.Print("Sending Analytics") - - analytics.SetApiUrl(globalConfiguration.GetString(configuration.API_URL)) - - res, err := analytics.Send() - if err != nil { - debugLogger.Err(err).Msg("Failed to send Analytics") - return - } - defer func() { _ = res.Body.Close() }() - - successfullySend := 200 <= res.StatusCode && res.StatusCode < 300 - if successfullySend { - debugLogger.Print("Analytics successfully send") - } else { - var details string - if res != nil { - details = res.Status - } - - debugLogger.Print("Failed to send Analytics:", details) - } -} - -func sendInstrumentation(eng workflow.Engine, instrumentor analytics.InstrumentationCollector, logger *zerolog.Logger) { - // Avoid duplicate data to be sent for IDE integrations that use the CLI - if !shallSendInstrumentation(eng.GetConfiguration(), instrumentor) { - logger.Print("This CLI call is not instrumented!") - return - } - - // add temporary static nodejs binary flag, remove once linuxstatic is official - staticNodeJsBinaryBool, parseErr := strconv.ParseBool(constants.StaticNodeJsBinary) - if parseErr != nil { - logger.Print("Failed to parse staticNodeJsBinary:", parseErr) - } else { - // the legacycli:: prefix is added to maintain compatibility with our monitoring dashboard - instrumentor.AddExtension("legacycli::static-nodejs-binary", staticNodeJsBinaryBool) - } - - logger.Print("Sending Instrumentation") - data, err := analytics.GetV2InstrumentationObject(instrumentor, analytics.WithLogger(logger)) - if err != nil { - logger.Err(err).Msg("Failed to derive data object") - } - - v2InstrumentationData := utils.ValueOf(json.Marshal(data)) - localConfiguration := globalConfiguration.Clone() - // the report analytics workflow needs --experimental to run - // we pass the flag here so that we report at every interaction - localConfiguration.Set(configuration.FLAG_EXPERIMENTAL, true) - localConfiguration.Set("inputData", string(v2InstrumentationData)) - _, err = eng.InvokeWithConfig( - localworkflows.WORKFLOWID_REPORT_ANALYTICS, - localConfiguration, - ) - - if err != nil { - logger.Err(err).Msg("Failed to send Instrumentation") - } else { - logger.Print("Instrumentation successfully sent") - } -} - func help(_ *cobra.Command, _ []string) error { helpProvided = true args := utils.RemoveSimilar(os.Args[1:], "--") // remove all double dash arguments to avoid issues with the help command @@ -548,11 +484,65 @@ func initExtensions(engine workflow.Engine, config configuration.Configuration) } } +// tearDown handles sending analytics and instrumentation +// It is used both for normal exit and signal-triggered exit +func tearDown(err error, errorList []error, startTime time.Time, ua networking.UserAgentInfo, cliAnalytics analytics.Analytics, networkAccess networking.NetworkAccess) int { + // Create a context with timeout for teardown operations to ensure we don't hang indefinitely + teardownCtx, cancel := context.WithTimeout(context.Background(), teardownTimeout) + defer cancel() + + outputError := err + allErrors := errorList + + if err != nil { + allErrors, outputError = processError(err, errorList) + + for _, tempError := range allErrors { + if tempError != nil { + cliAnalytics.AddError(tempError) + } + } + } + + exitCode := cliv2.DeriveExitCode(outputError) + globalLogger.Printf("Deriving Exit Code %d (cause: %v)", exitCode, outputError) + + displayError(outputError, globalEngine.GetUserInterface(), globalConfiguration, globalContext) + + updateInstrumentationDataBeforeSending(cliAnalytics, startTime, ua, exitCode) + + if !globalConfiguration.GetBool(configuration.ANALYTICS_DISABLED) { + sendAnalytics(teardownCtx, cliAnalytics, globalLogger) + } + sendInstrumentation(teardownCtx, globalEngine, cliAnalytics.GetInstrumentation(), globalLogger) + + // cleanup resources in use + // WARNING: deferred actions will execute AFTER cleanup; only defer if not impacted by this + if _, cleanupErr := globalEngine.Invoke(basic_workflows.WORKFLOWID_GLOBAL_CLEANUP, workflow.WithContext(teardownCtx)); cleanupErr != nil { + globalLogger.Printf("Failed to cleanup %v", cleanupErr) + } + + if globalConfiguration.GetBool(configuration.DEBUG) { + writeLogFooter(exitCode, allErrors, globalConfiguration, networkAccess) + } + + return exitCode +} + func MainWithErrorCode() int { initDebugBuild() errorList := []error{} errorListMutex := sync.Mutex{} + var finalExitCode int + + // preparing the possibility to tearDown from different threads while ensure it is only called once + var tearDownOnce sync.Once + + // init context + ctx := context.Background() + ctx = context.WithValue(ctx, networking.InteractionIdKey, instrumentation.AssembleUrnFromUUID(interactionId)) + globalContext = ctx startTime := time.Now() var err error @@ -633,10 +623,6 @@ func MainWithErrorCode() int { return constants.SNYK_EXIT_CODE_ERROR } - // init context - ctx := context.Background() - ctx = context.WithValue(ctx, networking.InteractionIdKey, instrumentation.AssembleUrnFromUUID(interactionId)) - // add output flags as persistent flags outputWorkflow, _ := globalEngine.GetWorkflow(localworkflows.WORKFLOWID_OUTPUT_WORKFLOW) outputFlags := workflow.FlagsetFromConfigurationOptions(outputWorkflow.GetConfigurationOptions()) @@ -681,43 +667,15 @@ func MainWithErrorCode() int { // ignore } - outputError := err - allErrors := errorList - - if err != nil { - allErrors, outputError = processError(err, errorList) - - for _, tempError := range allErrors { - if tempError != nil { - cliAnalytics.AddError(tempError) - } - } - } - - displayError(outputError, globalEngine.GetUserInterface(), globalConfiguration, ctx) - - exitCode := cliv2.DeriveExitCode(outputError) - globalLogger.Printf("Deriving Exit Code %d (cause: %v)", exitCode, outputError) - - updateInstrumentationDataBeforeSending(cliAnalytics, startTime, ua, exitCode) - - if !globalConfiguration.GetBool(configuration.ANALYTICS_DISABLED) { - sendAnalytics(cliAnalytics, globalLogger) - } - sendInstrumentation(globalEngine, cliAnalytics.GetInstrumentation(), globalLogger) - - // cleanup resources in use - // WARNING: deferred actions will execute AFTER cleanup; only defer if not impacted by this - _, err = globalEngine.Invoke(basic_workflows.WORKFLOWID_GLOBAL_CLEANUP) - if err != nil { - globalLogger.Printf("Failed to cleanup %v", err) - } + tearDownOnce.Do(func() { + errorListMutex.Lock() + errorListCopy := append([]error{}, errorList...) + errorListMutex.Unlock() - if debugEnabled { - writeLogFooter(exitCode, allErrors, globalConfiguration, networkAccess) - } + finalExitCode = tearDown(err, errorListCopy, startTime, ua, cliAnalytics, networkAccess) + }) - return exitCode + return finalExitCode } func processError(err error, errorList []error) ([]error, error) { diff --git a/cliv2/cmd/cliv2/main_test.go b/cliv2/cmd/cliv2/main_test.go index 0a70446a24..f96cebe1e5 100644 --- a/cliv2/cmd/cliv2/main_test.go +++ b/cliv2/cmd/cliv2/main_test.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/json" "errors" "fmt" @@ -466,7 +465,7 @@ func Test_runWorkflowAndProcessData(t *testing.T) { // invoke method under test logger := zerolog.New(os.Stderr) - err = runWorkflowAndProcessData(globalEngine, &logger, testCmnd) + err = runWorkflowAndProcessData(t.Context(), globalEngine, &logger, testCmnd) var expectedError *clierrors.ErrorWithExitCode assert.ErrorAs(t, err, &expectedError) @@ -560,7 +559,7 @@ func Test_runWorkflowAndProcessData_with_Filtering(t *testing.T) { assert.NoError(t, err) logger := zerolog.New(os.Stderr) - err = runWorkflowAndProcessData(globalEngine, &logger, testCmnd) + err = runWorkflowAndProcessData(t.Context(), globalEngine, &logger, testCmnd) } func Test_setTimeout(t *testing.T) { @@ -588,7 +587,7 @@ func Test_displayError(t *testing.T) { userInterface.EXPECT().OutputError(err, gomock.Any()).Times(1) config := configuration.NewWithOpts(configuration.WithAutomaticEnv()) - displayError(err, userInterface, config, context.Background()) + displayError(err, userInterface, config, t.Context()) }) scenarios := []struct { @@ -609,7 +608,7 @@ func Test_displayError(t *testing.T) { t.Run(fmt.Sprintf("%s does not display anything", scenario.name), func(t *testing.T) { config := configuration.NewWithOpts(configuration.WithAutomaticEnv()) err := scenario.err - displayError(err, userInterface, config, context.Background()) + displayError(err, userInterface, config, t.Context()) }) } @@ -618,7 +617,7 @@ func Test_displayError(t *testing.T) { userInterface.EXPECT().OutputError(err, gomock.Any()).Times(1) config := configuration.NewWithOpts(configuration.WithAutomaticEnv()) - displayError(err, userInterface, config, context.Background()) + displayError(err, userInterface, config, t.Context()) }) } diff --git a/cliv2/internal/cliv2/cliv2.go b/cliv2/internal/cliv2/cliv2.go index 98233e462b..bac899a594 100644 --- a/cliv2/internal/cliv2/cliv2.go +++ b/cliv2/internal/cliv2/cliv2.go @@ -263,8 +263,8 @@ func (c *CLI) commandVersion(passthroughArgs []string) error { } } -func (c *CLI) commandAbout(proxyInfo *proxy.ProxyInfo, passthroughArgs []string) error { - err := c.executeV1Default(proxyInfo, passthroughArgs) +func (c *CLI) commandAbout(ctx context.Context, proxyInfo *proxy.ProxyInfo, passthroughArgs []string) error { + err := c.executeV1Default(ctx, proxyInfo, passthroughArgs) if err != nil { return err } @@ -433,14 +433,11 @@ func (c *CLI) PrepareV1Command( return snykCmd, err } -func (c *CLI) executeV1Default(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { +func (c *CLI) executeV1Default(ctx context.Context, proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { timeout := c.globalConfig.GetInt(configuration.TIMEOUT) - var ctx context.Context var cancel context.CancelFunc - if timeout == 0 { - ctx = context.Background() - } else { - ctx, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() } @@ -545,7 +542,7 @@ func GetErrorFromFile(execErr error, errFilePath string, config configuration.Co return nil, ErrIPCNoDataSent } -func (c *CLI) Execute(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { +func (c *CLI) Execute(ctx context.Context, proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error { var err error handler := determineHandler(passThroughArgs) @@ -553,11 +550,11 @@ func (c *CLI) Execute(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) erro case V2_VERSION: err = c.commandVersion(passThroughArgs) case V2_ABOUT: - err = c.commandAbout(proxyInfo, passThroughArgs) + err = c.commandAbout(ctx, proxyInfo, passThroughArgs) case V1_DEFAULT: fallthrough default: - err = c.executeV1Default(proxyInfo, passThroughArgs) + err = c.executeV1Default(ctx, proxyInfo, passThroughArgs) } return err diff --git a/cliv2/internal/cliv2/cliv2_test.go b/cliv2/internal/cliv2/cliv2_test.go index be38d023fc..847f66f2da 100644 --- a/cliv2/internal/cliv2/cliv2_test.go +++ b/cliv2/internal/cliv2/cliv2_test.go @@ -59,7 +59,7 @@ func Test_NewCLIv2_SubprocessEnv_OverridesIfSet_AndDefaultsToOsEnv(t *testing.T) assert.NoError(t, err) cmd, err := cli.PrepareV1Command( - context.Background(), + t.Context(), "someExecutable", []string{"--help"}, getProxyInfoForTest(), @@ -83,7 +83,7 @@ func Test_NewCLIv2_SubprocessEnv_OverridesIfSet_AndDefaultsToOsEnv(t *testing.T) assert.NoError(t, err) cmd, err := cli.PrepareV1Command( - context.Background(), + t.Context(), "someExecutable", []string{"--help"}, getProxyInfoForTest(), @@ -352,7 +352,7 @@ func Test_prepareV1Command(t *testing.T) { assert.NoError(t, err) snykCmd, err := cli.PrepareV1Command( - context.Background(), + t.Context(), "someExecutable", expectedArgs, getProxyInfoForTest(), @@ -376,7 +376,7 @@ func Test_prepareV1Command_InjectsExecutablePath(t *testing.T) { assert.NoError(t, err) snykCmd, err := cli.PrepareV1Command( - context.Background(), + t.Context(), "someExecutable", []string{"--help"}, getProxyInfoForTest(), @@ -408,7 +408,7 @@ func Test_extractOnlyOnce(t *testing.T) { assert.NoError(t, cli.Init()) // run once - err = cli.Execute(getProxyInfoForTest(), []string{"--help"}) + err = cli.Execute(t.Context(), getProxyInfoForTest(), []string{"--help"}) assert.Error(t, err) // invalid binary expected here assert.FileExists(t, cli.GetBinaryLocation()) fileInfo1, err := os.Stat(cli.GetBinaryLocation()) @@ -419,7 +419,7 @@ func Test_extractOnlyOnce(t *testing.T) { // run twice assert.Nil(t, cli.Init()) - err = cli.Execute(getProxyInfoForTest(), []string{"--help"}) + err = cli.Execute(t.Context(), getProxyInfoForTest(), []string{"--help"}) assert.Error(t, err) // invalid binary expected here assert.FileExists(t, cli.GetBinaryLocation()) fileInfo2, err := os.Stat(cli.GetBinaryLocation()) @@ -479,7 +479,7 @@ func Test_executeRunV2only(t *testing.T) { assert.NoError(t, err) assert.NoError(t, cli.Init()) - actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"--version"})) + actualReturnCode := cliv2.DeriveExitCode(cli.Execute(t.Context(), getProxyInfoForTest(), []string{"--version"})) assert.Equal(t, expectedReturnCode, actualReturnCode) assert.FileExists(t, cli.GetBinaryLocation()) } @@ -496,7 +496,7 @@ func Test_executeUnknownCommand(t *testing.T) { assert.NoError(t, err) assert.NoError(t, cli.Init()) - actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"bogusCommand"})) + actualReturnCode := cliv2.DeriveExitCode(cli.Execute(t.Context(), getProxyInfoForTest(), []string{"bogusCommand"})) assert.Equal(t, expectedReturnCode, actualReturnCode) } @@ -590,7 +590,7 @@ func Test_setTimeout(t *testing.T) { // sleep for 2s cli.SetV1BinaryLocation("/bin/sleep") - err = cli.Execute(getProxyInfoForTest(), []string{"2"}) + err = cli.Execute(t.Context(), getProxyInfoForTest(), []string{"2"}) assert.ErrorIs(t, err, context.DeadlineExceeded) } diff --git a/cliv2/pkg/basic_workflows/legacycli.go b/cliv2/pkg/basic_workflows/legacycli.go index 9a9f711f37..24e4a7ffac 100644 --- a/cliv2/pkg/basic_workflows/legacycli.go +++ b/cliv2/pkg/basic_workflows/legacycli.go @@ -151,9 +151,9 @@ func legacycliWorkflow( return output, err } - // run the cli + // run the cli with context from invocation (allows cancellation on signal) proxyInfo := wrapperProxy.ProxyInfo() - err = cli.Execute(proxyInfo, finalizeArguments(args, config.GetStringSlice(configuration.UNKNOWN_ARGS))) + err = cli.Execute(invocation.Context(), proxyInfo, finalizeArguments(args, config.GetStringSlice(configuration.UNKNOWN_ARGS))) if !useStdIo { _ = outWriter.Flush()