diff --git a/sast-engine/graph/callgraph/builder/go_builder.go b/sast-engine/graph/callgraph/builder/go_builder.go index bfefeb1b..3fc1014d 100644 --- a/sast-engine/graph/callgraph/builder/go_builder.go +++ b/sast-engine/graph/callgraph/builder/go_builder.go @@ -71,7 +71,7 @@ func BuildGoCallGraph(codeGraph *graph.CodeGraph, registry *core.GoModuleRegistr // Pass 1: Index all function definitions fmt.Fprintf(os.Stderr, " Pass 1: Indexing functions...\n") - functionContext := indexGoFunctions(codeGraph, callGraph, registry) + functionContext := indexGoFunctions(codeGraph, callGraph, registry, typeEngine) fmt.Fprintf(os.Stderr, " Indexed %d functions\n", len(callGraph.Functions)) // Pass 2a: Extract return types from all indexed Go functions @@ -112,7 +112,7 @@ func BuildGoCallGraph(codeGraph *graph.CodeGraph, registry *core.GoModuleRegistr // Extract variable assignments for this file // ExtractGoVariableAssignments is thread-safe (uses mutex internally) - _ = extraction.ExtractGoVariableAssignments(filePath, sourceCode, typeEngine, registry, importMaps[filePath]) + _ = extraction.ExtractGoVariableAssignments(filePath, sourceCode, typeEngine, registry, importMaps[filePath], callGraph) // Progress tracking count := varProcessed.Add(1) @@ -284,7 +284,7 @@ func BuildGoCallGraph(codeGraph *graph.CodeGraph, registry *core.GoModuleRegistr // // Returns: // - functionContext: map from simple name to list of nodes for resolution -func indexGoFunctions(codeGraph *graph.CodeGraph, callGraph *core.CallGraph, registry *core.GoModuleRegistry) map[string][]*graph.Node { +func indexGoFunctions(codeGraph *graph.CodeGraph, callGraph *core.CallGraph, registry *core.GoModuleRegistry, typeEngine *resolution.GoTypeInferenceEngine) map[string][]*graph.Node { functionContext := make(map[string][]*graph.Node) totalNodes := len(codeGraph.Nodes) @@ -312,6 +312,12 @@ func indexGoFunctions(codeGraph *graph.CodeGraph, callGraph *core.CallGraph, reg // Add to CallGraph.Functions callGraph.Functions[fqn] = node + // Eagerly create scope so Pattern 1b Source 2 always finds one. + // Guard with GetScope == nil so Pass 2b bindings are not overwritten. + if typeEngine != nil && typeEngine.GetScope(fqn) == nil { + typeEngine.AddScope(resolution.NewGoFunctionScope(fqn)) + } + // Add to function context for name-based lookup functionContext[node.Name] = append(functionContext[node.Name], node) indexed++ diff --git a/sast-engine/graph/callgraph/builder/go_builder_scope_test.go b/sast-engine/graph/callgraph/builder/go_builder_scope_test.go new file mode 100644 index 00000000..c9ac3571 --- /dev/null +++ b/sast-engine/graph/callgraph/builder/go_builder_scope_test.go @@ -0,0 +1,141 @@ +package builder + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/resolution" + "github.com/stretchr/testify/assert" +) + +// TestIndexGoFunctions_EagerScopeCreation verifies that indexGoFunctions creates +// an empty GoFunctionScope for every indexed Go function even when no variable +// bindings are available. This ensures Pattern 1b Source 2 always finds a scope. +func TestIndexGoFunctions_EagerScopeCreation(t *testing.T) { + codeGraph := &graph.CodeGraph{ + Nodes: map[string]*graph.Node{ + "fn1": { + ID: "fn1", + Type: "function_declaration", + Name: "HandleRequest", + File: "/project/main.go", + }, + "fn2": { + ID: "fn2", + Type: "method", + Name: "Close", + File: "/project/main.go", + }, + "other": { + ID: "other", + Type: "identifier", // non-function node — must be skipped + Name: "x", + File: "/project/main.go", + }, + }, + } + + callGraph := &core.CallGraph{ + Functions: make(map[string]*graph.Node), + } + + registry := core.NewGoModuleRegistry() + registry.ModulePath = "github.com/example/app" + registry.DirToImport = map[string]string{ + "/project": "github.com/example/app", + } + + typeEngine := resolution.NewGoTypeInferenceEngine(registry) + + indexGoFunctions(codeGraph, callGraph, registry, typeEngine) + + // Every indexed function should have an eager scope. + for fqn := range callGraph.Functions { + scope := typeEngine.GetScope(fqn) + assert.NotNil(t, scope, "scope should be eagerly created for %s", fqn) + } + + // Non-function node must not appear in Functions map. + for fqn := range callGraph.Functions { + assert.NotContains(t, fqn, "identifier") + } +} + +// TestIndexGoFunctions_EagerScope_NilTypeEngine verifies that passing nil for +// typeEngine does not panic and still indexes functions normally. +func TestIndexGoFunctions_EagerScope_NilTypeEngine(t *testing.T) { + codeGraph := &graph.CodeGraph{ + Nodes: map[string]*graph.Node{ + "fn1": { + ID: "fn1", + Type: "function_declaration", + Name: "Run", + File: "/project/main.go", + }, + }, + } + + callGraph := &core.CallGraph{ + Functions: make(map[string]*graph.Node), + } + + registry := core.NewGoModuleRegistry() + registry.ModulePath = "github.com/example/app" + registry.DirToImport = map[string]string{ + "/project": "github.com/example/app", + } + + // Must not panic with nil typeEngine. + assert.NotPanics(t, func() { + indexGoFunctions(codeGraph, callGraph, registry, nil) + }) + + assert.NotEmpty(t, callGraph.Functions) +} + +// TestIndexGoFunctions_EagerScope_NotOverwritten verifies that if a scope already +// exists (e.g., created by a previous Pass 2b run), it is not replaced. +func TestIndexGoFunctions_EagerScope_NotOverwritten(t *testing.T) { + fqn := "github.com/example/app.ExistingFunc" + + codeGraph := &graph.CodeGraph{ + Nodes: map[string]*graph.Node{ + "fn1": { + ID: "fn1", + Type: "function_declaration", + Name: "ExistingFunc", + File: "/project/main.go", + }, + }, + } + + callGraph := &core.CallGraph{ + Functions: make(map[string]*graph.Node), + } + + registry := core.NewGoModuleRegistry() + registry.ModulePath = "github.com/example/app" + registry.DirToImport = map[string]string{ + "/project": "github.com/example/app", + } + + typeEngine := resolution.NewGoTypeInferenceEngine(registry) + + // Pre-create scope with a binding to simulate Pass 2b having run first. + preScope := resolution.NewGoFunctionScope(fqn) + preScope.AddVariable(&resolution.GoVariableBinding{ + VarName: "existing", + Type: &core.TypeInfo{TypeFQN: "builtin.string"}, + }) + typeEngine.AddScope(preScope) + + indexGoFunctions(codeGraph, callGraph, registry, typeEngine) + + // The pre-created scope must still have the binding. + scope := typeEngine.GetScope(fqn) + assert.NotNil(t, scope) + bindings, ok := scope.Variables["existing"] + assert.True(t, ok, "pre-existing binding should not be overwritten") + assert.Len(t, bindings, 1) +} diff --git a/sast-engine/graph/callgraph/builder/go_builder_test.go b/sast-engine/graph/callgraph/builder/go_builder_test.go index 121dd4a5..a1dd41c2 100644 --- a/sast-engine/graph/callgraph/builder/go_builder_test.go +++ b/sast-engine/graph/callgraph/builder/go_builder_test.go @@ -161,7 +161,7 @@ func TestIndexGoFunctions(t *testing.T) { }, } - functionContext := indexGoFunctions(codeGraph, callGraph, registry) + functionContext := indexGoFunctions(codeGraph, callGraph, registry, nil) // Check expected FQNs for _, expectedFQN := range tt.expectedFQNs { diff --git a/sast-engine/graph/callgraph/extraction/go_variables.go b/sast-engine/graph/callgraph/extraction/go_variables.go index deabd13a..3cc0a2ab 100644 --- a/sast-engine/graph/callgraph/extraction/go_variables.go +++ b/sast-engine/graph/callgraph/extraction/go_variables.go @@ -68,6 +68,7 @@ func ExtractGoVariableAssignments( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) error { // Parse with tree-sitter parser := sitter.NewParser() @@ -106,6 +107,7 @@ func ExtractGoVariableAssignments( typeEngine, registry, importMap, + callGraph, ) return nil @@ -122,6 +124,7 @@ func traverseForVariableAssignments( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) { if node == nil { return @@ -163,6 +166,7 @@ func traverseForVariableAssignments( typeEngine, registry, importMap, + callGraph, ) } @@ -177,6 +181,7 @@ func traverseForVariableAssignments( typeEngine, registry, importMap, + callGraph, ) } } @@ -193,6 +198,7 @@ func traverseForVariableAssignments( typeEngine, registry, importMap, + callGraph, ) } } @@ -226,6 +232,7 @@ func processShortVarDeclaration( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) { // Use existing helper to extract variable info varInfos := golangpkg.ParseShortVarDeclaration(node, sourceCode) @@ -256,6 +263,7 @@ func processShortVarDeclaration( typeEngine, registry, importMap, + callGraph, ) if typeInfo == nil { @@ -296,6 +304,7 @@ func processAssignmentStatement( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) { // Use existing helper to extract variable info varInfos := golangpkg.ParseAssignment(node, sourceCode) @@ -320,6 +329,7 @@ func processAssignmentStatement( typeEngine, registry, importMap, + callGraph, ) if typeInfo == nil { @@ -367,6 +377,7 @@ func inferTypeFromRHS( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) *core.TypeInfo { if rhsNode == nil { return nil @@ -432,14 +443,12 @@ func inferTypeFromRHS( // Function call - look up return type case "call_expression": - return inferTypeFromFunctionCall( - rhsNode, - sourceCode, - filePath, - typeEngine, - registry, - importMap, - ) + if result := inferTypeFromFunctionCall(rhsNode, sourceCode, filePath, typeEngine, registry, importMap); result != nil { + return result + } + // Param-aware fallback: if RHS is obj.Method() and obj is a function parameter, + // resolve the method's return type via StdlibLoader / ThirdPartyLoader. + return inferTypeFromParamMethodCall(rhsNode, sourceCode, functionFQN, callGraph, registry, importMap) // Variable reference - copy type from scope case "identifier": @@ -465,6 +474,7 @@ func inferTypeFromRHS( typeEngine, registry, importMap, + callGraph, ) // Expression list - for multi-assignment, get first element @@ -479,6 +489,7 @@ func inferTypeFromRHS( typeEngine, registry, importMap, + callGraph, ) } return nil @@ -651,6 +662,142 @@ func inferTypeFromThirdPartyFunction(importPath, funcName string, registry *core return nil } +// inferTypeFromParamMethodCall resolves the return type of a method call where the +// receiver is a function parameter (e.g. r.FormValue("id") when r is *http.Request). +// +// This is the param-aware fallback in inferTypeFromRHS: it fires only when the +// standard inferTypeFromFunctionCall path returned nil (i.e., the receiver is not +// a package alias and not tracked as a :=-variable in the scope). +// +// Resolution order follows the Check 2 / Check 2.5 precedence: +// 1. StdlibLoader — for stdlib types (e.g. net/http.Request) +// 2. ThirdPartyLoader — for vendored/GOMODCACHE types (e.g. gin.Context) +func inferTypeFromParamMethodCall( + callNode *sitter.Node, + sourceCode []byte, + functionFQN string, + callGraph *core.CallGraph, + registry *core.GoModuleRegistry, + importMap *core.GoImportMap, +) *core.TypeInfo { + if callGraph == nil || callNode == nil { + return nil + } + + // Must be a selector_expression receiver: obj.Method(...) + funcNode := callNode.ChildByFieldName("function") + if funcNode == nil || funcNode.Type() != "selector_expression" { + return nil + } + + operandNode := funcNode.ChildByFieldName("operand") + fieldNode := funcNode.ChildByFieldName("field") + if operandNode == nil || fieldNode == nil { + return nil + } + + objectName := operandNode.Content(sourceCode) + methodName := fieldNode.Content(sourceCode) + + // If the operand is a known package alias, it was already handled by inferTypeFromFunctionCall. + if importMap != nil { + if _, ok := importMap.Imports[objectName]; ok { + return nil + } + } + + // Look up the enclosing function's parameter list. + callerNode, ok := callGraph.Functions[functionFQN] + if !ok || callerNode == nil { + return nil + } + + for i, paramName := range callerNode.MethodArgumentsValue { + if paramName != objectName || i >= len(callerNode.MethodArgumentsType) { + continue + } + + typeStr := callerNode.MethodArgumentsType[i] + // Strip "name: " prefix that the parser sometimes prepends. + if colonIdx := strings.Index(typeStr, ": "); colonIdx >= 0 { + typeStr = typeStr[colonIdx+2:] + } + // Strip pointer qualifier — we look up the base type. + typeStr = strings.TrimPrefix(typeStr, "*") + + // Resolve short qualifier (e.g. "http.Request" → "net/http.Request"). + paramTypeFQN := extractionResolveGoTypeFQN(typeStr, importMap) + + importPath, typeName, split := extractionSplitGoTypeFQN(paramTypeFQN) + if !split { + continue + } + + // Check StdlibLoader first, then ThirdPartyLoader. + var method *core.GoStdlibFunction + if registry.StdlibLoader != nil { + if t, err := registry.StdlibLoader.GetType(importPath, typeName); err == nil && t != nil { + method = t.Methods[methodName] + } + } + if method == nil && registry.ThirdPartyLoader != nil { + if t, err := registry.ThirdPartyLoader.GetType(importPath, typeName); err == nil && t != nil { + method = t.Methods[methodName] + } + } + if method == nil || len(method.Returns) == 0 { + continue + } + + for _, ret := range method.Returns { + if ret.Type == "" || ret.Type == "error" { + continue + } + return &core.TypeInfo{ + TypeFQN: normalizeStdlibReturnType(ret.Type, importPath), + Confidence: 0.85, + Source: "method_return_type", + } + } + } + + return nil +} + +// extractionSplitGoTypeFQN splits a fully-qualified Go type name into its package +// import path and type name. Duplicated from builder/helpers.go to avoid an +// import cycle (extraction → builder → extraction). +func extractionSplitGoTypeFQN(typeFQN string) (importPath, typeName string, ok bool) { + if typeFQN == "" { + return "", "", false + } + lastDot := strings.LastIndex(typeFQN, ".") + if lastDot < 0 || lastDot == len(typeFQN)-1 { + return "", "", false + } + return typeFQN[:lastDot], typeFQN[lastDot+1:], true +} + +// extractionResolveGoTypeFQN resolves a short Go type name to a fully-qualified +// import path using the file's import map. Duplicated from builder/helpers.go to +// avoid an import cycle. +func extractionResolveGoTypeFQN(shortType string, importMap *core.GoImportMap) string { + if shortType == "" || importMap == nil { + return shortType + } + dotIdx := strings.Index(shortType, ".") + if dotIdx < 0 { + return shortType + } + alias := shortType[:dotIdx] + rest := shortType[dotIdx+1:] + importPath, ok := importMap.Resolve(alias) + if !ok { + return shortType + } + return importPath + "." + rest +} + // extractGoFunctionName extracts the function name from a function node. // Handles: // - Simple calls: foo() @@ -761,6 +908,7 @@ func inferTypeFromUnaryExpression( typeEngine *resolution.GoTypeInferenceEngine, registry *core.GoModuleRegistry, importMap *core.GoImportMap, + callGraph *core.CallGraph, ) *core.TypeInfo { // Check operator operatorNode := unaryNode.ChildByFieldName("operator") @@ -789,6 +937,7 @@ func inferTypeFromUnaryExpression( typeEngine, registry, importMap, + callGraph, ) default: diff --git a/sast-engine/graph/callgraph/extraction/go_variables_param_test.go b/sast-engine/graph/callgraph/extraction/go_variables_param_test.go new file mode 100644 index 00000000..6186da61 --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/go_variables_param_test.go @@ -0,0 +1,537 @@ +package extraction + +import ( + "testing" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/resolution" + "github.com/stretchr/testify/assert" +) + +// mockStdlibLoaderWithTypes extends mockStdlibLoader to support GetType. +type mockStdlibLoaderWithTypes struct { + stdlibPkgs map[string]bool + functions map[string]*core.GoStdlibFunction // key: "importPath.funcName" + types map[string]*core.GoStdlibType // key: "importPath.TypeName" +} + +func (m *mockStdlibLoaderWithTypes) ValidateStdlibImport(importPath string) bool { + return m.stdlibPkgs[importPath] +} + +func (m *mockStdlibLoaderWithTypes) GetFunction(importPath, funcName string) (*core.GoStdlibFunction, error) { + key := importPath + "." + funcName + fn, ok := m.functions[key] + if !ok { + return nil, errMockNotImplemented + } + return fn, nil +} + +func (m *mockStdlibLoaderWithTypes) GetType(importPath, typeName string) (*core.GoStdlibType, error) { + key := importPath + "." + typeName + t, ok := m.types[key] + if !ok { + return nil, errMockNotImplemented + } + return t, nil +} + +func (m *mockStdlibLoaderWithTypes) PackageCount() int { + return len(m.stdlibPkgs) +} + +// mockThirdPartyLoaderWithTypes is a core.GoThirdPartyLoader that serves type data. +type mockThirdPartyLoaderWithTypes struct { + packages map[string]bool + types map[string]*core.GoStdlibType // key: "importPath.TypeName" +} + +func (m *mockThirdPartyLoaderWithTypes) ValidateImport(importPath string) bool { + return m.packages[importPath] +} + +func (m *mockThirdPartyLoaderWithTypes) GetFunction(_, _ string) (*core.GoStdlibFunction, error) { + return nil, errMockNotImplemented +} + +func (m *mockThirdPartyLoaderWithTypes) GetType(importPath, typeName string) (*core.GoStdlibType, error) { + key := importPath + "." + typeName + t, ok := m.types[key] + if !ok { + return nil, errMockNotImplemented + } + return t, nil +} + +func (m *mockThirdPartyLoaderWithTypes) PackageCount() int { + return len(m.packages) +} + +// --------------------------------------------------------------------------- +// extractionSplitGoTypeFQN tests +// --------------------------------------------------------------------------- + +func TestExtractionSplitGoTypeFQN_Valid(t *testing.T) { + imp, name, ok := extractionSplitGoTypeFQN("net/http.Request") + assert.True(t, ok) + assert.Equal(t, "net/http", imp) + assert.Equal(t, "Request", name) +} + +func TestExtractionSplitGoTypeFQN_NoPackage(t *testing.T) { + _, _, ok := extractionSplitGoTypeFQN("string") + assert.False(t, ok) +} + +func TestExtractionSplitGoTypeFQN_Empty(t *testing.T) { + _, _, ok := extractionSplitGoTypeFQN("") + assert.False(t, ok) +} + +func TestExtractionSplitGoTypeFQN_TrailingDot(t *testing.T) { + _, _, ok := extractionSplitGoTypeFQN("net/http.") + assert.False(t, ok) +} + +// --------------------------------------------------------------------------- +// extractionResolveGoTypeFQN tests +// --------------------------------------------------------------------------- + +func TestExtractionResolveGoTypeFQN_KnownAlias(t *testing.T) { + importMap := &core.GoImportMap{ + Imports: map[string]string{"http": "net/http"}, + } + result := extractionResolveGoTypeFQN("http.Request", importMap) + assert.Equal(t, "net/http.Request", result) +} + +func TestExtractionResolveGoTypeFQN_UnknownAlias(t *testing.T) { + importMap := &core.GoImportMap{ + Imports: map[string]string{}, + } + result := extractionResolveGoTypeFQN("gin.Context", importMap) + assert.Equal(t, "gin.Context", result) +} + +func TestExtractionResolveGoTypeFQN_NilImportMap(t *testing.T) { + result := extractionResolveGoTypeFQN("http.Request", nil) + assert.Equal(t, "http.Request", result) +} + +func TestExtractionResolveGoTypeFQN_Unqualified(t *testing.T) { + importMap := &core.GoImportMap{Imports: map[string]string{}} + result := extractionResolveGoTypeFQN("MyStruct", importMap) + assert.Equal(t, "MyStruct", result) +} + +// --------------------------------------------------------------------------- +// inferTypeFromParamMethodCall — unit tests via ExtractGoVariableAssignments +// --------------------------------------------------------------------------- + +// TestParamAwareRHSInference_StdlibParam tests that `input := r.FormValue("id")` +// resolves to builtin.string when r is a *http.Request parameter. +func TestParamAwareRHSInference_StdlibParam(t *testing.T) { + code := `package main + +import "net/http" + +func handler(w http.ResponseWriter, r *http.Request) { + input := r.FormValue("id") + _ = input +} +` + + // Build a minimal registry pointing /test → "main". + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + // Attach a stdlib loader that knows about net/http.Request.FormValue. + loader := &mockStdlibLoaderWithTypes{ + stdlibPkgs: map[string]bool{"net/http": true}, + functions: map[string]*core.GoStdlibFunction{}, + types: map[string]*core.GoStdlibType{ + "net/http.Request": { + Name: "Request", + Methods: map[string]*core.GoStdlibFunction{ + "FormValue": { + Name: "FormValue", + Returns: []*core.GoReturnValue{{Type: "string"}}, + }, + }, + }, + }, + } + reg.StdlibLoader = loader + + importMap := &core.GoImportMap{ + Imports: map[string]string{"http": "net/http"}, + } + + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + // Build a minimal call graph with the handler function so param lookup works. + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.handler": { + ID: "handler", + Name: "handler", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"w", "r"}, + MethodArgumentsType: []string{"w: http.ResponseWriter", "r: *http.Request"}, + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + scope := typeEngine.GetScope("main.handler") + assert.NotNil(t, scope, "scope should exist for handler") + + bindings, ok := scope.Variables["input"] + assert.True(t, ok, "input binding should be created") + if assert.Len(t, bindings, 1) { + assert.Equal(t, "builtin.string", bindings[0].Type.TypeFQN) + assert.Equal(t, "method_return_type", bindings[0].Type.Source) + assert.InDelta(t, 0.85, bindings[0].Type.Confidence, 0.001) + } +} + +// TestParamAwareRHSInference_NilCallGraph ensures no panic and graceful nil +// return when callGraph is nil. +func TestParamAwareRHSInference_NilCallGraph(t *testing.T) { + code := `package main + +import "net/http" + +func handler(r *http.Request) { + input := r.FormValue("id") + _ = input +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + importMap := &core.GoImportMap{Imports: map[string]string{"http": "net/http"}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + // nil callGraph — must not panic; input binding should simply be absent. + assert.NotPanics(t, func() { + _ = ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, nil) + }) +} + +// TestParamAwareRHSInference_ThirdPartyParam tests that `q := c.Query("search")` +// resolves via ThirdPartyLoader when c is a third-party type parameter. +func TestParamAwareRHSInference_ThirdPartyParam(t *testing.T) { + code := `package main + +import "github.com/gin-gonic/gin" + +func handle(c *gin.Context) { + q := c.Query("search") + _ = q +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + tpLoader := &mockThirdPartyLoaderWithTypes{ + packages: map[string]bool{"github.com/gin-gonic/gin": true}, + types: map[string]*core.GoStdlibType{ + "github.com/gin-gonic/gin.Context": { + Name: "Context", + Methods: map[string]*core.GoStdlibFunction{ + "Query": { + Name: "Query", + Returns: []*core.GoReturnValue{{Type: "string"}}, + }, + }, + }, + }, + } + reg.ThirdPartyLoader = tpLoader + + importMap := &core.GoImportMap{ + Imports: map[string]string{"gin": "github.com/gin-gonic/gin"}, + } + + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.handle": { + ID: "handle", + Name: "handle", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"c"}, + MethodArgumentsType: []string{"c: *gin.Context"}, + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + scope := typeEngine.GetScope("main.handle") + assert.NotNil(t, scope) + + bindings, ok := scope.Variables["q"] + assert.True(t, ok, "q binding should be created via ThirdPartyLoader") + if assert.Len(t, bindings, 1) { + assert.Equal(t, "builtin.string", bindings[0].Type.TypeFQN) + assert.Equal(t, "method_return_type", bindings[0].Type.Source) + } +} + +// TestParamAwareRHSInference_UnknownParam tests that no binding is created when +// the parameter name is not in the function's MethodArgumentsValue list. +func TestParamAwareRHSInference_UnknownParam(t *testing.T) { + code := `package main + +import "net/http" + +func handler(r *http.Request) { + input := x.FormValue("id") + _ = input +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + importMap := &core.GoImportMap{Imports: map[string]string{"http": "net/http"}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.handler": { + ID: "handler", + Name: "handler", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"r"}, + MethodArgumentsType: []string{"r: *http.Request"}, + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + scope := typeEngine.GetScope("main.handler") + if scope != nil { + _, hasBinding := scope.Variables["input"] + assert.False(t, hasBinding, "no binding for x (not a parameter named r)") + } +} + +// TestParamAwareRHSInference_PackageAlias ensures package-qualified calls like +// http.NewRequest() are NOT intercepted by param-aware inference (they're already +// handled by inferTypeFromFunctionCall). +func TestParamAwareRHSInference_PackageAlias(t *testing.T) { + code := `package main + +import "net/http" + +func makeReq() { + req, _ := http.NewRequest("GET", "/", nil) + _ = req +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + loader := &mockStdlibLoaderWithTypes{ + stdlibPkgs: map[string]bool{"net/http": true}, + functions: map[string]*core.GoStdlibFunction{ + "net/http.NewRequest": { + Name: "NewRequest", + Returns: []*core.GoReturnValue{{Type: "*Request"}, {Type: "error"}}, + }, + }, + types: map[string]*core.GoStdlibType{}, + } + reg.StdlibLoader = loader + + importMap := &core.GoImportMap{Imports: map[string]string{"http": "net/http"}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.makeReq": { + ID: "makeReq", + Name: "makeReq", + Type: "function_declaration", + File: "/test/main.go", + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + // req should be resolved via inferTypeFromFunctionCall (stdlib path), not param-aware + scope := typeEngine.GetScope("main.makeReq") + assert.NotNil(t, scope) + bindings, ok := scope.Variables["req"] + assert.True(t, ok, "req should have a binding from stdlib lookup") + if assert.Len(t, bindings, 1) { + assert.Equal(t, "net/http.Request", bindings[0].Type.TypeFQN) + } +} + +// TestParamAwareRHSInference_MethodNotFound ensures no binding is created when the +// type is known but does not have the called method. +func TestParamAwareRHSInference_MethodNotFound(t *testing.T) { + code := `package main + +import "net/http" + +func handler(r *http.Request) { + v := r.UnknownMethod() + _ = v +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + loader := &mockStdlibLoaderWithTypes{ + stdlibPkgs: map[string]bool{"net/http": true}, + functions: map[string]*core.GoStdlibFunction{}, + types: map[string]*core.GoStdlibType{ + "net/http.Request": { + Name: "Request", + Methods: map[string]*core.GoStdlibFunction{}, + }, + }, + } + reg.StdlibLoader = loader + + importMap := &core.GoImportMap{Imports: map[string]string{"http": "net/http"}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.handler": { + ID: "handler", + Name: "handler", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"r"}, + MethodArgumentsType: []string{"r: *http.Request"}, + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + scope := typeEngine.GetScope("main.handler") + if scope != nil { + _, hasBinding := scope.Variables["v"] + assert.False(t, hasBinding, "no binding when method is not in type's method set") + } +} + +// TestParamAwareRHSInference_ErrorOnlyReturns ensures no binding is created when +// the method only returns error (no non-error return value to infer from). +func TestParamAwareRHSInference_ErrorOnlyReturns(t *testing.T) { + code := `package main + +import "net/http" + +func handler(r *http.Request) { + err := r.ParseForm() + _ = err +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + loader := &mockStdlibLoaderWithTypes{ + stdlibPkgs: map[string]bool{"net/http": true}, + functions: map[string]*core.GoStdlibFunction{}, + types: map[string]*core.GoStdlibType{ + "net/http.Request": { + Name: "Request", + Methods: map[string]*core.GoStdlibFunction{ + "ParseForm": { + Name: "ParseForm", + Returns: []*core.GoReturnValue{{Type: "error"}}, + }, + }, + }, + }, + } + reg.StdlibLoader = loader + + importMap := &core.GoImportMap{Imports: map[string]string{"http": "net/http"}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.handler": { + ID: "handler", + Name: "handler", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"r"}, + MethodArgumentsType: []string{"r: *http.Request"}, + }, + }, + } + + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + assert.NoError(t, err) + + scope := typeEngine.GetScope("main.handler") + if scope != nil { + _, hasBinding := scope.Variables["err"] + assert.False(t, hasBinding, "no binding when method only returns error") + } +} + +// TestParamAwareRHSInference_UnqualifiedParamType verifies graceful handling when +// the parameter type is unqualified (no package prefix, e.g. "MyStruct"). +func TestParamAwareRHSInference_UnqualifiedParamType(t *testing.T) { + code := `package main + +func process(s MyStruct) { + v := s.Compute() + _ = v +} +` + reg := core.NewGoModuleRegistry() + reg.ModulePath = "github.com/example/app" + reg.DirToImport = map[string]string{"/test": "main"} + + importMap := &core.GoImportMap{Imports: map[string]string{}} + typeEngine := resolution.NewGoTypeInferenceEngine(reg) + + callGraph := &core.CallGraph{ + Functions: map[string]*graph.Node{ + "main.process": { + ID: "process", + Name: "process", + Type: "function_declaration", + File: "/test/main.go", + MethodArgumentsValue: []string{"s"}, + MethodArgumentsType: []string{"s: MyStruct"}, + }, + }, + } + + // Must not panic; type has no dot, so extractionSplitGoTypeFQN returns false. + assert.NotPanics(t, func() { + _ = ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, callGraph) + }) +} diff --git a/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go b/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go index 63d37ac4..347000e0 100644 --- a/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go +++ b/sast-engine/graph/callgraph/extraction/go_variables_stdlib_test.go @@ -242,7 +242,7 @@ func Handler() { Imports: map[string]string{"http": "net/http"}, } - err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap) + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, nil) require.NoError(t, err) scope := typeEngine.GetScope("test.Handler") @@ -278,7 +278,7 @@ func ReadFile() { Imports: map[string]string{"os": "os"}, } - err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap) + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, nil) require.NoError(t, err) scope := typeEngine.GetScope("test.ReadFile") @@ -319,7 +319,7 @@ func Greet(name string) { Imports: map[string]string{"fmt": "fmt"}, } - err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap) + err := ExtractGoVariableAssignments("/test/main.go", []byte(code), typeEngine, reg, importMap, nil) require.NoError(t, err) scope := typeEngine.GetScope("test.Greet") diff --git a/sast-engine/graph/callgraph/extraction/go_variables_test.go b/sast-engine/graph/callgraph/extraction/go_variables_test.go index f909517d..13eb2d80 100644 --- a/sast-engine/graph/callgraph/extraction/go_variables_test.go +++ b/sast-engine/graph/callgraph/extraction/go_variables_test.go @@ -81,7 +81,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -150,7 +151,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -199,7 +201,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -273,7 +276,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -334,7 +338,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -389,7 +394,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -442,7 +448,8 @@ func (u *User) Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify assert.NoError(t, err) @@ -481,7 +488,8 @@ func TestExtractGoVariables_EmptyFile(t *testing.T) { typeEngine, registry, importMap, - ) + nil, +) // Verify - should not error assert.NoError(t, err) @@ -515,7 +523,8 @@ func Test() { typeEngine, registry, importMap, - ) + nil, +) // Verify - should return nil without error assert.NoError(t, err) @@ -579,6 +588,7 @@ func TestExtractGoVariables_Integration(t *testing.T) { typeEngine, registry, importMap, + nil, ) // Verify @@ -662,6 +672,7 @@ func Demo() { typeEngine, registry, &core.GoImportMap{}, + nil, ) // Should return nil (skipped) since path not in registry, not an error @@ -686,6 +697,7 @@ func Demo() { typeEngine, registry, &core.GoImportMap{}, + nil, ) assert.NoError(t, err)