Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 79 additions & 125 deletions src/psyclone/psyir/backend/fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def __init__(self, **kwargs):
Fparser2Reader.binary_operators)

# Create and store a CallTreeUtils instance for use when ordering
# parameter declarations. Have to import it here as CallTreeUtils
# also uses this Fortran backend.
# declarations. Have to import it here as CallTreeUtils also uses
# this Fortran backend.
# pylint: disable=import-outside-toplevel
from psyclone.psyir.tools.call_tree_utils import CallTreeUtils
self._call_tree_utils = CallTreeUtils()
Expand Down Expand Up @@ -861,78 +861,17 @@ def gen_access_stmts(self, symbol_table):
return result
return ""

# pylint: disable=too-many-branches
def _gen_parameter_decls(self, symbol_table, is_module_scope=False):
''' Create the declarations of all parameters present in the supplied
symbol table. Declarations are ordered so as to satisfy any inter-
dependencies between them.

:param symbol_table: the SymbolTable instance.
:type symbol: :py:class:`psyclone.psyir.symbols.SymbolTable`
:param bool is_module_scope: whether or not the declarations are in
a module scoping unit. Default is False.

:returns: Fortran code declaring all parameters.
:rtype: str

:raises VisitorError: if there is no way of resolving
interdependencies between parameter declarations.

'''
declarations = ""
local_constants = []

# First add the local constants
for sym in symbol_table.datasymbols:
if sym.is_import or sym.is_unresolved:
continue # Skip, these don't need declarations
if sym.is_constant:
local_constants.append(sym)

# There may be dependencies between these constants so setup a dict
# holding a set of all their dependencies. The checks have to be done
# with case-insensitive name comparisons because the dependent symbols
# are not always created in the same scope.
local_lowered_names = [sym.name.lower() for sym in local_constants]
decln_inputs = {}
for symbol in local_constants:
dependencies = symbol.get_all_accessed_symbols()
dependencies = {sym for sym in dependencies
# Discard self-dependencies: e.g. "a :: HUGE(a)"
if sym.name.lower() != symbol.name.lower() and
# Discard dependencies that are not local
sym.name.lower() in local_lowered_names}
decln_inputs[symbol] = dependencies

# We now iterate over the declarations, declaring those that have their
# inputs satisfied. Creating a declaration for a given symbol removes
# that symbol as a dependence from any outstanding declarations and
# adds it to the 'declared' set.
declared: set[Symbol] = set()
while local_constants:
for symbol in local_constants[:]:
inputs = decln_inputs[symbol]
if inputs.issubset(declared):
# All inputs are satisfied so this declaration can be added
declared.add(symbol)
local_constants.remove(symbol)
declarations += self.gen_vardecl(
symbol, include_visibility=is_module_scope)
break
else:
# We looped through all of the variables remaining to be
# declared and none had their dependencies satisfied.
raise VisitorError(
f"Unable to satisfy dependencies for the declarations of "
f"{[sym.name for sym in local_constants]}")
return declarations

