diff --git a/sast-engine/graph/callgraph/builder/go_builder.go b/sast-engine/graph/callgraph/builder/go_builder.go index 3fc1014d..d9f4fb68 100644 --- a/sast-engine/graph/callgraph/builder/go_builder.go +++ b/sast-engine/graph/callgraph/builder/go_builder.go @@ -155,7 +155,7 @@ func BuildGoCallGraph(codeGraph *graph.CodeGraph, registry *core.GoModuleRegistr importMap = core.NewGoImportMap(callSite.CallerFile) } - targetFQN, resolved, isStdlib := resolveGoCallTarget(callSite, importMap, registry, functionContext, typeEngine, callGraph) + targetFQN, resolved, isStdlib := resolveGoCallTarget(callSite, importMap, registry, functionContext, typeEngine, callGraph, codeGraph) if resolved { resolvedCount++ @@ -438,6 +438,7 @@ func resolveGoCallTarget( functionContext map[string][]*graph.Node, typeEngine *resolution.GoTypeInferenceEngine, callGraph *core.CallGraph, + codeGraph *graph.CodeGraph, ) (string, bool, bool) { // Pattern 1a: Qualified call (pkg.Func or obj.Method) if callSite.ObjectName != "" { @@ -491,6 +492,27 @@ func resolveGoCallTarget( } } + // Source 3: Package-level variable types from CodeGraph nodes. + // Covers `var globalDB *sql.DB` at package scope — not tracked by + // GoTypeInferenceEngine (which only processes := / = assignments in + // function bodies). Only fires when Source 1 and Source 2 both fail. + if typeFQN == "" && codeGraph != nil { + for _, node := range codeGraph.Nodes { + if node.Type != "module_variable" || node.DataType == "" { + continue + } + if node.Name != callSite.ObjectName { + continue + } + if !isSameGoPackage(callSite.CallerFile, node.File) { + continue + } + typeStr := strings.TrimPrefix(node.DataType, "*") + typeFQN = resolveGoTypeFQN(typeStr, importMap) + break + } + } + if typeFQN != "" { methodFQN := typeFQN + "." + callSite.FunctionName diff --git a/sast-engine/graph/callgraph/builder/go_builder_approach_c_test.go b/sast-engine/graph/callgraph/builder/go_builder_approach_c_test.go index 26b77125..979c68fe 100644 --- a/sast-engine/graph/callgraph/builder/go_builder_approach_c_test.go +++ b/sast-engine/graph/callgraph/builder/go_builder_approach_c_test.go @@ -41,7 +41,7 @@ func TestApproachC_ThirdPartyPartialResolution(t *testing.T) { } targetFQN, resolved, _ := resolveGoCallTarget( - callSite, importMap, goRegistry, nil, typeEngine, callGraph, + callSite, importMap, goRegistry, nil, typeEngine, callGraph, nil, ) assert.Equal(t, "github.com/redis/go-redis/v9.Client.Get", targetFQN) @@ -82,7 +82,7 @@ func TestApproachC_UserCodeMethodResolution(t *testing.T) { } targetFQN, resolved, isStdlib := resolveGoCallTarget( - callSite, importMap, goRegistry, nil, typeEngine, callGraph, + callSite, importMap, goRegistry, nil, typeEngine, callGraph, nil, ) assert.Equal(t, "testapp.Service.Handle", targetFQN) @@ -118,7 +118,7 @@ func TestApproachC_PointerTypeStripping(t *testing.T) { } targetFQN, resolved, _ := resolveGoCallTarget( - callSite, importMap, goRegistry, nil, typeEngine, callGraph, + callSite, importMap, goRegistry, nil, typeEngine, callGraph, nil, ) // Pointer * should be stripped: *database/sql.DB → database/sql.DB @@ -184,7 +184,7 @@ func TestApproachC_NoTypeEngine(t *testing.T) { // No typeEngine → Pattern 1b skipped → unresolved targetFQN, resolved, _ := resolveGoCallTarget( - callSite, importMap, goRegistry, nil, nil, callGraph, + callSite, importMap, goRegistry, nil, nil, callGraph, nil, ) assert.Equal(t, "", targetFQN) diff --git a/sast-engine/graph/callgraph/builder/go_builder_pkgvar_test.go b/sast-engine/graph/callgraph/builder/go_builder_pkgvar_test.go new file mode 100644 index 00000000..80f51f11 --- /dev/null +++ b/sast-engine/graph/callgraph/builder/go_builder_pkgvar_test.go @@ -0,0 +1,153 @@ +package builder + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/stretchr/testify/assert" +) + +// makePackageVarCodeGraph builds a CodeGraph containing a module_variable node +// that represents `var ` in the given file. +func makePackageVarCodeGraph(varName, dataType, file string) *graph.CodeGraph { + cg := graph.NewCodeGraph() + cg.Nodes[varName] = &graph.Node{ + ID: varName, + Type: "module_variable", + Name: varName, + DataType: dataType, + File: file, + Language: "go", + } + return cg +} + +// TestSource3_PackageLevelVariable verifies that Source 3 resolves the type of a +// package-level variable and returns the correct method FQN. +func TestSource3_PackageLevelVariable(t *testing.T) { + cg := makePackageVarCodeGraph("globalDB", "sql.DB", "/project/main.go") + + callSite := &CallSiteInternal{ + CallerFQN: "main.handler", + CallerFile: "/project/main.go", + ObjectName: "globalDB", + FunctionName: "Query", + } + + importMap := &core.GoImportMap{ + Imports: map[string]string{"sql": "database/sql"}, + } + + reg := core.NewGoModuleRegistry() + callGraph := &core.CallGraph{Functions: make(map[string]*graph.Node)} + + targetFQN, resolved, _ := resolveGoCallTarget( + callSite, importMap, reg, nil, nil, callGraph, cg, + ) + + assert.True(t, resolved) + assert.Equal(t, "database/sql.DB.Query", targetFQN) +} + +// TestSource3_PointerType verifies that Source 3 strips the leading `*` from the +// DataType field (e.g. `var db *sql.DB` stores DataType as "*sql.DB"). +func TestSource3_PointerType(t *testing.T) { + cg := makePackageVarCodeGraph("db", "*sql.DB", "/project/store.go") + + callSite := &CallSiteInternal{ + CallerFQN: "main.runQuery", + CallerFile: "/project/store.go", + ObjectName: "db", + FunctionName: "Exec", + } + + importMap := &core.GoImportMap{ + Imports: map[string]string{"sql": "database/sql"}, + } + + reg := core.NewGoModuleRegistry() + callGraph := &core.CallGraph{Functions: make(map[string]*graph.Node)} + + targetFQN, resolved, _ := resolveGoCallTarget( + callSite, importMap, reg, nil, nil, callGraph, cg, + ) + + assert.True(t, resolved) + assert.Equal(t, "database/sql.DB.Exec", targetFQN) +} + +// TestSource3_SamePackageFilter verifies that Source 3 only resolves variables +// defined in the same package as the caller (same directory). +func TestSource3_SamePackageFilter(t *testing.T) { + cg := graph.NewCodeGraph() + // Variable in a DIFFERENT package (/project/other/db.go) + cg.Nodes["otherDB"] = &graph.Node{ + ID: "otherDB", + Type: "module_variable", + Name: "globalDB", + DataType: "sql.DB", + File: "/project/other/db.go", + Language: "go", + } + + callSite := &CallSiteInternal{ + CallerFQN: "main.handler", + CallerFile: "/project/main.go", // different directory + ObjectName: "globalDB", + FunctionName: "Query", + } + + importMap := &core.GoImportMap{Imports: map[string]string{"sql": "database/sql"}} + reg := core.NewGoModuleRegistry() + callGraph := &core.CallGraph{Functions: make(map[string]*graph.Node)} + + _, resolved, _ := resolveGoCallTarget( + callSite, importMap, reg, nil, nil, callGraph, cg, + ) + + // Must NOT resolve: variable is in a different package directory. + assert.False(t, resolved) +} + +// TestSource3_NoTypeAnnotation verifies that Source 3 gracefully skips a +// module_variable node whose DataType is empty (e.g. `var db = sql.Open(...)`). +func TestSource3_NoTypeAnnotation(t *testing.T) { + cg := makePackageVarCodeGraph("db", "", "/project/main.go") // empty DataType + + callSite := &CallSiteInternal{ + CallerFQN: "main.handler", + CallerFile: "/project/main.go", + ObjectName: "db", + FunctionName: "Query", + } + + importMap := &core.GoImportMap{Imports: map[string]string{}} + reg := core.NewGoModuleRegistry() + callGraph := &core.CallGraph{Functions: make(map[string]*graph.Node)} + + _, resolved, _ := resolveGoCallTarget( + callSite, importMap, reg, nil, nil, callGraph, cg, + ) + + // Must NOT resolve: no type info available. + assert.False(t, resolved) +} + +// TestSource3_NilCodeGraph verifies that Source 3 does not panic with a nil CodeGraph. +func TestSource3_NilCodeGraph(t *testing.T) { + callSite := &CallSiteInternal{ + CallerFQN: "main.handler", + CallerFile: "/project/main.go", + ObjectName: "globalDB", + FunctionName: "Query", + } + + importMap := &core.GoImportMap{Imports: map[string]string{}} + reg := core.NewGoModuleRegistry() + callGraph := &core.CallGraph{Functions: make(map[string]*graph.Node)} + + assert.NotPanics(t, func() { + resolveGoCallTarget(callSite, importMap, reg, nil, nil, callGraph, nil) + }) +} diff --git a/sast-engine/graph/callgraph/builder/go_builder_stdlib_test.go b/sast-engine/graph/callgraph/builder/go_builder_stdlib_test.go index 648f46df..7a55fe59 100644 --- a/sast-engine/graph/callgraph/builder/go_builder_stdlib_test.go +++ b/sast-engine/graph/callgraph/builder/go_builder_stdlib_test.go @@ -75,7 +75,7 @@ func TestResolveGoCallTarget_StdlibImport(t *testing.T) { cs := &CallSiteInternal{FunctionName: "Println", ObjectName: "fmt"} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) require.True(t, resolved) assert.Equal(t, "fmt.Println", targetFQN) @@ -89,7 +89,7 @@ func TestResolveGoCallTarget_NilStdlibLoader(t *testing.T) { cs := &CallSiteInternal{FunctionName: "Println", ObjectName: "fmt"} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) require.True(t, resolved) assert.Equal(t, "fmt.Println", targetFQN) @@ -105,7 +105,7 @@ func TestResolveGoCallTarget_ThirdPartyImport(t *testing.T) { cs := &CallSiteInternal{FunctionName: "Default", ObjectName: "gin"} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) require.True(t, resolved) assert.Equal(t, "github.com/gin-gonic/gin.Default", targetFQN) @@ -119,7 +119,7 @@ func TestResolveGoCallTarget_StdlibMultiSegmentPath(t *testing.T) { cs := &CallSiteInternal{FunctionName: "ListenAndServe", ObjectName: "http"} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) require.True(t, resolved) assert.Equal(t, "net/http.ListenAndServe", targetFQN) @@ -137,7 +137,7 @@ func TestResolveGoCallTarget_Builtin(t *testing.T) { cs := &CallSiteInternal{FunctionName: "append", ObjectName: ""} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) require.True(t, resolved) assert.Equal(t, "builtin.append", targetFQN) @@ -151,7 +151,7 @@ func TestResolveGoCallTarget_Unresolved(t *testing.T) { cs := &CallSiteInternal{FunctionName: "Foo", ObjectName: "unknown"} - targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil) + targetFQN, resolved, isStdlib := resolveGoCallTarget(cs, importMap, reg, nil, nil, nil, nil) assert.False(t, resolved) assert.Empty(t, targetFQN) diff --git a/sast-engine/graph/callgraph/builder/go_builder_test.go b/sast-engine/graph/callgraph/builder/go_builder_test.go index a1dd41c2..e4dcb2da 100644 --- a/sast-engine/graph/callgraph/builder/go_builder_test.go +++ b/sast-engine/graph/callgraph/builder/go_builder_test.go @@ -391,7 +391,7 @@ func TestResolveGoCallTarget(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Pass nil for typeEngine and callGraph (backward compatibility) - targetFQN, resolved, _ := resolveGoCallTarget(tt.callSite, tt.importMap, tt.registry, tt.funcContext, nil, nil) + targetFQN, resolved, _ := resolveGoCallTarget(tt.callSite, tt.importMap, tt.registry, tt.funcContext, nil, nil, nil) assert.Equal(t, tt.shouldResolve, resolved, "Resolution status mismatch") @@ -833,6 +833,7 @@ func TestResolveGoCallTarget_VariableMethod(t *testing.T) { functionContext, typeEngine, callGraph, + nil, ) // Assert diff --git a/sast-engine/graph/callgraph/builder/go_version.go b/sast-engine/graph/callgraph/builder/go_version.go index 34c3b7de..325e68d9 100644 --- a/sast-engine/graph/callgraph/builder/go_version.go +++ b/sast-engine/graph/callgraph/builder/go_version.go @@ -130,7 +130,7 @@ func InitGoThirdPartyLoader(reg *core.GoModuleRegistry, projectPath string, refr return } - loader := registry.NewGoThirdPartyLocalLoader(projectPath, refreshCache, logger) + loader := registry.NewGoThirdPartyLocalLoader(projectPath, refreshCache, logger, reg) if loader.PackageCount() == 0 { if logger != nil { logger.Debug("No Go third-party dependencies found in go.mod") diff --git a/sast-engine/graph/callgraph/registry/go_thirdparty_crossembed_test.go b/sast-engine/graph/callgraph/registry/go_thirdparty_crossembed_test.go new file mode 100644 index 00000000..66212adc --- /dev/null +++ b/sast-engine/graph/callgraph/registry/go_thirdparty_crossembed_test.go @@ -0,0 +1,199 @@ +package registry + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/stretchr/testify/assert" +) + +// mockStdlibLoaderForEmbed implements core.GoStdlibLoader for embed resolution tests. +type mockStdlibLoaderForEmbed struct { + types map[string]*core.GoStdlibType // key: "importPath.TypeName" +} + +func (m *mockStdlibLoaderForEmbed) ValidateStdlibImport(importPath string) bool { + for k := range m.types { + // key format: "importPath.TypeName" + if len(k) > len(importPath) && k[:len(importPath)] == importPath { + return true + } + } + return false +} + +func (m *mockStdlibLoaderForEmbed) GetFunction(_, _ string) (*core.GoStdlibFunction, error) { + return nil, nil //nolint:nilnil +} + +func (m *mockStdlibLoaderForEmbed) GetType(importPath, typeName string) (*core.GoStdlibType, error) { + key := importPath + "." + typeName + t, ok := m.types[key] + if !ok { + return nil, nil //nolint:nilnil + } + return t, nil +} + +func (m *mockStdlibLoaderForEmbed) PackageCount() int { return len(m.types) } + +// buildLoaderWithRegistry creates a GoThirdPartyLocalLoader whose registry +// field points at a GoModuleRegistry with the given StdlibLoader attached. +func buildLoaderWithRegistry(stdlibLoader core.GoStdlibLoader) *GoThirdPartyLocalLoader { + reg := core.NewGoModuleRegistry() + reg.StdlibLoader = stdlibLoader + return &GoThirdPartyLocalLoader{ + registry: reg, + } +} + +// makePkgWithEmbedType constructs a minimal GoStdlibPackage with one type that +// embeds the given cross-package interface name (e.g. "context.Context"). +func makePkgWithEmbedType(typeName, embedName string) *core.GoStdlibPackage { + pkg := core.NewGoStdlibPackage("github.com/example/mypkg", "") + pkg.Types[typeName] = &core.GoStdlibType{ + Name: typeName, + Kind: "interface", + Methods: map[string]*core.GoStdlibFunction{}, + Embeds: []string{embedName}, + } + return pkg +} + +// --------------------------------------------------------------------------- +// TestResolveEmbeddings_ViaStdlibLoader +// --------------------------------------------------------------------------- + +// TestResolveEmbeddings_ViaStdlibLoader verifies that resolveEmbeddings copies +// methods from a StdlibLoader-provided type when the embed is cross-package. +func TestResolveEmbeddings_ViaStdlibLoader(t *testing.T) { + stdlibLoader := &mockStdlibLoaderForEmbed{ + types: map[string]*core.GoStdlibType{ + "context.Context": { + Name: "Context", + Methods: map[string]*core.GoStdlibFunction{ + "Deadline": {Name: "Deadline"}, + "Done": {Name: "Done"}, + "Err": {Name: "Err"}, + "Value": {Name: "Value"}, + }, + }, + }, + } + + loader := buildLoaderWithRegistry(stdlibLoader) + pkg := makePkgWithEmbedType("CancelableClient", "context.Context") + + loader.resolveEmbeddings(pkg) + + typ := pkg.Types["CancelableClient"] + assert.Contains(t, typ.Methods, "Deadline", "Deadline should be copied from context.Context via StdlibLoader") + assert.Contains(t, typ.Methods, "Done") + assert.Contains(t, typ.Methods, "Err") + assert.Contains(t, typ.Methods, "Value") +} + +// --------------------------------------------------------------------------- +// TestResolveEmbeddings_FallbackToWellKnown +// --------------------------------------------------------------------------- + +// TestResolveEmbeddings_FallbackToWellKnown verifies that when StdlibLoader is +// nil, resolveEmbeddings still resolves io.Closer via the well-known table. +func TestResolveEmbeddings_FallbackToWellKnown(t *testing.T) { + loader := &GoThirdPartyLocalLoader{ + registry: nil, // no registry → StdlibLoader unavailable + } + + pkg := makePkgWithEmbedType("Resource", "io.Closer") + + loader.resolveEmbeddings(pkg) + + typ := pkg.Types["Resource"] + assert.Contains(t, typ.Methods, "Close", "Close should resolve via well-known table even without StdlibLoader") +} + +// --------------------------------------------------------------------------- +// TestResolveEmbeddings_NilRegistryStdlibLoader +// --------------------------------------------------------------------------- + +// TestResolveEmbeddings_NilRegistryStdlibLoader ensures no panic when registry +// is non-nil but StdlibLoader is nil; should fall back to well-known table. +func TestResolveEmbeddings_NilRegistryStdlibLoader(t *testing.T) { + reg := core.NewGoModuleRegistry() + reg.StdlibLoader = nil // explicitly nil + loader := &GoThirdPartyLocalLoader{registry: reg} + + pkg := makePkgWithEmbedType("Resource", "io.Closer") + + assert.NotPanics(t, func() { + loader.resolveEmbeddings(pkg) + }) + + typ := pkg.Types["Resource"] + assert.Contains(t, typ.Methods, "Close", "well-known fallback should fire when StdlibLoader is nil") +} + +// --------------------------------------------------------------------------- +// TestResolveEmbeddings_DoesNotOverwriteExistingMethods +// --------------------------------------------------------------------------- + +// TestResolveEmbeddings_DoesNotOverwriteExistingMethods ensures that methods +// already present on the type are not replaced by embedded versions. +func TestResolveEmbeddings_DoesNotOverwriteExistingMethods(t *testing.T) { + customClose := &core.GoStdlibFunction{Name: "Close", Confidence: 0.5} //nolint:mnd + + reg := core.NewGoModuleRegistry() + reg.StdlibLoader = &mockStdlibLoaderForEmbed{ + types: map[string]*core.GoStdlibType{ + "io.Closer": { + Name: "Closer", + Methods: map[string]*core.GoStdlibFunction{ + "Close": {Name: "Close", Confidence: 1.0}, + }, + }, + }, + } + loader := &GoThirdPartyLocalLoader{registry: reg} + + pkg := core.NewGoStdlibPackage("github.com/example/mypkg", "") + pkg.Types["Resource"] = &core.GoStdlibType{ + Name: "Resource", + Kind: "struct", + Methods: map[string]*core.GoStdlibFunction{ + "Close": customClose, // already present + }, + Embeds: []string{"io.Closer"}, + } + + loader.resolveEmbeddings(pkg) + + // The custom Close (Confidence 0.5) must NOT be replaced by the stdlib one (1.0). + assert.InDelta(t, 0.5, pkg.Types["Resource"].Methods["Close"].Confidence, 0.001) +} + +// --------------------------------------------------------------------------- +// TestResolveEmbeddings_SamePackageEmbedSkipped +// --------------------------------------------------------------------------- + +// TestResolveEmbeddings_SamePackageEmbedSkipped verifies that same-package embeds +// (no dot in name) are ignored by resolveEmbeddings (they're handled earlier by +// flattenEmbeddedMethods). +func TestResolveEmbeddings_SamePackageEmbedSkipped(t *testing.T) { + loader := &GoThirdPartyLocalLoader{registry: nil} + + pkg := core.NewGoStdlibPackage("github.com/example/mypkg", "") + pkg.Types["Client"] = &core.GoStdlibType{ + Name: "Client", + Kind: "interface", + Methods: map[string]*core.GoStdlibFunction{}, + Embeds: []string{"EnqueueClient"}, // no dot → same package + } + + // Must not panic even when EnqueueClient type is absent from pkg.Types. + assert.NotPanics(t, func() { + loader.resolveEmbeddings(pkg) + }) + + // No methods should have been added (same-package embed, no fallback available). + assert.Empty(t, pkg.Types["Client"].Methods) +} diff --git a/sast-engine/graph/callgraph/registry/go_thirdparty_local.go b/sast-engine/graph/callgraph/registry/go_thirdparty_local.go index ccb5911c..89753770 100644 --- a/sast-engine/graph/callgraph/registry/go_thirdparty_local.go +++ b/sast-engine/graph/callgraph/registry/go_thirdparty_local.go @@ -53,6 +53,7 @@ type GoThirdPartyLocalLoader struct { cacheDir string // disk cache directory: {userCacheDir}/code-pathfinder/go-thirdparty/{projectHash}/ diskIndex *cacheIndex // loaded from cache-index.json; nil when disk cache is unavailable logger *output.Logger + registry *core.GoModuleRegistry // for StdlibLoader access in cross-package embed resolution } // NewGoThirdPartyLocalLoader creates a loader that finds and parses third-party @@ -60,11 +61,12 @@ type GoThirdPartyLocalLoader struct { // // When refreshCache is true (set by --refresh-rules on the CLI), the existing // go-thirdparty disk cache for this project is deleted and rebuilt from source. -func NewGoThirdPartyLocalLoader(projectRoot string, refreshCache bool, logger *output.Logger) *GoThirdPartyLocalLoader { +func NewGoThirdPartyLocalLoader(projectRoot string, refreshCache bool, logger *output.Logger, registry *core.GoModuleRegistry) *GoThirdPartyLocalLoader { loader := &GoThirdPartyLocalLoader{ projectRoot: projectRoot, packageCache: make(map[string]*core.GoStdlibPackage), logger: logger, + registry: registry, } loader.moduleVersions = parseGoModRequires(projectRoot) if logger != nil { @@ -223,6 +225,10 @@ func (l *GoThirdPartyLocalLoader) getOrLoadPackage(importPath string) (*core.GoS return nil, err } + // Resolve cross-package embeds (e.g., io.Closer, context.Context) via StdlibLoader + // when available, falling back to the hardcoded well-known table. + l.resolveEmbeddings(pkg) + l.packageCache[importPath] = pkg if l.logger != nil { l.logger.Debug("Extracted third-party package %s: %d types, %d functions", @@ -440,12 +446,53 @@ func extractGoPackageWithTreeSitter(importPath, srcDir string) (*core.GoStdlibPa // e.g., if Client embeds EnqueueClient, copy EnqueueClient's methods into Client. flattenEmbeddedMethods(pkg) - // Resolve cross-package embeds (e.g., io.Closer) using well-known stdlib interfaces. + // Resolve cross-package embeds using the well-known stdlib interface table. + // When called via getOrLoadPackage, resolveEmbeddings will additionally try + // StdlibLoader for interfaces not in this table. resolveWellKnownEmbeds(pkg) return pkg, nil } +// resolveEmbeddings resolves cross-package embedded interfaces for all types in pkg. +// Resolution order: +// 1. StdlibLoader.GetType — covers ALL stdlib interfaces (requires registry to be set) +// 2. getWellKnownInterfaceMethods — hardcoded fallback for when StdlibLoader is unavailable +func (l *GoThirdPartyLocalLoader) resolveEmbeddings(pkg *core.GoStdlibPackage) { + for _, typ := range pkg.Types { + for _, embeddedName := range typ.Embeds { + if !strings.Contains(embeddedName, ".") { + continue // same-package — already handled by flattenEmbeddedMethods + } + dotIdx := strings.LastIndex(embeddedName, ".") + pkgAlias := embeddedName[:dotIdx] + typeName := embeddedName[dotIdx+1:] + + // Try StdlibLoader first: covers all stdlib interfaces (e.g. context.Context, + // sort.Interface) that are not in the hardcoded well-known table. + if l.registry != nil && l.registry.StdlibLoader != nil { + if stdType, err := l.registry.StdlibLoader.GetType(pkgAlias, typeName); err == nil && stdType != nil { + for methodName, method := range stdType.Methods { + if _, exists := typ.Methods[methodName]; !exists { + typ.Methods[methodName] = method + } + } + continue + } + } + + // Fallback: well-known table (when StdlibLoader is unavailable). + if methods := getWellKnownInterfaceMethods(pkgAlias, typeName); methods != nil { + for methodName, method := range methods { + if _, exists := typ.Methods[methodName]; !exists { + typ.Methods[methodName] = method + } + } + } + } + } +} + // resolveWellKnownEmbeds resolves cross-package embedded interfaces using a hardcoded // table of well-known stdlib interfaces (io.Closer, io.Reader, etc.). func resolveWellKnownEmbeds(pkg *core.GoStdlibPackage) { diff --git a/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go b/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go index 9e6a5433..30a36772 100644 --- a/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go +++ b/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go @@ -212,7 +212,7 @@ func Open(dialector interface{}) (*DB, error) { require.NoError(t, err) // Create loader and test - loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil, nil) // Validate import assert.True(t, loader.ValidateImport("gorm.io/gorm")) @@ -282,7 +282,7 @@ func Default() *Engine { return nil } err = os.WriteFile(filepath.Join(vendorDir, "context.go"), []byte(ginSrc), 0644) require.NoError(t, err) - loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil, nil) // Verify type and methods ctxType, err := loader.GetType("github.com/gin-gonic/gin", "Context") @@ -453,7 +453,7 @@ func TestDiskCacheWriteAndRead(t *testing.T) { projectDir := makeVendoredProject(t) // Cold run: loader extracts from vendor/ and writes to disk cache. - loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) dbType, err := loader1.GetType("gorm.io/gorm", "DB") require.NoError(t, err) require.NotNil(t, dbType) @@ -475,7 +475,7 @@ func TestDiskCacheWriteAndRead(t *testing.T) { // Remove vendor/ to prove it's not re-parsing from source. require.NoError(t, os.RemoveAll(filepath.Join(projectDir, "vendor"))) - loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) dbType2, err := loader2.GetType("gorm.io/gorm", "DB") require.NoError(t, err) require.NotNil(t, dbType2, "disk cache hit should return the type even without vendor/") @@ -488,7 +488,7 @@ func TestCacheVersionMismatch(t *testing.T) { projectDir := makeVendoredProject(t) // Cold run at v1.25.7. - loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) _, err := loader1.GetType("gorm.io/gorm", "DB") require.NoError(t, err) @@ -506,7 +506,7 @@ func (db *DB) Save(value interface{}) *DB { return db } require.NoError(t, os.WriteFile(filepath.Join(vendorDir, "gorm.go"), []byte(newSrc), 0644)) // New loader: same cache dir, but go.mod says v1.25.8 — cache entry is v1.25.7 → miss. - loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) dbType, err := loader2.GetType("gorm.io/gorm", "DB") require.NoError(t, err) require.NotNil(t, dbType) @@ -526,7 +526,7 @@ func TestRefreshCacheFlush(t *testing.T) { projectDir := makeVendoredProject(t) // Cold run to populate cache. - loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) _, err := loader1.GetType("gorm.io/gorm", "DB") require.NoError(t, err) cacheDir := loader1.cacheDir @@ -550,7 +550,7 @@ func TestRefreshCacheFlush(t *testing.T) { require.NoError(t, os.WriteFile(pkgPath, patched, 0644)) // Refresh run: cache dir is wiped, extraction happens from vendor/ again. - loader2 := NewGoThirdPartyLocalLoader(projectDir, true, nil) + loader2 := NewGoThirdPartyLocalLoader(projectDir, true, nil, nil) dbType, err := loader2.GetType("gorm.io/gorm", "DB") require.NoError(t, err) require.NotNil(t, dbType) @@ -567,14 +567,14 @@ func TestRefreshCacheFlush(t *testing.T) { // TestPackageCount verifies PackageCount reflects the number of go.mod requires. func TestPackageCount(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) assert.Equal(t, 1, loader.PackageCount()) } // TestGetType_NotFound verifies that GetType returns an error for unknown types. func TestGoThirdPartyLocalGetType_NotFound(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) _, err := loader.GetType("gorm.io/gorm", "NonExistentType") assert.Error(t, err) @@ -583,7 +583,7 @@ func TestGoThirdPartyLocalGetType_NotFound(t *testing.T) { // TestGetFunction_NotFound verifies that GetFunction returns an error for unknown functions. func TestGoThirdPartyLocalGetFunction_NotFound(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) _, err := loader.GetFunction("gorm.io/gorm", "NonExistentFunc") assert.Error(t, err) @@ -597,7 +597,7 @@ func TestGetType_PackageNotFound(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644)) // No vendor/ directory, no GOMODCACHE entry → source not found. - loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil, nil) loader.cacheDir = "" // disable disk cache so we go straight to findPackageSource _, err := loader.GetType("github.com/missing/pkg", "SomeType") @@ -624,7 +624,7 @@ func (c *Client) Call() string { return "" } t.Setenv("GOMODCACHE", fakeCache) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) loader.cacheDir = "" // disable disk cache typ, err := loader.GetType("example.com/mylib", "Client") @@ -652,7 +652,7 @@ func (t *Token) Verify() bool { return true } t.Setenv("GOMODCACHE", fakeCache) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) loader.cacheDir = "" typ, err := loader.GetType("example.com/sdk/auth", "Token") @@ -665,7 +665,7 @@ func (t *Token) Verify() bool { return true } // produces an empty (not nil) index rather than a crash. func TestLoadCacheIndex_InvalidJSON(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) // Overwrite cache-index.json with garbage. require.NoError(t, os.WriteFile( @@ -684,7 +684,7 @@ func TestLoadCacheIndex_InvalidJSON(t *testing.T) { // diskIndex is nil (disk cache unavailable). func TestWriteToDiskCache_NilDiskIndex(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) loader.diskIndex = nil // simulate unavailable disk cache // Should not panic. @@ -724,7 +724,7 @@ func StandaloneFunc() int { return 0 } // loader with PackageCount == 0 (no crash, graceful degradation). func TestInitDiskCache_NoGoMod(t *testing.T) { emptyDir := t.TempDir() - loader := NewGoThirdPartyLocalLoader(emptyDir, false, nil) + loader := NewGoThirdPartyLocalLoader(emptyDir, false, nil, nil) assert.Equal(t, 0, loader.PackageCount()) } @@ -732,7 +732,7 @@ func TestInitDiskCache_NoGoMod(t *testing.T) { // when no module prefix matches — the fallback branch. func TestModuleKeyFor_NoMatch(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) // "unknown.io/pkg" is not in go.mod → should return the importPath unchanged. key := loader.moduleKeyFor("unknown.io/pkg") @@ -743,7 +743,7 @@ func TestModuleKeyFor_NoMatch(t *testing.T) { // when cache-index.json references a file that no longer exists on disk. func TestLoadFromDiskCache_MissingFile(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) // Populate index with an entry pointing at a non-existent file. loader.diskIndex = &cacheIndex{ @@ -761,14 +761,14 @@ func TestLoadFromDiskCache_MissingFile(t *testing.T) { func TestLoadFromDiskCache_WithLogger(t *testing.T) { projectDir := makeVendoredProject(t) logger := output.NewLogger(output.VerbosityDefault) - loader := NewGoThirdPartyLocalLoader(projectDir, false, logger) + loader := NewGoThirdPartyLocalLoader(projectDir, false, logger, nil) // Cold run to populate the disk cache. _, err := loader.GetType("gorm.io/gorm", "DB") require.NoError(t, err) // Warm run: build a new loader sharing the same cacheDir to hit the debug log branch. - loader2 := NewGoThirdPartyLocalLoader(projectDir, false, logger) + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, logger, nil) result := loader2.loadFromDiskCache("gorm.io/gorm") assert.NotNil(t, result) } @@ -776,7 +776,7 @@ func TestLoadFromDiskCache_WithLogger(t *testing.T) { // TestSaveCacheIndex_EmptyCacheDir verifies saveCacheIndex is a no-op when cacheDir is empty. func TestSaveCacheIndex_EmptyCacheDir(t *testing.T) { projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) loader.cacheDir = "" // simulate unavailable cache dir // Should not panic. @@ -788,7 +788,7 @@ func TestSaveCacheIndex_EmptyCacheDir(t *testing.T) { func TestWriteToDiskCache_WriteFailure(t *testing.T) { projectDir := makeVendoredProject(t) logger := output.NewLogger(output.VerbosityDefault) - loader := NewGoThirdPartyLocalLoader(projectDir, false, logger) + loader := NewGoThirdPartyLocalLoader(projectDir, false, logger, nil) // Point cacheDir at a regular file so os.WriteFile fails with ENOTDIR/EISDIR. fakeFile := filepath.Join(t.TempDir(), "not-a-dir") @@ -808,7 +808,7 @@ func TestInitDiskCache_MkdirAllFailure(t *testing.T) { require.NoError(t, os.WriteFile(blockingFile, []byte("x"), 0644)) projectDir := makeVendoredProject(t) - loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil, nil) // Override cacheDir to a path underneath a file (impossible to mkdir). loader.cacheDir = filepath.Join(blockingFile, "subdir") loader.diskIndex = nil @@ -825,12 +825,12 @@ func TestInitDiskCache_RefreshWithLogger(t *testing.T) { logger := output.NewLogger(output.VerbosityDefault) // First pass: populate cache. - loader1 := NewGoThirdPartyLocalLoader(projectDir, false, logger) + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, logger, nil) _, err := loader1.GetType("gorm.io/gorm", "DB") require.NoError(t, err) // Second pass with refreshCache=true should flush and still work. - loader2 := NewGoThirdPartyLocalLoader(projectDir, true, logger) + loader2 := NewGoThirdPartyLocalLoader(projectDir, true, logger, nil) typ, err := loader2.GetType("gorm.io/gorm", "DB") require.NoError(t, err) assert.NotNil(t, typ)