Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/psyclone/psyir/tools/definition_use_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Author: A. B. G. Chalk, STFC Daresbury Lab
# Minor contributions: M. Schreiber, Univ. Grenoble Alpes
# -----------------------------------------------------------------------------
"""This module contains the DefinitionUseChain class"""

Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
control_flow_region: Iterable[Node] = (),
start_point: Optional[int] = None,
stop_point: Optional[int] = None,
stop_at_call: bool = True
):
if not isinstance(reference, Reference):
raise TypeError(
Expand All @@ -114,6 +116,7 @@ def __init__(
)
self._start_point = start_point
self._stop_point = stop_point
self._stop_at_call = stop_at_call
if not control_flow_region:
self._scope = [reference.ancestor(Routine)]
if self._scope[0] is None:
Expand Down Expand Up @@ -257,6 +260,7 @@ def find_forward_accesses(self) -> list[Node]:
body,
start_point=ancestor.abs_position,
stop_point=sub_stop_point,
stop_at_call=self._stop_at_call
)
chains.insert(0, chain)
# If its a while loop, create a basic block for the while
Expand All @@ -269,6 +273,7 @@ def find_forward_accesses(self) -> list[Node]:
[ancestor.condition],
start_point=ancestor.abs_position,
stop_point=sub_stop_point,
stop_at_call=self._stop_at_call
)
chains.insert(0, chain)
ancestor = ancestor.ancestor((Loop, WhileLoop))
Expand All @@ -291,6 +296,7 @@ def find_forward_accesses(self) -> list[Node]:
[ancestor.lhs],
start_point=ancestor.lhs.abs_position - 1,
stop_point=ancestor.lhs.abs_position + 1,
stop_at_call=self._stop_at_call
)
control_flow_nodes.append(None)
chains.append(chain)
Expand All @@ -310,6 +316,7 @@ def find_forward_accesses(self) -> list[Node]:
block,
start_point=self._start_point,
stop_point=self._stop_point,
stop_at_call=self._stop_at_call
)
chains.append(chain)
for i, chain in enumerate(chains):
Expand Down Expand Up @@ -393,6 +400,7 @@ def find_forward_accesses(self) -> list[Node]:
[ancestor.lhs],
start_point=ancestor.lhs.abs_position - 1,
stop_point=ancestor.lhs.abs_position + 1,
stop_at_call=self._stop_at_call
)
# Find any forward_accesses in the lhs.
chain.find_forward_accesses()
Expand Down Expand Up @@ -507,11 +515,17 @@ def _compute_forward_uses(self, basic_block_list: list[Node]):
# catch the arguments that are passed into the call
# later as References.
continue
# For now just assume calls are bad if we have a non-local
# variable and we treat them as though they were a write.
if defs_out is not None:
self._killed.append(defs_out)
defs_out = reference

if self._stop_at_call:
# For now just assume calls are bad if we have a
# non-local variable and we treat them as though
# they were a write.
if defs_out is not None:
self._killed.append(defs_out)

defs_out = reference
else:
self._uses.append(reference)
continue
elif reference.get_signature_and_indices()[0] == sig:
# Work out if its read only or not.
Expand Down Expand Up @@ -912,6 +926,7 @@ def find_backward_accesses(self) -> list[Node]:
block,
start_point=self._start_point,
stop_point=self._stop_point,
stop_at_call=self._stop_at_call
)
chains.append(chain)
# If this is the top level access, we need to check if the
Expand Down Expand Up @@ -951,6 +966,7 @@ def find_backward_accesses(self) -> list[Node]:
body,
start_point=sub_start_point,
stop_point=sub_stop_point,
stop_at_call=self._stop_at_call
)
chains.append(chain)
control_flow_nodes.append(ancestor)
Expand All @@ -964,6 +980,7 @@ def find_backward_accesses(self) -> list[Node]:
[ancestor.condition],
start_point=ancestor.abs_position,
stop_point=sub_stop_point,
stop_at_call=self._stop_at_call
)
chains.append(chain)
ancestor = ancestor.ancestor((Loop, WhileLoop))
Expand All @@ -982,6 +999,7 @@ def find_backward_accesses(self) -> list[Node]:
ancestor.rhs.children[:],
start_point=ancestor.rhs.abs_position,
stop_point=end.abs_position,
stop_at_call=self._stop_at_call
)
control_flow_nodes.append(None)
chains.append(chain)
Expand Down Expand Up @@ -1071,6 +1089,7 @@ def find_backward_accesses(self) -> list[Node]:
[ancestor.rhs],
start_point=ancestor.rhs.abs_position,
stop_point=sys.maxsize,
stop_at_call=self._stop_at_call
)
# Find any backward_accesses in the rhs.
chain.find_backward_accesses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,42 @@ def test_definition_use_chain_find_backward_accesses_pure_call(
assert reaches[0] is routine.walk(Call)[0].children[1]


def test_definition_use_chain_find_backward_accesses_continue_at_call(
fortran_reader,
):
"""Functionality test for the find_backward_accesses routine. This
tests the behaviour for a pure subrotuine call."""
code = """
subroutine x(a, b)
integer, intent(inout) :: a, b, c
a = 2
b = 1
call foo(b)
a = a + 2
b = 3 + a
call bar(b)
c = 2 + a
b = 1 + a ! Stops dependency
end subroutine"""
psyir = fortran_reader.psyir_from_source(code)
routine = psyir.walk(Routine)[0]

# Start from last assignment 'a'
ref_a = routine.walk(Assignment)[-1].rhs.children[1]
chains = DefinitionUseChain(ref_a)
reaches = chains.find_backward_accesses()

assert len(reaches) == 2
assert isinstance(reaches[1], Call)

ref_a = routine.walk(Assignment)[0].lhs
chains = DefinitionUseChain(ref_a, stop_at_call=False)
reaches = chains.find_forward_accesses()

assert len(reaches) == 4
assert isinstance(reaches[2], Call)


def test_definition_use_chain_find_backward_accesses_ancestor_call(
fortran_reader,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Author: A. B. G. Chalk, STFC Daresbury Lab
# Minor contributions: M. Schreiber, Univ. Grenoble Alpes
# -----------------------------------------------------------------------------
'''This module contains the tests for the DefinitionUseChain class's
forward_accesses routine.'''
Expand Down Expand Up @@ -1161,6 +1162,41 @@ def test_definition_use_chain_find_forward_accesses_pure_call(
assert reaches[0] is argument


def test_definition_use_chain_find_forward_accesses_continue_at_call(
fortran_reader,
):
"""Functionality test for the find_forward_accesses routine. This
tests the behaviour for a pure subrotuine call."""
code = """
subroutine x(a, b)
integer, intent(inout) :: a, b
a = 2
b = 1
b = 1 + a ! Stops dependency
call foo(b)
a = a + 2
call bar(b)
a = a + 3
end subroutine"""
psyir = fortran_reader.psyir_from_source(code)
routine = psyir.walk(Routine)[0]

# Start from 'a'
ref_a = routine.walk(Assignment)[0].lhs
chains = DefinitionUseChain(ref_a)
reaches = chains.find_forward_accesses()

assert len(reaches) == 2
assert isinstance(reaches[1], Call)

ref_a = routine.walk(Assignment)[0].lhs
chains = DefinitionUseChain(ref_a, stop_at_call=False)
reaches = chains.find_forward_accesses()

assert len(reaches) == 5
assert isinstance(reaches[1], Call)


def test_forward_accesses_nested_loop(fortran_reader):
"""Test that if we have many nested loops we don't repeat the same
reference in the result."""
Expand Down
Loading