diff --git a/backend/src/appointment/controller/apis/google_client.py b/backend/src/appointment/controller/apis/google_client.py index f839873d6..c89fe23cc 100644 --- a/backend/src/appointment/controller/apis/google_client.py +++ b/backend/src/appointment/controller/apis/google_client.py @@ -25,7 +25,6 @@ class GoogleClient: 'https://www.googleapis.com/auth/userinfo.email', 'openid', ] - client: Flow | None = None def __init__(self, client_id, client_secret, project_id, callback_url): self.config = { @@ -39,30 +38,40 @@ def __init__(self, client_id, client_secret, project_id, callback_url): } self.callback_url = callback_url - self.client = None + self._setup_verified = False + + def _create_flow(self) -> Flow: + """Create a fresh Flow instance for an OAuth exchange.""" + return Flow.from_client_config(self.config, self.SCOPES, redirect_uri=self.callback_url) def setup(self): - # Ignore if we're already setup! - if self.client: + """Verify that credentials are valid by attempting to create a Flow. + Called once at startup; raises on bad credentials.""" + if self._setup_verified: return - """Actually create the client, this is separate, so we can catch any errors without breaking everything""" - self.client = Flow.from_client_config(self.config, self.SCOPES, redirect_uri=self.callback_url) + self._create_flow() + self._setup_verified = True def get_redirect_url(self): - """Returns the redirect url for the google oauth flow""" - if self.client is None: - return None + """Returns the redirect url, state, and code_verifier for the google oauth flow. + + The code_verifier must be stored (in the session, in this case) and passed back + to get_credentials() when the callback arrives so PKCE validation succeeds, + even if a different server instance handles the callback. + """ + flow = self._create_flow() # (Url, State ID) - return self.client.authorization_url(access_type='offline', prompt='consent') + url, state = flow.authorization_url(access_type='offline', prompt='consent') + return url, state, flow.code_verifier - def get_credentials(self, code: str): - if self.client is None: - return None + def get_credentials(self, code: str, code_verifier: str | None = None): + flow = self._create_flow() + flow.code_verifier = code_verifier try: - self.client.fetch_token(code=code) - return self.client.credentials + flow.fetch_token(code=code) + return flow.credentials except Warning as e: logging.error(f'[google_client.get_credentials] Google Warning: {str(e)}') # This usually is the "Scope has changed" error. @@ -73,9 +82,6 @@ def get_credentials(self, code: str): def get_profile(self, token): """Retrieve the user's profile associated with the token""" - if self.client is None: - return None - user_info_service = build('oauth2', 'v2', credentials=token) user_info = user_info_service.userinfo().get().execute() diff --git a/backend/src/appointment/routes/google.py b/backend/src/appointment/routes/google.py index 0826f5956..5ee4956fd 100644 --- a/backend/src/appointment/routes/google.py +++ b/backend/src/appointment/routes/google.py @@ -21,6 +21,7 @@ SESSION_OAUTH_STATE = 'google_oauth_state' SESSION_OAUTH_SUBSCRIBER_ID = 'google_oauth_subscriber_id' +SESSION_OAUTH_CODE_VERIFIER = 'google_oauth_code_verifier' @router.get('/ftue-status') @@ -43,10 +44,11 @@ def google_auth( subscriber: Subscriber = Depends(get_subscriber), ): """Starts the google oauth process""" - url, state = google_client.get_redirect_url() + url, state, code_verifier = google_client.get_redirect_url() request.session[SESSION_OAUTH_STATE] = state request.session[SESSION_OAUTH_SUBSCRIBER_ID] = subscriber.id + request.session[SESSION_OAUTH_CODE_VERIFIER] = code_verifier return url @@ -74,8 +76,10 @@ def google_callback( return google_callback_error(is_setup, l10n('google-connect-to-continue')) return google_callback_error(is_setup, l10n('google-sync-fail')) + code_verifier = request.session.get(SESSION_OAUTH_CODE_VERIFIER) + try: - creds = google_client.get_credentials(code) + creds = google_client.get_credentials(code, code_verifier=code_verifier) except GoogleScopeChanged: return google_callback_error(is_setup, l10n('google-scope-changed')) except GoogleInvalidCredentials: @@ -87,6 +91,7 @@ def google_callback( # Clear session keys request.session.pop(SESSION_OAUTH_STATE) request.session.pop(SESSION_OAUTH_SUBSCRIBER_ID) + request.session.pop(SESSION_OAUTH_CODE_VERIFIER, None) if subscriber is None: return google_callback_error(is_setup, l10n('google-auth-fail')) diff --git a/backend/test/integration/test_auth.py b/backend/test/integration/test_auth.py index ce37e31c9..826476ba7 100644 --- a/backend/test/integration/test_auth.py +++ b/backend/test/integration/test_auth.py @@ -781,7 +781,7 @@ def to_json(self): second_google_email = 'user2@gmail.com' mock_profile = {'email': second_google_email, 'id': second_google_id} - def mock_get_credentials(code): + def mock_get_credentials(code, code_verifier=None): return mock_creds def mock_get_profile(token): @@ -801,6 +801,7 @@ def mock_sync_calendars(db, subscriber_id, token, external_connection_id): { 'google_oauth_state': state, 'google_oauth_subscriber_id': subscriber.id, + 'google_oauth_code_verifier': 'mock_code_verifier', }, ) @@ -842,6 +843,122 @@ def mock_sync_calendars(db, subscriber_id, token, external_connection_id): assert second_connection[0].name == second_google_email assert second_connection[0].owner_id == subscriber.id + def test_google_auth_stores_code_verifier_in_session(self, with_client, monkeypatch): + """Test that GET /google/auth stores a code_verifier in the session for PKCE""" + from appointment.controller.apis.google_client import GoogleClient + from appointment.dependencies.google import get_google_client + + mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url') + + def mock_get_redirect_url(): + return 'https://accounts.google.com/o/oauth2/auth?...', 'mock_state', 'mock_code_verifier_abc' + + monkeypatch.setattr(mock_google_client, 'get_redirect_url', mock_get_redirect_url) + + session_data = {} + monkeypatch.setattr('starlette.requests.HTTPConnection.session', session_data) + + with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client + + response = with_client.get('/google/auth', headers=auth_headers) + assert response.status_code == 200 + + assert session_data.get('google_oauth_code_verifier') == 'mock_code_verifier_abc' + assert session_data.get('google_oauth_state') == 'mock_state' + + def test_google_callback_passes_code_verifier_to_get_credentials( + self, with_db, with_client, monkeypatch, make_basic_subscriber + ): + """Test that the callback reads code_verifier from session and passes it to get_credentials""" + from appointment.controller.apis.google_client import GoogleClient + from appointment.dependencies.google import get_google_client + + subscriber = make_basic_subscriber() + + mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url') + + class MockCredentials: + def to_json(self): + return '{"access_token": "tok", "refresh_token": "ref"}' + + captured_verifier = {} + + def mock_get_credentials(code, code_verifier=None): + captured_verifier['value'] = code_verifier + return MockCredentials() + + def mock_get_profile(token): + return {'email': 'test@gmail.com', 'id': 'google_id_999'} + + def mock_sync_calendars(db, subscriber_id, token, external_connection_id): + return False + + monkeypatch.setattr(mock_google_client, 'get_credentials', mock_get_credentials) + monkeypatch.setattr(mock_google_client, 'get_profile', mock_get_profile) + monkeypatch.setattr(mock_google_client, 'sync_calendars', mock_sync_calendars) + + state = 'test_state_verifier' + monkeypatch.setattr( + 'starlette.requests.HTTPConnection.session', + { + 'google_oauth_state': state, + 'google_oauth_subscriber_id': subscriber.id, + 'google_oauth_code_verifier': 'the_real_verifier_xyz', + }, + ) + + with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client + + response = with_client.get( + '/google/callback', params={'code': 'auth_code', 'state': state}, follow_redirects=False + ) + + assert response.status_code == 307 + assert captured_verifier['value'] == 'the_real_verifier_xyz' + + def test_google_callback_clears_code_verifier_from_session( + self, with_db, with_client, monkeypatch, make_basic_subscriber + ): + """Test that the callback clears the code_verifier from the session after use""" + from appointment.controller.apis.google_client import GoogleClient + from appointment.dependencies.google import get_google_client + + subscriber = make_basic_subscriber() + + mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url') + + class MockCredentials: + def to_json(self): + return '{"access_token": "tok", "refresh_token": "ref"}' + + def mock_get_credentials(code, code_verifier=None): + return MockCredentials() + + def mock_get_profile(token): + return {'email': 'clean@gmail.com', 'id': 'google_clean_id'} + + def mock_sync_calendars(db, subscriber_id, token, external_connection_id): + return False + + monkeypatch.setattr(mock_google_client, 'get_credentials', mock_get_credentials) + monkeypatch.setattr(mock_google_client, 'get_profile', mock_get_profile) + monkeypatch.setattr(mock_google_client, 'sync_calendars', mock_sync_calendars) + + state = 'test_state_cleanup' + session_data = { + 'google_oauth_state': state, + 'google_oauth_subscriber_id': subscriber.id, + 'google_oauth_code_verifier': 'verifier_to_be_cleaned', + } + monkeypatch.setattr('starlette.requests.HTTPConnection.session', session_data) + + with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client + + with_client.get('/google/callback', params={'code': 'auth_code', 'state': state}, follow_redirects=False) + + assert 'google_oauth_code_verifier' not in session_data + assert 'google_oauth_state' not in session_data + class TestOIDCToken: """Tests for the OIDC token endpoint""" diff --git a/backend/test/unit/test_google_client.py b/backend/test/unit/test_google_client.py new file mode 100644 index 000000000..9670fd44c --- /dev/null +++ b/backend/test/unit/test_google_client.py @@ -0,0 +1,98 @@ +from unittest.mock import patch, MagicMock + +from appointment.controller.apis.google_client import GoogleClient + + +class TestGoogleClient: + """Tests for GoogleClient OAuth flow and PKCE handling""" + + def _make_client(self): + return GoogleClient('client_id', 'client_secret', 'project_id', 'https://example.com/callback') + + def test_create_flow_returns_fresh_instances(self): + """Each call to _create_flow should return a distinct Flow object""" + client = self._make_client() + flow_a = client._create_flow() + flow_b = client._create_flow() + assert flow_a is not flow_b + + def test_setup_validates_credentials_once(self): + """setup() should attempt to create a Flow to verify creds, then skip on subsequent calls""" + client = self._make_client() + assert client._setup_verified is False + + client.setup() + assert client._setup_verified is True + + with patch.object(client, '_create_flow', wraps=client._create_flow) as spy: + client.setup() + spy.assert_not_called() + + def test_get_redirect_url_returns_code_verifier(self): + """get_redirect_url should return a 3-tuple including a non-None code_verifier""" + client = self._make_client() + result = client.get_redirect_url() + + assert len(result) == 3 + url, state, code_verifier = result + assert url is not None + assert state is not None + assert code_verifier is not None + assert isinstance(code_verifier, str) + assert len(code_verifier) > 0 + + def test_get_redirect_url_generates_unique_verifiers(self): + """Each call to get_redirect_url should produce a different code_verifier""" + client = self._make_client() + _, _, verifier_a = client.get_redirect_url() + _, _, verifier_b = client.get_redirect_url() + assert verifier_a != verifier_b + + def test_get_credentials_passes_code_verifier_to_flow(self): + """get_credentials should set code_verifier on the Flow before calling fetch_token""" + client = self._make_client() + + mock_flow = MagicMock() + mock_flow.credentials = MagicMock() + + with patch.object(client, '_create_flow', return_value=mock_flow): + client.get_credentials('auth_code', code_verifier='test_verifier_123') + + assert mock_flow.code_verifier == 'test_verifier_123' + mock_flow.fetch_token.assert_called_once_with(code='auth_code') + + def test_get_credentials_works_without_code_verifier(self): + """get_credentials should still work when code_verifier is None (backwards compat)""" + client = self._make_client() + + mock_flow = MagicMock() + mock_flow.credentials = MagicMock() + + with patch.object(client, '_create_flow', return_value=mock_flow): + client.get_credentials('auth_code') + + assert mock_flow.code_verifier is None + mock_flow.fetch_token.assert_called_once_with(code='auth_code') + + def test_concurrent_flows_are_isolated(self): + """Simulates two users starting OAuth flows (their code_verifiers must not interfere)""" + client = self._make_client() + + _, _, verifier_user_a = client.get_redirect_url() + _, _, verifier_user_b = client.get_redirect_url() + + assert verifier_user_a != verifier_user_b + + mock_flow = MagicMock() + mock_flow.credentials = MagicMock() + + with patch.object(client, '_create_flow', return_value=mock_flow): + client.get_credentials('code_a', code_verifier=verifier_user_a) + assert mock_flow.code_verifier == verifier_user_a + + mock_flow_b = MagicMock() + mock_flow_b.credentials = MagicMock() + + with patch.object(client, '_create_flow', return_value=mock_flow_b): + client.get_credentials('code_b', code_verifier=verifier_user_b) + assert mock_flow_b.code_verifier == verifier_user_b