diff --git a/sast-engine/cmd/resolution_report.go b/sast-engine/cmd/resolution_report.go index b04a9d76..3869d3aa 100644 --- a/sast-engine/cmd/resolution_report.go +++ b/sast-engine/cmd/resolution_report.go @@ -11,7 +11,9 @@ import ( "github.com/shivasurya/code-pathfinder/sast-engine/graph" "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/builder" "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/resolution" "github.com/shivasurya/code-pathfinder/sast-engine/output" "github.com/spf13/cobra" ) @@ -41,12 +43,28 @@ Use --csv to export unresolved calls with file, line, target, and reason.`, codeGraph := graph.Initialize(projectInput, nil) fmt.Println("Building call graph...") - cg, registry, _, err := callgraph.InitializeCallGraph(codeGraph, projectInput, output.NewLogger(output.VerbosityDefault)) + logger := output.NewLogger(output.VerbosityDefault) + cg, registry, _, err := callgraph.InitializeCallGraph(codeGraph, projectInput, logger) if err != nil { fmt.Printf("Error building call graph: %v\n", err) return } + // Build Go call graph if go.mod exists (same pipeline as scan.go). + goModPath := filepath.Join(projectInput, "go.mod") + if _, statErr := os.Stat(goModPath); statErr == nil { + goRegistry, goErr := resolution.BuildGoModuleRegistry(projectInput) + if goErr == nil && goRegistry != nil { + builder.InitGoStdlibLoader(goRegistry, projectInput, logger) + builder.InitGoThirdPartyLoader(goRegistry, projectInput, false, logger) + goTypeEngine := resolution.NewGoTypeInferenceEngine(goRegistry) + goCG, goErr := builder.BuildGoCallGraph(codeGraph, goRegistry, goTypeEngine) + if goErr == nil && goCG != nil { + builder.MergeCallGraphs(cg, goCG) + } + } + } + fmt.Printf("\nResolution Report for %s\n", projectInput) fmt.Println("===============================================") @@ -69,6 +87,12 @@ Use --csv to export unresolved calls with file, line, target, and reason.`, fmt.Println() } + // Print Go resolution statistics + if stats.GoTotalCalls > 0 { + printGoResolutionStatistics(stats) + fmt.Println() + } + // Print failure breakdown printFailureBreakdown(stats) fmt.Println() @@ -133,6 +157,16 @@ type resolutionStatistics struct { StdlibViaAnnotation int // Resolved via type annotations StdlibViaInference int // Resolved via type inference StdlibViaBuiltin int // Resolved via builtin registry + + // Go resolution statistics + GoTotalCalls int + GoResolvedCalls int + GoUnresolvedCalls int + GoUserCodeResolved int // Resolved via user-code call graph (Check 1) + GoStdlibResolved int // Resolved via StdlibLoader (Check 2) + GoThirdPartyResolved int // Resolved via ThirdPartyLoader (Check 2.5) + GoStdlibByModule map[string]int // e.g., "net/http" -> 12 + GoThirdPartyByModule map[string]int // e.g., "gorm.io/gorm" -> 5 } // aggregateResolutionStatistics analyzes the call graph and collects statistics. @@ -147,6 +181,8 @@ func aggregateResolutionStatistics(cg *core.CallGraph, projectRoot string) *reso ConfidenceDistribution: make(map[string]int), StdlibByModule: make(map[string]int), StdlibByType: make(map[string]int), + GoStdlibByModule: make(map[string]int), + GoThirdPartyByModule: make(map[string]int), } // Iterate through all call sites @@ -154,6 +190,40 @@ func aggregateResolutionStatistics(cg *core.CallGraph, projectRoot string) *reso for _, site := range callSites { stats.TotalCalls++ + // Classify Go call sites: check caller function language or FQN heuristic. + // Go FQNs always contain "/" (e.g., "net/http.Request.FormValue"). + // Caller function node language is the authoritative signal when available. + funcNode := cg.Functions[functionFQN] + isGoCall := (funcNode != nil && funcNode.Language == "go") || + strings.Contains(site.TargetFQN, "/") + + if isGoCall { + stats.GoTotalCalls++ + if site.Resolved { + stats.GoResolvedCalls++ + // Use site.IsStdlib (set by go_builder.go) as the authoritative stdlib signal. + // It correctly handles single-segment stdlib packages (fmt, os, sync, io). + switch { + case site.IsStdlib: + stats.GoStdlibResolved++ + goModule := extractGoModuleName(site.TargetFQN) + if goModule != "" { + stats.GoStdlibByModule[goModule]++ + } + case site.TypeSource == "thirdparty_local" || site.TypeSource == "thirdparty_cdn": + stats.GoThirdPartyResolved++ + goModule := extractGoModuleName(site.TargetFQN) + if goModule != "" { + stats.GoThirdPartyByModule[goModule]++ + } + default: + stats.GoUserCodeResolved++ + } + } else { + stats.GoUnresolvedCalls++ + } + } + if site.Resolved { stats.ResolvedCalls++ @@ -517,6 +587,92 @@ func percentage(part, total int) float64 { return float64(part) * 100.0 / float64(total) } +// extractGoModuleName extracts the Go module/package path from a fully-qualified name. +// Examples: +// +// "gorm.io/gorm.DB.Where" -> "gorm.io/gorm" +// "net/http.Request.FormValue" -> "net/http" +// "github.com/gin-gonic/gin.Context.Query" -> "github.com/gin-gonic/gin" +// "fmt.Println" -> "" (no slash, not a module path) +func extractGoModuleName(fqn string) string { + lastSlash := strings.LastIndex(fqn, "/") + if lastSlash == -1 { + // No slash — single-segment stdlib package (fmt, os, sync). + // Cannot reliably extract a module path, so return empty. + return "" + } + // After the last slash find the first "." which separates package from type name. + rest := fqn[lastSlash+1:] + dotIdx := strings.Index(rest, ".") + if dotIdx == -1 { + return fqn + } + return fqn[:lastSlash+1+dotIdx] +} + +// printTopModules prints the top N entries from a module→count map, +// sorted by count descending. +func printTopModules(modules map[string]int, topN int) { + type moduleCount struct { + module string + count int + } + entries := make([]moduleCount, 0, len(modules)) + for mod, count := range modules { + entries = append(entries, moduleCount{mod, count}) + } + sort.Slice(entries, func(i, j int) bool { + if entries[i].count != entries[j].count { + return entries[i].count > entries[j].count + } + return entries[i].module < entries[j].module + }) + for i, mc := range entries { + if i >= topN { + break + } + fmt.Printf(" %2d. %-40s %d calls\n", i+1, mc.module, mc.count) + } +} + +// printGoResolutionStatistics prints the Go call graph resolution statistics. +func printGoResolutionStatistics(stats *resolutionStatistics) { + fmt.Println("Go Resolution Statistics:") + fmt.Printf(" Total Go calls: %d\n", stats.GoTotalCalls) + fmt.Printf(" Resolved: %d (%.1f%%)\n", + stats.GoResolvedCalls, + percentage(stats.GoResolvedCalls, stats.GoTotalCalls)) + fmt.Printf(" Unresolved: %d (%.1f%%)\n", + stats.GoUnresolvedCalls, + percentage(stats.GoUnresolvedCalls, stats.GoTotalCalls)) + fmt.Println() + + if stats.GoResolvedCalls > 0 { + fmt.Println(" Resolution Breakdown:") + fmt.Printf(" User code: %d (%.1f%%)\n", + stats.GoUserCodeResolved, + percentage(stats.GoUserCodeResolved, stats.GoResolvedCalls)) + fmt.Printf(" Stdlib (CDN): %d (%.1f%%)\n", + stats.GoStdlibResolved, + percentage(stats.GoStdlibResolved, stats.GoResolvedCalls)) + fmt.Printf(" Third-party: %d (%.1f%%)\n", + stats.GoThirdPartyResolved, + percentage(stats.GoThirdPartyResolved, stats.GoResolvedCalls)) + fmt.Println() + } + + if len(stats.GoStdlibByModule) > 0 { + fmt.Println(" Top Go Stdlib Modules:") + printTopModules(stats.GoStdlibByModule, 10) + fmt.Println() + } + + if len(stats.GoThirdPartyByModule) > 0 { + fmt.Println(" Top Go Third-Party Modules:") + printTopModules(stats.GoThirdPartyByModule, 10) + } +} + // isStdlibResolution checks if a FQN resolves to Python stdlib. func isStdlibResolution(fqn string) bool { // List of common stdlib modules diff --git a/sast-engine/cmd/resolution_report_test.go b/sast-engine/cmd/resolution_report_test.go index b6e4e379..ee08891a 100644 --- a/sast-engine/cmd/resolution_report_test.go +++ b/sast-engine/cmd/resolution_report_test.go @@ -2,9 +2,11 @@ package cmd import ( "os" + "path/filepath" "strings" "testing" + "github.com/shivasurya/code-pathfinder/sast-engine/graph" "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" "github.com/stretchr/testify/assert" ) @@ -616,3 +618,251 @@ func TestAggregateResolutionStatistics_FailureReasonBreakdown(t *testing.T) { assert.Equal(t, 1, stats.FailuresByReason["variable_method"]) assert.Equal(t, 1, stats.FailuresByReason["attribute_chain"]) } + +func TestExtractGoModuleName(t *testing.T) { + tests := []struct { + name string + fqn string + expected string + }{ + {"gorm multi-segment", "gorm.io/gorm.DB.Where", "gorm.io/gorm"}, + {"net/http two-segment", "net/http.Request.FormValue", "net/http"}, + {"gin three-segment", "github.com/gin-gonic/gin.Context.Query", "github.com/gin-gonic/gin"}, + {"pgxpool four-segment", "github.com/jackc/pgx/v5/pgxpool.Pool.Query", "github.com/jackc/pgx/v5/pgxpool"}, + {"no slash single-segment stdlib", "fmt.Println", ""}, + {"no slash no dot", "fmt", ""}, + {"empty string", "", ""}, + {"no dot after last slash", "github.com/foo/bar", "github.com/foo/bar"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractGoModuleName(tt.fqn) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestPrintTopModules(t *testing.T) { + // Should not panic with a normal map. + modules := map[string]int{ + "net/http": 8, + "fmt": 5, + "gorm.io/gorm": 4, + "github.com/gin-gonic/gin": 3, + } + printTopModules(modules, 10) + + // Should not panic with empty map. + printTopModules(map[string]int{}, 10) + + // Should respect topN limit. + printTopModules(modules, 2) +} + +func TestPrintGoResolutionStatistics_Empty(t *testing.T) { + stats := &resolutionStatistics{ + GoTotalCalls: 0, + GoResolvedCalls: 0, + GoUnresolvedCalls: 0, + GoStdlibByModule: make(map[string]int), + GoThirdPartyByModule: make(map[string]int), + } + // Should not panic. + printGoResolutionStatistics(stats) +} + +func TestPrintGoResolutionStatistics_StdlibOnly(t *testing.T) { + stats := &resolutionStatistics{ + GoTotalCalls: 20, + GoResolvedCalls: 18, + GoUnresolvedCalls: 2, + GoUserCodeResolved: 0, + GoStdlibResolved: 18, + GoStdlibByModule: map[string]int{ + "net/http": 10, + "fmt": 8, + }, + GoThirdPartyByModule: make(map[string]int), + } + // Should not panic. + printGoResolutionStatistics(stats) +} + +func TestPrintGoResolutionStatistics_ThirdPartyOnly(t *testing.T) { + stats := &resolutionStatistics{ + GoTotalCalls: 10, + GoResolvedCalls: 10, + GoUnresolvedCalls: 0, + GoThirdPartyResolved: 10, + GoStdlibByModule: make(map[string]int), + GoThirdPartyByModule: map[string]int{ + "gorm.io/gorm": 4, + "github.com/gin-gonic/gin": 3, + "github.com/redis/go-redis/v9": 3, + }, + } + // Should not panic. + printGoResolutionStatistics(stats) +} + +func TestPrintGoResolutionStatistics_Mixed(t *testing.T) { + stats := &resolutionStatistics{ + GoTotalCalls: 45, + GoResolvedCalls: 42, + GoUnresolvedCalls: 3, + GoUserCodeResolved: 12, + GoStdlibResolved: 20, + GoThirdPartyResolved: 10, + GoStdlibByModule: map[string]int{ + "net/http": 8, + "fmt": 5, + "crypto/tls": 3, + "os": 4, + }, + GoThirdPartyByModule: map[string]int{ + "gorm.io/gorm": 4, + "github.com/gin-gonic/gin": 3, + }, + } + // Should not panic. + printGoResolutionStatistics(stats) +} + +func TestAggregateResolutionStatistics_GoCallSites(t *testing.T) { + cg := core.NewCallGraph() + + // Go stdlib call (IsStdlib=true, FQN has slash) + cg.AddCallSite("main.handler", core.CallSite{ + Target: "FormValue", + Resolved: true, + TargetFQN: "net/http.Request.FormValue", + IsStdlib: true, + }) + + // Go third-party call (TypeSource=thirdparty_local) + cg.AddCallSite("main.handler", core.CallSite{ + Target: "Raw", + Resolved: true, + TargetFQN: "gorm.io/gorm.DB.Raw", + TypeSource: "thirdparty_local", + }) + + // Go user code call (resolved, not stdlib, not third-party) + cg.AddCallSite("main.handler", core.CallSite{ + Target: "Process", + Resolved: true, + TargetFQN: "testapp/svc.Service.Process", + }) + + // Go unresolved call + cg.AddCallSite("main.handler", core.CallSite{ + Target: "Unknown", + Resolved: false, + TargetFQN: "github.com/some/pkg.Type.Unknown", + }) + + stats := aggregateResolutionStatistics(cg, "/project") + + assert.Equal(t, 4, stats.GoTotalCalls) + assert.Equal(t, 3, stats.GoResolvedCalls) + assert.Equal(t, 1, stats.GoUnresolvedCalls) + assert.Equal(t, 1, stats.GoStdlibResolved) + assert.Equal(t, 1, stats.GoThirdPartyResolved) + assert.Equal(t, 1, stats.GoUserCodeResolved) + assert.Equal(t, 1, stats.GoStdlibByModule["net/http"]) + assert.Equal(t, 1, stats.GoThirdPartyByModule["gorm.io/gorm"]) +} + +func TestAggregateResolutionStatistics_GoThirdPartyCDN(t *testing.T) { + cg := core.NewCallGraph() + + // Go CDN third-party call + cg.AddCallSite("main.handler", core.CallSite{ + Target: "Get", + Resolved: true, + TargetFQN: "github.com/redis/go-redis/v9.Client.Get", + TypeSource: "thirdparty_cdn", + }) + + stats := aggregateResolutionStatistics(cg, "/project") + + assert.Equal(t, 1, stats.GoTotalCalls) + assert.Equal(t, 1, stats.GoThirdPartyResolved) + assert.Equal(t, 1, stats.GoThirdPartyByModule["github.com/redis/go-redis/v9"]) +} + +func TestAggregateResolutionStatistics_GoModuleNameEmpty(t *testing.T) { + cg := core.NewCallGraph() + + // Register the caller function with Language="go" so the Go-detection path + // triggers even for single-segment stdlib FQNs (fmt, os) that have no slash. + cg.Functions["main.main"] = &graph.Node{ + ID: "main_main", + Type: "function", + Name: "main", + Language: "go", + } + + // Go stdlib single-segment — no slash, extractGoModuleName returns "" + cg.AddCallSite("main.main", core.CallSite{ + Target: "Println", + Resolved: true, + TargetFQN: "fmt.Println", + IsStdlib: true, + }) + + stats := aggregateResolutionStatistics(cg, "/project") + + // fmt.Println has IsStdlib=true, caller has Language="go" → classified as Go stdlib. + // extractGoModuleName returns "" (no slash) → GoStdlibByModule stays empty. + assert.Equal(t, 1, stats.GoStdlibResolved) + assert.Equal(t, 0, len(stats.GoStdlibByModule)) +} + +// --------------------------------------------------------------------------- +// Integration tests for the resolutionReportCmd Run function +// (covers Go pipeline construction + printGoResolutionStatistics call) +// --------------------------------------------------------------------------- + +func TestResolutionReportCmd_WithoutGoMod(t *testing.T) { + tmpDir := t.TempDir() + + // Python-only project — no go.mod, Go pipeline must be skipped. + err := os.WriteFile(filepath.Join(tmpDir, "main.py"), []byte("def handler():\n pass\n"), 0644) + assert.NoError(t, err) + + err = resolutionReportCmd.Flags().Set("project", tmpDir) + assert.NoError(t, err) + + // Should not panic. + assert.NotPanics(t, func() { + resolutionReportCmd.Run(resolutionReportCmd, []string{}) + }) +} + +func TestResolutionReportCmd_WithGoMod(t *testing.T) { + tmpDir := t.TempDir() + + // Minimal Go project — go.mod + one Go file. + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte("module example.com/test\n\ngo 1.22\n"), 0644) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte(`package main + +import "fmt" + +func main() { + fmt.Println("hello") +} +`), 0644) + assert.NoError(t, err) + + err = resolutionReportCmd.Flags().Set("project", tmpDir) + assert.NoError(t, err) + + // Go pipeline runs; must not panic regardless of what it resolves. + assert.NotPanics(t, func() { + resolutionReportCmd.Run(resolutionReportCmd, []string{}) + }) +}