diff --git a/src/parser.ts b/src/parser.ts index b185904..482fd12 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -823,6 +823,8 @@ function stateMachineStatementParser( let openBlocks = 0; + const declaredVariables = new Set(); + /* eslint arrow-body-style: 0, no-extra-parens: 0 */ const isValidToken = (step: Step, token: Token) => { if (!step.validation) { @@ -924,11 +926,17 @@ function stateMachineStatementParser( } } - if ( - token.type === 'parameter' && - (token.value === '?' || !statement.parameters.includes(token.value)) - ) { - statement.parameters.push(token.value); + if (token.type === 'parameter') { + // Variables declared via DECLARE inside a function/procedure body are + // local — they aren't user-supplied parameters, so exclude them. + if (prevNonWhitespaceToken?.value.toUpperCase() === 'DECLARE') { + declaredVariables.add(token.value); + } else if ( + !declaredVariables.has(token.value) && + (token.value === '?' || !statement.parameters.includes(token.value)) + ) { + statement.parameters.push(token.value); + } } if (statement.type && statement.start >= 0) { @@ -1112,8 +1120,13 @@ export function defaultParamTypesFor(dialect: Dialect): ParamTypes { numbered: ['$'], }; case 'mssql': + return { + named: ['@', ':'], + }; + case 'oracle': return { named: [':'], + numbered: [':'], }; case 'bigquery': return { @@ -1125,7 +1138,7 @@ export function defaultParamTypesFor(dialect: Dialect): ParamTypes { return { positional: true, numbered: ['?'], - named: [':', '@'], + named: [':', '@', '$'], }; default: return { diff --git a/test/identifier/single-statement.spec.ts b/test/identifier/single-statement.spec.ts index f86c1d8..5bdaccd 100644 --- a/test/identifier/single-statement.spec.ts +++ b/test/identifier/single-statement.spec.ts @@ -791,7 +791,7 @@ describe('identifier', () => { text: query, type: 'CREATE_FUNCTION', executionType: 'MODIFICATION', - parameters: [], + parameters: ['@DATE'], tables: [], }, ]; @@ -1445,7 +1445,7 @@ describe('identifier', () => { }); it('Should extract named Parameters', () => { - const actual = identify('SELECT * FROM Persons where x = :one and y = :two and a = :one', { + const actual = identify('SELECT * FROM Persons where x = @one and y = @two and a = @one', { dialect: 'mssql', strict: true, }); @@ -1453,10 +1453,10 @@ describe('identifier', () => { { start: 0, end: 61, - text: 'SELECT * FROM Persons where x = :one and y = :two and a = :one', + text: 'SELECT * FROM Persons where x = @one and y = @two and a = @one', type: 'SELECT', executionType: 'LISTING', - parameters: [':one', ':two'], + parameters: ['@one', '@two'], tables: [], }, ]; @@ -1465,7 +1465,7 @@ describe('identifier', () => { }); it('Should extract named Parameters with trailing commas', () => { - const actual = identify('SELECT * FROM Persons where x in (:one, :two, :three)', { + const actual = identify('SELECT * FROM Persons where x in (@one, @two, @three)', { dialect: 'mssql', strict: true, }); @@ -1473,10 +1473,90 @@ describe('identifier', () => { { start: 0, end: 52, - text: 'SELECT * FROM Persons where x in (:one, :two, :three)', + text: 'SELECT * FROM Persons where x in (@one, @two, @three)', + type: 'SELECT', + executionType: 'LISTING', + parameters: ['@one', '@two', '@three'], + tables: [], + }, + ]; + + expect(actual).to.eql(expected); + }); + + it('Should extract mssql colon-prefixed named parameters', () => { + const actual = identify('SELECT * FROM Persons where x = :one and y = :two and a = :one', { + dialect: 'mssql', + strict: true, + }); + const expected = [ + { + start: 0, + end: 61, + text: 'SELECT * FROM Persons where x = :one and y = :two and a = :one', + type: 'SELECT', + executionType: 'LISTING', + parameters: [':one', ':two'], + tables: [], + }, + ]; + + expect(actual).to.eql(expected); + }); + + it('Should extract oracle named parameters', () => { + const actual = identify('SELECT * FROM persons WHERE id = :one AND status = :two', { + dialect: 'oracle', + strict: true, + }); + const expected = [ + { + start: 0, + end: 54, + text: 'SELECT * FROM persons WHERE id = :one AND status = :two', + type: 'SELECT', + executionType: 'LISTING', + parameters: [':one', ':two'], + tables: [], + }, + ]; + + expect(actual).to.eql(expected); + }); + + it('Should extract oracle numbered parameters', () => { + const actual = identify('SELECT * FROM persons WHERE id = :1 AND status = :2', { + dialect: 'oracle', + strict: true, + }); + const expected = [ + { + start: 0, + end: 50, + text: 'SELECT * FROM persons WHERE id = :1 AND status = :2', + type: 'SELECT', + executionType: 'LISTING', + parameters: [':1', ':2'], + tables: [], + }, + ]; + + expect(actual).to.eql(expected); + }); + + it('Should extract sqlite $name parameters', () => { + const actual = identify('SELECT * FROM persons WHERE id = $myid', { + dialect: 'sqlite', + strict: true, + }); + const expected = [ + { + start: 0, + end: 37, + text: 'SELECT * FROM persons WHERE id = $myid', type: 'SELECT', executionType: 'LISTING', - parameters: [':one', ':two', ':three'], + parameters: ['$myid'], tables: [], }, ]; diff --git a/test/index.spec.ts b/test/index.spec.ts index 557cd86..bbc4443 100644 --- a/test/index.spec.ts +++ b/test/index.spec.ts @@ -138,17 +138,12 @@ describe('Regression tests', () => { // Regression test: https://github.com/beekeeper-studio/beekeeper-studio/issues/2560 it('Double colon should not be recognized as a param for mssql', () => { const result = identify( - ` - DECLARE @g geometry; - DECLARE @h geometry; - SET @g = geometry::STGeomFromText('POLYGON((0 0, 2 0, 2 2, 0 2, 0 0))', 0); - set @h = geometry::STGeomFromText('POLYGON((1 1, 3 1, 3 3, 1 3, 1 1))', 0); - SELECT @g.STWithin(@h); - `, + "SET @g = geometry::STGeomFromText('POLYGON((0 0, 2 0, 2 2, 0 2, 0 0))', 0);", { strict: false, dialect: 'mssql' as Dialect }, ); result.forEach((res) => { - expect(res.parameters.length).to.equal(0); + // :: cast syntax should not produce colon-prefixed parameters + expect(res.parameters.every((param) => !param.startsWith(':'))).to.equal(true); }); }); diff --git a/test/parser/single-statements.spec.ts b/test/parser/single-statements.spec.ts index 4fa5b0b..b2f142a 100644 --- a/test/parser/single-statements.spec.ts +++ b/test/parser/single-statements.spec.ts @@ -804,7 +804,7 @@ describe('parser', () => { it('should extract mssql parameters', () => { const actual = parse( - 'select x from a where x = :foo', + 'select x from a where x = @foo', true, 'mssql', false, @@ -826,13 +826,13 @@ describe('parser', () => { }, { type: 'parameter', - value: ':foo', + value: '@foo', start: 26, end: 29, }, ]; expect(actual.tokens).to.eql(expected); - expect(actual.body[0].parameters).to.eql([':foo']); + expect(actual.body[0].parameters).to.eql(['@foo']); }); it('should not identify params in a comment', () => { @@ -875,7 +875,7 @@ describe('parser', () => { it('should extract multiple mssql parameters', () => { const actual = parse( - 'select x from a where x = :foo and y = :bar', + 'select x from a where x = @foo and y = @bar', true, 'mssql', false, @@ -897,7 +897,7 @@ describe('parser', () => { }, { type: 'parameter', - value: ':foo', + value: '@foo', start: 26, end: 29, }, @@ -909,13 +909,13 @@ describe('parser', () => { }, { type: 'parameter', - value: ':bar', + value: '@bar', start: 39, end: 42, }, ]; expect(actual.tokens).to.eql(expected); - expect(actual.body[0].parameters).to.eql([':foo', ':bar']); + expect(actual.body[0].parameters).to.eql(['@foo', '@bar']); }); }); }); diff --git a/test/tokenizer/index.spec.ts b/test/tokenizer/index.spec.ts index 7195735..c0b7805 100644 --- a/test/tokenizer/index.spec.ts +++ b/test/tokenizer/index.spec.ts @@ -271,7 +271,7 @@ describe('scan', () => { ['?', 'generic'], ['?', 'mysql'], ['?', 'sqlite'], - [':', 'mssql'], + ['@', 'mssql'], ].forEach(([ch, dialect]) => { it(`scans just ${ch} as parameter for ${dialect}`, () => { const input = `${ch}`; @@ -322,10 +322,7 @@ describe('scan', () => { expect(actual).to.eql(expected); }); }); - [ - ['$', 'psql'], - [':', 'mssql'], - ].forEach(([ch, dialect]) => { + [['$', 'psql']].forEach(([ch, dialect]) => { it(`should scan ${ch}1 for ${dialect}`, () => { const input = `${ch}1`; const actual = scanToken( @@ -369,6 +366,24 @@ describe('scan', () => { it('should not include trailing non-alphanumerics for mssql', () => { const paramTypes = defaultParamTypesFor('mssql'); [ + { + actual: scanToken(initState('@one,'), 'mssql', paramTypes), + expected: { + type: 'parameter', + value: '@one', + start: 0, + end: 3, + }, + }, + { + actual: scanToken(initState('@two)'), 'mssql', paramTypes), + expected: { + type: 'parameter', + value: '@two', + start: 0, + end: 3, + }, + }, { actual: scanToken(initState(':one,'), 'mssql', paramTypes), expected: { @@ -390,6 +405,33 @@ describe('scan', () => { ].forEach(({ actual, expected }) => expect(actual).to.eql(expected)); }); + it('should not include trailing non-alphanumerics for oracle', () => { + const paramTypes = defaultParamTypesFor('oracle'); + [ + { + actual: scanToken(initState(':one,'), 'oracle', paramTypes), + expected: { + type: 'parameter', + value: ':one', + start: 0, + end: 3, + }, + }, + ].forEach(({ actual, expected }) => expect(actual).to.eql(expected)); + }); + + it('should recognize $name for sqlite', () => { + const paramTypes = defaultParamTypesFor('sqlite'); + const actual = scanToken(initState('$myvar'), 'sqlite', paramTypes); + const expected = { + type: 'parameter', + value: '$myvar', + start: 0, + end: 5, + }; + expect(actual).to.eql(expected); + }); + describe('custom parameters', () => { describe('positional parameters', () => { const paramTypes = {