Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package com.zimdugo.auth.application;

import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.stereotype.Component;

@Slf4j
@Component
public class OAuth2CallbackUrlCookieManager {

private static final String CALLBACK_URL_PARAM = "callbackUrl";
private static final String CALLBACK_URL_COOKIE_NAME = "oauth2_callback_url";
private static final int CALLBACK_URL_COOKIE_MAX_AGE_SECONDS = 300;
private static final String SAME_SITE_POLICY = "Lax";
private static final String RELATIVE_PATH_DEFAULT = "/";

@Value("${auth.callback.frontend-base-url:http://localhost:3000}")
private String frontendBaseUrl;

@Value("${auth.callback.allowed-origins:http://localhost:3000,http://localhost:5173}")
private String allowedOriginsProperty;

private Set<String> allowedOrigins;

@PostConstruct
void initializeAllowedOrigins() {
this.allowedOrigins = new LinkedHashSet<>();
Arrays.stream(allowedOriginsProperty.split(","))
.map(String::trim)
.filter(v -> !v.isBlank())
.map(this::extractOrigin)
.forEach(allowedOrigins::add);

String frontendOrigin = extractOrigin(frontendBaseUrl);
if (frontendOrigin != null) {
allowedOrigins.add(frontendOrigin);
}
}

public void saveCallbackUrl(HttpServletRequest request, HttpServletResponse response) {
String callbackUrl = normalize(request.getParameter(CALLBACK_URL_PARAM));
addCookie(response, callbackUrl, CALLBACK_URL_COOKIE_MAX_AGE_SECONDS);
}

public String resolveCallbackUrl(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

for (Cookie cookie : cookies) {
if (CALLBACK_URL_COOKIE_NAME.equals(cookie.getName())) {
return normalize(decode(cookie.getValue()));
}
}

return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

public void clearCallbackUrl(HttpServletResponse response) {
addCookie(response, "", 0);
}

private void addCookie(HttpServletResponse response, String value, int maxAgeSeconds) {
ResponseCookie cookie = ResponseCookie.from(CALLBACK_URL_COOKIE_NAME, encode(value))
.httpOnly(true)
.secure(false)
.path("/")
.maxAge(maxAgeSeconds)
.sameSite(SAME_SITE_POLICY)
.build();

response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString());
}

private String normalize(String callbackUrl) {
if (callbackUrl == null || callbackUrl.isBlank()) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

String trimmed = callbackUrl.trim();
if (trimmed.contains("\r") || trimmed.contains("\n")) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

if (trimmed.startsWith("/") && !trimmed.startsWith("//")) {
return toFrontendUrl(trimmed);
}

String origin = extractOrigin(trimmed);
if (origin == null || !allowedOrigins.contains(origin)) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

return trimmed;
}

private String encode(String value) {
return URLEncoder.encode(value, StandardCharsets.UTF_8);
}

private String decode(String value) {
try {
return URLDecoder.decode(value, StandardCharsets.UTF_8);
} catch (IllegalArgumentException e) {
log.warn("Failed to decode callback cookie. fallback to default.", e);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}
}

private String toFrontendUrl(String path) {
String base = frontendBaseUrl;
if (frontendBaseUrl.endsWith("/")) {
base = frontendBaseUrl.substring(0, frontendBaseUrl.length() - 1);
}
if (path == null || path.isBlank() || "/".equals(path)) {
return base + "/";
}
return base + path;
}

private String extractOrigin(String url) {
try {
URI uri = new URI(url);
if (uri.getScheme() == null || uri.getHost() == null) {
return null;
}

String scheme = uri.getScheme().toLowerCase();
if (!"http".equals(scheme) && !"https".equals(scheme)) {
return null;
}

if (uri.getPort() == -1) {
return scheme + "://" + uri.getHost().toLowerCase();
}
return scheme + "://" + uri.getHost().toLowerCase() + ":" + uri.getPort();
} catch (URISyntaxException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.zimdugo.auth.application;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
@RequiredArgsConstructor
public class OAuth2FailureHandler implements AuthenticationFailureHandler {

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);
callbackUrlCookieManager.clearCallbackUrl(response);

log.warn("oauth login failure. callbackUrl={}, reason={}", callbackUrl, exception.getMessage());
response.sendRedirect(appendCode(callbackUrl, "LOGIN_FAILED"));
}

private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}

Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package com.zimdugo.auth.application;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.zimdugo.auth.domain.AuthTokens;
import com.zimdugo.auth.domain.RefreshTokenRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseCookie;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
Expand All @@ -32,14 +30,16 @@ public class OAuth2SuccessHandler implements AuthenticationSuccessHandler {
private final JwtTokenProvider jwtTokenProvider;
private final RefreshTokenRepository refreshTokenRepository;
private final JwtProperties jwtProperties;
private final ObjectMapper objectMapper;
private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationSuccess(
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);

DefaultOAuth2User oAuth2User = (DefaultOAuth2User) authentication.getPrincipal();
Map<String, Object> attributes = oAuth2User.getAttributes();

Expand All @@ -60,17 +60,17 @@ public void onAuthenticationSuccess(
.sameSite(SAME_SITE_POLICY)
.build();

response.setHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.setCharacterEncoding("UTF-8");
response.addHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
callbackUrlCookieManager.clearCallbackUrl(response);

Map<String, Object> body = new LinkedHashMap<>();
body.put("message", "oauth login success");
body.put("userId", userId);
body.put("email", email);
body.put("accessToken", tokens.accessToken());
log.info("oauth login success. userId={}, sid={}, callbackUrl={}", userId, sid, callbackUrl);
response.sendRedirect(appendCode(callbackUrl, "LOGIN_SUCCESS"));
}

log.info("oauth login success. userId={}, sid={}", userId, sid);
response.getWriter().write(objectMapper.writeValueAsString(body));
private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}
14 changes: 8 additions & 6 deletions src/main/java/com/zimdugo/auth/entrypoint/AuthController.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.zimdugo.auth.application.AccountWithdrawalService;
import com.zimdugo.auth.application.AuthCommandService;
import com.zimdugo.auth.application.AuthRefreshResult;
import com.zimdugo.core.response.RestResponse;
import com.zimdugo.core.response.SuccessCode;
import jakarta.servlet.http.HttpServletResponse;
import java.util.LinkedHashMap;
import java.util.Map;
Expand Down Expand Up @@ -32,7 +34,7 @@ public class AuthController {
private final AccountWithdrawalService accountWithdrawalService;

@PostMapping("/refresh")
public ResponseEntity<?> refresh(
public ResponseEntity<RestResponse<Map<String, Object>>> refresh(
@CookieValue(name = REFRESH_TOKEN_COOKIE_NAME, required = false) String refreshTokenCookie,
@RequestHeader(name = REFRESH_TOKEN_HEADER_NAME, required = false) String refreshTokenHeader,
HttpServletResponse response
Expand All @@ -46,29 +48,29 @@ public ResponseEntity<?> refresh(
createRefreshTokenCookie(result.refreshToken()).toString()
);

return ResponseEntity.ok(createRefreshResponse(result));
return ResponseEntity.ok(RestResponse.of(SuccessCode.OK, createRefreshResponse(result)));
}

@PostMapping("/logout")
public ResponseEntity<?> logout(
public ResponseEntity<RestResponse<Void>> logout(
@CookieValue(name = REFRESH_TOKEN_COOKIE_NAME, required = false) String refreshTokenCookie,
@RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String authorization,
HttpServletResponse response
) {
authCommandService.logout(refreshTokenCookie, extractAccessToken(authorization));

response.setHeader(HttpHeaders.SET_COOKIE, createLogoutCookie().toString());
return ResponseEntity.ok(Map.of("message", "logout success"));
return ResponseEntity.ok(RestResponse.ok(SuccessCode.OK));
}

@PostMapping("/withdraw")
public ResponseEntity<?> withdraw(
public ResponseEntity<RestResponse<Void>> withdraw(
@RequestHeader(name = HttpHeaders.AUTHORIZATION, required = false) String authorization,
HttpServletResponse response
) {
accountWithdrawalService.withdraw(extractAccessToken(authorization));
response.setHeader(HttpHeaders.SET_COOKIE, createLogoutCookie().toString());
return ResponseEntity.ok(Map.of("message", "withdraw success"));
return ResponseEntity.ok(RestResponse.ok(SuccessCode.OK));
}

private Map<String, Object> createRefreshResponse(AuthRefreshResult result) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.zimdugo.auth.entrypoint;

import com.zimdugo.auth.application.OAuth2CallbackUrlCookieManager;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

@Component
@RequiredArgsConstructor
public class OAuth2CallbackUrlCaptureFilter extends OncePerRequestFilter {

private static final String OAUTH2_AUTHORIZATION_REQUEST_PREFIX = "/oauth2/authorization/";

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
protected void doFilterInternal(
HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
callbackUrlCookieManager.saveCallbackUrl(request, response);
filterChain.doFilter(request, response);
}

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return !request.getRequestURI().startsWith(OAUTH2_AUTHORIZATION_REQUEST_PREFIX);
}
}
Loading
Loading