Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package com.zimdugo.auth.application;

import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.stereotype.Component;

@Slf4j
@Component
public class OAuth2CallbackUrlCookieManager {

private static final String CALLBACK_URL_PARAM = "callbackUrl";
private static final String CALLBACK_URL_COOKIE_NAME = "oauth2_callback_url";
private static final int CALLBACK_URL_COOKIE_MAX_AGE_SECONDS = 300;
private static final String SAME_SITE_POLICY = "Lax";
private static final String RELATIVE_PATH_DEFAULT = "/";

@Value("${auth.callback.frontend-base-url:http://localhost:3000}")
private String frontendBaseUrl;

@Value("${auth.callback.allowed-origins:http://localhost:3000,http://localhost:5173}")
private String allowedOriginsProperty;

private Set<String> allowedOrigins;

@PostConstruct
void initializeAllowedOrigins() {
this.allowedOrigins = new LinkedHashSet<>();
Arrays.stream(allowedOriginsProperty.split(","))
.map(String::trim)
.filter(v -> !v.isBlank())
.map(this::extractOrigin)
.forEach(allowedOrigins::add);

String frontendOrigin = extractOrigin(frontendBaseUrl);
if (frontendOrigin != null) {
allowedOrigins.add(frontendOrigin);
}
}

public void saveCallbackUrl(HttpServletRequest request, HttpServletResponse response) {
String callbackUrl = normalize(request.getParameter(CALLBACK_URL_PARAM));
addCookie(response, callbackUrl, CALLBACK_URL_COOKIE_MAX_AGE_SECONDS);
}

public String resolveCallbackUrl(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies == null) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

for (Cookie cookie : cookies) {
if (CALLBACK_URL_COOKIE_NAME.equals(cookie.getName())) {
return normalize(decode(cookie.getValue()));
}
}

return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

public void clearCallbackUrl(HttpServletResponse response) {
addCookie(response, "", 0);
}

private void addCookie(HttpServletResponse response, String value, int maxAgeSeconds) {
ResponseCookie cookie = ResponseCookie.from(CALLBACK_URL_COOKIE_NAME, encode(value))
.httpOnly(true)
.secure(false)
.path("/")
.maxAge(maxAgeSeconds)
.sameSite(SAME_SITE_POLICY)
.build();

response.addHeader(HttpHeaders.SET_COOKIE, cookie.toString());
}

private String normalize(String callbackUrl) {
if (callbackUrl == null || callbackUrl.isBlank()) {
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

String trimmed = callbackUrl.trim();
if (trimmed.contains("\r") || trimmed.contains("\n")) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

if (trimmed.startsWith("/") && !trimmed.startsWith("//")) {
return toFrontendUrl(trimmed);
}

String origin = extractOrigin(trimmed);
if (origin == null || !allowedOrigins.contains(origin)) {
log.warn("Unsafe callbackUrl detected. fallback to default. callbackUrl={}", trimmed);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}

return trimmed;
}

private String encode(String value) {
return URLEncoder.encode(value, StandardCharsets.UTF_8);
}

private String decode(String value) {
try {
return URLDecoder.decode(value, StandardCharsets.UTF_8);
} catch (IllegalArgumentException e) {
log.warn("Failed to decode callback cookie. fallback to default.", e);
return toFrontendUrl(RELATIVE_PATH_DEFAULT);
}
}

private String toFrontendUrl(String path) {
String base = frontendBaseUrl;
if (frontendBaseUrl.endsWith("/")) {
base = frontendBaseUrl.substring(0, frontendBaseUrl.length() - 1);
}
if (path == null || path.isBlank() || "/".equals(path)) {
return base + "/";
}
return base + path;
}

private String extractOrigin(String url) {
try {
URI uri = new URI(url);
if (uri.getScheme() == null || uri.getHost() == null) {
return null;
}

String scheme = uri.getScheme().toLowerCase();
if (!"http".equals(scheme) && !"https".equals(scheme)) {
return null;
}

if (uri.getPort() == -1) {
return scheme + "://" + uri.getHost().toLowerCase();
}
return scheme + "://" + uri.getHost().toLowerCase() + ":" + uri.getPort();
} catch (URISyntaxException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.zimdugo.auth.application;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
@RequiredArgsConstructor
public class OAuth2FailureHandler implements AuthenticationFailureHandler {

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);
callbackUrlCookieManager.clearCallbackUrl(response);

log.warn("oauth login failure. callbackUrl={}, reason={}", callbackUrl, exception.getMessage());
response.sendRedirect(appendCode(callbackUrl, "LOGIN_FAILED"));
}

private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}

Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package com.zimdugo.auth.application;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.zimdugo.auth.domain.AuthTokens;
import com.zimdugo.auth.domain.RefreshTokenRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseCookie;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;

@Slf4j
@Component
Expand All @@ -32,14 +30,16 @@ public class OAuth2SuccessHandler implements AuthenticationSuccessHandler {
private final JwtTokenProvider jwtTokenProvider;
private final RefreshTokenRepository refreshTokenRepository;
private final JwtProperties jwtProperties;
private final ObjectMapper objectMapper;
private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
public void onAuthenticationSuccess(
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication
) throws IOException {
String callbackUrl = callbackUrlCookieManager.resolveCallbackUrl(request);

DefaultOAuth2User oAuth2User = (DefaultOAuth2User) authentication.getPrincipal();
Map<String, Object> attributes = oAuth2User.getAttributes();

Expand All @@ -60,17 +60,17 @@ public void onAuthenticationSuccess(
.sameSite(SAME_SITE_POLICY)
.build();

response.setHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.setCharacterEncoding("UTF-8");
response.addHeader(HttpHeaders.SET_COOKIE, rtCookie.toString());
callbackUrlCookieManager.clearCallbackUrl(response);

Map<String, Object> body = new LinkedHashMap<>();
body.put("message", "oauth login success");
body.put("userId", userId);
body.put("email", email);
body.put("accessToken", tokens.accessToken());
log.info("oauth login success. userId={}, sid={}, callbackUrl={}", userId, sid, callbackUrl);
response.sendRedirect(appendCode(callbackUrl, "LOGIN_SUCCESS"));
}

log.info("oauth login success. userId={}, sid={}", userId, sid);
response.getWriter().write(objectMapper.writeValueAsString(body));
private String appendCode(String callbackUrl, String code) {
return UriComponentsBuilder.fromUriString(callbackUrl)
.replaceQueryParam("code", code)
.build(true)
.toUriString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.zimdugo.auth.entrypoint;

import com.zimdugo.auth.application.OAuth2CallbackUrlCookieManager;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

@Component
@RequiredArgsConstructor
public class OAuth2CallbackUrlCaptureFilter extends OncePerRequestFilter {

private static final String OAUTH2_AUTHORIZATION_REQUEST_PREFIX = "/oauth2/authorization/";

private final OAuth2CallbackUrlCookieManager callbackUrlCookieManager;

@Override
protected void doFilterInternal(
HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
callbackUrlCookieManager.saveCallbackUrl(request, response);
filterChain.doFilter(request, response);
}

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return !request.getRequestURI().startsWith(OAUTH2_AUTHORIZATION_REQUEST_PREFIX);
}
}
7 changes: 7 additions & 0 deletions src/main/java/com/zimdugo/common/config/SecurityConfig.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.zimdugo.common.config;

import com.zimdugo.auth.application.CustomOAuth2UserService;
import com.zimdugo.auth.application.OAuth2FailureHandler;
import com.zimdugo.auth.application.OAuth2SuccessHandler;
import com.zimdugo.auth.entrypoint.JwtAuthenticationFilter;
import com.zimdugo.auth.entrypoint.OAuth2CallbackUrlCaptureFilter;
import com.zimdugo.common.security.CustomAccessDeniedHandler;
import com.zimdugo.common.security.CustomAuthenticationEntryPoint;
import lombok.RequiredArgsConstructor;
Expand All @@ -11,6 +13,7 @@
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;

Expand All @@ -20,7 +23,9 @@ public class SecurityConfig {

private final CustomOAuth2UserService customOAuth2UserService;
private final OAuth2SuccessHandler oAuth2SuccessHandler;
private final OAuth2FailureHandler oAuth2FailureHandler;
private final JwtAuthenticationFilter jwtAuthenticationFilter;
private final OAuth2CallbackUrlCaptureFilter oAuth2CallbackUrlCaptureFilter;
private final CustomAuthenticationEntryPoint customAuthenticationEntryPoint;
private final CustomAccessDeniedHandler customAccessDeniedHandler;

Expand All @@ -31,6 +36,7 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti
configureOauth2Login(http);

http.logout(AbstractHttpConfigurer::disable)
.addFilterBefore(oAuth2CallbackUrlCaptureFilter, OAuth2AuthorizationRequestRedirectFilter.class)
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);

return http.build();
Expand Down Expand Up @@ -72,6 +78,7 @@ private void configureOauth2Login(HttpSecurity http) throws Exception {
http.oauth2Login(oauth2 -> oauth2
.userInfoEndpoint(userInfo -> userInfo.userService(customOAuth2UserService))
.successHandler(oAuth2SuccessHandler)
.failureHandler(oAuth2FailureHandler)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.List;

public record UserProfileResponse(
public record UserProfileDto(
Long id,
String email,
String nickname,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class UserQueryService {
private final UserReader userReader;
private final SocialAccountReader socialAccountReader;

public UserProfileResponse getProfile(Long userId) {
public UserProfileDto getProfile(Long userId) {
User user = findById(userId);

List<SocialAccount> socialAccounts = socialAccountReader.findAllByUserId(userId);
Expand All @@ -28,7 +28,7 @@ public UserProfileResponse getProfile(Long userId) {
.map(sa -> sa.getProvider().name().toLowerCase())
.toList();

return new UserProfileResponse(
return new UserProfileDto(
user.getId(),
user.getEmail(),
user.getNickname(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import com.zimdugo.core.exception.BusinessException;
import com.zimdugo.core.exception.ErrorCode;
import com.zimdugo.user.application.UserProfileDto;
import com.zimdugo.user.application.UserQueryService;
import com.zimdugo.user.application.UserProfileResponse;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
Expand All @@ -22,7 +22,8 @@ public class UserController {
public ResponseEntity<UserProfileResponse> me(
Authentication authentication
) {
return ResponseEntity.ok(userQueryService.getProfile(extractUserId(authentication)));
UserProfileDto profile = userQueryService.getProfile(extractUserId(authentication));
return ResponseEntity.ok(UserProfileResponse.from(profile));
Comment thread
mike7643 marked this conversation as resolved.
Outdated
}

private Long extractUserId(Authentication authentication) {
Expand Down
Loading
Loading