diff --git a/CHANGES.md b/CHANGES.md index a8932714a92..4e0458c03e7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -25,6 +25,8 @@ - Fix `string_processing` crashing on unassigned long string literals with trailing commas (one-item tuples) (#4929) - Simplify implementation of the power operator "hugging" logic (#4918) +- Parenthesize complex expressions passed as keyword arguments or parameter defaults + (#4925) ### Configuration diff --git a/docs/the_black_code_style/future_style.md b/docs/the_black_code_style/future_style.md index 16a55e4e03e..31f4a72d6cf 100644 --- a/docs/the_black_code_style/future_style.md +++ b/docs/the_black_code_style/future_style.md @@ -20,6 +20,8 @@ Currently, the following features are included in the preview style: "hugging" logic (removing whitespace around `**` in simple expressions), which applies also in the rare case the exponentiation is split into separate lines. ([see below](labels/simplify-power-operator)) +- `arg_parens`: Parenthesize complex expressions passed as keyword arguments or + parameter defaults. For example, `foo(bar=x + y)` becomes `foo(bar=(x + y))`. - `wrap_long_dict_values_in_parens`: Add parentheses around long values in dictionaries. ([see below](labels/wrap-long-dict-values)) - `fix_if_guard_explosion_in_case_statement`: fixed exploding of the if guard in case diff --git a/scripts/fuzz.py b/scripts/fuzz.py index 915a036b4ae..c770da3abec 100644 --- a/scripts/fuzz.py +++ b/scripts/fuzz.py @@ -24,11 +24,11 @@ # Note that while Hypothesmith might generate code unlike that written by # humans, it's a general test that should pass for any *valid* source code. # (so e.g. running it against code scraped of the internet might also help) - src_contents=hypothesmith.from_grammar() | hypothesmith.from_node(), + src_contents=(hypothesmith.from_grammar() | hypothesmith.from_node()), # Using randomly-varied modes helps us to exercise less common code paths. mode=st.builds( black.FileMode, - line_length=st.just(88) | st.integers(0, 200), + line_length=(st.just(88) | st.integers(0, 200)), string_normalization=st.booleans(), preview=st.booleans(), is_pyi=st.booleans(), diff --git a/src/black/__init__.py b/src/black/__init__.py index 6eece7eada2..970aa0e6934 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -634,8 +634,8 @@ def main( is_pyi=pyi, is_ipynb=ipynb, skip_source_first_line=skip_source_first_line, - string_normalization=not skip_string_normalization, - magic_trailing_comma=not skip_magic_trailing_comma, + string_normalization=(not skip_string_normalization), + magic_trailing_comma=(not skip_magic_trailing_comma), preview=preview, unstable=unstable, python_cell_magics=set(python_cell_magics), @@ -787,7 +787,7 @@ def get_sources( path = Path(f"{STDIN_PLACEHOLDER}{path}") if path.suffix == ".ipynb" and not jupyter_dependencies_are_installed( - warn=verbose or not quiet + warn=(verbose or not quiet) ): continue diff --git a/src/black/brackets.py b/src/black/brackets.py index 44a3c9a2946..f76bedda803 100644 --- a/src/black/brackets.py +++ b/src/black/brackets.py @@ -238,7 +238,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf | None = None) -> Prior Higher numbers are higher priority. """ - if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): + if is_vararg(leaf, within=(VARARGS_PARENTS | UNPACKING_PARENTS)): # * and ** might also be MATH_OPERATORS but in this case they are not. # Don't treat them as a delimiter. return 0 diff --git a/src/black/comments.py b/src/black/comments.py index b3dda1d2a33..94cd32cf8e6 100644 --- a/src/black/comments.py +++ b/src/black/comments.py @@ -77,7 +77,7 @@ def generate_comments(leaf: LN, mode: Mode) -> Iterator[Leaf]: """ total_consumed = 0 for pc in list_comments( - leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER, mode=mode + leaf.prefix, is_endmarker=(leaf.type == token.ENDMARKER), mode=mode ): total_consumed = pc.consumed prefix = make_simple_prefix(pc.newlines, pc.form_feed) diff --git a/src/black/files.py b/src/black/files.py index 77d1a491693..95c3038430f 100644 --- a/src/black/files.py +++ b/src/black/files.py @@ -398,7 +398,7 @@ def gen_python_files( elif child.is_file(): if child.suffix == ".ipynb" and not jupyter_dependencies_are_installed( - warn=verbose or not quiet + warn=(verbose or not quiet) ): continue include_match = include.search(root_relative_path) if include else True diff --git a/src/black/handle_ipynb_magics.py b/src/black/handle_ipynb_magics.py index c84fe6219fb..5307d2620b9 100644 --- a/src/black/handle_ipynb_magics.py +++ b/src/black/handle_ipynb_magics.py @@ -133,7 +133,7 @@ def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str: for idx, token in reversed_enumerate(tokens): if token.name in TOKENS_TO_IGNORE: continue - tokens[idx] = token._replace(src=token.src + ";") + tokens[idx] = token._replace(src=(token.src + ";")) break else: # pragma: nocover raise AssertionError( diff --git a/src/black/linegen.py b/src/black/linegen.py index 06899029469..65087da7520 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -41,6 +41,8 @@ OPENING_BRACKETS, STANDALONE_COMMENT, STATEMENT, + TEST_DESCENDANTS, + TYPED_NAMES, WHITESPACE, Visitor, ensure_visible, @@ -52,6 +54,7 @@ is_atom_with_invisible_parens, is_docstring, is_empty_tuple, + is_exponentiation, is_generator, is_lpar_token, is_multiline_string, @@ -61,6 +64,7 @@ is_parent_function_or_class, is_part_of_annotation, is_rpar_token, + is_simple_exponentiation, is_stub_body, is_stub_suite, is_tuple, @@ -139,7 +143,7 @@ def line(self, indent: int = 0) -> Iterator[Line]: return complete_line = self.current_line - self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) + self.current_line = Line(mode=self.mode, depth=(complete_line.depth + indent)) yield complete_line def visit_default(self, node: LN) -> Iterator[Line]: @@ -446,6 +450,18 @@ def visit_factor(self, node: Node) -> Iterator[Line]: node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) yield from self.visit_default(node) + def visit_argument(self, node: Node) -> Iterator[Line]: + _maybe_wrap_complex_arg_expression_in_parens(node, self.mode) + yield from self.visit_default(node) + + def visit_typedargslist(self, node: Node) -> Iterator[Line]: + _maybe_wrap_complex_arg_expression_in_parens(node, self.mode) + yield from self.visit_default(node) + + def visit_varargslist(self, node: Node) -> Iterator[Line]: + _maybe_wrap_complex_arg_expression_in_parens(node, self.mode) + yield from self.visit_default(node) + def visit_tname(self, node: Node) -> Iterator[Line]: """ Add potential parentheses around types in function parameter lists to be made @@ -844,7 +860,7 @@ def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool: result.append( leaf, preformatted=True, - track_bracket=id(leaf) in leaves_to_track, + track_bracket=(id(leaf) in leaves_to_track), ) # we could also return true if the line is too long, and the return type is longer @@ -1152,7 +1168,7 @@ def _prefer_split_rhs_oop_over_rhs( # the left side of assignment is short enough (the -1 is for the ending optional # paren) if not is_line_short_enough( - rhs.head, mode=replace(mode, line_length=mode.line_length - 1) + rhs.head, mode=replace(mode, line_length=(mode.line_length - 1)) ): return True @@ -1290,7 +1306,7 @@ def bracket_split_build_line( result.append( leaf, preformatted=True, - track_bracket=id(leaf) in leaves_to_track, + track_bracket=(id(leaf) in leaves_to_track), ) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) @@ -1669,6 +1685,29 @@ def remove_await_parens(node: Node, mode: Mode, features: Collection[Feature]) - ensure_visible(closing_bracket) +def _maybe_wrap_complex_arg_expression_in_parens(node: Node, mode: Mode) -> None: + """Add parentheses around complex expression after equals sign + (unless it's a typed parameter): + + bar=x + y -> bar=(x + y) + """ + if Preview.arg_parens in mode: + for i, child in enumerate(node.children): + if child.type == token.EQUAL: + if node.children[i - 1].type not in TYPED_NAMES: + expr = node.children[i + 1] + if expr.type in TEST_DESCENDANTS and expr.type != syms.lambdef: + if ( + expr.type != syms.power + or expr.children[0].type == token.AWAIT + or ( + is_exponentiation(expr) + and not is_simple_exponentiation(expr) + ) + ): + wrap_in_parentheses(node, expr) + + def _maybe_wrap_cms_in_parens( node: Node, mode: Mode, features: Collection[Feature] ) -> None: diff --git a/src/black/mode.py b/src/black/mode.py index 8653a705495..84883bd6902 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -228,6 +228,7 @@ class Preview(Enum): hug_parens_with_braces_and_square_brackets = auto() wrap_comprehension_in = auto() simplify_power_operator_hugging = auto() + arg_parens = auto() wrap_long_dict_values_in_parens = auto() fix_if_guard_explosion_in_case_statement = auto() diff --git a/src/black/nodes.py b/src/black/nodes.py index 90487fb0f18..0f212037730 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -249,7 +249,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool, mode: Mode) -> str: return NO elif prevp.type in VARARGS_SPECIALS: - if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): + if is_vararg(prevp, within=(VARARGS_PARENTS | UNPACKING_PARENTS)): return NO elif prevp.type == token.COLON: @@ -563,11 +563,14 @@ def is_simple(node: LN) -> bool: else: return all(is_simple(child) for child in node.children) + return is_exponentiation(node) and is_simple(node) + + +def is_exponentiation(node: LN) -> bool: return ( node.type == syms.power and len(node.children) >= 3 and node.children[-2].type == token.DOUBLESTAR - and is_simple(node) ) diff --git a/src/black/ranges.py b/src/black/ranges.py index d7e003db83f..be31d742afd 100644 --- a/src/black/ranges.py +++ b/src/black/ranges.py @@ -499,9 +499,9 @@ def _calculate_lines_mappings( previous_block = matching_blocks[i - 1] lines_mappings.append( _LinesMapping( - original_start=previous_block.a + previous_block.size + 1, + original_start=(previous_block.a + previous_block.size + 1), original_end=block.a, - modified_start=previous_block.b + previous_block.size + 1, + modified_start=(previous_block.b + previous_block.size + 1), modified_end=block.b, is_changed_block=True, ) @@ -509,10 +509,10 @@ def _calculate_lines_mappings( if i < len(matching_blocks) - 1: lines_mappings.append( _LinesMapping( - original_start=block.a + 1, - original_end=block.a + block.size, - modified_start=block.b + 1, - modified_end=block.b + block.size, + original_start=(block.a + 1), + original_end=(block.a + block.size), + modified_start=(block.b + 1), + modified_end=(block.b + block.size), is_changed_block=False, ) ) diff --git a/src/black/resources/black.schema.json b/src/black/resources/black.schema.json index b2b1a03cbac..c35fee1fc58 100644 --- a/src/black/resources/black.schema.json +++ b/src/black/resources/black.schema.json @@ -84,6 +84,7 @@ "hug_parens_with_braces_and_square_brackets", "wrap_comprehension_in", "simplify_power_operator_hugging", + "arg_parens", "wrap_long_dict_values_in_parens", "fix_if_guard_explosion_in_case_statement" ] diff --git a/src/black/trans.py b/src/black/trans.py index 8d8ea2e15a3..8655ab9607d 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -2265,7 +2265,7 @@ def do_transform( string_value = LL[string_idx].value string_line = Line( mode=line.mode, - depth=line.depth + 1, + depth=(line.depth + 1), inside_brackets=True, should_split_rhs=line.should_split_rhs, magic_trailing_comma=line.magic_trailing_comma, diff --git a/src/blackd/__init__.py b/src/blackd/__init__.py index 432ca81a5ae..5710aaac457 100644 --- a/src/blackd/__init__.py +++ b/src/blackd/__init__.py @@ -203,8 +203,8 @@ def parse_mode(headers: MultiMapping[str]) -> black.Mode: is_pyi=pyi, line_length=line_length, skip_source_first_line=skip_source_first_line, - string_normalization=not skip_string_normalization, - magic_trailing_comma=not skip_magic_trailing_comma, + string_normalization=(not skip_string_normalization), + magic_trailing_comma=(not skip_magic_trailing_comma), preview=preview, unstable=unstable, enabled_features=enable_features, diff --git a/tests/data/cases/preview_arg_parens.py b/tests/data/cases/preview_arg_parens.py new file mode 100644 index 00000000000..c7f3bd85275 --- /dev/null +++ b/tests/data/cases/preview_arg_parens.py @@ -0,0 +1,54 @@ +# flags: --preview + +foo( + # with extra parens + bar=x if y else z, + bar=x or y, + bar=not x, + bar=x < y < z, + bar=x + y, + bar=f(x) ** y, + bar=await f(), + # without extra parens + bar=-x, + bar=x**y, + bar=x.y.z, + bar=f(x, y), + bar=(x, y), + bar=[x + y], + bar=f"{x + y}", + bar=lambda: x + y, +) + + +@foo(bar=x ** y, bar=x + y) +def foo(bar=x ** y, bar=x + y, bar: int=x + y): + return lambda bar=x ** y, bar=x + y: x + y + + +# output + +foo( + # with extra parens + bar=(x if y else z), + bar=(x or y), + bar=(not x), + bar=(x < y < z), + bar=(x + y), + bar=(f(x) ** y), + bar=(await f()), + # without extra parens + bar=-x, + bar=x**y, + bar=x.y.z, + bar=f(x, y), + bar=(x, y), + bar=[x + y], + bar=f"{x + y}", + bar=lambda: x + y, +) + + +@foo(bar=x**y, bar=(x + y)) +def foo(bar=x**y, bar=(x + y), bar: int = x + y): + return lambda bar=x**y, bar=(x + y): x + y diff --git a/tests/data/cases/preview_long_strings__regression.py b/tests/data/cases/preview_long_strings__regression.py index 123342f575c..d4836df3f88 100644 --- a/tests/data/cases/preview_long_strings__regression.py +++ b/tests/data/cases/preview_long_strings__regression.py @@ -671,7 +671,7 @@ def foo(): "{xxxx_xxx} >> {xxxxxx_xxxx}.xxxxxxx 2>&1; xx=$$?;" "xxxx $$xx".format( xxxx_xxx=xxxx_xxxxxxx, - xxxxxx_xxxx=xxxxxxx + "/" + xxxx_xxx_xxxx, + xxxxxx_xxxx=(xxxxxxx + "/" + xxxx_xxx_xxxx), x=xxx_xxxxx_xxxxx_xxx, ), x, diff --git a/tests/data/cases/preview_simplify_power_operator_hugging.py b/tests/data/cases/preview_simplify_power_operator_hugging.py index 1977396685c..98011322b26 100644 --- a/tests/data/cases/preview_simplify_power_operator_hugging.py +++ b/tests/data/cases/preview_simplify_power_operator_hugging.py @@ -1,5 +1,5 @@ # flags: --preview -# This is a copy of `power_op_spacing.py`. Remove when `simplify_power_operator_hugging` becomes stable. +# This is a copy of `power_op_spacing.py` for testing `simplify_power_operator_hugging`, with output adjusted for `arg_parens`. def function(**kwargs): t = a**2 + b**3 @@ -81,7 +81,7 @@ def function_dont_replace_spaces(): # output -# This is a copy of `power_op_spacing.py`. Remove when `simplify_power_operator_hugging` becomes stable. +# This is a copy of `power_op_spacing.py` for testing `simplify_power_operator_hugging`, with output adjusted for `arg_parens`. def function(**kwargs): @@ -145,9 +145,9 @@ def function_dont_replace_spaces(): view.variance, # type: ignore[union-attr] view.sum_of_weights, # type: ignore[union-attr] out=np.full(view.sum_of_weights.shape, np.nan), # type: ignore[union-attr] - where=view.sum_of_weights**2 > view.sum_of_weights_squared, # type: ignore[union-attr] + where=(view.sum_of_weights**2 > view.sum_of_weights_squared), # type: ignore[union-attr] ) return np.divide( - where=view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared, # type: ignore + where=(view.sum_of_weights_of_weight_long**2 > view.sum_of_weights_squared), # type: ignore ) diff --git a/tests/test_black.py b/tests/test_black.py index bdb41d6a551..68e5972b2bc 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -2479,7 +2479,7 @@ def assert_collected_sources( ) gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude) collected = black.get_sources( - root=root or THIS_DIR, + root=(root or THIS_DIR), src=gs_src, quiet=False, verbose=False, diff --git a/tests/test_blackd.py b/tests/test_blackd.py index a64715cd6d2..a3878f48c7b 100644 --- a/tests/test_blackd.py +++ b/tests/test_blackd.py @@ -131,7 +131,7 @@ async def check(header_value: str, expected_status: int) -> None: "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value} ) self.assertEqual( - response.status, expected_status, msg=await response.text() + response.status, expected_status, msg=(await response.text()) ) await check("3.6", 200) @@ -205,7 +205,7 @@ async def test_cors_headers_present(self) -> None: async def test_preserves_line_endings(self) -> None: for data in (b"c\r\nc\r\n", b"l\nl\n"): # test preserved newlines when reformatted - response = await self.client.post("/", data=data + b" ") + response = await self.client.post("/", data=(data + b" ")) self.assertEqual(await response.text(), data.decode()) # test 204 when no change response = await self.client.post("/", data=data) diff --git a/tests/util.py b/tests/util.py index 0acce4bed2b..ac114a86482 100644 --- a/tests/util.py +++ b/tests/util.py @@ -116,7 +116,7 @@ def assert_format( if mode.unstable: new_mode = replace(mode, unstable=False, preview=False) else: - new_mode = replace(mode, preview=not mode.preview) + new_mode = replace(mode, preview=(not mode.preview)) _assert_format_inner( source, None, @@ -283,10 +283,10 @@ def parse_mode(flags_line: str) -> TestCaseArgs: mode = black.Mode( target_versions=set(args.target_version), line_length=args.line_length, - string_normalization=not args.skip_string_normalization, + string_normalization=(not args.skip_string_normalization), is_pyi=args.pyi, is_ipynb=args.ipynb, - magic_trailing_comma=not args.skip_magic_trailing_comma, + magic_trailing_comma=(not args.skip_magic_trailing_comma), preview=args.preview, unstable=args.unstable, )