diff --git a/sast-engine/cmd/scan.go b/sast-engine/cmd/scan.go index 8cb168c8..b0d58b9c 100644 --- a/sast-engine/cmd/scan.go +++ b/sast-engine/cmd/scan.go @@ -255,6 +255,11 @@ Examples: } else { // Initialize Go stdlib loader and type inference engine builder.InitGoStdlibLoader(goRegistry, projectPath, logger) + + // Initialize Go third-party type loader (vendor/ + GOMODCACHE). + // Pass refreshRules so --refresh-rules also flushes the go-thirdparty disk cache. + builder.InitGoThirdPartyLoader(goRegistry, projectPath, refreshRules, logger) + goTypeEngine := resolution.NewGoTypeInferenceEngine(goRegistry) goCG, err := builder.BuildGoCallGraph(codeGraph, goRegistry, goTypeEngine) diff --git a/sast-engine/graph/callgraph/builder/go_builder.go b/sast-engine/graph/callgraph/builder/go_builder.go index aebe60aa..bfefeb1b 100644 --- a/sast-engine/graph/callgraph/builder/go_builder.go +++ b/sast-engine/graph/callgraph/builder/go_builder.go @@ -506,6 +506,24 @@ func resolveGoCallTarget( } } + // Check 2.5: Validate method via ThirdPartyLoader (vendor/GOMODCACHE) + if registry != nil && registry.ThirdPartyLoader != nil { + importPath, typeName, ok := splitGoTypeFQN(typeFQN) + if ok { + // Skip if already checked as stdlib + isStdlib := registry.StdlibLoader != nil && + registry.StdlibLoader.ValidateStdlibImport(importPath) + if !isStdlib && registry.ThirdPartyLoader.ValidateImport(importPath) { + tpType, err := registry.ThirdPartyLoader.GetType(importPath, typeName) + if err == nil && tpType != nil { + if _, hasMethod := tpType.Methods[callSite.FunctionName]; hasMethod { + return methodFQN, true, false // resolved via third-party + } + } + } + } + } + // Check 3: Promoted method via struct embedding if promotedFQN, resolved, isStdlib := resolvePromotedMethod( typeFQN, callSite.FunctionName, registry, @@ -513,7 +531,7 @@ func resolveGoCallTarget( return promotedFQN, true, isStdlib } - // Check 4: Third-party / unvalidated — accept with best-effort FQN + // Check 4: Unvalidated — accept with best-effort FQN return methodFQN, true, false } } diff --git a/sast-engine/graph/callgraph/builder/go_builder_thirdparty_test.go b/sast-engine/graph/callgraph/builder/go_builder_thirdparty_test.go new file mode 100644 index 00000000..82f2b574 --- /dev/null +++ b/sast-engine/graph/callgraph/builder/go_builder_thirdparty_test.go @@ -0,0 +1,180 @@ +package builder + +import ( + "os" + "path/filepath" + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/resolution" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestThirdPartyResolution_Check25_MethodValidation tests the full pipeline: +// go.mod dependency → vendor/ source → tree-sitter extraction → Pattern 1b Check 2.5 resolution. +func TestThirdPartyResolution_Check25_MethodValidation(t *testing.T) { + tmpDir := t.TempDir() + + // 1. Create go.mod with gorm dependency + goMod := `module testapp + +go 1.21 + +require gorm.io/gorm v1.25.7 +` + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644) + require.NoError(t, err) + + // 2. Create vendor/gorm.io/gorm/ with type metadata source + vendorDir := filepath.Join(tmpDir, "vendor", "gorm.io", "gorm") + err = os.MkdirAll(vendorDir, 0755) + require.NoError(t, err) + + gormSrc := `package gorm + +type DB struct { + Error error +} + +func (db *DB) Where(query interface{}, args ...interface{}) *DB { + return db +} + +func (db *DB) Raw(sql string, values ...interface{}) *DB { + return db +} + +func (db *DB) Exec(sql string, values ...interface{}) *DB { + return db +} + +func Open(dialector interface{}) (*DB, error) { + return nil, nil +} +` + err = os.WriteFile(filepath.Join(vendorDir, "gorm.go"), []byte(gormSrc), 0644) + require.NoError(t, err) + + // 3. Create user code that uses gorm + mainSrc := `package main + +import "gorm.io/gorm" + +func handler(db *gorm.DB) { + input := "user input" + db.Raw(input) + db.Where(input) + db.Exec(input) +} +` + err = os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte(mainSrc), 0644) + require.NoError(t, err) + + // 4. Build code graph and call graph + codeGraph := graph.Initialize(tmpDir, nil) + require.NotNil(t, codeGraph) + + goRegistry, err := resolution.BuildGoModuleRegistry(tmpDir) + require.NoError(t, err) + + // Initialize third-party loader (this is what scan.go would do) + InitGoThirdPartyLoader(goRegistry, tmpDir, false, nil) + require.NotNil(t, goRegistry.ThirdPartyLoader, "ThirdPartyLoader should be initialized") + + goTypeEngine := resolution.NewGoTypeInferenceEngine(goRegistry) + + callGraph, err := BuildGoCallGraph(codeGraph, goRegistry, goTypeEngine) + require.NoError(t, err) + require.NotNil(t, callGraph) + + // 5. Verify that third-party methods resolved correctly via Check 2.5 + // Look for call sites from handler function + handlerFQN := "testapp.handler" + callSites, ok := callGraph.CallSites[handlerFQN] + require.True(t, ok, "handler function should have call sites") + + resolvedTargets := make(map[string]bool) + for _, cs := range callSites { + if cs.Resolved { + resolvedTargets[cs.TargetFQN] = true + } + } + + // These should be resolved via Check 2.5 (third-party vendor/) + assert.True(t, resolvedTargets["gorm.io/gorm.DB.Raw"], + "db.Raw() should resolve to gorm.io/gorm.DB.Raw via Check 2.5") + assert.True(t, resolvedTargets["gorm.io/gorm.DB.Where"], + "db.Where() should resolve to gorm.io/gorm.DB.Where via Check 2.5") + assert.True(t, resolvedTargets["gorm.io/gorm.DB.Exec"], + "db.Exec() should resolve to gorm.io/gorm.DB.Exec via Check 2.5") +} + +// TestThirdPartyResolution_SubpackagePath tests resolution for subpackages +// within a third-party module (e.g., github.com/gin-gonic/gin/binding). +func TestThirdPartyResolution_SubpackagePath(t *testing.T) { + tmpDir := t.TempDir() + + goMod := `module testapp + +go 1.21 + +require github.com/gin-gonic/gin v1.9.1 +` + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644) + require.NoError(t, err) + + // Create vendor with gin Context + vendorDir := filepath.Join(tmpDir, "vendor", "github.com", "gin-gonic", "gin") + err = os.MkdirAll(vendorDir, 0755) + require.NoError(t, err) + + ginSrc := `package gin + +type Context struct{} + +func (c *Context) Query(key string) string { return "" } +func (c *Context) Param(key string) string { return "" } + +type Engine struct{} + +func Default() *Engine { return nil } +` + err = os.WriteFile(filepath.Join(vendorDir, "gin.go"), []byte(ginSrc), 0644) + require.NoError(t, err) + + mainSrc := `package main + +import "github.com/gin-gonic/gin" + +func handler(c *gin.Context) { + q := c.Query("search") + _ = q +} +` + err = os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte(mainSrc), 0644) + require.NoError(t, err) + + codeGraph := graph.Initialize(tmpDir, nil) + goRegistry, err := resolution.BuildGoModuleRegistry(tmpDir) + require.NoError(t, err) + + InitGoThirdPartyLoader(goRegistry, tmpDir, false, nil) + goTypeEngine := resolution.NewGoTypeInferenceEngine(goRegistry) + + callGraph, err := BuildGoCallGraph(codeGraph, goRegistry, goTypeEngine) + require.NoError(t, err) + + handlerFQN := "testapp.handler" + callSites := callGraph.CallSites[handlerFQN] + + resolvedTargets := make(map[string]bool) + for _, cs := range callSites { + if cs.Resolved { + resolvedTargets[cs.TargetFQN] = true + } + } + + assert.True(t, resolvedTargets["github.com/gin-gonic/gin.Context.Query"], + "c.Query() should resolve to github.com/gin-gonic/gin.Context.Query") +} diff --git a/sast-engine/graph/callgraph/builder/go_version.go b/sast-engine/graph/callgraph/builder/go_version.go index b2256adb..34c3b7de 100644 --- a/sast-engine/graph/callgraph/builder/go_version.go +++ b/sast-engine/graph/callgraph/builder/go_version.go @@ -119,3 +119,27 @@ func initGoStdlibLoaderWithBase(reg *core.GoModuleRegistry, projectPath string, logger.Progress("Loaded Go %s stdlib manifest (%d packages)", version, remote.PackageCount()) reg.StdlibLoader = remote } + +// InitGoThirdPartyLoader initializes the third-party type loader for Go dependencies. +// Parses go.mod require directives and lazily loads type metadata from vendor/ or GOMODCACHE. +// +// When refreshCache is true (triggered by --refresh-rules on the CLI), the on-disk +// go-thirdparty extraction cache for this project is flushed and rebuilt. +func InitGoThirdPartyLoader(reg *core.GoModuleRegistry, projectPath string, refreshCache bool, logger *output.Logger) { + if reg == nil { + return + } + + loader := registry.NewGoThirdPartyLocalLoader(projectPath, refreshCache, logger) + if loader.PackageCount() == 0 { + if logger != nil { + logger.Debug("No Go third-party dependencies found in go.mod") + } + return + } + + reg.ThirdPartyLoader = loader + if logger != nil { + logger.Progress("Go third-party loader ready (%d dependencies from go.mod)", loader.PackageCount()) + } +} diff --git a/sast-engine/graph/callgraph/builder/go_version_test.go b/sast-engine/graph/callgraph/builder/go_version_test.go index 49d4c324..4d0c9740 100644 --- a/sast-engine/graph/callgraph/builder/go_version_test.go +++ b/sast-engine/graph/callgraph/builder/go_version_test.go @@ -290,3 +290,58 @@ func TestInitGoStdlibLoader_PublicAPI_CallsInner(t *testing.T) { require.NotNil(t, reg.StdlibLoader) } + +// ----------------------------------------------------------------------------- +// InitGoThirdPartyLoader +// ----------------------------------------------------------------------------- + +func TestInitGoThirdPartyLoader_NilReg(t *testing.T) { + // Must not panic when reg is nil. + InitGoThirdPartyLoader(nil, t.TempDir(), false, nil) +} + +func TestInitGoThirdPartyLoader_NoDependencies(t *testing.T) { + // go.mod with no require directives → PackageCount == 0 → ThirdPartyLoader stays nil. + tmpDir := t.TempDir() + writeTempFile(t, tmpDir, "go.mod", "module github.com/example/app\n\ngo 1.21\n") + + reg := core.NewGoModuleRegistry() + logger := newGoVersionTestLogger() + InitGoThirdPartyLoader(reg, tmpDir, false, logger) + + assert.Nil(t, reg.ThirdPartyLoader) +} + +func TestInitGoThirdPartyLoader_WithDependencies(t *testing.T) { + // go.mod with one require + vendored source → ThirdPartyLoader is set. + tmpDir := t.TempDir() + writeTempFile(t, tmpDir, "go.mod", + "module github.com/example/app\n\ngo 1.21\n\nrequire example.com/lib v1.0.0\n") + + vendorDir := filepath.Join(tmpDir, "vendor", "example.com", "lib") + require.NoError(t, os.MkdirAll(vendorDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(vendorDir, "lib.go"), + []byte("package lib\ntype Client struct{}\n"), 0o644)) + + reg := core.NewGoModuleRegistry() + logger := newGoVersionTestLogger() + InitGoThirdPartyLoader(reg, tmpDir, false, logger) + + assert.NotNil(t, reg.ThirdPartyLoader) +} + +func TestInitGoThirdPartyLoader_RefreshCache(t *testing.T) { + // refreshCache=true should not panic and should still set the loader. + tmpDir := t.TempDir() + writeTempFile(t, tmpDir, "go.mod", + "module github.com/example/app\n\ngo 1.21\n\nrequire example.com/lib v1.0.0\n") + + vendorDir := filepath.Join(tmpDir, "vendor", "example.com", "lib") + require.NoError(t, os.MkdirAll(vendorDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(vendorDir, "lib.go"), + []byte("package lib\ntype Client struct{}\n"), 0o644)) + + reg := core.NewGoModuleRegistry() + InitGoThirdPartyLoader(reg, tmpDir, true, nil) + assert.NotNil(t, reg.ThirdPartyLoader) +} diff --git a/sast-engine/graph/callgraph/core/go_stdlib_types.go b/sast-engine/graph/callgraph/core/go_stdlib_types.go index 464bce59..9f8ab1fe 100644 --- a/sast-engine/graph/callgraph/core/go_stdlib_types.go +++ b/sast-engine/graph/callgraph/core/go_stdlib_types.go @@ -251,6 +251,12 @@ type GoStdlibType struct { IsGeneric bool `json:"is_generic"` TypeParams []*GoTypeParam `json:"type_params"` Docstring string `json:"docstring"` + + // Embeds lists type names embedded by this interface or struct. + // For interfaces: embedded interface names (e.g., "io.Closer", "EnqueueClient"). + // For structs: embedded struct names (e.g., "*sql.DB"). + // Used by the third-party loader to flatten embedded methods post-extraction. + Embeds []string `json:"embeds,omitempty"` } // GoStructField represents a single exported field in a struct type declaration. diff --git a/sast-engine/graph/callgraph/core/types.go b/sast-engine/graph/callgraph/core/types.go index 2b66b42f..8ecdfe38 100644 --- a/sast-engine/graph/callgraph/core/types.go +++ b/sast-engine/graph/callgraph/core/types.go @@ -381,6 +381,10 @@ type GoModuleRegistry struct { // It is initialized lazily from the CDN registry during call graph construction. // Nil when stdlib registry loading is disabled or unavailable. StdlibLoader GoStdlibLoader + + // ThirdPartyLoader provides type metadata for Go third-party libraries. + // Parses from vendor/ or GOMODCACHE. Nil when unavailable. + ThirdPartyLoader GoThirdPartyLoader } // NewGoModuleRegistry creates an initialized GoModuleRegistry. @@ -530,6 +534,23 @@ type GoStdlibLoader interface { PackageCount() int } +// GoThirdPartyLoader provides access to Go third-party library type metadata. +// Mirrors GoStdlibLoader and reuses the same GoStdlibType/GoStdlibFunction structs. +// Implemented by registry.GoThirdPartyLocalLoader. +type GoThirdPartyLoader interface { + // ValidateImport reports whether the given import path is a known third-party package. + ValidateImport(importPath string) bool + + // GetFunction returns the metadata for a named function in the given third-party package. + GetFunction(importPath, funcName string) (*GoStdlibFunction, error) + + // GetType returns the metadata for a named type in the given third-party package. + GetType(importPath, typeName string) (*GoStdlibType, error) + + // PackageCount returns the total number of third-party packages available. + PackageCount() int +} + // Helper function to extract the last component of a dotted path. // Example: "myapp.utils.helpers" → "helpers". func extractShortName(modulePath string) string { diff --git a/sast-engine/graph/callgraph/extraction/go_variables.go b/sast-engine/graph/callgraph/extraction/go_variables.go index 05fb4e12..deabd13a 100644 --- a/sast-engine/graph/callgraph/extraction/go_variables.go +++ b/sast-engine/graph/callgraph/extraction/go_variables.go @@ -534,6 +534,10 @@ func inferTypeFromFunctionCall( if ti := inferTypeFromStdlibFunction(importPath, fnName, registry); ti != nil { return ti } + // Fallback: attempt third-party lookup for non-stdlib cross-package calls. + if ti := inferTypeFromThirdPartyFunction(importPath, fnName, registry); ti != nil { + return ti + } } // Function not found or has no return type @@ -617,6 +621,36 @@ func normalizeStdlibReturnType(rawType, importPath string) string { return importPath + "." + t } +// inferTypeFromThirdPartyFunction looks up the primary return type of a Go +// third-party function using the ThirdPartyLoader attached to the registry. +func inferTypeFromThirdPartyFunction(importPath, funcName string, registry *core.GoModuleRegistry) *core.TypeInfo { + if registry.ThirdPartyLoader == nil { + return nil + } + if !registry.ThirdPartyLoader.ValidateImport(importPath) { + return nil + } + fn, err := registry.ThirdPartyLoader.GetFunction(importPath, funcName) + if err != nil || fn == nil || len(fn.Returns) == 0 { + return nil + } + for _, ret := range fn.Returns { + if ret.Type == "" || ret.Type == "error" { + continue + } + typeFQN := normalizeStdlibReturnType(ret.Type, importPath) + if typeFQN == "" { + continue + } + return &core.TypeInfo{ + TypeFQN: typeFQN, + Confidence: 0.85, + Source: "thirdparty_local", + } + } + return nil +} + // extractGoFunctionName extracts the function name from a function node. // Handles: // - Simple calls: foo() diff --git a/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go b/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go index f9626304..63d37ac4 100644 --- a/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go +++ b/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go @@ -329,3 +329,94 @@ func Greet(name string) { require.NotEmpty(t, bindings) assert.Equal(t, "builtin.string", bindings[0].Type.TypeFQN) } + +// ----------------------------------------------------------------------------- +// inferTypeFromThirdPartyFunction +// ----------------------------------------------------------------------------- + +// mockThirdPartyLoader implements core.GoThirdPartyLoader for testing. +type mockThirdPartyLoader struct { + knownPkgs map[string]bool + functions map[string]*core.GoStdlibFunction // key: "importPath.funcName" +} + +func (m *mockThirdPartyLoader) ValidateImport(importPath string) bool { + return m.knownPkgs[importPath] +} + +func (m *mockThirdPartyLoader) GetFunction(importPath, funcName string) (*core.GoStdlibFunction, error) { + key := importPath + "." + funcName + fn, ok := m.functions[key] + if !ok { + return nil, errors.New("not found") + } + return fn, nil +} + +func (m *mockThirdPartyLoader) GetType(_, _ string) (*core.GoStdlibType, error) { + return nil, errors.New("not implemented") +} + +func (m *mockThirdPartyLoader) PackageCount() int { return len(m.knownPkgs) } + +func TestInferTypeFromThirdPartyFunction_NilLoader(t *testing.T) { + reg := core.NewGoModuleRegistry() + // ThirdPartyLoader is nil → returns nil without panic. + result := inferTypeFromThirdPartyFunction("gorm.io/gorm", "Open", reg) + assert.Nil(t, result) +} + +func TestInferTypeFromThirdPartyFunction_UnknownPackage(t *testing.T) { + reg := core.NewGoModuleRegistry() + reg.ThirdPartyLoader = &mockThirdPartyLoader{knownPkgs: map[string]bool{}} + result := inferTypeFromThirdPartyFunction("gorm.io/gorm", "Open", reg) + assert.Nil(t, result) +} + +func TestInferTypeFromThirdPartyFunction_FuncNotFound(t *testing.T) { + reg := core.NewGoModuleRegistry() + reg.ThirdPartyLoader = &mockThirdPartyLoader{ + knownPkgs: map[string]bool{"gorm.io/gorm": true}, + functions: map[string]*core.GoStdlibFunction{}, + } + result := inferTypeFromThirdPartyFunction("gorm.io/gorm", "Open", reg) + assert.Nil(t, result) +} + +func TestInferTypeFromThirdPartyFunction_ErrorOnly(t *testing.T) { + // Function exists but its only return type is "error" → returns nil. + reg := core.NewGoModuleRegistry() + reg.ThirdPartyLoader = &mockThirdPartyLoader{ + knownPkgs: map[string]bool{"gorm.io/gorm": true}, + functions: map[string]*core.GoStdlibFunction{ + "gorm.io/gorm.Close": { + Name: "Close", + Returns: []*core.GoReturnValue{{Type: "error"}}, + }, + }, + } + result := inferTypeFromThirdPartyFunction("gorm.io/gorm", "Close", reg) + assert.Nil(t, result) +} + +func TestInferTypeFromThirdPartyFunction_Success(t *testing.T) { + // Open(dialector) (*DB, error) → first non-error return is *DB. + reg := core.NewGoModuleRegistry() + reg.ThirdPartyLoader = &mockThirdPartyLoader{ + knownPkgs: map[string]bool{"gorm.io/gorm": true}, + functions: map[string]*core.GoStdlibFunction{ + "gorm.io/gorm.Open": { + Name: "Open", + Returns: []*core.GoReturnValue{ + {Type: "*DB"}, + {Type: "error"}, + }, + }, + }, + } + result := inferTypeFromThirdPartyFunction("gorm.io/gorm", "Open", reg) + require.NotNil(t, result) + assert.Equal(t, "gorm.io/gorm.DB", result.TypeFQN) + assert.InDelta(t, 0.85, result.Confidence, 0.001) + assert.Equal(t, "thirdparty_local", result.Source) +} diff --git a/sast-engine/graph/callgraph/registry/go_thirdparty_local.go b/sast-engine/graph/callgraph/registry/go_thirdparty_local.go new file mode 100644 index 00000000..ccb5911c --- /dev/null +++ b/sast-engine/graph/callgraph/registry/go_thirdparty_local.go @@ -0,0 +1,913 @@ +package registry + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + "unicode" + + sitter "github.com/smacker/go-tree-sitter" + golang "github.com/smacker/go-tree-sitter/golang" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/output" +) + +// cacheIndexVersion is the schema version written into cache-index.json. +// Increment when the cache format changes to force re-extraction on upgrade. +const cacheIndexVersion = "1.0.0" + +// errPackageSourceNotFound is returned by getOrLoadPackage when no source +// directory can be located for the requested import path. +var errPackageSourceNotFound = errors.New("go-thirdparty: package source not found in vendor/ or GOMODCACHE") + +// cacheIndexEntry is one record in cache-index.json. +type cacheIndexEntry struct { + Version string `json:"version"` + File string `json:"file"` + CachedAt time.Time `json:"cachedAt"` +} + +// cacheIndex is the in-memory representation of cache-index.json. +type cacheIndex struct { + Version string `json:"version"` + Entries map[string]*cacheIndexEntry `json:"entries"` +} + +// GoThirdPartyLocalLoader extracts type metadata from third-party Go packages +// found in vendor/ or GOMODCACHE. Uses tree-sitter for lightweight parsing. +// Implements core.GoThirdPartyLoader. +type GoThirdPartyLocalLoader struct { + projectRoot string + moduleVersions map[string]string // import path → version (from go.mod require) + packageCache map[string]*core.GoStdlibPackage // import path → extracted package (in-memory) + cacheMutex sync.RWMutex + cacheDir string // disk cache directory: {userCacheDir}/code-pathfinder/go-thirdparty/{projectHash}/ + diskIndex *cacheIndex // loaded from cache-index.json; nil when disk cache is unavailable + logger *output.Logger +} + +// NewGoThirdPartyLocalLoader creates a loader that finds and parses third-party +// Go packages from vendor/ or GOMODCACHE. +// +// When refreshCache is true (set by --refresh-rules on the CLI), the existing +// go-thirdparty disk cache for this project is deleted and rebuilt from source. +func NewGoThirdPartyLocalLoader(projectRoot string, refreshCache bool, logger *output.Logger) *GoThirdPartyLocalLoader { + loader := &GoThirdPartyLocalLoader{ + projectRoot: projectRoot, + packageCache: make(map[string]*core.GoStdlibPackage), + logger: logger, + } + loader.moduleVersions = parseGoModRequires(projectRoot) + if logger != nil { + logger.Debug("Go third-party local loader: found %d dependencies in go.mod", len(loader.moduleVersions)) + } + loader.cacheDir = goThirdPartyCacheDir(projectRoot) + loader.initDiskCache(refreshCache) + return loader +} + +// goThirdPartyCacheDir returns the project-specific disk cache directory. +// Path: {os.UserCacheDir}/code-pathfinder/go-thirdparty/{sha256(projectRoot)[:12]}. +func goThirdPartyCacheDir(projectRoot string) string { + base, err := os.UserCacheDir() + if err != nil { + base = os.TempDir() + } + h := sha256.Sum256([]byte(projectRoot)) + projectHash := hex.EncodeToString(h[:])[:12] + return filepath.Join(base, "code-pathfinder", "go-thirdparty", projectHash) +} + +// initDiskCache prepares the on-disk cache directory and loads cache-index.json. +// If refreshCache is true, the directory is wiped before loading (always a miss). +func (l *GoThirdPartyLocalLoader) initDiskCache(refreshCache bool) { + if refreshCache { + if err := os.RemoveAll(l.cacheDir); err != nil && l.logger != nil { + l.logger.Debug("go-thirdparty: failed to flush cache dir %s: %v", l.cacheDir, err) + } + } + if err := os.MkdirAll(l.cacheDir, 0o755); err != nil { + if l.logger != nil { + l.logger.Debug("go-thirdparty: could not create cache dir %s: %v", l.cacheDir, err) + } + return + } + l.diskIndex = l.loadCacheIndex() +} + +// loadCacheIndex reads cache-index.json from disk. Returns an empty index on any error. +func (l *GoThirdPartyLocalLoader) loadCacheIndex() *cacheIndex { + indexPath := filepath.Join(l.cacheDir, "cache-index.json") + data, err := os.ReadFile(indexPath) + if err != nil { + return &cacheIndex{Version: cacheIndexVersion, Entries: make(map[string]*cacheIndexEntry)} + } + var idx cacheIndex + if err := json.Unmarshal(data, &idx); err != nil || idx.Entries == nil { + return &cacheIndex{Version: cacheIndexVersion, Entries: make(map[string]*cacheIndexEntry)} + } + return &idx +} + +// saveCacheIndex writes the current diskIndex to cache-index.json. +// Called while the write lock is held. +func (l *GoThirdPartyLocalLoader) saveCacheIndex() { + if l.diskIndex == nil || l.cacheDir == "" { + return + } + data, err := json.MarshalIndent(l.diskIndex, "", " ") + if err != nil { + return + } + _ = os.WriteFile(filepath.Join(l.cacheDir, "cache-index.json"), data, 0o644) +} + +// encodeCachePath converts an import path to a safe filename component. +// e.g. "gorm.io/gorm" → "gorm.io_gorm". +func encodeCachePath(importPath string) string { + return strings.ReplaceAll(importPath, "/", "_") +} + +// ValidateImport reports whether the import path is a known third-party dependency. +func (l *GoThirdPartyLocalLoader) ValidateImport(importPath string) bool { + // Check if any known module is a prefix of the import path. + // e.g., importPath "gorm.io/gorm" matches module "gorm.io/gorm" + // e.g., importPath "github.com/gin-gonic/gin/binding" matches module "github.com/gin-gonic/gin" + for modPath := range l.moduleVersions { + if importPath == modPath || strings.HasPrefix(importPath, modPath+"/") { + return true + } + } + return false +} + +// GetFunction returns function metadata for a third-party package function. +func (l *GoThirdPartyLocalLoader) GetFunction(importPath, funcName string) (*core.GoStdlibFunction, error) { + pkg, err := l.getOrLoadPackage(importPath) + if err != nil || pkg == nil { + return nil, err + } + fn, ok := pkg.Functions[funcName] + if !ok { + return nil, fmt.Errorf("function %s not found in %s", funcName, importPath) + } + return fn, nil +} + +// GetType returns type metadata for a third-party package type. +func (l *GoThirdPartyLocalLoader) GetType(importPath, typeName string) (*core.GoStdlibType, error) { + pkg, err := l.getOrLoadPackage(importPath) + if err != nil || pkg == nil { + return nil, err + } + typ, ok := pkg.Types[typeName] + if !ok { + return nil, fmt.Errorf("type %s not found in %s", typeName, importPath) + } + return typ, nil +} + +// PackageCount returns the number of known third-party dependencies. +func (l *GoThirdPartyLocalLoader) PackageCount() int { + return len(l.moduleVersions) +} + +// getOrLoadPackage retrieves a package from the in-memory cache, the disk cache, +// or by parsing from vendor/GOMODCACHE (in that priority order). +func (l *GoThirdPartyLocalLoader) getOrLoadPackage(importPath string) (*core.GoStdlibPackage, error) { + // Fast path: in-memory cache (includes negative results stored as nil). + l.cacheMutex.RLock() + if pkg, ok := l.packageCache[importPath]; ok { + l.cacheMutex.RUnlock() + return pkg, nil + } + l.cacheMutex.RUnlock() + + // Slow path: disk cache then source parse. + l.cacheMutex.Lock() + defer l.cacheMutex.Unlock() + + // Double-check under write lock. + if pkg, ok := l.packageCache[importPath]; ok { + return pkg, nil + } + + // Disk cache hit: version must match go.mod require version. + if pkg := l.loadFromDiskCache(importPath); pkg != nil { + l.packageCache[importPath] = pkg + return pkg, nil + } + + // Parse from vendor/ or GOMODCACHE. + srcDir := l.findPackageSource(importPath) + if srcDir == "" { + l.packageCache[importPath] = nil + return nil, errPackageSourceNotFound + } + + pkg, err := extractGoPackageWithTreeSitter(importPath, srcDir) + if err != nil { + if l.logger != nil { + l.logger.Debug("Failed to extract third-party package %s: %v", importPath, err) + } + l.packageCache[importPath] = nil + return nil, err + } + + l.packageCache[importPath] = pkg + if l.logger != nil { + l.logger.Debug("Extracted third-party package %s: %d types, %d functions", + importPath, len(pkg.Types), len(pkg.Functions)) + } + + // Persist to disk cache for subsequent runs. + l.writeToDiskCache(importPath, pkg) + return pkg, nil +} + +// loadFromDiskCache attempts to read a GoStdlibPackage from the disk cache. +// Returns nil on any cache miss, version mismatch, or read error. +func (l *GoThirdPartyLocalLoader) loadFromDiskCache(importPath string) *core.GoStdlibPackage { + if l.diskIndex == nil || l.cacheDir == "" { + return nil + } + entry, ok := l.diskIndex.Entries[importPath] + if !ok { + return nil + } + // Version mismatch → stale cache, re-extract. + if wantVer := l.moduleVersions[l.moduleKeyFor(importPath)]; wantVer != "" && entry.Version != wantVer { + return nil + } + data, err := os.ReadFile(filepath.Join(l.cacheDir, entry.File)) + if err != nil { + return nil + } + var pkg core.GoStdlibPackage + if err := json.Unmarshal(data, &pkg); err != nil { + return nil + } + if l.logger != nil { + l.logger.Debug("go-thirdparty: disk cache hit for %s (%s)", importPath, entry.Version) + } + return &pkg +} + +// writeToDiskCache serialises pkg to a JSON file and updates cache-index.json. +// Errors are logged at debug level and silently ignored (cache is best-effort). +func (l *GoThirdPartyLocalLoader) writeToDiskCache(importPath string, pkg *core.GoStdlibPackage) { + if l.diskIndex == nil || l.cacheDir == "" { + return + } + version := l.moduleVersions[l.moduleKeyFor(importPath)] + fileName := encodeCachePath(importPath) + "@" + version + ".json" + + data, err := json.MarshalIndent(pkg, "", " ") + if err != nil { + return + } + if err := os.WriteFile(filepath.Join(l.cacheDir, fileName), data, 0o644); err != nil { + if l.logger != nil { + l.logger.Debug("go-thirdparty: failed to write disk cache for %s: %v", importPath, err) + } + return + } + l.diskIndex.Entries[importPath] = &cacheIndexEntry{ + Version: version, + File: fileName, + CachedAt: time.Now().UTC(), + } + l.saveCacheIndex() +} + +// moduleKeyFor returns the go.mod module path that owns the given import path. +func (l *GoThirdPartyLocalLoader) moduleKeyFor(importPath string) string { + for modPath := range l.moduleVersions { + if importPath == modPath || strings.HasPrefix(importPath, modPath+"/") { + return modPath + } + } + return importPath +} + +// findPackageSource locates the source directory for an import path. +// Checks vendor/ first, then GOMODCACHE. +func (l *GoThirdPartyLocalLoader) findPackageSource(importPath string) string { + // 1. Check vendor/ + vendorPath := filepath.Join(l.projectRoot, "vendor", importPath) + if hasGoFiles(vendorPath) { + return vendorPath + } + + // 2. Check GOMODCACHE + modCache := os.Getenv("GOMODCACHE") + if modCache == "" { + gopath := os.Getenv("GOPATH") + if gopath == "" { + gopath = filepath.Join(os.Getenv("HOME"), "go") + } + modCache = filepath.Join(gopath, "pkg", "mod") + } + + // Find the module that owns this import path + for modPath, version := range l.moduleVersions { + if importPath == modPath || strings.HasPrefix(importPath, modPath+"/") { + // The subpackage path within the module + subPkg := strings.TrimPrefix(importPath, modPath) + modDir := filepath.Join(modCache, modPath+"@"+version) + pkgDir := filepath.Join(modDir, subPkg) + if hasGoFiles(pkgDir) { + return pkgDir + } + // Also try without subpackage (root of the module) + if subPkg == "" && hasGoFiles(modDir) { + return modDir + } + } + } + + return "" +} + +// hasGoFiles checks if a directory exists and contains at least one non-test .go file. +func hasGoFiles(dir string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && + !strings.HasSuffix(entry.Name(), "_test.go") { + return true + } + } + return false +} + +// parseGoModRequires extracts require directives from go.mod. +// Returns map of module path → version. +func parseGoModRequires(projectRoot string) map[string]string { + goModPath := filepath.Join(projectRoot, "go.mod") + content, err := os.ReadFile(goModPath) + if err != nil { + return nil + } + + requires := make(map[string]string) + inRequireBlock := false + + for _, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + + if line == "require (" { + inRequireBlock = true + continue + } + if inRequireBlock && line == ")" { + inRequireBlock = false + continue + } + + // Single-line require + if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { + parts := strings.Fields(line) + if len(parts) >= 3 { + requires[parts[1]] = parts[2] + } + continue + } + + // Inside require block + if inRequireBlock { + line = strings.TrimSuffix(line, "// indirect") + line = strings.TrimSpace(line) + parts := strings.Fields(line) + if len(parts) >= 2 { + requires[parts[0]] = parts[1] + } + } + } + + return requires +} + +// extractGoPackageWithTreeSitter parses .go files in a directory using tree-sitter +// and extracts exported types, methods, and functions into a GoStdlibPackage. +// After extraction, flattens embedded interface methods into parent interfaces. +func extractGoPackageWithTreeSitter(importPath, srcDir string) (*core.GoStdlibPackage, error) { + entries, err := os.ReadDir(srcDir) + if err != nil { + return nil, fmt.Errorf("reading directory %s: %w", srcDir, err) + } + + pkg := core.NewGoStdlibPackage(importPath, "") + + parser := sitter.NewParser() + parser.SetLanguage(golang.GetLanguage()) + defer parser.Close() + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".go") || + strings.HasSuffix(entry.Name(), "_test.go") { + continue + } + + filePath := filepath.Join(srcDir, entry.Name()) + src, err := os.ReadFile(filePath) + if err != nil { + continue + } + + tree, err := parser.ParseCtx(context.Background(), nil, src) + if err != nil { + continue + } + + extractFromTree(tree.RootNode(), src, pkg) + tree.Close() + } + + // Post-processing: flatten embedded interface/struct methods into parent types. + // e.g., if Client embeds EnqueueClient, copy EnqueueClient's methods into Client. + flattenEmbeddedMethods(pkg) + + // Resolve cross-package embeds (e.g., io.Closer) using well-known stdlib interfaces. + resolveWellKnownEmbeds(pkg) + + return pkg, nil +} + +// resolveWellKnownEmbeds resolves cross-package embedded interfaces using a hardcoded +// table of well-known stdlib interfaces (io.Closer, io.Reader, etc.). +func resolveWellKnownEmbeds(pkg *core.GoStdlibPackage) { + for _, typ := range pkg.Types { + for _, embeddedName := range typ.Embeds { + if !strings.Contains(embeddedName, ".") { + continue // same-package — already handled + } + dotIdx := strings.LastIndex(embeddedName, ".") + pkgAlias := embeddedName[:dotIdx] + typeName := embeddedName[dotIdx+1:] + + if methods := getWellKnownInterfaceMethods(pkgAlias, typeName); methods != nil { + for methodName, method := range methods { + if _, exists := typ.Methods[methodName]; !exists { + typ.Methods[methodName] = method + } + } + } + } + } +} + +// flattenEmbeddedMethods resolves embedded type references and copies their methods +// into the parent type. Handles same-package embeds (e.g., EnqueueClient) by looking +// up in pkg.Types. Cross-package embeds (e.g., io.Closer) are deferred to the loader. +func flattenEmbeddedMethods(pkg *core.GoStdlibPackage) { + for _, typ := range pkg.Types { + if len(typ.Embeds) == 0 { + continue + } + + for _, embeddedName := range typ.Embeds { + // Same-package embed: look up directly in pkg.Types + // e.g., "EnqueueClient" in posthog package + bareEmbed := strings.TrimPrefix(embeddedName, "*") + if embeddedType, ok := pkg.Types[bareEmbed]; ok { + for methodName, method := range embeddedType.Methods { + if _, exists := typ.Methods[methodName]; !exists { + typ.Methods[methodName] = method + } + } + // Recursively flatten (for multi-level embedding) + if len(embeddedType.Embeds) > 0 { + for _, deepEmbed := range embeddedType.Embeds { + deepBare := strings.TrimPrefix(deepEmbed, "*") + if deepType, ok2 := pkg.Types[deepBare]; ok2 { + for methodName, method := range deepType.Methods { + if _, exists := typ.Methods[methodName]; !exists { + typ.Methods[methodName] = method + } + } + } + } + } + } + // Cross-package embeds (e.g., "io.Closer") are resolved by resolveWellKnownEmbeds. + } + } +} + +// getWellKnownInterfaceMethods returns methods for commonly embedded stdlib interfaces. +// This is a hardcoded fallback for when we don't have access to the StdlibLoader +// (avoiding import cycle). Covers the most security-relevant embedded interfaces. +func getWellKnownInterfaceMethods(pkg, typeName string) map[string]*core.GoStdlibFunction { + key := pkg + "." + typeName + + wellKnown := map[string]map[string]*core.GoStdlibFunction{ + "io.Closer": { + "Close": {Name: "Close", Returns: []*core.GoReturnValue{{Type: "error"}}, Confidence: 1.0}, + }, + "io.Reader": { + "Read": {Name: "Read", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, + Returns: []*core.GoReturnValue{{Name: "n", Type: "int"}, {Name: "err", Type: "error"}}, Confidence: 1.0}, + }, + "io.Writer": { + "Write": {Name: "Write", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, + Returns: []*core.GoReturnValue{{Name: "n", Type: "int"}, {Name: "err", Type: "error"}}, Confidence: 1.0}, + }, + "io.ReadCloser": { + "Read": {Name: "Read", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, Returns: []*core.GoReturnValue{{Type: "int"}, {Type: "error"}}, Confidence: 1.0}, + "Close": {Name: "Close", Returns: []*core.GoReturnValue{{Type: "error"}}, Confidence: 1.0}, + }, + "io.WriteCloser": { + "Write": {Name: "Write", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, Returns: []*core.GoReturnValue{{Type: "int"}, {Type: "error"}}, Confidence: 1.0}, + "Close": {Name: "Close", Returns: []*core.GoReturnValue{{Type: "error"}}, Confidence: 1.0}, + }, + "io.ReadWriter": { + "Read": {Name: "Read", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, Returns: []*core.GoReturnValue{{Type: "int"}, {Type: "error"}}, Confidence: 1.0}, + "Write": {Name: "Write", Params: []*core.GoFunctionParam{{Name: "p", Type: "[]byte"}}, Returns: []*core.GoReturnValue{{Type: "int"}, {Type: "error"}}, Confidence: 1.0}, + }, + "fmt.Stringer": { + "String": {Name: "String", Returns: []*core.GoReturnValue{{Type: "string"}}, Confidence: 1.0}, + }, + "error": { + "Error": {Name: "Error", Returns: []*core.GoReturnValue{{Type: "string"}}, Confidence: 1.0}, + }, + } + + return wellKnown[key] +} + +// extractFromTree walks a Go AST and extracts exported declarations. +func extractFromTree(root *sitter.Node, src []byte, pkg *core.GoStdlibPackage) { + for i := 0; i < int(root.ChildCount()); i++ { + child := root.Child(i) + switch child.Type() { + case "function_declaration": + extractFunctionDecl(child, src, pkg) + case "method_declaration": + extractMethodDecl(child, src, pkg) + case "type_declaration": + extractTypeDecl(child, src, pkg) + } + } +} + +// extractFunctionDecl extracts an exported package-level function. +func extractFunctionDecl(node *sitter.Node, src []byte, pkg *core.GoStdlibPackage) { + nameNode := node.ChildByFieldName("name") + if nameNode == nil { + return + } + name := nameNode.Content(src) + if !isExported(name) { + return + } + + fn := &core.GoStdlibFunction{ + Name: name, + Confidence: 1.0, + } + + // Extract parameters + paramsNode := node.ChildByFieldName("parameters") + if paramsNode != nil { + fn.Params = extractParams(paramsNode, src) + fn.Signature = fmt.Sprintf("func %s%s", name, paramsNode.Content(src)) + } + + // Extract return type + resultNode := node.ChildByFieldName("result") + if resultNode != nil { + fn.Returns = extractReturns(resultNode, src) + fn.Signature += " " + resultNode.Content(src) + } + + pkg.Functions[name] = fn +} + +// extractMethodDecl extracts an exported method on a type. +func extractMethodDecl(node *sitter.Node, src []byte, pkg *core.GoStdlibPackage) { + nameNode := node.ChildByFieldName("name") + if nameNode == nil { + return + } + name := nameNode.Content(src) + if !isExported(name) { + return + } + + // Extract receiver type + receiverNode := node.ChildByFieldName("receiver") + if receiverNode == nil { + return + } + receiverType := extractReceiverTypeName(receiverNode, src) + if receiverType == "" { + return + } + + fn := &core.GoStdlibFunction{ + Name: name, + ReceiverType: receiverType, + Confidence: 1.0, + } + + // Extract parameters + paramsNode := node.ChildByFieldName("parameters") + if paramsNode != nil { + fn.Params = extractParams(paramsNode, src) + } + + // Extract return type + resultNode := node.ChildByFieldName("result") + if resultNode != nil { + fn.Returns = extractReturns(resultNode, src) + } + + // Ensure type exists in package + bareReceiver := strings.TrimPrefix(receiverType, "*") + typ, ok := pkg.Types[bareReceiver] + if !ok { + typ = &core.GoStdlibType{ + Name: bareReceiver, + Kind: "struct", + Methods: make(map[string]*core.GoStdlibFunction), + } + pkg.Types[bareReceiver] = typ + } + if typ.Methods == nil { + typ.Methods = make(map[string]*core.GoStdlibFunction) + } + typ.Methods[name] = fn +} + +// extractTypeDecl extracts exported type declarations (struct, interface, alias). +func extractTypeDecl(node *sitter.Node, src []byte, pkg *core.GoStdlibPackage) { + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "type_spec" { + continue + } + + nameNode := child.ChildByFieldName("name") + if nameNode == nil { + continue + } + name := nameNode.Content(src) + if !isExported(name) { + continue + } + + typeNode := child.ChildByFieldName("type") + if typeNode == nil { + continue + } + + // Get or create type entry (methods may have been added already) + typ, ok := pkg.Types[name] + if !ok { + typ = &core.GoStdlibType{ + Name: name, + Methods: make(map[string]*core.GoStdlibFunction), + } + pkg.Types[name] = typ + } + + switch typeNode.Type() { + case "struct_type": + typ.Kind = "struct" + typ.Fields = extractStructFields(typeNode, src) + case "interface_type": + typ.Kind = "interface" + // Interface methods are part of the interface body + extractInterfaceMethods(typeNode, src, typ) + default: + typ.Kind = "alias" + typ.Underlying = typeNode.Content(src) + } + } +} + +// extractStructFields extracts exported fields from a struct_type node. +func extractStructFields(node *sitter.Node, src []byte) []*core.GoStructField { + var fields []*core.GoStructField + + // Find field_declaration_list + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "field_declaration_list" { + continue + } + + for j := 0; j < int(child.ChildCount()); j++ { + field := child.Child(j) + if field.Type() != "field_declaration" { + continue + } + + nameNode := field.ChildByFieldName("name") + typeNode := field.ChildByFieldName("type") + + if typeNode == nil { + continue + } + + fieldName := "" + if nameNode != nil { + fieldName = nameNode.Content(src) + } + + exported := fieldName == "" || isExported(fieldName) // embedded fields are exported + if !exported { + continue + } + + tagNode := field.ChildByFieldName("tag") + tag := "" + if tagNode != nil { + tag = tagNode.Content(src) + } + + fields = append(fields, &core.GoStructField{ + Name: fieldName, + Type: typeNode.Content(src), + Tag: tag, + Exported: exported, + }) + } + } + + return fields +} + +// extractInterfaceMethods extracts method signatures and embedded interface names +// from an interface_type node. Embedded interfaces (e.g., io.Closer, EnqueueClient) +// are recorded in typ.Embeds for post-extraction flattening. +func extractInterfaceMethods(node *sitter.Node, src []byte, typ *core.GoStdlibType) { + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + + switch child.Type() { + case "method_spec", "method_elem": + // Direct method declaration: IsFeatureEnabled(payload string) (interface{}, error) + extractInterfaceMethodElem(child, src, typ) + + case "type_elem": + // Embedded type reference, wrapped in type_elem node. + // Contains either type_identifier (same-package) or qualified_type (cross-package). + for j := 0; j < int(child.ChildCount()); j++ { + inner := child.Child(j) + switch inner.Type() { + case "type_identifier": + typ.Embeds = append(typ.Embeds, inner.Content(src)) + case "qualified_type": + typ.Embeds = append(typ.Embeds, inner.Content(src)) + } + } + + case "type_identifier": + // Embedded same-package interface (direct child, some grammars) + typ.Embeds = append(typ.Embeds, child.Content(src)) + + case "qualified_type": + // Embedded cross-package interface (direct child, some grammars) + typ.Embeds = append(typ.Embeds, child.Content(src)) + } + } +} + +// extractInterfaceMethodElem extracts a single method from a method_elem or method_spec node. +func extractInterfaceMethodElem(child *sitter.Node, src []byte, typ *core.GoStdlibType) { + nameNode := child.ChildByFieldName("name") + if nameNode == nil { + return + } + name := nameNode.Content(src) + if !isExported(name) { + return + } + + fn := &core.GoStdlibFunction{ + Name: name, + Confidence: 1.0, + } + + paramsNode := child.ChildByFieldName("parameters") + if paramsNode != nil { + fn.Params = extractParams(paramsNode, src) + } + + resultNode := child.ChildByFieldName("result") + if resultNode != nil { + fn.Returns = extractReturns(resultNode, src) + } + + typ.Methods[name] = fn +} + +// extractParams extracts function parameters from a parameter_list node. +func extractParams(node *sitter.Node, src []byte) []*core.GoFunctionParam { + var params []*core.GoFunctionParam + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "parameter_declaration" { + continue + } + + typeNode := child.ChildByFieldName("type") + if typeNode == nil { + continue + } + paramType := typeNode.Content(src) + + // Check for variadic + isVariadic := strings.HasPrefix(paramType, "...") + if isVariadic { + paramType = strings.TrimPrefix(paramType, "...") + } + + // Extract parameter name(s) + nameNode := child.ChildByFieldName("name") + paramName := "" + if nameNode != nil { + paramName = nameNode.Content(src) + } + + params = append(params, &core.GoFunctionParam{ + Name: paramName, + Type: paramType, + IsVariadic: isVariadic, + }) + } + return params +} + +// extractReturns extracts return types from a result node. +func extractReturns(node *sitter.Node, src []byte) []*core.GoReturnValue { + content := node.Content(src) + + // Simple return type (no parens) + if !strings.HasPrefix(content, "(") { + return []*core.GoReturnValue{{Type: content}} + } + + // Multiple returns: (type1, type2, ...) + inner := strings.TrimPrefix(content, "(") + inner = strings.TrimSuffix(inner, ")") + var returns []*core.GoReturnValue + for _, part := range strings.Split(inner, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + // Handle named returns: "name type" + fields := strings.Fields(part) + if len(fields) == 2 { + returns = append(returns, &core.GoReturnValue{Name: fields[0], Type: fields[1]}) + } else { + returns = append(returns, &core.GoReturnValue{Type: part}) + } + } + return returns +} + +// extractReceiverTypeName extracts the type name from a receiver parameter list. +// e.g., "(db *DB)" → "*DB", "(s Server)" → "Server". +func extractReceiverTypeName(node *sitter.Node, src []byte) string { + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() != "parameter_declaration" { + continue + } + typeNode := child.ChildByFieldName("type") + if typeNode != nil { + typeName := typeNode.Content(src) + // Strip package qualifiers for receiver — we only care about the bare type name + if strings.Contains(typeName, ".") { + parts := strings.SplitN(typeName, ".", 2) + typeName = parts[len(parts)-1] + } + return typeName + } + } + return "" +} + +// isExported checks if a Go identifier is exported (starts with uppercase). +func isExported(name string) bool { + if name == "" { + return false + } + return unicode.IsUpper(rune(name[0])) +} diff --git a/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go b/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go new file mode 100644 index 00000000..9e6a5433 --- /dev/null +++ b/sast-engine/graph/callgraph/registry/go_thirdparty_local_test.go @@ -0,0 +1,859 @@ +package registry + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/output" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestParseGoModRequires tests parsing require directives from go.mod. +func TestParseGoModRequires(t *testing.T) { + tmpDir := t.TempDir() + + goMod := `module github.com/example/myapp + +go 1.21 + +require ( + gorm.io/gorm v1.25.7 + github.com/gin-gonic/gin v1.9.1 + github.com/stretchr/testify v1.9.0 // indirect +) + +require github.com/redis/go-redis/v9 v9.5.1 +` + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644) + require.NoError(t, err) + + requires := parseGoModRequires(tmpDir) + + assert.Equal(t, "v1.25.7", requires["gorm.io/gorm"]) + assert.Equal(t, "v1.9.1", requires["github.com/gin-gonic/gin"]) + assert.Equal(t, "v1.9.0", requires["github.com/stretchr/testify"]) + assert.Equal(t, "v9.5.1", requires["github.com/redis/go-redis/v9"]) +} + +// TestExtractGoPackageWithTreeSitter tests extracting type metadata from Go source. +func TestExtractGoPackageWithTreeSitter(t *testing.T) { + tmpDir := t.TempDir() + + // Write a minimal third-party package source file + src := `package gorm + +// DB is the main database handle. +type DB struct { + Error error + RowsAffected int64 + Statement *Statement +} + +// Statement holds the current query state. +type Statement struct { + SQL string +} + +// Dialector is the database driver interface. +type Dialector interface { + Initialize(db *DB) error + Name() string +} + +// Where adds a WHERE clause. +func (db *DB) Where(query interface{}, args ...interface{}) *DB { + return db +} + +// Find retrieves records. +func (db *DB) Find(dest interface{}, conds ...interface{}) *DB { + return db +} + +// Raw executes a raw SQL query. +func (db *DB) Raw(sql string, values ...interface{}) *DB { + return db +} + +// Exec executes a raw SQL statement. +func (db *DB) Exec(sql string, values ...interface{}) *DB { + return db +} + +// Create inserts a new record. +func (db *DB) Create(value interface{}) *DB { + return db +} + +// Open creates a new DB connection. +func Open(dialector Dialector, opts ...Option) (*DB, error) { + return nil, nil +} + +// Option configures the DB. +type Option struct{} + +// unexportedFunc should not be extracted. +func unexportedFunc() {} + +// unexportedType should not be extracted. +type unexportedType struct { + field string +} +` + err := os.WriteFile(filepath.Join(tmpDir, "gorm.go"), []byte(src), 0644) + require.NoError(t, err) + + pkg, err := extractGoPackageWithTreeSitter("gorm.io/gorm", tmpDir) + require.NoError(t, err) + require.NotNil(t, pkg) + + // Verify package metadata + assert.Equal(t, "gorm.io/gorm", pkg.ImportPath) + + // Verify types extracted + assert.Contains(t, pkg.Types, "DB") + assert.Contains(t, pkg.Types, "Statement") + assert.Contains(t, pkg.Types, "Dialector") + assert.Contains(t, pkg.Types, "Option") + + // Verify unexported types NOT extracted + assert.NotContains(t, pkg.Types, "unexportedType") + + // Verify DB type details + dbType := pkg.Types["DB"] + assert.Equal(t, "struct", dbType.Kind) + assert.NotNil(t, dbType.Methods) + + // Verify methods on DB + assert.Contains(t, dbType.Methods, "Where") + assert.Contains(t, dbType.Methods, "Find") + assert.Contains(t, dbType.Methods, "Raw") + assert.Contains(t, dbType.Methods, "Exec") + assert.Contains(t, dbType.Methods, "Create") + + // Verify Where method details + whereMethod := dbType.Methods["Where"] + assert.Equal(t, "Where", whereMethod.Name) + assert.NotEmpty(t, whereMethod.Params) + assert.NotEmpty(t, whereMethod.Returns) + assert.Equal(t, "*DB", whereMethod.Returns[0].Type) + + // Verify DB fields + assert.NotEmpty(t, dbType.Fields) + fieldNames := make([]string, 0, len(dbType.Fields)) + for _, f := range dbType.Fields { + fieldNames = append(fieldNames, f.Name) + } + assert.Contains(t, fieldNames, "Error") + assert.Contains(t, fieldNames, "RowsAffected") + assert.Contains(t, fieldNames, "Statement") + + // Verify Dialector interface + dialType := pkg.Types["Dialector"] + assert.Equal(t, "interface", dialType.Kind) + assert.Contains(t, dialType.Methods, "Initialize") + assert.Contains(t, dialType.Methods, "Name") + + // Verify package-level function + assert.Contains(t, pkg.Functions, "Open") + openFn := pkg.Functions["Open"] + assert.Equal(t, "Open", openFn.Name) + assert.Len(t, openFn.Returns, 2) + assert.Equal(t, "*DB", openFn.Returns[0].Type) + assert.Equal(t, "error", openFn.Returns[1].Type) + + // Verify unexported function NOT extracted + assert.NotContains(t, pkg.Functions, "unexportedFunc") +} + +// TestGoThirdPartyLocalLoader_VendorResolution tests resolution from vendor/. +func TestGoThirdPartyLocalLoader_VendorResolution(t *testing.T) { + tmpDir := t.TempDir() + + // Create go.mod with dependency + goMod := `module github.com/example/myapp + +go 1.21 + +require gorm.io/gorm v1.25.7 +` + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644) + require.NoError(t, err) + + // Create vendor/gorm.io/gorm/ with Go source + vendorDir := filepath.Join(tmpDir, "vendor", "gorm.io", "gorm") + err = os.MkdirAll(vendorDir, 0755) + require.NoError(t, err) + + gormSrc := `package gorm + +type DB struct { + Error error +} + +func (db *DB) Where(query interface{}, args ...interface{}) *DB { + return db +} + +func (db *DB) Raw(sql string, values ...interface{}) *DB { + return db +} + +func Open(dialector interface{}) (*DB, error) { + return nil, nil +} +` + err = os.WriteFile(filepath.Join(vendorDir, "gorm.go"), []byte(gormSrc), 0644) + require.NoError(t, err) + + // Create loader and test + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + + // Validate import + assert.True(t, loader.ValidateImport("gorm.io/gorm")) + assert.False(t, loader.ValidateImport("unknown.io/pkg")) + + // Get type + dbType, err := loader.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + require.NotNil(t, dbType) + + assert.Equal(t, "struct", dbType.Kind) + assert.Contains(t, dbType.Methods, "Where") + assert.Contains(t, dbType.Methods, "Raw") + + // Get function + openFn, err := loader.GetFunction("gorm.io/gorm", "Open") + require.NoError(t, err) + require.NotNil(t, openFn) + assert.Equal(t, "Open", openFn.Name) + assert.Len(t, openFn.Returns, 2) + assert.Equal(t, "*DB", openFn.Returns[0].Type) +} + +// TestGoThirdPartyLocalLoader_MethodValidation tests that method existence +// can be validated — the key use case for Pattern 1b Check 2.5. +func TestGoThirdPartyLocalLoader_MethodValidation(t *testing.T) { + tmpDir := t.TempDir() + + goMod := `module github.com/example/myapp + +go 1.21 + +require github.com/gin-gonic/gin v1.9.1 +` + err := os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644) + require.NoError(t, err) + + vendorDir := filepath.Join(tmpDir, "vendor", "github.com", "gin-gonic", "gin") + err = os.MkdirAll(vendorDir, 0755) + require.NoError(t, err) + + ginSrc := `package gin + +type Context struct { + Request interface{} + Writer interface{} +} + +func (c *Context) Query(key string) string { + return "" +} + +func (c *Context) Param(key string) string { + return "" +} + +func (c *Context) PostForm(key string) string { + return "" +} + +func (c *Context) JSON(code int, obj interface{}) {} + +type Engine struct{} + +func Default() *Engine { return nil } +` + err = os.WriteFile(filepath.Join(vendorDir, "context.go"), []byte(ginSrc), 0644) + require.NoError(t, err) + + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + + // Verify type and methods + ctxType, err := loader.GetType("github.com/gin-gonic/gin", "Context") + require.NoError(t, err) + require.NotNil(t, ctxType) + + // Methods that exist + _, hasQuery := ctxType.Methods["Query"] + assert.True(t, hasQuery, "Context should have Query method") + + _, hasParam := ctxType.Methods["Param"] + assert.True(t, hasParam, "Context should have Param method") + + _, hasPostForm := ctxType.Methods["PostForm"] + assert.True(t, hasPostForm, "Context should have PostForm method") + + _, hasJSON := ctxType.Methods["JSON"] + assert.True(t, hasJSON, "Context should have JSON method") + + // Method that doesn't exist — this is what Check 2.5 needs to detect + _, hasFake := ctxType.Methods["FakeMethod"] + assert.False(t, hasFake, "Context should NOT have FakeMethod") + + // Return type inference + defaultFn, err := loader.GetFunction("github.com/gin-gonic/gin", "Default") + require.NoError(t, err) + require.NotNil(t, defaultFn) + assert.Equal(t, "*Engine", defaultFn.Returns[0].Type) +} + +// TestInterfaceEmbedding_SamePackage tests that embedded interface methods are flattened. +// This is the posthog.Client pattern: Client embeds EnqueueClient and io.Closer. +func TestInterfaceEmbedding_SamePackage(t *testing.T) { + tmpDir := t.TempDir() + + // Simulate posthog package with interface embedding + src := `package posthog + +type Client interface { + io.Closer + EnqueueClient + + IsFeatureEnabled(payload FeatureFlagPayload) (interface{}, error) + GetFeatureFlag(payload FeatureFlagPayload) (interface{}, error) +} + +type EnqueueClient interface { + Enqueue(msg Message) error +} + +type Message interface{} +type FeatureFlagPayload struct{} +type Capture struct { + DistinctId string + Event string +} +` + err := os.WriteFile(filepath.Join(tmpDir, "posthog.go"), []byte(src), 0644) + require.NoError(t, err) + + pkg, err := extractGoPackageWithTreeSitter("github.com/posthog/posthog-go", tmpDir) + require.NoError(t, err) + require.NotNil(t, pkg) + + // Client interface should exist + clientType, ok := pkg.Types["Client"] + require.True(t, ok, "Client type should exist") + assert.Equal(t, "interface", clientType.Kind) + + // Direct methods + assert.Contains(t, clientType.Methods, "IsFeatureEnabled", "Direct method should be extracted") + assert.Contains(t, clientType.Methods, "GetFeatureFlag", "Direct method should be extracted") + + // Flattened from EnqueueClient (same-package embed) + assert.Contains(t, clientType.Methods, "Enqueue", + "Enqueue should be flattened from embedded EnqueueClient interface") + + // Flattened from io.Closer (cross-package embed, well-known) + assert.Contains(t, clientType.Methods, "Close", + "Close should be flattened from embedded io.Closer interface") + + // Verify Embeds field captured the embedding info + assert.Contains(t, clientType.Embeds, "EnqueueClient") + assert.Contains(t, clientType.Embeds, "io.Closer") +} + +// TestInterfaceEmbedding_MultiLevel tests multi-level embedding: +// A embeds B, B embeds C → A should have C's methods. +func TestInterfaceEmbedding_MultiLevel(t *testing.T) { + tmpDir := t.TempDir() + + src := `package pkg + +type Level1 interface { + Level2 + L1Method() string +} + +type Level2 interface { + Level3 + L2Method() int +} + +type Level3 interface { + L3Method() error +} +` + err := os.WriteFile(filepath.Join(tmpDir, "levels.go"), []byte(src), 0644) + require.NoError(t, err) + + pkg, err := extractGoPackageWithTreeSitter("example.com/pkg", tmpDir) + require.NoError(t, err) + + l1 := pkg.Types["Level1"] + require.NotNil(t, l1) + + assert.Contains(t, l1.Methods, "L1Method", "Own method") + assert.Contains(t, l1.Methods, "L2Method", "Flattened from Level2") + assert.Contains(t, l1.Methods, "L3Method", "Flattened from Level3 (multi-level)") +} + +// TestHasGoFiles tests the hasGoFiles helper. +func TestHasGoFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Empty dir + assert.False(t, hasGoFiles(tmpDir)) + + // Dir with non-Go file + err := os.WriteFile(filepath.Join(tmpDir, "readme.md"), []byte("# hi"), 0644) + require.NoError(t, err) + assert.False(t, hasGoFiles(tmpDir)) + + // Dir with test file only + err = os.WriteFile(filepath.Join(tmpDir, "foo_test.go"), []byte("package foo"), 0644) + require.NoError(t, err) + assert.False(t, hasGoFiles(tmpDir)) + + // Dir with Go source file + err = os.WriteFile(filepath.Join(tmpDir, "foo.go"), []byte("package foo"), 0644) + require.NoError(t, err) + assert.True(t, hasGoFiles(tmpDir)) +} + +// makeVendoredProject creates a minimal project with a vendored gorm stub +// and returns the project root directory. Shared by disk-cache tests. +func makeVendoredProject(t *testing.T) string { + t.Helper() + projectDir := t.TempDir() + + goMod := "module github.com/example/myapp\n\ngo 1.21\n\nrequire gorm.io/gorm v1.25.7\n" + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(goMod), 0644)) + + vendorDir := filepath.Join(projectDir, "vendor", "gorm.io", "gorm") + require.NoError(t, os.MkdirAll(vendorDir, 0755)) + + gormSrc := `package gorm +type DB struct{ Error error } +func (db *DB) Where(query interface{}, args ...interface{}) *DB { return db } +` + require.NoError(t, os.WriteFile(filepath.Join(vendorDir, "gorm.go"), []byte(gormSrc), 0644)) + return projectDir +} + +// TestDiskCacheWriteAndRead verifies that a cold extraction is persisted to disk +// and a second loader instance reads from the disk cache instead of re-parsing. +func TestDiskCacheWriteAndRead(t *testing.T) { + projectDir := makeVendoredProject(t) + + // Cold run: loader extracts from vendor/ and writes to disk cache. + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + dbType, err := loader1.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + require.NotNil(t, dbType) + assert.Contains(t, dbType.Methods, "Where") + + // Cache-index.json must exist after the cold run. + indexPath := filepath.Join(loader1.cacheDir, "cache-index.json") + _, statErr := os.Stat(indexPath) + assert.NoError(t, statErr, "cache-index.json should be written after cold extraction") + + // Package JSON file must exist. + entry := loader1.diskIndex.Entries["gorm.io/gorm"] + require.NotNil(t, entry, "cache-index.json should contain gorm.io/gorm entry") + pkgFile := filepath.Join(loader1.cacheDir, entry.File) + _, statErr = os.Stat(pkgFile) + assert.NoError(t, statErr, "package JSON file should exist on disk") + + // Warm run: new loader with same projectDir reads from disk cache. + // Remove vendor/ to prove it's not re-parsing from source. + require.NoError(t, os.RemoveAll(filepath.Join(projectDir, "vendor"))) + + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + dbType2, err := loader2.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + require.NotNil(t, dbType2, "disk cache hit should return the type even without vendor/") + assert.Contains(t, dbType2.Methods, "Where", "disk-cached type should retain methods") +} + +// TestCacheVersionMismatch verifies that a go.mod version bump causes re-extraction +// and invalidates the stale disk cache entry. +func TestCacheVersionMismatch(t *testing.T) { + projectDir := makeVendoredProject(t) + + // Cold run at v1.25.7. + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + _, err := loader1.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + + // Simulate a go.mod upgrade to v1.25.8: update go.mod and add new method to vendor source. + goMod := "module github.com/example/myapp\n\ngo 1.21\n\nrequire gorm.io/gorm v1.25.8\n" + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(goMod), 0644)) + + vendorDir := filepath.Join(projectDir, "vendor", "gorm.io", "gorm") + require.NoError(t, os.MkdirAll(vendorDir, 0755)) + newSrc := `package gorm +type DB struct{ Error error } +func (db *DB) Where(query interface{}, args ...interface{}) *DB { return db } +func (db *DB) Save(value interface{}) *DB { return db } +` + require.NoError(t, os.WriteFile(filepath.Join(vendorDir, "gorm.go"), []byte(newSrc), 0644)) + + // New loader: same cache dir, but go.mod says v1.25.8 — cache entry is v1.25.7 → miss. + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + dbType, err := loader2.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + require.NotNil(t, dbType) + + // New method from v1.25.8 source must be present (re-extraction happened). + assert.Contains(t, dbType.Methods, "Save", "re-extraction should pick up new method after version bump") + + // Cache-index entry should now reflect v1.25.8. + entry := loader2.diskIndex.Entries["gorm.io/gorm"] + require.NotNil(t, entry) + assert.Equal(t, "v1.25.8", entry.Version, "cache-index.json should be updated to new version") +} + +// TestRefreshCacheFlush verifies that refreshCache=true wipes existing cache files +// and forces a fresh extraction on the next load. +func TestRefreshCacheFlush(t *testing.T) { + projectDir := makeVendoredProject(t) + + // Cold run to populate cache. + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, nil) + _, err := loader1.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + cacheDir := loader1.cacheDir + + // Verify cache files exist. + indexPath := filepath.Join(cacheDir, "cache-index.json") + _, statErr := os.Stat(indexPath) + require.NoError(t, statErr, "cache-index.json should exist after cold run") + + // Inject a sentinel into the cached package JSON to detect whether it gets reused. + entry := loader1.diskIndex.Entries["gorm.io/gorm"] + require.NotNil(t, entry) + pkgPath := filepath.Join(cacheDir, entry.File) + data, err := os.ReadFile(pkgPath) + require.NoError(t, err) + var pkgMap map[string]interface{} + require.NoError(t, json.Unmarshal(data, &pkgMap)) + pkgMap["_sentinel"] = "stale" + patched, err := json.Marshal(pkgMap) + require.NoError(t, err) + require.NoError(t, os.WriteFile(pkgPath, patched, 0644)) + + // Refresh run: cache dir is wiped, extraction happens from vendor/ again. + loader2 := NewGoThirdPartyLocalLoader(projectDir, true, nil) + dbType, err := loader2.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + require.NotNil(t, dbType) + assert.Contains(t, dbType.Methods, "Where", "fresh extraction should succeed after cache flush") + + // The new cache entry should NOT contain the sentinel we injected. + entry2 := loader2.diskIndex.Entries["gorm.io/gorm"] + require.NotNil(t, entry2) + freshData, err := os.ReadFile(filepath.Join(cacheDir, entry2.File)) + require.NoError(t, err) + assert.NotContains(t, string(freshData), `"_sentinel"`, "flushed cache should not contain stale sentinel") +} + +// TestPackageCount verifies PackageCount reflects the number of go.mod requires. +func TestPackageCount(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + assert.Equal(t, 1, loader.PackageCount()) +} + +// TestGetType_NotFound verifies that GetType returns an error for unknown types. +func TestGoThirdPartyLocalGetType_NotFound(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + + _, err := loader.GetType("gorm.io/gorm", "NonExistentType") + assert.Error(t, err) +} + +// TestGetFunction_NotFound verifies that GetFunction returns an error for unknown functions. +func TestGoThirdPartyLocalGetFunction_NotFound(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + + _, err := loader.GetFunction("gorm.io/gorm", "NonExistentFunc") + assert.Error(t, err) +} + +// TestGetType_PackageNotFound verifies that GetType returns an error when the +// package cannot be located in vendor/ or GOMODCACHE. +func TestGetType_PackageNotFound(t *testing.T) { + tmpDir := t.TempDir() + goMod := "module github.com/example/myapp\n\ngo 1.21\n\nrequire github.com/missing/pkg v1.0.0\n" + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte(goMod), 0644)) + + // No vendor/ directory, no GOMODCACHE entry → source not found. + loader := NewGoThirdPartyLocalLoader(tmpDir, false, nil) + loader.cacheDir = "" // disable disk cache so we go straight to findPackageSource + + _, err := loader.GetType("github.com/missing/pkg", "SomeType") + assert.Error(t, err) +} + +// TestFindPackageSource_GOMODCACHE verifies that findPackageSource falls back to +// GOMODCACHE when vendor/ does not contain the package. +func TestFindPackageSource_GOMODCACHE(t *testing.T) { + projectDir := t.TempDir() + + goMod := "module github.com/example/myapp\n\ngo 1.21\n\nrequire example.com/mylib v1.2.3\n" + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(goMod), 0644)) + + // Set up a fake GOMODCACHE with the module source. + fakeCache := t.TempDir() + modDir := filepath.Join(fakeCache, "example.com", "mylib@v1.2.3") + require.NoError(t, os.MkdirAll(modDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(modDir, "mylib.go"), []byte(`package mylib + +type Client struct{} +func (c *Client) Call() string { return "" } +`), 0644)) + + t.Setenv("GOMODCACHE", fakeCache) + + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader.cacheDir = "" // disable disk cache + + typ, err := loader.GetType("example.com/mylib", "Client") + require.NoError(t, err) + require.NotNil(t, typ) + assert.Contains(t, typ.Methods, "Call") +} + +// TestFindPackageSource_GOMODCACHE_Subpackage verifies resolution of a subpackage +// (import path = module/subpkg) from GOMODCACHE. +func TestFindPackageSource_GOMODCACHE_Subpackage(t *testing.T) { + projectDir := t.TempDir() + + goMod := "module github.com/example/myapp\n\ngo 1.21\n\nrequire example.com/sdk v2.0.0\n" + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(goMod), 0644)) + + fakeCache := t.TempDir() + subDir := filepath.Join(fakeCache, "example.com", "sdk@v2.0.0", "auth") + require.NoError(t, os.MkdirAll(subDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(subDir, "auth.go"), []byte(`package auth + +type Token struct{} +func (t *Token) Verify() bool { return true } +`), 0644)) + + t.Setenv("GOMODCACHE", fakeCache) + + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader.cacheDir = "" + + typ, err := loader.GetType("example.com/sdk/auth", "Token") + require.NoError(t, err) + require.NotNil(t, typ) + assert.Contains(t, typ.Methods, "Verify") +} + +// TestLoadCacheIndex_InvalidJSON verifies that a corrupt cache-index.json +// produces an empty (not nil) index rather than a crash. +func TestLoadCacheIndex_InvalidJSON(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + + // Overwrite cache-index.json with garbage. + require.NoError(t, os.WriteFile( + filepath.Join(loader.cacheDir, "cache-index.json"), + []byte("not-valid-json{{{"), + 0644, + )) + + // Re-load the index — should return an empty index, not panic. + idx := loader.loadCacheIndex() + require.NotNil(t, idx) + assert.Empty(t, idx.Entries) +} + +// TestWriteToDiskCache_NilDiskIndex verifies writeToDiskCache is a no-op when +// diskIndex is nil (disk cache unavailable). +func TestWriteToDiskCache_NilDiskIndex(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader.diskIndex = nil // simulate unavailable disk cache + + // Should not panic. + loader.writeToDiskCache("gorm.io/gorm", nil) +} + +// TestIsExported_EmptyString verifies isExported handles empty input without panic. +func TestIsExported_EmptyString(t *testing.T) { + assert.False(t, isExported("")) +} + +// TestExtractMethodDecl_NoReceiverType verifies that method declarations with +// no parseable receiver type are silently skipped. +func TestExtractMethodDecl_NoReceiverType(t *testing.T) { + tmpDir := t.TempDir() + + // Valid method + exported function only; no unusual receiver syntax. + src := `package mypkg + +type Svc struct{} + +func (s *Svc) DoWork() string { return "" } +func StandaloneFunc() int { return 0 } +` + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "svc.go"), []byte(src), 0644)) + + pkg, err := extractGoPackageWithTreeSitter("example.com/mypkg", tmpDir) + require.NoError(t, err) + require.NotNil(t, pkg) + + assert.Contains(t, pkg.Types, "Svc") + assert.Contains(t, pkg.Types["Svc"].Methods, "DoWork") + assert.Contains(t, pkg.Functions, "StandaloneFunc") +} + +// TestInitDiskCache_NoGoMod verifies that a project with no go.mod produces a +// loader with PackageCount == 0 (no crash, graceful degradation). +func TestInitDiskCache_NoGoMod(t *testing.T) { + emptyDir := t.TempDir() + loader := NewGoThirdPartyLocalLoader(emptyDir, false, nil) + assert.Equal(t, 0, loader.PackageCount()) +} + +// TestModuleKeyFor_NoMatch verifies moduleKeyFor returns the importPath itself +// when no module prefix matches — the fallback branch. +func TestModuleKeyFor_NoMatch(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + + // "unknown.io/pkg" is not in go.mod → should return the importPath unchanged. + key := loader.moduleKeyFor("unknown.io/pkg") + assert.Equal(t, "unknown.io/pkg", key) +} + +// TestLoadFromDiskCache_MissingFile verifies that loadFromDiskCache returns nil +// when cache-index.json references a file that no longer exists on disk. +func TestLoadFromDiskCache_MissingFile(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + + // Populate index with an entry pointing at a non-existent file. + loader.diskIndex = &cacheIndex{ + Version: cacheIndexVersion, + Entries: map[string]*cacheIndexEntry{ + "gorm.io/gorm": {Version: "v1.25.7", File: "nonexistent_file.json"}, + }, + } + result := loader.loadFromDiskCache("gorm.io/gorm") + assert.Nil(t, result) +} + +// TestLoadFromDiskCache_WithLogger verifies the debug-log branch executes when +// a logger is attached and the cache hits successfully. +func TestLoadFromDiskCache_WithLogger(t *testing.T) { + projectDir := makeVendoredProject(t) + logger := output.NewLogger(output.VerbosityDefault) + loader := NewGoThirdPartyLocalLoader(projectDir, false, logger) + + // Cold run to populate the disk cache. + _, err := loader.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + + // Warm run: build a new loader sharing the same cacheDir to hit the debug log branch. + loader2 := NewGoThirdPartyLocalLoader(projectDir, false, logger) + result := loader2.loadFromDiskCache("gorm.io/gorm") + assert.NotNil(t, result) +} + +// TestSaveCacheIndex_EmptyCacheDir verifies saveCacheIndex is a no-op when cacheDir is empty. +func TestSaveCacheIndex_EmptyCacheDir(t *testing.T) { + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + loader.cacheDir = "" // simulate unavailable cache dir + + // Should not panic. + loader.saveCacheIndex() +} + +// TestWriteToDiskCache_WriteFailure verifies that writeToDiskCache logs and +// continues gracefully when the output file cannot be written. +func TestWriteToDiskCache_WriteFailure(t *testing.T) { + projectDir := makeVendoredProject(t) + logger := output.NewLogger(output.VerbosityDefault) + loader := NewGoThirdPartyLocalLoader(projectDir, false, logger) + + // Point cacheDir at a regular file so os.WriteFile fails with ENOTDIR/EISDIR. + fakeFile := filepath.Join(t.TempDir(), "not-a-dir") + require.NoError(t, os.WriteFile(fakeFile, []byte("x"), 0644)) + loader.cacheDir = fakeFile + + // Should not panic; WriteFile will fail but error is swallowed. + loader.writeToDiskCache("gorm.io/gorm", &core.GoStdlibPackage{ImportPath: "gorm.io/gorm"}) +} + +// TestInitDiskCache_MkdirAllFailure verifies initDiskCache handles the case +// where the cache directory cannot be created. +func TestInitDiskCache_MkdirAllFailure(t *testing.T) { + // Put a regular file at the path we'll try to mkdir — this forces MkdirAll to fail. + parent := t.TempDir() + blockingFile := filepath.Join(parent, "blocked") + require.NoError(t, os.WriteFile(blockingFile, []byte("x"), 0644)) + + projectDir := makeVendoredProject(t) + loader := NewGoThirdPartyLocalLoader(projectDir, false, nil) + // Override cacheDir to a path underneath a file (impossible to mkdir). + loader.cacheDir = filepath.Join(blockingFile, "subdir") + loader.diskIndex = nil + + // initDiskCache must not panic; diskIndex stays nil on failure. + loader.initDiskCache(false) + assert.Nil(t, loader.diskIndex) +} + +// TestInitDiskCache_RefreshWithLogger verifies the refreshCache=true path +// with a logger attached (covers the RemoveAll + logger.Debug branch). +func TestInitDiskCache_RefreshWithLogger(t *testing.T) { + projectDir := makeVendoredProject(t) + logger := output.NewLogger(output.VerbosityDefault) + + // First pass: populate cache. + loader1 := NewGoThirdPartyLocalLoader(projectDir, false, logger) + _, err := loader1.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + + // Second pass with refreshCache=true should flush and still work. + loader2 := NewGoThirdPartyLocalLoader(projectDir, true, logger) + typ, err := loader2.GetType("gorm.io/gorm", "DB") + require.NoError(t, err) + assert.NotNil(t, typ) +} + +// TestExtractMethodDecl_UnexportedMethod verifies that unexported methods are skipped. +func TestExtractMethodDecl_UnexportedMethod(t *testing.T) { + tmpDir := t.TempDir() + src := `package mypkg + +type Svc struct{} + +func (s *Svc) ExportedMethod() string { return "" } +func (s *Svc) unexportedMethod() string { return "" } +` + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "svc.go"), []byte(src), 0644)) + + pkg, err := extractGoPackageWithTreeSitter("example.com/mypkg", tmpDir) + require.NoError(t, err) + require.NotNil(t, pkg) + + svc := pkg.Types["Svc"] + require.NotNil(t, svc) + assert.Contains(t, svc.Methods, "ExportedMethod") + assert.NotContains(t, svc.Methods, "unexportedMethod") +}