diff --git a/e2e/src/include/bar/bar.controller.ts b/e2e/src/include/bar/bar.controller.ts new file mode 100644 index 000000000..bfba2df8d --- /dev/null +++ b/e2e/src/include/bar/bar.controller.ts @@ -0,0 +1,9 @@ +import { Controller, Get } from '@nestjs/common'; + +@Controller('bar') +export class BarController { + @Get() + getBar() { + return 'bar'; + } +} diff --git a/e2e/src/include/bar/bar.module.ts b/e2e/src/include/bar/bar.module.ts new file mode 100644 index 000000000..994b960de --- /dev/null +++ b/e2e/src/include/bar/bar.module.ts @@ -0,0 +1,7 @@ +import { Module } from '@nestjs/common'; +import { BarController } from './bar.controller'; + +@Module({ + controllers: [BarController] +}) +export class BarModule {} diff --git a/e2e/src/include/baz/baz.controller.ts b/e2e/src/include/baz/baz.controller.ts new file mode 100644 index 000000000..90d3dc6a5 --- /dev/null +++ b/e2e/src/include/baz/baz.controller.ts @@ -0,0 +1,9 @@ +import { Controller, Get } from '@nestjs/common'; + +@Controller('baz') +export class BazController { + @Get() + getBaz() { + return 'baz'; + } +} diff --git a/e2e/src/include/baz/baz.module.ts b/e2e/src/include/baz/baz.module.ts new file mode 100644 index 000000000..ffcdbe3ec --- /dev/null +++ b/e2e/src/include/baz/baz.module.ts @@ -0,0 +1,7 @@ +import { Module } from '@nestjs/common'; +import { BazController } from './baz.controller'; + +@Module({ + controllers: [BazController] +}) +export class BazModule {} diff --git a/e2e/src/include/foo/foo.controller.ts b/e2e/src/include/foo/foo.controller.ts new file mode 100644 index 000000000..03a226451 --- /dev/null +++ b/e2e/src/include/foo/foo.controller.ts @@ -0,0 +1,9 @@ +import { Controller, Get } from '@nestjs/common'; + +@Controller('foo') +export class FooController { + @Get() + getFoo() { + return 'foo'; + } +} diff --git a/e2e/src/include/foo/foo.module.ts b/e2e/src/include/foo/foo.module.ts new file mode 100644 index 000000000..c6b3b0aa8 --- /dev/null +++ b/e2e/src/include/foo/foo.module.ts @@ -0,0 +1,7 @@ +import { Module } from '@nestjs/common'; +import { FooController } from './foo.controller'; + +@Module({ + controllers: [FooController] +}) +export class FooModule {} diff --git a/e2e/src/include/include.module.ts b/e2e/src/include/include.module.ts new file mode 100644 index 000000000..5b5b47706 --- /dev/null +++ b/e2e/src/include/include.module.ts @@ -0,0 +1,9 @@ +import { Module } from '@nestjs/common'; +import { BarModule } from './bar/bar.module'; +import { BazModule } from './baz/baz.module'; +import { FooModule } from './foo/foo.module'; + +@Module({ + imports: [FooModule, BarModule, BazModule] +}) +export class IncludeModule {} diff --git a/e2e/validate-schema.e2e-spec.ts b/e2e/validate-schema.e2e-spec.ts index c58026706..bf6e52311 100644 --- a/e2e/validate-schema.e2e-spec.ts +++ b/e2e/validate-schema.e2e-spec.ts @@ -8,6 +8,7 @@ import { DocumentBuilder, getSchemaPath, OpenAPIObject, + SwaggerDocumentOptions, SwaggerModule } from '../lib'; import { SchemaObject } from '../lib/interfaces/open-api-spec.interface'; @@ -16,138 +17,224 @@ import { Cat } from './src/cats/classes/cat.class'; import { TagDto } from './src/cats/dto/tag.dto'; import { ValidationErrorDto } from './src/common/dto/validation-error.dto'; import { ExpressController } from './src/express.controller'; +import { IncludeModule } from './src/include/include.module'; +import { FooModule } from './src/include/foo/foo.module'; +import { BarModule } from './src/include/bar/bar.module'; +import { BazModule } from './src/include/baz/baz.module'; describe('Validate OpenAPI schema', () => { - let app: INestApplication; - let options: Omit; + describe('general schema', () => { + let app: INestApplication; + let options: Omit; - beforeEach(async () => { - app = await NestFactory.create( - { - module: class {}, - imports: [ApplicationModule], - controllers: [ExpressController] - }, - { - logger: false - } - ); - app.setGlobalPrefix('api/'); - app.enableVersioning(); - - options = new DocumentBuilder() - .setTitle('Cats example') - .setDescription('The cats API description') - .setVersion('1.0') - .setBasePath('api') - .addTag('cats') - .addBasicAuth() - .addBearerAuth() - .addOAuth2() - .addApiKey() - .addApiKey({ type: 'apiKey' }, 'key1') - .addApiKey({ type: 'apiKey' }, 'key2') - .addCookieAuth() - .addSecurityRequirements('bearer') - .addSecurityRequirements({ basic: [], cookie: [] }) - .addGlobalResponse({ - status: 500, - description: 'Internal server error' - }) - .addGlobalResponse({ - status: 400, - description: 'Bad request', - type: ValidationErrorDto - }) - .addGlobalParameters({ - name: 'x-tenant-id', - in: 'header', - schema: { type: 'string' } - }) - .addExtension('x-test', { test: 'test' }) - .addExtension('x-logo', { url: 'https://example.com/logo.png' }, 'info') - .addServer( - 'http://localhost:3000', - 'Local server', + beforeEach(async () => { + app = await NestFactory.create( { - someVariable: { - default: 'Variable default value here', - description: 'A variable description here' - } + module: class {}, + imports: [ApplicationModule], + controllers: [ExpressController] }, { - 'x-google-endpoint': { - allowCors: true - }, - 'x-another-field': 'another value' + logger: false } - ) - .build(); - }); + ); + app.setGlobalPrefix('api/'); + app.enableVersioning(); - it('should produce a valid OpenAPI 3.0 schema', async () => { - await SwaggerModule.loadPluginMetadata(async () => ({ - '@nestjs/swagger': { - models: [ - [ - import('./src/cats/classes/cat.class'), - { - Cat: { - tags: { - description: 'Tags of the cat', - example: ['tag1', 'tag2'], - required: false - }, - siblings: { - required: false, - type: () => ({ - ids: { required: true, type: () => Number } - }) - } - } + options = new DocumentBuilder() + .setTitle('Cats example') + .setDescription('The cats API description') + .setVersion('1.0') + .setBasePath('api') + .addTag('cats') + .addBasicAuth() + .addBearerAuth() + .addOAuth2() + .addApiKey() + .addApiKey({ type: 'apiKey' }, 'key1') + .addApiKey({ type: 'apiKey' }, 'key2') + .addCookieAuth() + .addSecurityRequirements('bearer') + .addSecurityRequirements({ basic: [], cookie: [] }) + .addGlobalResponse({ + status: 500, + description: 'Internal server error' + }) + .addGlobalResponse({ + status: 400, + description: 'Bad request', + type: ValidationErrorDto + }) + .addGlobalParameters({ + name: 'x-tenant-id', + in: 'header', + schema: { type: 'string' } + }) + .addExtension('x-test', { test: 'test' }) + .addExtension('x-logo', { url: 'https://example.com/logo.png' }, 'info') + .addServer( + 'http://localhost:3000', + 'Local server', + { + someVariable: { + default: 'Variable default value here', + description: 'A variable description here' } - ], - [ - import('./src/cats/dto/create-cat.dto'), - { - CreateCatDto: { - enumWithDescription: { - enum: await import('./src/cats/dto/pagination-query.dto').then( - (f) => f.LettersEnum - ) - }, - name: { - description: 'Name of the cat' + }, + { + 'x-google-endpoint': { + allowCors: true + }, + 'x-another-field': 'another value' + } + ) + .build(); + }); + + it('should produce a valid OpenAPI 3.0 schema', async () => { + await SwaggerModule.loadPluginMetadata(async () => ({ + '@nestjs/swagger': { + models: [ + [ + import('./src/cats/classes/cat.class'), + { + Cat: { + tags: { + description: 'Tags of the cat', + example: ['tag1', 'tag2'], + required: false + }, + siblings: { + required: false, + type: () => ({ + ids: { required: true, type: () => Number } + }) + } } } - } - ] - ], - controllers: [ - [ - import('./src/cats/cats.controller'), - { - CatsController: { - findAllBulk: { - type: [ - await import('./src/cats/classes/cat.class').then( - (f) => f.Cat + ], + [ + import('./src/cats/dto/create-cat.dto'), + { + CreateCatDto: { + enumWithDescription: { + enum: await import('./src/cats/dto/pagination-query.dto').then( + (f) => f.LettersEnum ) - ], - summary: 'Find all cats in bulk' + }, + name: { + description: 'Name of the cat' + } } } - } + ] + ], + controllers: [ + [ + import('./src/cats/cats.controller'), + { + CatsController: { + findAllBulk: { + type: [ + await import('./src/cats/classes/cat.class').then( + (f) => f.Cat + ) + ], + summary: 'Find all cats in bulk' + } + } + } + ] ] - ] + } + })); + const document = SwaggerModule.createDocument(app, options); + + const doc = JSON.stringify(document, null, 2); + writeFileSync(join(__dirname, 'api-spec.json'), doc); + + try { + const api = (await SwaggerParser.validate( + document as any + )) as OpenAPIV3.Document; + console.log( + 'API name: %s, Version: %s', + api.info.title, + api.info.version + ); + expect(api.info.title).toEqual('Cats example'); + expect( + api.components.schemas['Cat']['x-schema-extension']['test'] + ).toEqual('test'); + expect( + api.components.schemas['Cat']['x-schema-extension-multiple']['test'] + ).toEqual('test*2'); + expect( + api.paths['/api/cats']['post']['callbacks']['myEvent'][ + '{$request.body#/callbackUrl}' + ]['post']['requestBody']['content']['application/json']['schema'][ + 'properties' + ]['breed']['type'] + ).toEqual('string'); + expect( + api.paths['/api/cats']['post']['callbacks']['mySecondEvent'][ + '{$request.body#/callbackUrl}' + ]['post']['requestBody']['content']['application/json']['schema'][ + 'properties' + ]['breed']['type'] + ).toEqual('string'); + expect( + api.paths['/api/cats']['get']['x-codeSamples'][0]['lang'] + ).toEqual('JavaScript'); + expect(api.paths['/api/cats']['get']['x-multiple']['test']).toEqual( + 'test' + ); + expect(api.paths['/api/cats']['get']['tags']).toContain('tag1'); + expect(api.paths['/api/cats']['get']['tags']).toContain('tag2'); + } catch (err) { + console.log(doc); + expect(err).toBeUndefined(); } - })); - const document = SwaggerModule.createDocument(app, options); + }); + + it('should fix colons in url', async () => { + const document = SwaggerModule.createDocument(app, options); + expect( + document.paths['/api/v1/express:colon:another/{prop}'] + ).toBeDefined(); + }); + + it('should merge custom components passed via config', async () => { + const components = { + schemas: { + Person: { + oneOf: [ + { + $ref: getSchemaPath(Cat) + }, + { + $ref: getSchemaPath(TagDto) + } + ], + discriminator: { + propertyName: '_resolveType', + mapping: { + cat: getSchemaPath(Cat), + tag: getSchemaPath(TagDto) + } + } + } + } + }; - const doc = JSON.stringify(document, null, 2); - writeFileSync(join(__dirname, 'api-spec.json'), doc); + const document = SwaggerModule.createDocument(app, { + ...options, + components: { + ...options.components, + ...components + } + }); - try { const api = (await SwaggerParser.validate( document as any )) as OpenAPIV3.Document; @@ -156,146 +243,133 @@ describe('Validate OpenAPI schema', () => { api.info.title, api.info.version ); - expect(api.info.title).toEqual('Cats example'); - expect( - api.components.schemas['Cat']['x-schema-extension']['test'] - ).toEqual('test'); - expect( - api.components.schemas['Cat']['x-schema-extension-multiple']['test'] - ).toEqual('test*2'); - expect( - api.paths['/api/cats']['post']['callbacks']['myEvent'][ - '{$request.body#/callbackUrl}' - ]['post']['requestBody']['content']['application/json']['schema'][ - 'properties' - ]['breed']['type'] - ).toEqual('string'); - expect( - api.paths['/api/cats']['post']['callbacks']['mySecondEvent'][ - '{$request.body#/callbackUrl}' - ]['post']['requestBody']['content']['application/json']['schema'][ - 'properties' - ]['breed']['type'] - ).toEqual('string'); - expect(api.paths['/api/cats']['get']['x-codeSamples'][0]['lang']).toEqual( - 'JavaScript' - ); - expect(api.paths['/api/cats']['get']['x-multiple']['test']).toEqual( - 'test' - ); - expect(api.paths['/api/cats']['get']['tags']).toContain('tag1'); - expect(api.paths['/api/cats']['get']['tags']).toContain('tag2'); - } catch (err) { - console.log(doc); - expect(err).toBeUndefined(); - } - }); - - it('should fix colons in url', async () => { - const document = SwaggerModule.createDocument(app, options); - expect( - document.paths['/api/v1/express:colon:another/{prop}'] - ).toBeDefined(); - }); + expect(api.components.schemas).toHaveProperty('Person'); + expect(api.components.schemas).toHaveProperty('Cat'); + }); - it('should merge custom components passed via config', async () => { - const components = { - schemas: { - Person: { - oneOf: [ - { - $ref: getSchemaPath(Cat) + it('should consider explicit config over auto-detected schema', () => { + const document = SwaggerModule.createDocument(app, options); + expect(document.paths['/api/cats/download'].get.responses).toEqual({ + '200': { + description: 'binary file for download', + content: { + 'application/pdf': { + schema: { type: 'string', format: 'binary' } }, - { - $ref: getSchemaPath(TagDto) - } - ], - discriminator: { - propertyName: '_resolveType', - mapping: { - cat: getSchemaPath(Cat), - tag: getSchemaPath(TagDto) - } + 'image/jpeg': { schema: { type: 'string', format: 'binary' } } } } - } - }; + }); + }); - const document = SwaggerModule.createDocument(app, { - ...options, - components: { - ...options.components, - ...components - } + it('should not add optional properties to required list', () => { + const document = SwaggerModule.createDocument(app, options); + const required = (document.components?.schemas?.Cat as SchemaObject) + ?.required; + expect(required).not.toContain('optionalRawDefinition'); }); - const api = (await SwaggerParser.validate( - document as any - )) as OpenAPIV3.Document; - console.log('API name: %s, Version: %s', api.info.title, api.info.version); - expect(api.components.schemas).toHaveProperty('Person'); - expect(api.components.schemas).toHaveProperty('Cat'); - }); + it('should fail if extension is not prefixed with x-', () => { + expect(() => + new DocumentBuilder().addExtension('test', { test: 'test' }).build() + ).toThrow( + 'Extension key is not prefixed. Please ensure you prefix it with `x-`.' + ); + }); - it('should consider explicit config over auto-detected schema', () => { - const document = SwaggerModule.createDocument(app, options); - expect(document.paths['/api/cats/download'].get.responses).toEqual({ - '200': { - description: 'binary file for download', - content: { - 'application/pdf': { - schema: { type: 'string', format: 'binary' } - }, - 'image/jpeg': { schema: { type: 'string', format: 'binary' } } - } - } + it('should add extension to root', () => { + const document = SwaggerModule.createDocument(app, options); + expect(document['x-test']).toEqual({ test: 'test' }); }); - }); - it('should not add optional properties to required list', () => { - const document = SwaggerModule.createDocument(app, options); - const required = (document.components?.schemas?.Cat as SchemaObject) - ?.required; - expect(required).not.toContain('optionalRawDefinition'); - }); + it('should add extension to info', () => { + const document = SwaggerModule.createDocument(app, options); + expect(document.info['x-logo']).toEqual({ + url: 'https://example.com/logo.png' + }); + }); - it('should fail if extension is not prefixed with x-', () => { - expect(() => - new DocumentBuilder().addExtension('test', { test: 'test' }).build() - ).toThrow( - 'Extension key is not prefixed. Please ensure you prefix it with `x-`.' - ); + it('should add server to the root', () => { + const document = SwaggerModule.createDocument(app, options); + expect(document.servers).toBeDefined(); + expect(document.servers).toHaveLength(1); + expect(document.servers?.[0]).toEqual({ + url: 'http://localhost:3000', + description: 'Local server', + variables: { + someVariable: { + default: 'Variable default value here', + description: 'A variable description here' + } + }, + 'x-google-endpoint': { + allowCors: true + }, + 'x-another-field': 'another value' + }); + }); }); - it('should add extension to root', () => { - const document = SwaggerModule.createDocument(app, options); - expect(document['x-test']).toEqual({ test: 'test' }); - }); + describe('include', () => { + const createDocument = async (swaggerOptions?: SwaggerDocumentOptions) => { + const app = await NestFactory.create(IncludeModule, { + logger: false + }); + app.setGlobalPrefix('api/'); + + const options = new DocumentBuilder() + .setTitle('Include') + .setVersion('1.0') + .build(); + + return SwaggerModule.createDocument(app, options, swaggerOptions); + }; + + const getSortedPaths = (doc: OpenAPIObject): readonly string[] => + Object.keys(doc.paths).sort(); - it('should add extension to info', () => { - const document = SwaggerModule.createDocument(app, options); - expect(document.info['x-logo']).toEqual({ - url: 'https://example.com/logo.png' + it('should include all modules by default', async () => { + const doc = await createDocument({}); + const paths = getSortedPaths(doc); + expect(paths).toEqual(['/api/bar', '/api/baz', '/api/foo']); }); - }); - it('should add server to the root', () => { - const document = SwaggerModule.createDocument(app, options); - expect(document.servers).toBeDefined(); - expect(document.servers).toHaveLength(1); - expect(document.servers?.[0]).toEqual({ - url: 'http://localhost:3000', - description: 'Local server', - variables: { - someVariable: { - default: 'Variable default value here', - description: 'A variable description here' - } - }, - 'x-google-endpoint': { - allowCors: true - }, - 'x-another-field': 'another value' + it('should include all modules if includes is an empty array', async () => { + const doc = await createDocument({ include: [] }); + const paths = getSortedPaths(doc); + expect(paths).toEqual(['/api/bar', '/api/baz', '/api/foo']); + }); + + it('should include only specified modules - array mode', async () => { + const doc = await createDocument({ + include: [FooModule] + }); + const paths = getSortedPaths(doc); + expect(paths).toEqual(['/api/foo']); + }); + + it('should include all modules if the fn allows', async () => { + const doc = await createDocument({ + include: () => true + }); + const paths = getSortedPaths(doc); + expect(paths).toEqual(['/api/bar', '/api/baz', '/api/foo']); + }); + + it('should exclude all modules if the fn rejects', async () => { + const doc = await createDocument({ + include: () => false + }); + const paths = getSortedPaths(doc); + expect(paths).toEqual([]); + }); + + it('should include matching modules', async () => { + const doc = await createDocument({ + include: (t) => t === BazModule + }); + const paths = getSortedPaths(doc); + expect(paths).toEqual(['/api/baz']); }); }); }); diff --git a/lib/interfaces/swagger-document-options.interface.ts b/lib/interfaces/swagger-document-options.interface.ts index 3bd0db189..7202f81a5 100644 --- a/lib/interfaces/swagger-document-options.interface.ts +++ b/lib/interfaces/swagger-document-options.interface.ts @@ -1,17 +1,21 @@ +import { Module } from '@nestjs/core/injector/module'; + export type OperationIdFactory = ( controllerKey: string, methodKey: string, version?: string ) => string; +export type SwaggerIncludeModuleFn = (type: Function) => boolean; + /** * @publicApi */ export interface SwaggerDocumentOptions { /** - * List of modules to include in the specification + * Modules to include in the specification. Can either be a list of modules or a predicate function. */ - include?: Function[]; + include?: SwaggerIncludeModuleFn | Function[]; /** * Additional, extra models that should be inspected and included in the specification diff --git a/lib/swagger-scanner.ts b/lib/swagger-scanner.ts index 168204c9d..6313c5cb3 100644 --- a/lib/swagger-scanner.ts +++ b/lib/swagger-scanner.ts @@ -7,7 +7,8 @@ import { flatten, isEmpty } from 'lodash'; import { OpenAPIObject, OperationIdFactory, - SwaggerDocumentOptions + SwaggerDocumentOptions, + SwaggerIncludeModuleFn } from './interfaces'; import { ModuleRoute } from './interfaces/module-route.interface'; import { @@ -22,6 +23,8 @@ import { SwaggerTransformer } from './swagger-transformer'; import { getGlobalPrefix } from './utils/get-global-prefix'; import { stripLastSlash } from './utils/strip-last-slash.util'; +type IncludeConfig = SwaggerDocumentOptions['include']; + export class SwaggerScanner { private readonly transformer = new SwaggerTransformer(); private readonly schemaObjectFactory = new SchemaObjectFactory( @@ -36,7 +39,7 @@ export class SwaggerScanner { ): Omit { const { deepScanRoutes, - include: includedModules = [], + include = [], extraModels = [], ignoreGlobalPrefix = false, operationIdFactory, @@ -50,10 +53,7 @@ export class SwaggerScanner { const httpAdapterType = app.getHttpAdapter().getType(); this.initializeSwaggerExplorer(httpAdapterType); - const modules: Module[] = this.getModules( - container.getModules(), - includedModules - ); + const modules: Module[] = this.getModules(container.getModules(), include); const globalPrefix = !ignoreGlobalPrefix ? stripLastSlash(getGlobalPrefix(app)) : ''; @@ -132,14 +132,25 @@ export class SwaggerScanner { public getModules( modulesContainer: Map, - include: Function[] + include: IncludeConfig ): Module[] { - if (!include || isEmpty(include)) { - return [...modulesContainer.values()]; + const fn = + typeof include === 'function' + ? include + : this.generateModuleIncludeFn(include); + + return [...modulesContainer.values()].filter((o) => fn(o.metatype)); + } + + private generateModuleIncludeFn( + arr: readonly Function[] + ): SwaggerIncludeModuleFn { + if (arr.length === 0) { + return () => true; } - return [...modulesContainer.values()].filter(({ metatype }) => - include.some((item) => item === metatype) - ); + + const set = new Set(arr); + return (metatype) => set.has(metatype); } public addExtraModels(