diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index aec83d714d..4811f78a6e 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -113,6 +113,54 @@ def _all_supported_types(self, types: Iterable[PrecisionType]): return False return True + def _apply_max_precision_constraints(self, node, precision): + """ + Clamps the precision to the node's max_precision constraints. + + Logic: + 1. Width/Integer: Always constrained to the minimum of inferred vs max. + 2. Rounding/Saturation: Inherited from max_precision ONLY if they differ from the defaults + (meaning the user likely set them explicitly). + 3. Signedness: max_precision signed arg is always preferred. + """ + max_precision = self._get_maximum_precision(node) + + if max_precision is None: + return precision + + new_width = min(precision.width, max_precision.width) + new_integer = min(precision.integer, max_precision.integer) + + # Default modes defined in FixedPrecisionType + default_type = FixedPrecisionType() + DEFAULT_RND = default_type.rounding_mode + DEFAULT_SAT = default_type.saturation_mode + DEFAULT_SAT_BITS = default_type.saturation_bits + + if max_precision.rounding_mode != DEFAULT_RND: + new_rounding_mode = max_precision.rounding_mode + else: + new_rounding_mode = precision.rounding_mode + + if max_precision.saturation_mode != DEFAULT_SAT: + new_saturation_mode = max_precision.saturation_mode + else: + new_saturation_mode = precision.saturation_mode + + if max_precision.saturation_bits != DEFAULT_SAT_BITS: + new_saturation_bits = max_precision.saturation_bits + else: + new_saturation_bits = precision.saturation_bits + + return FixedPrecisionType( + width=new_width, + integer=new_integer, + signed=max_precision.signed, + rounding_mode=new_rounding_mode, + saturation_mode=new_saturation_mode, + saturation_bits=new_saturation_bits, + ) + def _infer_default_type(self, node, type_name): model_config = node.model.config default_precision = model_config.backend.convert_precision_string(model_config.model_precision['default']) @@ -180,14 +228,8 @@ def _infer_common_precision(self, node, types_to_infer, n_ops): bitwidth = integers + max(frac, bias_width - bias_integers) signed = signed or bias_signed - # if max_precision is specified, limit the size to be less than max precisoin - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - bitwidth = min(bitwidth, max_precision.width) - integers = min(integers, max_precision.integer) - - # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. - new_type = FixedPrecisionType(bitwidth, integers, signed) + out_precision = FixedPrecisionType(bitwidth, integers, signed) + new_type = self._apply_max_precision_constraints(node, out_precision) else: new_type = self._get_default_precision(node) @@ -334,15 +376,8 @@ def _infer_bn_precision(self, node, types_to_infer): out_precision_width = out_precision_integer + max( after_scale_width - after_scale_integer, bias_precision.fractional ) - - # if max_precision is specified, limit the size to be less than max precisoin - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - out_precision_width = min(out_precision_width, max_precision.width) - out_precision_integer = min(out_precision_integer, max_precision.integer) - - # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. out_precision = FixedPrecisionType(out_precision_width, out_precision_integer, out_precision_signed) + out_precision = self._apply_max_precision_constraints(node, out_precision) else: out_precision = self._get_default_precision(node) @@ -413,11 +448,8 @@ def _infer_merge_precision(self, node, types_to_infer): + 1 ) new_width = new_int + max(input_1.fractional, input_2.fractional) - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - new_width = min(new_width, max_precision.width) - new_int = min(new_int, max_precision.integer) out_precision = FixedPrecisionType(new_width, new_int, new_signed) + out_precision = self._apply_max_precision_constraints(node, out_precision) else: out_precision = self._get_default_precision(node) elif op == 'multiply': @@ -425,12 +457,8 @@ def _infer_merge_precision(self, node, types_to_infer): new_signed = input_1.signed or input_2.signed new_int = input_1.integer + input_2.integer new_width = input_1.width + input_2.width - # if max_precision is specified, limit the size to be less than max precisoin - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - new_width = min(new_width, max_precision.width) - new_int = min(new_int, max_precision.integer) out_precision = FixedPrecisionType(new_width, new_int, new_signed) + out_precision = self._apply_max_precision_constraints(node, out_precision) else: out_precision = self._get_default_precision(node) elif op in ('maximum', 'minimum'): @@ -487,18 +515,13 @@ def _infer_cat_precision(self, node, types_to_infer): new_width = max(input_1.fractional, input_2.fractional) + max(input_1_integer, input_2_integer) new_int = max(input_1_integer, input_2_integer) - # if max_precision is specified, limit the size to be less than max precisoin - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - new_width = min(new_width, max_precision.width) - new_int = min(new_int, max_precision.integer) - # some logic copied from former SetPrecisionConcat optimizer newrmode = input_1.rounding_mode if input_1.rounding_mode != RoundingMode.TRN else input_2.rounding_mode newsmode = input_1.saturation_mode if input_1.saturation_mode != SaturationMode.WRAP else input_2.saturation_mode newsbits = input_1.saturation_bits if input_1.saturation_bits != 0 else input_2.saturation_bits out_precision = FixedPrecisionType(new_width, new_int, new_signed, newrmode, newsmode, newsbits) + out_precision = self._apply_max_precision_constraints(node, out_precision) else: out_precision = self._get_default_precision(node) @@ -520,13 +543,8 @@ def _infer_dot_precision(self, node, types_to_infer): new_width = input_1.width + input_2.width + math.ceil(np.log2(n_in)) new_int = input_1.integer + input_2.integer + math.ceil(np.log2(n_in)) - # if max_precision is specified, limit the size to be less than max precisoin - max_precision = self._get_maximum_precision(node) - if max_precision is not None: - new_width = min(new_width, max_precision.width) - new_int = min(new_int, max_precision.integer) - out_precision = FixedPrecisionType(new_width, new_int, new_signed) + out_precision = self._apply_max_precision_constraints(node, out_precision) else: out_precision = self._get_default_precision(node) node.types['result_t'].name = node.name + '_result_t' diff --git a/test/pytest/test_max_precision.py b/test/pytest/test_max_precision.py new file mode 100644 index 0000000000..99df65a4b1 --- /dev/null +++ b/test/pytest/test_max_precision.py @@ -0,0 +1,311 @@ +from collections import namedtuple + +import pytest + +from hls4ml.model.optimizer.passes.infer_precision import InferPrecisionTypes +from hls4ml.model.types import ( + FixedPrecisionType, + IntegerPrecisionType, + NamedType, + RoundingMode, + SaturationMode, + UnspecifiedPrecisionType, +) + + +class MockBackend: + def convert_precision_string(self, precision_string): + """ + Simple mock that expects a FixedPrecisionType object or None + to be passed directly for testing purposes, or a simple string parser. + """ + if isinstance(precision_string, (FixedPrecisionType, IntegerPrecisionType)): + return precision_string + return None + + +class MockConfig: + def __init__(self, max_precision=None, default_precision=None): + self.model_precision = {} + if max_precision: + self.model_precision['maximum'] = max_precision + if default_precision: + self.model_precision['default'] = default_precision + + self.backend = MockBackend() + + +class MockModel: + def __init__(self, max_precision=None): + default = FixedPrecisionType(width=16, integer=6) + self.config = MockConfig(max_precision, default) + + +class MockVariable: + def __init__(self, precision): + self.type = namedtuple('Type', ['precision'])(precision) + self.shape = [10, 10] + + +class MockWeight: + def __init__(self, precision): + self.precision = precision + self.nonzeros = 10 + + def update_precision(self, new_precision): + self.precision = new_precision + + +class MockNode: + def __init__(self, class_name, name='test_node', max_precision=None, inputs=None): + self.class_name = class_name + self.name = name + self.model = MockModel(max_precision) + self.attributes = { + 'n_in': 10, + 'n_out': 10, + 'n_chan': 3, + 'filt_height': 3, + 'filt_width': 3, + 'pool_height': 2, + 'pool_width': 2, + 'op': 'multiply', # Default for merge tests + 'pool_op': 'average', + } + self.types = { + 'result_t': NamedType('result_t', UnspecifiedPrecisionType()), + 'accum_t': NamedType('accum_t', UnspecifiedPrecisionType()), + 'weight_t': NamedType('weight_t', FixedPrecisionType(8, 4)), + 'bias_t': NamedType('bias_t', FixedPrecisionType(8, 4)), + 'scale_t': NamedType('scale_t', FixedPrecisionType(8, 4)), + 'pointwise_t': NamedType('pointwise_t', FixedPrecisionType(8, 4)), + } + self.weights = { + 'weight': MockWeight(FixedPrecisionType(8, 4)), + 'bias': MockWeight(FixedPrecisionType(8, 4)), + 'scale': MockWeight(FixedPrecisionType(8, 4)), + 'pointwise': MockWeight(FixedPrecisionType(8, 4)), + } + + # Setup inputs + self.inputs = inputs if inputs else ['input_1'] + self._input_vars = {'input_1': MockVariable(FixedPrecisionType(16, 6))} + if len(self.inputs) > 1: + self._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6)) + + def get_attr(self, key, default=None): + return self.attributes.get(key, default) + + def get_input_variable(self, input_name=None): + if input_name is None: + return self._input_vars[self.inputs[0]] + return self._input_vars.get(input_name) + + def get_output_variable(self): + return MockVariable(UnspecifiedPrecisionType()) + + +@pytest.fixture +def optimizer(): + return InferPrecisionTypes() + + +class TestApplyMaxPrecisionConstraints: + """ + Tests the logic of _apply_max_precision_constraints function directly. + """ + + def test_no_max_precision_set(self, optimizer): + """If 'maximum' is not in config, return precision unchanged.""" + node = MockNode('Dense', max_precision=None) + + input_prec = FixedPrecisionType(width=20, integer=10) + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.width == 20 + assert result.integer == 10 + + def test_clamp_width(self, optimizer): + """Should reduce width if input > max.""" + max_prec = FixedPrecisionType(width=16, integer=10) + node = MockNode('Dense', max_precision=max_prec) + + input_prec = FixedPrecisionType(width=32, integer=10) + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.width == 16 + assert result.integer == 10 + + def test_clamp_integer(self, optimizer): + """Should reduce integer bits if input > max.""" + max_prec = FixedPrecisionType(width=32, integer=5) + node = MockNode('Dense', max_precision=max_prec) + + input_prec = FixedPrecisionType(width=32, integer=10) + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.width == 32 + assert result.integer == 5 + + def test_signedness_inheritance(self, optimizer): + """Should always adopt the signedness of the maximum precision.""" + # Max is Unsigned (signed=0) + max_prec = FixedPrecisionType(width=32, integer=10, signed=0) + node = MockNode('Dense', max_precision=max_prec) + + # Input is Signed + input_prec = FixedPrecisionType(width=32, integer=10, signed=1) + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.signed == 0 + + def test_mode_inheritance_from_max(self, optimizer): + """If Max specifies rounding/sat modes, they should override input.""" + max_prec = FixedPrecisionType( + 16, 6, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT, saturation_bits=2 + ) + node = MockNode('Dense', max_precision=max_prec) + + # Input has different modes + input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.TRN, saturation_mode=SaturationMode.WRAP) + + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.rounding_mode == RoundingMode.RND + assert result.saturation_mode == SaturationMode.SAT + assert result.saturation_bits == 2 + + def test_mode_preservation_when_max_is_none(self, optimizer): + """If Max modes are default, input modes should be preserved.""" + # Create a max precision where modes are initialized with defaults + max_prec = FixedPrecisionType(16, 6) + + node = MockNode('Dense', max_precision=max_prec) + + input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.RND_ZERO, saturation_mode=SaturationMode.SAT_SYM) + + result = optimizer._apply_max_precision_constraints(node, input_prec) + + assert result.rounding_mode == RoundingMode.RND_ZERO + assert result.saturation_mode == SaturationMode.SAT_SYM + + +class TestInferPrecision: + """ + Tests that _infer_precision calls apply_max_constraints for specific layers. + We verify this by setting a strict Max constraint and asserting the result_t + complies with it. + """ + + # Define a strict constraint + STRICT_MAX = FixedPrecisionType(width=4, integer=2, signed=True) + + @pytest.mark.parametrize( + 'layer_class', + [ + 'Dense', + 'Conv1D', + 'Conv2D', + 'PointwiseConv2D', + 'DepthwiseConv2D', + ], + ) + def test_common_precision_layers(self, optimizer, layer_class): + """Tests layers that use _infer_common_precision.""" + node = MockNode(layer_class, max_precision=self.STRICT_MAX) + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16, signed=1)) + + types_to_infer = ['result_t', 'accum_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2 + + def test_batch_normalization(self, optimizer): + """Tests BN layer inference.""" + node = MockNode('BatchNormalization', max_precision=self.STRICT_MAX) + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2 + + def test_merge_multiply(self, optimizer): + """Tests Merge layer with Multiply op.""" + node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) + node.attributes['op'] = 'multiply' + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2 + + def test_merge_add(self, optimizer): + """Tests Merge layer with Add op.""" + node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) + node.attributes['op'] = 'add' + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2 + + def test_concatenate_same_input_precisions(self, optimizer): + """ + Tests Concatenate layer. If precisions of both inputs are the same, + max precision is ignored (see _infer_cat_precision function). + """ + node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 20 + assert res_prec.integer == 10 + + def test_concatenate_different_input_precisions(self, optimizer): + """Tests Concatenate layer.""" + node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2 + + def test_dot(self, optimizer): + """Tests Dot layer.""" + node = MockNode('Dot', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) + + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) + + types_to_infer = ['result_t'] + optimizer._infer_precision(node, types_to_infer) + + res_prec = node.types['result_t'].precision + assert res_prec.width == 4 + assert res_prec.integer == 2