diff --git a/CHANGES.md b/CHANGES.md index 36cbd0039e0..eefcc8cfa71 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -22,6 +22,7 @@ - Fix `string_processing` crashing on unassigned long string literals with trailing commas (one-item tuples) (#4929) - Simplify implementation of the power operator "hugging" logic (#4918) +- Always wrap `in` clauses of comprehensions if they have delimiters (#4881) ### Configuration diff --git a/src/black/brackets.py b/src/black/brackets.py index 44a3c9a2946..f029dec7c7e 100644 --- a/src/black/brackets.py +++ b/src/black/brackets.py @@ -340,8 +340,14 @@ def max_delimiter_priority_in_atom(node: LN) -> Priority: if not (first.type == token.LPAR and last.type == token.RPAR): return 0 + return max_delimiter_priority(node.children[1:-1]) + + +def max_delimiter_priority(children: list[LN]) -> Priority: + """Return maximum delimiter priority in a list of leaves and nodes.""" + bt = BracketTracker() - for c in node.children[1:-1]: + for c in children: if isinstance(c, Leaf): bt.mark(c) else: diff --git a/src/black/linegen.py b/src/black/linegen.py index f907070fe25..6cb68c863e4 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -15,6 +15,7 @@ DOT_PRIORITY, STRING_PRIORITY, get_leaves_inside_matching_brackets, + max_delimiter_priority, max_delimiter_priority_in_atom, ) from black.comments import ( @@ -639,9 +640,43 @@ def visit_tstring(self, node: Node) -> Iterator[Line]: def visit_comp_for(self, node: Node) -> Iterator[Line]: if Preview.wrap_comprehension_in in self.mode: - normalize_invisible_parens( - node, parens_after={"in"}, mode=self.mode, features=self.features - ) + check_lpar = False + for child in node.children: + if not check_lpar: + check_lpar = isinstance(child, Leaf) and child.value == "in" + continue + + check_lpar = isinstance(child, Leaf) and child.value == "in" + + is_wrapped = ( + len(child.children) == 3 + and is_lpar_token(child.children[0]) + and is_rpar_token(child.children[-1]) + ) + max_delimiter = max_delimiter_priority( + (child.children[1] if is_wrapped else child).children + ) + + if child.type == syms.atom: + if is_wrapped and child.children[1].type == syms.test: + continue + if max_delimiter > DOT_PRIORITY: + if not is_wrapped and maybe_make_parens_invisible_in_atom( + child, + parent=node, + mode=self.mode, + features=self.features, + ): + wrap_in_parentheses(node, child, visible=True) + elif maybe_make_parens_invisible_in_atom( + child, parent=node, mode=self.mode, features=self.features + ): + wrap_in_parentheses(node, child, visible=False) + else: + wrap_in_parentheses( + node, child, visible=max_delimiter > DOT_PRIORITY + ) + yield from self.visit_default(node) def visit_old_comp_for(self, node: Node) -> Iterator[Line]: diff --git a/tests/data/cases/preview_wrap_comprehension_in.py b/tests/data/cases/preview_wrap_comprehension_in.py index e457f0e772f..2d51a269f3e 100644 --- a/tests/data/cases/preview_wrap_comprehension_in.py +++ b/tests/data/cases/preview_wrap_comprehension_in.py @@ -70,6 +70,14 @@ ) } +# Short `in`s with delimiters +filtered_trait_instructions: list[str] = [ + trait for trait in (traits_instructions or []) if trait +] +filtered_trait_instructions: list[str] = [ + trait for trait in traits_instructions or [] if trait +] + # output [ a @@ -159,3 +167,11 @@ key_with_super_really_long_name: key_with_super_really_long_name for key in dictionary } + +# Short `in`s with delimiters +filtered_trait_instructions: list[str] = [ + trait for trait in (traits_instructions or []) if trait +] +filtered_trait_instructions: list[str] = [ + trait for trait in (traits_instructions or []) if trait +] diff --git a/tests/optional.py b/tests/optional.py index f9bceb6f9ff..84960481cbe 100644 --- a/tests/optional.py +++ b/tests/optional.py @@ -91,7 +91,7 @@ def pytest_configure(config: "Config") -> None: passed_args = config.getoption("run_optional") if passed_args: ot_run.update(itertools.chain.from_iterable(a.split(",") for a in passed_args)) - ot_run |= {no(excluded) for excluded in ot_markers - ot_run} + ot_run |= {no(excluded) for excluded in (ot_markers - ot_run)} ot_markers |= {no(m) for m in ot_markers} log.info("optional tests to run: %s", ot_run) diff --git a/tests/test_docs.py b/tests/test_docs.py index e09124acb2e..271905260c9 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -64,7 +64,7 @@ def test_feature_lists_are_up_to_date() -> None: future_style = f.readlines() preview_error = check_feature_list( future_style, - {feature.name for feature in set(Preview) - UNSTABLE_FEATURES}, + {feature.name for feature in (set(Preview) - UNSTABLE_FEATURES)}, "preview", ) assert preview_error is None, preview_error