diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListener.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListener.java index 5a7b05030..7ed2a6a68 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListener.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListener.java @@ -87,10 +87,12 @@ public void onApplicationEvent(RefreshRoutesResultEvent event) { routeLocator.getRoutes().collectList().subscribe(routes -> { // pre-populate with pre-existing global cors configurations to combine with. Map corsConfigurations = new LinkedHashMap<>(); + Map routeCorsConfigurations = new LinkedHashMap<>(); routes.forEach(route -> { Optional corsConfiguration = getCorsConfiguration(route); corsConfiguration.ifPresent(configuration -> { + routeCorsConfigurations.put(route.getId(), configuration); String pathPredicate = getPathPredicate(route); corsConfigurations.put(pathPredicate, configuration); }); @@ -101,6 +103,7 @@ public void onApplicationEvent(RefreshRoutesResultEvent event) { corsConfigurations.put(path, config); } }); + routePredicateHandlerMapping.setRouteCorsConfigurations(routeCorsConfigurations); routePredicateHandlerMapping.setCorsConfigurations(corsConfigurations); }); } diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java index c80e70709..6408e5648 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java @@ -16,6 +16,8 @@ package org.springframework.cloud.gateway.handler; +import java.util.Collections; +import java.util.Map; import java.util.function.Function; import reactor.core.publisher.Mono; @@ -27,6 +29,7 @@ import org.springframework.cloud.gateway.support.ServerWebExchangeUtils; import org.springframework.core.env.Environment; import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource; import org.springframework.web.reactive.handler.AbstractHandlerMapping; import org.springframework.web.server.ServerWebExchange; @@ -51,6 +54,8 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { private final ManagementPortType managementPortType; + private volatile Map routeCorsConfigurations = Collections.emptyMap(); + public RoutePredicateHandlerMapping(FilteringWebHandler webHandler, RouteLocator routeLocator, GlobalCorsProperties globalCorsProperties, Environment environment) { this.webHandler = webHandler; @@ -108,6 +113,30 @@ protected Mono getHandlerInternal(ServerWebExchange exchange) { }); } + public void setRouteCorsConfigurations(Map routeCorsConfigurations) { + this.routeCorsConfigurations = routeCorsConfigurations; + } + + @Override + public void setCorsConfigurations(Map corsConfigurations) { + if (this.routeCorsConfigurations.isEmpty()) { + super.setCorsConfigurations(corsConfigurations); + return; + } + UrlBasedCorsConfigurationSource pathBasedSource = new UrlBasedCorsConfigurationSource(getPathPatternParser()); + pathBasedSource.setCorsConfigurations(corsConfigurations); + setCorsConfigurationSource(exchange -> { + Route route = exchange.getAttribute(GATEWAY_ROUTE_ATTR); + if (route != null) { + CorsConfiguration routeConfig = this.routeCorsConfigurations.get(route.getId()); + if (routeConfig != null) { + return routeConfig; + } + } + return pathBasedSource.getCorsConfiguration(exchange); + }); + } + @Override protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) { // TODO: support cors configuration via properties on a route see gh-229 diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/cors/CorsPerRouteTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/cors/CorsPerRouteTests.java index 2650c12e7..d91852c70 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/cors/CorsPerRouteTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/cors/CorsPerRouteTests.java @@ -96,6 +96,50 @@ public void testPreFlightCorsRequestJavaConfig() { }); } + @Test + public void testPreFlightCorsRequestHostPredicateA() { + testClient.options() + .uri("/anything") + .header("Origin", "https://origin-a.com") + .header("Host", "hosta.example.com") + .header("Access-Control-Request-Method", "GET") + .exchange() + .expectBody(Map.class) + .consumeWith(result -> { + assertThat(result.getResponseBody()).isNull(); + assertThat(result.getStatus()).isEqualTo(HttpStatus.OK); + + HttpHeaders responseHeaders = result.getResponseHeaders(); + assertThat(responseHeaders.getAccessControlAllowOrigin()).as(missingHeader(ACCESS_CONTROL_ALLOW_ORIGIN)) + .isEqualTo("https://origin-a.com"); + assertThat(responseHeaders.getAccessControlAllowMethods()) + .as(missingHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)) + .containsExactly(HttpMethod.GET); + }); + } + + @Test + public void testPreFlightCorsRequestHostPredicateB() { + testClient.options() + .uri("/anything") + .header("Origin", "https://origin-b.com") + .header("Host", "hostb.example.com") + .header("Access-Control-Request-Method", "POST") + .exchange() + .expectBody(Map.class) + .consumeWith(result -> { + assertThat(result.getResponseBody()).isNull(); + assertThat(result.getStatus()).isEqualTo(HttpStatus.OK); + + HttpHeaders responseHeaders = result.getResponseHeaders(); + assertThat(responseHeaders.getAccessControlAllowOrigin()).as(missingHeader(ACCESS_CONTROL_ALLOW_ORIGIN)) + .isEqualTo("https://origin-b.com"); + assertThat(responseHeaders.getAccessControlAllowMethods()) + .as(missingHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)) + .containsExactly(HttpMethod.POST); + }); + } + @Test public void testPreFlightForbiddenCorsRequest() { testClient.get() diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListenerTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListenerTests.java index 5344bcd62..ccc12de88 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListenerTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/cors/CorsGatewayFilterApplicationListenerTests.java @@ -34,6 +34,7 @@ import org.springframework.cloud.gateway.config.GlobalCorsProperties; import org.springframework.cloud.gateway.event.RefreshRoutesResultEvent; import org.springframework.cloud.gateway.handler.RoutePredicateHandlerMapping; +import org.springframework.cloud.gateway.handler.predicate.HostRoutePredicateFactory; import org.springframework.cloud.gateway.handler.predicate.PathRoutePredicateFactory; import org.springframework.cloud.gateway.route.Route; import org.springframework.cloud.gateway.route.RouteLocator; @@ -104,6 +105,9 @@ class CorsGatewayFilterApplicationListenerTests { @Captor private ArgumentCaptor> corsConfigurations; + @Captor + private ArgumentCaptor> routeCorsConfigurations; + private GlobalCorsProperties globalCorsProperties; private CorsGatewayFilterApplicationListener listener; @@ -150,6 +154,60 @@ private CorsConfiguration createCorsConfig(String origin) { return config; } + @Test + void testOnApplicationEvent_hostOnlyRoutes_storesRouteCorsConfigurations() { + + String hostA = "hosta.example.com"; + String hostB = "hostb.example.com"; + String originA = "https://originA.com"; + String originB = "https://originB.com"; + String routeIdA = "host-route-a"; + String routeIdB = "host-route-b"; + + Route routeA = buildHostRoute(routeIdA, hostA, originA); + Route routeB = buildHostRoute(routeIdB, hostB, originB); + + when(routeLocator.getRoutes()).thenReturn(Flux.just(routeA, routeB)); + + listener.onApplicationEvent(new RefreshRoutesResultEvent(this)); + + Awaitility.await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> { + + verify(handlerMapping).setRouteCorsConfigurations(routeCorsConfigurations.capture()); + + Map routeConfigs = routeCorsConfigurations.getValue(); + assertThat(routeConfigs).containsKeys(routeIdA, routeIdB); + assertThat(routeConfigs.get(routeIdA).getAllowedOrigins()).containsExactly(originA); + assertThat(routeConfigs.get(routeIdB).getAllowedOrigins()).containsExactly(originB); + }); + } + + @Test + void testOnApplicationEvent_pathRoutes_alsoStoresRouteCorsConfigurations() { + + Route route1 = buildRoute(ROUTE_ID_1, ROUTE_PATH_1, ORIGIN_ROUTE_1); + Route route2 = buildRoute(ROUTE_ID_2, ROUTE_PATH_2, ORIGIN_ROUTE_2); + + when(routeLocator.getRoutes()).thenReturn(Flux.just(route1, route2)); + + listener.onApplicationEvent(new RefreshRoutesResultEvent(this)); + + Awaitility.await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> { + + verify(handlerMapping).setRouteCorsConfigurations(routeCorsConfigurations.capture()); + verify(handlerMapping).setCorsConfigurations(corsConfigurations.capture()); + + Map routeConfigs = routeCorsConfigurations.getValue(); + assertThat(routeConfigs).containsKeys(ROUTE_ID_1, ROUTE_ID_2); + assertThat(routeConfigs.get(ROUTE_ID_1).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_1); + assertThat(routeConfigs.get(ROUTE_ID_2).getAllowedOrigins()).containsExactly(ORIGIN_ROUTE_2); + + // path-based configurations should still work + Map pathConfigs = corsConfigurations.getValue(); + assertThat(pathConfigs).containsKeys(ROUTE_PATH_1, ROUTE_PATH_2); + }); + } + private Route buildRoute(String id, String path, String allowedOrigin) { return Route.async() @@ -160,4 +218,14 @@ private Route buildRoute(String id, String path, String allowedOrigin) { .build(); } + private Route buildHostRoute(String id, String host, String allowedOrigin) { + + return Route.async() + .id(id) + .uri(ROUTE_URI) + .predicate(new HostRoutePredicateFactory().apply(config -> config.setPatterns(List.of(host)))) + .metadata(METADATA_KEY, Map.of(ALLOWED_ORIGINS_KEY, List.of(allowedOrigin))) + .build(); + } + } diff --git a/spring-cloud-gateway-server/src/test/resources/application-cors-per-route-config.yml b/spring-cloud-gateway-server/src/test/resources/application-cors-per-route-config.yml index b1a875360..0b6a8f9d5 100644 --- a/spring-cloud-gateway-server/src/test/resources/application-cors-per-route-config.yml +++ b/spring-cloud-gateway-server/src/test/resources/application-cors-per-route-config.yml @@ -25,4 +25,22 @@ spring: allowedMethods: - GET - PUT - allowedHeaders: '*' \ No newline at end of file + allowedHeaders: '*' + - id: cors_host_a + uri: ${test.uri} + predicates: + - Host=hosta.example.com + metadata: + cors: + allowedOrigins: 'https://origin-a.com' + allowedMethods: + - GET + - id: cors_host_b + uri: ${test.uri} + predicates: + - Host=hostb.example.com + metadata: + cors: + allowedOrigins: 'https://origin-b.com' + allowedMethods: + - POST