diff --git a/reactivex/operators/_retry.py b/reactivex/operators/_retry.py index 10920687..c86502eb 100644 --- a/reactivex/operators/_retry.py +++ b/reactivex/operators/_retry.py @@ -1,7 +1,7 @@ from typing import TypeVar import reactivex -from reactivex import Observable +from reactivex import Observable, abc from reactivex.internal import curry_flip from reactivex.internal.utils import infinite @@ -17,6 +17,10 @@ def retry_( times or until it successfully terminates. If the retry count is not specified, it retries indefinitely. + The retry budget is per-subscription, so combining ``retry(n)`` with + ``repeat()`` works as expected: each resubscription by ``repeat()`` + starts with a fresh retry allowance. + Examples: >>> result = source.pipe(retry()) >>> result = retry()(source) @@ -32,12 +36,21 @@ def retry_( sequence repeatedly until it terminates successfully. """ - if retry_count is None: - gen = infinite() - else: - gen = range(retry_count) - - return reactivex.catch_with_iterable(source for _ in gen) + def subscribe( + observer: abc.ObserverBase[_T], scheduler_: abc.SchedulerBase | None = None + ) -> abc.DisposableBase: + # Create a fresh generator on every subscription so that the retry + # budget is not shared across resubscriptions (e.g. via repeat()). + if retry_count is None: + gen = infinite() + else: + gen = range(retry_count) + + return reactivex.catch_with_iterable(source for _ in gen).subscribe( + observer, scheduler=scheduler_ + ) + + return Observable(subscribe) __all__ = ["retry_"] diff --git a/tests/test_observable/test_retry.py b/tests/test_observable/test_retry.py index b9f4bd7d..9f9ac083 100644 --- a/tests/test_observable/test_retry.py +++ b/tests/test_observable/test_retry.py @@ -193,6 +193,31 @@ def dispose(_, __): with pytest.raises(Exception): xss.subscribe() + def test_retry_with_count_combined_with_repeat(self): + """retry(n) should reset its budget per subscription so repeat() works correctly. + + Regression test for https://github.com/ReactiveX/RxPY/issues/712. + """ + scheduler = TestScheduler() + xs = scheduler.create_cold_observable(on_next(90, 42), on_completed(200)) + + result = scheduler.start( + lambda: xs.pipe(ops.retry(2), ops.repeat()), + disposed=1000, + ) + assert result.messages == [ + on_next(290, 42), + on_next(490, 42), + on_next(690, 42), + on_next(890, 42), + ] + assert xs.subscriptions == [ + subscribe(200, 400), + subscribe(400, 600), + subscribe(600, 800), + subscribe(800, 1000), + ] + if __name__ == "__main__": unittest.main()