diff --git a/CHANGES.md b/CHANGES.md index 303d6849bcb..d833690bcb8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -51,6 +51,8 @@ +- Format files from largest to smallest to improve benefits from concurrency (#4784) + ### Output diff --git a/src/black/__init__.py b/src/black/__init__.py index 079e95cf386..39c2562f523 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -905,7 +905,7 @@ def reformat_one( WriteBack.DIFF, WriteBack.COLOR_DIFF, ): - if not cache.is_changed(src): + if not cache.is_changed(src)[0]: changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( src, fast=fast, write_back=write_back, mode=mode, lines=lines diff --git a/src/black/cache.py b/src/black/cache.py index ef9d99a7b90..7ee06e9b441 100644 --- a/src/black/cache.py +++ b/src/black/cache.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from dataclasses import dataclass, field from pathlib import Path -from typing import NamedTuple +from typing import NamedTuple, Optional, Union from platformdirs import user_cache_dir @@ -99,33 +99,38 @@ def get_file_data(path: Path) -> FileData: hash = Cache.hash_digest(path) return FileData(stat.st_mtime, stat.st_size, hash) - def is_changed(self, source: Path) -> bool: - """Check if source has changed compared to cached version.""" + def is_changed(self, source: Path) -> tuple[bool, Optional[os.stat_result]]: + """Check if source has changed compared to cached version. + + Also returns the stat result that was used.""" res_src = source.resolve() old = self.file_data.get(str(res_src)) if old is None: - return True + return True, None st = res_src.stat() if st.st_size != old.st_size: - return True + return True, st if st.st_mtime != old.st_mtime: new_hash = Cache.hash_digest(res_src) if new_hash != old.hash: - return True - return False + return True, st + return False, st - def filtered_cached(self, sources: Iterable[Path]) -> tuple[set[Path], set[Path]]: + def filtered_cached( + self, sources: Iterable[Path] + ) -> tuple[set[tuple[Optional[os.stat_result], Path]], set[Path]]: """Split an iterable of paths in `sources` into two sets. - The first contains paths of files that modified on disk or are not in the - cache. The other contains paths to non-modified files. + The first contains paths and stat results of files that modified on disk + or are not in the cache. The other contains paths to non-modified files. """ - changed: set[Path] = set() + changed: set[tuple[Optional[os.stat_result], Path]] = set() done: set[Path] = set() for src in sources: - if self.is_changed(src): - changed.add(src) + is_changed, stat = self.is_changed(src) + if is_changed: + changed.add((stat, src)) else: done.add(src) return changed, done diff --git a/src/black/concurrency.py b/src/black/concurrency.py index 53a61456b63..70e250c0d5b 100644 --- a/src/black/concurrency.py +++ b/src/black/concurrency.py @@ -146,14 +146,19 @@ async def schedule_formatting( :func:`format_file_in_place`. """ cache = None if no_cache else Cache.read(mode) - if cache is not None and write_back not in ( + if cache is None or write_back in ( WriteBack.DIFF, WriteBack.COLOR_DIFF, ): - sources, cached = cache.filtered_cached(sources) + sources_with_stats: Iterable[tuple[Optional[os.stat_result], Path]] = ( + (None, src) for src in sources + ) + else: + sources_with_stats, cached = cache.filtered_cached(sources) for src in sorted(cached): report.done(src, Changed.CACHED) - if not sources: + + if not sources_with_stats: return cancelled = [] @@ -170,7 +175,7 @@ async def schedule_formatting( executor, format_file_in_place, src, fast, mode, write_back, lock ) ): src - for src in sorted(sources) + for _, src in sorted(sources_with_stats, key=_sources_sort_key) } pending = tasks.keys() try: @@ -202,3 +207,14 @@ async def schedule_formatting( await asyncio.gather(*cancelled, return_exceptions=True) if sources_to_cache and not no_cache and cache is not None: cache.write(sources_to_cache) + + +def _sources_sort_key( + source: tuple[Optional[os.stat_result], Path], +) -> tuple[int, Path]: + stat, src = source + + if not stat: + stat = src.stat() + + return -stat.st_size, src diff --git a/tests/test_black.py b/tests/test_black.py index 291dc01421e..f3ca9a6e721 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -1113,9 +1113,9 @@ def test_single_file_force_pyi(self) -> None: actual = path.read_text(encoding="utf-8") # verify cache with --pyi is separate pyi_cache = black.Cache.read(pyi_mode) - assert not pyi_cache.is_changed(path) + assert not pyi_cache.is_changed(path)[0] normal_cache = black.Cache.read(DEFAULT_MODE) - assert normal_cache.is_changed(path) + assert normal_cache.is_changed(path)[0] self.assertFormatEqual(expected, actual) black.assert_equivalent(contents, actual) black.assert_stable(contents, actual, pyi_mode) @@ -1140,8 +1140,8 @@ def test_multi_file_force_pyi(self) -> None: pyi_cache = black.Cache.read(pyi_mode) normal_cache = black.Cache.read(reg_mode) for path in paths: - assert not pyi_cache.is_changed(path) - assert normal_cache.is_changed(path) + assert not pyi_cache.is_changed(path)[0] + assert normal_cache.is_changed(path)[0] def test_pipe_force_pyi(self) -> None: source, expected = read_data("miscellaneous", "force_pyi") @@ -1163,9 +1163,9 @@ def test_single_file_force_py36(self) -> None: actual = path.read_text(encoding="utf-8") # verify cache with --target-version is separate py36_cache = black.Cache.read(py36_mode) - assert not py36_cache.is_changed(path) + assert not py36_cache.is_changed(path)[0] normal_cache = black.Cache.read(reg_mode) - assert normal_cache.is_changed(path) + assert normal_cache.is_changed(path)[0] self.assertEqual(actual, expected) @event_loop() @@ -1188,8 +1188,8 @@ def test_multi_file_force_py36(self) -> None: pyi_cache = black.Cache.read(py36_mode) normal_cache = black.Cache.read(reg_mode) for path in paths: - assert not pyi_cache.is_changed(path) - assert normal_cache.is_changed(path) + assert not pyi_cache.is_changed(path)[0] + assert normal_cache.is_changed(path)[0] def test_pipe_force_py36(self) -> None: source, expected = read_data("miscellaneous", "force_py36") @@ -2153,7 +2153,7 @@ def test_cache_broken_file(self) -> None: src.write_text("print('hello')", encoding="utf-8") invokeBlack([str(src)]) cache = black.Cache.read(mode) - assert not cache.is_changed(src) + assert not cache.is_changed(src)[0] def test_cache_single_file_already_cached(self) -> None: mode = DEFAULT_MODE @@ -2182,8 +2182,8 @@ def test_cache_multiple_files(self) -> None: assert one.read_text(encoding="utf-8") == "print('hello')" assert two.read_text(encoding="utf-8") == 'print("hello")\n' cache = black.Cache.read(mode) - assert not cache.is_changed(one) - assert not cache.is_changed(two) + assert not cache.is_changed(one)[0] + assert not cache.is_changed(two)[0] @pytest.mark.incompatible_with_mypyc @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"]) @@ -2293,7 +2293,7 @@ def test_write_cache_read_cache(self) -> None: write_cache = black.Cache.read(mode) write_cache.write([src]) read_cache = black.Cache.read(mode) - assert not read_cache.is_changed(src) + assert not read_cache.is_changed(src)[0] @pytest.mark.incompatible_with_mypyc def test_filter_cached(self) -> None: @@ -2381,8 +2381,8 @@ def test_failed_formatting_does_not_get_cached(self) -> None: clean.write_text('print("hello")\n', encoding="utf-8") invokeBlack([str(workspace)], exit_code=123) cache = black.Cache.read(mode) - assert cache.is_changed(failing) - assert not cache.is_changed(clean) + assert cache.is_changed(failing)[0] + assert not cache.is_changed(clean)[0] def test_write_cache_write_fail(self) -> None: mode = DEFAULT_MODE @@ -2401,9 +2401,9 @@ def test_read_cache_line_lengths(self) -> None: cache = black.Cache.read(mode) cache.write([path]) one = black.Cache.read(mode) - assert not one.is_changed(path) + assert not one.is_changed(path)[0] two = black.Cache.read(short_mode) - assert two.is_changed(path) + assert two.is_changed(path)[0] def test_cache_key(self) -> None: # Test that all members of the mode enum affect the cache key.