diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index d74f782103..6bc4792acf 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -32,6 +32,7 @@ # POSSIBILITY OF SUCH DAMAGE. # ----------------------------------------------------------------------------- # Author: A. B. G. Chalk, STFC Daresbury Lab +# Minor contributions: M. Schreiber, Univ. Grenoble Alpes # ----------------------------------------------------------------------------- """This module contains the DefinitionUseChain class""" @@ -88,6 +89,7 @@ def __init__( control_flow_region: Iterable[Node] = (), start_point: Optional[int] = None, stop_point: Optional[int] = None, + stop_at_call: bool = True ): if not isinstance(reference, Reference): raise TypeError( @@ -114,6 +116,7 @@ def __init__( ) self._start_point = start_point self._stop_point = stop_point + self._stop_at_call = stop_at_call if not control_flow_region: self._scope = [reference.ancestor(Routine)] if self._scope[0] is None: @@ -257,6 +260,7 @@ def find_forward_accesses(self) -> list[Node]: body, start_point=ancestor.abs_position, stop_point=sub_stop_point, + stop_at_call=self._stop_at_call ) chains.insert(0, chain) # If its a while loop, create a basic block for the while @@ -269,6 +273,7 @@ def find_forward_accesses(self) -> list[Node]: [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, + stop_at_call=self._stop_at_call ) chains.insert(0, chain) ancestor = ancestor.ancestor((Loop, WhileLoop)) @@ -291,6 +296,7 @@ def find_forward_accesses(self) -> list[Node]: [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, + stop_at_call=self._stop_at_call ) control_flow_nodes.append(None) chains.append(chain) @@ -310,6 +316,7 @@ def find_forward_accesses(self) -> list[Node]: block, start_point=self._start_point, stop_point=self._stop_point, + stop_at_call=self._stop_at_call ) chains.append(chain) for i, chain in enumerate(chains): @@ -393,6 +400,7 @@ def find_forward_accesses(self) -> list[Node]: [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, + stop_at_call=self._stop_at_call ) # Find any forward_accesses in the lhs. chain.find_forward_accesses() @@ -507,11 +515,17 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): # catch the arguments that are passed into the call # later as References. continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + + if self._stop_at_call: + # For now just assume calls are bad if we have a + # non-local variable and we treat them as though + # they were a write. + if defs_out is not None: + self._killed.append(defs_out) + + defs_out = reference + else: + self._uses.append(reference) continue elif reference.get_signature_and_indices()[0] == sig: # Work out if its read only or not. @@ -912,6 +926,7 @@ def find_backward_accesses(self) -> list[Node]: block, start_point=self._start_point, stop_point=self._stop_point, + stop_at_call=self._stop_at_call ) chains.append(chain) # If this is the top level access, we need to check if the @@ -951,6 +966,7 @@ def find_backward_accesses(self) -> list[Node]: body, start_point=sub_start_point, stop_point=sub_stop_point, + stop_at_call=self._stop_at_call ) chains.append(chain) control_flow_nodes.append(ancestor) @@ -964,6 +980,7 @@ def find_backward_accesses(self) -> list[Node]: [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, + stop_at_call=self._stop_at_call ) chains.append(chain) ancestor = ancestor.ancestor((Loop, WhileLoop)) @@ -982,6 +999,7 @@ def find_backward_accesses(self) -> list[Node]: ancestor.rhs.children[:], start_point=ancestor.rhs.abs_position, stop_point=end.abs_position, + stop_at_call=self._stop_at_call ) control_flow_nodes.append(None) chains.append(chain) @@ -1071,6 +1089,7 @@ def find_backward_accesses(self) -> list[Node]: [ancestor.rhs], start_point=ancestor.rhs.abs_position, stop_point=sys.maxsize, + stop_at_call=self._stop_at_call ) # Find any backward_accesses in the rhs. chain.find_backward_accesses() diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py index d7940c0df5..90321070f8 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py @@ -822,6 +822,42 @@ def test_definition_use_chain_find_backward_accesses_pure_call( assert reaches[0] is routine.walk(Call)[0].children[1] +def test_definition_use_chain_find_backward_accesses_continue_at_call( + fortran_reader, +): + """Functionality test for the find_backward_accesses routine. This + tests the behaviour for a pure subrotuine call.""" + code = """ + subroutine x(a, b) + integer, intent(inout) :: a, b, c + a = 2 + b = 1 + call foo(b) + a = a + 2 + b = 3 + a + call bar(b) + c = 2 + a + b = 1 + a ! Stops dependency + end subroutine""" + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Routine)[0] + + # Start from last assignment 'a' + ref_a = routine.walk(Assignment)[-1].rhs.children[1] + chains = DefinitionUseChain(ref_a) + reaches = chains.find_backward_accesses() + + assert len(reaches) == 2 + assert isinstance(reaches[1], Call) + + ref_a = routine.walk(Assignment)[0].lhs + chains = DefinitionUseChain(ref_a, stop_at_call=False) + reaches = chains.find_forward_accesses() + + assert len(reaches) == 4 + assert isinstance(reaches[2], Call) + + def test_definition_use_chain_find_backward_accesses_ancestor_call( fortran_reader, ): diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index d4d1d6f3c8..68a368f244 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -32,6 +32,7 @@ # POSSIBILITY OF SUCH DAMAGE. # ----------------------------------------------------------------------------- # Author: A. B. G. Chalk, STFC Daresbury Lab +# Minor contributions: M. Schreiber, Univ. Grenoble Alpes # ----------------------------------------------------------------------------- '''This module contains the tests for the DefinitionUseChain class's forward_accesses routine.''' @@ -1161,6 +1162,41 @@ def test_definition_use_chain_find_forward_accesses_pure_call( assert reaches[0] is argument +def test_definition_use_chain_find_forward_accesses_continue_at_call( + fortran_reader, +): + """Functionality test for the find_forward_accesses routine. This + tests the behaviour for a pure subrotuine call.""" + code = """ + subroutine x(a, b) + integer, intent(inout) :: a, b + a = 2 + b = 1 + b = 1 + a ! Stops dependency + call foo(b) + a = a + 2 + call bar(b) + a = a + 3 + end subroutine""" + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Routine)[0] + + # Start from 'a' + ref_a = routine.walk(Assignment)[0].lhs + chains = DefinitionUseChain(ref_a) + reaches = chains.find_forward_accesses() + + assert len(reaches) == 2 + assert isinstance(reaches[1], Call) + + ref_a = routine.walk(Assignment)[0].lhs + chains = DefinitionUseChain(ref_a, stop_at_call=False) + reaches = chains.find_forward_accesses() + + assert len(reaches) == 5 + assert isinstance(reaches[1], Call) + + def test_forward_accesses_nested_loop(fortran_reader): """Test that if we have many nested loops we don't repeat the same reference in the result."""