Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions backend/src/appointment/controller/apis/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class GoogleClient:
'https://www.googleapis.com/auth/userinfo.email',
'openid',
]
client: Flow | None = None

def __init__(self, client_id, client_secret, project_id, callback_url):
self.config = {
Expand All @@ -39,30 +38,40 @@ def __init__(self, client_id, client_secret, project_id, callback_url):
}

self.callback_url = callback_url
self.client = None
self._setup_verified = False

def _create_flow(self) -> Flow:
"""Create a fresh Flow instance for an OAuth exchange."""
return Flow.from_client_config(self.config, self.SCOPES, redirect_uri=self.callback_url)

def setup(self):
# Ignore if we're already setup!
if self.client:
"""Verify that credentials are valid by attempting to create a Flow.
Called once at startup; raises on bad credentials."""
if self._setup_verified:
return
"""Actually create the client, this is separate, so we can catch any errors without breaking everything"""
self.client = Flow.from_client_config(self.config, self.SCOPES, redirect_uri=self.callback_url)
self._create_flow()
self._setup_verified = True

def get_redirect_url(self):
"""Returns the redirect url for the google oauth flow"""
if self.client is None:
return None
"""Returns the redirect url, state, and code_verifier for the google oauth flow.

The code_verifier must be stored (in the session, in this case) and passed back
to get_credentials() when the callback arrives so PKCE validation succeeds,
even if a different server instance handles the callback.
"""
flow = self._create_flow()

# (Url, State ID)
return self.client.authorization_url(access_type='offline', prompt='consent')
url, state = flow.authorization_url(access_type='offline', prompt='consent')
return url, state, flow.code_verifier

def get_credentials(self, code: str):
if self.client is None:
return None
def get_credentials(self, code: str, code_verifier: str | None = None):
flow = self._create_flow()
flow.code_verifier = code_verifier

try:
self.client.fetch_token(code=code)
return self.client.credentials
flow.fetch_token(code=code)
return flow.credentials
except Warning as e:
logging.error(f'[google_client.get_credentials] Google Warning: {str(e)}')
# This usually is the "Scope has changed" error.
Expand All @@ -73,9 +82,6 @@ def get_credentials(self, code: str):

def get_profile(self, token):
"""Retrieve the user's profile associated with the token"""
if self.client is None:
return None

user_info_service = build('oauth2', 'v2', credentials=token)
user_info = user_info_service.userinfo().get().execute()

Expand Down
9 changes: 7 additions & 2 deletions backend/src/appointment/routes/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

SESSION_OAUTH_STATE = 'google_oauth_state'
SESSION_OAUTH_SUBSCRIBER_ID = 'google_oauth_subscriber_id'
SESSION_OAUTH_CODE_VERIFIER = 'google_oauth_code_verifier'


@router.get('/ftue-status')
Expand All @@ -43,10 +44,11 @@ def google_auth(
subscriber: Subscriber = Depends(get_subscriber),
):
"""Starts the google oauth process"""
url, state = google_client.get_redirect_url()
url, state, code_verifier = google_client.get_redirect_url()

request.session[SESSION_OAUTH_STATE] = state
request.session[SESSION_OAUTH_SUBSCRIBER_ID] = subscriber.id
request.session[SESSION_OAUTH_CODE_VERIFIER] = code_verifier

return url

