Skip to content

Commit e706a5c

Browse files
committed
refactor(feature-store): Replace input() with acknowledge_risk param
Add acknowledge_risk: Optional[bool] = None to enable_lake_formation() and LakeFormationConfig. None triggers interactive input() prompt, True proceeds without prompting, False aborts with RuntimeError. Removes all builtins.input mocking from unit and integration tests. Tests now pass acknowledge_risk=True or False directly. Removes one duplicate test that became identical after the refactor. --- X-AI-Prompt: add y/n confirmation for disable_hybrid_access_mode=True, then refactor to use acknowledge_risk param instead of input() X-AI-Tool: kiro-cli
1 parent 4e661a7 commit e706a5c

File tree

3 files changed

+63
-39
lines changed

3 files changed

+63
-39
lines changed

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group_manager.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,17 @@ class LakeFormationConfig(Base):
4545
may break existing jobs that access the table via IAM-based permissions. After
4646
this change, all principals must be granted access through Lake Formation.
4747
If False, IAM-based access remains alongside Lake Formation permissions.
48+
acknowledge_risk: Controls confirmation behavior for risky Lake Formation operations.
49+
If True, skips interactive confirmation prompts and proceeds. If False, raises
50+
RuntimeError without proceeding. If None (default), prompts the user interactively
51+
via input().
4852
"""
4953

5054
enabled: bool = False
5155
use_service_linked_role: bool = True
5256
registration_role_arn: Optional[str] = None
5357
disable_hybrid_access_mode: bool
58+
acknowledge_risk: Optional[bool] = None
5459

5560

5661
class FeatureGroupManager(FeatureGroup):
@@ -375,6 +380,7 @@ def _generate_s3_deny_statements(
375380
def enable_lake_formation(
376381
self,
377382
disable_hybrid_access_mode: bool,
383+
acknowledge_risk: Optional[bool] = None,
378384
session: Optional[Session] = None,
379385
region: Optional[str] = None,
380386
use_service_linked_role: bool = True,
@@ -399,6 +405,9 @@ def enable_lake_formation(
399405
the table via IAM-based permissions. After this change, all principals must
400406
be granted access through Lake Formation. If False, prompts the user for
401407
confirmation before proceeding with hybrid access.
408+
acknowledge_risk: Controls confirmation behavior for risky operations.
409+
If True, skips interactive prompts and proceeds. If False, raises
410+
RuntimeError without proceeding. If None (default), prompts interactively.
402411
session: Boto3 session.
403412
region: Region name.
404413
use_service_linked_role: Whether to use the Lake Formation service-linked role
@@ -496,12 +505,15 @@ def enable_lake_formation(
496505
f"to the table is still allowed alongside Lake Formation permissions. "
497506
f"For more info: https://docs.aws.amazon.com/lake-formation/latest/dg/hybrid-access-mode.html"
498507
)
499-
proceed = input(
500-
"Hybrid access mode is not disabled. IAM-based access to the Glue table will "
501-
"still be allowed. Do you want to proceed without revoking IAMAllowedPrincipal "
502-
"permissions? (y/n): "
503-
).strip().lower()
504-
if proceed != "y":
508+
if acknowledge_risk is None:
509+
proceed = input(
510+
"Hybrid access mode is not disabled. IAM-based access to the Glue table will "
511+
"still be allowed. Do you want to proceed without revoking IAMAllowedPrincipal "
512+
"permissions? (y/n): "
513+
).strip().lower() == "y"
514+
else:
515+
proceed = acknowledge_risk
516+
if not proceed:
505517
raise RuntimeError(
506518
"User chose not to proceed without disabling hybrid access mode. "
507519
"Re-run with disable_hybrid_access_mode=True to revoke IAMAllowedPrincipal permissions."
@@ -514,6 +526,18 @@ def enable_lake_formation(
514526
f"After this change, all principals must be granted access through Lake Formation. "
515527
f"For more info: https://docs.aws.amazon.com/lake-formation/latest/dg/hybrid-access-mode.html"
516528
)
529+
if acknowledge_risk is None:
530+
proceed = input(
531+
"This will revoke IAMAllowedPrincipal permissions and may break existing jobs "
532+
"that rely on IAM-based access. Do you want to proceed? (y/n): "
533+
).strip().lower() == "y"
534+
else:
535+
proceed = acknowledge_risk
536+
if not proceed:
537+
raise RuntimeError(
538+
"User chose not to proceed with disabling hybrid access mode. "
539+
"Re-run with disable_hybrid_access_mode=False to keep IAMAllowedPrincipal permissions."
540+
)
517541

518542

519543
results = {
@@ -805,5 +829,6 @@ def create(
805829
use_service_linked_role=lake_formation_config.use_service_linked_role,
806830
registration_role_arn=lake_formation_config.registration_role_arn,
807831
disable_hybrid_access_mode=lake_formation_config.disable_hybrid_access_mode,
832+
acknowledge_risk=lake_formation_config.acknowledge_risk,
808833
)
809834
return feature_group

sagemaker-mlops/tests/integ/test_feature_store_lakeformation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import logging
1313
import uuid
14-
1514
import boto3
1615
import pytest
1716
from botocore.exceptions import ClientError
@@ -161,7 +160,7 @@ def test_create_feature_group_and_enable_lake_formation(s3_uri, role, region):
161160
assert fg.feature_group_status == "Created"
162161

163162
# Enable Lake Formation governance
164-
result = fg.enable_lake_formation(disable_hybrid_access_mode=True)
163+
result = fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
165164

166165
# Verify all phases completed successfully
167166
assert result["s3_location_registered"] is True
@@ -197,7 +196,8 @@ def test_create_feature_group_with_lake_formation_enabled(s3_uri, role, region):
197196
offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri))
198197
lake_formation_config = LakeFormationConfig(
199198
enabled=True,
200-
disable_hybrid_access_mode = True
199+
disable_hybrid_access_mode = True,
200+
acknowledge_risk=True,
201201
)
202202

203203
fg = FeatureGroupManager.create(
@@ -460,6 +460,7 @@ def test_enable_lake_formation_fails_with_nonexistent_role(
460460
use_service_linked_role=False,
461461
registration_role_arn=nonexistent_role,
462462
disable_hybrid_access_mode=True,
463+
acknowledge_risk=True,
463464
)
464465

465466
# Verify we got an appropriate error
@@ -501,7 +502,7 @@ def test_enable_lake_formation_full_flow_with_policy_output(s3_uri, role, region
501502

502503
# Enable Lake Formation governance
503504
with caplog.at_level(logging.WARNING, logger="sagemaker.mlops.feature_store.feature_group_manager"):
504-
result = fg.enable_lake_formation(disable_hybrid_access_mode=True)
505+
result = fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
505506

506507
# Verify all phases completed successfully
507508
assert result["s3_location_registered"] is True
@@ -544,7 +545,7 @@ def test_enable_lake_formation_default_logs_recommended_policy(s3_uri, role, reg
544545

545546
# Enable Lake Formation governance with disable_hybrid_access_mode=True
546547
with caplog.at_level(logging.WARNING, logger="sagemaker.mlops.feature_store.feature_group_manager"):
547-
result = fg.enable_lake_formation(disable_hybrid_access_mode=True)
548+
result = fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
548549

549550
# Verify phases completed successfully
550551
assert result["s3_location_registered"] is True
@@ -588,6 +589,7 @@ def test_enable_lake_formation_with_custom_role_logs_policy(s3_uri, role, region
588589
use_service_linked_role=False,
589590
registration_role_arn=role,
590591
disable_hybrid_access_mode=True,
592+
acknowledge_risk=True,
591593
)
592594

593595
# Verify all phases completed successfully

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_group_manager.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def test_wait_for_active_calls_wait_for_status(
440440
mock_revoke.return_value = True
441441

442442
# Call with wait_for_active=True
443-
fg.enable_lake_formation(wait_for_active=True, disable_hybrid_access_mode=True)
443+
fg.enable_lake_formation(wait_for_active=True, disable_hybrid_access_mode=True, acknowledge_risk=True)
444444

445445
# Verify wait_for_status was called with "Created"
446446
mock_wait.assert_called_once_with(target_status="Created")
@@ -478,7 +478,7 @@ def test_wait_for_active_false_does_not_call_wait(
478478
mock_revoke.return_value = True
479479

480480
# Call with wait_for_active=False (default)
481-
fg.enable_lake_formation(wait_for_active=False, disable_hybrid_access_mode=True)
481+
fg.enable_lake_formation(wait_for_active=False, disable_hybrid_access_mode=True, acknowledge_risk=True)
482482

483483
# Verify wait_for_status was NOT called
484484
mock_wait.assert_not_called()
@@ -564,7 +564,7 @@ def test_fail_fast_phase_execution(
564564
with pytest.raises(
565565
RuntimeError, match="Failed to register S3 location with Lake Formation"
566566
):
567-
fg.enable_lake_formation(disable_hybrid_access_mode=True)
567+
fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
568568

569569
# Verify Phase 1 was called but Phase 2 and 3 were not
570570
mock_register.assert_called_once()
@@ -583,7 +583,7 @@ def test_fail_fast_phase_execution(
583583
mock_revoke.return_value = True
584584

585585
with pytest.raises(RuntimeError, match="Failed to grant Lake Formation permissions"):
586-
fg.enable_lake_formation(disable_hybrid_access_mode=True)
586+
fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
587587

588588
# Verify Phase 1 and 2 were called but Phase 3 was not
589589
mock_register.assert_called_once()
@@ -603,7 +603,7 @@ def test_fail_fast_phase_execution(
603603
mock_revoke.side_effect = Exception("Phase 3 failed")
604604

605605
with pytest.raises(RuntimeError, match="Failed to revoke IAMAllowedPrincipal permissions"):
606-
fg.enable_lake_formation(disable_hybrid_access_mode=True)
606+
fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
607607

608608
# Verify all phases were called
609609
mock_register.assert_called_once()
@@ -967,6 +967,7 @@ def test_enable_lake_formation_called_when_enabled(
967967
use_service_linked_role=True,
968968
registration_role_arn=None,
969969
disable_hybrid_access_mode=False,
970+
acknowledge_risk=None,
970971
)
971972
# Verify the feature group was returned
972973
assert result == mock_fg
@@ -1175,6 +1176,7 @@ def test_use_service_linked_role_extraction_from_config(
11751176
use_service_linked_role=use_slr,
11761177
registration_role_arn=expected_registration_role,
11771178
disable_hybrid_access_mode=False,
1179+
acknowledge_risk=None,
11781180
)
11791181
# Verify the feature group was returned
11801182
assert result == mock_fg
@@ -1213,45 +1215,34 @@ def test_revoke_called_when_disable_hybrid_access_mode_true(
12131215
mock_grant.return_value = True
12141216
mock_revoke.return_value = True
12151217

1216-
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=True)
1218+
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
12171219

12181220
mock_revoke.assert_called_once()
12191221
assert result["hybrid_access_mode_disabled"] is True
12201222

1221-
@patch("builtins.input", return_value="y")
12221223
@patch.object(FeatureGroupManager, "refresh")
12231224
@patch.object(FeatureGroupManager, "_register_s3_with_lake_formation")
12241225
@patch.object(FeatureGroupManager, "_grant_lake_formation_permissions")
12251226
@patch.object(FeatureGroupManager, "_revoke_iam_allowed_principal")
12261227
def test_revoke_not_called_when_disable_hybrid_access_mode_false(
1227-
self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_input
1228+
self, mock_revoke, mock_grant, mock_register, mock_refresh
12281229
):
12291230
"""Test that IAMAllowedPrincipal is NOT revoked when disable_hybrid_access_mode=False."""
12301231
mock_register.return_value = True
12311232
mock_grant.return_value = True
12321233

1233-
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=False)
1234+
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=False, acknowledge_risk=True)
12341235

12351236
mock_revoke.assert_not_called()
12361237
assert result["hybrid_access_mode_disabled"] is False
12371238

1238-
@patch("builtins.input", return_value="n")
12391239
@patch.object(FeatureGroupManager, "refresh")
12401240
def test_raises_error_when_user_declines_hybrid_access_prompt(
1241-
self, mock_refresh, mock_input
1241+
self, mock_refresh
12421242
):
12431243
"""Test that RuntimeError is raised when user declines the hybrid access prompt."""
12441244
with pytest.raises(RuntimeError, match="User chose not to proceed"):
1245-
self.fg.enable_lake_formation(disable_hybrid_access_mode=False)
1246-
1247-
@patch("builtins.input", return_value="")
1248-
@patch.object(FeatureGroupManager, "refresh")
1249-
def test_raises_error_when_user_enters_empty_at_hybrid_access_prompt(
1250-
self, mock_refresh, mock_input
1251-
):
1252-
"""Test that RuntimeError is raised when user enters empty string at prompt."""
1253-
with pytest.raises(RuntimeError, match="User chose not to proceed"):
1254-
self.fg.enable_lake_formation(disable_hybrid_access_mode=False)
1245+
self.fg.enable_lake_formation(disable_hybrid_access_mode=False, acknowledge_risk=False)
12551246

12561247

12571248
class TestCreateWithLakeFormationDisableHybridAccessMode:
@@ -1305,6 +1296,7 @@ def test_enable_lake_formation_called_with_disable_hybrid_access_mode(
13051296
use_service_linked_role=True,
13061297
registration_role_arn=None,
13071298
disable_hybrid_access_mode=True,
1299+
acknowledge_risk=None,
13081300
)
13091301

13101302

@@ -1623,7 +1615,7 @@ def test_uses_service_linked_role_arn_when_use_service_linked_role_true(
16231615
mock_revoke.return_value = True
16241616
mock_generate.return_value = []
16251617

1626-
fg.enable_lake_formation(use_service_linked_role=True, disable_hybrid_access_mode=True)
1618+
fg.enable_lake_formation(use_service_linked_role=True, disable_hybrid_access_mode=True, acknowledge_risk=True)
16271619

16281620
expected_slr_arn = "arn:aws:iam::123456789012:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess"
16291621
mock_generate.assert_called_once()
@@ -1666,7 +1658,7 @@ def test_uses_service_linked_role_arn_by_default(
16661658
mock_revoke.return_value = True
16671659
mock_generate.return_value = []
16681660

1669-
fg.enable_lake_formation(disable_hybrid_access_mode=True)
1661+
fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
16701662

16711663
expected_slr_arn = "arn:aws:iam::987654321098:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess"
16721664
mock_generate.assert_called_once()
@@ -1709,7 +1701,7 @@ def test_service_linked_role_arn_uses_correct_account_id(
17091701
mock_revoke.return_value = True
17101702
mock_generate.return_value = []
17111703

1712-
fg.enable_lake_formation(use_service_linked_role=True, disable_hybrid_access_mode=True)
1704+
fg.enable_lake_formation(use_service_linked_role=True, disable_hybrid_access_mode=True, acknowledge_risk=True)
17131705

17141706
expected_slr_arn = f"arn:aws:iam::{account_id}:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess"
17151707
mock_generate.assert_called_once()
@@ -1763,6 +1755,7 @@ def test_uses_registration_role_arn_when_use_service_linked_role_false(
17631755
use_service_linked_role=False,
17641756
registration_role_arn=custom_registration_role,
17651757
disable_hybrid_access_mode=True,
1758+
acknowledge_risk=True,
17661759
)
17671760

17681761
mock_generate.assert_called_once()
@@ -1813,6 +1806,7 @@ def test_registration_role_arn_passed_to_s3_registration(
18131806
use_service_linked_role=False,
18141807
registration_role_arn=custom_registration_role,
18151808
disable_hybrid_access_mode=True,
1809+
acknowledge_risk=True,
18161810
)
18171811

18181812
mock_register.assert_called_once()
@@ -1860,6 +1854,7 @@ def test_different_registration_role_arns_produce_different_policies(
18601854
use_service_linked_role=False,
18611855
registration_role_arn=first_role,
18621856
disable_hybrid_access_mode=True,
1857+
acknowledge_risk=True,
18631858
)
18641859
first_call_kwargs = mock_generate.call_args[1]
18651860
first_lf_role = first_call_kwargs["lake_formation_role_arn"]
@@ -1874,6 +1869,7 @@ def test_different_registration_role_arns_produce_different_policies(
18741869
use_service_linked_role=False,
18751870
registration_role_arn=second_role,
18761871
disable_hybrid_access_mode=True,
1872+
acknowledge_risk=True,
18771873
)
18781874
second_call_kwargs = mock_generate.call_args[1]
18791875
second_lf_role = second_call_kwargs["lake_formation_role_arn"]
@@ -1972,7 +1968,7 @@ def test_iceberg_strips_data_suffix_for_s3_registration(
19721968
mock_grant.return_value = True
19731969
mock_revoke.return_value = True
19741970

1975-
self.fg.enable_lake_formation(disable_hybrid_access_mode=True)
1971+
self.fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
19761972

19771973
# The registered S3 location should NOT end with /data
19781974
call_args = mock_register.call_args
@@ -1997,7 +1993,7 @@ def test_non_iceberg_keeps_full_s3_path(
19971993
mock_grant.return_value = True
19981994
mock_revoke.return_value = True
19991995

2000-
self.fg.enable_lake_formation(disable_hybrid_access_mode=True)
1996+
self.fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
20011997

20021998
call_args = mock_register.call_args
20031999
registered_location = call_args[0][0]
@@ -2036,7 +2032,7 @@ def test_raises_error_when_feature_group_arn_is_none(
20362032
mock_revoke.return_value = True
20372033

20382034
with pytest.raises(ValueError, match="Feature Group ARN is required"):
2039-
fg.enable_lake_formation(disable_hybrid_access_mode=True)
2035+
fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
20402036

20412037

20422038
class TestEnableLakeFormationHappyPath:
@@ -2071,7 +2067,7 @@ def test_returns_all_true_on_success(
20712067
mock_grant.return_value = True
20722068
mock_revoke.return_value = True
20732069

2074-
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=True)
2070+
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=True, acknowledge_risk=True)
20752071

20762072
assert result == {
20772073
"s3_location_registered": True,
@@ -2128,6 +2124,7 @@ def test_session_and_region_passed_to_enable_lake_formation(
21282124
use_service_linked_role=True,
21292125
registration_role_arn=None,
21302126
disable_hybrid_access_mode=False,
2127+
acknowledge_risk=None,
21312128
)
21322129

21332130

0 commit comments

Comments
 (0)