Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hls4ml/backends/catapult/catapult_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def create_initial_config(
fifo=None,
clock_period=5,
io_type='io_parallel',
write_tar=False,
):
config = {}

Expand All @@ -206,6 +207,9 @@ def create_initial_config(
config['ClockPeriod'] = clock_period
config['FIFO'] = fifo
config['IOType'] = io_type
config['WriterConfig'] = {
'WriteTar': write_tar,
}
config['HLSConfig'] = {}

return config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def create_initial_config(
clock_period=5,
clock_uncertainty='12.5%',
io_type='io_parallel',
write_tar=False,
interface='axi_stream',
driver='python',
input_type='float',
Expand Down Expand Up @@ -131,7 +132,13 @@ def create_initial_config(
populated config
"""
board = board if board is not None else 'pynq-z2'
config = super().create_initial_config(part, clock_period, clock_uncertainty, io_type)
config = super().create_initial_config(
part=part,
clock_period=clock_period,
clock_uncertainty=clock_uncertainty,
io_type=io_type,
write_tar=write_tar,
)
config['AcceleratorConfig'] = {}
config['AcceleratorConfig']['Board'] = board
config['AcceleratorConfig']['Interface'] = interface # axi_stream, axi_master, axi_lite
Expand Down
13 changes: 7 additions & 6 deletions hls4ml/writer/catapult_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,12 +892,13 @@ def write_tar(self, model):
Args:
model (ModelGraph): the hls4ml model.
"""

if not os.path.exists(model.config.get_output_dir() + '.tar.gz'):
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)
else:
print('Project .tar.gz archive already exists')
if not self.should_write_tar(model):
return
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True, arcname='')

def write_hls(self, model):
self.write_output_dir(model)
Expand Down
13 changes: 6 additions & 7 deletions hls4ml/writer/libero_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,13 +884,12 @@ def write_tar(self, model):
Args:
model (ModelGraph): the hls4ml model.
"""

write_tar = model.config.get_writer_config().get('WriteTar', False)
if write_tar:
tar_path = Path(model.config.get_output_dir() + '.tar.gz')
tar_path.unlink(missing_ok=True)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True, arcname='')
if not self.should_write_tar(model):
return
tar_path = Path(model.config.get_output_dir() + '.tar.gz')
tar_path.unlink(missing_ok=True)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True, arcname='')

def write_hls(self, model):
print('Writing HLS project')
Expand Down
14 changes: 7 additions & 7 deletions hls4ml/writer/oneapi_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,13 +971,13 @@ def write_tar(self, model):
Args:
model (ModelGraph): the hls4ml model.
"""

if model.config.get_writer_config().get('WriteTar', False):
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)
if not self.should_write_tar(model):
return
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)

def write_hls(self, model):
self.write_project_dir(model)
Expand Down
14 changes: 7 additions & 7 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,13 +1330,13 @@ def write_tar(self, model):
Args:
model (ModelGraph): the hls4ml model.
"""

if model.config.get_writer_config().get('WriteTar', False):
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)
if not self.should_write_tar(model):
return
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True)

def write_hls(self, model):
self.write_project_dir(model)
Expand Down
2 changes: 2 additions & 0 deletions hls4ml/writer/vivado_accelerator_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def write_driver(self, model):
)

def write_new_tar(self, model):
if not self.should_write_tar(model):
return
tarfile = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tarfile):
os.remove(tarfile)
Expand Down
15 changes: 7 additions & 8 deletions hls4ml/writer/vivado_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,14 +1112,13 @@ def write_tar(self, model):
Args:
model (ModelGraph): the hls4ml model.
"""

write_tar = model.config.get_writer_config().get('WriteTar', False)
if write_tar:
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True, arcname='')
if not self.should_write_tar(model):
return
tar_path = model.config.get_output_dir() + '.tar.gz'
if os.path.exists(tar_path):
os.remove(tar_path)
with tarfile.open(tar_path, mode='w:gz') as archive:
archive.add(model.config.get_output_dir(), recursive=True, arcname='')

def write_hls(self, model, is_multigraph=False):
if not is_multigraph:
Expand Down
11 changes: 11 additions & 0 deletions hls4ml/writer/writers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import os

WRITE_TAR_ENV_VAR = 'HLS4ML_WRITE_TAR'


class Writer:
def __init__(self):
pass

def should_write_tar(self, model):
write_tar_config = model.config.get_writer_config().get('WriteTar', False)
env_value = os.environ.get(WRITE_TAR_ENV_VAR, '')
write_tar_env = env_value.strip().lower() in {'1', 'true'}
return write_tar_config or write_tar_env

def write_hls(self, model):
raise NotImplementedError

Expand Down
29 changes: 27 additions & 2 deletions test/pytest/test_writer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_emulator(test_case_id, keras_model, io_type, backend):
hls_model.compile() # It's enough that the model compiles


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) # No Quartus for now
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI', 'VivadoAccelerator'])
@pytest.mark.parametrize('write_tar', [True, False])
def test_write_tar(test_case_id, keras_model, write_tar, backend):
config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name')
config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend)
odir = str(test_root_path / test_case_id)

if os.path.exists(odir + '.tar.gz'):
Expand All @@ -60,6 +60,31 @@ def test_write_tar(test_case_id, keras_model, write_tar, backend):
assert tar_written == write_tar


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI', 'VivadoAccelerator'])
@pytest.mark.parametrize(
'env_write_tar, expected_tar',
[
('false', False),
('1', True),
],
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add the id=[...] here for the proper naming of the tests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, done

def test_write_tar_env_override(test_case_id, keras_model, backend, monkeypatch, env_write_tar, expected_tar):
config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend)
odir = str(test_root_path / test_case_id)

if os.path.exists(odir + '.tar.gz'):
os.remove(odir + '.tar.gz')

monkeypatch.setenv('HLS4ML_WRITE_TAR', env_write_tar)
hls_model = hls4ml.converters.convert_from_keras_model(
keras_model, hls_config=config, output_dir=odir, backend=backend, write_tar=False
)
hls_model.write()

tar_written = os.path.exists(odir + '.tar.gz')
assert tar_written == expected_tar


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) # No Quartus for now
@pytest.mark.parametrize('write_weights_txt', [True, False])
def test_write_weights_txt(test_case_id, keras_model, write_weights_txt, backend):
Expand Down
Loading