diff --git a/hls4ml/backends/catapult/catapult_backend.py b/hls4ml/backends/catapult/catapult_backend.py index 5b493cd944..8ee06fb4a1 100644 --- a/hls4ml/backends/catapult/catapult_backend.py +++ b/hls4ml/backends/catapult/catapult_backend.py @@ -195,6 +195,7 @@ def create_initial_config( fifo=None, clock_period=5, io_type='io_parallel', + write_tar=False, ): config = {} @@ -206,6 +207,9 @@ def create_initial_config( config['ClockPeriod'] = clock_period config['FIFO'] = fifo config['IOType'] = io_type + config['WriterConfig'] = { + 'WriteTar': write_tar, + } config['HLSConfig'] = {} return config diff --git a/hls4ml/backends/vivado_accelerator/vivado_accelerator_backend.py b/hls4ml/backends/vivado_accelerator/vivado_accelerator_backend.py index 128a8a8345..6bb4ed1200 100644 --- a/hls4ml/backends/vivado_accelerator/vivado_accelerator_backend.py +++ b/hls4ml/backends/vivado_accelerator/vivado_accelerator_backend.py @@ -102,6 +102,7 @@ def create_initial_config( clock_period=5, clock_uncertainty='12.5%', io_type='io_parallel', + write_tar=False, interface='axi_stream', driver='python', input_type='float', @@ -131,7 +132,13 @@ def create_initial_config( populated config """ board = board if board is not None else 'pynq-z2' - config = super().create_initial_config(part, clock_period, clock_uncertainty, io_type) + config = super().create_initial_config( + part=part, + clock_period=clock_period, + clock_uncertainty=clock_uncertainty, + io_type=io_type, + write_tar=write_tar, + ) config['AcceleratorConfig'] = {} config['AcceleratorConfig']['Board'] = board config['AcceleratorConfig']['Interface'] = interface # axi_stream, axi_master, axi_lite diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py index ba7e511995..6b32357f4d 100755 --- a/hls4ml/writer/catapult_writer.py +++ b/hls4ml/writer/catapult_writer.py @@ -892,12 +892,13 @@ def write_tar(self, model): Args: model (ModelGraph): the hls4ml model. """ - - if not os.path.exists(model.config.get_output_dir() + '.tar.gz'): - with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True) - else: - print('Project .tar.gz archive already exists') + if not self.should_write_tar(model): + return + tar_path = model.config.get_output_dir() + '.tar.gz' + if os.path.exists(tar_path): + os.remove(tar_path) + with tarfile.open(tar_path, mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True, arcname='') def write_hls(self, model): self.write_output_dir(model) diff --git a/hls4ml/writer/libero_writer.py b/hls4ml/writer/libero_writer.py index a5be68c81e..c3c6a1b56a 100644 --- a/hls4ml/writer/libero_writer.py +++ b/hls4ml/writer/libero_writer.py @@ -884,13 +884,12 @@ def write_tar(self, model): Args: model (ModelGraph): the hls4ml model. """ - - write_tar = model.config.get_writer_config().get('WriteTar', False) - if write_tar: - tar_path = Path(model.config.get_output_dir() + '.tar.gz') - tar_path.unlink(missing_ok=True) - with tarfile.open(tar_path, mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True, arcname='') + if not self.should_write_tar(model): + return + tar_path = Path(model.config.get_output_dir() + '.tar.gz') + tar_path.unlink(missing_ok=True) + with tarfile.open(tar_path, mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True, arcname='') def write_hls(self, model): print('Writing HLS project') diff --git a/hls4ml/writer/oneapi_writer.py b/hls4ml/writer/oneapi_writer.py index 3c0a778c50..87ee1df2eb 100644 --- a/hls4ml/writer/oneapi_writer.py +++ b/hls4ml/writer/oneapi_writer.py @@ -971,13 +971,13 @@ def write_tar(self, model): Args: model (ModelGraph): the hls4ml model. """ - - if model.config.get_writer_config().get('WriteTar', False): - tar_path = model.config.get_output_dir() + '.tar.gz' - if os.path.exists(tar_path): - os.remove(tar_path) - with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True) + if not self.should_write_tar(model): + return + tar_path = model.config.get_output_dir() + '.tar.gz' + if os.path.exists(tar_path): + os.remove(tar_path) + with tarfile.open(tar_path, mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True) def write_hls(self, model): self.write_project_dir(model) diff --git a/hls4ml/writer/quartus_writer.py b/hls4ml/writer/quartus_writer.py index e0d6338ac3..c23e49da62 100644 --- a/hls4ml/writer/quartus_writer.py +++ b/hls4ml/writer/quartus_writer.py @@ -1330,13 +1330,13 @@ def write_tar(self, model): Args: model (ModelGraph): the hls4ml model. """ - - if model.config.get_writer_config().get('WriteTar', False): - tar_path = model.config.get_output_dir() + '.tar.gz' - if os.path.exists(tar_path): - os.remove(tar_path) - with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True) + if not self.should_write_tar(model): + return + tar_path = model.config.get_output_dir() + '.tar.gz' + if os.path.exists(tar_path): + os.remove(tar_path) + with tarfile.open(tar_path, mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True) def write_hls(self, model): self.write_project_dir(model) diff --git a/hls4ml/writer/vivado_accelerator_writer.py b/hls4ml/writer/vivado_accelerator_writer.py index 7557eee019..e6a2a744d3 100644 --- a/hls4ml/writer/vivado_accelerator_writer.py +++ b/hls4ml/writer/vivado_accelerator_writer.py @@ -405,6 +405,8 @@ def write_driver(self, model): ) def write_new_tar(self, model): + if not self.should_write_tar(model): + return tarfile = model.config.get_output_dir() + '.tar.gz' if os.path.exists(tarfile): os.remove(tarfile) diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index dc5556cb33..7fdb7478f6 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -1112,14 +1112,13 @@ def write_tar(self, model): Args: model (ModelGraph): the hls4ml model. """ - - write_tar = model.config.get_writer_config().get('WriteTar', False) - if write_tar: - tar_path = model.config.get_output_dir() + '.tar.gz' - if os.path.exists(tar_path): - os.remove(tar_path) - with tarfile.open(tar_path, mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True, arcname='') + if not self.should_write_tar(model): + return + tar_path = model.config.get_output_dir() + '.tar.gz' + if os.path.exists(tar_path): + os.remove(tar_path) + with tarfile.open(tar_path, mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True, arcname='') def write_hls(self, model, is_multigraph=False): if not is_multigraph: diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index 54caec1d11..3b8b752d80 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -1,7 +1,18 @@ +import os + +WRITE_TAR_ENV_VAR = 'HLS4ML_WRITE_TAR' + + class Writer: def __init__(self): pass + def should_write_tar(self, model): + write_tar_config = model.config.get_writer_config().get('WriteTar', False) + env_value = os.environ.get(WRITE_TAR_ENV_VAR, '') + write_tar_env = env_value.strip().lower() in {'1', 'true'} + return write_tar_config or write_tar_env + def write_hls(self, model): raise NotImplementedError diff --git a/test/pytest/test_writer_config.py b/test/pytest/test_writer_config.py index 90e35e056e..d04655d6c4 100644 --- a/test/pytest/test_writer_config.py +++ b/test/pytest/test_writer_config.py @@ -42,10 +42,10 @@ def test_emulator(test_case_id, keras_model, io_type, backend): hls_model.compile() # It's enough that the model compiles -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) # No Quartus for now +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI', 'VivadoAccelerator']) @pytest.mark.parametrize('write_tar', [True, False]) def test_write_tar(test_case_id, keras_model, write_tar, backend): - config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name') + config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend) odir = str(test_root_path / test_case_id) if os.path.exists(odir + '.tar.gz'): @@ -60,6 +60,32 @@ def test_write_tar(test_case_id, keras_model, write_tar, backend): assert tar_written == write_tar +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI', 'VivadoAccelerator']) +@pytest.mark.parametrize( + 'env_write_tar, expected_tar', + [ + ('false', False), + ('1', True), + ], + ids=['write_tar_env_false', 'write_tar_env_true'], +) +def test_write_tar_env_override(test_case_id, keras_model, backend, monkeypatch, env_write_tar, expected_tar): + config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend) + odir = str(test_root_path / test_case_id) + + if os.path.exists(odir + '.tar.gz'): + os.remove(odir + '.tar.gz') + + monkeypatch.setenv('HLS4ML_WRITE_TAR', env_write_tar) + hls_model = hls4ml.converters.convert_from_keras_model( + keras_model, hls_config=config, output_dir=odir, backend=backend, write_tar=False + ) + hls_model.write() + + tar_written = os.path.exists(odir + '.tar.gz') + assert tar_written == expected_tar + + @pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) # No Quartus for now @pytest.mark.parametrize('write_weights_txt', [True, False]) def test_write_weights_txt(test_case_id, keras_model, write_weights_txt, backend):