Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package com.zimdugo.auth.application;

import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.stereotype.Component;

@Slf4j
@Component
public class OAuth2CallbackUrlCookieManager {

private static final String CALLBACK_URL_PARAM = "callbackUrl";
private static final String CALLBACK_URL_COOKIE_NAME = "oauth2_callback_url";
private static final int CALLBACK_URL_COOKIE_MAX_AGE_SECONDS = 300;
private static final String SAME_SITE_POLICY = "Lax";
private static final String RELATIVE_PATH_DEFAULT = "/";

@Value("${auth.callback.frontend-base-url:http://localhost:3000}")
private String frontendBaseUrl;

@Value("${auth.callback.allowed-origins:http://localhost:3000,http://localhost:5173}")
private String allowedOriginsProperty;

private Set<String> allowedOrigins;

@PostConstruct
void initializeAllowedOrigins() {
this.allowedOrigins = new LinkedHashSet<>();
Arrays.stream(allowedOriginsProperty.split(","))
.map(String::trim)
.filter(v -> !v.isBlank())
.map(this::extractOrigin)
.forEach(allowedOrigins::add);

String frontendOrigin = extractOrigin(frontendBaseUrl);
if (frontendOrigin != null) {
allowedOrigins.add(frontendOrigin);
}
}

public void saveCallbackUrl(HttpServletRequest request, HttpServletResponse response) {
String callbackUrl = normalize(request.getParameter(CALLBACK_URL_PARAM));
addCookie(response, callbackUrl, CALLBACK_URL_COOKIE_MAX_AGE_SECONDS);
}

public String resolveCallbackUrl(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

for (Cookie cookie : cookies) {
if (CALLBACK_URL_COOKIE_NAME.equals(cookie.getName())) {
return normalize(decode(cookie.getValue()));
}
}

return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

public void clearCallbackUrl(HttpServletResponse response) {
addCookie(response, "", 0);
}

private void addCookie(HttpServletResponse response, String value, int maxAgeSeconds) {
ResponseCookie cookie = ResponseCookie.from(CALLBACK_URL_COOKIE_NAME, encode(value))
.httpOnly(true)
.secure(false)
.path("/")
.maxAge(maxAgeSeconds)
.sameSite(SAME_SITE_POLICY)
.build();

response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString());
}

private String normalize(String callbackUrl) {
if (callbackUrl == null || callbackUrl.isBlank()) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

String trimmed = callbackUrl.trim();
if (trimmed.contains("\r") || trimmed.contains("\n")) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

if (trimmed.startsWith("/") && !trimmed.startsWith("//")) {
return toFrontendUrl(trimmed);
}

String origin = extractOrigin(trimmed);
if (origin == null || !allowedOrigins.contains(origin)) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

return trimmed;
}

private String encode(String value) {
return URLEncoder.encode(value, StandardCharsets.UTF_8);
}

private String decode(String value) {
try {
return URLDecoder.decode(value, StandardCharsets.UTF_8);
} catch (IllegalArgumentException e) {
log.warn("Failed to decode callback cookie. fallback to default.", e);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}
}

private String toFrontendUrl(String path) {
String base = frontendBaseUrl;
if (frontendBaseUrl.endsWith("/")) {
base = frontendBaseUrl.substring(0, frontendBaseUrl.length() - 1);
}
if (path == null || path.isBlank() || "/".equals(path)) {
return base + "/";
}
return base + path;
}

private String extractOrigin(String url) {
try {
URI uri = new URI(url);
if (uri.getScheme() == null || uri.getHost() == null) {
return null;
}

String scheme = uri.getScheme().toLowerCase();
if (!"http".equals(scheme) && !"https".equals(scheme)) {
return null;
}

if (uri.getPort() == -1) {
return scheme + "://" + uri.getHost().toLowerCase();
}
return scheme + "://" + uri.getHost().toLowerCase() + ":" + uri.getPort();
} catch (URISyntaxException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.zimdugo.auth.application;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
@RequiredArgsConstructor
public class OAuth2FailureHandler implements AuthenticationFailureHandler {

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);
callbackUrlCookieManager.clearCallbackUrl(response);

log.warn("oauth login failure. callbackUrl={}, reason={}", callbackUrl, exception.getMessage());
response.sendRedirect(appendCode(callbackUrl, "LOGIN_FAILED"));
}

private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}

Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package com.zimdugo.auth.application;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.zimdugo.auth.domain.AuthTokens;
import com.zimdugo.auth.domain.RefreshTokenRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseCookie;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
Expand All @@ -32,14 +30,16 @@ public class OAuth2SuccessHandler implements AuthenticationSuccessHandler {
private final JwtTokenProvider jwtTokenProvider;
private final RefreshTokenRepository refreshTokenRepository;
private final JwtProperties jwtProperties;
private final ObjectMapper objectMapper;
private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationSuccess(
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);

DefaultOAuth2User oAuth2User = (DefaultOAuth2User) authentication.getPrincipal();
Map<String, Object> attributes = oAuth2User.getAttributes();

Expand All @@ -60,17 +60,17 @@ public void onAuthenticationSuccess(
.sameSite(SAME_SITE_POLICY)
.build();

response.setHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.setCharacterEncoding("UTF-8");
response.addHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
callbackUrlCookieManager.clearCallbackUrl(response);

Map<String, Object> body = new LinkedHashMap<>();
body.put("message", "oauth login success");
body.put("userId", userId);
body.put("email", email);
body.put("accessToken", tokens.accessToken());
log.info("oauth login success. userId={}, sid={}, callbackUrl={}", userId, sid, callbackUrl);
response.sendRedirect(appendCode(callbackUrl, "LOGIN_SUCCESS"));
}

log.info("oauth login success. userId={}, sid={}", userId, sid);
response.getWriter().write(objectMapper.writeValueAsString(body));
private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}
14 changes: 8 additions & 6 deletions src/main/java/com/zimdugo/auth/entrypoint/AuthController.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.zimdugo.auth.application.AccountWithdrawalService;
import com.zimdugo.auth.application.AuthCommandService;
import com.zimdugo.auth.application.AuthRefreshResult;
import com.zimdugo.core.response.RestResponse;
import com.zimdugo.core.response.SuccessCode;
import jakarta.servlet.http.HttpServletResponse;
import java.util.LinkedHashMap;
import java.util.Map;
Expand Down Expand Up @@ -32,7 +34,7 @@ public class AuthController {
private final AccountWithdrawalService accountWithdrawalService;

@PostMapping("/refresh")
public ResponseEntity<?> refresh(
public ResponseEntity<RestResponse<Map<String, Object>>> refresh(
@CookieValue(name = REFRESH_TOKEN_COOKIE_NAME, required = false) String refreshTokenCookie,
@RequestHeader(name = REFRESH_TOKEN_HEADER_NAME, required = false) String refreshTokenHeader,
HttpServletResponse response
Expand All @@ -46,29 +48,29 @@ public ResponseEntity<?> refresh(
createRefreshTokenCookie(result.refreshToken()).toString()
);

return ResponseEntity.ok(createRefreshResponse(result));
return ResponseEntity.ok(RestResponse.of(SuccessCode.OK, createRefreshResponse(result)));
}

@PostMapping("/logout")
public ResponseEntity<?> logout(
public ResponseEntity<RestResponse<Void>> logout(
@CookieValue(name = REFRESH_TOKEN_COOKIE_NAME, required = false) String refreshTokenCookie,
@RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String authorization,
HttpServletResponse response
) {
authCommandService.logout(refreshTokenCookie, extractAccessToken(authorization));

response.setHeader(HttpHeaders.SET_COOKIE, createLogoutCookie().toString());
return ResponseEntity.ok(Map.of("message", "logout success"));
return ResponseEntity.ok(RestResponse.ok(SuccessCode.OK));
}

@PostMapping("/withdraw")
public ResponseEntity<?> withdraw(
public ResponseEntity<RestResponse<Void>> withdraw(
@RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String authorization,
HttpServletResponse response
) {
accountWithdrawalService.withdraw(extractAccessToken(authorization));
response.setHeader(HttpHeaders.SET_COOKIE, createLogoutCookie().toString());
return ResponseEntity.ok(Map.of("message", "withdraw success"));
return ResponseEntity.ok(RestResponse.ok(SuccessCode.OK));
}

private Map<String, Object> createRefreshResponse(AuthRefreshResult result) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.zimdugo.auth.entrypoint;

import com.zimdugo.auth.application.OAuth2CallbackUrlCookieManager;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

@Component
@RequiredArgsConstructor
public class OAuth2CallbackUrlCaptureFilter extends OncePerRequestFilter {

private static final String OAUTH2_AUTHORIZATION_REQUEST_PREFIX = "/oauth2/authorization/";

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
protected void doFilterInternal(
HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
callbackUrlCookieManager.saveCallbackUrl(request, response);
filterChain.doFilter(request, response);
}

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return !request.getRequestURI().startsWith(OAUTH2_AUTHORIZATION_REQUEST_PREFIX);
}
}
Loading
Loading