Expand Down Expand Up @@ -74,8 +76,10 @@ def google_callback(
return google_callback_error(is_setup, l10n('google-connect-to-continue'))
return google_callback_error(is_setup, l10n('google-sync-fail'))

code_verifier = request.session.get(SESSION_OAUTH_CODE_VERIFIER)

try:
creds = google_client.get_credentials(code)
creds = google_client.get_credentials(code, code_verifier=code_verifier)
except GoogleScopeChanged:
return google_callback_error(is_setup, l10n('google-scope-changed'))
except GoogleInvalidCredentials:
Expand All @@ -87,6 +91,7 @@ def google_callback(
# Clear session keys
request.session.pop(SESSION_OAUTH_STATE)
request.session.pop(SESSION_OAUTH_SUBSCRIBER_ID)
request.session.pop(SESSION_OAUTH_CODE_VERIFIER, None)

if subscriber is None:
return google_callback_error(is_setup, l10n('google-auth-fail'))
Expand Down
119 changes: 118 additions & 1 deletion backend/test/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def to_json(self):
second_google_email = 'user2@gmail.com'
mock_profile = {'email': second_google_email, 'id': second_google_id}

def mock_get_credentials(code):
def mock_get_credentials(code, code_verifier=None):
return mock_creds

def mock_get_profile(token):
Expand All @@ -801,6 +801,7 @@ def mock_sync_calendars(db, subscriber_id, token, external_connection_id):
{
'google_oauth_state': state,
'google_oauth_subscriber_id': subscriber.id,
'google_oauth_code_verifier': 'mock_code_verifier',
},
)

Expand Down Expand Up @@ -842,6 +843,122 @@ def mock_sync_calendars(db, subscriber_id, token, external_connection_id):
assert second_connection[0].name == second_google_email
assert second_connection[0].owner_id == subscriber.id

def test_google_auth_stores_code_verifier_in_session(self, with_client, monkeypatch):
"""Test that GET /google/auth stores a code_verifier in the session for PKCE"""
from appointment.controller.apis.google_client import GoogleClient
from appointment.dependencies.google import get_google_client

mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url')

def mock_get_redirect_url():
return 'https://accounts.google.com/o/oauth2/auth?...', 'mock_state', 'mock_code_verifier_abc'

monkeypatch.setattr(mock_google_client, 'get_redirect_url', mock_get_redirect_url)

session_data = {}
monkeypatch.setattr('starlette.requests.HTTPConnection.session', session_data)

with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client

response = with_client.get('/google/auth', headers=auth_headers)
assert response.status_code == 200

assert session_data.get('google_oauth_code_verifier') == 'mock_code_verifier_abc'
assert session_data.get('google_oauth_state') == 'mock_state'

def test_google_callback_passes_code_verifier_to_get_credentials(
self, with_db, with_client, monkeypatch, make_basic_subscriber
):
"""Test that the callback reads code_verifier from session and passes it to get_credentials"""
from appointment.controller.apis.google_client import GoogleClient
from appointment.dependencies.google import get_google_client

subscriber = make_basic_subscriber()

mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url')

class MockCredentials:
def to_json(self):
return '{"access_token": "tok", "refresh_token": "ref"}'

captured_verifier = {}

def mock_get_credentials(code, code_verifier=None):
captured_verifier['value'] = code_verifier
return MockCredentials()

def mock_get_profile(token):
return {'email': 'test@gmail.com', 'id': 'google_id_999'}

def mock_sync_calendars(db, subscriber_id, token, external_connection_id):
return False

monkeypatch.setattr(mock_google_client, 'get_credentials', mock_get_credentials)
monkeypatch.setattr(mock_google_client, 'get_profile', mock_get_profile)
monkeypatch.setattr(mock_google_client, 'sync_calendars', mock_sync_calendars)

state = 'test_state_verifier'
monkeypatch.setattr(
'starlette.requests.HTTPConnection.session',
{
'google_oauth_state': state,
'google_oauth_subscriber_id': subscriber.id,
'google_oauth_code_verifier': 'the_real_verifier_xyz',
},
)

with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client

response = with_client.get(
'/google/callback', params={'code': 'auth_code', 'state': state}, follow_redirects=False
)

assert response.status_code == 307
assert captured_verifier['value'] == 'the_real_verifier_xyz'

def test_google_callback_clears_code_verifier_from_session(
self, with_db, with_client, monkeypatch, make_basic_subscriber
):
"""Test that the callback clears the code_verifier from the session after use"""
from appointment.controller.apis.google_client import GoogleClient
from appointment.dependencies.google import get_google_client

subscriber = make_basic_subscriber()

mock_google_client = GoogleClient('client_id', 'client_secret', 'project_id', 'callback_url')

class MockCredentials:
def to_json(self):
return '{"access_token": "tok", "refresh_token": "ref"}'

def mock_get_credentials(code, code_verifier=None):
return MockCredentials()

def mock_get_profile(token):
return {'email': 'clean@gmail.com', 'id': 'google_clean_id'}

def mock_sync_calendars(db, subscriber_id, token, external_connection_id):
return False

monkeypatch.setattr(mock_google_client, 'get_credentials', mock_get_credentials)
monkeypatch.setattr(mock_google_client, 'get_profile', mock_get_profile)
monkeypatch.setattr(mock_google_client, 'sync_calendars', mock_sync_calendars)

state = 'test_state_cleanup'
session_data = {
'google_oauth_state': state,
'google_oauth_subscriber_id': subscriber.id,
'google_oauth_code_verifier': 'verifier_to_be_cleaned',
}
monkeypatch.setattr('starlette.requests.HTTPConnection.session', session_data)

with_client.app.dependency_overrides[get_google_client] = lambda: mock_google_client

with_client.get('/google/callback', params={'code': 'auth_code', 'state': state}, follow_redirects=False)

assert 'google_oauth_code_verifier' not in session_data
assert 'google_oauth_state' not in session_data


class TestOIDCToken:
"""Tests for the OIDC token endpoint"""
Expand Down
98 changes: 98 additions & 0 deletions backend/test/unit/test_google_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from unittest.mock import patch, MagicMock

from appointment.controller.apis.google_client import GoogleClient


class TestGoogleClient:
"""Tests for GoogleClient OAuth flow and PKCE handling"""

def _make_client(self):
return GoogleClient('client_id', 'client_secret', 'project_id', 'https://example.com/callback')

def test_create_flow_returns_fresh_instances(self):
"""Each call to _create_flow should return a distinct Flow object"""
client = self._make_client()
flow_a = client._create_flow()
flow_b = client._create_flow()
assert flow_a is not flow_b

def test_setup_validates_credentials_once(self):
"""setup() should attempt to create a Flow to verify creds, then skip on subsequent calls"""
client = self._make_client()
assert client._setup_verified is False

client.setup()
assert client._setup_verified is True

with patch.object(client, '_create_flow', wraps=client._create_flow) as spy:
client.setup()
spy.assert_not_called()

def test_get_redirect_url_returns_code_verifier(self):
"""get_redirect_url should return a 3-tuple including a non-None code_verifier"""
client = self._make_client()
result = client.get_redirect_url()

assert len(result) == 3
url, state, code_verifier = result
assert url is not None
assert state is not None
assert code_verifier is not None
assert isinstance(code_verifier, str)
assert len(code_verifier) > 0

def test_get_redirect_url_generates_unique_verifiers(self):
"""Each call to get_redirect_url should produce a different code_verifier"""
client = self._make_client()
_, _, verifier_a = client.get_redirect_url()
_, _, verifier_b = client.get_redirect_url()
assert verifier_a != verifier_b

def test_get_credentials_passes_code_verifier_to_flow(self):
"""get_credentials should set code_verifier on the Flow before calling fetch_token"""
client = self._make_client()

mock_flow = MagicMock()
mock_flow.credentials = MagicMock()

with patch.object(client, '_create_flow', return_value=mock_flow):
client.get_credentials('auth_code', code_verifier='test_verifier_123')

assert mock_flow.code_verifier == 'test_verifier_123'
mock_flow.fetch_token.assert_called_once_with(code='auth_code')

def test_get_credentials_works_without_code_verifier(self):
"""get_credentials should still work when code_verifier is None (backwards compat)"""
client = self._make_client()

mock_flow = MagicMock()
mock_flow.credentials = MagicMock()

with patch.object(client, '_create_flow', return_value=mock_flow):
client.get_credentials('auth_code')

assert mock_flow.code_verifier is None
mock_flow.fetch_token.assert_called_once_with(code='auth_code')

def test_concurrent_flows_are_isolated(self):
"""Simulates two users starting OAuth flows (their code_verifiers must not interfere)"""
client = self._make_client()

_, _, verifier_user_a = client.get_redirect_url()
_, _, verifier_user_b = client.get_redirect_url()

assert verifier_user_a != verifier_user_b

mock_flow = MagicMock()
mock_flow.credentials = MagicMock()

with patch.object(client, '_create_flow', return_value=mock_flow):
client.get_credentials('code_a', code_verifier=verifier_user_a)
assert mock_flow.code_verifier == verifier_user_a

mock_flow_b = MagicMock()
mock_flow_b.credentials = MagicMock()

with patch.object(client, '_create_flow', return_value=mock_flow_b):
client.get_credentials('code_b', code_verifier=verifier_user_b)
assert mock_flow_b.code_verifier == verifier_user_b
Loading