def gen_decls(self,
symbol_table: SymbolTable,
is_module_scope: bool = False) -> str:
'''Create and return the Fortran declarations for the supplied
SymbolTable.

Declarations are ordered such that any given symbol is declared after
those upon which it depends. Stricly speaking, the Fortran standard
does not mandate this in the majority of cases but compiler
implementations do not always follow the standard.

:param symbol_table: the SymbolTable instance.
:param is_module_scope: whether or not the declarations are in
a module scoping unit. Default is False.
Expand All @@ -952,6 +891,8 @@ def gen_decls(self,
RoutineSymbols) in the supplied table that do not have an
explicit declaration (UnresolvedInterface) and there are no
wildcard imports or unknown interfaces.
:raises VisitorError: if there is no way of resolving interdependencies
between symbol declarations.

'''
# pylint: disable=too-many-branches
Expand Down Expand Up @@ -983,20 +924,18 @@ def gen_decls(self,

# If the symbol table contains any symbols with an
# UnresolvedInterface interface (they are not explicitly
# declared), we need to check that we have at least one
# wildcard import which could be bringing them into this
# scope, or an unknown interface which could be declaring
# them.
# declared), we need to check that we have at least one wildcard
# import which could be bringing them into this scope, or an
# unknown interface which could be declaring them.
unresolved_symbols = []
for sym in all_symbols[:]:
if isinstance(sym.interface, UnresolvedInterface):
unresolved_symbols.append(sym)
all_symbols.remove(sym)
try:
internal_interface_symbol = symbol_table.lookup(
"_psyclone_internal_interface")
except KeyError:
internal_interface_symbol = None

internal_interface_symbol = symbol_table.lookup(
"_psyclone_internal_interface", otherwise=None)

if unresolved_symbols and not (
symbol_table.wildcard_imports() or internal_interface_symbol):
symbols_txt = ", ".join(
Expand All @@ -1014,62 +953,77 @@ def gen_decls(self,
raise VisitorError(
f"Found a symbol '{sym.name}' with a name greater than "
f"{self.MAX_VARIABLE_NAME_LENGTH} characters in length. "
"This is not standards-compliant Fortran.")

# As a convention, we will declare the variables in the following
# order:

# 1: Routine declarations and interfaces. (Note that accessibility
# statements are generated in gen_access_stmts().)
for sym in all_symbols[:]:
if not isinstance(sym, RoutineSymbol):
continue
# Interfaces can be GenericInterfaceSymbols or RoutineSymbols
# of UnsupportedFortranType.
if isinstance(sym, GenericInterfaceSymbol):
declarations += self.gen_interfacedecl(sym)
elif isinstance(sym.datatype, UnsupportedType):
declarations += self.gen_vardecl(
sym, include_visibility=is_module_scope)
elif not (sym.is_modulevar or sym.is_automatic):
raise VisitorError(
f"Routine symbol '{sym.name}' has '{sym.interface}'. "
f"This is not supported by the Fortran back-end.")
all_symbols.remove(sym)
f"This is not standards-compliant Fortran.")

# 2: Constants.
declarations += self._gen_parameter_decls(symbol_table,
is_module_scope)
for sym in all_symbols[:]:
if isinstance(sym, DataSymbol) and sym.is_constant:
all_symbols.remove(sym)
# There may be dependencies between the symbols so setup a dict
# holding a set of all their dependencies. The checks have to be done
# with case-insensitive name comparisons because the dependent symbols
# are not always created in the same scope.
local_lowered_names = [sym.name.lower() for sym in all_symbols]
decln_inputs: dict[str, Symbol] = {}
for symbol in all_symbols:
dependencies = symbol.get_all_accessed_symbols()
dependencies = {sym for sym in dependencies
# Discard self-dependencies: e.g. "a :: HUGE(a)"
if sym.name.lower() != symbol.name.lower() and
# Discard dependencies that are not local
sym.name.lower() in local_lowered_names and
# Discard dependencies on RoutineSymbols (but
# *not* interfaces)
not (isinstance(sym, RoutineSymbol) and
not isinstance(sym, GenericInterfaceSymbol))}
decln_inputs[symbol] = dependencies

# 3: Argument variable declarations
# Sanity check that we haven't got arguments if we're in a module scope
if symbol_table.argument_datasymbols and is_module_scope:
raise VisitorError(
f"Arguments are not allowed in this context but this symbol "
f"table contains argument(s): "
f"'{[sym.name for sym in symbol_table.argument_datasymbols]}'."
)
# We use symbol_table.argument_datasymbols because it has the
# symbol order that we need
for symbol in symbol_table.argument_datasymbols:
declarations += self.gen_vardecl(
symbol, include_visibility=is_module_scope)
all_symbols.remove(symbol)

# 4: Derived-type declarations. These must come before any declarations
# of symbols of these types.
for symbol in all_symbols[:]:
if isinstance(symbol, DataTypeSymbol):
declarations += self.gen_typedecl(
symbol, include_visibility=is_module_scope)
all_symbols.remove(symbol)

# 5: The rest of the symbols
for symbol in all_symbols:
declarations += self.gen_vardecl(
symbol, include_visibility=is_module_scope)

# We now iterate over the declarations, declaring those that have their
# inputs satisfied. Creating a declaration for a given symbol removes
# that symbol as a dependence from any outstanding declarations and
# adds it to the 'declared' set.
declared: set[Symbol] = set()

while all_symbols:
for symbol in all_symbols[:]:
inputs = decln_inputs[symbol]
if inputs.issubset(declared):
# All inputs are satisfied so this declaration can be added
declared.add(symbol)
all_symbols.remove(symbol)
if isinstance(symbol, RoutineSymbol):
# Interfaces can be GenericInterfaceSymbols or
# RoutineSymbols of UnsupportedFortranType.
if isinstance(symbol, GenericInterfaceSymbol):
declarations += self.gen_interfacedecl(symbol)
elif isinstance(symbol.datatype, UnsupportedType):
declarations += self.gen_vardecl(
symbol, include_visibility=is_module_scope)
elif not (symbol.is_modulevar or symbol.is_automatic):
raise VisitorError(
f"Routine symbol '{symbol.name}' has "
f"'{symbol.interface}'. This is not supported "
f"by the Fortran back-end.")
elif isinstance(symbol, DataTypeSymbol):
declarations += self.gen_typedecl(
symbol, include_visibility=is_module_scope)
else:
declarations += self.gen_vardecl(
symbol, include_visibility=is_module_scope)
# Now that we've created a new declaration (and thus
# potentially resolved some dependencies) we go back to
# the start of the list of remaining symbols.
break
else:
# We looped through all of the variables remaining to be
# declared and none had their dependencies satisfied.
raise VisitorError(
f"Unable to satisfy dependencies for the declarations of "
f"{[sym.name for sym in all_symbols]}")

return declarations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def test_module_inline_apply_polymorphic_kernel_in_multiple_invokes(tmpdir):
use quadrature_xyoz_mod, only : quadrature_xyoz_proxy_type, \
quadrature_xyoz_type
use function_space_mod, only : basis, diff_basis
real""" in output
integer""" in output
assert "mixed_kernel_mod" not in output
assert LFRicBuild(tmpdir).code_compiles(psy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ def test_go_move_iteration_boundaries_inside_kernel_two_kernels_apply_twice(
assert "use time_smooth_mod" not in output

expected = '''subroutine invoke_0(cu_fld, p_fld, u_fld, unew_fld, uold_fld)
integer :: j
integer :: i
type(r2d_field), intent(inout) :: cu_fld
type(r2d_field), intent(inout) :: p_fld
type(r2d_field), intent(inout) :: u_fld
type(r2d_field), intent(inout) :: unew_fld
type(r2d_field), intent(inout) :: uold_fld
integer :: j
integer :: i
integer :: xstart
integer :: xstop
integer :: ystart
Expand Down
1 change: 1 addition & 0 deletions src/psyclone/tests/domain/lfric/dofkern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_multi_invoke_cell_dof_builtin(tmpdir, monkeypatch, annexed, dist_mem):
type(field_type), intent(in) :: f3
type(field_type), intent(in) :: f4
real(kind=r_def), intent(in) :: scalar_arg
integer(kind=i_def) :: cell
real(kind=r_def), intent(in) :: a
type(field_type), intent(in) :: m1
type(field_type), intent(in) :: m2
Expand Down
8 changes: 4 additions & 4 deletions src/psyclone/tests/domain/lfric/lfric_field_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def test_field(tmpdir):
" contains\n"
" subroutine invoke_0_testkern_type(a, f1, f2, m1, m2)\n"
" use constants_mod, only : i_def\n"
" integer(kind=i_def) :: cell\n"
" real(kind=r_def), intent(in) :: a\n"
" type(field_type), intent(in) :: f1\n"
" type(field_type), intent(in) :: f2\n"
" type(field_type), intent(in) :: m1\n"
" type(field_type), intent(in) :: m2\n"
" integer(kind=i_def) :: cell\n"
" real(kind=r_def), pointer, dimension(:) :: f1_data => null()\n"
" real(kind=r_def), pointer, dimension(:) :: f2_data => null()\n"
" real(kind=r_def), pointer, dimension(:) :: m1_data => null()\n"
Expand Down Expand Up @@ -177,12 +177,12 @@ def test_field_deref(tmpdir, dist_mem):
assert LFRicBuild(tmpdir).code_compiles(psy)

output = (
" integer(kind=i_def) :: cell\n"
" real(kind=r_def), intent(in) :: a\n"
" type(field_type), intent(in) :: f1\n"
" type(field_type), intent(in) :: est_f2\n"
" type(field_type), intent(in) :: m1\n"
" type(field_type), intent(in) :: est_m2\n"
" integer(kind=i_def) :: cell\n"
)
assert output in generated_code
output = (
Expand Down Expand Up @@ -322,6 +322,7 @@ def test_field_fs(tmpdir):
f6, m5, m6, m7)
use mesh_mod, only : mesh_type
use constants_mod, only : i_def
integer(kind=i_def) :: cell
type(field_type), intent(in) :: f1
type(field_type), intent(in) :: f2
type(field_type), intent(in) :: m1
Expand All @@ -335,7 +336,6 @@ def test_field_fs(tmpdir):
type(field_type), intent(in) :: m5
type(field_type), intent(in) :: m6
type(field_type), intent(in) :: m7
integer(kind=i_def) :: cell
type(mesh_type), pointer :: mesh => null()
integer(kind=i_def) :: max_halo_depth_mesh
real(kind=r_def), pointer, dimension(:) :: f1_data => null()
Expand Down Expand Up @@ -655,6 +655,7 @@ def test_int_field_fs(tmpdir):
subroutine invoke_0_testkern_fs_int_field_type(f1, f2, m1, m2, f3, f4, m3, \
m4, f5, f6, m5, m6, f7, f8, m7)
use mesh_mod, only : mesh_type
integer(kind=i_def) :: cell
type(integer_field_type), intent(in) :: f1
type(integer_field_type), intent(in) :: f2
type(integer_field_type), intent(in) :: m1
Expand All @@ -670,7 +671,6 @@ def test_int_field_fs(tmpdir):
type(integer_field_type), intent(in) :: f7
type(integer_field_type), intent(in) :: f8
type(integer_field_type), intent(in) :: m7
integer(kind=i_def) :: cell
type(mesh_type), pointer :: mesh => null()
integer(kind=i_def) :: max_halo_depth_mesh
integer(kind=i_def), pointer, dimension(:) :: f1_data => null()
Expand Down
Loading
Loading