From 7b6dfb933b7ea970fdd338bf6d92b804ce21e7d8 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 23 Oct 2025 11:17:45 +0900 Subject: [PATCH 01/27] [tools] Introduce circle2circle (python) It introduces circle2circle (= o2o, where o means circle). It aims to manipulate circles in Unix-filter way. ONE-DCO-1.0-Signed-off-by: Sanggyu Lee --- tools/circle2circle/README.md | 141 + tools/circle2circle/circle.py | 20819 ++++++++++++++++ tools/circle2circle/fuse.bmm_lhs_const.py | 291 + tools/circle2circle/gen_circle.add.py | 170 + .../gen_circle.bmm_lhs_const.fc.py | 189 + tools/circle2circle/o2o.py | 125 + tools/circle2circle/remove.io.py | 99 + tools/circle2circle/remove.unused_tensors.py | 216 + .../rename.io.remove_namespace.py | 78 + .../circle2circle/rename.io.remove_prefix.py | 54 + tools/circle2circle/reshape.fc_weight.py | 114 + tools/circle2circle/reshape.io.py | 78 + tools/circle2circle/transpose.io.kvcache.py | 65 + 13 files changed, 22439 insertions(+) create mode 100644 tools/circle2circle/README.md create mode 100644 tools/circle2circle/circle.py create mode 100755 tools/circle2circle/fuse.bmm_lhs_const.py create mode 100755 tools/circle2circle/gen_circle.add.py create mode 100755 tools/circle2circle/gen_circle.bmm_lhs_const.fc.py create mode 100755 tools/circle2circle/o2o.py create mode 100755 tools/circle2circle/remove.io.py create mode 100755 tools/circle2circle/remove.unused_tensors.py create mode 100755 tools/circle2circle/rename.io.remove_namespace.py create mode 100755 tools/circle2circle/rename.io.remove_prefix.py create mode 100755 tools/circle2circle/reshape.fc_weight.py create mode 100755 tools/circle2circle/reshape.io.py create mode 100755 tools/circle2circle/transpose.io.kvcache.py diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md new file mode 100644 index 00000000000..c6d10283d5c --- /dev/null +++ b/tools/circle2circle/README.md @@ -0,0 +1,141 @@ +# circle2circle (circle to circle) + +`circle2circle` is a tool for transforming Circle models. + +It includes various filter command (= pass) to perform specific modifications. + +
+ +## How to Use + +Imagine Unix filter usage like `cat hello.txt | sort | uniq`. + +All circle2circle command scripts read a Circle model from **standard input** and write the transformed model to **standard output**. + +An example: + +```bash +./rename.io.remove_namespace.py < in.circle > out.circle +``` + +Filters example: + +```bash +./rename.io.remove_namespace.py < in.circle | ./rename.io.remove_prefix.py past_key_values_ > out.circle +``` + +
+ +## Filter List + +### `remove.io.py` + +Removes input or output tensors from a Circle model, keeping only the tensors at the specified indices. + +#### Arguments + +* `io_type` (required): Specifies whether to process `input` or `output` tensors. +* `--keep_by_name` (required): A string defining the names of the tensors to keep. It supports comma‑separated tensor names (e.g., "input1,input2"). + +## + +### `rename.io.remove_namespace.py` + +Removes namespaces from the names of input and output tensors. A namespace is identified as the part of the tensor name before a double colon (`::`). For example, a tensor named `module::input_tensor` would be renamed to `input_tensor`. + +## + +### `rename.io.remove_prefix.py` + +Removes a user-specified prefix from the names of all tensors in the model. + +#### Arguments + +* `prefix` (required): The string prefix to remove from tensor names. + + +## + +### `reshape.fc_weight.py` + +Reshapes the weight tensors of `FULLY_CONNECTED` operators from an effectively 2D shape (e.g., `[1, 1, D_out, D_in]`) to a strict 2D shape (`[D_out, D_in]`). This is useful for optimizing or standardizing the model structure. If a weight tensor is used by multiple operators, a new tensor is created for the specific operator to prevent conflicts. + +## + +### `transpose.io.kcache.py` + +Finds input tensors matching the pattern `*key_cache_\d+` (e.g., `past_key_values_key_cache_0`) and transposes their second and third dimensions if they are 4D. For example, a shape `[d0, d1, d2, d3]` will become `[d0, d2, d1, d3]`. + +## + +### `fuse.bmm_lhs_const.py` + +Fuses `BATCH_MATMUL` + `TRANSPOSE` to `FULLY_CONNECTED` when LHS is constant. + +#### Transformation Diagram + +``` +BEFORE: + +LHS(constant):[B,M,K] \ + BatchMatMul(LHS,RHS):[B,M,N] -> TRANSPOSE:[B,N,M] -> OUTPUT +RHS:[B,K,N] / + +AFTER: + +RHS:[B,K,N] \ + FullyConnected(RHS,LHS):[B,N,M] -> OUTPUT +LHS(constant):[B,M,K] / ~~ ~~ + input weights + +Condition: +- B = 1 and K = 1 + +Key Relationship: +- BatchMatMul's LHS (constant) becomes FullyConnected's weights +- BatchMatMul's RHS becomes FullyConnected's input +``` + +## + +### `select.op.py` + +Selectively removes operators from a Circle model based on their index range. This filter allows you to keep only the operators within specified index ranges while removing all others. It automatically handles tensor connections, updates subgraph inputs/outputs, and cleans up unused operator codes. + +#### Arguments + +* `--by_id` (required): Specifies the operator index range to keep. Supports multiple ranges separated by commas and individual indices. + +#### Example Usage + +```bash +# Keep only operators 0-181 +./select.op.py --by_id 0-181 < old.circle > new.circle + +# Keep operators 0-10 and 15-20 +./select.op.py --by_id 0-10,15-20 < old.circle > new.circle + +# Keep only operator 5 +./select.op.py --by_id 5 < old.circle > new.circle +``` + +## + +### `remove.unused_tensors.py` + +Identifies and removes unused tensors from all subgraphs within a Circle model. A tensor is considered "unused" if it is not an input to any operator and not an output of its containing subgraph. This helps in cleaning up the model and potentially reducing its size. The script can either list unused tensors or modify the model to remove them. + +## + +### `gen_circle.*.py` + + +These scripts generate test Circle models with specific operator patterns for development and testing purposes. Each script follows the naming convention `gen_circle..py` and automatically generates an output file with the name `.circle` when executed. + +#### `gen_circle.add.py` + +Generates a simple Circle model with one `ADD` operator for testing basic functionality. + +#### `gen_circle.bmm_lhs_const.fc.py` + +Generates a test Circle model with `BATCH_MATMUL` and `TRANSPOSE` operations where the LHS is constant. This model is designed to test the fusion pattern used in `fuse.bmm_lhs_const.py`. diff --git a/tools/circle2circle/circle.py b/tools/circle2circle/circle.py new file mode 100644 index 00000000000..76328683f10 --- /dev/null +++ b/tools/circle2circle/circle.py @@ -0,0 +1,20819 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: circle + +import flatbuffers +from flatbuffers.compat import import_numpy + +np = import_numpy() + + +class TensorType(object): + MXINT8 = -7 + MXFP4 = -6 + GGML_Q8_1 = -5 + GGML_Q8_0 = -4 + GGML_Q4_1 = -3 + GGML_Q4_0 = -2 + UINT4 = -1 + FLOAT32 = 0 + FLOAT16 = 1 + INT32 = 2 + UINT8 = 3 + INT64 = 4 + STRING = 5 + BOOL = 6 + INT16 = 7 + COMPLEX64 = 8 + INT8 = 9 + FLOAT64 = 10 + COMPLEX128 = 11 + UINT64 = 12 + RESOURCE = 13 + VARIANT = 14 + UINT32 = 15 + UINT16 = 16 + INT4 = 17 + + +class QuantizationDetails(object): + NONE = 0 + CustomQuantization = 1 + MXQuantization = 2 + + +def QuantizationDetailsCreator(unionType, table): + from flatbuffers.table import Table + if not isinstance(table, Table): + return None + if unionType == QuantizationDetails().CustomQuantization: + return CustomQuantizationT.InitFromBuf(table.Bytes, table.Pos) + if unionType == QuantizationDetails().MXQuantization: + return MXQuantizationT.InitFromBuf(table.Bytes, table.Pos) + return None + + +class DimensionType(object): + DENSE = 0 + SPARSE_CSR = 1 + + +class SparseIndexVector(object): + NONE = 0 + Int32Vector = 1 + Uint16Vector = 2 + Uint8Vector = 3 + + +def SparseIndexVectorCreator(unionType, table): + from flatbuffers.table import Table + if not isinstance(table, Table): + return None + if unionType == SparseIndexVector().Int32Vector: + return Int32VectorT.InitFromBuf(table.Bytes, table.Pos) + if unionType == SparseIndexVector().Uint16Vector: + return Uint16VectorT.InitFromBuf(table.Bytes, table.Pos) + if unionType == SparseIndexVector().Uint8Vector: + return Uint8VectorT.InitFromBuf(table.Bytes, table.Pos) + return None + + +class CompressionType(object): + NONE = 0 + HUFFMAN = 1 + + +class BuiltinOperator(object): + ATTENTION = -9 + RUN_MODEL = -8 + ROPE = -7 + RMS_NORM = -6 + GRU = -5 + BCQ_GATHER = -4 + BCQ_FULLY_CONNECTED = -3 + INSTANCE_NORM = -2 + ADD = 0 + AVERAGE_POOL_2D = 1 + CONCATENATION = 2 + CONV_2D = 3 + DEPTHWISE_CONV_2D = 4 + DEPTH_TO_SPACE = 5 + DEQUANTIZE = 6 + EMBEDDING_LOOKUP = 7 + FLOOR = 8 + FULLY_CONNECTED = 9 + HASHTABLE_LOOKUP = 10 + L2_NORMALIZATION = 11 + L2_POOL_2D = 12 + LOCAL_RESPONSE_NORMALIZATION = 13 + LOGISTIC = 14 + LSH_PROJECTION = 15 + LSTM = 16 + MAX_POOL_2D = 17 + MUL = 18 + RELU = 19 + RELU_N1_TO_1 = 20 + RELU6 = 21 + RESHAPE = 22 + RESIZE_BILINEAR = 23 + RNN = 24 + SOFTMAX = 25 + SPACE_TO_DEPTH = 26 + SVDF = 27 + TANH = 28 + CONCAT_EMBEDDINGS = 29 + SKIP_GRAM = 30 + CALL = 31 + CUSTOM = 32 + EMBEDDING_LOOKUP_SPARSE = 33 + PAD = 34 + UNIDIRECTIONAL_SEQUENCE_RNN = 35 + GATHER = 36 + BATCH_TO_SPACE_ND = 37 + SPACE_TO_BATCH_ND = 38 + TRANSPOSE = 39 + MEAN = 40 + SUB = 41 + DIV = 42 + SQUEEZE = 43 + UNIDIRECTIONAL_SEQUENCE_LSTM = 44 + STRIDED_SLICE = 45 + BIDIRECTIONAL_SEQUENCE_RNN = 46 + EXP = 47 + TOPK_V2 = 48 + SPLIT = 49 + LOG_SOFTMAX = 50 + DELEGATE = 51 + BIDIRECTIONAL_SEQUENCE_LSTM = 52 + CAST = 53 + PRELU = 54 + MAXIMUM = 55 + ARG_MAX = 56 + MINIMUM = 57 + LESS = 58 + NEG = 59 + PADV2 = 60 + GREATER = 61 + GREATER_EQUAL = 62 + LESS_EQUAL = 63 + SELECT = 64 + SLICE = 65 + SIN = 66 + TRANSPOSE_CONV = 67 + SPARSE_TO_DENSE = 68 + TILE = 69 + EXPAND_DIMS = 70 + EQUAL = 71 + NOT_EQUAL = 72 + LOG = 73 + SUM = 74 + SQRT = 75 + RSQRT = 76 + SHAPE = 77 + POW = 78 + ARG_MIN = 79 + FAKE_QUANT = 80 + REDUCE_PROD = 81 + REDUCE_MAX = 82 + PACK = 83 + LOGICAL_OR = 84 + ONE_HOT = 85 + LOGICAL_AND = 86 + LOGICAL_NOT = 87 + UNPACK = 88 + REDUCE_MIN = 89 + FLOOR_DIV = 90 + REDUCE_ANY = 91 + SQUARE = 92 + ZEROS_LIKE = 93 + FILL = 94 + FLOOR_MOD = 95 + RANGE = 96 + RESIZE_NEAREST_NEIGHBOR = 97 + LEAKY_RELU = 98 + SQUARED_DIFFERENCE = 99 + MIRROR_PAD = 100 + ABS = 101 + SPLIT_V = 102 + UNIQUE = 103 + CEIL = 104 + REVERSE_V2 = 105 + ADD_N = 106 + GATHER_ND = 107 + COS = 108 + WHERE = 109 + RANK = 110 + ELU = 111 + REVERSE_SEQUENCE = 112 + MATRIX_DIAG = 113 + QUANTIZE = 114 + MATRIX_SET_DIAG = 115 + ROUND = 116 + HARD_SWISH = 117 + IF = 118 + WHILE = 119 + NON_MAX_SUPPRESSION_V4 = 120 + NON_MAX_SUPPRESSION_V5 = 121 + SCATTER_ND = 122 + SELECT_V2 = 123 + DENSIFY = 124 + SEGMENT_SUM = 125 + BATCH_MATMUL = 126 + PLACEHOLDER_FOR_GREATER_OP_CODES = 127 + CUMSUM = 128 + CALL_ONCE = 129 + BROADCAST_TO = 130 + RFFT2D = 131 + CONV_3D = 132 + IMAG = 133 + REAL = 134 + COMPLEX_ABS = 135 + HASHTABLE = 136 + HASHTABLE_FIND = 137 + HASHTABLE_IMPORT = 138 + HASHTABLE_SIZE = 139 + REDUCE_ALL = 140 + CONV_3D_TRANSPOSE = 141 + VAR_HANDLE = 142 + READ_VARIABLE = 143 + ASSIGN_VARIABLE = 144 + BROADCAST_ARGS = 145 + RANDOM_STANDARD_NORMAL = 146 + BUCKETIZE = 147 + RANDOM_UNIFORM = 148 + MULTINOMIAL = 149 + GELU = 150 + DYNAMIC_UPDATE_SLICE = 151 + RELU_0_TO_1 = 152 + UNSORTED_SEGMENT_PROD = 153 + UNSORTED_SEGMENT_MAX = 154 + UNSORTED_SEGMENT_SUM = 155 + ATAN2 = 156 + UNSORTED_SEGMENT_MIN = 157 + SIGN = 158 + BITCAST = 159 + BITWISE_XOR = 160 + RIGHT_SHIFT = 161 + STABLEHLO_LOGISTIC = 162 + STABLEHLO_ADD = 163 + STABLEHLO_DIVIDE = 164 + STABLEHLO_MULTIPLY = 165 + STABLEHLO_MAXIMUM = 166 + STABLEHLO_RESHAPE = 167 + STABLEHLO_CLAMP = 168 + STABLEHLO_CONCATENATE = 169 + STABLEHLO_BROADCAST_IN_DIM = 170 + STABLEHLO_CONVOLUTION = 171 + STABLEHLO_SLICE = 172 + STABLEHLO_CUSTOM_CALL = 173 + STABLEHLO_REDUCE = 174 + STABLEHLO_ABS = 175 + STABLEHLO_AND = 176 + STABLEHLO_COSINE = 177 + STABLEHLO_EXPONENTIAL = 178 + STABLEHLO_FLOOR = 179 + STABLEHLO_LOG = 180 + STABLEHLO_MINIMUM = 181 + STABLEHLO_NEGATE = 182 + STABLEHLO_OR = 183 + STABLEHLO_POWER = 184 + STABLEHLO_REMAINDER = 185 + STABLEHLO_RSQRT = 186 + STABLEHLO_SELECT = 187 + STABLEHLO_SUBTRACT = 188 + STABLEHLO_TANH = 189 + STABLEHLO_SCATTER = 190 + STABLEHLO_COMPARE = 191 + STABLEHLO_CONVERT = 192 + STABLEHLO_DYNAMIC_SLICE = 193 + STABLEHLO_DYNAMIC_UPDATE_SLICE = 194 + STABLEHLO_PAD = 195 + STABLEHLO_IOTA = 196 + STABLEHLO_DOT_GENERAL = 197 + STABLEHLO_REDUCE_WINDOW = 198 + STABLEHLO_SORT = 199 + STABLEHLO_WHILE = 200 + STABLEHLO_GATHER = 201 + STABLEHLO_TRANSPOSE = 202 + DILATE = 203 + STABLEHLO_RNG_BIT_GENERATOR = 204 + REDUCE_WINDOW = 205 + + +class BuiltinOptions(object): + NONE = 0 + Conv2DOptions = 1 + DepthwiseConv2DOptions = 2 + ConcatEmbeddingsOptions = 3 + LSHProjectionOptions = 4 + Pool2DOptions = 5 + SVDFOptions = 6 + RNNOptions = 7 + FullyConnectedOptions = 8 + SoftmaxOptions = 9 + ConcatenationOptions = 10 + AddOptions = 11 + L2NormOptions = 12 + LocalResponseNormalizationOptions = 13 + LSTMOptions = 14 + ResizeBilinearOptions = 15 + CallOptions = 16 + ReshapeOptions = 17 + SkipGramOptions = 18 + SpaceToDepthOptions = 19 + EmbeddingLookupSparseOptions = 20 + MulOptions = 21 + PadOptions = 22 + GatherOptions = 23 + BatchToSpaceNDOptions = 24 + SpaceToBatchNDOptions = 25 + TransposeOptions = 26 + ReducerOptions = 27 + SubOptions = 28 + DivOptions = 29 + SqueezeOptions = 30 + SequenceRNNOptions = 31 + StridedSliceOptions = 32 + ExpOptions = 33 + TopKV2Options = 34 + SplitOptions = 35 + LogSoftmaxOptions = 36 + CastOptions = 37 + DequantizeOptions = 38 + MaximumMinimumOptions = 39 + ArgMaxOptions = 40 + LessOptions = 41 + NegOptions = 42 + PadV2Options = 43 + GreaterOptions = 44 + GreaterEqualOptions = 45 + LessEqualOptions = 46 + SelectOptions = 47 + SliceOptions = 48 + TransposeConvOptions = 49 + SparseToDenseOptions = 50 + TileOptions = 51 + ExpandDimsOptions = 52 + EqualOptions = 53 + NotEqualOptions = 54 + ShapeOptions = 55 + PowOptions = 56 + ArgMinOptions = 57 + FakeQuantOptions = 58 + PackOptions = 59 + LogicalOrOptions = 60 + OneHotOptions = 61 + LogicalAndOptions = 62 + LogicalNotOptions = 63 + UnpackOptions = 64 + FloorDivOptions = 65 + SquareOptions = 66 + ZerosLikeOptions = 67 + FillOptions = 68 + BidirectionalSequenceLSTMOptions = 69 + BidirectionalSequenceRNNOptions = 70 + UnidirectionalSequenceLSTMOptions = 71 + FloorModOptions = 72 + RangeOptions = 73 + ResizeNearestNeighborOptions = 74 + LeakyReluOptions = 75 + SquaredDifferenceOptions = 76 + MirrorPadOptions = 77 + AbsOptions = 78 + SplitVOptions = 79 + UniqueOptions = 80 + ReverseV2Options = 81 + AddNOptions = 82 + GatherNdOptions = 83 + CosOptions = 84 + WhereOptions = 85 + RankOptions = 86 + ReverseSequenceOptions = 87 + MatrixDiagOptions = 88 + QuantizeOptions = 89 + MatrixSetDiagOptions = 90 + HardSwishOptions = 91 + IfOptions = 92 + WhileOptions = 93 + DepthToSpaceOptions = 94 + NonMaxSuppressionV4Options = 95 + NonMaxSuppressionV5Options = 96 + ScatterNdOptions = 97 + SelectV2Options = 98 + DensifyOptions = 99 + SegmentSumOptions = 100 + BatchMatMulOptions = 101 + CumsumOptions = 102 + CallOnceOptions = 103 + BroadcastToOptions = 104 + Rfft2dOptions = 105 + Conv3DOptions = 106 + HashtableOptions = 107 + HashtableFindOptions = 108 + HashtableImportOptions = 109 + HashtableSizeOptions = 110 + VarHandleOptions = 111 + ReadVariableOptions = 112 + AssignVariableOptions = 113 + RandomOptions = 114 + BucketizeOptions = 115 + GeluOptions = 116 + DynamicUpdateSliceOptions = 117 + UnsortedSegmentProdOptions = 118 + UnsortedSegmentMaxOptions = 119 + UnsortedSegmentMinOptions = 120 + UnsortedSegmentSumOptions = 121 + ATan2Options = 122 + SignOptions = 123 + BitcastOptions = 124 + BitwiseXorOptions = 125 + RightShiftOptions = 126 + AttentionOptions = 247 + RunModelOptions = 248 + RoPEOptions = 249 + RmsNormOptions = 250 + GRUOptions = 251 + BCQGatherOptions = 252 + BCQFullyConnectedOptions = 253 + InstanceNormOptions = 254 + + +def BuiltinOptionsCreator(unionType, table): + from flatbuffers.table import Table + if not isinstance(table, Table): + return None + if unionType == BuiltinOptions().Conv2DOptions: + return Conv2DOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DepthwiseConv2DOptions: + return DepthwiseConv2DOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ConcatEmbeddingsOptions: + return ConcatEmbeddingsOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LSHProjectionOptions: + return LSHProjectionOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().Pool2DOptions: + return Pool2DOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SVDFOptions: + return SVDFOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RNNOptions: + return RNNOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().FullyConnectedOptions: + return FullyConnectedOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SoftmaxOptions: + return SoftmaxOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ConcatenationOptions: + return ConcatenationOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().AddOptions: + return AddOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().L2NormOptions: + return L2NormOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LocalResponseNormalizationOptions: + return LocalResponseNormalizationOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LSTMOptions: + return LSTMOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ResizeBilinearOptions: + return ResizeBilinearOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().CallOptions: + return CallOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ReshapeOptions: + return ReshapeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SkipGramOptions: + return SkipGramOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SpaceToDepthOptions: + return SpaceToDepthOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().EmbeddingLookupSparseOptions: + return EmbeddingLookupSparseOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().MulOptions: + return MulOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().PadOptions: + return PadOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GatherOptions: + return GatherOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BatchToSpaceNDOptions: + return BatchToSpaceNDOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SpaceToBatchNDOptions: + return SpaceToBatchNDOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().TransposeOptions: + return TransposeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ReducerOptions: + return ReducerOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SubOptions: + return SubOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DivOptions: + return DivOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SqueezeOptions: + return SqueezeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SequenceRNNOptions: + return SequenceRNNOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().StridedSliceOptions: + return StridedSliceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ExpOptions: + return ExpOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().TopKV2Options: + return TopKV2OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SplitOptions: + return SplitOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LogSoftmaxOptions: + return LogSoftmaxOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().CastOptions: + return CastOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DequantizeOptions: + return DequantizeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().MaximumMinimumOptions: + return MaximumMinimumOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ArgMaxOptions: + return ArgMaxOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LessOptions: + return LessOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().NegOptions: + return NegOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().PadV2Options: + return PadV2OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GreaterOptions: + return GreaterOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GreaterEqualOptions: + return GreaterEqualOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LessEqualOptions: + return LessEqualOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SelectOptions: + return SelectOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SliceOptions: + return SliceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().TransposeConvOptions: + return TransposeConvOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SparseToDenseOptions: + return SparseToDenseOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().TileOptions: + return TileOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ExpandDimsOptions: + return ExpandDimsOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().EqualOptions: + return EqualOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().NotEqualOptions: + return NotEqualOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ShapeOptions: + return ShapeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().PowOptions: + return PowOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ArgMinOptions: + return ArgMinOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().FakeQuantOptions: + return FakeQuantOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().PackOptions: + return PackOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LogicalOrOptions: + return LogicalOrOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().OneHotOptions: + return OneHotOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LogicalAndOptions: + return LogicalAndOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LogicalNotOptions: + return LogicalNotOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnpackOptions: + return UnpackOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().FloorDivOptions: + return FloorDivOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SquareOptions: + return SquareOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ZerosLikeOptions: + return ZerosLikeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().FillOptions: + return FillOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BidirectionalSequenceLSTMOptions: + return BidirectionalSequenceLSTMOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BidirectionalSequenceRNNOptions: + return BidirectionalSequenceRNNOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnidirectionalSequenceLSTMOptions: + return UnidirectionalSequenceLSTMOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().FloorModOptions: + return FloorModOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RangeOptions: + return RangeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ResizeNearestNeighborOptions: + return ResizeNearestNeighborOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().LeakyReluOptions: + return LeakyReluOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SquaredDifferenceOptions: + return SquaredDifferenceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().MirrorPadOptions: + return MirrorPadOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().AbsOptions: + return AbsOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SplitVOptions: + return SplitVOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UniqueOptions: + return UniqueOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ReverseV2Options: + return ReverseV2OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().AddNOptions: + return AddNOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GatherNdOptions: + return GatherNdOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().CosOptions: + return CosOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().WhereOptions: + return WhereOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RankOptions: + return RankOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ReverseSequenceOptions: + return ReverseSequenceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().MatrixDiagOptions: + return MatrixDiagOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().QuantizeOptions: + return QuantizeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().MatrixSetDiagOptions: + return MatrixSetDiagOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().HardSwishOptions: + return HardSwishOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().IfOptions: + return IfOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().WhileOptions: + return WhileOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DepthToSpaceOptions: + return DepthToSpaceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().NonMaxSuppressionV4Options: + return NonMaxSuppressionV4OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().NonMaxSuppressionV5Options: + return NonMaxSuppressionV5OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ScatterNdOptions: + return ScatterNdOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SelectV2Options: + return SelectV2OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DensifyOptions: + return DensifyOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SegmentSumOptions: + return SegmentSumOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BatchMatMulOptions: + return BatchMatMulOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().CumsumOptions: + return CumsumOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().CallOnceOptions: + return CallOnceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BroadcastToOptions: + return BroadcastToOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().Rfft2dOptions: + return Rfft2dOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().Conv3DOptions: + return Conv3DOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().HashtableOptions: + return HashtableOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().HashtableFindOptions: + return HashtableFindOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().HashtableImportOptions: + return HashtableImportOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().HashtableSizeOptions: + return HashtableSizeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().VarHandleOptions: + return VarHandleOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ReadVariableOptions: + return ReadVariableOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().AssignVariableOptions: + return AssignVariableOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RandomOptions: + return RandomOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BucketizeOptions: + return BucketizeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GeluOptions: + return GeluOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().DynamicUpdateSliceOptions: + return DynamicUpdateSliceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnsortedSegmentProdOptions: + return UnsortedSegmentProdOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnsortedSegmentMaxOptions: + return UnsortedSegmentMaxOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnsortedSegmentMinOptions: + return UnsortedSegmentMinOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().UnsortedSegmentSumOptions: + return UnsortedSegmentSumOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().ATan2Options: + return ATan2OptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().SignOptions: + return SignOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BitcastOptions: + return BitcastOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BitwiseXorOptions: + return BitwiseXorOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RightShiftOptions: + return RightShiftOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().AttentionOptions: + return AttentionOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RunModelOptions: + return RunModelOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RoPEOptions: + return RoPEOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().RmsNormOptions: + return RmsNormOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().GRUOptions: + return GRUOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BCQGatherOptions: + return BCQGatherOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().BCQFullyConnectedOptions: + return BCQFullyConnectedOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions().InstanceNormOptions: + return InstanceNormOptionsT.InitFromBuf(table.Bytes, table.Pos) + return None + + +class BuiltinOptions2(object): + NONE = 0 + StablehloConcatenateOptions = 1 + StablehloBroadcastInDimOptions = 2 + StablehloSliceOptions = 3 + StablehloConvolutionOptions = 4 + StablehloCustomCallOptions = 5 + StablehloReduceOptions = 6 + StablehloScatterOptions = 7 + StablehloCompareOptions = 8 + StablehloDynamicSliceOptions = 9 + StablehloPadOptions = 10 + StablehloIotaOptions = 11 + StablehloDotGeneralOptions = 12 + StablehloReduceWindowOptions = 13 + StablehloSortOptions = 14 + StablehloWhileOptions = 15 + StablehloGatherOptions = 16 + StablehloTransposeOptions = 17 + DilateOptions = 18 + StablehloRngBitGeneratorOptions = 19 + ReduceWindowOptions = 20 + + +def BuiltinOptions2Creator(unionType, table): + from flatbuffers.table import Table + if not isinstance(table, Table): + return None + if unionType == BuiltinOptions2().StablehloConcatenateOptions: + return StablehloConcatenateOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloBroadcastInDimOptions: + return StablehloBroadcastInDimOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloSliceOptions: + return StablehloSliceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloConvolutionOptions: + return StablehloConvolutionOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloCustomCallOptions: + return StablehloCustomCallOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloReduceOptions: + return StablehloReduceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloScatterOptions: + return StablehloScatterOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloCompareOptions: + return StablehloCompareOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloDynamicSliceOptions: + return StablehloDynamicSliceOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloPadOptions: + return StablehloPadOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloIotaOptions: + return StablehloIotaOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloDotGeneralOptions: + return StablehloDotGeneralOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloReduceWindowOptions: + return StablehloReduceWindowOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloSortOptions: + return StablehloSortOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloWhileOptions: + return StablehloWhileOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloGatherOptions: + return StablehloGatherOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloTransposeOptions: + return StablehloTransposeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().DilateOptions: + return DilateOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloRngBitGeneratorOptions: + return StablehloRngBitGeneratorOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().ReduceWindowOptions: + return ReduceWindowOptionsT.InitFromBuf(table.Bytes, table.Pos) + return None + + +class StablehloPrecisionConfig(object): + DEFAULT = 0 + HIGH = 1 + HIGHEST = 2 + + +class StablehloComparisonDirection(object): + STABLEHLO_COMPARISON_DIRECTION_EQ = 0 + STABLEHLO_COMPARISON_DIRECTION_NE = 1 + STABLEHLO_COMPARISON_DIRECTION_GE = 2 + STABLEHLO_COMPARISON_DIRECTION_GT = 3 + STABLEHLO_COMPARISON_DIRECTION_LE = 4 + STABLEHLO_COMPARISON_DIRECTION_LT = 5 + + +class StablehloComparisonType(object): + STABLEHLO_COMPARISON_TYPE_NOTYPE = 0 + STABLEHLO_COMPARISON_TYPE_FLOAT = 1 + STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER = 2 + STABLEHLO_COMPARISON_TYPE_SIGNED = 3 + STABLEHLO_COMPARISON_TYPE_UNSIGNED = 4 + + +class RngAlgorithm(object): + DEFAULT = 0 + PHILOX = 1 + THREEFRY = 2 + + +class Padding(object): + SAME = 0 + VALID = 1 + + +class ActivationFunctionType(object): + NONE = 0 + RELU = 1 + RELU_N1_TO_1 = 2 + RELU6 = 3 + TANH = 4 + SIGN_BIT = 5 + + +class LSHProjectionType(object): + UNKNOWN = 0 + SPARSE = 1 + DENSE = 2 + + +class FullyConnectedOptionsWeightsFormat(object): + DEFAULT = 0 + SHUFFLED4x16INT8 = 1 + SHUFFLED16x1FLOAT32 = 127 + + +class LSTMKernelType(object): + FULL = 0 + BASIC = 1 + + +class CombinerType(object): + SUM = 0 + MEAN = 1 + SQRTN = 2 + + +class MirrorPadMode(object): + REFLECT = 0 + SYMMETRIC = 1 + + +class ReduceWindowFunction(object): + UNSUPPORTED = 0 + ADD = 1 + MUL = 2 + MINIMUM = 3 + MAXIMUM = 4 + ALL = 5 + ANY = 6 + + +class RoPEMode(object): + GPT_NEOX = 0 + GPT_J = 1 + + +class CustomOptionsFormat(object): + FLEXBUFFERS = 0 + + +class DataFormat(object): + CHANNELS_LAST = 0 + CHANNELS_FIRST = 1 + + +class CustomQuantization(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CustomQuantization() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCustomQuantization(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CustomQuantizationBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CustomQuantization + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CustomQuantization + def Custom(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # CustomQuantization + def CustomAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # CustomQuantization + def CustomLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # CustomQuantization + def CustomIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def CustomQuantizationStart(builder): + builder.StartObject(1) + + +def CustomQuantizationAddCustom(builder, custom): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(custom), 0) + + +def CustomQuantizationStartCustomVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def CustomQuantizationEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class CustomQuantizationT(object): + + # CustomQuantizationT + def __init__(self): + self.custom = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + customQuantization = CustomQuantization() + customQuantization.Init(buf, pos) + return cls.InitFromObj(customQuantization) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, customQuantization): + x = CustomQuantizationT() + x._UnPack(customQuantization) + return x + + # CustomQuantizationT + def _UnPack(self, customQuantization): + if customQuantization is None: + return + if not customQuantization.CustomIsNone(): + if np is None: + self.custom = [] + for i in range(customQuantization.CustomLength()): + self.custom.append(customQuantization.Custom(i)) + else: + self.custom = customQuantization.CustomAsNumpy() + + # CustomQuantizationT + def Pack(self, builder): + if self.custom is not None: + if np is not None and type(self.custom) is np.ndarray: + custom = builder.CreateNumpyVector(self.custom) + else: + CustomQuantizationStartCustomVector(builder, len(self.custom)) + for i in reversed(range(len(self.custom))): + builder.PrependUint8(self.custom[i]) + custom = builder.EndVector() + CustomQuantizationStart(builder) + if self.custom is not None: + CustomQuantizationAddCustom(builder, custom) + customQuantization = CustomQuantizationEnd(builder) + return customQuantization + + +class MXQuantization(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MXQuantization() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMXQuantization(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MXQuantizationBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MXQuantization + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MXQuantization + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def MXQuantizationStart(builder): + builder.StartObject(1) + + +def MXQuantizationAddAxis(builder, axis): + builder.PrependInt32Slot(0, axis, 0) + + +def MXQuantizationEnd(builder): + return builder.EndObject() + + +class MXQuantizationT(object): + + # MXQuantizationT + def __init__(self): + self.axis = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + mxquantization = MXQuantization() + mxquantization.Init(buf, pos) + return cls.InitFromObj(mxquantization) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, mxquantization): + x = MXQuantizationT() + x._UnPack(mxquantization) + return x + + # MXQuantizationT + def _UnPack(self, mxquantization): + if mxquantization is None: + return + self.axis = mxquantization.Axis() + + # MXQuantizationT + def Pack(self, builder): + MXQuantizationStart(builder) + MXQuantizationAddAxis(builder, self.axis) + mxquantization = MXQuantizationEnd(builder) + return mxquantization + + +class QuantizationParameters(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = QuantizationParameters() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsQuantizationParameters(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def QuantizationParametersBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # QuantizationParameters + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # QuantizationParameters + def Min(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Float32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # QuantizationParameters + def MinAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # QuantizationParameters + def MinLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # QuantizationParameters + def MinIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # QuantizationParameters + def Max(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Float32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # QuantizationParameters + def MaxAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # QuantizationParameters + def MaxLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # QuantizationParameters + def MaxIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # QuantizationParameters + def Scale(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Float32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # QuantizationParameters + def ScaleAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # QuantizationParameters + def ScaleLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # QuantizationParameters + def ScaleIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # QuantizationParameters + def ZeroPoint(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # QuantizationParameters + def ZeroPointAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # QuantizationParameters + def ZeroPointLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # QuantizationParameters + def ZeroPointIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # QuantizationParameters + def DetailsType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # QuantizationParameters + def Details(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + # QuantizationParameters + def QuantizedDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def QuantizationParametersStart(builder): + builder.StartObject(7) + + +def QuantizationParametersAddMin(builder, min): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(min), 0) + + +def QuantizationParametersStartMinVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def QuantizationParametersAddMax(builder, max): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(max), 0) + + +def QuantizationParametersStartMaxVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def QuantizationParametersAddScale(builder, scale): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(scale), 0) + + +def QuantizationParametersStartScaleVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def QuantizationParametersAddZeroPoint(builder, zeroPoint): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(zeroPoint), 0) + + +def QuantizationParametersStartZeroPointVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def QuantizationParametersAddDetailsType(builder, detailsType): + builder.PrependUint8Slot(4, detailsType, 0) + + +def QuantizationParametersAddDetails(builder, details): + builder.PrependUOffsetTRelativeSlot( + 5, flatbuffers.number_types.UOffsetTFlags.py_type(details), 0) + + +def QuantizationParametersAddQuantizedDimension(builder, quantizedDimension): + builder.PrependInt32Slot(6, quantizedDimension, 0) + + +def QuantizationParametersEnd(builder): + return builder.EndObject() + + +try: + from typing import List, Union +except: + pass + + +class QuantizationParametersT(object): + + # QuantizationParametersT + def __init__(self): + self.min = None # type: List[float] + self.max = None # type: List[float] + self.scale = None # type: List[float] + self.zeroPoint = None # type: List[int] + self.detailsType = 0 # type: int + self.details = None # type: Union[None, CustomQuantizationT, MXQuantizationT] + self.quantizedDimension = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + quantizationParameters = QuantizationParameters() + quantizationParameters.Init(buf, pos) + return cls.InitFromObj(quantizationParameters) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, quantizationParameters): + x = QuantizationParametersT() + x._UnPack(quantizationParameters) + return x + + # QuantizationParametersT + def _UnPack(self, quantizationParameters): + if quantizationParameters is None: + return + if not quantizationParameters.MinIsNone(): + if np is None: + self.min = [] + for i in range(quantizationParameters.MinLength()): + self.min.append(quantizationParameters.Min(i)) + else: + self.min = quantizationParameters.MinAsNumpy() + if not quantizationParameters.MaxIsNone(): + if np is None: + self.max = [] + for i in range(quantizationParameters.MaxLength()): + self.max.append(quantizationParameters.Max(i)) + else: + self.max = quantizationParameters.MaxAsNumpy() + if not quantizationParameters.ScaleIsNone(): + if np is None: + self.scale = [] + for i in range(quantizationParameters.ScaleLength()): + self.scale.append(quantizationParameters.Scale(i)) + else: + self.scale = quantizationParameters.ScaleAsNumpy() + if not quantizationParameters.ZeroPointIsNone(): + if np is None: + self.zeroPoint = [] + for i in range(quantizationParameters.ZeroPointLength()): + self.zeroPoint.append(quantizationParameters.ZeroPoint(i)) + else: + self.zeroPoint = quantizationParameters.ZeroPointAsNumpy() + self.detailsType = quantizationParameters.DetailsType() + self.details = QuantizationDetailsCreator(self.detailsType, + quantizationParameters.Details()) + self.quantizedDimension = quantizationParameters.QuantizedDimension() + + # QuantizationParametersT + def Pack(self, builder): + if self.min is not None: + if np is not None and type(self.min) is np.ndarray: + min = builder.CreateNumpyVector(self.min) + else: + QuantizationParametersStartMinVector(builder, len(self.min)) + for i in reversed(range(len(self.min))): + builder.PrependFloat32(self.min[i]) + min = builder.EndVector() + if self.max is not None: + if np is not None and type(self.max) is np.ndarray: + max = builder.CreateNumpyVector(self.max) + else: + QuantizationParametersStartMaxVector(builder, len(self.max)) + for i in reversed(range(len(self.max))): + builder.PrependFloat32(self.max[i]) + max = builder.EndVector() + if self.scale is not None: + if np is not None and type(self.scale) is np.ndarray: + scale = builder.CreateNumpyVector(self.scale) + else: + QuantizationParametersStartScaleVector(builder, len(self.scale)) + for i in reversed(range(len(self.scale))): + builder.PrependFloat32(self.scale[i]) + scale = builder.EndVector() + if self.zeroPoint is not None: + if np is not None and type(self.zeroPoint) is np.ndarray: + zeroPoint = builder.CreateNumpyVector(self.zeroPoint) + else: + QuantizationParametersStartZeroPointVector(builder, len(self.zeroPoint)) + for i in reversed(range(len(self.zeroPoint))): + builder.PrependInt64(self.zeroPoint[i]) + zeroPoint = builder.EndVector() + if self.details is not None: + details = self.details.Pack(builder) + QuantizationParametersStart(builder) + if self.min is not None: + QuantizationParametersAddMin(builder, min) + if self.max is not None: + QuantizationParametersAddMax(builder, max) + if self.scale is not None: + QuantizationParametersAddScale(builder, scale) + if self.zeroPoint is not None: + QuantizationParametersAddZeroPoint(builder, zeroPoint) + QuantizationParametersAddDetailsType(builder, self.detailsType) + if self.details is not None: + QuantizationParametersAddDetails(builder, details) + QuantizationParametersAddQuantizedDimension(builder, self.quantizedDimension) + quantizationParameters = QuantizationParametersEnd(builder) + return quantizationParameters + + +class Int32Vector(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Int32Vector() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsInt32Vector(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Int32VectorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Int32Vector + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Int32Vector + def Values(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Int32Vector + def ValuesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Int32Vector + def ValuesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Int32Vector + def ValuesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def Int32VectorStart(builder): + builder.StartObject(1) + + +def Int32VectorAddValues(builder, values): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0) + + +def Int32VectorStartValuesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def Int32VectorEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class Int32VectorT(object): + + # Int32VectorT + def __init__(self): + self.values = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + int32Vector = Int32Vector() + int32Vector.Init(buf, pos) + return cls.InitFromObj(int32Vector) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, int32Vector): + x = Int32VectorT() + x._UnPack(int32Vector) + return x + + # Int32VectorT + def _UnPack(self, int32Vector): + if int32Vector is None: + return + if not int32Vector.ValuesIsNone(): + if np is None: + self.values = [] + for i in range(int32Vector.ValuesLength()): + self.values.append(int32Vector.Values(i)) + else: + self.values = int32Vector.ValuesAsNumpy() + + # Int32VectorT + def Pack(self, builder): + if self.values is not None: + if np is not None and type(self.values) is np.ndarray: + values = builder.CreateNumpyVector(self.values) + else: + Int32VectorStartValuesVector(builder, len(self.values)) + for i in reversed(range(len(self.values))): + builder.PrependInt32(self.values[i]) + values = builder.EndVector() + Int32VectorStart(builder) + if self.values is not None: + Int32VectorAddValues(builder, values) + int32Vector = Int32VectorEnd(builder) + return int32Vector + + +class Uint16Vector(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Uint16Vector() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUint16Vector(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Uint16VectorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Uint16Vector + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Uint16Vector + def Values(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint16Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2)) + return 0 + + # Uint16Vector + def ValuesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint16Flags, o) + return 0 + + # Uint16Vector + def ValuesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Uint16Vector + def ValuesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def Uint16VectorStart(builder): + builder.StartObject(1) + + +def Uint16VectorAddValues(builder, values): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0) + + +def Uint16VectorStartValuesVector(builder, numElems): + return builder.StartVector(2, numElems, 2) + + +def Uint16VectorEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class Uint16VectorT(object): + + # Uint16VectorT + def __init__(self): + self.values = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + uint16Vector = Uint16Vector() + uint16Vector.Init(buf, pos) + return cls.InitFromObj(uint16Vector) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, uint16Vector): + x = Uint16VectorT() + x._UnPack(uint16Vector) + return x + + # Uint16VectorT + def _UnPack(self, uint16Vector): + if uint16Vector is None: + return + if not uint16Vector.ValuesIsNone(): + if np is None: + self.values = [] + for i in range(uint16Vector.ValuesLength()): + self.values.append(uint16Vector.Values(i)) + else: + self.values = uint16Vector.ValuesAsNumpy() + + # Uint16VectorT + def Pack(self, builder): + if self.values is not None: + if np is not None and type(self.values) is np.ndarray: + values = builder.CreateNumpyVector(self.values) + else: + Uint16VectorStartValuesVector(builder, len(self.values)) + for i in reversed(range(len(self.values))): + builder.PrependUint16(self.values[i]) + values = builder.EndVector() + Uint16VectorStart(builder) + if self.values is not None: + Uint16VectorAddValues(builder, values) + uint16Vector = Uint16VectorEnd(builder) + return uint16Vector + + +class Uint8Vector(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Uint8Vector() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUint8Vector(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Uint8VectorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Uint8Vector + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Uint8Vector + def Values(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # Uint8Vector + def ValuesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # Uint8Vector + def ValuesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Uint8Vector + def ValuesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def Uint8VectorStart(builder): + builder.StartObject(1) + + +def Uint8VectorAddValues(builder, values): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0) + + +def Uint8VectorStartValuesVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def Uint8VectorEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class Uint8VectorT(object): + + # Uint8VectorT + def __init__(self): + self.values = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + uint8Vector = Uint8Vector() + uint8Vector.Init(buf, pos) + return cls.InitFromObj(uint8Vector) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, uint8Vector): + x = Uint8VectorT() + x._UnPack(uint8Vector) + return x + + # Uint8VectorT + def _UnPack(self, uint8Vector): + if uint8Vector is None: + return + if not uint8Vector.ValuesIsNone(): + if np is None: + self.values = [] + for i in range(uint8Vector.ValuesLength()): + self.values.append(uint8Vector.Values(i)) + else: + self.values = uint8Vector.ValuesAsNumpy() + + # Uint8VectorT + def Pack(self, builder): + if self.values is not None: + if np is not None and type(self.values) is np.ndarray: + values = builder.CreateNumpyVector(self.values) + else: + Uint8VectorStartValuesVector(builder, len(self.values)) + for i in reversed(range(len(self.values))): + builder.PrependUint8(self.values[i]) + values = builder.EndVector() + Uint8VectorStart(builder) + if self.values is not None: + Uint8VectorAddValues(builder, values) + uint8Vector = Uint8VectorEnd(builder) + return uint8Vector + + +class DimensionMetadata(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DimensionMetadata() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDimensionMetadata(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DimensionMetadataBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DimensionMetadata + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DimensionMetadata + def Format(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # DimensionMetadata + def DenseSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # DimensionMetadata + def ArraySegmentsType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # DimensionMetadata + def ArraySegments(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + # DimensionMetadata + def ArrayIndicesType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # DimensionMetadata + def ArrayIndices(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + +def DimensionMetadataStart(builder): + builder.StartObject(6) + + +def DimensionMetadataAddFormat(builder, format): + builder.PrependInt8Slot(0, format, 0) + + +def DimensionMetadataAddDenseSize(builder, denseSize): + builder.PrependInt32Slot(1, denseSize, 0) + + +def DimensionMetadataAddArraySegmentsType(builder, arraySegmentsType): + builder.PrependUint8Slot(2, arraySegmentsType, 0) + + +def DimensionMetadataAddArraySegments(builder, arraySegments): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(arraySegments), 0) + + +def DimensionMetadataAddArrayIndicesType(builder, arrayIndicesType): + builder.PrependUint8Slot(4, arrayIndicesType, 0) + + +def DimensionMetadataAddArrayIndices(builder, arrayIndices): + builder.PrependUOffsetTRelativeSlot( + 5, flatbuffers.number_types.UOffsetTFlags.py_type(arrayIndices), 0) + + +def DimensionMetadataEnd(builder): + return builder.EndObject() + + +try: + from typing import Union +except: + pass + + +class DimensionMetadataT(object): + + # DimensionMetadataT + def __init__(self): + self.format = 0 # type: int + self.denseSize = 0 # type: int + self.arraySegmentsType = 0 # type: int + self.arraySegments = None # type: Union[None, Int32VectorT, Uint16VectorT, Uint8VectorT] + self.arrayIndicesType = 0 # type: int + self.arrayIndices = None # type: Union[None, Int32VectorT, Uint16VectorT, Uint8VectorT] + + @classmethod + def InitFromBuf(cls, buf, pos): + dimensionMetadata = DimensionMetadata() + dimensionMetadata.Init(buf, pos) + return cls.InitFromObj(dimensionMetadata) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, dimensionMetadata): + x = DimensionMetadataT() + x._UnPack(dimensionMetadata) + return x + + # DimensionMetadataT + def _UnPack(self, dimensionMetadata): + if dimensionMetadata is None: + return + self.format = dimensionMetadata.Format() + self.denseSize = dimensionMetadata.DenseSize() + self.arraySegmentsType = dimensionMetadata.ArraySegmentsType() + self.arraySegments = SparseIndexVectorCreator(self.arraySegmentsType, + dimensionMetadata.ArraySegments()) + self.arrayIndicesType = dimensionMetadata.ArrayIndicesType() + self.arrayIndices = SparseIndexVectorCreator(self.arrayIndicesType, + dimensionMetadata.ArrayIndices()) + + # DimensionMetadataT + def Pack(self, builder): + if self.arraySegments is not None: + arraySegments = self.arraySegments.Pack(builder) + if self.arrayIndices is not None: + arrayIndices = self.arrayIndices.Pack(builder) + DimensionMetadataStart(builder) + DimensionMetadataAddFormat(builder, self.format) + DimensionMetadataAddDenseSize(builder, self.denseSize) + DimensionMetadataAddArraySegmentsType(builder, self.arraySegmentsType) + if self.arraySegments is not None: + DimensionMetadataAddArraySegments(builder, arraySegments) + DimensionMetadataAddArrayIndicesType(builder, self.arrayIndicesType) + if self.arrayIndices is not None: + DimensionMetadataAddArrayIndices(builder, arrayIndices) + dimensionMetadata = DimensionMetadataEnd(builder) + return dimensionMetadata + + +class SparsityParameters(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SparsityParameters() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSparsityParameters(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SparsityParametersBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SparsityParameters + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SparsityParameters + def TraversalOrder(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SparsityParameters + def TraversalOrderAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SparsityParameters + def TraversalOrderLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SparsityParameters + def TraversalOrderIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # SparsityParameters + def BlockMap(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SparsityParameters + def BlockMapAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SparsityParameters + def BlockMapLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SparsityParameters + def BlockMapIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # SparsityParameters + def DimMetadata(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = DimensionMetadata() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # SparsityParameters + def DimMetadataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SparsityParameters + def DimMetadataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + +def SparsityParametersStart(builder): + builder.StartObject(3) + + +def SparsityParametersAddTraversalOrder(builder, traversalOrder): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(traversalOrder), 0) + + +def SparsityParametersStartTraversalOrderVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SparsityParametersAddBlockMap(builder, blockMap): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(blockMap), 0) + + +def SparsityParametersStartBlockMapVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SparsityParametersAddDimMetadata(builder, dimMetadata): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(dimMetadata), 0) + + +def SparsityParametersStartDimMetadataVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SparsityParametersEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class SparsityParametersT(object): + + # SparsityParametersT + def __init__(self): + self.traversalOrder = None # type: List[int] + self.blockMap = None # type: List[int] + self.dimMetadata = None # type: List[DimensionMetadataT] + + @classmethod + def InitFromBuf(cls, buf, pos): + sparsityParameters = SparsityParameters() + sparsityParameters.Init(buf, pos) + return cls.InitFromObj(sparsityParameters) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, sparsityParameters): + x = SparsityParametersT() + x._UnPack(sparsityParameters) + return x + + # SparsityParametersT + def _UnPack(self, sparsityParameters): + if sparsityParameters is None: + return + if not sparsityParameters.TraversalOrderIsNone(): + if np is None: + self.traversalOrder = [] + for i in range(sparsityParameters.TraversalOrderLength()): + self.traversalOrder.append(sparsityParameters.TraversalOrder(i)) + else: + self.traversalOrder = sparsityParameters.TraversalOrderAsNumpy() + if not sparsityParameters.BlockMapIsNone(): + if np is None: + self.blockMap = [] + for i in range(sparsityParameters.BlockMapLength()): + self.blockMap.append(sparsityParameters.BlockMap(i)) + else: + self.blockMap = sparsityParameters.BlockMapAsNumpy() + if not sparsityParameters.DimMetadataIsNone(): + self.dimMetadata = [] + for i in range(sparsityParameters.DimMetadataLength()): + if sparsityParameters.DimMetadata(i) is None: + self.dimMetadata.append(None) + else: + dimensionMetadata_ = DimensionMetadataT.InitFromObj( + sparsityParameters.DimMetadata(i)) + self.dimMetadata.append(dimensionMetadata_) + + # SparsityParametersT + def Pack(self, builder): + if self.traversalOrder is not None: + if np is not None and type(self.traversalOrder) is np.ndarray: + traversalOrder = builder.CreateNumpyVector(self.traversalOrder) + else: + SparsityParametersStartTraversalOrderVector(builder, + len(self.traversalOrder)) + for i in reversed(range(len(self.traversalOrder))): + builder.PrependInt32(self.traversalOrder[i]) + traversalOrder = builder.EndVector() + if self.blockMap is not None: + if np is not None and type(self.blockMap) is np.ndarray: + blockMap = builder.CreateNumpyVector(self.blockMap) + else: + SparsityParametersStartBlockMapVector(builder, len(self.blockMap)) + for i in reversed(range(len(self.blockMap))): + builder.PrependInt32(self.blockMap[i]) + blockMap = builder.EndVector() + if self.dimMetadata is not None: + dimMetadatalist = [] + for i in range(len(self.dimMetadata)): + dimMetadatalist.append(self.dimMetadata[i].Pack(builder)) + SparsityParametersStartDimMetadataVector(builder, len(self.dimMetadata)) + for i in reversed(range(len(self.dimMetadata))): + builder.PrependUOffsetTRelative(dimMetadatalist[i]) + dimMetadata = builder.EndVector() + SparsityParametersStart(builder) + if self.traversalOrder is not None: + SparsityParametersAddTraversalOrder(builder, traversalOrder) + if self.blockMap is not None: + SparsityParametersAddBlockMap(builder, blockMap) + if self.dimMetadata is not None: + SparsityParametersAddDimMetadata(builder, dimMetadata) + sparsityParameters = SparsityParametersEnd(builder) + return sparsityParameters + + +class VariantSubType(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = VariantSubType() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsVariantSubType(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def VariantSubTypeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # VariantSubType + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # VariantSubType + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # VariantSubType + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # VariantSubType + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # VariantSubType + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # VariantSubType + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # VariantSubType + def HasRank(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def VariantSubTypeStart(builder): + builder.StartObject(3) + + +def VariantSubTypeAddShape(builder, shape): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) + + +def VariantSubTypeStartShapeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def VariantSubTypeAddType(builder, type): + builder.PrependInt8Slot(1, type, 0) + + +def VariantSubTypeAddHasRank(builder, hasRank): + builder.PrependBoolSlot(2, hasRank, 0) + + +def VariantSubTypeEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class VariantSubTypeT(object): + + # VariantSubTypeT + def __init__(self): + self.shape = None # type: List[int] + self.type = 0 # type: int + self.hasRank = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + variantSubType = VariantSubType() + variantSubType.Init(buf, pos) + return cls.InitFromObj(variantSubType) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, variantSubType): + x = VariantSubTypeT() + x._UnPack(variantSubType) + return x + + # VariantSubTypeT + def _UnPack(self, variantSubType): + if variantSubType is None: + return + if not variantSubType.ShapeIsNone(): + if np is None: + self.shape = [] + for i in range(variantSubType.ShapeLength()): + self.shape.append(variantSubType.Shape(i)) + else: + self.shape = variantSubType.ShapeAsNumpy() + self.type = variantSubType.Type() + self.hasRank = variantSubType.HasRank() + + # VariantSubTypeT + def Pack(self, builder): + if self.shape is not None: + if np is not None and type(self.shape) is np.ndarray: + shape = builder.CreateNumpyVector(self.shape) + else: + VariantSubTypeStartShapeVector(builder, len(self.shape)) + for i in reversed(range(len(self.shape))): + builder.PrependInt32(self.shape[i]) + shape = builder.EndVector() + VariantSubTypeStart(builder) + if self.shape is not None: + VariantSubTypeAddShape(builder, shape) + VariantSubTypeAddType(builder, self.type) + VariantSubTypeAddHasRank(builder, self.hasRank) + variantSubType = VariantSubTypeEnd(builder) + return variantSubType + + +class Tensor(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Tensor() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTensor(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TensorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Tensor + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Tensor + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Tensor + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Tensor + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Tensor + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Tensor + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Tensor + def Buffer(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # Tensor + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Tensor + def Quantization(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = QuantizationParameters() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Tensor + def IsVariable(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # Tensor + def Sparsity(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = SparsityParameters() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Tensor + def ShapeSignature(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Tensor + def ShapeSignatureAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Tensor + def ShapeSignatureLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Tensor + def ShapeSignatureIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + return o == 0 + + # Tensor + def HasRank(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # Tensor + def VariantTensors(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = VariantSubType() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Tensor + def VariantTensorsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Tensor + def VariantTensorsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + return o == 0 + + # Tensor + def CompressionType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def TensorStart(builder): + builder.StartObject(11) + + +def TensorAddShape(builder, shape): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) + + +def TensorStartShapeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def TensorAddType(builder, type): + builder.PrependInt8Slot(1, type, 0) + + +def TensorAddBuffer(builder, buffer): + builder.PrependUint32Slot(2, buffer, 0) + + +def TensorAddName(builder, name): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + + +def TensorAddQuantization(builder, quantization): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(quantization), 0) + + +def TensorAddIsVariable(builder, isVariable): + builder.PrependBoolSlot(5, isVariable, 0) + + +def TensorAddSparsity(builder, sparsity): + builder.PrependUOffsetTRelativeSlot( + 6, flatbuffers.number_types.UOffsetTFlags.py_type(sparsity), 0) + + +def TensorAddShapeSignature(builder, shapeSignature): + builder.PrependUOffsetTRelativeSlot( + 7, flatbuffers.number_types.UOffsetTFlags.py_type(shapeSignature), 0) + + +def TensorStartShapeSignatureVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def TensorAddHasRank(builder, hasRank): + builder.PrependBoolSlot(8, hasRank, 0) + + +def TensorAddVariantTensors(builder, variantTensors): + builder.PrependUOffsetTRelativeSlot( + 9, flatbuffers.number_types.UOffsetTFlags.py_type(variantTensors), 0) + + +def TensorStartVariantTensorsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def TensorAddCompressionType(builder, compressionType): + builder.PrependInt8Slot(10, compressionType, 0) + + +def TensorEnd(builder): + return builder.EndObject() + + +try: + from typing import List, Optional +except: + pass + + +class TensorT(object): + + # TensorT + def __init__(self): + self.shape = None # type: List[int] + self.type = 0 # type: int + self.buffer = 0 # type: int + self.name = None # type: str + self.quantization = None # type: Optional[QuantizationParametersT] + self.isVariable = False # type: bool + self.sparsity = None # type: Optional[SparsityParametersT] + self.shapeSignature = None # type: List[int] + self.hasRank = False # type: bool + self.variantTensors = None # type: List[VariantSubTypeT] + self.compressionType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + tensor = Tensor() + tensor.Init(buf, pos) + return cls.InitFromObj(tensor) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, tensor): + x = TensorT() + x._UnPack(tensor) + return x + + # TensorT + def _UnPack(self, tensor): + if tensor is None: + return + if not tensor.ShapeIsNone(): + if np is None: + self.shape = [] + for i in range(tensor.ShapeLength()): + self.shape.append(tensor.Shape(i)) + else: + self.shape = tensor.ShapeAsNumpy() + self.type = tensor.Type() + self.buffer = tensor.Buffer() + self.name = tensor.Name() + if tensor.Quantization() is not None: + self.quantization = QuantizationParametersT.InitFromObj(tensor.Quantization()) + self.isVariable = tensor.IsVariable() + if tensor.Sparsity() is not None: + self.sparsity = SparsityParametersT.InitFromObj(tensor.Sparsity()) + if not tensor.ShapeSignatureIsNone(): + if np is None: + self.shapeSignature = [] + for i in range(tensor.ShapeSignatureLength()): + self.shapeSignature.append(tensor.ShapeSignature(i)) + else: + self.shapeSignature = tensor.ShapeSignatureAsNumpy() + self.hasRank = tensor.HasRank() + if not tensor.VariantTensorsIsNone(): + self.variantTensors = [] + for i in range(tensor.VariantTensorsLength()): + if tensor.VariantTensors(i) is None: + self.variantTensors.append(None) + else: + variantSubType_ = VariantSubTypeT.InitFromObj( + tensor.VariantTensors(i)) + self.variantTensors.append(variantSubType_) + self.compressionType = tensor.CompressionType() + + # TensorT + def Pack(self, builder): + if self.shape is not None: + if np is not None and type(self.shape) is np.ndarray: + shape = builder.CreateNumpyVector(self.shape) + else: + TensorStartShapeVector(builder, len(self.shape)) + for i in reversed(range(len(self.shape))): + builder.PrependInt32(self.shape[i]) + shape = builder.EndVector() + if self.name is not None: + name = builder.CreateString(self.name) + if self.quantization is not None: + quantization = self.quantization.Pack(builder) + if self.sparsity is not None: + sparsity = self.sparsity.Pack(builder) + if self.shapeSignature is not None: + if np is not None and type(self.shapeSignature) is np.ndarray: + shapeSignature = builder.CreateNumpyVector(self.shapeSignature) + else: + TensorStartShapeSignatureVector(builder, len(self.shapeSignature)) + for i in reversed(range(len(self.shapeSignature))): + builder.PrependInt32(self.shapeSignature[i]) + shapeSignature = builder.EndVector() + if self.variantTensors is not None: + variantTensorslist = [] + for i in range(len(self.variantTensors)): + variantTensorslist.append(self.variantTensors[i].Pack(builder)) + TensorStartVariantTensorsVector(builder, len(self.variantTensors)) + for i in reversed(range(len(self.variantTensors))): + builder.PrependUOffsetTRelative(variantTensorslist[i]) + variantTensors = builder.EndVector() + TensorStart(builder) + if self.shape is not None: + TensorAddShape(builder, shape) + TensorAddType(builder, self.type) + TensorAddBuffer(builder, self.buffer) + if self.name is not None: + TensorAddName(builder, name) + if self.quantization is not None: + TensorAddQuantization(builder, quantization) + TensorAddIsVariable(builder, self.isVariable) + if self.sparsity is not None: + TensorAddSparsity(builder, sparsity) + if self.shapeSignature is not None: + TensorAddShapeSignature(builder, shapeSignature) + TensorAddHasRank(builder, self.hasRank) + if self.variantTensors is not None: + TensorAddVariantTensors(builder, variantTensors) + TensorAddCompressionType(builder, self.compressionType) + tensor = TensorEnd(builder) + return tensor + + +class StablehloGatherOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloGatherOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloGatherOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloGatherOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloGatherOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloGatherOptions + def OffsetDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloGatherOptions + def OffsetDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloGatherOptions + def OffsetDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloGatherOptions + def OffsetDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloGatherOptions + def CollapsedSliceDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloGatherOptions + def CollapsedSliceDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloGatherOptions + def CollapsedSliceDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloGatherOptions + def CollapsedSliceDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloGatherOptions + def StartIndexMap(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloGatherOptions + def StartIndexMapAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloGatherOptions + def StartIndexMapLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloGatherOptions + def StartIndexMapIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StablehloGatherOptions + def IndexVectorDim(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloGatherOptions + def SliceSizes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloGatherOptions + def SliceSizesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloGatherOptions + def SliceSizesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloGatherOptions + def SliceSizesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # StablehloGatherOptions + def IndicesAreSorted(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def StablehloGatherOptionsStart(builder): + builder.StartObject(6) + + +def StablehloGatherOptionsAddOffsetDims(builder, offsetDims): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(offsetDims), 0) + + +def StablehloGatherOptionsStartOffsetDimsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloGatherOptionsAddCollapsedSliceDims(builder, collapsedSliceDims): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(collapsedSliceDims), 0) + + +def StablehloGatherOptionsStartCollapsedSliceDimsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloGatherOptionsAddStartIndexMap(builder, startIndexMap): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(startIndexMap), 0) + + +def StablehloGatherOptionsStartStartIndexMapVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloGatherOptionsAddIndexVectorDim(builder, indexVectorDim): + builder.PrependInt64Slot(3, indexVectorDim, 0) + + +def StablehloGatherOptionsAddSliceSizes(builder, sliceSizes): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(sliceSizes), 0) + + +def StablehloGatherOptionsStartSliceSizesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloGatherOptionsAddIndicesAreSorted(builder, indicesAreSorted): + builder.PrependBoolSlot(5, indicesAreSorted, 0) + + +def StablehloGatherOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloGatherOptionsT(object): + + # StablehloGatherOptionsT + def __init__(self): + self.offsetDims = None # type: List[int] + self.collapsedSliceDims = None # type: List[int] + self.startIndexMap = None # type: List[int] + self.indexVectorDim = 0 # type: int + self.sliceSizes = None # type: List[int] + self.indicesAreSorted = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloGatherOptions = StablehloGatherOptions() + stablehloGatherOptions.Init(buf, pos) + return cls.InitFromObj(stablehloGatherOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloGatherOptions): + x = StablehloGatherOptionsT() + x._UnPack(stablehloGatherOptions) + return x + + # StablehloGatherOptionsT + def _UnPack(self, stablehloGatherOptions): + if stablehloGatherOptions is None: + return + if not stablehloGatherOptions.OffsetDimsIsNone(): + if np is None: + self.offsetDims = [] + for i in range(stablehloGatherOptions.OffsetDimsLength()): + self.offsetDims.append(stablehloGatherOptions.OffsetDims(i)) + else: + self.offsetDims = stablehloGatherOptions.OffsetDimsAsNumpy() + if not stablehloGatherOptions.CollapsedSliceDimsIsNone(): + if np is None: + self.collapsedSliceDims = [] + for i in range(stablehloGatherOptions.CollapsedSliceDimsLength()): + self.collapsedSliceDims.append( + stablehloGatherOptions.CollapsedSliceDims(i)) + else: + self.collapsedSliceDims = stablehloGatherOptions.CollapsedSliceDimsAsNumpy( + ) + if not stablehloGatherOptions.StartIndexMapIsNone(): + if np is None: + self.startIndexMap = [] + for i in range(stablehloGatherOptions.StartIndexMapLength()): + self.startIndexMap.append(stablehloGatherOptions.StartIndexMap(i)) + else: + self.startIndexMap = stablehloGatherOptions.StartIndexMapAsNumpy() + self.indexVectorDim = stablehloGatherOptions.IndexVectorDim() + if not stablehloGatherOptions.SliceSizesIsNone(): + if np is None: + self.sliceSizes = [] + for i in range(stablehloGatherOptions.SliceSizesLength()): + self.sliceSizes.append(stablehloGatherOptions.SliceSizes(i)) + else: + self.sliceSizes = stablehloGatherOptions.SliceSizesAsNumpy() + self.indicesAreSorted = stablehloGatherOptions.IndicesAreSorted() + + # StablehloGatherOptionsT + def Pack(self, builder): + if self.offsetDims is not None: + if np is not None and type(self.offsetDims) is np.ndarray: + offsetDims = builder.CreateNumpyVector(self.offsetDims) + else: + StablehloGatherOptionsStartOffsetDimsVector(builder, len(self.offsetDims)) + for i in reversed(range(len(self.offsetDims))): + builder.PrependInt64(self.offsetDims[i]) + offsetDims = builder.EndVector() + if self.collapsedSliceDims is not None: + if np is not None and type(self.collapsedSliceDims) is np.ndarray: + collapsedSliceDims = builder.CreateNumpyVector(self.collapsedSliceDims) + else: + StablehloGatherOptionsStartCollapsedSliceDimsVector( + builder, len(self.collapsedSliceDims)) + for i in reversed(range(len(self.collapsedSliceDims))): + builder.PrependInt64(self.collapsedSliceDims[i]) + collapsedSliceDims = builder.EndVector() + if self.startIndexMap is not None: + if np is not None and type(self.startIndexMap) is np.ndarray: + startIndexMap = builder.CreateNumpyVector(self.startIndexMap) + else: + StablehloGatherOptionsStartStartIndexMapVector(builder, + len(self.startIndexMap)) + for i in reversed(range(len(self.startIndexMap))): + builder.PrependInt64(self.startIndexMap[i]) + startIndexMap = builder.EndVector() + if self.sliceSizes is not None: + if np is not None and type(self.sliceSizes) is np.ndarray: + sliceSizes = builder.CreateNumpyVector(self.sliceSizes) + else: + StablehloGatherOptionsStartSliceSizesVector(builder, len(self.sliceSizes)) + for i in reversed(range(len(self.sliceSizes))): + builder.PrependInt64(self.sliceSizes[i]) + sliceSizes = builder.EndVector() + StablehloGatherOptionsStart(builder) + if self.offsetDims is not None: + StablehloGatherOptionsAddOffsetDims(builder, offsetDims) + if self.collapsedSliceDims is not None: + StablehloGatherOptionsAddCollapsedSliceDims(builder, collapsedSliceDims) + if self.startIndexMap is not None: + StablehloGatherOptionsAddStartIndexMap(builder, startIndexMap) + StablehloGatherOptionsAddIndexVectorDim(builder, self.indexVectorDim) + if self.sliceSizes is not None: + StablehloGatherOptionsAddSliceSizes(builder, sliceSizes) + StablehloGatherOptionsAddIndicesAreSorted(builder, self.indicesAreSorted) + stablehloGatherOptions = StablehloGatherOptionsEnd(builder) + return stablehloGatherOptions + + +class StablehloTransposeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloTransposeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloTransposeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloTransposeOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloTransposeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloTransposeOptions + def Permutation(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloTransposeOptions + def PermutationAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloTransposeOptions + def PermutationLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloTransposeOptions + def PermutationIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def StablehloTransposeOptionsStart(builder): + builder.StartObject(1) + + +def StablehloTransposeOptionsAddPermutation(builder, permutation): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(permutation), 0) + + +def StablehloTransposeOptionsStartPermutationVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloTransposeOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloTransposeOptionsT(object): + + # StablehloTransposeOptionsT + def __init__(self): + self.permutation = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloTransposeOptions = StablehloTransposeOptions() + stablehloTransposeOptions.Init(buf, pos) + return cls.InitFromObj(stablehloTransposeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloTransposeOptions): + x = StablehloTransposeOptionsT() + x._UnPack(stablehloTransposeOptions) + return x + + # StablehloTransposeOptionsT + def _UnPack(self, stablehloTransposeOptions): + if stablehloTransposeOptions is None: + return + if not stablehloTransposeOptions.PermutationIsNone(): + if np is None: + self.permutation = [] + for i in range(stablehloTransposeOptions.PermutationLength()): + self.permutation.append(stablehloTransposeOptions.Permutation(i)) + else: + self.permutation = stablehloTransposeOptions.PermutationAsNumpy() + + # StablehloTransposeOptionsT + def Pack(self, builder): + if self.permutation is not None: + if np is not None and type(self.permutation) is np.ndarray: + permutation = builder.CreateNumpyVector(self.permutation) + else: + StablehloTransposeOptionsStartPermutationVector( + builder, len(self.permutation)) + for i in reversed(range(len(self.permutation))): + builder.PrependInt64(self.permutation[i]) + permutation = builder.EndVector() + StablehloTransposeOptionsStart(builder) + if self.permutation is not None: + StablehloTransposeOptionsAddPermutation(builder, permutation) + stablehloTransposeOptions = StablehloTransposeOptionsEnd(builder) + return stablehloTransposeOptions + + +class StablehloDotGeneralOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloDotGeneralOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloDotGeneralOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloDotGeneralOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloDotGeneralOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloDotGeneralOptions + def LhsBatchingDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloDotGeneralOptions + def LhsBatchingDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloDotGeneralOptions + def LhsBatchingDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDotGeneralOptions + def LhsBatchingDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloDotGeneralOptions + def RhsBatchingDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloDotGeneralOptions + def RhsBatchingDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloDotGeneralOptions + def RhsBatchingDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDotGeneralOptions + def RhsBatchingDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloDotGeneralOptions + def LhsContractingDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloDotGeneralOptions + def LhsContractingDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloDotGeneralOptions + def LhsContractingDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDotGeneralOptions + def LhsContractingDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StablehloDotGeneralOptions + def RhsContractingDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloDotGeneralOptions + def RhsContractingDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloDotGeneralOptions + def RhsContractingDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDotGeneralOptions + def RhsContractingDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # StablehloDotGeneralOptions + def PrecisionConfig(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # StablehloDotGeneralOptions + def PrecisionConfigAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # StablehloDotGeneralOptions + def PrecisionConfigLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDotGeneralOptions + def PrecisionConfigIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + +def StablehloDotGeneralOptionsStart(builder): + builder.StartObject(5) + + +def StablehloDotGeneralOptionsAddLhsBatchingDimensions(builder, lhsBatchingDimensions): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(lhsBatchingDimensions), 0) + + +def StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloDotGeneralOptionsAddRhsBatchingDimensions(builder, rhsBatchingDimensions): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(rhsBatchingDimensions), 0) + + +def StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloDotGeneralOptionsAddLhsContractingDimensions(builder, + lhsContractingDimensions): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(lhsContractingDimensions), 0) + + +def StablehloDotGeneralOptionsStartLhsContractingDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloDotGeneralOptionsAddRhsContractingDimensions(builder, + rhsContractingDimensions): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(rhsContractingDimensions), 0) + + +def StablehloDotGeneralOptionsStartRhsContractingDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloDotGeneralOptionsAddPrecisionConfig(builder, precisionConfig): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(precisionConfig), 0) + + +def StablehloDotGeneralOptionsStartPrecisionConfigVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def StablehloDotGeneralOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloDotGeneralOptionsT(object): + + # StablehloDotGeneralOptionsT + def __init__(self): + self.lhsBatchingDimensions = None # type: List[int] + self.rhsBatchingDimensions = None # type: List[int] + self.lhsContractingDimensions = None # type: List[int] + self.rhsContractingDimensions = None # type: List[int] + self.precisionConfig = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloDotGeneralOptions = StablehloDotGeneralOptions() + stablehloDotGeneralOptions.Init(buf, pos) + return cls.InitFromObj(stablehloDotGeneralOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloDotGeneralOptions): + x = StablehloDotGeneralOptionsT() + x._UnPack(stablehloDotGeneralOptions) + return x + + # StablehloDotGeneralOptionsT + def _UnPack(self, stablehloDotGeneralOptions): + if stablehloDotGeneralOptions is None: + return + if not stablehloDotGeneralOptions.LhsBatchingDimensionsIsNone(): + if np is None: + self.lhsBatchingDimensions = [] + for i in range(stablehloDotGeneralOptions.LhsBatchingDimensionsLength()): + self.lhsBatchingDimensions.append( + stablehloDotGeneralOptions.LhsBatchingDimensions(i)) + else: + self.lhsBatchingDimensions = stablehloDotGeneralOptions.LhsBatchingDimensionsAsNumpy( + ) + if not stablehloDotGeneralOptions.RhsBatchingDimensionsIsNone(): + if np is None: + self.rhsBatchingDimensions = [] + for i in range(stablehloDotGeneralOptions.RhsBatchingDimensionsLength()): + self.rhsBatchingDimensions.append( + stablehloDotGeneralOptions.RhsBatchingDimensions(i)) + else: + self.rhsBatchingDimensions = stablehloDotGeneralOptions.RhsBatchingDimensionsAsNumpy( + ) + if not stablehloDotGeneralOptions.LhsContractingDimensionsIsNone(): + if np is None: + self.lhsContractingDimensions = [] + for i in range( + stablehloDotGeneralOptions.LhsContractingDimensionsLength()): + self.lhsContractingDimensions.append( + stablehloDotGeneralOptions.LhsContractingDimensions(i)) + else: + self.lhsContractingDimensions = stablehloDotGeneralOptions.LhsContractingDimensionsAsNumpy( + ) + if not stablehloDotGeneralOptions.RhsContractingDimensionsIsNone(): + if np is None: + self.rhsContractingDimensions = [] + for i in range( + stablehloDotGeneralOptions.RhsContractingDimensionsLength()): + self.rhsContractingDimensions.append( + stablehloDotGeneralOptions.RhsContractingDimensions(i)) + else: + self.rhsContractingDimensions = stablehloDotGeneralOptions.RhsContractingDimensionsAsNumpy( + ) + if not stablehloDotGeneralOptions.PrecisionConfigIsNone(): + if np is None: + self.precisionConfig = [] + for i in range(stablehloDotGeneralOptions.PrecisionConfigLength()): + self.precisionConfig.append( + stablehloDotGeneralOptions.PrecisionConfig(i)) + else: + self.precisionConfig = stablehloDotGeneralOptions.PrecisionConfigAsNumpy() + + # StablehloDotGeneralOptionsT + def Pack(self, builder): + if self.lhsBatchingDimensions is not None: + if np is not None and type(self.lhsBatchingDimensions) is np.ndarray: + lhsBatchingDimensions = builder.CreateNumpyVector( + self.lhsBatchingDimensions) + else: + StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector( + builder, len(self.lhsBatchingDimensions)) + for i in reversed(range(len(self.lhsBatchingDimensions))): + builder.PrependInt64(self.lhsBatchingDimensions[i]) + lhsBatchingDimensions = builder.EndVector() + if self.rhsBatchingDimensions is not None: + if np is not None and type(self.rhsBatchingDimensions) is np.ndarray: + rhsBatchingDimensions = builder.CreateNumpyVector( + self.rhsBatchingDimensions) + else: + StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector( + builder, len(self.rhsBatchingDimensions)) + for i in reversed(range(len(self.rhsBatchingDimensions))): + builder.PrependInt64(self.rhsBatchingDimensions[i]) + rhsBatchingDimensions = builder.EndVector() + if self.lhsContractingDimensions is not None: + if np is not None and type(self.lhsContractingDimensions) is np.ndarray: + lhsContractingDimensions = builder.CreateNumpyVector( + self.lhsContractingDimensions) + else: + StablehloDotGeneralOptionsStartLhsContractingDimensionsVector( + builder, len(self.lhsContractingDimensions)) + for i in reversed(range(len(self.lhsContractingDimensions))): + builder.PrependInt64(self.lhsContractingDimensions[i]) + lhsContractingDimensions = builder.EndVector() + if self.rhsContractingDimensions is not None: + if np is not None and type(self.rhsContractingDimensions) is np.ndarray: + rhsContractingDimensions = builder.CreateNumpyVector( + self.rhsContractingDimensions) + else: + StablehloDotGeneralOptionsStartRhsContractingDimensionsVector( + builder, len(self.rhsContractingDimensions)) + for i in reversed(range(len(self.rhsContractingDimensions))): + builder.PrependInt64(self.rhsContractingDimensions[i]) + rhsContractingDimensions = builder.EndVector() + if self.precisionConfig is not None: + if np is not None and type(self.precisionConfig) is np.ndarray: + precisionConfig = builder.CreateNumpyVector(self.precisionConfig) + else: + StablehloDotGeneralOptionsStartPrecisionConfigVector( + builder, len(self.precisionConfig)) + for i in reversed(range(len(self.precisionConfig))): + builder.PrependUint32(self.precisionConfig[i]) + precisionConfig = builder.EndVector() + StablehloDotGeneralOptionsStart(builder) + if self.lhsBatchingDimensions is not None: + StablehloDotGeneralOptionsAddLhsBatchingDimensions(builder, + lhsBatchingDimensions) + if self.rhsBatchingDimensions is not None: + StablehloDotGeneralOptionsAddRhsBatchingDimensions(builder, + rhsBatchingDimensions) + if self.lhsContractingDimensions is not None: + StablehloDotGeneralOptionsAddLhsContractingDimensions( + builder, lhsContractingDimensions) + if self.rhsContractingDimensions is not None: + StablehloDotGeneralOptionsAddRhsContractingDimensions( + builder, rhsContractingDimensions) + if self.precisionConfig is not None: + StablehloDotGeneralOptionsAddPrecisionConfig(builder, precisionConfig) + stablehloDotGeneralOptions = StablehloDotGeneralOptionsEnd(builder) + return stablehloDotGeneralOptions + + +class StablehloReduceWindowOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloReduceWindowOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloReduceWindowOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloReduceWindowOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloReduceWindowOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloReduceWindowOptions + def WindowDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceWindowOptions + def WindowDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceWindowOptions + def WindowDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceWindowOptions + def WindowDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloReduceWindowOptions + def WindowStrides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceWindowOptions + def WindowStridesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceWindowOptions + def WindowStridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceWindowOptions + def WindowStridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloReduceWindowOptions + def BaseDilations(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceWindowOptions + def BaseDilationsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceWindowOptions + def BaseDilationsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceWindowOptions + def BaseDilationsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StablehloReduceWindowOptions + def WindowDilations(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceWindowOptions + def WindowDilationsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceWindowOptions + def WindowDilationsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceWindowOptions + def WindowDilationsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # StablehloReduceWindowOptions + def Padding(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceWindowOptions + def PaddingAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceWindowOptions + def PaddingLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceWindowOptions + def PaddingIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # StablehloReduceWindowOptions + def BodySubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StablehloReduceWindowOptionsStart(builder): + builder.StartObject(6) + + +def StablehloReduceWindowOptionsAddWindowDimensions(builder, windowDimensions): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(windowDimensions), 0) + + +def StablehloReduceWindowOptionsStartWindowDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceWindowOptionsAddWindowStrides(builder, windowStrides): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(windowStrides), 0) + + +def StablehloReduceWindowOptionsStartWindowStridesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceWindowOptionsAddBaseDilations(builder, baseDilations): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(baseDilations), 0) + + +def StablehloReduceWindowOptionsStartBaseDilationsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceWindowOptionsAddWindowDilations(builder, windowDilations): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(windowDilations), 0) + + +def StablehloReduceWindowOptionsStartWindowDilationsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceWindowOptionsAddPadding(builder, padding): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0) + + +def StablehloReduceWindowOptionsStartPaddingVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceWindowOptionsAddBodySubgraphIndex(builder, bodySubgraphIndex): + builder.PrependInt32Slot(5, bodySubgraphIndex, 0) + + +def StablehloReduceWindowOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloReduceWindowOptionsT(object): + + # StablehloReduceWindowOptionsT + def __init__(self): + self.windowDimensions = None # type: List[int] + self.windowStrides = None # type: List[int] + self.baseDilations = None # type: List[int] + self.windowDilations = None # type: List[int] + self.padding = None # type: List[int] + self.bodySubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloReduceWindowOptions = StablehloReduceWindowOptions() + stablehloReduceWindowOptions.Init(buf, pos) + return cls.InitFromObj(stablehloReduceWindowOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloReduceWindowOptions): + x = StablehloReduceWindowOptionsT() + x._UnPack(stablehloReduceWindowOptions) + return x + + # StablehloReduceWindowOptionsT + def _UnPack(self, stablehloReduceWindowOptions): + if stablehloReduceWindowOptions is None: + return + if not stablehloReduceWindowOptions.WindowDimensionsIsNone(): + if np is None: + self.windowDimensions = [] + for i in range(stablehloReduceWindowOptions.WindowDimensionsLength()): + self.windowDimensions.append( + stablehloReduceWindowOptions.WindowDimensions(i)) + else: + self.windowDimensions = stablehloReduceWindowOptions.WindowDimensionsAsNumpy( + ) + if not stablehloReduceWindowOptions.WindowStridesIsNone(): + if np is None: + self.windowStrides = [] + for i in range(stablehloReduceWindowOptions.WindowStridesLength()): + self.windowStrides.append( + stablehloReduceWindowOptions.WindowStrides(i)) + else: + self.windowStrides = stablehloReduceWindowOptions.WindowStridesAsNumpy() + if not stablehloReduceWindowOptions.BaseDilationsIsNone(): + if np is None: + self.baseDilations = [] + for i in range(stablehloReduceWindowOptions.BaseDilationsLength()): + self.baseDilations.append( + stablehloReduceWindowOptions.BaseDilations(i)) + else: + self.baseDilations = stablehloReduceWindowOptions.BaseDilationsAsNumpy() + if not stablehloReduceWindowOptions.WindowDilationsIsNone(): + if np is None: + self.windowDilations = [] + for i in range(stablehloReduceWindowOptions.WindowDilationsLength()): + self.windowDilations.append( + stablehloReduceWindowOptions.WindowDilations(i)) + else: + self.windowDilations = stablehloReduceWindowOptions.WindowDilationsAsNumpy( + ) + if not stablehloReduceWindowOptions.PaddingIsNone(): + if np is None: + self.padding = [] + for i in range(stablehloReduceWindowOptions.PaddingLength()): + self.padding.append(stablehloReduceWindowOptions.Padding(i)) + else: + self.padding = stablehloReduceWindowOptions.PaddingAsNumpy() + self.bodySubgraphIndex = stablehloReduceWindowOptions.BodySubgraphIndex() + + # StablehloReduceWindowOptionsT + def Pack(self, builder): + if self.windowDimensions is not None: + if np is not None and type(self.windowDimensions) is np.ndarray: + windowDimensions = builder.CreateNumpyVector(self.windowDimensions) + else: + StablehloReduceWindowOptionsStartWindowDimensionsVector( + builder, len(self.windowDimensions)) + for i in reversed(range(len(self.windowDimensions))): + builder.PrependInt64(self.windowDimensions[i]) + windowDimensions = builder.EndVector() + if self.windowStrides is not None: + if np is not None and type(self.windowStrides) is np.ndarray: + windowStrides = builder.CreateNumpyVector(self.windowStrides) + else: + StablehloReduceWindowOptionsStartWindowStridesVector( + builder, len(self.windowStrides)) + for i in reversed(range(len(self.windowStrides))): + builder.PrependInt64(self.windowStrides[i]) + windowStrides = builder.EndVector() + if self.baseDilations is not None: + if np is not None and type(self.baseDilations) is np.ndarray: + baseDilations = builder.CreateNumpyVector(self.baseDilations) + else: + StablehloReduceWindowOptionsStartBaseDilationsVector( + builder, len(self.baseDilations)) + for i in reversed(range(len(self.baseDilations))): + builder.PrependInt64(self.baseDilations[i]) + baseDilations = builder.EndVector() + if self.windowDilations is not None: + if np is not None and type(self.windowDilations) is np.ndarray: + windowDilations = builder.CreateNumpyVector(self.windowDilations) + else: + StablehloReduceWindowOptionsStartWindowDilationsVector( + builder, len(self.windowDilations)) + for i in reversed(range(len(self.windowDilations))): + builder.PrependInt64(self.windowDilations[i]) + windowDilations = builder.EndVector() + if self.padding is not None: + if np is not None and type(self.padding) is np.ndarray: + padding = builder.CreateNumpyVector(self.padding) + else: + StablehloReduceWindowOptionsStartPaddingVector(builder, len(self.padding)) + for i in reversed(range(len(self.padding))): + builder.PrependInt64(self.padding[i]) + padding = builder.EndVector() + StablehloReduceWindowOptionsStart(builder) + if self.windowDimensions is not None: + StablehloReduceWindowOptionsAddWindowDimensions(builder, windowDimensions) + if self.windowStrides is not None: + StablehloReduceWindowOptionsAddWindowStrides(builder, windowStrides) + if self.baseDilations is not None: + StablehloReduceWindowOptionsAddBaseDilations(builder, baseDilations) + if self.windowDilations is not None: + StablehloReduceWindowOptionsAddWindowDilations(builder, windowDilations) + if self.padding is not None: + StablehloReduceWindowOptionsAddPadding(builder, padding) + StablehloReduceWindowOptionsAddBodySubgraphIndex(builder, self.bodySubgraphIndex) + stablehloReduceWindowOptions = StablehloReduceWindowOptionsEnd(builder) + return stablehloReduceWindowOptions + + +class StablehloWhileOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloWhileOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloWhileOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloWhileOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloWhileOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloWhileOptions + def CondSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StablehloWhileOptions + def BodySubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StablehloWhileOptionsStart(builder): + builder.StartObject(2) + + +def StablehloWhileOptionsAddCondSubgraphIndex(builder, condSubgraphIndex): + builder.PrependInt32Slot(0, condSubgraphIndex, 0) + + +def StablehloWhileOptionsAddBodySubgraphIndex(builder, bodySubgraphIndex): + builder.PrependInt32Slot(1, bodySubgraphIndex, 0) + + +def StablehloWhileOptionsEnd(builder): + return builder.EndObject() + + +class StablehloWhileOptionsT(object): + + # StablehloWhileOptionsT + def __init__(self): + self.condSubgraphIndex = 0 # type: int + self.bodySubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloWhileOptions = StablehloWhileOptions() + stablehloWhileOptions.Init(buf, pos) + return cls.InitFromObj(stablehloWhileOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloWhileOptions): + x = StablehloWhileOptionsT() + x._UnPack(stablehloWhileOptions) + return x + + # StablehloWhileOptionsT + def _UnPack(self, stablehloWhileOptions): + if stablehloWhileOptions is None: + return + self.condSubgraphIndex = stablehloWhileOptions.CondSubgraphIndex() + self.bodySubgraphIndex = stablehloWhileOptions.BodySubgraphIndex() + + # StablehloWhileOptionsT + def Pack(self, builder): + StablehloWhileOptionsStart(builder) + StablehloWhileOptionsAddCondSubgraphIndex(builder, self.condSubgraphIndex) + StablehloWhileOptionsAddBodySubgraphIndex(builder, self.bodySubgraphIndex) + stablehloWhileOptions = StablehloWhileOptionsEnd(builder) + return stablehloWhileOptions + + +class StablehloSortOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloSortOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloSortOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloSortOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloSortOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloSortOptions + def Dimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloSortOptions + def IsStable(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # StablehloSortOptions + def ComparatorSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StablehloSortOptionsStart(builder): + builder.StartObject(3) + + +def StablehloSortOptionsAddDimension(builder, dimension): + builder.PrependInt64Slot(0, dimension, 0) + + +def StablehloSortOptionsAddIsStable(builder, isStable): + builder.PrependBoolSlot(1, isStable, 0) + + +def StablehloSortOptionsAddComparatorSubgraphIndex(builder, comparatorSubgraphIndex): + builder.PrependInt32Slot(2, comparatorSubgraphIndex, 0) + + +def StablehloSortOptionsEnd(builder): + return builder.EndObject() + + +class StablehloSortOptionsT(object): + + # StablehloSortOptionsT + def __init__(self): + self.dimension = 0 # type: int + self.isStable = False # type: bool + self.comparatorSubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloSortOptions = StablehloSortOptions() + stablehloSortOptions.Init(buf, pos) + return cls.InitFromObj(stablehloSortOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloSortOptions): + x = StablehloSortOptionsT() + x._UnPack(stablehloSortOptions) + return x + + # StablehloSortOptionsT + def _UnPack(self, stablehloSortOptions): + if stablehloSortOptions is None: + return + self.dimension = stablehloSortOptions.Dimension() + self.isStable = stablehloSortOptions.IsStable() + self.comparatorSubgraphIndex = stablehloSortOptions.ComparatorSubgraphIndex() + + # StablehloSortOptionsT + def Pack(self, builder): + StablehloSortOptionsStart(builder) + StablehloSortOptionsAddDimension(builder, self.dimension) + StablehloSortOptionsAddIsStable(builder, self.isStable) + StablehloSortOptionsAddComparatorSubgraphIndex(builder, + self.comparatorSubgraphIndex) + stablehloSortOptions = StablehloSortOptionsEnd(builder) + return stablehloSortOptions + + +class StablehloConcatenateOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloConcatenateOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloConcatenateOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloConcatenateOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloConcatenateOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloConcatenateOptions + def Dimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + +def StablehloConcatenateOptionsStart(builder): + builder.StartObject(1) + + +def StablehloConcatenateOptionsAddDimension(builder, dimension): + builder.PrependInt64Slot(0, dimension, 0) + + +def StablehloConcatenateOptionsEnd(builder): + return builder.EndObject() + + +class StablehloConcatenateOptionsT(object): + + # StablehloConcatenateOptionsT + def __init__(self): + self.dimension = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloConcatenateOptions = StablehloConcatenateOptions() + stablehloConcatenateOptions.Init(buf, pos) + return cls.InitFromObj(stablehloConcatenateOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloConcatenateOptions): + x = StablehloConcatenateOptionsT() + x._UnPack(stablehloConcatenateOptions) + return x + + # StablehloConcatenateOptionsT + def _UnPack(self, stablehloConcatenateOptions): + if stablehloConcatenateOptions is None: + return + self.dimension = stablehloConcatenateOptions.Dimension() + + # StablehloConcatenateOptionsT + def Pack(self, builder): + StablehloConcatenateOptionsStart(builder) + StablehloConcatenateOptionsAddDimension(builder, self.dimension) + stablehloConcatenateOptions = StablehloConcatenateOptionsEnd(builder) + return stablehloConcatenateOptions + + +class StablehloBroadcastInDimOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloBroadcastInDimOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloBroadcastInDimOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloBroadcastInDimOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloBroadcastInDimOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloBroadcastInDimOptions + def BroadcastDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloBroadcastInDimOptions + def BroadcastDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloBroadcastInDimOptions + def BroadcastDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloBroadcastInDimOptions + def BroadcastDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def StablehloBroadcastInDimOptionsStart(builder): + builder.StartObject(1) + + +def StablehloBroadcastInDimOptionsAddBroadcastDimensions(builder, broadcastDimensions): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(broadcastDimensions), 0) + + +def StablehloBroadcastInDimOptionsStartBroadcastDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloBroadcastInDimOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloBroadcastInDimOptionsT(object): + + # StablehloBroadcastInDimOptionsT + def __init__(self): + self.broadcastDimensions = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloBroadcastInDimOptions = StablehloBroadcastInDimOptions() + stablehloBroadcastInDimOptions.Init(buf, pos) + return cls.InitFromObj(stablehloBroadcastInDimOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloBroadcastInDimOptions): + x = StablehloBroadcastInDimOptionsT() + x._UnPack(stablehloBroadcastInDimOptions) + return x + + # StablehloBroadcastInDimOptionsT + def _UnPack(self, stablehloBroadcastInDimOptions): + if stablehloBroadcastInDimOptions is None: + return + if not stablehloBroadcastInDimOptions.BroadcastDimensionsIsNone(): + if np is None: + self.broadcastDimensions = [] + for i in range( + stablehloBroadcastInDimOptions.BroadcastDimensionsLength()): + self.broadcastDimensions.append( + stablehloBroadcastInDimOptions.BroadcastDimensions(i)) + else: + self.broadcastDimensions = stablehloBroadcastInDimOptions.BroadcastDimensionsAsNumpy( + ) + + # StablehloBroadcastInDimOptionsT + def Pack(self, builder): + if self.broadcastDimensions is not None: + if np is not None and type(self.broadcastDimensions) is np.ndarray: + broadcastDimensions = builder.CreateNumpyVector(self.broadcastDimensions) + else: + StablehloBroadcastInDimOptionsStartBroadcastDimensionsVector( + builder, len(self.broadcastDimensions)) + for i in reversed(range(len(self.broadcastDimensions))): + builder.PrependInt64(self.broadcastDimensions[i]) + broadcastDimensions = builder.EndVector() + StablehloBroadcastInDimOptionsStart(builder) + if self.broadcastDimensions is not None: + StablehloBroadcastInDimOptionsAddBroadcastDimensions( + builder, broadcastDimensions) + stablehloBroadcastInDimOptions = StablehloBroadcastInDimOptionsEnd(builder) + return stablehloBroadcastInDimOptions + + +class StablehloCompareOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloCompareOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloCompareOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloCompareOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloCompareOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloCompareOptions + def ComparisonDirection(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # StablehloCompareOptions + def CompareType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + +def StablehloCompareOptionsStart(builder): + builder.StartObject(2) + + +def StablehloCompareOptionsAddComparisonDirection(builder, comparisonDirection): + builder.PrependUint32Slot(0, comparisonDirection, 0) + + +def StablehloCompareOptionsAddCompareType(builder, compareType): + builder.PrependUint32Slot(1, compareType, 0) + + +def StablehloCompareOptionsEnd(builder): + return builder.EndObject() + + +class StablehloCompareOptionsT(object): + + # StablehloCompareOptionsT + def __init__(self): + self.comparisonDirection = 0 # type: int + self.compareType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloCompareOptions = StablehloCompareOptions() + stablehloCompareOptions.Init(buf, pos) + return cls.InitFromObj(stablehloCompareOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloCompareOptions): + x = StablehloCompareOptionsT() + x._UnPack(stablehloCompareOptions) + return x + + # StablehloCompareOptionsT + def _UnPack(self, stablehloCompareOptions): + if stablehloCompareOptions is None: + return + self.comparisonDirection = stablehloCompareOptions.ComparisonDirection() + self.compareType = stablehloCompareOptions.CompareType() + + # StablehloCompareOptionsT + def Pack(self, builder): + StablehloCompareOptionsStart(builder) + StablehloCompareOptionsAddComparisonDirection(builder, self.comparisonDirection) + StablehloCompareOptionsAddCompareType(builder, self.compareType) + stablehloCompareOptions = StablehloCompareOptionsEnd(builder) + return stablehloCompareOptions + + +class StablehloDynamicSliceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloDynamicSliceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloDynamicSliceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloDynamicSliceOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloDynamicSliceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloDynamicSliceOptions + def SliceSizes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloDynamicSliceOptions + def SliceSizesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloDynamicSliceOptions + def SliceSizesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloDynamicSliceOptions + def SliceSizesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def StablehloDynamicSliceOptionsStart(builder): + builder.StartObject(1) + + +def StablehloDynamicSliceOptionsAddSliceSizes(builder, sliceSizes): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(sliceSizes), 0) + + +def StablehloDynamicSliceOptionsStartSliceSizesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloDynamicSliceOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloDynamicSliceOptionsT(object): + + # StablehloDynamicSliceOptionsT + def __init__(self): + self.sliceSizes = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloDynamicSliceOptions = StablehloDynamicSliceOptions() + stablehloDynamicSliceOptions.Init(buf, pos) + return cls.InitFromObj(stablehloDynamicSliceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloDynamicSliceOptions): + x = StablehloDynamicSliceOptionsT() + x._UnPack(stablehloDynamicSliceOptions) + return x + + # StablehloDynamicSliceOptionsT + def _UnPack(self, stablehloDynamicSliceOptions): + if stablehloDynamicSliceOptions is None: + return + if not stablehloDynamicSliceOptions.SliceSizesIsNone(): + if np is None: + self.sliceSizes = [] + for i in range(stablehloDynamicSliceOptions.SliceSizesLength()): + self.sliceSizes.append(stablehloDynamicSliceOptions.SliceSizes(i)) + else: + self.sliceSizes = stablehloDynamicSliceOptions.SliceSizesAsNumpy() + + # StablehloDynamicSliceOptionsT + def Pack(self, builder): + if self.sliceSizes is not None: + if np is not None and type(self.sliceSizes) is np.ndarray: + sliceSizes = builder.CreateNumpyVector(self.sliceSizes) + else: + StablehloDynamicSliceOptionsStartSliceSizesVector( + builder, len(self.sliceSizes)) + for i in reversed(range(len(self.sliceSizes))): + builder.PrependInt64(self.sliceSizes[i]) + sliceSizes = builder.EndVector() + StablehloDynamicSliceOptionsStart(builder) + if self.sliceSizes is not None: + StablehloDynamicSliceOptionsAddSliceSizes(builder, sliceSizes) + stablehloDynamicSliceOptions = StablehloDynamicSliceOptionsEnd(builder) + return stablehloDynamicSliceOptions + + +class StablehloPadOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloPadOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloPadOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloPadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloPadOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloPadOptions + def EdgePaddingLow(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloPadOptions + def EdgePaddingLowAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloPadOptions + def EdgePaddingLowLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloPadOptions + def EdgePaddingLowIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloPadOptions + def EdgePaddingHigh(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloPadOptions + def EdgePaddingHighAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloPadOptions + def EdgePaddingHighLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloPadOptions + def EdgePaddingHighIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloPadOptions + def InteriorPadding(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloPadOptions + def InteriorPaddingAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloPadOptions + def InteriorPaddingLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloPadOptions + def InteriorPaddingIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + +def StablehloPadOptionsStart(builder): + builder.StartObject(3) + + +def StablehloPadOptionsAddEdgePaddingLow(builder, edgePaddingLow): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(edgePaddingLow), 0) + + +def StablehloPadOptionsStartEdgePaddingLowVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloPadOptionsAddEdgePaddingHigh(builder, edgePaddingHigh): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(edgePaddingHigh), 0) + + +def StablehloPadOptionsStartEdgePaddingHighVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloPadOptionsAddInteriorPadding(builder, interiorPadding): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(interiorPadding), 0) + + +def StablehloPadOptionsStartInteriorPaddingVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloPadOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloPadOptionsT(object): + + # StablehloPadOptionsT + def __init__(self): + self.edgePaddingLow = None # type: List[int] + self.edgePaddingHigh = None # type: List[int] + self.interiorPadding = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloPadOptions = StablehloPadOptions() + stablehloPadOptions.Init(buf, pos) + return cls.InitFromObj(stablehloPadOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloPadOptions): + x = StablehloPadOptionsT() + x._UnPack(stablehloPadOptions) + return x + + # StablehloPadOptionsT + def _UnPack(self, stablehloPadOptions): + if stablehloPadOptions is None: + return + if not stablehloPadOptions.EdgePaddingLowIsNone(): + if np is None: + self.edgePaddingLow = [] + for i in range(stablehloPadOptions.EdgePaddingLowLength()): + self.edgePaddingLow.append(stablehloPadOptions.EdgePaddingLow(i)) + else: + self.edgePaddingLow = stablehloPadOptions.EdgePaddingLowAsNumpy() + if not stablehloPadOptions.EdgePaddingHighIsNone(): + if np is None: + self.edgePaddingHigh = [] + for i in range(stablehloPadOptions.EdgePaddingHighLength()): + self.edgePaddingHigh.append(stablehloPadOptions.EdgePaddingHigh(i)) + else: + self.edgePaddingHigh = stablehloPadOptions.EdgePaddingHighAsNumpy() + if not stablehloPadOptions.InteriorPaddingIsNone(): + if np is None: + self.interiorPadding = [] + for i in range(stablehloPadOptions.InteriorPaddingLength()): + self.interiorPadding.append(stablehloPadOptions.InteriorPadding(i)) + else: + self.interiorPadding = stablehloPadOptions.InteriorPaddingAsNumpy() + + # StablehloPadOptionsT + def Pack(self, builder): + if self.edgePaddingLow is not None: + if np is not None and type(self.edgePaddingLow) is np.ndarray: + edgePaddingLow = builder.CreateNumpyVector(self.edgePaddingLow) + else: + StablehloPadOptionsStartEdgePaddingLowVector(builder, + len(self.edgePaddingLow)) + for i in reversed(range(len(self.edgePaddingLow))): + builder.PrependInt64(self.edgePaddingLow[i]) + edgePaddingLow = builder.EndVector() + if self.edgePaddingHigh is not None: + if np is not None and type(self.edgePaddingHigh) is np.ndarray: + edgePaddingHigh = builder.CreateNumpyVector(self.edgePaddingHigh) + else: + StablehloPadOptionsStartEdgePaddingHighVector(builder, + len(self.edgePaddingHigh)) + for i in reversed(range(len(self.edgePaddingHigh))): + builder.PrependInt64(self.edgePaddingHigh[i]) + edgePaddingHigh = builder.EndVector() + if self.interiorPadding is not None: + if np is not None and type(self.interiorPadding) is np.ndarray: + interiorPadding = builder.CreateNumpyVector(self.interiorPadding) + else: + StablehloPadOptionsStartInteriorPaddingVector(builder, + len(self.interiorPadding)) + for i in reversed(range(len(self.interiorPadding))): + builder.PrependInt64(self.interiorPadding[i]) + interiorPadding = builder.EndVector() + StablehloPadOptionsStart(builder) + if self.edgePaddingLow is not None: + StablehloPadOptionsAddEdgePaddingLow(builder, edgePaddingLow) + if self.edgePaddingHigh is not None: + StablehloPadOptionsAddEdgePaddingHigh(builder, edgePaddingHigh) + if self.interiorPadding is not None: + StablehloPadOptionsAddInteriorPadding(builder, interiorPadding) + stablehloPadOptions = StablehloPadOptionsEnd(builder) + return stablehloPadOptions + + +class StablehloIotaOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloIotaOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloIotaOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloIotaOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloIotaOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloIotaOptions + def IotaDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + +def StablehloIotaOptionsStart(builder): + builder.StartObject(1) + + +def StablehloIotaOptionsAddIotaDimension(builder, iotaDimension): + builder.PrependInt64Slot(0, iotaDimension, 0) + + +def StablehloIotaOptionsEnd(builder): + return builder.EndObject() + + +class StablehloIotaOptionsT(object): + + # StablehloIotaOptionsT + def __init__(self): + self.iotaDimension = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloIotaOptions = StablehloIotaOptions() + stablehloIotaOptions.Init(buf, pos) + return cls.InitFromObj(stablehloIotaOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloIotaOptions): + x = StablehloIotaOptionsT() + x._UnPack(stablehloIotaOptions) + return x + + # StablehloIotaOptionsT + def _UnPack(self, stablehloIotaOptions): + if stablehloIotaOptions is None: + return + self.iotaDimension = stablehloIotaOptions.IotaDimension() + + # StablehloIotaOptionsT + def Pack(self, builder): + StablehloIotaOptionsStart(builder) + StablehloIotaOptionsAddIotaDimension(builder, self.iotaDimension) + stablehloIotaOptions = StablehloIotaOptionsEnd(builder) + return stablehloIotaOptions + + +class StablehloCustomCallOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloCustomCallOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloCustomCallOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloCustomCallOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloCustomCallOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloCustomCallOptions + def CallTargetName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # StablehloCustomCallOptions + def HasSideEffect(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # StablehloCustomCallOptions + def BackendConfig(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # StablehloCustomCallOptions + def ApiVersion(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StablehloCustomCallOptions + def CalledComputations(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # StablehloCustomCallOptions + def CalledComputationsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # StablehloCustomCallOptions + def CalledComputationsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloCustomCallOptions + def CalledComputationsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # StablehloCustomCallOptions + def CustomAttributes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # StablehloCustomCallOptions + def CustomAttributesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # StablehloCustomCallOptions + def CustomAttributesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloCustomCallOptions + def CustomAttributesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + return o == 0 + + +def StablehloCustomCallOptionsStart(builder): + builder.StartObject(6) + + +def StablehloCustomCallOptionsAddCallTargetName(builder, callTargetName): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(callTargetName), 0) + + +def StablehloCustomCallOptionsAddHasSideEffect(builder, hasSideEffect): + builder.PrependBoolSlot(1, hasSideEffect, 0) + + +def StablehloCustomCallOptionsAddBackendConfig(builder, backendConfig): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(backendConfig), 0) + + +def StablehloCustomCallOptionsAddApiVersion(builder, apiVersion): + builder.PrependInt32Slot(3, apiVersion, 0) + + +def StablehloCustomCallOptionsAddCalledComputations(builder, calledComputations): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(calledComputations), 0) + + +def StablehloCustomCallOptionsStartCalledComputationsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def StablehloCustomCallOptionsAddCustomAttributes(builder, customAttributes): + builder.PrependUOffsetTRelativeSlot( + 5, flatbuffers.number_types.UOffsetTFlags.py_type(customAttributes), 0) + + +def StablehloCustomCallOptionsStartCustomAttributesVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def StablehloCustomCallOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloCustomCallOptionsT(object): + + # StablehloCustomCallOptionsT + def __init__(self): + self.callTargetName = None # type: str + self.hasSideEffect = False # type: bool + self.backendConfig = None # type: str + self.apiVersion = 0 # type: int + self.calledComputations = None # type: List[int] + self.customAttributes = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloCustomCallOptions = StablehloCustomCallOptions() + stablehloCustomCallOptions.Init(buf, pos) + return cls.InitFromObj(stablehloCustomCallOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloCustomCallOptions): + x = StablehloCustomCallOptionsT() + x._UnPack(stablehloCustomCallOptions) + return x + + # StablehloCustomCallOptionsT + def _UnPack(self, stablehloCustomCallOptions): + if stablehloCustomCallOptions is None: + return + self.callTargetName = stablehloCustomCallOptions.CallTargetName() + self.hasSideEffect = stablehloCustomCallOptions.HasSideEffect() + self.backendConfig = stablehloCustomCallOptions.BackendConfig() + self.apiVersion = stablehloCustomCallOptions.ApiVersion() + if not stablehloCustomCallOptions.CalledComputationsIsNone(): + if np is None: + self.calledComputations = [] + for i in range(stablehloCustomCallOptions.CalledComputationsLength()): + self.calledComputations.append( + stablehloCustomCallOptions.CalledComputations(i)) + else: + self.calledComputations = stablehloCustomCallOptions.CalledComputationsAsNumpy( + ) + if not stablehloCustomCallOptions.CustomAttributesIsNone(): + if np is None: + self.customAttributes = [] + for i in range(stablehloCustomCallOptions.CustomAttributesLength()): + self.customAttributes.append( + stablehloCustomCallOptions.CustomAttributes(i)) + else: + self.customAttributes = stablehloCustomCallOptions.CustomAttributesAsNumpy( + ) + + # StablehloCustomCallOptionsT + def Pack(self, builder): + if self.callTargetName is not None: + callTargetName = builder.CreateString(self.callTargetName) + if self.backendConfig is not None: + backendConfig = builder.CreateString(self.backendConfig) + if self.calledComputations is not None: + if np is not None and type(self.calledComputations) is np.ndarray: + calledComputations = builder.CreateNumpyVector(self.calledComputations) + else: + StablehloCustomCallOptionsStartCalledComputationsVector( + builder, len(self.calledComputations)) + for i in reversed(range(len(self.calledComputations))): + builder.PrependInt32(self.calledComputations[i]) + calledComputations = builder.EndVector() + if self.customAttributes is not None: + if np is not None and type(self.customAttributes) is np.ndarray: + customAttributes = builder.CreateNumpyVector(self.customAttributes) + else: + StablehloCustomCallOptionsStartCustomAttributesVector( + builder, len(self.customAttributes)) + for i in reversed(range(len(self.customAttributes))): + builder.PrependUint8(self.customAttributes[i]) + customAttributes = builder.EndVector() + StablehloCustomCallOptionsStart(builder) + if self.callTargetName is not None: + StablehloCustomCallOptionsAddCallTargetName(builder, callTargetName) + StablehloCustomCallOptionsAddHasSideEffect(builder, self.hasSideEffect) + if self.backendConfig is not None: + StablehloCustomCallOptionsAddBackendConfig(builder, backendConfig) + StablehloCustomCallOptionsAddApiVersion(builder, self.apiVersion) + if self.calledComputations is not None: + StablehloCustomCallOptionsAddCalledComputations(builder, calledComputations) + if self.customAttributes is not None: + StablehloCustomCallOptionsAddCustomAttributes(builder, customAttributes) + stablehloCustomCallOptions = StablehloCustomCallOptionsEnd(builder) + return stablehloCustomCallOptions + + +class StablehloReduceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloReduceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloReduceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloReduceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloReduceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloReduceOptions + def Dimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloReduceOptions + def DimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloReduceOptions + def DimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloReduceOptions + def DimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloReduceOptions + def BodySubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StablehloReduceOptionsStart(builder): + builder.StartObject(2) + + +def StablehloReduceOptionsAddDimensions(builder, dimensions): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(dimensions), 0) + + +def StablehloReduceOptionsStartDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloReduceOptionsAddBodySubgraphIndex(builder, bodySubgraphIndex): + builder.PrependInt32Slot(1, bodySubgraphIndex, 0) + + +def StablehloReduceOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloReduceOptionsT(object): + + # StablehloReduceOptionsT + def __init__(self): + self.dimensions = None # type: List[int] + self.bodySubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloReduceOptions = StablehloReduceOptions() + stablehloReduceOptions.Init(buf, pos) + return cls.InitFromObj(stablehloReduceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloReduceOptions): + x = StablehloReduceOptionsT() + x._UnPack(stablehloReduceOptions) + return x + + # StablehloReduceOptionsT + def _UnPack(self, stablehloReduceOptions): + if stablehloReduceOptions is None: + return + if not stablehloReduceOptions.DimensionsIsNone(): + if np is None: + self.dimensions = [] + for i in range(stablehloReduceOptions.DimensionsLength()): + self.dimensions.append(stablehloReduceOptions.Dimensions(i)) + else: + self.dimensions = stablehloReduceOptions.DimensionsAsNumpy() + self.bodySubgraphIndex = stablehloReduceOptions.BodySubgraphIndex() + + # StablehloReduceOptionsT + def Pack(self, builder): + if self.dimensions is not None: + if np is not None and type(self.dimensions) is np.ndarray: + dimensions = builder.CreateNumpyVector(self.dimensions) + else: + StablehloReduceOptionsStartDimensionsVector(builder, len(self.dimensions)) + for i in reversed(range(len(self.dimensions))): + builder.PrependInt64(self.dimensions[i]) + dimensions = builder.EndVector() + StablehloReduceOptionsStart(builder) + if self.dimensions is not None: + StablehloReduceOptionsAddDimensions(builder, dimensions) + StablehloReduceOptionsAddBodySubgraphIndex(builder, self.bodySubgraphIndex) + stablehloReduceOptions = StablehloReduceOptionsEnd(builder) + return stablehloReduceOptions + + +class StablehloSliceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloSliceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloSliceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloSliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloSliceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloSliceOptions + def StartIndices(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloSliceOptions + def StartIndicesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloSliceOptions + def StartIndicesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloSliceOptions + def StartIndicesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloSliceOptions + def LimitIndices(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloSliceOptions + def LimitIndicesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloSliceOptions + def LimitIndicesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloSliceOptions + def LimitIndicesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloSliceOptions + def Strides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloSliceOptions + def StridesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloSliceOptions + def StridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloSliceOptions + def StridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + +def StablehloSliceOptionsStart(builder): + builder.StartObject(3) + + +def StablehloSliceOptionsAddStartIndices(builder, startIndices): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(startIndices), 0) + + +def StablehloSliceOptionsStartStartIndicesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloSliceOptionsAddLimitIndices(builder, limitIndices): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(limitIndices), 0) + + +def StablehloSliceOptionsStartLimitIndicesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloSliceOptionsAddStrides(builder, strides): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0) + + +def StablehloSliceOptionsStartStridesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloSliceOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloSliceOptionsT(object): + + # StablehloSliceOptionsT + def __init__(self): + self.startIndices = None # type: List[int] + self.limitIndices = None # type: List[int] + self.strides = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloSliceOptions = StablehloSliceOptions() + stablehloSliceOptions.Init(buf, pos) + return cls.InitFromObj(stablehloSliceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloSliceOptions): + x = StablehloSliceOptionsT() + x._UnPack(stablehloSliceOptions) + return x + + # StablehloSliceOptionsT + def _UnPack(self, stablehloSliceOptions): + if stablehloSliceOptions is None: + return + if not stablehloSliceOptions.StartIndicesIsNone(): + if np is None: + self.startIndices = [] + for i in range(stablehloSliceOptions.StartIndicesLength()): + self.startIndices.append(stablehloSliceOptions.StartIndices(i)) + else: + self.startIndices = stablehloSliceOptions.StartIndicesAsNumpy() + if not stablehloSliceOptions.LimitIndicesIsNone(): + if np is None: + self.limitIndices = [] + for i in range(stablehloSliceOptions.LimitIndicesLength()): + self.limitIndices.append(stablehloSliceOptions.LimitIndices(i)) + else: + self.limitIndices = stablehloSliceOptions.LimitIndicesAsNumpy() + if not stablehloSliceOptions.StridesIsNone(): + if np is None: + self.strides = [] + for i in range(stablehloSliceOptions.StridesLength()): + self.strides.append(stablehloSliceOptions.Strides(i)) + else: + self.strides = stablehloSliceOptions.StridesAsNumpy() + + # StablehloSliceOptionsT + def Pack(self, builder): + if self.startIndices is not None: + if np is not None and type(self.startIndices) is np.ndarray: + startIndices = builder.CreateNumpyVector(self.startIndices) + else: + StablehloSliceOptionsStartStartIndicesVector(builder, + len(self.startIndices)) + for i in reversed(range(len(self.startIndices))): + builder.PrependInt64(self.startIndices[i]) + startIndices = builder.EndVector() + if self.limitIndices is not None: + if np is not None and type(self.limitIndices) is np.ndarray: + limitIndices = builder.CreateNumpyVector(self.limitIndices) + else: + StablehloSliceOptionsStartLimitIndicesVector(builder, + len(self.limitIndices)) + for i in reversed(range(len(self.limitIndices))): + builder.PrependInt64(self.limitIndices[i]) + limitIndices = builder.EndVector() + if self.strides is not None: + if np is not None and type(self.strides) is np.ndarray: + strides = builder.CreateNumpyVector(self.strides) + else: + StablehloSliceOptionsStartStridesVector(builder, len(self.strides)) + for i in reversed(range(len(self.strides))): + builder.PrependInt64(self.strides[i]) + strides = builder.EndVector() + StablehloSliceOptionsStart(builder) + if self.startIndices is not None: + StablehloSliceOptionsAddStartIndices(builder, startIndices) + if self.limitIndices is not None: + StablehloSliceOptionsAddLimitIndices(builder, limitIndices) + if self.strides is not None: + StablehloSliceOptionsAddStrides(builder, strides) + stablehloSliceOptions = StablehloSliceOptionsEnd(builder) + return stablehloSliceOptions + + +class StablehloConvolutionOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloConvolutionOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloConvolutionOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloConvolutionOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloConvolutionOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloConvolutionOptions + def WindowStrides(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def WindowStridesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def WindowStridesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def WindowStridesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # StablehloConvolutionOptions + def Padding(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def PaddingAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def PaddingLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def PaddingIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloConvolutionOptions + def LhsDilation(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def LhsDilationAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def LhsDilationLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def LhsDilationIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StablehloConvolutionOptions + def RhsDilation(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def RhsDilationAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def RhsDilationLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def RhsDilationIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # StablehloConvolutionOptions + def WindowReversal(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.BoolFlags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # StablehloConvolutionOptions + def WindowReversalAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.BoolFlags, o) + return 0 + + # StablehloConvolutionOptions + def WindowReversalLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def WindowReversalIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # StablehloConvolutionOptions + def InputBatchDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def InputFeatureDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def InputSpatialDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def InputSpatialDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def InputSpatialDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def InputSpatialDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + return o == 0 + + # StablehloConvolutionOptions + def KernelInputFeatureDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def KernelOutputFeatureDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def KernelSpatialDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def KernelSpatialDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def KernelSpatialDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def KernelSpatialDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + return o == 0 + + # StablehloConvolutionOptions + def OutputBatchDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def OutputFeatureDimension(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def OutputSpatialDimensions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloConvolutionOptions + def OutputSpatialDimensionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloConvolutionOptions + def OutputSpatialDimensionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def OutputSpatialDimensionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30)) + return o == 0 + + # StablehloConvolutionOptions + def FeatureGroupCount(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def BatchGroupCount(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(34)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloConvolutionOptions + def PrecisionConfig(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # StablehloConvolutionOptions + def PrecisionConfigAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # StablehloConvolutionOptions + def PrecisionConfigLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloConvolutionOptions + def PrecisionConfigIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) + return o == 0 + + +def StablehloConvolutionOptionsStart(builder): + builder.StartObject(17) + + +def StablehloConvolutionOptionsAddWindowStrides(builder, windowStrides): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(windowStrides), 0) + + +def StablehloConvolutionOptionsStartWindowStridesVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddPadding(builder, padding): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0) + + +def StablehloConvolutionOptionsStartPaddingVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddLhsDilation(builder, lhsDilation): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(lhsDilation), 0) + + +def StablehloConvolutionOptionsStartLhsDilationVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddRhsDilation(builder, rhsDilation): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(rhsDilation), 0) + + +def StablehloConvolutionOptionsStartRhsDilationVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddWindowReversal(builder, windowReversal): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(windowReversal), 0) + + +def StablehloConvolutionOptionsStartWindowReversalVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def StablehloConvolutionOptionsAddInputBatchDimension(builder, inputBatchDimension): + builder.PrependInt64Slot(5, inputBatchDimension, 0) + + +def StablehloConvolutionOptionsAddInputFeatureDimension(builder, inputFeatureDimension): + builder.PrependInt64Slot(6, inputFeatureDimension, 0) + + +def StablehloConvolutionOptionsAddInputSpatialDimensions(builder, inputSpatialDimensions): + builder.PrependUOffsetTRelativeSlot( + 7, flatbuffers.number_types.UOffsetTFlags.py_type(inputSpatialDimensions), 0) + + +def StablehloConvolutionOptionsStartInputSpatialDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddKernelInputFeatureDimension( + builder, kernelInputFeatureDimension): + builder.PrependInt64Slot(8, kernelInputFeatureDimension, 0) + + +def StablehloConvolutionOptionsAddKernelOutputFeatureDimension( + builder, kernelOutputFeatureDimension): + builder.PrependInt64Slot(9, kernelOutputFeatureDimension, 0) + + +def StablehloConvolutionOptionsAddKernelSpatialDimensions(builder, + kernelSpatialDimensions): + builder.PrependUOffsetTRelativeSlot( + 10, flatbuffers.number_types.UOffsetTFlags.py_type(kernelSpatialDimensions), 0) + + +def StablehloConvolutionOptionsStartKernelSpatialDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddOutputBatchDimension(builder, outputBatchDimension): + builder.PrependInt64Slot(11, outputBatchDimension, 0) + + +def StablehloConvolutionOptionsAddOutputFeatureDimension(builder, outputFeatureDimension): + builder.PrependInt64Slot(12, outputFeatureDimension, 0) + + +def StablehloConvolutionOptionsAddOutputSpatialDimensions(builder, + outputSpatialDimensions): + builder.PrependUOffsetTRelativeSlot( + 13, flatbuffers.number_types.UOffsetTFlags.py_type(outputSpatialDimensions), 0) + + +def StablehloConvolutionOptionsStartOutputSpatialDimensionsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloConvolutionOptionsAddFeatureGroupCount(builder, featureGroupCount): + builder.PrependInt64Slot(14, featureGroupCount, 0) + + +def StablehloConvolutionOptionsAddBatchGroupCount(builder, batchGroupCount): + builder.PrependInt64Slot(15, batchGroupCount, 0) + + +def StablehloConvolutionOptionsAddPrecisionConfig(builder, precisionConfig): + builder.PrependUOffsetTRelativeSlot( + 16, flatbuffers.number_types.UOffsetTFlags.py_type(precisionConfig), 0) + + +def StablehloConvolutionOptionsStartPrecisionConfigVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def StablehloConvolutionOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloConvolutionOptionsT(object): + + # StablehloConvolutionOptionsT + def __init__(self): + self.windowStrides = None # type: List[int] + self.padding = None # type: List[int] + self.lhsDilation = None # type: List[int] + self.rhsDilation = None # type: List[int] + self.windowReversal = None # type: List[bool] + self.inputBatchDimension = 0 # type: int + self.inputFeatureDimension = 0 # type: int + self.inputSpatialDimensions = None # type: List[int] + self.kernelInputFeatureDimension = 0 # type: int + self.kernelOutputFeatureDimension = 0 # type: int + self.kernelSpatialDimensions = None # type: List[int] + self.outputBatchDimension = 0 # type: int + self.outputFeatureDimension = 0 # type: int + self.outputSpatialDimensions = None # type: List[int] + self.featureGroupCount = 0 # type: int + self.batchGroupCount = 0 # type: int + self.precisionConfig = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloConvolutionOptions = StablehloConvolutionOptions() + stablehloConvolutionOptions.Init(buf, pos) + return cls.InitFromObj(stablehloConvolutionOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloConvolutionOptions): + x = StablehloConvolutionOptionsT() + x._UnPack(stablehloConvolutionOptions) + return x + + # StablehloConvolutionOptionsT + def _UnPack(self, stablehloConvolutionOptions): + if stablehloConvolutionOptions is None: + return + if not stablehloConvolutionOptions.WindowStridesIsNone(): + if np is None: + self.windowStrides = [] + for i in range(stablehloConvolutionOptions.WindowStridesLength()): + self.windowStrides.append( + stablehloConvolutionOptions.WindowStrides(i)) + else: + self.windowStrides = stablehloConvolutionOptions.WindowStridesAsNumpy() + if not stablehloConvolutionOptions.PaddingIsNone(): + if np is None: + self.padding = [] + for i in range(stablehloConvolutionOptions.PaddingLength()): + self.padding.append(stablehloConvolutionOptions.Padding(i)) + else: + self.padding = stablehloConvolutionOptions.PaddingAsNumpy() + if not stablehloConvolutionOptions.LhsDilationIsNone(): + if np is None: + self.lhsDilation = [] + for i in range(stablehloConvolutionOptions.LhsDilationLength()): + self.lhsDilation.append(stablehloConvolutionOptions.LhsDilation(i)) + else: + self.lhsDilation = stablehloConvolutionOptions.LhsDilationAsNumpy() + if not stablehloConvolutionOptions.RhsDilationIsNone(): + if np is None: + self.rhsDilation = [] + for i in range(stablehloConvolutionOptions.RhsDilationLength()): + self.rhsDilation.append(stablehloConvolutionOptions.RhsDilation(i)) + else: + self.rhsDilation = stablehloConvolutionOptions.RhsDilationAsNumpy() + if not stablehloConvolutionOptions.WindowReversalIsNone(): + if np is None: + self.windowReversal = [] + for i in range(stablehloConvolutionOptions.WindowReversalLength()): + self.windowReversal.append( + stablehloConvolutionOptions.WindowReversal(i)) + else: + self.windowReversal = stablehloConvolutionOptions.WindowReversalAsNumpy() + self.inputBatchDimension = stablehloConvolutionOptions.InputBatchDimension() + self.inputFeatureDimension = stablehloConvolutionOptions.InputFeatureDimension() + if not stablehloConvolutionOptions.InputSpatialDimensionsIsNone(): + if np is None: + self.inputSpatialDimensions = [] + for i in range( + stablehloConvolutionOptions.InputSpatialDimensionsLength()): + self.inputSpatialDimensions.append( + stablehloConvolutionOptions.InputSpatialDimensions(i)) + else: + self.inputSpatialDimensions = stablehloConvolutionOptions.InputSpatialDimensionsAsNumpy( + ) + self.kernelInputFeatureDimension = stablehloConvolutionOptions.KernelInputFeatureDimension( + ) + self.kernelOutputFeatureDimension = stablehloConvolutionOptions.KernelOutputFeatureDimension( + ) + if not stablehloConvolutionOptions.KernelSpatialDimensionsIsNone(): + if np is None: + self.kernelSpatialDimensions = [] + for i in range( + stablehloConvolutionOptions.KernelSpatialDimensionsLength()): + self.kernelSpatialDimensions.append( + stablehloConvolutionOptions.KernelSpatialDimensions(i)) + else: + self.kernelSpatialDimensions = stablehloConvolutionOptions.KernelSpatialDimensionsAsNumpy( + ) + self.outputBatchDimension = stablehloConvolutionOptions.OutputBatchDimension() + self.outputFeatureDimension = stablehloConvolutionOptions.OutputFeatureDimension() + if not stablehloConvolutionOptions.OutputSpatialDimensionsIsNone(): + if np is None: + self.outputSpatialDimensions = [] + for i in range( + stablehloConvolutionOptions.OutputSpatialDimensionsLength()): + self.outputSpatialDimensions.append( + stablehloConvolutionOptions.OutputSpatialDimensions(i)) + else: + self.outputSpatialDimensions = stablehloConvolutionOptions.OutputSpatialDimensionsAsNumpy( + ) + self.featureGroupCount = stablehloConvolutionOptions.FeatureGroupCount() + self.batchGroupCount = stablehloConvolutionOptions.BatchGroupCount() + if not stablehloConvolutionOptions.PrecisionConfigIsNone(): + if np is None: + self.precisionConfig = [] + for i in range(stablehloConvolutionOptions.PrecisionConfigLength()): + self.precisionConfig.append( + stablehloConvolutionOptions.PrecisionConfig(i)) + else: + self.precisionConfig = stablehloConvolutionOptions.PrecisionConfigAsNumpy( + ) + + # StablehloConvolutionOptionsT + def Pack(self, builder): + if self.windowStrides is not None: + if np is not None and type(self.windowStrides) is np.ndarray: + windowStrides = builder.CreateNumpyVector(self.windowStrides) + else: + StablehloConvolutionOptionsStartWindowStridesVector( + builder, len(self.windowStrides)) + for i in reversed(range(len(self.windowStrides))): + builder.PrependInt64(self.windowStrides[i]) + windowStrides = builder.EndVector() + if self.padding is not None: + if np is not None and type(self.padding) is np.ndarray: + padding = builder.CreateNumpyVector(self.padding) + else: + StablehloConvolutionOptionsStartPaddingVector(builder, len(self.padding)) + for i in reversed(range(len(self.padding))): + builder.PrependInt64(self.padding[i]) + padding = builder.EndVector() + if self.lhsDilation is not None: + if np is not None and type(self.lhsDilation) is np.ndarray: + lhsDilation = builder.CreateNumpyVector(self.lhsDilation) + else: + StablehloConvolutionOptionsStartLhsDilationVector( + builder, len(self.lhsDilation)) + for i in reversed(range(len(self.lhsDilation))): + builder.PrependInt64(self.lhsDilation[i]) + lhsDilation = builder.EndVector() + if self.rhsDilation is not None: + if np is not None and type(self.rhsDilation) is np.ndarray: + rhsDilation = builder.CreateNumpyVector(self.rhsDilation) + else: + StablehloConvolutionOptionsStartRhsDilationVector( + builder, len(self.rhsDilation)) + for i in reversed(range(len(self.rhsDilation))): + builder.PrependInt64(self.rhsDilation[i]) + rhsDilation = builder.EndVector() + if self.windowReversal is not None: + if np is not None and type(self.windowReversal) is np.ndarray: + windowReversal = builder.CreateNumpyVector(self.windowReversal) + else: + StablehloConvolutionOptionsStartWindowReversalVector( + builder, len(self.windowReversal)) + for i in reversed(range(len(self.windowReversal))): + builder.PrependBool(self.windowReversal[i]) + windowReversal = builder.EndVector() + if self.inputSpatialDimensions is not None: + if np is not None and type(self.inputSpatialDimensions) is np.ndarray: + inputSpatialDimensions = builder.CreateNumpyVector( + self.inputSpatialDimensions) + else: + StablehloConvolutionOptionsStartInputSpatialDimensionsVector( + builder, len(self.inputSpatialDimensions)) + for i in reversed(range(len(self.inputSpatialDimensions))): + builder.PrependInt64(self.inputSpatialDimensions[i]) + inputSpatialDimensions = builder.EndVector() + if self.kernelSpatialDimensions is not None: + if np is not None and type(self.kernelSpatialDimensions) is np.ndarray: + kernelSpatialDimensions = builder.CreateNumpyVector( + self.kernelSpatialDimensions) + else: + StablehloConvolutionOptionsStartKernelSpatialDimensionsVector( + builder, len(self.kernelSpatialDimensions)) + for i in reversed(range(len(self.kernelSpatialDimensions))): + builder.PrependInt64(self.kernelSpatialDimensions[i]) + kernelSpatialDimensions = builder.EndVector() + if self.outputSpatialDimensions is not None: + if np is not None and type(self.outputSpatialDimensions) is np.ndarray: + outputSpatialDimensions = builder.CreateNumpyVector( + self.outputSpatialDimensions) + else: + StablehloConvolutionOptionsStartOutputSpatialDimensionsVector( + builder, len(self.outputSpatialDimensions)) + for i in reversed(range(len(self.outputSpatialDimensions))): + builder.PrependInt64(self.outputSpatialDimensions[i]) + outputSpatialDimensions = builder.EndVector() + if self.precisionConfig is not None: + if np is not None and type(self.precisionConfig) is np.ndarray: + precisionConfig = builder.CreateNumpyVector(self.precisionConfig) + else: + StablehloConvolutionOptionsStartPrecisionConfigVector( + builder, len(self.precisionConfig)) + for i in reversed(range(len(self.precisionConfig))): + builder.PrependUint32(self.precisionConfig[i]) + precisionConfig = builder.EndVector() + StablehloConvolutionOptionsStart(builder) + if self.windowStrides is not None: + StablehloConvolutionOptionsAddWindowStrides(builder, windowStrides) + if self.padding is not None: + StablehloConvolutionOptionsAddPadding(builder, padding) + if self.lhsDilation is not None: + StablehloConvolutionOptionsAddLhsDilation(builder, lhsDilation) + if self.rhsDilation is not None: + StablehloConvolutionOptionsAddRhsDilation(builder, rhsDilation) + if self.windowReversal is not None: + StablehloConvolutionOptionsAddWindowReversal(builder, windowReversal) + StablehloConvolutionOptionsAddInputBatchDimension(builder, + self.inputBatchDimension) + StablehloConvolutionOptionsAddInputFeatureDimension(builder, + self.inputFeatureDimension) + if self.inputSpatialDimensions is not None: + StablehloConvolutionOptionsAddInputSpatialDimensions( + builder, inputSpatialDimensions) + StablehloConvolutionOptionsAddKernelInputFeatureDimension( + builder, self.kernelInputFeatureDimension) + StablehloConvolutionOptionsAddKernelOutputFeatureDimension( + builder, self.kernelOutputFeatureDimension) + if self.kernelSpatialDimensions is not None: + StablehloConvolutionOptionsAddKernelSpatialDimensions( + builder, kernelSpatialDimensions) + StablehloConvolutionOptionsAddOutputBatchDimension(builder, + self.outputBatchDimension) + StablehloConvolutionOptionsAddOutputFeatureDimension(builder, + self.outputFeatureDimension) + if self.outputSpatialDimensions is not None: + StablehloConvolutionOptionsAddOutputSpatialDimensions( + builder, outputSpatialDimensions) + StablehloConvolutionOptionsAddFeatureGroupCount(builder, self.featureGroupCount) + StablehloConvolutionOptionsAddBatchGroupCount(builder, self.batchGroupCount) + if self.precisionConfig is not None: + StablehloConvolutionOptionsAddPrecisionConfig(builder, precisionConfig) + stablehloConvolutionOptions = StablehloConvolutionOptionsEnd(builder) + return stablehloConvolutionOptions + + +class StablehloScatterOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloScatterOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloScatterOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloScatterOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloScatterOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloScatterOptions + def IndicesAreSorted(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # StablehloScatterOptions + def UpdateWindowDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloScatterOptions + def UpdateWindowDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloScatterOptions + def UpdateWindowDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloScatterOptions + def UpdateWindowDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # StablehloScatterOptions + def InsertedWindowDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloScatterOptions + def InsertedWindowDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloScatterOptions + def InsertedWindowDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloScatterOptions + def InsertedWindowDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StablehloScatterOptions + def ScatterDimsToOperandDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int64Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8)) + return 0 + + # StablehloScatterOptions + def ScatterDimsToOperandDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o) + return 0 + + # StablehloScatterOptions + def ScatterDimsToOperandDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloScatterOptions + def ScatterDimsToOperandDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # StablehloScatterOptions + def IndexVectorDim(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # StablehloScatterOptions + def UniqueIndices(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # StablehloScatterOptions + def UpdateComputationSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StablehloScatterOptionsStart(builder): + builder.StartObject(7) + + +def StablehloScatterOptionsAddIndicesAreSorted(builder, indicesAreSorted): + builder.PrependBoolSlot(0, indicesAreSorted, 0) + + +def StablehloScatterOptionsAddUpdateWindowDims(builder, updateWindowDims): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(updateWindowDims), 0) + + +def StablehloScatterOptionsStartUpdateWindowDimsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloScatterOptionsAddInsertedWindowDims(builder, insertedWindowDims): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(insertedWindowDims), 0) + + +def StablehloScatterOptionsStartInsertedWindowDimsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloScatterOptionsAddScatterDimsToOperandDims(builder, scatterDimsToOperandDims): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(scatterDimsToOperandDims), 0) + + +def StablehloScatterOptionsStartScatterDimsToOperandDimsVector(builder, numElems): + return builder.StartVector(8, numElems, 8) + + +def StablehloScatterOptionsAddIndexVectorDim(builder, indexVectorDim): + builder.PrependInt64Slot(4, indexVectorDim, 0) + + +def StablehloScatterOptionsAddUniqueIndices(builder, uniqueIndices): + builder.PrependBoolSlot(5, uniqueIndices, 0) + + +def StablehloScatterOptionsAddUpdateComputationSubgraphIndex( + builder, updateComputationSubgraphIndex): + builder.PrependInt32Slot(6, updateComputationSubgraphIndex, 0) + + +def StablehloScatterOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloScatterOptionsT(object): + + # StablehloScatterOptionsT + def __init__(self): + self.indicesAreSorted = False # type: bool + self.updateWindowDims = None # type: List[int] + self.insertedWindowDims = None # type: List[int] + self.scatterDimsToOperandDims = None # type: List[int] + self.indexVectorDim = 0 # type: int + self.uniqueIndices = False # type: bool + self.updateComputationSubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloScatterOptions = StablehloScatterOptions() + stablehloScatterOptions.Init(buf, pos) + return cls.InitFromObj(stablehloScatterOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloScatterOptions): + x = StablehloScatterOptionsT() + x._UnPack(stablehloScatterOptions) + return x + + # StablehloScatterOptionsT + def _UnPack(self, stablehloScatterOptions): + if stablehloScatterOptions is None: + return + self.indicesAreSorted = stablehloScatterOptions.IndicesAreSorted() + if not stablehloScatterOptions.UpdateWindowDimsIsNone(): + if np is None: + self.updateWindowDims = [] + for i in range(stablehloScatterOptions.UpdateWindowDimsLength()): + self.updateWindowDims.append( + stablehloScatterOptions.UpdateWindowDims(i)) + else: + self.updateWindowDims = stablehloScatterOptions.UpdateWindowDimsAsNumpy() + if not stablehloScatterOptions.InsertedWindowDimsIsNone(): + if np is None: + self.insertedWindowDims = [] + for i in range(stablehloScatterOptions.InsertedWindowDimsLength()): + self.insertedWindowDims.append( + stablehloScatterOptions.InsertedWindowDims(i)) + else: + self.insertedWindowDims = stablehloScatterOptions.InsertedWindowDimsAsNumpy( + ) + if not stablehloScatterOptions.ScatterDimsToOperandDimsIsNone(): + if np is None: + self.scatterDimsToOperandDims = [] + for i in range(stablehloScatterOptions.ScatterDimsToOperandDimsLength()): + self.scatterDimsToOperandDims.append( + stablehloScatterOptions.ScatterDimsToOperandDims(i)) + else: + self.scatterDimsToOperandDims = stablehloScatterOptions.ScatterDimsToOperandDimsAsNumpy( + ) + self.indexVectorDim = stablehloScatterOptions.IndexVectorDim() + self.uniqueIndices = stablehloScatterOptions.UniqueIndices() + self.updateComputationSubgraphIndex = stablehloScatterOptions.UpdateComputationSubgraphIndex( + ) + + # StablehloScatterOptionsT + def Pack(self, builder): + if self.updateWindowDims is not None: + if np is not None and type(self.updateWindowDims) is np.ndarray: + updateWindowDims = builder.CreateNumpyVector(self.updateWindowDims) + else: + StablehloScatterOptionsStartUpdateWindowDimsVector( + builder, len(self.updateWindowDims)) + for i in reversed(range(len(self.updateWindowDims))): + builder.PrependInt64(self.updateWindowDims[i]) + updateWindowDims = builder.EndVector() + if self.insertedWindowDims is not None: + if np is not None and type(self.insertedWindowDims) is np.ndarray: + insertedWindowDims = builder.CreateNumpyVector(self.insertedWindowDims) + else: + StablehloScatterOptionsStartInsertedWindowDimsVector( + builder, len(self.insertedWindowDims)) + for i in reversed(range(len(self.insertedWindowDims))): + builder.PrependInt64(self.insertedWindowDims[i]) + insertedWindowDims = builder.EndVector() + if self.scatterDimsToOperandDims is not None: + if np is not None and type(self.scatterDimsToOperandDims) is np.ndarray: + scatterDimsToOperandDims = builder.CreateNumpyVector( + self.scatterDimsToOperandDims) + else: + StablehloScatterOptionsStartScatterDimsToOperandDimsVector( + builder, len(self.scatterDimsToOperandDims)) + for i in reversed(range(len(self.scatterDimsToOperandDims))): + builder.PrependInt64(self.scatterDimsToOperandDims[i]) + scatterDimsToOperandDims = builder.EndVector() + StablehloScatterOptionsStart(builder) + StablehloScatterOptionsAddIndicesAreSorted(builder, self.indicesAreSorted) + if self.updateWindowDims is not None: + StablehloScatterOptionsAddUpdateWindowDims(builder, updateWindowDims) + if self.insertedWindowDims is not None: + StablehloScatterOptionsAddInsertedWindowDims(builder, insertedWindowDims) + if self.scatterDimsToOperandDims is not None: + StablehloScatterOptionsAddScatterDimsToOperandDims(builder, + scatterDimsToOperandDims) + StablehloScatterOptionsAddIndexVectorDim(builder, self.indexVectorDim) + StablehloScatterOptionsAddUniqueIndices(builder, self.uniqueIndices) + StablehloScatterOptionsAddUpdateComputationSubgraphIndex( + builder, self.updateComputationSubgraphIndex) + stablehloScatterOptions = StablehloScatterOptionsEnd(builder) + return stablehloScatterOptions + + +class StablehloRngBitGeneratorOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloRngBitGeneratorOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloRngBitGeneratorOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloRngBitGeneratorOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloRngBitGeneratorOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloRngBitGeneratorOptions + def Algorithm(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def StablehloRngBitGeneratorOptionsStart(builder): + builder.StartObject(1) + + +def StablehloRngBitGeneratorOptionsAddAlgorithm(builder, algorithm): + builder.PrependInt8Slot(0, algorithm, 0) + + +def StablehloRngBitGeneratorOptionsEnd(builder): + return builder.EndObject() + + +class StablehloRngBitGeneratorOptionsT(object): + + # StablehloRngBitGeneratorOptionsT + def __init__(self): + self.algorithm = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloRngBitGeneratorOptions = StablehloRngBitGeneratorOptions() + stablehloRngBitGeneratorOptions.Init(buf, pos) + return cls.InitFromObj(stablehloRngBitGeneratorOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloRngBitGeneratorOptions): + x = StablehloRngBitGeneratorOptionsT() + x._UnPack(stablehloRngBitGeneratorOptions) + return x + + # StablehloRngBitGeneratorOptionsT + def _UnPack(self, stablehloRngBitGeneratorOptions): + if stablehloRngBitGeneratorOptions is None: + return + self.algorithm = stablehloRngBitGeneratorOptions.Algorithm() + + # StablehloRngBitGeneratorOptionsT + def Pack(self, builder): + StablehloRngBitGeneratorOptionsStart(builder) + StablehloRngBitGeneratorOptionsAddAlgorithm(builder, self.algorithm) + stablehloRngBitGeneratorOptions = StablehloRngBitGeneratorOptionsEnd(builder) + return stablehloRngBitGeneratorOptions + + +class Conv2DOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Conv2DOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsConv2DOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Conv2DOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Conv2DOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Conv2DOptions + def Padding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Conv2DOptions + def StrideW(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Conv2DOptions + def StrideH(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Conv2DOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Conv2DOptions + def DilationWFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # Conv2DOptions + def DilationHFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # Conv2DOptions + def QuantizedBiasType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def Conv2DOptionsStart(builder): + builder.StartObject(7) + + +def Conv2DOptionsAddPadding(builder, padding): + builder.PrependInt8Slot(0, padding, 0) + + +def Conv2DOptionsAddStrideW(builder, strideW): + builder.PrependInt32Slot(1, strideW, 0) + + +def Conv2DOptionsAddStrideH(builder, strideH): + builder.PrependInt32Slot(2, strideH, 0) + + +def Conv2DOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(3, fusedActivationFunction, 0) + + +def Conv2DOptionsAddDilationWFactor(builder, dilationWFactor): + builder.PrependInt32Slot(4, dilationWFactor, 1) + + +def Conv2DOptionsAddDilationHFactor(builder, dilationHFactor): + builder.PrependInt32Slot(5, dilationHFactor, 1) + + +def Conv2DOptionsAddQuantizedBiasType(builder, quantizedBiasType): + builder.PrependInt8Slot(6, quantizedBiasType, 0) + + +def Conv2DOptionsEnd(builder): + return builder.EndObject() + + +class Conv2DOptionsT(object): + + # Conv2DOptionsT + def __init__(self): + self.padding = 0 # type: int + self.strideW = 0 # type: int + self.strideH = 0 # type: int + self.fusedActivationFunction = 0 # type: int + self.dilationWFactor = 1 # type: int + self.dilationHFactor = 1 # type: int + self.quantizedBiasType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + conv2Doptions = Conv2DOptions() + conv2Doptions.Init(buf, pos) + return cls.InitFromObj(conv2Doptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, conv2Doptions): + x = Conv2DOptionsT() + x._UnPack(conv2Doptions) + return x + + # Conv2DOptionsT + def _UnPack(self, conv2Doptions): + if conv2Doptions is None: + return + self.padding = conv2Doptions.Padding() + self.strideW = conv2Doptions.StrideW() + self.strideH = conv2Doptions.StrideH() + self.fusedActivationFunction = conv2Doptions.FusedActivationFunction() + self.dilationWFactor = conv2Doptions.DilationWFactor() + self.dilationHFactor = conv2Doptions.DilationHFactor() + self.quantizedBiasType = conv2Doptions.QuantizedBiasType() + + # Conv2DOptionsT + def Pack(self, builder): + Conv2DOptionsStart(builder) + Conv2DOptionsAddPadding(builder, self.padding) + Conv2DOptionsAddStrideW(builder, self.strideW) + Conv2DOptionsAddStrideH(builder, self.strideH) + Conv2DOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + Conv2DOptionsAddDilationWFactor(builder, self.dilationWFactor) + Conv2DOptionsAddDilationHFactor(builder, self.dilationHFactor) + Conv2DOptionsAddQuantizedBiasType(builder, self.quantizedBiasType) + conv2Doptions = Conv2DOptionsEnd(builder) + return conv2Doptions + + +class Conv3DOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Conv3DOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsConv3DOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Conv3DOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Conv3DOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Conv3DOptions + def Padding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Conv3DOptions + def StrideD(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Conv3DOptions + def StrideW(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Conv3DOptions + def StrideH(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Conv3DOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Conv3DOptions + def DilationDFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # Conv3DOptions + def DilationWFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # Conv3DOptions + def DilationHFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + +def Conv3DOptionsStart(builder): + builder.StartObject(8) + + +def Conv3DOptionsAddPadding(builder, padding): + builder.PrependInt8Slot(0, padding, 0) + + +def Conv3DOptionsAddStrideD(builder, strideD): + builder.PrependInt32Slot(1, strideD, 0) + + +def Conv3DOptionsAddStrideW(builder, strideW): + builder.PrependInt32Slot(2, strideW, 0) + + +def Conv3DOptionsAddStrideH(builder, strideH): + builder.PrependInt32Slot(3, strideH, 0) + + +def Conv3DOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(4, fusedActivationFunction, 0) + + +def Conv3DOptionsAddDilationDFactor(builder, dilationDFactor): + builder.PrependInt32Slot(5, dilationDFactor, 1) + + +def Conv3DOptionsAddDilationWFactor(builder, dilationWFactor): + builder.PrependInt32Slot(6, dilationWFactor, 1) + + +def Conv3DOptionsAddDilationHFactor(builder, dilationHFactor): + builder.PrependInt32Slot(7, dilationHFactor, 1) + + +def Conv3DOptionsEnd(builder): + return builder.EndObject() + + +class Conv3DOptionsT(object): + + # Conv3DOptionsT + def __init__(self): + self.padding = 0 # type: int + self.strideD = 0 # type: int + self.strideW = 0 # type: int + self.strideH = 0 # type: int + self.fusedActivationFunction = 0 # type: int + self.dilationDFactor = 1 # type: int + self.dilationWFactor = 1 # type: int + self.dilationHFactor = 1 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + conv3Doptions = Conv3DOptions() + conv3Doptions.Init(buf, pos) + return cls.InitFromObj(conv3Doptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, conv3Doptions): + x = Conv3DOptionsT() + x._UnPack(conv3Doptions) + return x + + # Conv3DOptionsT + def _UnPack(self, conv3Doptions): + if conv3Doptions is None: + return + self.padding = conv3Doptions.Padding() + self.strideD = conv3Doptions.StrideD() + self.strideW = conv3Doptions.StrideW() + self.strideH = conv3Doptions.StrideH() + self.fusedActivationFunction = conv3Doptions.FusedActivationFunction() + self.dilationDFactor = conv3Doptions.DilationDFactor() + self.dilationWFactor = conv3Doptions.DilationWFactor() + self.dilationHFactor = conv3Doptions.DilationHFactor() + + # Conv3DOptionsT + def Pack(self, builder): + Conv3DOptionsStart(builder) + Conv3DOptionsAddPadding(builder, self.padding) + Conv3DOptionsAddStrideD(builder, self.strideD) + Conv3DOptionsAddStrideW(builder, self.strideW) + Conv3DOptionsAddStrideH(builder, self.strideH) + Conv3DOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + Conv3DOptionsAddDilationDFactor(builder, self.dilationDFactor) + Conv3DOptionsAddDilationWFactor(builder, self.dilationWFactor) + Conv3DOptionsAddDilationHFactor(builder, self.dilationHFactor) + conv3Doptions = Conv3DOptionsEnd(builder) + return conv3Doptions + + +class Pool2DOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Pool2DOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPool2DOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Pool2DOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Pool2DOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Pool2DOptions + def Padding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Pool2DOptions + def StrideW(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Pool2DOptions + def StrideH(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Pool2DOptions + def FilterWidth(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Pool2DOptions + def FilterHeight(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Pool2DOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def Pool2DOptionsStart(builder): + builder.StartObject(6) + + +def Pool2DOptionsAddPadding(builder, padding): + builder.PrependInt8Slot(0, padding, 0) + + +def Pool2DOptionsAddStrideW(builder, strideW): + builder.PrependInt32Slot(1, strideW, 0) + + +def Pool2DOptionsAddStrideH(builder, strideH): + builder.PrependInt32Slot(2, strideH, 0) + + +def Pool2DOptionsAddFilterWidth(builder, filterWidth): + builder.PrependInt32Slot(3, filterWidth, 0) + + +def Pool2DOptionsAddFilterHeight(builder, filterHeight): + builder.PrependInt32Slot(4, filterHeight, 0) + + +def Pool2DOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(5, fusedActivationFunction, 0) + + +def Pool2DOptionsEnd(builder): + return builder.EndObject() + + +class Pool2DOptionsT(object): + + # Pool2DOptionsT + def __init__(self): + self.padding = 0 # type: int + self.strideW = 0 # type: int + self.strideH = 0 # type: int + self.filterWidth = 0 # type: int + self.filterHeight = 0 # type: int + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + pool2Doptions = Pool2DOptions() + pool2Doptions.Init(buf, pos) + return cls.InitFromObj(pool2Doptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, pool2Doptions): + x = Pool2DOptionsT() + x._UnPack(pool2Doptions) + return x + + # Pool2DOptionsT + def _UnPack(self, pool2Doptions): + if pool2Doptions is None: + return + self.padding = pool2Doptions.Padding() + self.strideW = pool2Doptions.StrideW() + self.strideH = pool2Doptions.StrideH() + self.filterWidth = pool2Doptions.FilterWidth() + self.filterHeight = pool2Doptions.FilterHeight() + self.fusedActivationFunction = pool2Doptions.FusedActivationFunction() + + # Pool2DOptionsT + def Pack(self, builder): + Pool2DOptionsStart(builder) + Pool2DOptionsAddPadding(builder, self.padding) + Pool2DOptionsAddStrideW(builder, self.strideW) + Pool2DOptionsAddStrideH(builder, self.strideH) + Pool2DOptionsAddFilterWidth(builder, self.filterWidth) + Pool2DOptionsAddFilterHeight(builder, self.filterHeight) + Pool2DOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + pool2Doptions = Pool2DOptionsEnd(builder) + return pool2Doptions + + +class DepthwiseConv2DOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DepthwiseConv2DOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDepthwiseConv2DOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DepthwiseConv2DOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DepthwiseConv2DOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DepthwiseConv2DOptions + def Padding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # DepthwiseConv2DOptions + def StrideW(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # DepthwiseConv2DOptions + def StrideH(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # DepthwiseConv2DOptions + def DepthMultiplier(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # DepthwiseConv2DOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # DepthwiseConv2DOptions + def DilationWFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # DepthwiseConv2DOptions + def DilationHFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + +def DepthwiseConv2DOptionsStart(builder): + builder.StartObject(7) + + +def DepthwiseConv2DOptionsAddPadding(builder, padding): + builder.PrependInt8Slot(0, padding, 0) + + +def DepthwiseConv2DOptionsAddStrideW(builder, strideW): + builder.PrependInt32Slot(1, strideW, 0) + + +def DepthwiseConv2DOptionsAddStrideH(builder, strideH): + builder.PrependInt32Slot(2, strideH, 0) + + +def DepthwiseConv2DOptionsAddDepthMultiplier(builder, depthMultiplier): + builder.PrependInt32Slot(3, depthMultiplier, 0) + + +def DepthwiseConv2DOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(4, fusedActivationFunction, 0) + + +def DepthwiseConv2DOptionsAddDilationWFactor(builder, dilationWFactor): + builder.PrependInt32Slot(5, dilationWFactor, 1) + + +def DepthwiseConv2DOptionsAddDilationHFactor(builder, dilationHFactor): + builder.PrependInt32Slot(6, dilationHFactor, 1) + + +def DepthwiseConv2DOptionsEnd(builder): + return builder.EndObject() + + +class DepthwiseConv2DOptionsT(object): + + # DepthwiseConv2DOptionsT + def __init__(self): + self.padding = 0 # type: int + self.strideW = 0 # type: int + self.strideH = 0 # type: int + self.depthMultiplier = 0 # type: int + self.fusedActivationFunction = 0 # type: int + self.dilationWFactor = 1 # type: int + self.dilationHFactor = 1 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + depthwiseConv2Doptions = DepthwiseConv2DOptions() + depthwiseConv2Doptions.Init(buf, pos) + return cls.InitFromObj(depthwiseConv2Doptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, depthwiseConv2Doptions): + x = DepthwiseConv2DOptionsT() + x._UnPack(depthwiseConv2Doptions) + return x + + # DepthwiseConv2DOptionsT + def _UnPack(self, depthwiseConv2Doptions): + if depthwiseConv2Doptions is None: + return + self.padding = depthwiseConv2Doptions.Padding() + self.strideW = depthwiseConv2Doptions.StrideW() + self.strideH = depthwiseConv2Doptions.StrideH() + self.depthMultiplier = depthwiseConv2Doptions.DepthMultiplier() + self.fusedActivationFunction = depthwiseConv2Doptions.FusedActivationFunction() + self.dilationWFactor = depthwiseConv2Doptions.DilationWFactor() + self.dilationHFactor = depthwiseConv2Doptions.DilationHFactor() + + # DepthwiseConv2DOptionsT + def Pack(self, builder): + DepthwiseConv2DOptionsStart(builder) + DepthwiseConv2DOptionsAddPadding(builder, self.padding) + DepthwiseConv2DOptionsAddStrideW(builder, self.strideW) + DepthwiseConv2DOptionsAddStrideH(builder, self.strideH) + DepthwiseConv2DOptionsAddDepthMultiplier(builder, self.depthMultiplier) + DepthwiseConv2DOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + DepthwiseConv2DOptionsAddDilationWFactor(builder, self.dilationWFactor) + DepthwiseConv2DOptionsAddDilationHFactor(builder, self.dilationHFactor) + depthwiseConv2Doptions = DepthwiseConv2DOptionsEnd(builder) + return depthwiseConv2Doptions + + +class ConcatEmbeddingsOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ConcatEmbeddingsOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsConcatEmbeddingsOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ConcatEmbeddingsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ConcatEmbeddingsOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ConcatEmbeddingsOptions + def NumChannels(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ConcatEmbeddingsOptions + def NumColumnsPerChannel(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ConcatEmbeddingsOptions + def NumColumnsPerChannelAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ConcatEmbeddingsOptions + def NumColumnsPerChannelLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ConcatEmbeddingsOptions + def NumColumnsPerChannelIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # ConcatEmbeddingsOptions + def EmbeddingDimPerChannel(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ConcatEmbeddingsOptions + def EmbeddingDimPerChannelAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ConcatEmbeddingsOptions + def EmbeddingDimPerChannelLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ConcatEmbeddingsOptions + def EmbeddingDimPerChannelIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + +def ConcatEmbeddingsOptionsStart(builder): + builder.StartObject(3) + + +def ConcatEmbeddingsOptionsAddNumChannels(builder, numChannels): + builder.PrependInt32Slot(0, numChannels, 0) + + +def ConcatEmbeddingsOptionsAddNumColumnsPerChannel(builder, numColumnsPerChannel): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(numColumnsPerChannel), 0) + + +def ConcatEmbeddingsOptionsStartNumColumnsPerChannelVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ConcatEmbeddingsOptionsAddEmbeddingDimPerChannel(builder, embeddingDimPerChannel): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(embeddingDimPerChannel), 0) + + +def ConcatEmbeddingsOptionsStartEmbeddingDimPerChannelVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ConcatEmbeddingsOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class ConcatEmbeddingsOptionsT(object): + + # ConcatEmbeddingsOptionsT + def __init__(self): + self.numChannels = 0 # type: int + self.numColumnsPerChannel = None # type: List[int] + self.embeddingDimPerChannel = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + concatEmbeddingsOptions = ConcatEmbeddingsOptions() + concatEmbeddingsOptions.Init(buf, pos) + return cls.InitFromObj(concatEmbeddingsOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, concatEmbeddingsOptions): + x = ConcatEmbeddingsOptionsT() + x._UnPack(concatEmbeddingsOptions) + return x + + # ConcatEmbeddingsOptionsT + def _UnPack(self, concatEmbeddingsOptions): + if concatEmbeddingsOptions is None: + return + self.numChannels = concatEmbeddingsOptions.NumChannels() + if not concatEmbeddingsOptions.NumColumnsPerChannelIsNone(): + if np is None: + self.numColumnsPerChannel = [] + for i in range(concatEmbeddingsOptions.NumColumnsPerChannelLength()): + self.numColumnsPerChannel.append( + concatEmbeddingsOptions.NumColumnsPerChannel(i)) + else: + self.numColumnsPerChannel = concatEmbeddingsOptions.NumColumnsPerChannelAsNumpy( + ) + if not concatEmbeddingsOptions.EmbeddingDimPerChannelIsNone(): + if np is None: + self.embeddingDimPerChannel = [] + for i in range(concatEmbeddingsOptions.EmbeddingDimPerChannelLength()): + self.embeddingDimPerChannel.append( + concatEmbeddingsOptions.EmbeddingDimPerChannel(i)) + else: + self.embeddingDimPerChannel = concatEmbeddingsOptions.EmbeddingDimPerChannelAsNumpy( + ) + + # ConcatEmbeddingsOptionsT + def Pack(self, builder): + if self.numColumnsPerChannel is not None: + if np is not None and type(self.numColumnsPerChannel) is np.ndarray: + numColumnsPerChannel = builder.CreateNumpyVector( + self.numColumnsPerChannel) + else: + ConcatEmbeddingsOptionsStartNumColumnsPerChannelVector( + builder, len(self.numColumnsPerChannel)) + for i in reversed(range(len(self.numColumnsPerChannel))): + builder.PrependInt32(self.numColumnsPerChannel[i]) + numColumnsPerChannel = builder.EndVector() + if self.embeddingDimPerChannel is not None: + if np is not None and type(self.embeddingDimPerChannel) is np.ndarray: + embeddingDimPerChannel = builder.CreateNumpyVector( + self.embeddingDimPerChannel) + else: + ConcatEmbeddingsOptionsStartEmbeddingDimPerChannelVector( + builder, len(self.embeddingDimPerChannel)) + for i in reversed(range(len(self.embeddingDimPerChannel))): + builder.PrependInt32(self.embeddingDimPerChannel[i]) + embeddingDimPerChannel = builder.EndVector() + ConcatEmbeddingsOptionsStart(builder) + ConcatEmbeddingsOptionsAddNumChannels(builder, self.numChannels) + if self.numColumnsPerChannel is not None: + ConcatEmbeddingsOptionsAddNumColumnsPerChannel(builder, numColumnsPerChannel) + if self.embeddingDimPerChannel is not None: + ConcatEmbeddingsOptionsAddEmbeddingDimPerChannel(builder, + embeddingDimPerChannel) + concatEmbeddingsOptions = ConcatEmbeddingsOptionsEnd(builder) + return concatEmbeddingsOptions + + +class LSHProjectionOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LSHProjectionOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLSHProjectionOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LSHProjectionOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LSHProjectionOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # LSHProjectionOptions + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def LSHProjectionOptionsStart(builder): + builder.StartObject(1) + + +def LSHProjectionOptionsAddType(builder, type): + builder.PrependInt8Slot(0, type, 0) + + +def LSHProjectionOptionsEnd(builder): + return builder.EndObject() + + +class LSHProjectionOptionsT(object): + + # LSHProjectionOptionsT + def __init__(self): + self.type = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + lshprojectionOptions = LSHProjectionOptions() + lshprojectionOptions.Init(buf, pos) + return cls.InitFromObj(lshprojectionOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, lshprojectionOptions): + x = LSHProjectionOptionsT() + x._UnPack(lshprojectionOptions) + return x + + # LSHProjectionOptionsT + def _UnPack(self, lshprojectionOptions): + if lshprojectionOptions is None: + return + self.type = lshprojectionOptions.Type() + + # LSHProjectionOptionsT + def Pack(self, builder): + LSHProjectionOptionsStart(builder) + LSHProjectionOptionsAddType(builder, self.type) + lshprojectionOptions = LSHProjectionOptionsEnd(builder) + return lshprojectionOptions + + +class SVDFOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SVDFOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSVDFOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SVDFOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SVDFOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SVDFOptions + def Rank(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # SVDFOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # SVDFOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def SVDFOptionsStart(builder): + builder.StartObject(3) + + +def SVDFOptionsAddRank(builder, rank): + builder.PrependInt32Slot(0, rank, 0) + + +def SVDFOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def SVDFOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(2, asymmetricQuantizeInputs, 0) + + +def SVDFOptionsEnd(builder): + return builder.EndObject() + + +class SVDFOptionsT(object): + + # SVDFOptionsT + def __init__(self): + self.rank = 0 # type: int + self.fusedActivationFunction = 0 # type: int + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + svdfoptions = SVDFOptions() + svdfoptions.Init(buf, pos) + return cls.InitFromObj(svdfoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, svdfoptions): + x = SVDFOptionsT() + x._UnPack(svdfoptions) + return x + + # SVDFOptionsT + def _UnPack(self, svdfoptions): + if svdfoptions is None: + return + self.rank = svdfoptions.Rank() + self.fusedActivationFunction = svdfoptions.FusedActivationFunction() + self.asymmetricQuantizeInputs = svdfoptions.AsymmetricQuantizeInputs() + + # SVDFOptionsT + def Pack(self, builder): + SVDFOptionsStart(builder) + SVDFOptionsAddRank(builder, self.rank) + SVDFOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + SVDFOptionsAddAsymmetricQuantizeInputs(builder, self.asymmetricQuantizeInputs) + svdfoptions = SVDFOptionsEnd(builder) + return svdfoptions + + +class RNNOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RNNOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRNNOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RNNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RNNOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RNNOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # RNNOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def RNNOptionsStart(builder): + builder.StartObject(2) + + +def RNNOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def RNNOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(1, asymmetricQuantizeInputs, 0) + + +def RNNOptionsEnd(builder): + return builder.EndObject() + + +class RNNOptionsT(object): + + # RNNOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + rnnoptions = RNNOptions() + rnnoptions.Init(buf, pos) + return cls.InitFromObj(rnnoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rnnoptions): + x = RNNOptionsT() + x._UnPack(rnnoptions) + return x + + # RNNOptionsT + def _UnPack(self, rnnoptions): + if rnnoptions is None: + return + self.fusedActivationFunction = rnnoptions.FusedActivationFunction() + self.asymmetricQuantizeInputs = rnnoptions.AsymmetricQuantizeInputs() + + # RNNOptionsT + def Pack(self, builder): + RNNOptionsStart(builder) + RNNOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + RNNOptionsAddAsymmetricQuantizeInputs(builder, self.asymmetricQuantizeInputs) + rnnoptions = RNNOptionsEnd(builder) + return rnnoptions + + +class SequenceRNNOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SequenceRNNOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSequenceRNNOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SequenceRNNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SequenceRNNOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SequenceRNNOptions + def TimeMajor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # SequenceRNNOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # SequenceRNNOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def SequenceRNNOptionsStart(builder): + builder.StartObject(3) + + +def SequenceRNNOptionsAddTimeMajor(builder, timeMajor): + builder.PrependBoolSlot(0, timeMajor, 0) + + +def SequenceRNNOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def SequenceRNNOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(2, asymmetricQuantizeInputs, 0) + + +def SequenceRNNOptionsEnd(builder): + return builder.EndObject() + + +class SequenceRNNOptionsT(object): + + # SequenceRNNOptionsT + def __init__(self): + self.timeMajor = False # type: bool + self.fusedActivationFunction = 0 # type: int + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + sequenceRnnoptions = SequenceRNNOptions() + sequenceRnnoptions.Init(buf, pos) + return cls.InitFromObj(sequenceRnnoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, sequenceRnnoptions): + x = SequenceRNNOptionsT() + x._UnPack(sequenceRnnoptions) + return x + + # SequenceRNNOptionsT + def _UnPack(self, sequenceRnnoptions): + if sequenceRnnoptions is None: + return + self.timeMajor = sequenceRnnoptions.TimeMajor() + self.fusedActivationFunction = sequenceRnnoptions.FusedActivationFunction() + self.asymmetricQuantizeInputs = sequenceRnnoptions.AsymmetricQuantizeInputs() + + # SequenceRNNOptionsT + def Pack(self, builder): + SequenceRNNOptionsStart(builder) + SequenceRNNOptionsAddTimeMajor(builder, self.timeMajor) + SequenceRNNOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + SequenceRNNOptionsAddAsymmetricQuantizeInputs(builder, + self.asymmetricQuantizeInputs) + sequenceRnnoptions = SequenceRNNOptionsEnd(builder) + return sequenceRnnoptions + + +class BidirectionalSequenceRNNOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BidirectionalSequenceRNNOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBidirectionalSequenceRNNOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BidirectionalSequenceRNNOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BidirectionalSequenceRNNOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BidirectionalSequenceRNNOptions + def TimeMajor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # BidirectionalSequenceRNNOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # BidirectionalSequenceRNNOptions + def MergeOutputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # BidirectionalSequenceRNNOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def BidirectionalSequenceRNNOptionsStart(builder): + builder.StartObject(4) + + +def BidirectionalSequenceRNNOptionsAddTimeMajor(builder, timeMajor): + builder.PrependBoolSlot(0, timeMajor, 0) + + +def BidirectionalSequenceRNNOptionsAddFusedActivationFunction(builder, + fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def BidirectionalSequenceRNNOptionsAddMergeOutputs(builder, mergeOutputs): + builder.PrependBoolSlot(2, mergeOutputs, 0) + + +def BidirectionalSequenceRNNOptionsAddAsymmetricQuantizeInputs(builder, + asymmetricQuantizeInputs): + builder.PrependBoolSlot(3, asymmetricQuantizeInputs, 0) + + +def BidirectionalSequenceRNNOptionsEnd(builder): + return builder.EndObject() + + +class BidirectionalSequenceRNNOptionsT(object): + + # BidirectionalSequenceRNNOptionsT + def __init__(self): + self.timeMajor = False # type: bool + self.fusedActivationFunction = 0 # type: int + self.mergeOutputs = False # type: bool + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + bidirectionalSequenceRnnoptions = BidirectionalSequenceRNNOptions() + bidirectionalSequenceRnnoptions.Init(buf, pos) + return cls.InitFromObj(bidirectionalSequenceRnnoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bidirectionalSequenceRnnoptions): + x = BidirectionalSequenceRNNOptionsT() + x._UnPack(bidirectionalSequenceRnnoptions) + return x + + # BidirectionalSequenceRNNOptionsT + def _UnPack(self, bidirectionalSequenceRnnoptions): + if bidirectionalSequenceRnnoptions is None: + return + self.timeMajor = bidirectionalSequenceRnnoptions.TimeMajor() + self.fusedActivationFunction = bidirectionalSequenceRnnoptions.FusedActivationFunction( + ) + self.mergeOutputs = bidirectionalSequenceRnnoptions.MergeOutputs() + self.asymmetricQuantizeInputs = bidirectionalSequenceRnnoptions.AsymmetricQuantizeInputs( + ) + + # BidirectionalSequenceRNNOptionsT + def Pack(self, builder): + BidirectionalSequenceRNNOptionsStart(builder) + BidirectionalSequenceRNNOptionsAddTimeMajor(builder, self.timeMajor) + BidirectionalSequenceRNNOptionsAddFusedActivationFunction( + builder, self.fusedActivationFunction) + BidirectionalSequenceRNNOptionsAddMergeOutputs(builder, self.mergeOutputs) + BidirectionalSequenceRNNOptionsAddAsymmetricQuantizeInputs( + builder, self.asymmetricQuantizeInputs) + bidirectionalSequenceRnnoptions = BidirectionalSequenceRNNOptionsEnd(builder) + return bidirectionalSequenceRnnoptions + + +class FullyConnectedOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FullyConnectedOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsFullyConnectedOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def FullyConnectedOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # FullyConnectedOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # FullyConnectedOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # FullyConnectedOptions + def WeightsFormat(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # FullyConnectedOptions + def KeepNumDims(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # FullyConnectedOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # FullyConnectedOptions + def QuantizedBiasType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def FullyConnectedOptionsStart(builder): + builder.StartObject(5) + + +def FullyConnectedOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def FullyConnectedOptionsAddWeightsFormat(builder, weightsFormat): + builder.PrependInt8Slot(1, weightsFormat, 0) + + +def FullyConnectedOptionsAddKeepNumDims(builder, keepNumDims): + builder.PrependBoolSlot(2, keepNumDims, 0) + + +def FullyConnectedOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(3, asymmetricQuantizeInputs, 0) + + +def FullyConnectedOptionsAddQuantizedBiasType(builder, quantizedBiasType): + builder.PrependInt8Slot(4, quantizedBiasType, 0) + + +def FullyConnectedOptionsEnd(builder): + return builder.EndObject() + + +class FullyConnectedOptionsT(object): + + # FullyConnectedOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.weightsFormat = 0 # type: int + self.keepNumDims = False # type: bool + self.asymmetricQuantizeInputs = False # type: bool + self.quantizedBiasType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + fullyConnectedOptions = FullyConnectedOptions() + fullyConnectedOptions.Init(buf, pos) + return cls.InitFromObj(fullyConnectedOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, fullyConnectedOptions): + x = FullyConnectedOptionsT() + x._UnPack(fullyConnectedOptions) + return x + + # FullyConnectedOptionsT + def _UnPack(self, fullyConnectedOptions): + if fullyConnectedOptions is None: + return + self.fusedActivationFunction = fullyConnectedOptions.FusedActivationFunction() + self.weightsFormat = fullyConnectedOptions.WeightsFormat() + self.keepNumDims = fullyConnectedOptions.KeepNumDims() + self.asymmetricQuantizeInputs = fullyConnectedOptions.AsymmetricQuantizeInputs() + self.quantizedBiasType = fullyConnectedOptions.QuantizedBiasType() + + # FullyConnectedOptionsT + def Pack(self, builder): + FullyConnectedOptionsStart(builder) + FullyConnectedOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + FullyConnectedOptionsAddWeightsFormat(builder, self.weightsFormat) + FullyConnectedOptionsAddKeepNumDims(builder, self.keepNumDims) + FullyConnectedOptionsAddAsymmetricQuantizeInputs(builder, + self.asymmetricQuantizeInputs) + FullyConnectedOptionsAddQuantizedBiasType(builder, self.quantizedBiasType) + fullyConnectedOptions = FullyConnectedOptionsEnd(builder) + return fullyConnectedOptions + + +class SoftmaxOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SoftmaxOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSoftmaxOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SoftmaxOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SoftmaxOptions + def Beta(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + +def SoftmaxOptionsStart(builder): + builder.StartObject(1) + + +def SoftmaxOptionsAddBeta(builder, beta): + builder.PrependFloat32Slot(0, beta, 0.0) + + +def SoftmaxOptionsEnd(builder): + return builder.EndObject() + + +class SoftmaxOptionsT(object): + + # SoftmaxOptionsT + def __init__(self): + self.beta = 0.0 # type: float + + @classmethod + def InitFromBuf(cls, buf, pos): + softmaxOptions = SoftmaxOptions() + softmaxOptions.Init(buf, pos) + return cls.InitFromObj(softmaxOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, softmaxOptions): + x = SoftmaxOptionsT() + x._UnPack(softmaxOptions) + return x + + # SoftmaxOptionsT + def _UnPack(self, softmaxOptions): + if softmaxOptions is None: + return + self.beta = softmaxOptions.Beta() + + # SoftmaxOptionsT + def Pack(self, builder): + SoftmaxOptionsStart(builder) + SoftmaxOptionsAddBeta(builder, self.beta) + softmaxOptions = SoftmaxOptionsEnd(builder) + return softmaxOptions + + +class ConcatenationOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ConcatenationOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsConcatenationOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ConcatenationOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ConcatenationOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ConcatenationOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ConcatenationOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def ConcatenationOptionsStart(builder): + builder.StartObject(2) + + +def ConcatenationOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(0, axis, 0) + + +def ConcatenationOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def ConcatenationOptionsEnd(builder): + return builder.EndObject() + + +class ConcatenationOptionsT(object): + + # ConcatenationOptionsT + def __init__(self): + self.axis = 0 # type: int + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + concatenationOptions = ConcatenationOptions() + concatenationOptions.Init(buf, pos) + return cls.InitFromObj(concatenationOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, concatenationOptions): + x = ConcatenationOptionsT() + x._UnPack(concatenationOptions) + return x + + # ConcatenationOptionsT + def _UnPack(self, concatenationOptions): + if concatenationOptions is None: + return + self.axis = concatenationOptions.Axis() + self.fusedActivationFunction = concatenationOptions.FusedActivationFunction() + + # ConcatenationOptionsT + def Pack(self, builder): + ConcatenationOptionsStart(builder) + ConcatenationOptionsAddAxis(builder, self.axis) + ConcatenationOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + concatenationOptions = ConcatenationOptionsEnd(builder) + return concatenationOptions + + +class AddOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AddOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAddOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def AddOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # AddOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # AddOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # AddOptions + def PotScaleInt16(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return True + + +def AddOptionsStart(builder): + builder.StartObject(2) + + +def AddOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def AddOptionsAddPotScaleInt16(builder, potScaleInt16): + builder.PrependBoolSlot(1, potScaleInt16, 1) + + +def AddOptionsEnd(builder): + return builder.EndObject() + + +class AddOptionsT(object): + + # AddOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.potScaleInt16 = True # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + addOptions = AddOptions() + addOptions.Init(buf, pos) + return cls.InitFromObj(addOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, addOptions): + x = AddOptionsT() + x._UnPack(addOptions) + return x + + # AddOptionsT + def _UnPack(self, addOptions): + if addOptions is None: + return + self.fusedActivationFunction = addOptions.FusedActivationFunction() + self.potScaleInt16 = addOptions.PotScaleInt16() + + # AddOptionsT + def Pack(self, builder): + AddOptionsStart(builder) + AddOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + AddOptionsAddPotScaleInt16(builder, self.potScaleInt16) + addOptions = AddOptionsEnd(builder) + return addOptions + + +class MulOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MulOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMulOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MulOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MulOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MulOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def MulOptionsStart(builder): + builder.StartObject(1) + + +def MulOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def MulOptionsEnd(builder): + return builder.EndObject() + + +class MulOptionsT(object): + + # MulOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + mulOptions = MulOptions() + mulOptions.Init(buf, pos) + return cls.InitFromObj(mulOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, mulOptions): + x = MulOptionsT() + x._UnPack(mulOptions) + return x + + # MulOptionsT + def _UnPack(self, mulOptions): + if mulOptions is None: + return + self.fusedActivationFunction = mulOptions.FusedActivationFunction() + + # MulOptionsT + def Pack(self, builder): + MulOptionsStart(builder) + MulOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + mulOptions = MulOptionsEnd(builder) + return mulOptions + + +class L2NormOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = L2NormOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsL2NormOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def L2NormOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # L2NormOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # L2NormOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def L2NormOptionsStart(builder): + builder.StartObject(1) + + +def L2NormOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def L2NormOptionsEnd(builder): + return builder.EndObject() + + +class L2NormOptionsT(object): + + # L2NormOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + l2NormOptions = L2NormOptions() + l2NormOptions.Init(buf, pos) + return cls.InitFromObj(l2NormOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, l2NormOptions): + x = L2NormOptionsT() + x._UnPack(l2NormOptions) + return x + + # L2NormOptionsT + def _UnPack(self, l2NormOptions): + if l2NormOptions is None: + return + self.fusedActivationFunction = l2NormOptions.FusedActivationFunction() + + # L2NormOptionsT + def Pack(self, builder): + L2NormOptionsStart(builder) + L2NormOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + l2NormOptions = L2NormOptionsEnd(builder) + return l2NormOptions + + +class LocalResponseNormalizationOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LocalResponseNormalizationOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLocalResponseNormalizationOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LocalResponseNormalizationOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LocalResponseNormalizationOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # LocalResponseNormalizationOptions + def Radius(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # LocalResponseNormalizationOptions + def Bias(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # LocalResponseNormalizationOptions + def Alpha(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # LocalResponseNormalizationOptions + def Beta(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + +def LocalResponseNormalizationOptionsStart(builder): + builder.StartObject(4) + + +def LocalResponseNormalizationOptionsAddRadius(builder, radius): + builder.PrependInt32Slot(0, radius, 0) + + +def LocalResponseNormalizationOptionsAddBias(builder, bias): + builder.PrependFloat32Slot(1, bias, 0.0) + + +def LocalResponseNormalizationOptionsAddAlpha(builder, alpha): + builder.PrependFloat32Slot(2, alpha, 0.0) + + +def LocalResponseNormalizationOptionsAddBeta(builder, beta): + builder.PrependFloat32Slot(3, beta, 0.0) + + +def LocalResponseNormalizationOptionsEnd(builder): + return builder.EndObject() + + +class LocalResponseNormalizationOptionsT(object): + + # LocalResponseNormalizationOptionsT + def __init__(self): + self.radius = 0 # type: int + self.bias = 0.0 # type: float + self.alpha = 0.0 # type: float + self.beta = 0.0 # type: float + + @classmethod + def InitFromBuf(cls, buf, pos): + localResponseNormalizationOptions = LocalResponseNormalizationOptions() + localResponseNormalizationOptions.Init(buf, pos) + return cls.InitFromObj(localResponseNormalizationOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, localResponseNormalizationOptions): + x = LocalResponseNormalizationOptionsT() + x._UnPack(localResponseNormalizationOptions) + return x + + # LocalResponseNormalizationOptionsT + def _UnPack(self, localResponseNormalizationOptions): + if localResponseNormalizationOptions is None: + return + self.radius = localResponseNormalizationOptions.Radius() + self.bias = localResponseNormalizationOptions.Bias() + self.alpha = localResponseNormalizationOptions.Alpha() + self.beta = localResponseNormalizationOptions.Beta() + + # LocalResponseNormalizationOptionsT + def Pack(self, builder): + LocalResponseNormalizationOptionsStart(builder) + LocalResponseNormalizationOptionsAddRadius(builder, self.radius) + LocalResponseNormalizationOptionsAddBias(builder, self.bias) + LocalResponseNormalizationOptionsAddAlpha(builder, self.alpha) + LocalResponseNormalizationOptionsAddBeta(builder, self.beta) + localResponseNormalizationOptions = LocalResponseNormalizationOptionsEnd(builder) + return localResponseNormalizationOptions + + +class LSTMOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LSTMOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLSTMOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LSTMOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LSTMOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # LSTMOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # LSTMOptions + def CellClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # LSTMOptions + def ProjClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # LSTMOptions + def KernelType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # LSTMOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def LSTMOptionsStart(builder): + builder.StartObject(5) + + +def LSTMOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def LSTMOptionsAddCellClip(builder, cellClip): + builder.PrependFloat32Slot(1, cellClip, 0.0) + + +def LSTMOptionsAddProjClip(builder, projClip): + builder.PrependFloat32Slot(2, projClip, 0.0) + + +def LSTMOptionsAddKernelType(builder, kernelType): + builder.PrependInt8Slot(3, kernelType, 0) + + +def LSTMOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(4, asymmetricQuantizeInputs, 0) + + +def LSTMOptionsEnd(builder): + return builder.EndObject() + + +class LSTMOptionsT(object): + + # LSTMOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.cellClip = 0.0 # type: float + self.projClip = 0.0 # type: float + self.kernelType = 0 # type: int + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + lstmoptions = LSTMOptions() + lstmoptions.Init(buf, pos) + return cls.InitFromObj(lstmoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, lstmoptions): + x = LSTMOptionsT() + x._UnPack(lstmoptions) + return x + + # LSTMOptionsT + def _UnPack(self, lstmoptions): + if lstmoptions is None: + return + self.fusedActivationFunction = lstmoptions.FusedActivationFunction() + self.cellClip = lstmoptions.CellClip() + self.projClip = lstmoptions.ProjClip() + self.kernelType = lstmoptions.KernelType() + self.asymmetricQuantizeInputs = lstmoptions.AsymmetricQuantizeInputs() + + # LSTMOptionsT + def Pack(self, builder): + LSTMOptionsStart(builder) + LSTMOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + LSTMOptionsAddCellClip(builder, self.cellClip) + LSTMOptionsAddProjClip(builder, self.projClip) + LSTMOptionsAddKernelType(builder, self.kernelType) + LSTMOptionsAddAsymmetricQuantizeInputs(builder, self.asymmetricQuantizeInputs) + lstmoptions = LSTMOptionsEnd(builder) + return lstmoptions + + +class UnidirectionalSequenceLSTMOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnidirectionalSequenceLSTMOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnidirectionalSequenceLSTMOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnidirectionalSequenceLSTMOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnidirectionalSequenceLSTMOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # UnidirectionalSequenceLSTMOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # UnidirectionalSequenceLSTMOptions + def CellClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # UnidirectionalSequenceLSTMOptions + def ProjClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # UnidirectionalSequenceLSTMOptions + def TimeMajor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # UnidirectionalSequenceLSTMOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # UnidirectionalSequenceLSTMOptions + def DiagonalRecurrentTensors(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def UnidirectionalSequenceLSTMOptionsStart(builder): + builder.StartObject(6) + + +def UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction( + builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def UnidirectionalSequenceLSTMOptionsAddCellClip(builder, cellClip): + builder.PrependFloat32Slot(1, cellClip, 0.0) + + +def UnidirectionalSequenceLSTMOptionsAddProjClip(builder, projClip): + builder.PrependFloat32Slot(2, projClip, 0.0) + + +def UnidirectionalSequenceLSTMOptionsAddTimeMajor(builder, timeMajor): + builder.PrependBoolSlot(3, timeMajor, 0) + + +def UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs( + builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(4, asymmetricQuantizeInputs, 0) + + +def UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors( + builder, diagonalRecurrentTensors): + builder.PrependBoolSlot(5, diagonalRecurrentTensors, 0) + + +def UnidirectionalSequenceLSTMOptionsEnd(builder): + return builder.EndObject() + + +class UnidirectionalSequenceLSTMOptionsT(object): + + # UnidirectionalSequenceLSTMOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.cellClip = 0.0 # type: float + self.projClip = 0.0 # type: float + self.timeMajor = False # type: bool + self.asymmetricQuantizeInputs = False # type: bool + self.diagonalRecurrentTensors = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + unidirectionalSequenceLstmoptions = UnidirectionalSequenceLSTMOptions() + unidirectionalSequenceLstmoptions.Init(buf, pos) + return cls.InitFromObj(unidirectionalSequenceLstmoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unidirectionalSequenceLstmoptions): + x = UnidirectionalSequenceLSTMOptionsT() + x._UnPack(unidirectionalSequenceLstmoptions) + return x + + # UnidirectionalSequenceLSTMOptionsT + def _UnPack(self, unidirectionalSequenceLstmoptions): + if unidirectionalSequenceLstmoptions is None: + return + self.fusedActivationFunction = unidirectionalSequenceLstmoptions.FusedActivationFunction( + ) + self.cellClip = unidirectionalSequenceLstmoptions.CellClip() + self.projClip = unidirectionalSequenceLstmoptions.ProjClip() + self.timeMajor = unidirectionalSequenceLstmoptions.TimeMajor() + self.asymmetricQuantizeInputs = unidirectionalSequenceLstmoptions.AsymmetricQuantizeInputs( + ) + self.diagonalRecurrentTensors = unidirectionalSequenceLstmoptions.DiagonalRecurrentTensors( + ) + + # UnidirectionalSequenceLSTMOptionsT + def Pack(self, builder): + UnidirectionalSequenceLSTMOptionsStart(builder) + UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction( + builder, self.fusedActivationFunction) + UnidirectionalSequenceLSTMOptionsAddCellClip(builder, self.cellClip) + UnidirectionalSequenceLSTMOptionsAddProjClip(builder, self.projClip) + UnidirectionalSequenceLSTMOptionsAddTimeMajor(builder, self.timeMajor) + UnidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs( + builder, self.asymmetricQuantizeInputs) + UnidirectionalSequenceLSTMOptionsAddDiagonalRecurrentTensors( + builder, self.diagonalRecurrentTensors) + unidirectionalSequenceLstmoptions = UnidirectionalSequenceLSTMOptionsEnd(builder) + return unidirectionalSequenceLstmoptions + + +class BidirectionalSequenceLSTMOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BidirectionalSequenceLSTMOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBidirectionalSequenceLSTMOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BidirectionalSequenceLSTMOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BidirectionalSequenceLSTMOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BidirectionalSequenceLSTMOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # BidirectionalSequenceLSTMOptions + def CellClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # BidirectionalSequenceLSTMOptions + def ProjClip(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # BidirectionalSequenceLSTMOptions + def MergeOutputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # BidirectionalSequenceLSTMOptions + def TimeMajor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return True + + # BidirectionalSequenceLSTMOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def BidirectionalSequenceLSTMOptionsStart(builder): + builder.StartObject(6) + + +def BidirectionalSequenceLSTMOptionsAddFusedActivationFunction(builder, + fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def BidirectionalSequenceLSTMOptionsAddCellClip(builder, cellClip): + builder.PrependFloat32Slot(1, cellClip, 0.0) + + +def BidirectionalSequenceLSTMOptionsAddProjClip(builder, projClip): + builder.PrependFloat32Slot(2, projClip, 0.0) + + +def BidirectionalSequenceLSTMOptionsAddMergeOutputs(builder, mergeOutputs): + builder.PrependBoolSlot(3, mergeOutputs, 0) + + +def BidirectionalSequenceLSTMOptionsAddTimeMajor(builder, timeMajor): + builder.PrependBoolSlot(4, timeMajor, 1) + + +def BidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs( + builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(5, asymmetricQuantizeInputs, 0) + + +def BidirectionalSequenceLSTMOptionsEnd(builder): + return builder.EndObject() + + +class BidirectionalSequenceLSTMOptionsT(object): + + # BidirectionalSequenceLSTMOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.cellClip = 0.0 # type: float + self.projClip = 0.0 # type: float + self.mergeOutputs = False # type: bool + self.timeMajor = True # type: bool + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + bidirectionalSequenceLstmoptions = BidirectionalSequenceLSTMOptions() + bidirectionalSequenceLstmoptions.Init(buf, pos) + return cls.InitFromObj(bidirectionalSequenceLstmoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bidirectionalSequenceLstmoptions): + x = BidirectionalSequenceLSTMOptionsT() + x._UnPack(bidirectionalSequenceLstmoptions) + return x + + # BidirectionalSequenceLSTMOptionsT + def _UnPack(self, bidirectionalSequenceLstmoptions): + if bidirectionalSequenceLstmoptions is None: + return + self.fusedActivationFunction = bidirectionalSequenceLstmoptions.FusedActivationFunction( + ) + self.cellClip = bidirectionalSequenceLstmoptions.CellClip() + self.projClip = bidirectionalSequenceLstmoptions.ProjClip() + self.mergeOutputs = bidirectionalSequenceLstmoptions.MergeOutputs() + self.timeMajor = bidirectionalSequenceLstmoptions.TimeMajor() + self.asymmetricQuantizeInputs = bidirectionalSequenceLstmoptions.AsymmetricQuantizeInputs( + ) + + # BidirectionalSequenceLSTMOptionsT + def Pack(self, builder): + BidirectionalSequenceLSTMOptionsStart(builder) + BidirectionalSequenceLSTMOptionsAddFusedActivationFunction( + builder, self.fusedActivationFunction) + BidirectionalSequenceLSTMOptionsAddCellClip(builder, self.cellClip) + BidirectionalSequenceLSTMOptionsAddProjClip(builder, self.projClip) + BidirectionalSequenceLSTMOptionsAddMergeOutputs(builder, self.mergeOutputs) + BidirectionalSequenceLSTMOptionsAddTimeMajor(builder, self.timeMajor) + BidirectionalSequenceLSTMOptionsAddAsymmetricQuantizeInputs( + builder, self.asymmetricQuantizeInputs) + bidirectionalSequenceLstmoptions = BidirectionalSequenceLSTMOptionsEnd(builder) + return bidirectionalSequenceLstmoptions + + +class ResizeBilinearOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ResizeBilinearOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsResizeBilinearOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ResizeBilinearOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ResizeBilinearOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ResizeBilinearOptions + def AlignCorners(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # ResizeBilinearOptions + def HalfPixelCenters(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def ResizeBilinearOptionsStart(builder): + builder.StartObject(4) + + +def ResizeBilinearOptionsAddAlignCorners(builder, alignCorners): + builder.PrependBoolSlot(2, alignCorners, 0) + + +def ResizeBilinearOptionsAddHalfPixelCenters(builder, halfPixelCenters): + builder.PrependBoolSlot(3, halfPixelCenters, 0) + + +def ResizeBilinearOptionsEnd(builder): + return builder.EndObject() + + +class ResizeBilinearOptionsT(object): + + # ResizeBilinearOptionsT + def __init__(self): + self.alignCorners = False # type: bool + self.halfPixelCenters = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + resizeBilinearOptions = ResizeBilinearOptions() + resizeBilinearOptions.Init(buf, pos) + return cls.InitFromObj(resizeBilinearOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, resizeBilinearOptions): + x = ResizeBilinearOptionsT() + x._UnPack(resizeBilinearOptions) + return x + + # ResizeBilinearOptionsT + def _UnPack(self, resizeBilinearOptions): + if resizeBilinearOptions is None: + return + self.alignCorners = resizeBilinearOptions.AlignCorners() + self.halfPixelCenters = resizeBilinearOptions.HalfPixelCenters() + + # ResizeBilinearOptionsT + def Pack(self, builder): + ResizeBilinearOptionsStart(builder) + ResizeBilinearOptionsAddAlignCorners(builder, self.alignCorners) + ResizeBilinearOptionsAddHalfPixelCenters(builder, self.halfPixelCenters) + resizeBilinearOptions = ResizeBilinearOptionsEnd(builder) + return resizeBilinearOptions + + +class ResizeNearestNeighborOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ResizeNearestNeighborOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsResizeNearestNeighborOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ResizeNearestNeighborOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ResizeNearestNeighborOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ResizeNearestNeighborOptions + def AlignCorners(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # ResizeNearestNeighborOptions + def HalfPixelCenters(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def ResizeNearestNeighborOptionsStart(builder): + builder.StartObject(2) + + +def ResizeNearestNeighborOptionsAddAlignCorners(builder, alignCorners): + builder.PrependBoolSlot(0, alignCorners, 0) + + +def ResizeNearestNeighborOptionsAddHalfPixelCenters(builder, halfPixelCenters): + builder.PrependBoolSlot(1, halfPixelCenters, 0) + + +def ResizeNearestNeighborOptionsEnd(builder): + return builder.EndObject() + + +class ResizeNearestNeighborOptionsT(object): + + # ResizeNearestNeighborOptionsT + def __init__(self): + self.alignCorners = False # type: bool + self.halfPixelCenters = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + resizeNearestNeighborOptions = ResizeNearestNeighborOptions() + resizeNearestNeighborOptions.Init(buf, pos) + return cls.InitFromObj(resizeNearestNeighborOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, resizeNearestNeighborOptions): + x = ResizeNearestNeighborOptionsT() + x._UnPack(resizeNearestNeighborOptions) + return x + + # ResizeNearestNeighborOptionsT + def _UnPack(self, resizeNearestNeighborOptions): + if resizeNearestNeighborOptions is None: + return + self.alignCorners = resizeNearestNeighborOptions.AlignCorners() + self.halfPixelCenters = resizeNearestNeighborOptions.HalfPixelCenters() + + # ResizeNearestNeighborOptionsT + def Pack(self, builder): + ResizeNearestNeighborOptionsStart(builder) + ResizeNearestNeighborOptionsAddAlignCorners(builder, self.alignCorners) + ResizeNearestNeighborOptionsAddHalfPixelCenters(builder, self.halfPixelCenters) + resizeNearestNeighborOptions = ResizeNearestNeighborOptionsEnd(builder) + return resizeNearestNeighborOptions + + +class CallOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CallOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCallOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CallOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CallOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CallOptions + def Subgraph(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + +def CallOptionsStart(builder): + builder.StartObject(1) + + +def CallOptionsAddSubgraph(builder, subgraph): + builder.PrependUint32Slot(0, subgraph, 0) + + +def CallOptionsEnd(builder): + return builder.EndObject() + + +class CallOptionsT(object): + + # CallOptionsT + def __init__(self): + self.subgraph = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + callOptions = CallOptions() + callOptions.Init(buf, pos) + return cls.InitFromObj(callOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, callOptions): + x = CallOptionsT() + x._UnPack(callOptions) + return x + + # CallOptionsT + def _UnPack(self, callOptions): + if callOptions is None: + return + self.subgraph = callOptions.Subgraph() + + # CallOptionsT + def Pack(self, builder): + CallOptionsStart(builder) + CallOptionsAddSubgraph(builder, self.subgraph) + callOptions = CallOptionsEnd(builder) + return callOptions + + +class PadOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PadOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPadOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def PadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # PadOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def PadOptionsStart(builder): + builder.StartObject(0) + + +def PadOptionsEnd(builder): + return builder.EndObject() + + +class PadOptionsT(object): + + # PadOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + padOptions = PadOptions() + padOptions.Init(buf, pos) + return cls.InitFromObj(padOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, padOptions): + x = PadOptionsT() + x._UnPack(padOptions) + return x + + # PadOptionsT + def _UnPack(self, padOptions): + if padOptions is None: + return + + # PadOptionsT + def Pack(self, builder): + PadOptionsStart(builder) + padOptions = PadOptionsEnd(builder) + return padOptions + + +class PadV2Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PadV2Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPadV2Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def PadV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # PadV2Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def PadV2OptionsStart(builder): + builder.StartObject(0) + + +def PadV2OptionsEnd(builder): + return builder.EndObject() + + +class PadV2OptionsT(object): + + # PadV2OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + padV2Options = PadV2Options() + padV2Options.Init(buf, pos) + return cls.InitFromObj(padV2Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, padV2Options): + x = PadV2OptionsT() + x._UnPack(padV2Options) + return x + + # PadV2OptionsT + def _UnPack(self, padV2Options): + if padV2Options is None: + return + + # PadV2OptionsT + def Pack(self, builder): + PadV2OptionsStart(builder) + padV2Options = PadV2OptionsEnd(builder) + return padV2Options + + +class ReshapeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReshapeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReshapeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReshapeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReshapeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReshapeOptions + def NewShape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ReshapeOptions + def NewShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ReshapeOptions + def NewShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ReshapeOptions + def NewShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def ReshapeOptionsStart(builder): + builder.StartObject(1) + + +def ReshapeOptionsAddNewShape(builder, newShape): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(newShape), 0) + + +def ReshapeOptionsStartNewShapeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ReshapeOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class ReshapeOptionsT(object): + + # ReshapeOptionsT + def __init__(self): + self.newShape = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + reshapeOptions = ReshapeOptions() + reshapeOptions.Init(buf, pos) + return cls.InitFromObj(reshapeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, reshapeOptions): + x = ReshapeOptionsT() + x._UnPack(reshapeOptions) + return x + + # ReshapeOptionsT + def _UnPack(self, reshapeOptions): + if reshapeOptions is None: + return + if not reshapeOptions.NewShapeIsNone(): + if np is None: + self.newShape = [] + for i in range(reshapeOptions.NewShapeLength()): + self.newShape.append(reshapeOptions.NewShape(i)) + else: + self.newShape = reshapeOptions.NewShapeAsNumpy() + + # ReshapeOptionsT + def Pack(self, builder): + if self.newShape is not None: + if np is not None and type(self.newShape) is np.ndarray: + newShape = builder.CreateNumpyVector(self.newShape) + else: + ReshapeOptionsStartNewShapeVector(builder, len(self.newShape)) + for i in reversed(range(len(self.newShape))): + builder.PrependInt32(self.newShape[i]) + newShape = builder.EndVector() + ReshapeOptionsStart(builder) + if self.newShape is not None: + ReshapeOptionsAddNewShape(builder, newShape) + reshapeOptions = ReshapeOptionsEnd(builder) + return reshapeOptions + + +class SpaceToBatchNDOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SpaceToBatchNDOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSpaceToBatchNDOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SpaceToBatchNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SpaceToBatchNDOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SpaceToBatchNDOptionsStart(builder): + builder.StartObject(0) + + +def SpaceToBatchNDOptionsEnd(builder): + return builder.EndObject() + + +class SpaceToBatchNDOptionsT(object): + + # SpaceToBatchNDOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + spaceToBatchNdoptions = SpaceToBatchNDOptions() + spaceToBatchNdoptions.Init(buf, pos) + return cls.InitFromObj(spaceToBatchNdoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, spaceToBatchNdoptions): + x = SpaceToBatchNDOptionsT() + x._UnPack(spaceToBatchNdoptions) + return x + + # SpaceToBatchNDOptionsT + def _UnPack(self, spaceToBatchNdoptions): + if spaceToBatchNdoptions is None: + return + + # SpaceToBatchNDOptionsT + def Pack(self, builder): + SpaceToBatchNDOptionsStart(builder) + spaceToBatchNdoptions = SpaceToBatchNDOptionsEnd(builder) + return spaceToBatchNdoptions + + +class BatchToSpaceNDOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BatchToSpaceNDOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBatchToSpaceNDOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BatchToSpaceNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BatchToSpaceNDOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def BatchToSpaceNDOptionsStart(builder): + builder.StartObject(0) + + +def BatchToSpaceNDOptionsEnd(builder): + return builder.EndObject() + + +class BatchToSpaceNDOptionsT(object): + + # BatchToSpaceNDOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + batchToSpaceNdoptions = BatchToSpaceNDOptions() + batchToSpaceNdoptions.Init(buf, pos) + return cls.InitFromObj(batchToSpaceNdoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, batchToSpaceNdoptions): + x = BatchToSpaceNDOptionsT() + x._UnPack(batchToSpaceNdoptions) + return x + + # BatchToSpaceNDOptionsT + def _UnPack(self, batchToSpaceNdoptions): + if batchToSpaceNdoptions is None: + return + + # BatchToSpaceNDOptionsT + def Pack(self, builder): + BatchToSpaceNDOptionsStart(builder) + batchToSpaceNdoptions = BatchToSpaceNDOptionsEnd(builder) + return batchToSpaceNdoptions + + +class SkipGramOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SkipGramOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSkipGramOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SkipGramOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SkipGramOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SkipGramOptions + def NgramSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # SkipGramOptions + def MaxSkipSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # SkipGramOptions + def IncludeAllNgrams(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def SkipGramOptionsStart(builder): + builder.StartObject(3) + + +def SkipGramOptionsAddNgramSize(builder, ngramSize): + builder.PrependInt32Slot(0, ngramSize, 0) + + +def SkipGramOptionsAddMaxSkipSize(builder, maxSkipSize): + builder.PrependInt32Slot(1, maxSkipSize, 0) + + +def SkipGramOptionsAddIncludeAllNgrams(builder, includeAllNgrams): + builder.PrependBoolSlot(2, includeAllNgrams, 0) + + +def SkipGramOptionsEnd(builder): + return builder.EndObject() + + +class SkipGramOptionsT(object): + + # SkipGramOptionsT + def __init__(self): + self.ngramSize = 0 # type: int + self.maxSkipSize = 0 # type: int + self.includeAllNgrams = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + skipGramOptions = SkipGramOptions() + skipGramOptions.Init(buf, pos) + return cls.InitFromObj(skipGramOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, skipGramOptions): + x = SkipGramOptionsT() + x._UnPack(skipGramOptions) + return x + + # SkipGramOptionsT + def _UnPack(self, skipGramOptions): + if skipGramOptions is None: + return + self.ngramSize = skipGramOptions.NgramSize() + self.maxSkipSize = skipGramOptions.MaxSkipSize() + self.includeAllNgrams = skipGramOptions.IncludeAllNgrams() + + # SkipGramOptionsT + def Pack(self, builder): + SkipGramOptionsStart(builder) + SkipGramOptionsAddNgramSize(builder, self.ngramSize) + SkipGramOptionsAddMaxSkipSize(builder, self.maxSkipSize) + SkipGramOptionsAddIncludeAllNgrams(builder, self.includeAllNgrams) + skipGramOptions = SkipGramOptionsEnd(builder) + return skipGramOptions + + +class SpaceToDepthOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SpaceToDepthOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSpaceToDepthOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SpaceToDepthOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SpaceToDepthOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SpaceToDepthOptions + def BlockSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def SpaceToDepthOptionsStart(builder): + builder.StartObject(1) + + +def SpaceToDepthOptionsAddBlockSize(builder, blockSize): + builder.PrependInt32Slot(0, blockSize, 0) + + +def SpaceToDepthOptionsEnd(builder): + return builder.EndObject() + + +class SpaceToDepthOptionsT(object): + + # SpaceToDepthOptionsT + def __init__(self): + self.blockSize = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + spaceToDepthOptions = SpaceToDepthOptions() + spaceToDepthOptions.Init(buf, pos) + return cls.InitFromObj(spaceToDepthOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, spaceToDepthOptions): + x = SpaceToDepthOptionsT() + x._UnPack(spaceToDepthOptions) + return x + + # SpaceToDepthOptionsT + def _UnPack(self, spaceToDepthOptions): + if spaceToDepthOptions is None: + return + self.blockSize = spaceToDepthOptions.BlockSize() + + # SpaceToDepthOptionsT + def Pack(self, builder): + SpaceToDepthOptionsStart(builder) + SpaceToDepthOptionsAddBlockSize(builder, self.blockSize) + spaceToDepthOptions = SpaceToDepthOptionsEnd(builder) + return spaceToDepthOptions + + +class DepthToSpaceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DepthToSpaceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDepthToSpaceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DepthToSpaceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DepthToSpaceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DepthToSpaceOptions + def BlockSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def DepthToSpaceOptionsStart(builder): + builder.StartObject(1) + + +def DepthToSpaceOptionsAddBlockSize(builder, blockSize): + builder.PrependInt32Slot(0, blockSize, 0) + + +def DepthToSpaceOptionsEnd(builder): + return builder.EndObject() + + +class DepthToSpaceOptionsT(object): + + # DepthToSpaceOptionsT + def __init__(self): + self.blockSize = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + depthToSpaceOptions = DepthToSpaceOptions() + depthToSpaceOptions.Init(buf, pos) + return cls.InitFromObj(depthToSpaceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, depthToSpaceOptions): + x = DepthToSpaceOptionsT() + x._UnPack(depthToSpaceOptions) + return x + + # DepthToSpaceOptionsT + def _UnPack(self, depthToSpaceOptions): + if depthToSpaceOptions is None: + return + self.blockSize = depthToSpaceOptions.BlockSize() + + # DepthToSpaceOptionsT + def Pack(self, builder): + DepthToSpaceOptionsStart(builder) + DepthToSpaceOptionsAddBlockSize(builder, self.blockSize) + depthToSpaceOptions = DepthToSpaceOptionsEnd(builder) + return depthToSpaceOptions + + +class SubOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SubOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSubOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SubOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SubOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SubOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # SubOptions + def PotScaleInt16(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return True + + +def SubOptionsStart(builder): + builder.StartObject(2) + + +def SubOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def SubOptionsAddPotScaleInt16(builder, potScaleInt16): + builder.PrependBoolSlot(1, potScaleInt16, 1) + + +def SubOptionsEnd(builder): + return builder.EndObject() + + +class SubOptionsT(object): + + # SubOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.potScaleInt16 = True # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + subOptions = SubOptions() + subOptions.Init(buf, pos) + return cls.InitFromObj(subOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, subOptions): + x = SubOptionsT() + x._UnPack(subOptions) + return x + + # SubOptionsT + def _UnPack(self, subOptions): + if subOptions is None: + return + self.fusedActivationFunction = subOptions.FusedActivationFunction() + self.potScaleInt16 = subOptions.PotScaleInt16() + + # SubOptionsT + def Pack(self, builder): + SubOptionsStart(builder) + SubOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + SubOptionsAddPotScaleInt16(builder, self.potScaleInt16) + subOptions = SubOptionsEnd(builder) + return subOptions + + +class DivOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DivOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDivOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DivOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DivOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DivOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def DivOptionsStart(builder): + builder.StartObject(1) + + +def DivOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def DivOptionsEnd(builder): + return builder.EndObject() + + +class DivOptionsT(object): + + # DivOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + divOptions = DivOptions() + divOptions.Init(buf, pos) + return cls.InitFromObj(divOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, divOptions): + x = DivOptionsT() + x._UnPack(divOptions) + return x + + # DivOptionsT + def _UnPack(self, divOptions): + if divOptions is None: + return + self.fusedActivationFunction = divOptions.FusedActivationFunction() + + # DivOptionsT + def Pack(self, builder): + DivOptionsStart(builder) + DivOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + divOptions = DivOptionsEnd(builder) + return divOptions + + +class TopKV2Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TopKV2Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTopKV2Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TopKV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # TopKV2Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def TopKV2OptionsStart(builder): + builder.StartObject(0) + + +def TopKV2OptionsEnd(builder): + return builder.EndObject() + + +class TopKV2OptionsT(object): + + # TopKV2OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + topKv2Options = TopKV2Options() + topKv2Options.Init(buf, pos) + return cls.InitFromObj(topKv2Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, topKv2Options): + x = TopKV2OptionsT() + x._UnPack(topKv2Options) + return x + + # TopKV2OptionsT + def _UnPack(self, topKv2Options): + if topKv2Options is None: + return + + # TopKV2OptionsT + def Pack(self, builder): + TopKV2OptionsStart(builder) + topKv2Options = TopKV2OptionsEnd(builder) + return topKv2Options + + +class EmbeddingLookupSparseOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = EmbeddingLookupSparseOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsEmbeddingLookupSparseOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def EmbeddingLookupSparseOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # EmbeddingLookupSparseOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # EmbeddingLookupSparseOptions + def Combiner(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def EmbeddingLookupSparseOptionsStart(builder): + builder.StartObject(1) + + +def EmbeddingLookupSparseOptionsAddCombiner(builder, combiner): + builder.PrependInt8Slot(0, combiner, 0) + + +def EmbeddingLookupSparseOptionsEnd(builder): + return builder.EndObject() + + +class EmbeddingLookupSparseOptionsT(object): + + # EmbeddingLookupSparseOptionsT + def __init__(self): + self.combiner = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + embeddingLookupSparseOptions = EmbeddingLookupSparseOptions() + embeddingLookupSparseOptions.Init(buf, pos) + return cls.InitFromObj(embeddingLookupSparseOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, embeddingLookupSparseOptions): + x = EmbeddingLookupSparseOptionsT() + x._UnPack(embeddingLookupSparseOptions) + return x + + # EmbeddingLookupSparseOptionsT + def _UnPack(self, embeddingLookupSparseOptions): + if embeddingLookupSparseOptions is None: + return + self.combiner = embeddingLookupSparseOptions.Combiner() + + # EmbeddingLookupSparseOptionsT + def Pack(self, builder): + EmbeddingLookupSparseOptionsStart(builder) + EmbeddingLookupSparseOptionsAddCombiner(builder, self.combiner) + embeddingLookupSparseOptions = EmbeddingLookupSparseOptionsEnd(builder) + return embeddingLookupSparseOptions + + +class GatherOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GatherOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGatherOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GatherOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GatherOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # GatherOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # GatherOptions + def BatchDims(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def GatherOptionsStart(builder): + builder.StartObject(2) + + +def GatherOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(0, axis, 0) + + +def GatherOptionsAddBatchDims(builder, batchDims): + builder.PrependInt32Slot(1, batchDims, 0) + + +def GatherOptionsEnd(builder): + return builder.EndObject() + + +class GatherOptionsT(object): + + # GatherOptionsT + def __init__(self): + self.axis = 0 # type: int + self.batchDims = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + gatherOptions = GatherOptions() + gatherOptions.Init(buf, pos) + return cls.InitFromObj(gatherOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, gatherOptions): + x = GatherOptionsT() + x._UnPack(gatherOptions) + return x + + # GatherOptionsT + def _UnPack(self, gatherOptions): + if gatherOptions is None: + return + self.axis = gatherOptions.Axis() + self.batchDims = gatherOptions.BatchDims() + + # GatherOptionsT + def Pack(self, builder): + GatherOptionsStart(builder) + GatherOptionsAddAxis(builder, self.axis) + GatherOptionsAddBatchDims(builder, self.batchDims) + gatherOptions = GatherOptionsEnd(builder) + return gatherOptions + + +class TransposeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TransposeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTransposeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TransposeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # TransposeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def TransposeOptionsStart(builder): + builder.StartObject(0) + + +def TransposeOptionsEnd(builder): + return builder.EndObject() + + +class TransposeOptionsT(object): + + # TransposeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + transposeOptions = TransposeOptions() + transposeOptions.Init(buf, pos) + return cls.InitFromObj(transposeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, transposeOptions): + x = TransposeOptionsT() + x._UnPack(transposeOptions) + return x + + # TransposeOptionsT + def _UnPack(self, transposeOptions): + if transposeOptions is None: + return + + # TransposeOptionsT + def Pack(self, builder): + TransposeOptionsStart(builder) + transposeOptions = TransposeOptionsEnd(builder) + return transposeOptions + + +class ExpOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ExpOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsExpOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ExpOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ExpOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ExpOptionsStart(builder): + builder.StartObject(0) + + +def ExpOptionsEnd(builder): + return builder.EndObject() + + +class ExpOptionsT(object): + + # ExpOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + expOptions = ExpOptions() + expOptions.Init(buf, pos) + return cls.InitFromObj(expOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, expOptions): + x = ExpOptionsT() + x._UnPack(expOptions) + return x + + # ExpOptionsT + def _UnPack(self, expOptions): + if expOptions is None: + return + + # ExpOptionsT + def Pack(self, builder): + ExpOptionsStart(builder) + expOptions = ExpOptionsEnd(builder) + return expOptions + + +class CosOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CosOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCosOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CosOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CosOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def CosOptionsStart(builder): + builder.StartObject(0) + + +def CosOptionsEnd(builder): + return builder.EndObject() + + +class CosOptionsT(object): + + # CosOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + cosOptions = CosOptions() + cosOptions.Init(buf, pos) + return cls.InitFromObj(cosOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, cosOptions): + x = CosOptionsT() + x._UnPack(cosOptions) + return x + + # CosOptionsT + def _UnPack(self, cosOptions): + if cosOptions is None: + return + + # CosOptionsT + def Pack(self, builder): + CosOptionsStart(builder) + cosOptions = CosOptionsEnd(builder) + return cosOptions + + +class ReducerOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReducerOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReducerOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReducerOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReducerOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReducerOptions + def KeepDims(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def ReducerOptionsStart(builder): + builder.StartObject(1) + + +def ReducerOptionsAddKeepDims(builder, keepDims): + builder.PrependBoolSlot(0, keepDims, 0) + + +def ReducerOptionsEnd(builder): + return builder.EndObject() + + +class ReducerOptionsT(object): + + # ReducerOptionsT + def __init__(self): + self.keepDims = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + reducerOptions = ReducerOptions() + reducerOptions.Init(buf, pos) + return cls.InitFromObj(reducerOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, reducerOptions): + x = ReducerOptionsT() + x._UnPack(reducerOptions) + return x + + # ReducerOptionsT + def _UnPack(self, reducerOptions): + if reducerOptions is None: + return + self.keepDims = reducerOptions.KeepDims() + + # ReducerOptionsT + def Pack(self, builder): + ReducerOptionsStart(builder) + ReducerOptionsAddKeepDims(builder, self.keepDims) + reducerOptions = ReducerOptionsEnd(builder) + return reducerOptions + + +class SqueezeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SqueezeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSqueezeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SqueezeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SqueezeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SqueezeOptions + def SqueezeDims(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SqueezeOptions + def SqueezeDimsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SqueezeOptions + def SqueezeDimsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SqueezeOptions + def SqueezeDimsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def SqueezeOptionsStart(builder): + builder.StartObject(1) + + +def SqueezeOptionsAddSqueezeDims(builder, squeezeDims): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(squeezeDims), 0) + + +def SqueezeOptionsStartSqueezeDimsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SqueezeOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class SqueezeOptionsT(object): + + # SqueezeOptionsT + def __init__(self): + self.squeezeDims = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + squeezeOptions = SqueezeOptions() + squeezeOptions.Init(buf, pos) + return cls.InitFromObj(squeezeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, squeezeOptions): + x = SqueezeOptionsT() + x._UnPack(squeezeOptions) + return x + + # SqueezeOptionsT + def _UnPack(self, squeezeOptions): + if squeezeOptions is None: + return + if not squeezeOptions.SqueezeDimsIsNone(): + if np is None: + self.squeezeDims = [] + for i in range(squeezeOptions.SqueezeDimsLength()): + self.squeezeDims.append(squeezeOptions.SqueezeDims(i)) + else: + self.squeezeDims = squeezeOptions.SqueezeDimsAsNumpy() + + # SqueezeOptionsT + def Pack(self, builder): + if self.squeezeDims is not None: + if np is not None and type(self.squeezeDims) is np.ndarray: + squeezeDims = builder.CreateNumpyVector(self.squeezeDims) + else: + SqueezeOptionsStartSqueezeDimsVector(builder, len(self.squeezeDims)) + for i in reversed(range(len(self.squeezeDims))): + builder.PrependInt32(self.squeezeDims[i]) + squeezeDims = builder.EndVector() + SqueezeOptionsStart(builder) + if self.squeezeDims is not None: + SqueezeOptionsAddSqueezeDims(builder, squeezeDims) + squeezeOptions = SqueezeOptionsEnd(builder) + return squeezeOptions + + +class SplitOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SplitOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSplitOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SplitOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SplitOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SplitOptions + def NumSplits(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def SplitOptionsStart(builder): + builder.StartObject(1) + + +def SplitOptionsAddNumSplits(builder, numSplits): + builder.PrependInt32Slot(0, numSplits, 0) + + +def SplitOptionsEnd(builder): + return builder.EndObject() + + +class SplitOptionsT(object): + + # SplitOptionsT + def __init__(self): + self.numSplits = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + splitOptions = SplitOptions() + splitOptions.Init(buf, pos) + return cls.InitFromObj(splitOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, splitOptions): + x = SplitOptionsT() + x._UnPack(splitOptions) + return x + + # SplitOptionsT + def _UnPack(self, splitOptions): + if splitOptions is None: + return + self.numSplits = splitOptions.NumSplits() + + # SplitOptionsT + def Pack(self, builder): + SplitOptionsStart(builder) + SplitOptionsAddNumSplits(builder, self.numSplits) + splitOptions = SplitOptionsEnd(builder) + return splitOptions + + +class SplitVOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SplitVOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSplitVOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SplitVOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SplitVOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SplitVOptions + def NumSplits(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def SplitVOptionsStart(builder): + builder.StartObject(1) + + +def SplitVOptionsAddNumSplits(builder, numSplits): + builder.PrependInt32Slot(0, numSplits, 0) + + +def SplitVOptionsEnd(builder): + return builder.EndObject() + + +class SplitVOptionsT(object): + + # SplitVOptionsT + def __init__(self): + self.numSplits = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + splitVoptions = SplitVOptions() + splitVoptions.Init(buf, pos) + return cls.InitFromObj(splitVoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, splitVoptions): + x = SplitVOptionsT() + x._UnPack(splitVoptions) + return x + + # SplitVOptionsT + def _UnPack(self, splitVoptions): + if splitVoptions is None: + return + self.numSplits = splitVoptions.NumSplits() + + # SplitVOptionsT + def Pack(self, builder): + SplitVOptionsStart(builder) + SplitVOptionsAddNumSplits(builder, self.numSplits) + splitVoptions = SplitVOptionsEnd(builder) + return splitVoptions + + +class StridedSliceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StridedSliceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStridedSliceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StridedSliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StridedSliceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StridedSliceOptions + def BeginMask(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StridedSliceOptions + def EndMask(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StridedSliceOptions + def EllipsisMask(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StridedSliceOptions + def NewAxisMask(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StridedSliceOptions + def ShrinkAxisMask(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StridedSliceOptions + def Offset(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def StridedSliceOptionsStart(builder): + builder.StartObject(6) + + +def StridedSliceOptionsAddBeginMask(builder, beginMask): + builder.PrependInt32Slot(0, beginMask, 0) + + +def StridedSliceOptionsAddEndMask(builder, endMask): + builder.PrependInt32Slot(1, endMask, 0) + + +def StridedSliceOptionsAddEllipsisMask(builder, ellipsisMask): + builder.PrependInt32Slot(2, ellipsisMask, 0) + + +def StridedSliceOptionsAddNewAxisMask(builder, newAxisMask): + builder.PrependInt32Slot(3, newAxisMask, 0) + + +def StridedSliceOptionsAddShrinkAxisMask(builder, shrinkAxisMask): + builder.PrependInt32Slot(4, shrinkAxisMask, 0) + + +def StridedSliceOptionsAddOffset(builder, offset): + builder.PrependBoolSlot(5, offset, 0) + + +def StridedSliceOptionsEnd(builder): + return builder.EndObject() + + +class StridedSliceOptionsT(object): + + # StridedSliceOptionsT + def __init__(self): + self.beginMask = 0 # type: int + self.endMask = 0 # type: int + self.ellipsisMask = 0 # type: int + self.newAxisMask = 0 # type: int + self.shrinkAxisMask = 0 # type: int + self.offset = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + stridedSliceOptions = StridedSliceOptions() + stridedSliceOptions.Init(buf, pos) + return cls.InitFromObj(stridedSliceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stridedSliceOptions): + x = StridedSliceOptionsT() + x._UnPack(stridedSliceOptions) + return x + + # StridedSliceOptionsT + def _UnPack(self, stridedSliceOptions): + if stridedSliceOptions is None: + return + self.beginMask = stridedSliceOptions.BeginMask() + self.endMask = stridedSliceOptions.EndMask() + self.ellipsisMask = stridedSliceOptions.EllipsisMask() + self.newAxisMask = stridedSliceOptions.NewAxisMask() + self.shrinkAxisMask = stridedSliceOptions.ShrinkAxisMask() + self.offset = stridedSliceOptions.Offset() + + # StridedSliceOptionsT + def Pack(self, builder): + StridedSliceOptionsStart(builder) + StridedSliceOptionsAddBeginMask(builder, self.beginMask) + StridedSliceOptionsAddEndMask(builder, self.endMask) + StridedSliceOptionsAddEllipsisMask(builder, self.ellipsisMask) + StridedSliceOptionsAddNewAxisMask(builder, self.newAxisMask) + StridedSliceOptionsAddShrinkAxisMask(builder, self.shrinkAxisMask) + StridedSliceOptionsAddOffset(builder, self.offset) + stridedSliceOptions = StridedSliceOptionsEnd(builder) + return stridedSliceOptions + + +class LogSoftmaxOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogSoftmaxOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLogSoftmaxOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LogSoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LogSoftmaxOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogSoftmaxOptionsStart(builder): + builder.StartObject(0) + + +def LogSoftmaxOptionsEnd(builder): + return builder.EndObject() + + +class LogSoftmaxOptionsT(object): + + # LogSoftmaxOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + logSoftmaxOptions = LogSoftmaxOptions() + logSoftmaxOptions.Init(buf, pos) + return cls.InitFromObj(logSoftmaxOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, logSoftmaxOptions): + x = LogSoftmaxOptionsT() + x._UnPack(logSoftmaxOptions) + return x + + # LogSoftmaxOptionsT + def _UnPack(self, logSoftmaxOptions): + if logSoftmaxOptions is None: + return + + # LogSoftmaxOptionsT + def Pack(self, builder): + LogSoftmaxOptionsStart(builder) + logSoftmaxOptions = LogSoftmaxOptionsEnd(builder) + return logSoftmaxOptions + + +class CastOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CastOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCastOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CastOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CastOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CastOptions + def InDataType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # CastOptions + def OutDataType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def CastOptionsStart(builder): + builder.StartObject(2) + + +def CastOptionsAddInDataType(builder, inDataType): + builder.PrependInt8Slot(0, inDataType, 0) + + +def CastOptionsAddOutDataType(builder, outDataType): + builder.PrependInt8Slot(1, outDataType, 0) + + +def CastOptionsEnd(builder): + return builder.EndObject() + + +class CastOptionsT(object): + + # CastOptionsT + def __init__(self): + self.inDataType = 0 # type: int + self.outDataType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + castOptions = CastOptions() + castOptions.Init(buf, pos) + return cls.InitFromObj(castOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, castOptions): + x = CastOptionsT() + x._UnPack(castOptions) + return x + + # CastOptionsT + def _UnPack(self, castOptions): + if castOptions is None: + return + self.inDataType = castOptions.InDataType() + self.outDataType = castOptions.OutDataType() + + # CastOptionsT + def Pack(self, builder): + CastOptionsStart(builder) + CastOptionsAddInDataType(builder, self.inDataType) + CastOptionsAddOutDataType(builder, self.outDataType) + castOptions = CastOptionsEnd(builder) + return castOptions + + +class DequantizeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DequantizeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDequantizeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DequantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DequantizeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def DequantizeOptionsStart(builder): + builder.StartObject(0) + + +def DequantizeOptionsEnd(builder): + return builder.EndObject() + + +class DequantizeOptionsT(object): + + # DequantizeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + dequantizeOptions = DequantizeOptions() + dequantizeOptions.Init(buf, pos) + return cls.InitFromObj(dequantizeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, dequantizeOptions): + x = DequantizeOptionsT() + x._UnPack(dequantizeOptions) + return x + + # DequantizeOptionsT + def _UnPack(self, dequantizeOptions): + if dequantizeOptions is None: + return + + # DequantizeOptionsT + def Pack(self, builder): + DequantizeOptionsStart(builder) + dequantizeOptions = DequantizeOptionsEnd(builder) + return dequantizeOptions + + +class MaximumMinimumOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MaximumMinimumOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMaximumMinimumOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MaximumMinimumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MaximumMinimumOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def MaximumMinimumOptionsStart(builder): + builder.StartObject(0) + + +def MaximumMinimumOptionsEnd(builder): + return builder.EndObject() + + +class MaximumMinimumOptionsT(object): + + # MaximumMinimumOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + maximumMinimumOptions = MaximumMinimumOptions() + maximumMinimumOptions.Init(buf, pos) + return cls.InitFromObj(maximumMinimumOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, maximumMinimumOptions): + x = MaximumMinimumOptionsT() + x._UnPack(maximumMinimumOptions) + return x + + # MaximumMinimumOptionsT + def _UnPack(self, maximumMinimumOptions): + if maximumMinimumOptions is None: + return + + # MaximumMinimumOptionsT + def Pack(self, builder): + MaximumMinimumOptionsStart(builder) + maximumMinimumOptions = MaximumMinimumOptionsEnd(builder) + return maximumMinimumOptions + + +class TileOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TileOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTileOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TileOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # TileOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def TileOptionsStart(builder): + builder.StartObject(0) + + +def TileOptionsEnd(builder): + return builder.EndObject() + + +class TileOptionsT(object): + + # TileOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + tileOptions = TileOptions() + tileOptions.Init(buf, pos) + return cls.InitFromObj(tileOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, tileOptions): + x = TileOptionsT() + x._UnPack(tileOptions) + return x + + # TileOptionsT + def _UnPack(self, tileOptions): + if tileOptions is None: + return + + # TileOptionsT + def Pack(self, builder): + TileOptionsStart(builder) + tileOptions = TileOptionsEnd(builder) + return tileOptions + + +class ArgMaxOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ArgMaxOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsArgMaxOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ArgMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ArgMaxOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ArgMaxOptions + def OutputType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def ArgMaxOptionsStart(builder): + builder.StartObject(1) + + +def ArgMaxOptionsAddOutputType(builder, outputType): + builder.PrependInt8Slot(0, outputType, 0) + + +def ArgMaxOptionsEnd(builder): + return builder.EndObject() + + +class ArgMaxOptionsT(object): + + # ArgMaxOptionsT + def __init__(self): + self.outputType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + argMaxOptions = ArgMaxOptions() + argMaxOptions.Init(buf, pos) + return cls.InitFromObj(argMaxOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, argMaxOptions): + x = ArgMaxOptionsT() + x._UnPack(argMaxOptions) + return x + + # ArgMaxOptionsT + def _UnPack(self, argMaxOptions): + if argMaxOptions is None: + return + self.outputType = argMaxOptions.OutputType() + + # ArgMaxOptionsT + def Pack(self, builder): + ArgMaxOptionsStart(builder) + ArgMaxOptionsAddOutputType(builder, self.outputType) + argMaxOptions = ArgMaxOptionsEnd(builder) + return argMaxOptions + + +class ArgMinOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ArgMinOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsArgMinOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ArgMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ArgMinOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ArgMinOptions + def OutputType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def ArgMinOptionsStart(builder): + builder.StartObject(1) + + +def ArgMinOptionsAddOutputType(builder, outputType): + builder.PrependInt8Slot(0, outputType, 0) + + +def ArgMinOptionsEnd(builder): + return builder.EndObject() + + +class ArgMinOptionsT(object): + + # ArgMinOptionsT + def __init__(self): + self.outputType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + argMinOptions = ArgMinOptions() + argMinOptions.Init(buf, pos) + return cls.InitFromObj(argMinOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, argMinOptions): + x = ArgMinOptionsT() + x._UnPack(argMinOptions) + return x + + # ArgMinOptionsT + def _UnPack(self, argMinOptions): + if argMinOptions is None: + return + self.outputType = argMinOptions.OutputType() + + # ArgMinOptionsT + def Pack(self, builder): + ArgMinOptionsStart(builder) + ArgMinOptionsAddOutputType(builder, self.outputType) + argMinOptions = ArgMinOptionsEnd(builder) + return argMinOptions + + +class GreaterOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GreaterOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGreaterOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GreaterOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GreaterOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def GreaterOptionsStart(builder): + builder.StartObject(0) + + +def GreaterOptionsEnd(builder): + return builder.EndObject() + + +class GreaterOptionsT(object): + + # GreaterOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + greaterOptions = GreaterOptions() + greaterOptions.Init(buf, pos) + return cls.InitFromObj(greaterOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, greaterOptions): + x = GreaterOptionsT() + x._UnPack(greaterOptions) + return x + + # GreaterOptionsT + def _UnPack(self, greaterOptions): + if greaterOptions is None: + return + + # GreaterOptionsT + def Pack(self, builder): + GreaterOptionsStart(builder) + greaterOptions = GreaterOptionsEnd(builder) + return greaterOptions + + +class GreaterEqualOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GreaterEqualOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGreaterEqualOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GreaterEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GreaterEqualOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def GreaterEqualOptionsStart(builder): + builder.StartObject(0) + + +def GreaterEqualOptionsEnd(builder): + return builder.EndObject() + + +class GreaterEqualOptionsT(object): + + # GreaterEqualOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + greaterEqualOptions = GreaterEqualOptions() + greaterEqualOptions.Init(buf, pos) + return cls.InitFromObj(greaterEqualOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, greaterEqualOptions): + x = GreaterEqualOptionsT() + x._UnPack(greaterEqualOptions) + return x + + # GreaterEqualOptionsT + def _UnPack(self, greaterEqualOptions): + if greaterEqualOptions is None: + return + + # GreaterEqualOptionsT + def Pack(self, builder): + GreaterEqualOptionsStart(builder) + greaterEqualOptions = GreaterEqualOptionsEnd(builder) + return greaterEqualOptions + + +class LessOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LessOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLessOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LessOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LessOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LessOptionsStart(builder): + builder.StartObject(0) + + +def LessOptionsEnd(builder): + return builder.EndObject() + + +class LessOptionsT(object): + + # LessOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + lessOptions = LessOptions() + lessOptions.Init(buf, pos) + return cls.InitFromObj(lessOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, lessOptions): + x = LessOptionsT() + x._UnPack(lessOptions) + return x + + # LessOptionsT + def _UnPack(self, lessOptions): + if lessOptions is None: + return + + # LessOptionsT + def Pack(self, builder): + LessOptionsStart(builder) + lessOptions = LessOptionsEnd(builder) + return lessOptions + + +class LessEqualOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LessEqualOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLessEqualOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LessEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LessEqualOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LessEqualOptionsStart(builder): + builder.StartObject(0) + + +def LessEqualOptionsEnd(builder): + return builder.EndObject() + + +class LessEqualOptionsT(object): + + # LessEqualOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + lessEqualOptions = LessEqualOptions() + lessEqualOptions.Init(buf, pos) + return cls.InitFromObj(lessEqualOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, lessEqualOptions): + x = LessEqualOptionsT() + x._UnPack(lessEqualOptions) + return x + + # LessEqualOptionsT + def _UnPack(self, lessEqualOptions): + if lessEqualOptions is None: + return + + # LessEqualOptionsT + def Pack(self, builder): + LessEqualOptionsStart(builder) + lessEqualOptions = LessEqualOptionsEnd(builder) + return lessEqualOptions + + +class NegOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NegOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNegOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def NegOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # NegOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def NegOptionsStart(builder): + builder.StartObject(0) + + +def NegOptionsEnd(builder): + return builder.EndObject() + + +class NegOptionsT(object): + + # NegOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + negOptions = NegOptions() + negOptions.Init(buf, pos) + return cls.InitFromObj(negOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, negOptions): + x = NegOptionsT() + x._UnPack(negOptions) + return x + + # NegOptionsT + def _UnPack(self, negOptions): + if negOptions is None: + return + + # NegOptionsT + def Pack(self, builder): + NegOptionsStart(builder) + negOptions = NegOptionsEnd(builder) + return negOptions + + +class SelectOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SelectOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSelectOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SelectOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SelectOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SelectOptionsStart(builder): + builder.StartObject(0) + + +def SelectOptionsEnd(builder): + return builder.EndObject() + + +class SelectOptionsT(object): + + # SelectOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + selectOptions = SelectOptions() + selectOptions.Init(buf, pos) + return cls.InitFromObj(selectOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, selectOptions): + x = SelectOptionsT() + x._UnPack(selectOptions) + return x + + # SelectOptionsT + def _UnPack(self, selectOptions): + if selectOptions is None: + return + + # SelectOptionsT + def Pack(self, builder): + SelectOptionsStart(builder) + selectOptions = SelectOptionsEnd(builder) + return selectOptions + + +class SliceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SliceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSliceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SliceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SliceOptionsStart(builder): + builder.StartObject(0) + + +def SliceOptionsEnd(builder): + return builder.EndObject() + + +class SliceOptionsT(object): + + # SliceOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + sliceOptions = SliceOptions() + sliceOptions.Init(buf, pos) + return cls.InitFromObj(sliceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, sliceOptions): + x = SliceOptionsT() + x._UnPack(sliceOptions) + return x + + # SliceOptionsT + def _UnPack(self, sliceOptions): + if sliceOptions is None: + return + + # SliceOptionsT + def Pack(self, builder): + SliceOptionsStart(builder) + sliceOptions = SliceOptionsEnd(builder) + return sliceOptions + + +class TransposeConvOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TransposeConvOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTransposeConvOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TransposeConvOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # TransposeConvOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TransposeConvOptions + def Padding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # TransposeConvOptions + def StrideW(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # TransposeConvOptions + def StrideH(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # TransposeConvOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # TransposeConvOptions + def QuantizedBiasType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def TransposeConvOptionsStart(builder): + builder.StartObject(5) + + +def TransposeConvOptionsAddPadding(builder, padding): + builder.PrependInt8Slot(0, padding, 0) + + +def TransposeConvOptionsAddStrideW(builder, strideW): + builder.PrependInt32Slot(1, strideW, 0) + + +def TransposeConvOptionsAddStrideH(builder, strideH): + builder.PrependInt32Slot(2, strideH, 0) + + +def TransposeConvOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(3, fusedActivationFunction, 0) + + +def TransposeConvOptionsAddQuantizedBiasType(builder, quantizedBiasType): + builder.PrependInt8Slot(4, quantizedBiasType, 0) + + +def TransposeConvOptionsEnd(builder): + return builder.EndObject() + + +class TransposeConvOptionsT(object): + + # TransposeConvOptionsT + def __init__(self): + self.padding = 0 # type: int + self.strideW = 0 # type: int + self.strideH = 0 # type: int + self.fusedActivationFunction = 0 # type: int + self.quantizedBiasType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + transposeConvOptions = TransposeConvOptions() + transposeConvOptions.Init(buf, pos) + return cls.InitFromObj(transposeConvOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, transposeConvOptions): + x = TransposeConvOptionsT() + x._UnPack(transposeConvOptions) + return x + + # TransposeConvOptionsT + def _UnPack(self, transposeConvOptions): + if transposeConvOptions is None: + return + self.padding = transposeConvOptions.Padding() + self.strideW = transposeConvOptions.StrideW() + self.strideH = transposeConvOptions.StrideH() + self.fusedActivationFunction = transposeConvOptions.FusedActivationFunction() + self.quantizedBiasType = transposeConvOptions.QuantizedBiasType() + + # TransposeConvOptionsT + def Pack(self, builder): + TransposeConvOptionsStart(builder) + TransposeConvOptionsAddPadding(builder, self.padding) + TransposeConvOptionsAddStrideW(builder, self.strideW) + TransposeConvOptionsAddStrideH(builder, self.strideH) + TransposeConvOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + TransposeConvOptionsAddQuantizedBiasType(builder, self.quantizedBiasType) + transposeConvOptions = TransposeConvOptionsEnd(builder) + return transposeConvOptions + + +class ExpandDimsOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ExpandDimsOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsExpandDimsOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ExpandDimsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ExpandDimsOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ExpandDimsOptionsStart(builder): + builder.StartObject(0) + + +def ExpandDimsOptionsEnd(builder): + return builder.EndObject() + + +class ExpandDimsOptionsT(object): + + # ExpandDimsOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + expandDimsOptions = ExpandDimsOptions() + expandDimsOptions.Init(buf, pos) + return cls.InitFromObj(expandDimsOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, expandDimsOptions): + x = ExpandDimsOptionsT() + x._UnPack(expandDimsOptions) + return x + + # ExpandDimsOptionsT + def _UnPack(self, expandDimsOptions): + if expandDimsOptions is None: + return + + # ExpandDimsOptionsT + def Pack(self, builder): + ExpandDimsOptionsStart(builder) + expandDimsOptions = ExpandDimsOptionsEnd(builder) + return expandDimsOptions + + +class SparseToDenseOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SparseToDenseOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSparseToDenseOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SparseToDenseOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SparseToDenseOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SparseToDenseOptions + def ValidateIndices(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def SparseToDenseOptionsStart(builder): + builder.StartObject(1) + + +def SparseToDenseOptionsAddValidateIndices(builder, validateIndices): + builder.PrependBoolSlot(0, validateIndices, 0) + + +def SparseToDenseOptionsEnd(builder): + return builder.EndObject() + + +class SparseToDenseOptionsT(object): + + # SparseToDenseOptionsT + def __init__(self): + self.validateIndices = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + sparseToDenseOptions = SparseToDenseOptions() + sparseToDenseOptions.Init(buf, pos) + return cls.InitFromObj(sparseToDenseOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, sparseToDenseOptions): + x = SparseToDenseOptionsT() + x._UnPack(sparseToDenseOptions) + return x + + # SparseToDenseOptionsT + def _UnPack(self, sparseToDenseOptions): + if sparseToDenseOptions is None: + return + self.validateIndices = sparseToDenseOptions.ValidateIndices() + + # SparseToDenseOptionsT + def Pack(self, builder): + SparseToDenseOptionsStart(builder) + SparseToDenseOptionsAddValidateIndices(builder, self.validateIndices) + sparseToDenseOptions = SparseToDenseOptionsEnd(builder) + return sparseToDenseOptions + + +class EqualOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = EqualOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsEqualOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def EqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # EqualOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def EqualOptionsStart(builder): + builder.StartObject(0) + + +def EqualOptionsEnd(builder): + return builder.EndObject() + + +class EqualOptionsT(object): + + # EqualOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + equalOptions = EqualOptions() + equalOptions.Init(buf, pos) + return cls.InitFromObj(equalOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, equalOptions): + x = EqualOptionsT() + x._UnPack(equalOptions) + return x + + # EqualOptionsT + def _UnPack(self, equalOptions): + if equalOptions is None: + return + + # EqualOptionsT + def Pack(self, builder): + EqualOptionsStart(builder) + equalOptions = EqualOptionsEnd(builder) + return equalOptions + + +class NotEqualOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NotEqualOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNotEqualOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def NotEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # NotEqualOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def NotEqualOptionsStart(builder): + builder.StartObject(0) + + +def NotEqualOptionsEnd(builder): + return builder.EndObject() + + +class NotEqualOptionsT(object): + + # NotEqualOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + notEqualOptions = NotEqualOptions() + notEqualOptions.Init(buf, pos) + return cls.InitFromObj(notEqualOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, notEqualOptions): + x = NotEqualOptionsT() + x._UnPack(notEqualOptions) + return x + + # NotEqualOptionsT + def _UnPack(self, notEqualOptions): + if notEqualOptions is None: + return + + # NotEqualOptionsT + def Pack(self, builder): + NotEqualOptionsStart(builder) + notEqualOptions = NotEqualOptionsEnd(builder) + return notEqualOptions + + +class ShapeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ShapeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsShapeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ShapeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ShapeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ShapeOptions + def OutType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def ShapeOptionsStart(builder): + builder.StartObject(1) + + +def ShapeOptionsAddOutType(builder, outType): + builder.PrependInt8Slot(0, outType, 0) + + +def ShapeOptionsEnd(builder): + return builder.EndObject() + + +class ShapeOptionsT(object): + + # ShapeOptionsT + def __init__(self): + self.outType = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + shapeOptions = ShapeOptions() + shapeOptions.Init(buf, pos) + return cls.InitFromObj(shapeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, shapeOptions): + x = ShapeOptionsT() + x._UnPack(shapeOptions) + return x + + # ShapeOptionsT + def _UnPack(self, shapeOptions): + if shapeOptions is None: + return + self.outType = shapeOptions.OutType() + + # ShapeOptionsT + def Pack(self, builder): + ShapeOptionsStart(builder) + ShapeOptionsAddOutType(builder, self.outType) + shapeOptions = ShapeOptionsEnd(builder) + return shapeOptions + + +class RankOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RankOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRankOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RankOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RankOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def RankOptionsStart(builder): + builder.StartObject(0) + + +def RankOptionsEnd(builder): + return builder.EndObject() + + +class RankOptionsT(object): + + # RankOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + rankOptions = RankOptions() + rankOptions.Init(buf, pos) + return cls.InitFromObj(rankOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rankOptions): + x = RankOptionsT() + x._UnPack(rankOptions) + return x + + # RankOptionsT + def _UnPack(self, rankOptions): + if rankOptions is None: + return + + # RankOptionsT + def Pack(self, builder): + RankOptionsStart(builder) + rankOptions = RankOptionsEnd(builder) + return rankOptions + + +class PowOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PowOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPowOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def PowOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # PowOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def PowOptionsStart(builder): + builder.StartObject(0) + + +def PowOptionsEnd(builder): + return builder.EndObject() + + +class PowOptionsT(object): + + # PowOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + powOptions = PowOptions() + powOptions.Init(buf, pos) + return cls.InitFromObj(powOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, powOptions): + x = PowOptionsT() + x._UnPack(powOptions) + return x + + # PowOptionsT + def _UnPack(self, powOptions): + if powOptions is None: + return + + # PowOptionsT + def Pack(self, builder): + PowOptionsStart(builder) + powOptions = PowOptionsEnd(builder) + return powOptions + + +class FakeQuantOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FakeQuantOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsFakeQuantOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def FakeQuantOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # FakeQuantOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # FakeQuantOptions + def Min(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # FakeQuantOptions + def Max(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # FakeQuantOptions + def NumBits(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # FakeQuantOptions + def NarrowRange(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def FakeQuantOptionsStart(builder): + builder.StartObject(4) + + +def FakeQuantOptionsAddMin(builder, min): + builder.PrependFloat32Slot(0, min, 0.0) + + +def FakeQuantOptionsAddMax(builder, max): + builder.PrependFloat32Slot(1, max, 0.0) + + +def FakeQuantOptionsAddNumBits(builder, numBits): + builder.PrependInt32Slot(2, numBits, 0) + + +def FakeQuantOptionsAddNarrowRange(builder, narrowRange): + builder.PrependBoolSlot(3, narrowRange, 0) + + +def FakeQuantOptionsEnd(builder): + return builder.EndObject() + + +class FakeQuantOptionsT(object): + + # FakeQuantOptionsT + def __init__(self): + self.min = 0.0 # type: float + self.max = 0.0 # type: float + self.numBits = 0 # type: int + self.narrowRange = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + fakeQuantOptions = FakeQuantOptions() + fakeQuantOptions.Init(buf, pos) + return cls.InitFromObj(fakeQuantOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, fakeQuantOptions): + x = FakeQuantOptionsT() + x._UnPack(fakeQuantOptions) + return x + + # FakeQuantOptionsT + def _UnPack(self, fakeQuantOptions): + if fakeQuantOptions is None: + return + self.min = fakeQuantOptions.Min() + self.max = fakeQuantOptions.Max() + self.numBits = fakeQuantOptions.NumBits() + self.narrowRange = fakeQuantOptions.NarrowRange() + + # FakeQuantOptionsT + def Pack(self, builder): + FakeQuantOptionsStart(builder) + FakeQuantOptionsAddMin(builder, self.min) + FakeQuantOptionsAddMax(builder, self.max) + FakeQuantOptionsAddNumBits(builder, self.numBits) + FakeQuantOptionsAddNarrowRange(builder, self.narrowRange) + fakeQuantOptions = FakeQuantOptionsEnd(builder) + return fakeQuantOptions + + +class PackOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PackOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPackOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def PackOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # PackOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # PackOptions + def ValuesCount(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # PackOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def PackOptionsStart(builder): + builder.StartObject(2) + + +def PackOptionsAddValuesCount(builder, valuesCount): + builder.PrependInt32Slot(0, valuesCount, 0) + + +def PackOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(1, axis, 0) + + +def PackOptionsEnd(builder): + return builder.EndObject() + + +class PackOptionsT(object): + + # PackOptionsT + def __init__(self): + self.valuesCount = 0 # type: int + self.axis = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + packOptions = PackOptions() + packOptions.Init(buf, pos) + return cls.InitFromObj(packOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, packOptions): + x = PackOptionsT() + x._UnPack(packOptions) + return x + + # PackOptionsT + def _UnPack(self, packOptions): + if packOptions is None: + return + self.valuesCount = packOptions.ValuesCount() + self.axis = packOptions.Axis() + + # PackOptionsT + def Pack(self, builder): + PackOptionsStart(builder) + PackOptionsAddValuesCount(builder, self.valuesCount) + PackOptionsAddAxis(builder, self.axis) + packOptions = PackOptionsEnd(builder) + return packOptions + + +class LogicalOrOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalOrOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLogicalOrOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LogicalOrOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LogicalOrOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalOrOptionsStart(builder): + builder.StartObject(0) + + +def LogicalOrOptionsEnd(builder): + return builder.EndObject() + + +class LogicalOrOptionsT(object): + + # LogicalOrOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + logicalOrOptions = LogicalOrOptions() + logicalOrOptions.Init(buf, pos) + return cls.InitFromObj(logicalOrOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, logicalOrOptions): + x = LogicalOrOptionsT() + x._UnPack(logicalOrOptions) + return x + + # LogicalOrOptionsT + def _UnPack(self, logicalOrOptions): + if logicalOrOptions is None: + return + + # LogicalOrOptionsT + def Pack(self, builder): + LogicalOrOptionsStart(builder) + logicalOrOptions = LogicalOrOptionsEnd(builder) + return logicalOrOptions + + +class OneHotOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = OneHotOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsOneHotOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def OneHotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # OneHotOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # OneHotOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def OneHotOptionsStart(builder): + builder.StartObject(1) + + +def OneHotOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(0, axis, 0) + + +def OneHotOptionsEnd(builder): + return builder.EndObject() + + +class OneHotOptionsT(object): + + # OneHotOptionsT + def __init__(self): + self.axis = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + oneHotOptions = OneHotOptions() + oneHotOptions.Init(buf, pos) + return cls.InitFromObj(oneHotOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, oneHotOptions): + x = OneHotOptionsT() + x._UnPack(oneHotOptions) + return x + + # OneHotOptionsT + def _UnPack(self, oneHotOptions): + if oneHotOptions is None: + return + self.axis = oneHotOptions.Axis() + + # OneHotOptionsT + def Pack(self, builder): + OneHotOptionsStart(builder) + OneHotOptionsAddAxis(builder, self.axis) + oneHotOptions = OneHotOptionsEnd(builder) + return oneHotOptions + + +class AbsOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AbsOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAbsOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def AbsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # AbsOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def AbsOptionsStart(builder): + builder.StartObject(0) + + +def AbsOptionsEnd(builder): + return builder.EndObject() + + +class AbsOptionsT(object): + + # AbsOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + absOptions = AbsOptions() + absOptions.Init(buf, pos) + return cls.InitFromObj(absOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, absOptions): + x = AbsOptionsT() + x._UnPack(absOptions) + return x + + # AbsOptionsT + def _UnPack(self, absOptions): + if absOptions is None: + return + + # AbsOptionsT + def Pack(self, builder): + AbsOptionsStart(builder) + absOptions = AbsOptionsEnd(builder) + return absOptions + + +class HardSwishOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = HardSwishOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsHardSwishOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def HardSwishOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # HardSwishOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def HardSwishOptionsStart(builder): + builder.StartObject(0) + + +def HardSwishOptionsEnd(builder): + return builder.EndObject() + + +class HardSwishOptionsT(object): + + # HardSwishOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + hardSwishOptions = HardSwishOptions() + hardSwishOptions.Init(buf, pos) + return cls.InitFromObj(hardSwishOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, hardSwishOptions): + x = HardSwishOptionsT() + x._UnPack(hardSwishOptions) + return x + + # HardSwishOptionsT + def _UnPack(self, hardSwishOptions): + if hardSwishOptions is None: + return + + # HardSwishOptionsT + def Pack(self, builder): + HardSwishOptionsStart(builder) + hardSwishOptions = HardSwishOptionsEnd(builder) + return hardSwishOptions + + +class LogicalAndOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalAndOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLogicalAndOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LogicalAndOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LogicalAndOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalAndOptionsStart(builder): + builder.StartObject(0) + + +def LogicalAndOptionsEnd(builder): + return builder.EndObject() + + +class LogicalAndOptionsT(object): + + # LogicalAndOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + logicalAndOptions = LogicalAndOptions() + logicalAndOptions.Init(buf, pos) + return cls.InitFromObj(logicalAndOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, logicalAndOptions): + x = LogicalAndOptionsT() + x._UnPack(logicalAndOptions) + return x + + # LogicalAndOptionsT + def _UnPack(self, logicalAndOptions): + if logicalAndOptions is None: + return + + # LogicalAndOptionsT + def Pack(self, builder): + LogicalAndOptionsStart(builder) + logicalAndOptions = LogicalAndOptionsEnd(builder) + return logicalAndOptions + + +class LogicalNotOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalNotOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLogicalNotOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LogicalNotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LogicalNotOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalNotOptionsStart(builder): + builder.StartObject(0) + + +def LogicalNotOptionsEnd(builder): + return builder.EndObject() + + +class LogicalNotOptionsT(object): + + # LogicalNotOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + logicalNotOptions = LogicalNotOptions() + logicalNotOptions.Init(buf, pos) + return cls.InitFromObj(logicalNotOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, logicalNotOptions): + x = LogicalNotOptionsT() + x._UnPack(logicalNotOptions) + return x + + # LogicalNotOptionsT + def _UnPack(self, logicalNotOptions): + if logicalNotOptions is None: + return + + # LogicalNotOptionsT + def Pack(self, builder): + LogicalNotOptionsStart(builder) + logicalNotOptions = LogicalNotOptionsEnd(builder) + return logicalNotOptions + + +class UnpackOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnpackOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnpackOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnpackOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnpackOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # UnpackOptions + def Num(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # UnpackOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def UnpackOptionsStart(builder): + builder.StartObject(2) + + +def UnpackOptionsAddNum(builder, num): + builder.PrependInt32Slot(0, num, 0) + + +def UnpackOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(1, axis, 0) + + +def UnpackOptionsEnd(builder): + return builder.EndObject() + + +class UnpackOptionsT(object): + + # UnpackOptionsT + def __init__(self): + self.num = 0 # type: int + self.axis = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + unpackOptions = UnpackOptions() + unpackOptions.Init(buf, pos) + return cls.InitFromObj(unpackOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unpackOptions): + x = UnpackOptionsT() + x._UnPack(unpackOptions) + return x + + # UnpackOptionsT + def _UnPack(self, unpackOptions): + if unpackOptions is None: + return + self.num = unpackOptions.Num() + self.axis = unpackOptions.Axis() + + # UnpackOptionsT + def Pack(self, builder): + UnpackOptionsStart(builder) + UnpackOptionsAddNum(builder, self.num) + UnpackOptionsAddAxis(builder, self.axis) + unpackOptions = UnpackOptionsEnd(builder) + return unpackOptions + + +class FloorDivOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FloorDivOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsFloorDivOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def FloorDivOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # FloorDivOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def FloorDivOptionsStart(builder): + builder.StartObject(0) + + +def FloorDivOptionsEnd(builder): + return builder.EndObject() + + +class FloorDivOptionsT(object): + + # FloorDivOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + floorDivOptions = FloorDivOptions() + floorDivOptions.Init(buf, pos) + return cls.InitFromObj(floorDivOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, floorDivOptions): + x = FloorDivOptionsT() + x._UnPack(floorDivOptions) + return x + + # FloorDivOptionsT + def _UnPack(self, floorDivOptions): + if floorDivOptions is None: + return + + # FloorDivOptionsT + def Pack(self, builder): + FloorDivOptionsStart(builder) + floorDivOptions = FloorDivOptionsEnd(builder) + return floorDivOptions + + +class SquareOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SquareOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSquareOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SquareOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SquareOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SquareOptionsStart(builder): + builder.StartObject(0) + + +def SquareOptionsEnd(builder): + return builder.EndObject() + + +class SquareOptionsT(object): + + # SquareOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + squareOptions = SquareOptions() + squareOptions.Init(buf, pos) + return cls.InitFromObj(squareOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, squareOptions): + x = SquareOptionsT() + x._UnPack(squareOptions) + return x + + # SquareOptionsT + def _UnPack(self, squareOptions): + if squareOptions is None: + return + + # SquareOptionsT + def Pack(self, builder): + SquareOptionsStart(builder) + squareOptions = SquareOptionsEnd(builder) + return squareOptions + + +class ZerosLikeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ZerosLikeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsZerosLikeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ZerosLikeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ZerosLikeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ZerosLikeOptionsStart(builder): + builder.StartObject(0) + + +def ZerosLikeOptionsEnd(builder): + return builder.EndObject() + + +class ZerosLikeOptionsT(object): + + # ZerosLikeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + zerosLikeOptions = ZerosLikeOptions() + zerosLikeOptions.Init(buf, pos) + return cls.InitFromObj(zerosLikeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, zerosLikeOptions): + x = ZerosLikeOptionsT() + x._UnPack(zerosLikeOptions) + return x + + # ZerosLikeOptionsT + def _UnPack(self, zerosLikeOptions): + if zerosLikeOptions is None: + return + + # ZerosLikeOptionsT + def Pack(self, builder): + ZerosLikeOptionsStart(builder) + zerosLikeOptions = ZerosLikeOptionsEnd(builder) + return zerosLikeOptions + + +class FillOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FillOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsFillOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def FillOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # FillOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def FillOptionsStart(builder): + builder.StartObject(0) + + +def FillOptionsEnd(builder): + return builder.EndObject() + + +class FillOptionsT(object): + + # FillOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + fillOptions = FillOptions() + fillOptions.Init(buf, pos) + return cls.InitFromObj(fillOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, fillOptions): + x = FillOptionsT() + x._UnPack(fillOptions) + return x + + # FillOptionsT + def _UnPack(self, fillOptions): + if fillOptions is None: + return + + # FillOptionsT + def Pack(self, builder): + FillOptionsStart(builder) + fillOptions = FillOptionsEnd(builder) + return fillOptions + + +class FloorModOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FloorModOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsFloorModOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def FloorModOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # FloorModOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def FloorModOptionsStart(builder): + builder.StartObject(0) + + +def FloorModOptionsEnd(builder): + return builder.EndObject() + + +class FloorModOptionsT(object): + + # FloorModOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + floorModOptions = FloorModOptions() + floorModOptions.Init(buf, pos) + return cls.InitFromObj(floorModOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, floorModOptions): + x = FloorModOptionsT() + x._UnPack(floorModOptions) + return x + + # FloorModOptionsT + def _UnPack(self, floorModOptions): + if floorModOptions is None: + return + + # FloorModOptionsT + def Pack(self, builder): + FloorModOptionsStart(builder) + floorModOptions = FloorModOptionsEnd(builder) + return floorModOptions + + +class RangeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RangeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRangeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RangeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RangeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def RangeOptionsStart(builder): + builder.StartObject(0) + + +def RangeOptionsEnd(builder): + return builder.EndObject() + + +class RangeOptionsT(object): + + # RangeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + rangeOptions = RangeOptions() + rangeOptions.Init(buf, pos) + return cls.InitFromObj(rangeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rangeOptions): + x = RangeOptionsT() + x._UnPack(rangeOptions) + return x + + # RangeOptionsT + def _UnPack(self, rangeOptions): + if rangeOptions is None: + return + + # RangeOptionsT + def Pack(self, builder): + RangeOptionsStart(builder) + rangeOptions = RangeOptionsEnd(builder) + return rangeOptions + + +class LeakyReluOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LeakyReluOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsLeakyReluOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def LeakyReluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # LeakyReluOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # LeakyReluOptions + def Alpha(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + +def LeakyReluOptionsStart(builder): + builder.StartObject(1) + + +def LeakyReluOptionsAddAlpha(builder, alpha): + builder.PrependFloat32Slot(0, alpha, 0.0) + + +def LeakyReluOptionsEnd(builder): + return builder.EndObject() + + +class LeakyReluOptionsT(object): + + # LeakyReluOptionsT + def __init__(self): + self.alpha = 0.0 # type: float + + @classmethod + def InitFromBuf(cls, buf, pos): + leakyReluOptions = LeakyReluOptions() + leakyReluOptions.Init(buf, pos) + return cls.InitFromObj(leakyReluOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, leakyReluOptions): + x = LeakyReluOptionsT() + x._UnPack(leakyReluOptions) + return x + + # LeakyReluOptionsT + def _UnPack(self, leakyReluOptions): + if leakyReluOptions is None: + return + self.alpha = leakyReluOptions.Alpha() + + # LeakyReluOptionsT + def Pack(self, builder): + LeakyReluOptionsStart(builder) + LeakyReluOptionsAddAlpha(builder, self.alpha) + leakyReluOptions = LeakyReluOptionsEnd(builder) + return leakyReluOptions + + +class SquaredDifferenceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SquaredDifferenceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSquaredDifferenceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SquaredDifferenceOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SquaredDifferenceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SquaredDifferenceOptionsStart(builder): + builder.StartObject(0) + + +def SquaredDifferenceOptionsEnd(builder): + return builder.EndObject() + + +class SquaredDifferenceOptionsT(object): + + # SquaredDifferenceOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + squaredDifferenceOptions = SquaredDifferenceOptions() + squaredDifferenceOptions.Init(buf, pos) + return cls.InitFromObj(squaredDifferenceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, squaredDifferenceOptions): + x = SquaredDifferenceOptionsT() + x._UnPack(squaredDifferenceOptions) + return x + + # SquaredDifferenceOptionsT + def _UnPack(self, squaredDifferenceOptions): + if squaredDifferenceOptions is None: + return + + # SquaredDifferenceOptionsT + def Pack(self, builder): + SquaredDifferenceOptionsStart(builder) + squaredDifferenceOptions = SquaredDifferenceOptionsEnd(builder) + return squaredDifferenceOptions + + +class MirrorPadOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MirrorPadOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMirrorPadOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MirrorPadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MirrorPadOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MirrorPadOptions + def Mode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def MirrorPadOptionsStart(builder): + builder.StartObject(1) + + +def MirrorPadOptionsAddMode(builder, mode): + builder.PrependInt8Slot(0, mode, 0) + + +def MirrorPadOptionsEnd(builder): + return builder.EndObject() + + +class MirrorPadOptionsT(object): + + # MirrorPadOptionsT + def __init__(self): + self.mode = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + mirrorPadOptions = MirrorPadOptions() + mirrorPadOptions.Init(buf, pos) + return cls.InitFromObj(mirrorPadOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, mirrorPadOptions): + x = MirrorPadOptionsT() + x._UnPack(mirrorPadOptions) + return x + + # MirrorPadOptionsT + def _UnPack(self, mirrorPadOptions): + if mirrorPadOptions is None: + return + self.mode = mirrorPadOptions.Mode() + + # MirrorPadOptionsT + def Pack(self, builder): + MirrorPadOptionsStart(builder) + MirrorPadOptionsAddMode(builder, self.mode) + mirrorPadOptions = MirrorPadOptionsEnd(builder) + return mirrorPadOptions + + +class UniqueOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UniqueOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUniqueOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UniqueOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UniqueOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # UniqueOptions + def IdxOutType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 2 + + +def UniqueOptionsStart(builder): + builder.StartObject(1) + + +def UniqueOptionsAddIdxOutType(builder, idxOutType): + builder.PrependInt8Slot(0, idxOutType, 2) + + +def UniqueOptionsEnd(builder): + return builder.EndObject() + + +class UniqueOptionsT(object): + + # UniqueOptionsT + def __init__(self): + self.idxOutType = 2 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + uniqueOptions = UniqueOptions() + uniqueOptions.Init(buf, pos) + return cls.InitFromObj(uniqueOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, uniqueOptions): + x = UniqueOptionsT() + x._UnPack(uniqueOptions) + return x + + # UniqueOptionsT + def _UnPack(self, uniqueOptions): + if uniqueOptions is None: + return + self.idxOutType = uniqueOptions.IdxOutType() + + # UniqueOptionsT + def Pack(self, builder): + UniqueOptionsStart(builder) + UniqueOptionsAddIdxOutType(builder, self.idxOutType) + uniqueOptions = UniqueOptionsEnd(builder) + return uniqueOptions + + +class ReverseV2Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReverseV2Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReverseV2Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReverseV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReverseV2Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ReverseV2OptionsStart(builder): + builder.StartObject(0) + + +def ReverseV2OptionsEnd(builder): + return builder.EndObject() + + +class ReverseV2OptionsT(object): + + # ReverseV2OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + reverseV2Options = ReverseV2Options() + reverseV2Options.Init(buf, pos) + return cls.InitFromObj(reverseV2Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, reverseV2Options): + x = ReverseV2OptionsT() + x._UnPack(reverseV2Options) + return x + + # ReverseV2OptionsT + def _UnPack(self, reverseV2Options): + if reverseV2Options is None: + return + + # ReverseV2OptionsT + def Pack(self, builder): + ReverseV2OptionsStart(builder) + reverseV2Options = ReverseV2OptionsEnd(builder) + return reverseV2Options + + +class AddNOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AddNOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAddNOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def AddNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # AddNOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def AddNOptionsStart(builder): + builder.StartObject(0) + + +def AddNOptionsEnd(builder): + return builder.EndObject() + + +class AddNOptionsT(object): + + # AddNOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + addNoptions = AddNOptions() + addNoptions.Init(buf, pos) + return cls.InitFromObj(addNoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, addNoptions): + x = AddNOptionsT() + x._UnPack(addNoptions) + return x + + # AddNOptionsT + def _UnPack(self, addNoptions): + if addNoptions is None: + return + + # AddNOptionsT + def Pack(self, builder): + AddNOptionsStart(builder) + addNoptions = AddNOptionsEnd(builder) + return addNoptions + + +class GatherNdOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GatherNdOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGatherNdOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GatherNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GatherNdOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def GatherNdOptionsStart(builder): + builder.StartObject(0) + + +def GatherNdOptionsEnd(builder): + return builder.EndObject() + + +class GatherNdOptionsT(object): + + # GatherNdOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + gatherNdOptions = GatherNdOptions() + gatherNdOptions.Init(buf, pos) + return cls.InitFromObj(gatherNdOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, gatherNdOptions): + x = GatherNdOptionsT() + x._UnPack(gatherNdOptions) + return x + + # GatherNdOptionsT + def _UnPack(self, gatherNdOptions): + if gatherNdOptions is None: + return + + # GatherNdOptionsT + def Pack(self, builder): + GatherNdOptionsStart(builder) + gatherNdOptions = GatherNdOptionsEnd(builder) + return gatherNdOptions + + +class WhereOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = WhereOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsWhereOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def WhereOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # WhereOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def WhereOptionsStart(builder): + builder.StartObject(0) + + +def WhereOptionsEnd(builder): + return builder.EndObject() + + +class WhereOptionsT(object): + + # WhereOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + whereOptions = WhereOptions() + whereOptions.Init(buf, pos) + return cls.InitFromObj(whereOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, whereOptions): + x = WhereOptionsT() + x._UnPack(whereOptions) + return x + + # WhereOptionsT + def _UnPack(self, whereOptions): + if whereOptions is None: + return + + # WhereOptionsT + def Pack(self, builder): + WhereOptionsStart(builder) + whereOptions = WhereOptionsEnd(builder) + return whereOptions + + +class ReverseSequenceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReverseSequenceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReverseSequenceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReverseSequenceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReverseSequenceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReverseSequenceOptions + def SeqDim(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ReverseSequenceOptions + def BatchDim(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def ReverseSequenceOptionsStart(builder): + builder.StartObject(2) + + +def ReverseSequenceOptionsAddSeqDim(builder, seqDim): + builder.PrependInt32Slot(0, seqDim, 0) + + +def ReverseSequenceOptionsAddBatchDim(builder, batchDim): + builder.PrependInt32Slot(1, batchDim, 0) + + +def ReverseSequenceOptionsEnd(builder): + return builder.EndObject() + + +class ReverseSequenceOptionsT(object): + + # ReverseSequenceOptionsT + def __init__(self): + self.seqDim = 0 # type: int + self.batchDim = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + reverseSequenceOptions = ReverseSequenceOptions() + reverseSequenceOptions.Init(buf, pos) + return cls.InitFromObj(reverseSequenceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, reverseSequenceOptions): + x = ReverseSequenceOptionsT() + x._UnPack(reverseSequenceOptions) + return x + + # ReverseSequenceOptionsT + def _UnPack(self, reverseSequenceOptions): + if reverseSequenceOptions is None: + return + self.seqDim = reverseSequenceOptions.SeqDim() + self.batchDim = reverseSequenceOptions.BatchDim() + + # ReverseSequenceOptionsT + def Pack(self, builder): + ReverseSequenceOptionsStart(builder) + ReverseSequenceOptionsAddSeqDim(builder, self.seqDim) + ReverseSequenceOptionsAddBatchDim(builder, self.batchDim) + reverseSequenceOptions = ReverseSequenceOptionsEnd(builder) + return reverseSequenceOptions + + +class MatrixDiagOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MatrixDiagOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMatrixDiagOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MatrixDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MatrixDiagOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def MatrixDiagOptionsStart(builder): + builder.StartObject(0) + + +def MatrixDiagOptionsEnd(builder): + return builder.EndObject() + + +class MatrixDiagOptionsT(object): + + # MatrixDiagOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + matrixDiagOptions = MatrixDiagOptions() + matrixDiagOptions.Init(buf, pos) + return cls.InitFromObj(matrixDiagOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, matrixDiagOptions): + x = MatrixDiagOptionsT() + x._UnPack(matrixDiagOptions) + return x + + # MatrixDiagOptionsT + def _UnPack(self, matrixDiagOptions): + if matrixDiagOptions is None: + return + + # MatrixDiagOptionsT + def Pack(self, builder): + MatrixDiagOptionsStart(builder) + matrixDiagOptions = MatrixDiagOptionsEnd(builder) + return matrixDiagOptions + + +class QuantizeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = QuantizeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsQuantizeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def QuantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # QuantizeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def QuantizeOptionsStart(builder): + builder.StartObject(0) + + +def QuantizeOptionsEnd(builder): + return builder.EndObject() + + +class QuantizeOptionsT(object): + + # QuantizeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + quantizeOptions = QuantizeOptions() + quantizeOptions.Init(buf, pos) + return cls.InitFromObj(quantizeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, quantizeOptions): + x = QuantizeOptionsT() + x._UnPack(quantizeOptions) + return x + + # QuantizeOptionsT + def _UnPack(self, quantizeOptions): + if quantizeOptions is None: + return + + # QuantizeOptionsT + def Pack(self, builder): + QuantizeOptionsStart(builder) + quantizeOptions = QuantizeOptionsEnd(builder) + return quantizeOptions + + +class MatrixSetDiagOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MatrixSetDiagOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMatrixSetDiagOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MatrixSetDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # MatrixSetDiagOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def MatrixSetDiagOptionsStart(builder): + builder.StartObject(0) + + +def MatrixSetDiagOptionsEnd(builder): + return builder.EndObject() + + +class MatrixSetDiagOptionsT(object): + + # MatrixSetDiagOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + matrixSetDiagOptions = MatrixSetDiagOptions() + matrixSetDiagOptions.Init(buf, pos) + return cls.InitFromObj(matrixSetDiagOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, matrixSetDiagOptions): + x = MatrixSetDiagOptionsT() + x._UnPack(matrixSetDiagOptions) + return x + + # MatrixSetDiagOptionsT + def _UnPack(self, matrixSetDiagOptions): + if matrixSetDiagOptions is None: + return + + # MatrixSetDiagOptionsT + def Pack(self, builder): + MatrixSetDiagOptionsStart(builder) + matrixSetDiagOptions = MatrixSetDiagOptionsEnd(builder) + return matrixSetDiagOptions + + +class IfOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = IfOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsIfOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def IfOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # IfOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # IfOptions + def ThenSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # IfOptions + def ElseSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def IfOptionsStart(builder): + builder.StartObject(2) + + +def IfOptionsAddThenSubgraphIndex(builder, thenSubgraphIndex): + builder.PrependInt32Slot(0, thenSubgraphIndex, 0) + + +def IfOptionsAddElseSubgraphIndex(builder, elseSubgraphIndex): + builder.PrependInt32Slot(1, elseSubgraphIndex, 0) + + +def IfOptionsEnd(builder): + return builder.EndObject() + + +class IfOptionsT(object): + + # IfOptionsT + def __init__(self): + self.thenSubgraphIndex = 0 # type: int + self.elseSubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + ifOptions = IfOptions() + ifOptions.Init(buf, pos) + return cls.InitFromObj(ifOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, ifOptions): + x = IfOptionsT() + x._UnPack(ifOptions) + return x + + # IfOptionsT + def _UnPack(self, ifOptions): + if ifOptions is None: + return + self.thenSubgraphIndex = ifOptions.ThenSubgraphIndex() + self.elseSubgraphIndex = ifOptions.ElseSubgraphIndex() + + # IfOptionsT + def Pack(self, builder): + IfOptionsStart(builder) + IfOptionsAddThenSubgraphIndex(builder, self.thenSubgraphIndex) + IfOptionsAddElseSubgraphIndex(builder, self.elseSubgraphIndex) + ifOptions = IfOptionsEnd(builder) + return ifOptions + + +class CallOnceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CallOnceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCallOnceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CallOnceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CallOnceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CallOnceOptions + def InitSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def CallOnceOptionsStart(builder): + builder.StartObject(1) + + +def CallOnceOptionsAddInitSubgraphIndex(builder, initSubgraphIndex): + builder.PrependInt32Slot(0, initSubgraphIndex, 0) + + +def CallOnceOptionsEnd(builder): + return builder.EndObject() + + +class CallOnceOptionsT(object): + + # CallOnceOptionsT + def __init__(self): + self.initSubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + callOnceOptions = CallOnceOptions() + callOnceOptions.Init(buf, pos) + return cls.InitFromObj(callOnceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, callOnceOptions): + x = CallOnceOptionsT() + x._UnPack(callOnceOptions) + return x + + # CallOnceOptionsT + def _UnPack(self, callOnceOptions): + if callOnceOptions is None: + return + self.initSubgraphIndex = callOnceOptions.InitSubgraphIndex() + + # CallOnceOptionsT + def Pack(self, builder): + CallOnceOptionsStart(builder) + CallOnceOptionsAddInitSubgraphIndex(builder, self.initSubgraphIndex) + callOnceOptions = CallOnceOptionsEnd(builder) + return callOnceOptions + + +class WhileOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = WhileOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsWhileOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def WhileOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # WhileOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # WhileOptions + def CondSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # WhileOptions + def BodySubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def WhileOptionsStart(builder): + builder.StartObject(2) + + +def WhileOptionsAddCondSubgraphIndex(builder, condSubgraphIndex): + builder.PrependInt32Slot(0, condSubgraphIndex, 0) + + +def WhileOptionsAddBodySubgraphIndex(builder, bodySubgraphIndex): + builder.PrependInt32Slot(1, bodySubgraphIndex, 0) + + +def WhileOptionsEnd(builder): + return builder.EndObject() + + +class WhileOptionsT(object): + + # WhileOptionsT + def __init__(self): + self.condSubgraphIndex = 0 # type: int + self.bodySubgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + whileOptions = WhileOptions() + whileOptions.Init(buf, pos) + return cls.InitFromObj(whileOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, whileOptions): + x = WhileOptionsT() + x._UnPack(whileOptions) + return x + + # WhileOptionsT + def _UnPack(self, whileOptions): + if whileOptions is None: + return + self.condSubgraphIndex = whileOptions.CondSubgraphIndex() + self.bodySubgraphIndex = whileOptions.BodySubgraphIndex() + + # WhileOptionsT + def Pack(self, builder): + WhileOptionsStart(builder) + WhileOptionsAddCondSubgraphIndex(builder, self.condSubgraphIndex) + WhileOptionsAddBodySubgraphIndex(builder, self.bodySubgraphIndex) + whileOptions = WhileOptionsEnd(builder) + return whileOptions + + +class NonMaxSuppressionV4Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NonMaxSuppressionV4Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNonMaxSuppressionV4Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def NonMaxSuppressionV4OptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # NonMaxSuppressionV4Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def NonMaxSuppressionV4OptionsStart(builder): + builder.StartObject(0) + + +def NonMaxSuppressionV4OptionsEnd(builder): + return builder.EndObject() + + +class NonMaxSuppressionV4OptionsT(object): + + # NonMaxSuppressionV4OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + nonMaxSuppressionV4Options = NonMaxSuppressionV4Options() + nonMaxSuppressionV4Options.Init(buf, pos) + return cls.InitFromObj(nonMaxSuppressionV4Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, nonMaxSuppressionV4Options): + x = NonMaxSuppressionV4OptionsT() + x._UnPack(nonMaxSuppressionV4Options) + return x + + # NonMaxSuppressionV4OptionsT + def _UnPack(self, nonMaxSuppressionV4Options): + if nonMaxSuppressionV4Options is None: + return + + # NonMaxSuppressionV4OptionsT + def Pack(self, builder): + NonMaxSuppressionV4OptionsStart(builder) + nonMaxSuppressionV4Options = NonMaxSuppressionV4OptionsEnd(builder) + return nonMaxSuppressionV4Options + + +class NonMaxSuppressionV5Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NonMaxSuppressionV5Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNonMaxSuppressionV5Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def NonMaxSuppressionV5OptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # NonMaxSuppressionV5Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def NonMaxSuppressionV5OptionsStart(builder): + builder.StartObject(0) + + +def NonMaxSuppressionV5OptionsEnd(builder): + return builder.EndObject() + + +class NonMaxSuppressionV5OptionsT(object): + + # NonMaxSuppressionV5OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + nonMaxSuppressionV5Options = NonMaxSuppressionV5Options() + nonMaxSuppressionV5Options.Init(buf, pos) + return cls.InitFromObj(nonMaxSuppressionV5Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, nonMaxSuppressionV5Options): + x = NonMaxSuppressionV5OptionsT() + x._UnPack(nonMaxSuppressionV5Options) + return x + + # NonMaxSuppressionV5OptionsT + def _UnPack(self, nonMaxSuppressionV5Options): + if nonMaxSuppressionV5Options is None: + return + + # NonMaxSuppressionV5OptionsT + def Pack(self, builder): + NonMaxSuppressionV5OptionsStart(builder) + nonMaxSuppressionV5Options = NonMaxSuppressionV5OptionsEnd(builder) + return nonMaxSuppressionV5Options + + +class ScatterNdOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ScatterNdOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsScatterNdOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ScatterNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ScatterNdOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ScatterNdOptionsStart(builder): + builder.StartObject(0) + + +def ScatterNdOptionsEnd(builder): + return builder.EndObject() + + +class ScatterNdOptionsT(object): + + # ScatterNdOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + scatterNdOptions = ScatterNdOptions() + scatterNdOptions.Init(buf, pos) + return cls.InitFromObj(scatterNdOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, scatterNdOptions): + x = ScatterNdOptionsT() + x._UnPack(scatterNdOptions) + return x + + # ScatterNdOptionsT + def _UnPack(self, scatterNdOptions): + if scatterNdOptions is None: + return + + # ScatterNdOptionsT + def Pack(self, builder): + ScatterNdOptionsStart(builder) + scatterNdOptions = ScatterNdOptionsEnd(builder) + return scatterNdOptions + + +class SelectV2Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SelectV2Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSelectV2Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SelectV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SelectV2Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SelectV2OptionsStart(builder): + builder.StartObject(0) + + +def SelectV2OptionsEnd(builder): + return builder.EndObject() + + +class SelectV2OptionsT(object): + + # SelectV2OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + selectV2Options = SelectV2Options() + selectV2Options.Init(buf, pos) + return cls.InitFromObj(selectV2Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, selectV2Options): + x = SelectV2OptionsT() + x._UnPack(selectV2Options) + return x + + # SelectV2OptionsT + def _UnPack(self, selectV2Options): + if selectV2Options is None: + return + + # SelectV2OptionsT + def Pack(self, builder): + SelectV2OptionsStart(builder) + selectV2Options = SelectV2OptionsEnd(builder) + return selectV2Options + + +class DensifyOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DensifyOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDensifyOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DensifyOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DensifyOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def DensifyOptionsStart(builder): + builder.StartObject(0) + + +def DensifyOptionsEnd(builder): + return builder.EndObject() + + +class DensifyOptionsT(object): + + # DensifyOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + densifyOptions = DensifyOptions() + densifyOptions.Init(buf, pos) + return cls.InitFromObj(densifyOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, densifyOptions): + x = DensifyOptionsT() + x._UnPack(densifyOptions) + return x + + # DensifyOptionsT + def _UnPack(self, densifyOptions): + if densifyOptions is None: + return + + # DensifyOptionsT + def Pack(self, builder): + DensifyOptionsStart(builder) + densifyOptions = DensifyOptionsEnd(builder) + return densifyOptions + + +class SegmentSumOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SegmentSumOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSegmentSumOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SegmentSumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SegmentSumOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SegmentSumOptionsStart(builder): + builder.StartObject(0) + + +def SegmentSumOptionsEnd(builder): + return builder.EndObject() + + +class SegmentSumOptionsT(object): + + # SegmentSumOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + segmentSumOptions = SegmentSumOptions() + segmentSumOptions.Init(buf, pos) + return cls.InitFromObj(segmentSumOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, segmentSumOptions): + x = SegmentSumOptionsT() + x._UnPack(segmentSumOptions) + return x + + # SegmentSumOptionsT + def _UnPack(self, segmentSumOptions): + if segmentSumOptions is None: + return + + # SegmentSumOptionsT + def Pack(self, builder): + SegmentSumOptionsStart(builder) + segmentSumOptions = SegmentSumOptionsEnd(builder) + return segmentSumOptions + + +class BatchMatMulOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BatchMatMulOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBatchMatMulOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BatchMatMulOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BatchMatMulOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BatchMatMulOptions + def AdjointLhs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # BatchMatMulOptions + def AdjointRhs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # BatchMatMulOptions + def AsymmetricQuantizeInputs(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def BatchMatMulOptionsStart(builder): + builder.StartObject(3) + + +def BatchMatMulOptionsAddAdjointLhs(builder, adjointLhs): + builder.PrependBoolSlot(0, adjointLhs, 0) + + +def BatchMatMulOptionsAddAdjointRhs(builder, adjointRhs): + builder.PrependBoolSlot(1, adjointRhs, 0) + + +def BatchMatMulOptionsAddAsymmetricQuantizeInputs(builder, asymmetricQuantizeInputs): + builder.PrependBoolSlot(2, asymmetricQuantizeInputs, 0) + + +def BatchMatMulOptionsEnd(builder): + return builder.EndObject() + + +class BatchMatMulOptionsT(object): + + # BatchMatMulOptionsT + def __init__(self): + self.adjointLhs = False # type: bool + self.adjointRhs = False # type: bool + self.asymmetricQuantizeInputs = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + batchMatMulOptions = BatchMatMulOptions() + batchMatMulOptions.Init(buf, pos) + return cls.InitFromObj(batchMatMulOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, batchMatMulOptions): + x = BatchMatMulOptionsT() + x._UnPack(batchMatMulOptions) + return x + + # BatchMatMulOptionsT + def _UnPack(self, batchMatMulOptions): + if batchMatMulOptions is None: + return + self.adjointLhs = batchMatMulOptions.AdjointLhs() + self.adjointRhs = batchMatMulOptions.AdjointRhs() + self.asymmetricQuantizeInputs = batchMatMulOptions.AsymmetricQuantizeInputs() + + # BatchMatMulOptionsT + def Pack(self, builder): + BatchMatMulOptionsStart(builder) + BatchMatMulOptionsAddAdjointLhs(builder, self.adjointLhs) + BatchMatMulOptionsAddAdjointRhs(builder, self.adjointRhs) + BatchMatMulOptionsAddAsymmetricQuantizeInputs(builder, + self.asymmetricQuantizeInputs) + batchMatMulOptions = BatchMatMulOptionsEnd(builder) + return batchMatMulOptions + + +class CumsumOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CumsumOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsCumsumOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def CumsumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # CumsumOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CumsumOptions + def Exclusive(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # CumsumOptions + def Reverse(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def CumsumOptionsStart(builder): + builder.StartObject(2) + + +def CumsumOptionsAddExclusive(builder, exclusive): + builder.PrependBoolSlot(0, exclusive, 0) + + +def CumsumOptionsAddReverse(builder, reverse): + builder.PrependBoolSlot(1, reverse, 0) + + +def CumsumOptionsEnd(builder): + return builder.EndObject() + + +class CumsumOptionsT(object): + + # CumsumOptionsT + def __init__(self): + self.exclusive = False # type: bool + self.reverse = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + cumsumOptions = CumsumOptions() + cumsumOptions.Init(buf, pos) + return cls.InitFromObj(cumsumOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, cumsumOptions): + x = CumsumOptionsT() + x._UnPack(cumsumOptions) + return x + + # CumsumOptionsT + def _UnPack(self, cumsumOptions): + if cumsumOptions is None: + return + self.exclusive = cumsumOptions.Exclusive() + self.reverse = cumsumOptions.Reverse() + + # CumsumOptionsT + def Pack(self, builder): + CumsumOptionsStart(builder) + CumsumOptionsAddExclusive(builder, self.exclusive) + CumsumOptionsAddReverse(builder, self.reverse) + cumsumOptions = CumsumOptionsEnd(builder) + return cumsumOptions + + +class BroadcastToOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BroadcastToOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBroadcastToOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BroadcastToOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BroadcastToOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def BroadcastToOptionsStart(builder): + builder.StartObject(0) + + +def BroadcastToOptionsEnd(builder): + return builder.EndObject() + + +class BroadcastToOptionsT(object): + + # BroadcastToOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + broadcastToOptions = BroadcastToOptions() + broadcastToOptions.Init(buf, pos) + return cls.InitFromObj(broadcastToOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, broadcastToOptions): + x = BroadcastToOptionsT() + x._UnPack(broadcastToOptions) + return x + + # BroadcastToOptionsT + def _UnPack(self, broadcastToOptions): + if broadcastToOptions is None: + return + + # BroadcastToOptionsT + def Pack(self, builder): + BroadcastToOptionsStart(builder) + broadcastToOptions = BroadcastToOptionsEnd(builder) + return broadcastToOptions + + +class Rfft2dOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Rfft2dOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRfft2dOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def Rfft2dOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Rfft2dOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def Rfft2dOptionsStart(builder): + builder.StartObject(0) + + +def Rfft2dOptionsEnd(builder): + return builder.EndObject() + + +class Rfft2dOptionsT(object): + + # Rfft2dOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + rfft2dOptions = Rfft2dOptions() + rfft2dOptions.Init(buf, pos) + return cls.InitFromObj(rfft2dOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rfft2dOptions): + x = Rfft2dOptionsT() + x._UnPack(rfft2dOptions) + return x + + # Rfft2dOptionsT + def _UnPack(self, rfft2dOptions): + if rfft2dOptions is None: + return + + # Rfft2dOptionsT + def Pack(self, builder): + Rfft2dOptionsStart(builder) + rfft2dOptions = Rfft2dOptionsEnd(builder) + return rfft2dOptions + + +class HashtableOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = HashtableOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsHashtableOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def HashtableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # HashtableOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # HashtableOptions + def TableId(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # HashtableOptions + def KeyDtype(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # HashtableOptions + def ValueDtype(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def HashtableOptionsStart(builder): + builder.StartObject(3) + + +def HashtableOptionsAddTableId(builder, tableId): + builder.PrependInt32Slot(0, tableId, 0) + + +def HashtableOptionsAddKeyDtype(builder, keyDtype): + builder.PrependInt8Slot(1, keyDtype, 0) + + +def HashtableOptionsAddValueDtype(builder, valueDtype): + builder.PrependInt8Slot(2, valueDtype, 0) + + +def HashtableOptionsEnd(builder): + return builder.EndObject() + + +class HashtableOptionsT(object): + + # HashtableOptionsT + def __init__(self): + self.tableId = 0 # type: int + self.keyDtype = 0 # type: int + self.valueDtype = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + hashtableOptions = HashtableOptions() + hashtableOptions.Init(buf, pos) + return cls.InitFromObj(hashtableOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, hashtableOptions): + x = HashtableOptionsT() + x._UnPack(hashtableOptions) + return x + + # HashtableOptionsT + def _UnPack(self, hashtableOptions): + if hashtableOptions is None: + return + self.tableId = hashtableOptions.TableId() + self.keyDtype = hashtableOptions.KeyDtype() + self.valueDtype = hashtableOptions.ValueDtype() + + # HashtableOptionsT + def Pack(self, builder): + HashtableOptionsStart(builder) + HashtableOptionsAddTableId(builder, self.tableId) + HashtableOptionsAddKeyDtype(builder, self.keyDtype) + HashtableOptionsAddValueDtype(builder, self.valueDtype) + hashtableOptions = HashtableOptionsEnd(builder) + return hashtableOptions + + +class HashtableFindOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = HashtableFindOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsHashtableFindOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def HashtableFindOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # HashtableFindOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def HashtableFindOptionsStart(builder): + builder.StartObject(0) + + +def HashtableFindOptionsEnd(builder): + return builder.EndObject() + + +class HashtableFindOptionsT(object): + + # HashtableFindOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + hashtableFindOptions = HashtableFindOptions() + hashtableFindOptions.Init(buf, pos) + return cls.InitFromObj(hashtableFindOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, hashtableFindOptions): + x = HashtableFindOptionsT() + x._UnPack(hashtableFindOptions) + return x + + # HashtableFindOptionsT + def _UnPack(self, hashtableFindOptions): + if hashtableFindOptions is None: + return + + # HashtableFindOptionsT + def Pack(self, builder): + HashtableFindOptionsStart(builder) + hashtableFindOptions = HashtableFindOptionsEnd(builder) + return hashtableFindOptions + + +class HashtableImportOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = HashtableImportOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsHashtableImportOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def HashtableImportOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # HashtableImportOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def HashtableImportOptionsStart(builder): + builder.StartObject(0) + + +def HashtableImportOptionsEnd(builder): + return builder.EndObject() + + +class HashtableImportOptionsT(object): + + # HashtableImportOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + hashtableImportOptions = HashtableImportOptions() + hashtableImportOptions.Init(buf, pos) + return cls.InitFromObj(hashtableImportOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, hashtableImportOptions): + x = HashtableImportOptionsT() + x._UnPack(hashtableImportOptions) + return x + + # HashtableImportOptionsT + def _UnPack(self, hashtableImportOptions): + if hashtableImportOptions is None: + return + + # HashtableImportOptionsT + def Pack(self, builder): + HashtableImportOptionsStart(builder) + hashtableImportOptions = HashtableImportOptionsEnd(builder) + return hashtableImportOptions + + +class HashtableSizeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = HashtableSizeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsHashtableSizeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def HashtableSizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # HashtableSizeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def HashtableSizeOptionsStart(builder): + builder.StartObject(0) + + +def HashtableSizeOptionsEnd(builder): + return builder.EndObject() + + +class HashtableSizeOptionsT(object): + + # HashtableSizeOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + hashtableSizeOptions = HashtableSizeOptions() + hashtableSizeOptions.Init(buf, pos) + return cls.InitFromObj(hashtableSizeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, hashtableSizeOptions): + x = HashtableSizeOptionsT() + x._UnPack(hashtableSizeOptions) + return x + + # HashtableSizeOptionsT + def _UnPack(self, hashtableSizeOptions): + if hashtableSizeOptions is None: + return + + # HashtableSizeOptionsT + def Pack(self, builder): + HashtableSizeOptionsStart(builder) + hashtableSizeOptions = HashtableSizeOptionsEnd(builder) + return hashtableSizeOptions + + +class VarHandleOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = VarHandleOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsVarHandleOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def VarHandleOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # VarHandleOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # VarHandleOptions + def Container(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # VarHandleOptions + def SharedName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + +def VarHandleOptionsStart(builder): + builder.StartObject(2) + + +def VarHandleOptionsAddContainer(builder, container): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(container), 0) + + +def VarHandleOptionsAddSharedName(builder, sharedName): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(sharedName), 0) + + +def VarHandleOptionsEnd(builder): + return builder.EndObject() + + +class VarHandleOptionsT(object): + + # VarHandleOptionsT + def __init__(self): + self.container = None # type: str + self.sharedName = None # type: str + + @classmethod + def InitFromBuf(cls, buf, pos): + varHandleOptions = VarHandleOptions() + varHandleOptions.Init(buf, pos) + return cls.InitFromObj(varHandleOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, varHandleOptions): + x = VarHandleOptionsT() + x._UnPack(varHandleOptions) + return x + + # VarHandleOptionsT + def _UnPack(self, varHandleOptions): + if varHandleOptions is None: + return + self.container = varHandleOptions.Container() + self.sharedName = varHandleOptions.SharedName() + + # VarHandleOptionsT + def Pack(self, builder): + if self.container is not None: + container = builder.CreateString(self.container) + if self.sharedName is not None: + sharedName = builder.CreateString(self.sharedName) + VarHandleOptionsStart(builder) + if self.container is not None: + VarHandleOptionsAddContainer(builder, container) + if self.sharedName is not None: + VarHandleOptionsAddSharedName(builder, sharedName) + varHandleOptions = VarHandleOptionsEnd(builder) + return varHandleOptions + + +class ReadVariableOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReadVariableOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReadVariableOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReadVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReadVariableOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ReadVariableOptionsStart(builder): + builder.StartObject(0) + + +def ReadVariableOptionsEnd(builder): + return builder.EndObject() + + +class ReadVariableOptionsT(object): + + # ReadVariableOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + readVariableOptions = ReadVariableOptions() + readVariableOptions.Init(buf, pos) + return cls.InitFromObj(readVariableOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, readVariableOptions): + x = ReadVariableOptionsT() + x._UnPack(readVariableOptions) + return x + + # ReadVariableOptionsT + def _UnPack(self, readVariableOptions): + if readVariableOptions is None: + return + + # ReadVariableOptionsT + def Pack(self, builder): + ReadVariableOptionsStart(builder) + readVariableOptions = ReadVariableOptionsEnd(builder) + return readVariableOptions + + +class AssignVariableOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AssignVariableOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAssignVariableOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def AssignVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # AssignVariableOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def AssignVariableOptionsStart(builder): + builder.StartObject(0) + + +def AssignVariableOptionsEnd(builder): + return builder.EndObject() + + +class AssignVariableOptionsT(object): + + # AssignVariableOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + assignVariableOptions = AssignVariableOptions() + assignVariableOptions.Init(buf, pos) + return cls.InitFromObj(assignVariableOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, assignVariableOptions): + x = AssignVariableOptionsT() + x._UnPack(assignVariableOptions) + return x + + # AssignVariableOptionsT + def _UnPack(self, assignVariableOptions): + if assignVariableOptions is None: + return + + # AssignVariableOptionsT + def Pack(self, builder): + AssignVariableOptionsStart(builder) + assignVariableOptions = AssignVariableOptionsEnd(builder) + return assignVariableOptions + + +class RandomOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RandomOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRandomOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RandomOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RandomOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RandomOptions + def Seed(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + # RandomOptions + def Seed2(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) + return 0 + + +def RandomOptionsStart(builder): + builder.StartObject(2) + + +def RandomOptionsAddSeed(builder, seed): + builder.PrependInt64Slot(0, seed, 0) + + +def RandomOptionsAddSeed2(builder, seed2): + builder.PrependInt64Slot(1, seed2, 0) + + +def RandomOptionsEnd(builder): + return builder.EndObject() + + +class RandomOptionsT(object): + + # RandomOptionsT + def __init__(self): + self.seed = 0 # type: int + self.seed2 = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + randomOptions = RandomOptions() + randomOptions.Init(buf, pos) + return cls.InitFromObj(randomOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, randomOptions): + x = RandomOptionsT() + x._UnPack(randomOptions) + return x + + # RandomOptionsT + def _UnPack(self, randomOptions): + if randomOptions is None: + return + self.seed = randomOptions.Seed() + self.seed2 = randomOptions.Seed2() + + # RandomOptionsT + def Pack(self, builder): + RandomOptionsStart(builder) + RandomOptionsAddSeed(builder, self.seed) + RandomOptionsAddSeed2(builder, self.seed2) + randomOptions = RandomOptionsEnd(builder) + return randomOptions + + +class BucketizeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BucketizeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBucketizeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BucketizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BucketizeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BucketizeOptions + def Boundaries(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Float32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # BucketizeOptions + def BoundariesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # BucketizeOptions + def BoundariesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # BucketizeOptions + def BoundariesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def BucketizeOptionsStart(builder): + builder.StartObject(1) + + +def BucketizeOptionsAddBoundaries(builder, boundaries): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(boundaries), 0) + + +def BucketizeOptionsStartBoundariesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def BucketizeOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class BucketizeOptionsT(object): + + # BucketizeOptionsT + def __init__(self): + self.boundaries = None # type: List[float] + + @classmethod + def InitFromBuf(cls, buf, pos): + bucketizeOptions = BucketizeOptions() + bucketizeOptions.Init(buf, pos) + return cls.InitFromObj(bucketizeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bucketizeOptions): + x = BucketizeOptionsT() + x._UnPack(bucketizeOptions) + return x + + # BucketizeOptionsT + def _UnPack(self, bucketizeOptions): + if bucketizeOptions is None: + return + if not bucketizeOptions.BoundariesIsNone(): + if np is None: + self.boundaries = [] + for i in range(bucketizeOptions.BoundariesLength()): + self.boundaries.append(bucketizeOptions.Boundaries(i)) + else: + self.boundaries = bucketizeOptions.BoundariesAsNumpy() + + # BucketizeOptionsT + def Pack(self, builder): + if self.boundaries is not None: + if np is not None and type(self.boundaries) is np.ndarray: + boundaries = builder.CreateNumpyVector(self.boundaries) + else: + BucketizeOptionsStartBoundariesVector(builder, len(self.boundaries)) + for i in reversed(range(len(self.boundaries))): + builder.PrependFloat32(self.boundaries[i]) + boundaries = builder.EndVector() + BucketizeOptionsStart(builder) + if self.boundaries is not None: + BucketizeOptionsAddBoundaries(builder, boundaries) + bucketizeOptions = BucketizeOptionsEnd(builder) + return bucketizeOptions + + +class GeluOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GeluOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGeluOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GeluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GeluOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # GeluOptions + def Approximate(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def GeluOptionsStart(builder): + builder.StartObject(1) + + +def GeluOptionsAddApproximate(builder, approximate): + builder.PrependBoolSlot(0, approximate, 0) + + +def GeluOptionsEnd(builder): + return builder.EndObject() + + +class GeluOptionsT(object): + + # GeluOptionsT + def __init__(self): + self.approximate = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + geluOptions = GeluOptions() + geluOptions.Init(buf, pos) + return cls.InitFromObj(geluOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, geluOptions): + x = GeluOptionsT() + x._UnPack(geluOptions) + return x + + # GeluOptionsT + def _UnPack(self, geluOptions): + if geluOptions is None: + return + self.approximate = geluOptions.Approximate() + + # GeluOptionsT + def Pack(self, builder): + GeluOptionsStart(builder) + GeluOptionsAddApproximate(builder, self.approximate) + geluOptions = GeluOptionsEnd(builder) + return geluOptions + + +class DynamicUpdateSliceOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DynamicUpdateSliceOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDynamicUpdateSliceOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DynamicUpdateSliceOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DynamicUpdateSliceOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def DynamicUpdateSliceOptionsStart(builder): + builder.StartObject(0) + + +def DynamicUpdateSliceOptionsEnd(builder): + return builder.EndObject() + + +class DynamicUpdateSliceOptionsT(object): + + # DynamicUpdateSliceOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + dynamicUpdateSliceOptions = DynamicUpdateSliceOptions() + dynamicUpdateSliceOptions.Init(buf, pos) + return cls.InitFromObj(dynamicUpdateSliceOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, dynamicUpdateSliceOptions): + x = DynamicUpdateSliceOptionsT() + x._UnPack(dynamicUpdateSliceOptions) + return x + + # DynamicUpdateSliceOptionsT + def _UnPack(self, dynamicUpdateSliceOptions): + if dynamicUpdateSliceOptions is None: + return + + # DynamicUpdateSliceOptionsT + def Pack(self, builder): + DynamicUpdateSliceOptionsStart(builder) + dynamicUpdateSliceOptions = DynamicUpdateSliceOptionsEnd(builder) + return dynamicUpdateSliceOptions + + +class UnsortedSegmentProdOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnsortedSegmentProdOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnsortedSegmentProdOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnsortedSegmentProdOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnsortedSegmentProdOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def UnsortedSegmentProdOptionsStart(builder): + builder.StartObject(0) + + +def UnsortedSegmentProdOptionsEnd(builder): + return builder.EndObject() + + +class UnsortedSegmentProdOptionsT(object): + + # UnsortedSegmentProdOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + unsortedSegmentProdOptions = UnsortedSegmentProdOptions() + unsortedSegmentProdOptions.Init(buf, pos) + return cls.InitFromObj(unsortedSegmentProdOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unsortedSegmentProdOptions): + x = UnsortedSegmentProdOptionsT() + x._UnPack(unsortedSegmentProdOptions) + return x + + # UnsortedSegmentProdOptionsT + def _UnPack(self, unsortedSegmentProdOptions): + if unsortedSegmentProdOptions is None: + return + + # UnsortedSegmentProdOptionsT + def Pack(self, builder): + UnsortedSegmentProdOptionsStart(builder) + unsortedSegmentProdOptions = UnsortedSegmentProdOptionsEnd(builder) + return unsortedSegmentProdOptions + + +class UnsortedSegmentMaxOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnsortedSegmentMaxOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnsortedSegmentMaxOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnsortedSegmentMaxOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnsortedSegmentMaxOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def UnsortedSegmentMaxOptionsStart(builder): + builder.StartObject(0) + + +def UnsortedSegmentMaxOptionsEnd(builder): + return builder.EndObject() + + +class UnsortedSegmentMaxOptionsT(object): + + # UnsortedSegmentMaxOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + unsortedSegmentMaxOptions = UnsortedSegmentMaxOptions() + unsortedSegmentMaxOptions.Init(buf, pos) + return cls.InitFromObj(unsortedSegmentMaxOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unsortedSegmentMaxOptions): + x = UnsortedSegmentMaxOptionsT() + x._UnPack(unsortedSegmentMaxOptions) + return x + + # UnsortedSegmentMaxOptionsT + def _UnPack(self, unsortedSegmentMaxOptions): + if unsortedSegmentMaxOptions is None: + return + + # UnsortedSegmentMaxOptionsT + def Pack(self, builder): + UnsortedSegmentMaxOptionsStart(builder) + unsortedSegmentMaxOptions = UnsortedSegmentMaxOptionsEnd(builder) + return unsortedSegmentMaxOptions + + +class UnsortedSegmentSumOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnsortedSegmentSumOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnsortedSegmentSumOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnsortedSegmentSumOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnsortedSegmentSumOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def UnsortedSegmentSumOptionsStart(builder): + builder.StartObject(0) + + +def UnsortedSegmentSumOptionsEnd(builder): + return builder.EndObject() + + +class UnsortedSegmentSumOptionsT(object): + + # UnsortedSegmentSumOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + unsortedSegmentSumOptions = UnsortedSegmentSumOptions() + unsortedSegmentSumOptions.Init(buf, pos) + return cls.InitFromObj(unsortedSegmentSumOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unsortedSegmentSumOptions): + x = UnsortedSegmentSumOptionsT() + x._UnPack(unsortedSegmentSumOptions) + return x + + # UnsortedSegmentSumOptionsT + def _UnPack(self, unsortedSegmentSumOptions): + if unsortedSegmentSumOptions is None: + return + + # UnsortedSegmentSumOptionsT + def Pack(self, builder): + UnsortedSegmentSumOptionsStart(builder) + unsortedSegmentSumOptions = UnsortedSegmentSumOptionsEnd(builder) + return unsortedSegmentSumOptions + + +class ATan2Options(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ATan2Options() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsATan2Options(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ATan2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ATan2Options + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ATan2OptionsStart(builder): + builder.StartObject(0) + + +def ATan2OptionsEnd(builder): + return builder.EndObject() + + +class ATan2OptionsT(object): + + # ATan2OptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + atan2Options = ATan2Options() + atan2Options.Init(buf, pos) + return cls.InitFromObj(atan2Options) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, atan2Options): + x = ATan2OptionsT() + x._UnPack(atan2Options) + return x + + # ATan2OptionsT + def _UnPack(self, atan2Options): + if atan2Options is None: + return + + # ATan2OptionsT + def Pack(self, builder): + ATan2OptionsStart(builder) + atan2Options = ATan2OptionsEnd(builder) + return atan2Options + + +class UnsortedSegmentMinOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnsortedSegmentMinOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsUnsortedSegmentMinOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def UnsortedSegmentMinOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # UnsortedSegmentMinOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def UnsortedSegmentMinOptionsStart(builder): + builder.StartObject(0) + + +def UnsortedSegmentMinOptionsEnd(builder): + return builder.EndObject() + + +class UnsortedSegmentMinOptionsT(object): + + # UnsortedSegmentMinOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + unsortedSegmentMinOptions = UnsortedSegmentMinOptions() + unsortedSegmentMinOptions.Init(buf, pos) + return cls.InitFromObj(unsortedSegmentMinOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, unsortedSegmentMinOptions): + x = UnsortedSegmentMinOptionsT() + x._UnPack(unsortedSegmentMinOptions) + return x + + # UnsortedSegmentMinOptionsT + def _UnPack(self, unsortedSegmentMinOptions): + if unsortedSegmentMinOptions is None: + return + + # UnsortedSegmentMinOptionsT + def Pack(self, builder): + UnsortedSegmentMinOptionsStart(builder) + unsortedSegmentMinOptions = UnsortedSegmentMinOptionsEnd(builder) + return unsortedSegmentMinOptions + + +class SignOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SignOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSignOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SignOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SignOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SignOptionsStart(builder): + builder.StartObject(0) + + +def SignOptionsEnd(builder): + return builder.EndObject() + + +class SignOptionsT(object): + + # SignOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + signOptions = SignOptions() + signOptions.Init(buf, pos) + return cls.InitFromObj(signOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, signOptions): + x = SignOptionsT() + x._UnPack(signOptions) + return x + + # SignOptionsT + def _UnPack(self, signOptions): + if signOptions is None: + return + + # SignOptionsT + def Pack(self, builder): + SignOptionsStart(builder) + signOptions = SignOptionsEnd(builder) + return signOptions + + +class BitcastOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BitcastOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBitcastOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BitcastOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BitcastOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def BitcastOptionsStart(builder): + builder.StartObject(0) + + +def BitcastOptionsEnd(builder): + return builder.EndObject() + + +class BitcastOptionsT(object): + + # BitcastOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + bitcastOptions = BitcastOptions() + bitcastOptions.Init(buf, pos) + return cls.InitFromObj(bitcastOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bitcastOptions): + x = BitcastOptionsT() + x._UnPack(bitcastOptions) + return x + + # BitcastOptionsT + def _UnPack(self, bitcastOptions): + if bitcastOptions is None: + return + + # BitcastOptionsT + def Pack(self, builder): + BitcastOptionsStart(builder) + bitcastOptions = BitcastOptionsEnd(builder) + return bitcastOptions + + +class BitwiseXorOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BitwiseXorOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBitwiseXorOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BitwiseXorOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BitwiseXorOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def BitwiseXorOptionsStart(builder): + builder.StartObject(0) + + +def BitwiseXorOptionsEnd(builder): + return builder.EndObject() + + +class BitwiseXorOptionsT(object): + + # BitwiseXorOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + bitwiseXorOptions = BitwiseXorOptions() + bitwiseXorOptions.Init(buf, pos) + return cls.InitFromObj(bitwiseXorOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bitwiseXorOptions): + x = BitwiseXorOptionsT() + x._UnPack(bitwiseXorOptions) + return x + + # BitwiseXorOptionsT + def _UnPack(self, bitwiseXorOptions): + if bitwiseXorOptions is None: + return + + # BitwiseXorOptionsT + def Pack(self, builder): + BitwiseXorOptionsStart(builder) + bitwiseXorOptions = BitwiseXorOptionsEnd(builder) + return bitwiseXorOptions + + +class RightShiftOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RightShiftOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRightShiftOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RightShiftOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RightShiftOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def RightShiftOptionsStart(builder): + builder.StartObject(0) + + +def RightShiftOptionsEnd(builder): + return builder.EndObject() + + +class RightShiftOptionsT(object): + + # RightShiftOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + rightShiftOptions = RightShiftOptions() + rightShiftOptions.Init(buf, pos) + return cls.InitFromObj(rightShiftOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rightShiftOptions): + x = RightShiftOptionsT() + x._UnPack(rightShiftOptions) + return x + + # RightShiftOptionsT + def _UnPack(self, rightShiftOptions): + if rightShiftOptions is None: + return + + # RightShiftOptionsT + def Pack(self, builder): + RightShiftOptionsStart(builder) + rightShiftOptions = RightShiftOptionsEnd(builder) + return rightShiftOptions + + +class DilateOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DilateOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDilateOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def DilateOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # DilateOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def DilateOptionsStart(builder): + builder.StartObject(0) + + +def DilateOptionsEnd(builder): + return builder.EndObject() + + +class DilateOptionsT(object): + + # DilateOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + dilateOptions = DilateOptions() + dilateOptions.Init(buf, pos) + return cls.InitFromObj(dilateOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, dilateOptions): + x = DilateOptionsT() + x._UnPack(dilateOptions) + return x + + # DilateOptionsT + def _UnPack(self, dilateOptions): + if dilateOptions is None: + return + + # DilateOptionsT + def Pack(self, builder): + DilateOptionsStart(builder) + dilateOptions = DilateOptionsEnd(builder) + return dilateOptions + + +class ReduceWindowOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReduceWindowOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsReduceWindowOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ReduceWindowOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # ReduceWindowOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReduceWindowOptions + def ReduceFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def ReduceWindowOptionsStart(builder): + builder.StartObject(1) + + +def ReduceWindowOptionsAddReduceFunction(builder, reduceFunction): + builder.PrependInt32Slot(0, reduceFunction, 0) + + +def ReduceWindowOptionsEnd(builder): + return builder.EndObject() + + +class ReduceWindowOptionsT(object): + + # ReduceWindowOptionsT + def __init__(self): + self.reduceFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + reduceWindowOptions = ReduceWindowOptions() + reduceWindowOptions.Init(buf, pos) + return cls.InitFromObj(reduceWindowOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, reduceWindowOptions): + x = ReduceWindowOptionsT() + x._UnPack(reduceWindowOptions) + return x + + # ReduceWindowOptionsT + def _UnPack(self, reduceWindowOptions): + if reduceWindowOptions is None: + return + self.reduceFunction = reduceWindowOptions.ReduceFunction() + + # ReduceWindowOptionsT + def Pack(self, builder): + ReduceWindowOptionsStart(builder) + ReduceWindowOptionsAddReduceFunction(builder, self.reduceFunction) + reduceWindowOptions = ReduceWindowOptionsEnd(builder) + return reduceWindowOptions + + +class GRUOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = GRUOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsGRUOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def GRUOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # GRUOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # GRUOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # GRUOptions + def ReturnSequences(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # GRUOptions + def TimeMajor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + +def GRUOptionsStart(builder): + builder.StartObject(3) + + +def GRUOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(0, fusedActivationFunction, 0) + + +def GRUOptionsAddReturnSequences(builder, returnSequences): + builder.PrependBoolSlot(1, returnSequences, 0) + + +def GRUOptionsAddTimeMajor(builder, timeMajor): + builder.PrependBoolSlot(2, timeMajor, 0) + + +def GRUOptionsEnd(builder): + return builder.EndObject() + + +class GRUOptionsT(object): + + # GRUOptionsT + def __init__(self): + self.fusedActivationFunction = 0 # type: int + self.returnSequences = False # type: bool + self.timeMajor = False # type: bool + + @classmethod + def InitFromBuf(cls, buf, pos): + gruoptions = GRUOptions() + gruoptions.Init(buf, pos) + return cls.InitFromObj(gruoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, gruoptions): + x = GRUOptionsT() + x._UnPack(gruoptions) + return x + + # GRUOptionsT + def _UnPack(self, gruoptions): + if gruoptions is None: + return + self.fusedActivationFunction = gruoptions.FusedActivationFunction() + self.returnSequences = gruoptions.ReturnSequences() + self.timeMajor = gruoptions.TimeMajor() + + # GRUOptionsT + def Pack(self, builder): + GRUOptionsStart(builder) + GRUOptionsAddFusedActivationFunction(builder, self.fusedActivationFunction) + GRUOptionsAddReturnSequences(builder, self.returnSequences) + GRUOptionsAddTimeMajor(builder, self.timeMajor) + gruoptions = GRUOptionsEnd(builder) + return gruoptions + + +class BCQGatherOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BCQGatherOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBCQGatherOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BCQGatherOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BCQGatherOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BCQGatherOptions + def InputHiddenSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BCQGatherOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def BCQGatherOptionsStart(builder): + builder.StartObject(2) + + +def BCQGatherOptionsAddInputHiddenSize(builder, inputHiddenSize): + builder.PrependInt32Slot(0, inputHiddenSize, 0) + + +def BCQGatherOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(1, axis, 0) + + +def BCQGatherOptionsEnd(builder): + return builder.EndObject() + + +class BCQGatherOptionsT(object): + + # BCQGatherOptionsT + def __init__(self): + self.inputHiddenSize = 0 # type: int + self.axis = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + bcqgatherOptions = BCQGatherOptions() + bcqgatherOptions.Init(buf, pos) + return cls.InitFromObj(bcqgatherOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bcqgatherOptions): + x = BCQGatherOptionsT() + x._UnPack(bcqgatherOptions) + return x + + # BCQGatherOptionsT + def _UnPack(self, bcqgatherOptions): + if bcqgatherOptions is None: + return + self.inputHiddenSize = bcqgatherOptions.InputHiddenSize() + self.axis = bcqgatherOptions.Axis() + + # BCQGatherOptionsT + def Pack(self, builder): + BCQGatherOptionsStart(builder) + BCQGatherOptionsAddInputHiddenSize(builder, self.inputHiddenSize) + BCQGatherOptionsAddAxis(builder, self.axis) + bcqgatherOptions = BCQGatherOptionsEnd(builder) + return bcqgatherOptions + + +class BCQFullyConnectedOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BCQFullyConnectedOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBCQFullyConnectedOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BCQFullyConnectedOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BCQFullyConnectedOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BCQFullyConnectedOptions + def WeightsHiddenSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BCQFullyConnectedOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def BCQFullyConnectedOptionsStart(builder): + builder.StartObject(2) + + +def BCQFullyConnectedOptionsAddWeightsHiddenSize(builder, weightsHiddenSize): + builder.PrependInt32Slot(0, weightsHiddenSize, 0) + + +def BCQFullyConnectedOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def BCQFullyConnectedOptionsEnd(builder): + return builder.EndObject() + + +class BCQFullyConnectedOptionsT(object): + + # BCQFullyConnectedOptionsT + def __init__(self): + self.weightsHiddenSize = 0 # type: int + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + bcqfullyConnectedOptions = BCQFullyConnectedOptions() + bcqfullyConnectedOptions.Init(buf, pos) + return cls.InitFromObj(bcqfullyConnectedOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, bcqfullyConnectedOptions): + x = BCQFullyConnectedOptionsT() + x._UnPack(bcqfullyConnectedOptions) + return x + + # BCQFullyConnectedOptionsT + def _UnPack(self, bcqfullyConnectedOptions): + if bcqfullyConnectedOptions is None: + return + self.weightsHiddenSize = bcqfullyConnectedOptions.WeightsHiddenSize() + self.fusedActivationFunction = bcqfullyConnectedOptions.FusedActivationFunction() + + # BCQFullyConnectedOptionsT + def Pack(self, builder): + BCQFullyConnectedOptionsStart(builder) + BCQFullyConnectedOptionsAddWeightsHiddenSize(builder, self.weightsHiddenSize) + BCQFullyConnectedOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + bcqfullyConnectedOptions = BCQFullyConnectedOptionsEnd(builder) + return bcqfullyConnectedOptions + + +class InstanceNormOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = InstanceNormOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsInstanceNormOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def InstanceNormOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # InstanceNormOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # InstanceNormOptions + def Epsilon(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # InstanceNormOptions + def FusedActivationFunction(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + +def InstanceNormOptionsStart(builder): + builder.StartObject(2) + + +def InstanceNormOptionsAddEpsilon(builder, epsilon): + builder.PrependFloat32Slot(0, epsilon, 0.0) + + +def InstanceNormOptionsAddFusedActivationFunction(builder, fusedActivationFunction): + builder.PrependInt8Slot(1, fusedActivationFunction, 0) + + +def InstanceNormOptionsEnd(builder): + return builder.EndObject() + + +class InstanceNormOptionsT(object): + + # InstanceNormOptionsT + def __init__(self): + self.epsilon = 0.0 # type: float + self.fusedActivationFunction = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + instanceNormOptions = InstanceNormOptions() + instanceNormOptions.Init(buf, pos) + return cls.InitFromObj(instanceNormOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, instanceNormOptions): + x = InstanceNormOptionsT() + x._UnPack(instanceNormOptions) + return x + + # InstanceNormOptionsT + def _UnPack(self, instanceNormOptions): + if instanceNormOptions is None: + return + self.epsilon = instanceNormOptions.Epsilon() + self.fusedActivationFunction = instanceNormOptions.FusedActivationFunction() + + # InstanceNormOptionsT + def Pack(self, builder): + InstanceNormOptionsStart(builder) + InstanceNormOptionsAddEpsilon(builder, self.epsilon) + InstanceNormOptionsAddFusedActivationFunction(builder, + self.fusedActivationFunction) + instanceNormOptions = InstanceNormOptionsEnd(builder) + return instanceNormOptions + + +class RmsNormOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RmsNormOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRmsNormOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RmsNormOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RmsNormOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RmsNormOptions + def Epsilon(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + +def RmsNormOptionsStart(builder): + builder.StartObject(1) + + +def RmsNormOptionsAddEpsilon(builder, epsilon): + builder.PrependFloat32Slot(0, epsilon, 0.0) + + +def RmsNormOptionsEnd(builder): + return builder.EndObject() + + +class RmsNormOptionsT(object): + + # RmsNormOptionsT + def __init__(self): + self.epsilon = 0.0 # type: float + + @classmethod + def InitFromBuf(cls, buf, pos): + rmsNormOptions = RmsNormOptions() + rmsNormOptions.Init(buf, pos) + return cls.InitFromObj(rmsNormOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, rmsNormOptions): + x = RmsNormOptionsT() + x._UnPack(rmsNormOptions) + return x + + # RmsNormOptionsT + def _UnPack(self, rmsNormOptions): + if rmsNormOptions is None: + return + self.epsilon = rmsNormOptions.Epsilon() + + # RmsNormOptionsT + def Pack(self, builder): + RmsNormOptionsStart(builder) + RmsNormOptionsAddEpsilon(builder, self.epsilon) + rmsNormOptions = RmsNormOptionsEnd(builder) + return rmsNormOptions + + +class RoPEOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RoPEOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRoPEOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RoPEOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RoPEOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RoPEOptions + def Mode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def RoPEOptionsStart(builder): + builder.StartObject(1) + + +def RoPEOptionsAddMode(builder, mode): + builder.PrependInt32Slot(0, mode, 0) + + +def RoPEOptionsEnd(builder): + return builder.EndObject() + + +class RoPEOptionsT(object): + + # RoPEOptionsT + def __init__(self): + self.mode = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + roPeoptions = RoPEOptions() + roPeoptions.Init(buf, pos) + return cls.InitFromObj(roPeoptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, roPeoptions): + x = RoPEOptionsT() + x._UnPack(roPeoptions) + return x + + # RoPEOptionsT + def _UnPack(self, roPeoptions): + if roPeoptions is None: + return + self.mode = roPeoptions.Mode() + + # RoPEOptionsT + def Pack(self, builder): + RoPEOptionsStart(builder) + RoPEOptionsAddMode(builder, self.mode) + roPeoptions = RoPEOptionsEnd(builder) + return roPeoptions + + +class RunModelOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RunModelOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRunModelOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def RunModelOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # RunModelOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RunModelOptions + def Location(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # RunModelOptions + def Signature(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + +def RunModelOptionsStart(builder): + builder.StartObject(2) + + +def RunModelOptionsAddLocation(builder, location): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(location), 0) + + +def RunModelOptionsAddSignature(builder, signature): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(signature), 0) + + +def RunModelOptionsEnd(builder): + return builder.EndObject() + + +class RunModelOptionsT(object): + + # RunModelOptionsT + def __init__(self): + self.location = None # type: str + self.signature = None # type: str + + @classmethod + def InitFromBuf(cls, buf, pos): + runModelOptions = RunModelOptions() + runModelOptions.Init(buf, pos) + return cls.InitFromObj(runModelOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, runModelOptions): + x = RunModelOptionsT() + x._UnPack(runModelOptions) + return x + + # RunModelOptionsT + def _UnPack(self, runModelOptions): + if runModelOptions is None: + return + self.location = runModelOptions.Location() + self.signature = runModelOptions.Signature() + + # RunModelOptionsT + def Pack(self, builder): + if self.location is not None: + location = builder.CreateString(self.location) + if self.signature is not None: + signature = builder.CreateString(self.signature) + RunModelOptionsStart(builder) + if self.location is not None: + RunModelOptionsAddLocation(builder, location) + if self.signature is not None: + RunModelOptionsAddSignature(builder, signature) + runModelOptions = RunModelOptionsEnd(builder) + return runModelOptions + + +class AttentionOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AttentionOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAttentionOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def AttentionOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # AttentionOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # AttentionOptions + def LayerIdx(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def AttentionOptionsStart(builder): + builder.StartObject(1) + + +def AttentionOptionsAddLayerIdx(builder, layerIdx): + builder.PrependInt32Slot(0, layerIdx, 0) + + +def AttentionOptionsEnd(builder): + return builder.EndObject() + + +class AttentionOptionsT(object): + + # AttentionOptionsT + def __init__(self): + self.layerIdx = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + attentionOptions = AttentionOptions() + attentionOptions.Init(buf, pos) + return cls.InitFromObj(attentionOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, attentionOptions): + x = AttentionOptionsT() + x._UnPack(attentionOptions) + return x + + # AttentionOptionsT + def _UnPack(self, attentionOptions): + if attentionOptions is None: + return + self.layerIdx = attentionOptions.LayerIdx() + + # AttentionOptionsT + def Pack(self, builder): + AttentionOptionsStart(builder) + AttentionOptionsAddLayerIdx(builder, self.layerIdx) + attentionOptions = AttentionOptionsEnd(builder) + return attentionOptions + + +class OperatorCode(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = OperatorCode() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsOperatorCode(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def OperatorCodeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # OperatorCode + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # OperatorCode + def DeprecatedBuiltinCode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # OperatorCode + def CustomCode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # OperatorCode + def Version(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # OperatorCode + def BuiltinCode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def OperatorCodeStart(builder): + builder.StartObject(4) + + +def OperatorCodeAddDeprecatedBuiltinCode(builder, deprecatedBuiltinCode): + builder.PrependInt8Slot(0, deprecatedBuiltinCode, 0) + + +def OperatorCodeAddCustomCode(builder, customCode): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(customCode), 0) + + +def OperatorCodeAddVersion(builder, version): + builder.PrependInt32Slot(2, version, 1) + + +def OperatorCodeAddBuiltinCode(builder, builtinCode): + builder.PrependInt32Slot(3, builtinCode, 0) + + +def OperatorCodeEnd(builder): + return builder.EndObject() + + +class OperatorCodeT(object): + + # OperatorCodeT + def __init__(self): + self.deprecatedBuiltinCode = 0 # type: int + self.customCode = None # type: str + self.version = 1 # type: int + self.builtinCode = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + operatorCode = OperatorCode() + operatorCode.Init(buf, pos) + return cls.InitFromObj(operatorCode) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, operatorCode): + x = OperatorCodeT() + x._UnPack(operatorCode) + return x + + # OperatorCodeT + def _UnPack(self, operatorCode): + if operatorCode is None: + return + self.deprecatedBuiltinCode = operatorCode.DeprecatedBuiltinCode() + self.customCode = operatorCode.CustomCode() + self.version = operatorCode.Version() + self.builtinCode = operatorCode.BuiltinCode() + + # OperatorCodeT + def Pack(self, builder): + if self.customCode is not None: + customCode = builder.CreateString(self.customCode) + OperatorCodeStart(builder) + OperatorCodeAddDeprecatedBuiltinCode(builder, self.deprecatedBuiltinCode) + if self.customCode is not None: + OperatorCodeAddCustomCode(builder, customCode) + OperatorCodeAddVersion(builder, self.version) + OperatorCodeAddBuiltinCode(builder, self.builtinCode) + operatorCode = OperatorCodeEnd(builder) + return operatorCode + + +class Operator(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Operator() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsOperator(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def OperatorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Operator + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Operator + def OpcodeIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # Operator + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Operator + def InputsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Operator + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Operator + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # Operator + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Operator + def OutputsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Operator + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Operator + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # Operator + def BuiltinOptionsType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # Operator + def BuiltinOptions(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + # Operator + def CustomOptions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # Operator + def CustomOptionsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # Operator + def CustomOptionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Operator + def CustomOptionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + return o == 0 + + # Operator + def CustomOptionsFormat(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # Operator + def MutatingVariableInputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.BoolFlags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # Operator + def MutatingVariableInputsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.BoolFlags, o) + return 0 + + # Operator + def MutatingVariableInputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Operator + def MutatingVariableInputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + return o == 0 + + # Operator + def Intermediates(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Operator + def IntermediatesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Operator + def IntermediatesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Operator + def IntermediatesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + return o == 0 + + # Operator + def LargeCustomOptionsOffset(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) + return 0 + + # Operator + def LargeCustomOptionsSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) + return 0 + + # Operator + def BuiltinOptions2Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # Operator + def BuiltinOptions2(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + +def OperatorStart(builder): + builder.StartObject(13) + + +def OperatorAddOpcodeIndex(builder, opcodeIndex): + builder.PrependUint32Slot(0, opcodeIndex, 0) + + +def OperatorAddInputs(builder, inputs): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) + + +def OperatorStartInputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def OperatorAddOutputs(builder, outputs): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) + + +def OperatorStartOutputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def OperatorAddBuiltinOptionsType(builder, builtinOptionsType): + builder.PrependUint8Slot(3, builtinOptionsType, 0) + + +def OperatorAddBuiltinOptions(builder, builtinOptions): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(builtinOptions), 0) + + +def OperatorAddCustomOptions(builder, customOptions): + builder.PrependUOffsetTRelativeSlot( + 5, flatbuffers.number_types.UOffsetTFlags.py_type(customOptions), 0) + + +def OperatorStartCustomOptionsVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def OperatorAddCustomOptionsFormat(builder, customOptionsFormat): + builder.PrependInt8Slot(6, customOptionsFormat, 0) + + +def OperatorAddMutatingVariableInputs(builder, mutatingVariableInputs): + builder.PrependUOffsetTRelativeSlot( + 7, flatbuffers.number_types.UOffsetTFlags.py_type(mutatingVariableInputs), 0) + + +def OperatorStartMutatingVariableInputsVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def OperatorAddIntermediates(builder, intermediates): + builder.PrependUOffsetTRelativeSlot( + 8, flatbuffers.number_types.UOffsetTFlags.py_type(intermediates), 0) + + +def OperatorStartIntermediatesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def OperatorAddLargeCustomOptionsOffset(builder, largeCustomOptionsOffset): + builder.PrependUint64Slot(9, largeCustomOptionsOffset, 0) + + +def OperatorAddLargeCustomOptionsSize(builder, largeCustomOptionsSize): + builder.PrependUint64Slot(10, largeCustomOptionsSize, 0) + + +def OperatorAddBuiltinOptions2Type(builder, builtinOptions2Type): + builder.PrependUint8Slot(11, builtinOptions2Type, 0) + + +def OperatorAddBuiltinOptions2(builder, builtinOptions2): + builder.PrependUOffsetTRelativeSlot( + 12, flatbuffers.number_types.UOffsetTFlags.py_type(builtinOptions2), 0) + + +def OperatorEnd(builder): + return builder.EndObject() + + +try: + from typing import List, Union +except: + pass + + +class OperatorT(object): + + # OperatorT + def __init__(self): + self.opcodeIndex = 0 # type: int + self.inputs = None # type: List[int] + self.outputs = None # type: List[int] + self.builtinOptionsType = 0 # type: int + self.builtinOptions = None # type: Union[None, Conv2DOptionsT, DepthwiseConv2DOptionsT, ConcatEmbeddingsOptionsT, LSHProjectionOptionsT, Pool2DOptionsT, SVDFOptionsT, RNNOptionsT, FullyConnectedOptionsT, SoftmaxOptionsT, ConcatenationOptionsT, AddOptionsT, L2NormOptionsT, LocalResponseNormalizationOptionsT, LSTMOptionsT, ResizeBilinearOptionsT, CallOptionsT, ReshapeOptionsT, SkipGramOptionsT, SpaceToDepthOptionsT, EmbeddingLookupSparseOptionsT, MulOptionsT, PadOptionsT, GatherOptionsT, BatchToSpaceNDOptionsT, SpaceToBatchNDOptionsT, TransposeOptionsT, ReducerOptionsT, SubOptionsT, DivOptionsT, SqueezeOptionsT, SequenceRNNOptionsT, StridedSliceOptionsT, ExpOptionsT, TopKV2OptionsT, SplitOptionsT, LogSoftmaxOptionsT, CastOptionsT, DequantizeOptionsT, MaximumMinimumOptionsT, ArgMaxOptionsT, LessOptionsT, NegOptionsT, PadV2OptionsT, GreaterOptionsT, GreaterEqualOptionsT, LessEqualOptionsT, SelectOptionsT, SliceOptionsT, TransposeConvOptionsT, SparseToDenseOptionsT, TileOptionsT, ExpandDimsOptionsT, EqualOptionsT, NotEqualOptionsT, ShapeOptionsT, PowOptionsT, ArgMinOptionsT, FakeQuantOptionsT, PackOptionsT, LogicalOrOptionsT, OneHotOptionsT, LogicalAndOptionsT, LogicalNotOptionsT, UnpackOptionsT, FloorDivOptionsT, SquareOptionsT, ZerosLikeOptionsT, FillOptionsT, BidirectionalSequenceLSTMOptionsT, BidirectionalSequenceRNNOptionsT, UnidirectionalSequenceLSTMOptionsT, FloorModOptionsT, RangeOptionsT, ResizeNearestNeighborOptionsT, LeakyReluOptionsT, SquaredDifferenceOptionsT, MirrorPadOptionsT, AbsOptionsT, SplitVOptionsT, UniqueOptionsT, ReverseV2OptionsT, AddNOptionsT, GatherNdOptionsT, CosOptionsT, WhereOptionsT, RankOptionsT, ReverseSequenceOptionsT, MatrixDiagOptionsT, QuantizeOptionsT, MatrixSetDiagOptionsT, HardSwishOptionsT, IfOptionsT, WhileOptionsT, DepthToSpaceOptionsT, NonMaxSuppressionV4OptionsT, NonMaxSuppressionV5OptionsT, ScatterNdOptionsT, SelectV2OptionsT, DensifyOptionsT, SegmentSumOptionsT, BatchMatMulOptionsT, CumsumOptionsT, CallOnceOptionsT, BroadcastToOptionsT, Rfft2dOptionsT, Conv3DOptionsT, HashtableOptionsT, HashtableFindOptionsT, HashtableImportOptionsT, HashtableSizeOptionsT, VarHandleOptionsT, ReadVariableOptionsT, AssignVariableOptionsT, RandomOptionsT, BucketizeOptionsT, GeluOptionsT, DynamicUpdateSliceOptionsT, UnsortedSegmentProdOptionsT, UnsortedSegmentMaxOptionsT, UnsortedSegmentMinOptionsT, UnsortedSegmentSumOptionsT, ATan2OptionsT, SignOptionsT, BitcastOptionsT, BitwiseXorOptionsT, RightShiftOptionsT, AttentionOptionsT, RunModelOptionsT, RoPEOptionsT, RmsNormOptionsT, GRUOptionsT, BCQGatherOptionsT, BCQFullyConnectedOptionsT, InstanceNormOptionsT] + self.customOptions = None # type: List[int] + self.customOptionsFormat = 0 # type: int + self.mutatingVariableInputs = None # type: List[bool] + self.intermediates = None # type: List[int] + self.largeCustomOptionsOffset = 0 # type: int + self.largeCustomOptionsSize = 0 # type: int + self.builtinOptions2Type = 0 # type: int + self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT, ReduceWindowOptionsT] + + @classmethod + def InitFromBuf(cls, buf, pos): + operator = Operator() + operator.Init(buf, pos) + return cls.InitFromObj(operator) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, operator): + x = OperatorT() + x._UnPack(operator) + return x + + # OperatorT + def _UnPack(self, operator): + if operator is None: + return + self.opcodeIndex = operator.OpcodeIndex() + if not operator.InputsIsNone(): + if np is None: + self.inputs = [] + for i in range(operator.InputsLength()): + self.inputs.append(operator.Inputs(i)) + else: + self.inputs = operator.InputsAsNumpy() + if not operator.OutputsIsNone(): + if np is None: + self.outputs = [] + for i in range(operator.OutputsLength()): + self.outputs.append(operator.Outputs(i)) + else: + self.outputs = operator.OutputsAsNumpy() + self.builtinOptionsType = operator.BuiltinOptionsType() + self.builtinOptions = BuiltinOptionsCreator(self.builtinOptionsType, + operator.BuiltinOptions()) + if not operator.CustomOptionsIsNone(): + if np is None: + self.customOptions = [] + for i in range(operator.CustomOptionsLength()): + self.customOptions.append(operator.CustomOptions(i)) + else: + self.customOptions = operator.CustomOptionsAsNumpy() + self.customOptionsFormat = operator.CustomOptionsFormat() + if not operator.MutatingVariableInputsIsNone(): + if np is None: + self.mutatingVariableInputs = [] + for i in range(operator.MutatingVariableInputsLength()): + self.mutatingVariableInputs.append(operator.MutatingVariableInputs(i)) + else: + self.mutatingVariableInputs = operator.MutatingVariableInputsAsNumpy() + if not operator.IntermediatesIsNone(): + if np is None: + self.intermediates = [] + for i in range(operator.IntermediatesLength()): + self.intermediates.append(operator.Intermediates(i)) + else: + self.intermediates = operator.IntermediatesAsNumpy() + self.largeCustomOptionsOffset = operator.LargeCustomOptionsOffset() + self.largeCustomOptionsSize = operator.LargeCustomOptionsSize() + self.builtinOptions2Type = operator.BuiltinOptions2Type() + self.builtinOptions2 = BuiltinOptions2Creator(self.builtinOptions2Type, + operator.BuiltinOptions2()) + + # OperatorT + def Pack(self, builder): + if self.inputs is not None: + if np is not None and type(self.inputs) is np.ndarray: + inputs = builder.CreateNumpyVector(self.inputs) + else: + OperatorStartInputsVector(builder, len(self.inputs)) + for i in reversed(range(len(self.inputs))): + builder.PrependInt32(self.inputs[i]) + inputs = builder.EndVector() + if self.outputs is not None: + if np is not None and type(self.outputs) is np.ndarray: + outputs = builder.CreateNumpyVector(self.outputs) + else: + OperatorStartOutputsVector(builder, len(self.outputs)) + for i in reversed(range(len(self.outputs))): + builder.PrependInt32(self.outputs[i]) + outputs = builder.EndVector() + if self.builtinOptions is not None: + builtinOptions = self.builtinOptions.Pack(builder) + if self.customOptions is not None: + if np is not None and type(self.customOptions) is np.ndarray: + customOptions = builder.CreateNumpyVector(self.customOptions) + else: + OperatorStartCustomOptionsVector(builder, len(self.customOptions)) + for i in reversed(range(len(self.customOptions))): + builder.PrependUint8(self.customOptions[i]) + customOptions = builder.EndVector() + if self.mutatingVariableInputs is not None: + if np is not None and type(self.mutatingVariableInputs) is np.ndarray: + mutatingVariableInputs = builder.CreateNumpyVector( + self.mutatingVariableInputs) + else: + OperatorStartMutatingVariableInputsVector( + builder, len(self.mutatingVariableInputs)) + for i in reversed(range(len(self.mutatingVariableInputs))): + builder.PrependBool(self.mutatingVariableInputs[i]) + mutatingVariableInputs = builder.EndVector() + if self.intermediates is not None: + if np is not None and type(self.intermediates) is np.ndarray: + intermediates = builder.CreateNumpyVector(self.intermediates) + else: + OperatorStartIntermediatesVector(builder, len(self.intermediates)) + for i in reversed(range(len(self.intermediates))): + builder.PrependInt32(self.intermediates[i]) + intermediates = builder.EndVector() + if self.builtinOptions2 is not None: + builtinOptions2 = self.builtinOptions2.Pack(builder) + OperatorStart(builder) + OperatorAddOpcodeIndex(builder, self.opcodeIndex) + if self.inputs is not None: + OperatorAddInputs(builder, inputs) + if self.outputs is not None: + OperatorAddOutputs(builder, outputs) + OperatorAddBuiltinOptionsType(builder, self.builtinOptionsType) + if self.builtinOptions is not None: + OperatorAddBuiltinOptions(builder, builtinOptions) + if self.customOptions is not None: + OperatorAddCustomOptions(builder, customOptions) + OperatorAddCustomOptionsFormat(builder, self.customOptionsFormat) + if self.mutatingVariableInputs is not None: + OperatorAddMutatingVariableInputs(builder, mutatingVariableInputs) + if self.intermediates is not None: + OperatorAddIntermediates(builder, intermediates) + OperatorAddLargeCustomOptionsOffset(builder, self.largeCustomOptionsOffset) + OperatorAddLargeCustomOptionsSize(builder, self.largeCustomOptionsSize) + OperatorAddBuiltinOptions2Type(builder, self.builtinOptions2Type) + if self.builtinOptions2 is not None: + OperatorAddBuiltinOptions2(builder, builtinOptions2) + operator = OperatorEnd(builder) + return operator + + +class SubGraph(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SubGraph() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSubGraph(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SubGraphBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SubGraph + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SubGraph + def Tensors(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = Tensor() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # SubGraph + def TensorsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SubGraph + def TensorsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # SubGraph + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SubGraph + def InputsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SubGraph + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SubGraph + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # SubGraph + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SubGraph + def OutputsAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SubGraph + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SubGraph + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # SubGraph + def Operators(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = Operator() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # SubGraph + def OperatorsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SubGraph + def OperatorsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # SubGraph + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + +def SubGraphStart(builder): + builder.StartObject(6) + + +def SubGraphAddTensors(builder, tensors): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0) + + +def SubGraphStartTensorsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SubGraphAddInputs(builder, inputs): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) + + +def SubGraphStartInputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SubGraphAddOutputs(builder, outputs): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) + + +def SubGraphStartOutputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SubGraphAddOperators(builder, operators): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(operators), 0) + + +def SubGraphStartOperatorsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SubGraphAddName(builder, name): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + + +def SubGraphEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class SubGraphT(object): + + # SubGraphT + def __init__(self): + self.tensors = None # type: List[TensorT] + self.inputs = None # type: List[int] + self.outputs = None # type: List[int] + self.operators = None # type: List[OperatorT] + self.name = None # type: str + + @classmethod + def InitFromBuf(cls, buf, pos): + subGraph = SubGraph() + subGraph.Init(buf, pos) + return cls.InitFromObj(subGraph) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, subGraph): + x = SubGraphT() + x._UnPack(subGraph) + return x + + # SubGraphT + def _UnPack(self, subGraph): + if subGraph is None: + return + if not subGraph.TensorsIsNone(): + self.tensors = [] + for i in range(subGraph.TensorsLength()): + if subGraph.Tensors(i) is None: + self.tensors.append(None) + else: + tensor_ = TensorT.InitFromObj(subGraph.Tensors(i)) + self.tensors.append(tensor_) + if not subGraph.InputsIsNone(): + if np is None: + self.inputs = [] + for i in range(subGraph.InputsLength()): + self.inputs.append(subGraph.Inputs(i)) + else: + self.inputs = subGraph.InputsAsNumpy() + if not subGraph.OutputsIsNone(): + if np is None: + self.outputs = [] + for i in range(subGraph.OutputsLength()): + self.outputs.append(subGraph.Outputs(i)) + else: + self.outputs = subGraph.OutputsAsNumpy() + if not subGraph.OperatorsIsNone(): + self.operators = [] + for i in range(subGraph.OperatorsLength()): + if subGraph.Operators(i) is None: + self.operators.append(None) + else: + operator_ = OperatorT.InitFromObj(subGraph.Operators(i)) + self.operators.append(operator_) + self.name = subGraph.Name() + + # SubGraphT + def Pack(self, builder): + if self.tensors is not None: + tensorslist = [] + for i in range(len(self.tensors)): + tensorslist.append(self.tensors[i].Pack(builder)) + SubGraphStartTensorsVector(builder, len(self.tensors)) + for i in reversed(range(len(self.tensors))): + builder.PrependUOffsetTRelative(tensorslist[i]) + tensors = builder.EndVector() + if self.inputs is not None: + if np is not None and type(self.inputs) is np.ndarray: + inputs = builder.CreateNumpyVector(self.inputs) + else: + SubGraphStartInputsVector(builder, len(self.inputs)) + for i in reversed(range(len(self.inputs))): + builder.PrependInt32(self.inputs[i]) + inputs = builder.EndVector() + if self.outputs is not None: + if np is not None and type(self.outputs) is np.ndarray: + outputs = builder.CreateNumpyVector(self.outputs) + else: + SubGraphStartOutputsVector(builder, len(self.outputs)) + for i in reversed(range(len(self.outputs))): + builder.PrependInt32(self.outputs[i]) + outputs = builder.EndVector() + if self.operators is not None: + operatorslist = [] + for i in range(len(self.operators)): + operatorslist.append(self.operators[i].Pack(builder)) + SubGraphStartOperatorsVector(builder, len(self.operators)) + for i in reversed(range(len(self.operators))): + builder.PrependUOffsetTRelative(operatorslist[i]) + operators = builder.EndVector() + if self.name is not None: + name = builder.CreateString(self.name) + SubGraphStart(builder) + if self.tensors is not None: + SubGraphAddTensors(builder, tensors) + if self.inputs is not None: + SubGraphAddInputs(builder, inputs) + if self.outputs is not None: + SubGraphAddOutputs(builder, outputs) + if self.operators is not None: + SubGraphAddOperators(builder, operators) + if self.name is not None: + SubGraphAddName(builder, name) + subGraph = SubGraphEnd(builder) + return subGraph + + +class Buffer(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Buffer() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBuffer(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BufferBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Buffer + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Buffer + def Data(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # Buffer + def DataAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # Buffer + def DataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Buffer + def DataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Buffer + def Offset(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) + return 0 + + # Buffer + def Size(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos) + return 0 + + +def BufferStart(builder): + builder.StartObject(3) + + +def BufferAddData(builder, data): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) + + +def BufferStartDataVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def BufferAddOffset(builder, offset): + builder.PrependUint64Slot(1, offset, 0) + + +def BufferAddSize(builder, size): + builder.PrependUint64Slot(2, size, 0) + + +def BufferEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class BufferT(object): + + # BufferT + def __init__(self): + self.data = None # type: List[int] + self.offset = 0 # type: int + self.size = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + buffer = Buffer() + buffer.Init(buf, pos) + return cls.InitFromObj(buffer) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, buffer): + x = BufferT() + x._UnPack(buffer) + return x + + # BufferT + def _UnPack(self, buffer): + if buffer is None: + return + if not buffer.DataIsNone(): + if np is None: + self.data = [] + for i in range(buffer.DataLength()): + self.data.append(buffer.Data(i)) + else: + self.data = buffer.DataAsNumpy() + self.offset = buffer.Offset() + self.size = buffer.Size() + + # BufferT + def Pack(self, builder): + if self.data is not None: + if np is not None and type(self.data) is np.ndarray: + data = builder.CreateNumpyVector(self.data) + else: + BufferStartDataVector(builder, len(self.data)) + for i in reversed(range(len(self.data))): + builder.PrependUint8(self.data[i]) + data = builder.EndVector() + BufferStart(builder) + if self.data is not None: + BufferAddData(builder, data) + BufferAddOffset(builder, self.offset) + BufferAddSize(builder, self.size) + buffer = BufferEnd(builder) + return buffer + + +class Metadata(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Metadata() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMetadata(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def MetadataBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Metadata + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Metadata + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def Buffer(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + +def MetadataStart(builder): + builder.StartObject(2) + + +def MetadataAddName(builder, name): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + + +def MetadataAddBuffer(builder, buffer): + builder.PrependUint32Slot(1, buffer, 0) + + +def MetadataEnd(builder): + return builder.EndObject() + + +class MetadataT(object): + + # MetadataT + def __init__(self): + self.name = None # type: str + self.buffer = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + metadata = Metadata() + metadata.Init(buf, pos) + return cls.InitFromObj(metadata) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, metadata): + x = MetadataT() + x._UnPack(metadata) + return x + + # MetadataT + def _UnPack(self, metadata): + if metadata is None: + return + self.name = metadata.Name() + self.buffer = metadata.Buffer() + + # MetadataT + def Pack(self, builder): + if self.name is not None: + name = builder.CreateString(self.name) + MetadataStart(builder) + if self.name is not None: + MetadataAddName(builder, name) + MetadataAddBuffer(builder, self.buffer) + metadata = MetadataEnd(builder) + return metadata + + +class TensorMap(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TensorMap() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTensorMap(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def TensorMapBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # TensorMap + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TensorMap + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # TensorMap + def TensorIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + +def TensorMapStart(builder): + builder.StartObject(2) + + +def TensorMapAddName(builder, name): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + + +def TensorMapAddTensorIndex(builder, tensorIndex): + builder.PrependUint32Slot(1, tensorIndex, 0) + + +def TensorMapEnd(builder): + return builder.EndObject() + + +class TensorMapT(object): + + # TensorMapT + def __init__(self): + self.name = None # type: str + self.tensorIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + tensorMap = TensorMap() + tensorMap.Init(buf, pos) + return cls.InitFromObj(tensorMap) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, tensorMap): + x = TensorMapT() + x._UnPack(tensorMap) + return x + + # TensorMapT + def _UnPack(self, tensorMap): + if tensorMap is None: + return + self.name = tensorMap.Name() + self.tensorIndex = tensorMap.TensorIndex() + + # TensorMapT + def Pack(self, builder): + if self.name is not None: + name = builder.CreateString(self.name) + TensorMapStart(builder) + if self.name is not None: + TensorMapAddName(builder, name) + TensorMapAddTensorIndex(builder, self.tensorIndex) + tensorMap = TensorMapEnd(builder) + return tensorMap + + +class SignatureDef(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SignatureDef() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsSignatureDef(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def SignatureDefBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # SignatureDef + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SignatureDef + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = TensorMap() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # SignatureDef + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SignatureDef + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # SignatureDef + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = TensorMap() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # SignatureDef + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SignatureDef + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # SignatureDef + def SignatureKey(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # SignatureDef + def SubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + +def SignatureDefStart(builder): + builder.StartObject(5) + + +def SignatureDefAddInputs(builder, inputs): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) + + +def SignatureDefStartInputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SignatureDefAddOutputs(builder, outputs): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) + + +def SignatureDefStartOutputsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def SignatureDefAddSignatureKey(builder, signatureKey): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(signatureKey), 0) + + +def SignatureDefAddSubgraphIndex(builder, subgraphIndex): + builder.PrependUint32Slot(4, subgraphIndex, 0) + + +def SignatureDefEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class SignatureDefT(object): + + # SignatureDefT + def __init__(self): + self.inputs = None # type: List[TensorMapT] + self.outputs = None # type: List[TensorMapT] + self.signatureKey = None # type: str + self.subgraphIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + signatureDef = SignatureDef() + signatureDef.Init(buf, pos) + return cls.InitFromObj(signatureDef) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, signatureDef): + x = SignatureDefT() + x._UnPack(signatureDef) + return x + + # SignatureDefT + def _UnPack(self, signatureDef): + if signatureDef is None: + return + if not signatureDef.InputsIsNone(): + self.inputs = [] + for i in range(signatureDef.InputsLength()): + if signatureDef.Inputs(i) is None: + self.inputs.append(None) + else: + tensorMap_ = TensorMapT.InitFromObj(signatureDef.Inputs(i)) + self.inputs.append(tensorMap_) + if not signatureDef.OutputsIsNone(): + self.outputs = [] + for i in range(signatureDef.OutputsLength()): + if signatureDef.Outputs(i) is None: + self.outputs.append(None) + else: + tensorMap_ = TensorMapT.InitFromObj(signatureDef.Outputs(i)) + self.outputs.append(tensorMap_) + self.signatureKey = signatureDef.SignatureKey() + self.subgraphIndex = signatureDef.SubgraphIndex() + + # SignatureDefT + def Pack(self, builder): + if self.inputs is not None: + inputslist = [] + for i in range(len(self.inputs)): + inputslist.append(self.inputs[i].Pack(builder)) + SignatureDefStartInputsVector(builder, len(self.inputs)) + for i in reversed(range(len(self.inputs))): + builder.PrependUOffsetTRelative(inputslist[i]) + inputs = builder.EndVector() + if self.outputs is not None: + outputslist = [] + for i in range(len(self.outputs)): + outputslist.append(self.outputs[i].Pack(builder)) + SignatureDefStartOutputsVector(builder, len(self.outputs)) + for i in reversed(range(len(self.outputs))): + builder.PrependUOffsetTRelative(outputslist[i]) + outputs = builder.EndVector() + if self.signatureKey is not None: + signatureKey = builder.CreateString(self.signatureKey) + SignatureDefStart(builder) + if self.inputs is not None: + SignatureDefAddInputs(builder, inputs) + if self.outputs is not None: + SignatureDefAddOutputs(builder, outputs) + if self.signatureKey is not None: + SignatureDefAddSignatureKey(builder, signatureKey) + SignatureDefAddSubgraphIndex(builder, self.subgraphIndex) + signatureDef = SignatureDefEnd(builder) + return signatureDef + + +class Model(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Model() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsModel(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def ModelBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # Model + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Model + def Version(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # Model + def OperatorCodes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = OperatorCode() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Model + def OperatorCodesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def OperatorCodesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # Model + def Subgraphs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = SubGraph() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Model + def SubgraphsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def SubgraphsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # Model + def Description(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Model + def Buffers(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = Buffer() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Model + def BuffersLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def BuffersIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # Model + def MetadataBuffer(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Model + def MetadataBufferAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Model + def MetadataBufferLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def MetadataBufferIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + return o == 0 + + # Model + def Metadata(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = Metadata() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Model + def MetadataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def MetadataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + return o == 0 + + # Model + def SignatureDefs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = SignatureDef() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # Model + def SignatureDefsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Model + def SignatureDefsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + return o == 0 + + +def ModelStart(builder): + builder.StartObject(8) + + +def ModelAddVersion(builder, version): + builder.PrependUint32Slot(0, version, 0) + + +def ModelAddOperatorCodes(builder, operatorCodes): + builder.PrependUOffsetTRelativeSlot( + 1, flatbuffers.number_types.UOffsetTFlags.py_type(operatorCodes), 0) + + +def ModelStartOperatorCodesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelAddSubgraphs(builder, subgraphs): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(subgraphs), 0) + + +def ModelStartSubgraphsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelAddDescription(builder, description): + builder.PrependUOffsetTRelativeSlot( + 3, flatbuffers.number_types.UOffsetTFlags.py_type(description), 0) + + +def ModelAddBuffers(builder, buffers): + builder.PrependUOffsetTRelativeSlot( + 4, flatbuffers.number_types.UOffsetTFlags.py_type(buffers), 0) + + +def ModelStartBuffersVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelAddMetadataBuffer(builder, metadataBuffer): + builder.PrependUOffsetTRelativeSlot( + 5, flatbuffers.number_types.UOffsetTFlags.py_type(metadataBuffer), 0) + + +def ModelStartMetadataBufferVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelAddMetadata(builder, metadata): + builder.PrependUOffsetTRelativeSlot( + 6, flatbuffers.number_types.UOffsetTFlags.py_type(metadata), 0) + + +def ModelStartMetadataVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelAddSignatureDefs(builder, signatureDefs): + builder.PrependUOffsetTRelativeSlot( + 7, flatbuffers.number_types.UOffsetTFlags.py_type(signatureDefs), 0) + + +def ModelStartSignatureDefsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def ModelEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class ModelT(object): + + # ModelT + def __init__(self): + self.version = 0 # type: int + self.operatorCodes = None # type: List[OperatorCodeT] + self.subgraphs = None # type: List[SubGraphT] + self.description = None # type: str + self.buffers = None # type: List[BufferT] + self.metadataBuffer = None # type: List[int] + self.metadata = None # type: List[MetadataT] + self.signatureDefs = None # type: List[SignatureDefT] + + @classmethod + def InitFromBuf(cls, buf, pos): + model = Model() + model.Init(buf, pos) + return cls.InitFromObj(model) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, model): + x = ModelT() + x._UnPack(model) + return x + + # ModelT + def _UnPack(self, model): + if model is None: + return + self.version = model.Version() + if not model.OperatorCodesIsNone(): + self.operatorCodes = [] + for i in range(model.OperatorCodesLength()): + if model.OperatorCodes(i) is None: + self.operatorCodes.append(None) + else: + operatorCode_ = OperatorCodeT.InitFromObj(model.OperatorCodes(i)) + self.operatorCodes.append(operatorCode_) + if not model.SubgraphsIsNone(): + self.subgraphs = [] + for i in range(model.SubgraphsLength()): + if model.Subgraphs(i) is None: + self.subgraphs.append(None) + else: + subGraph_ = SubGraphT.InitFromObj(model.Subgraphs(i)) + self.subgraphs.append(subGraph_) + self.description = model.Description() + if not model.BuffersIsNone(): + self.buffers = [] + for i in range(model.BuffersLength()): + if model.Buffers(i) is None: + self.buffers.append(None) + else: + buffer_ = BufferT.InitFromObj(model.Buffers(i)) + self.buffers.append(buffer_) + if not model.MetadataBufferIsNone(): + if np is None: + self.metadataBuffer = [] + for i in range(model.MetadataBufferLength()): + self.metadataBuffer.append(model.MetadataBuffer(i)) + else: + self.metadataBuffer = model.MetadataBufferAsNumpy() + if not model.MetadataIsNone(): + self.metadata = [] + for i in range(model.MetadataLength()): + if model.Metadata(i) is None: + self.metadata.append(None) + else: + metadata_ = MetadataT.InitFromObj(model.Metadata(i)) + self.metadata.append(metadata_) + if not model.SignatureDefsIsNone(): + self.signatureDefs = [] + for i in range(model.SignatureDefsLength()): + if model.SignatureDefs(i) is None: + self.signatureDefs.append(None) + else: + signatureDef_ = SignatureDefT.InitFromObj(model.SignatureDefs(i)) + self.signatureDefs.append(signatureDef_) + + # ModelT + def Pack(self, builder): + if self.operatorCodes is not None: + operatorCodeslist = [] + for i in range(len(self.operatorCodes)): + operatorCodeslist.append(self.operatorCodes[i].Pack(builder)) + ModelStartOperatorCodesVector(builder, len(self.operatorCodes)) + for i in reversed(range(len(self.operatorCodes))): + builder.PrependUOffsetTRelative(operatorCodeslist[i]) + operatorCodes = builder.EndVector() + if self.subgraphs is not None: + subgraphslist = [] + for i in range(len(self.subgraphs)): + subgraphslist.append(self.subgraphs[i].Pack(builder)) + ModelStartSubgraphsVector(builder, len(self.subgraphs)) + for i in reversed(range(len(self.subgraphs))): + builder.PrependUOffsetTRelative(subgraphslist[i]) + subgraphs = builder.EndVector() + if self.description is not None: + description = builder.CreateString(self.description) + if self.buffers is not None: + bufferslist = [] + for i in range(len(self.buffers)): + bufferslist.append(self.buffers[i].Pack(builder)) + ModelStartBuffersVector(builder, len(self.buffers)) + for i in reversed(range(len(self.buffers))): + builder.PrependUOffsetTRelative(bufferslist[i]) + buffers = builder.EndVector() + if self.metadataBuffer is not None: + if np is not None and type(self.metadataBuffer) is np.ndarray: + metadataBuffer = builder.CreateNumpyVector(self.metadataBuffer) + else: + ModelStartMetadataBufferVector(builder, len(self.metadataBuffer)) + for i in reversed(range(len(self.metadataBuffer))): + builder.PrependInt32(self.metadataBuffer[i]) + metadataBuffer = builder.EndVector() + if self.metadata is not None: + metadatalist = [] + for i in range(len(self.metadata)): + metadatalist.append(self.metadata[i].Pack(builder)) + ModelStartMetadataVector(builder, len(self.metadata)) + for i in reversed(range(len(self.metadata))): + builder.PrependUOffsetTRelative(metadatalist[i]) + metadata = builder.EndVector() + if self.signatureDefs is not None: + signatureDefslist = [] + for i in range(len(self.signatureDefs)): + signatureDefslist.append(self.signatureDefs[i].Pack(builder)) + ModelStartSignatureDefsVector(builder, len(self.signatureDefs)) + for i in reversed(range(len(self.signatureDefs))): + builder.PrependUOffsetTRelative(signatureDefslist[i]) + signatureDefs = builder.EndVector() + ModelStart(builder) + ModelAddVersion(builder, self.version) + if self.operatorCodes is not None: + ModelAddOperatorCodes(builder, operatorCodes) + if self.subgraphs is not None: + ModelAddSubgraphs(builder, subgraphs) + if self.description is not None: + ModelAddDescription(builder, description) + if self.buffers is not None: + ModelAddBuffers(builder, buffers) + if self.metadataBuffer is not None: + ModelAddMetadataBuffer(builder, metadataBuffer) + if self.metadata is not None: + ModelAddMetadata(builder, metadata) + if self.signatureDefs is not None: + ModelAddSignatureDefs(builder, signatureDefs) + model = ModelEnd(builder) + return model diff --git a/tools/circle2circle/fuse.bmm_lhs_const.py b/tools/circle2circle/fuse.bmm_lhs_const.py new file mode 100755 index 00000000000..8fd7f022db1 --- /dev/null +++ b/tools/circle2circle/fuse.bmm_lhs_const.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 + +import numpy as np +import circle +import o2o + + +def get_tensor_by_index(subgraph, index): + """Safely get tensor by its index.""" + if 0 <= index < len(subgraph.tensors): + return subgraph.tensors[index] + return None + + +def is_tensor_constant(tensor, model_buffers): + """Check if a tensor is constant by verifying its buffer.""" + if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): + # A non-zero buffer index that points to a valid buffer typically means it's constant. + # The 0th buffer is always an empty buffer. + return True + return False + + +def find_operator_by_output(subgraph, output_tensor_index): + """Find the first operator that produces the given output tensor index.""" + for op_idx, operator in enumerate(subgraph.operators): + if operator.outputs and output_tensor_index in operator.outputs: + return op_idx, operator + return None, None + + +def from_buffer(buffer_index, model_buffers): + """Converts buffer data to a numpy array (int32).""" + if buffer_index > 0 and buffer_index - 1 < len(model_buffers): + buffer_obj = model_buffers[buffer_index] + if buffer_obj and len(buffer_obj.data) > 0: + # Assuming data is a bytearray of int32s. + # This needs to match the actual data type in the model. + try: + return np.frombuffer(buffer_obj.data, dtype=np.int32) + except Exception as e: + o2o.log( + f"Could not parse permutation tensor buffer for buffer index {buffer_index}: {e}" + ) + return None + return None + + +def is_effectively_2d(shape): + """Check if a tensor shape is effectively 2D (all leading dimensions are 1)""" + if len(shape) < 2: + return False # Cannot be effectively 2D if less than 2 dimensions + return all(dim == 1 for dim in shape[:-2]) + + +def count_tensor_usage(model, tensor_index): + """Count how many operators use a specific tensor as input""" + count = 0 + for subgraph_idx, subgraph in enumerate(model.subgraphs): + for operator in subgraph.operators: + if operator.inputs is not None: + for input_idx in operator.inputs: + if input_idx == tensor_index: + count += 1 + return count + + +def get_or_create_operator_code(model, builtin_op_type): + """Get the index of an operator code, or create it if it doesn't exist.""" + for i, op_code in enumerate(model.operatorCodes): + if op_code.builtinCode == builtin_op_type: + return i + + # If not found, create a new one + new_op_code = circle.OperatorCodeT() + new_op_code.builtinCode = builtin_op_type + new_op_code.version = 1 # Default version + model.operatorCodes.append(new_op_code) + return len(model.operatorCodes) - 1 + + +def create_transpose_permutation_tensor(model, subgraph, rank): + """Create a permutation tensor for transposing last two dimensions.""" + # Create permutation: [0, 1, ..., rank-3, rank-1, rank-2] + perm_shape = [rank] + perm_data = list(range(rank)) + perm_data[-1], perm_data[-2] = perm_data[-2], perm_data[-1] # Swap last two + + # Create buffer for permutation data + perm_buffer = circle.BufferT() + perm_buffer.data = np.array(perm_data, dtype=np.int32).tobytes() + model.buffers.append(perm_buffer) + buffer_index = len(model.buffers) + + # Create tensor + perm_tensor = circle.TensorT() + perm_tensor.shape = perm_shape + perm_tensor.type = circle.TensorType.INT32 + perm_tensor.buffer = buffer_index + perm_tensor.name = f"transpose_perm_{len(subgraph.tensors)}" + subgraph.tensors.append(perm_tensor) + tensor_index = len(subgraph.tensors) - 1 + + return tensor_index + + +def add_rhs_transpose_if_needed(model, subgraph, bmm_op_idx, rhs_tensor_index, + rhs_tensor): + """Add TRANSPOSE operator for RHS if K != 1 OR B != 1.""" + if len(rhs_tensor.shape) < 3: + # Need at least 3 dimensions: [B, K, N] + return rhs_tensor_index + + B = rhs_tensor.shape[0] + K = rhs_tensor.shape[1] + + # Skip transpose if both B = 1 and K = 1 + if B == 1 and K == 1: + o2o.log( + f"RHS tensor shape {rhs_tensor.shape} has B=1 and K=1, skipping transpose") + return rhs_tensor_index + + o2o.log(f"Adding transpose for RHS tensor shape {rhs_tensor.shape} (B={B}, K={K})") + + # Create permutation tensor + rank = len(rhs_tensor.shape) + perm_tensor_index = create_transpose_permutation_tensor(model, subgraph, rank) + + # Create output tensor for transposed RHS + transposed_rhs_tensor = circle.TensorT() + transposed_rhs_tensor.shape = list(rhs_tensor.shape) + transposed_rhs_tensor.shape[-1], transposed_rhs_tensor.shape[ + -2] = transposed_rhs_tensor.shape[-2], transposed_rhs_tensor.shape[-1] + transposed_rhs_tensor.type = rhs_tensor.type + transposed_rhs_tensor.buffer = 0 # No buffer (intermediate tensor) + transposed_rhs_tensor.name = f"transposed_rhs_{len(subgraph.tensors)}" + subgraph.tensors.append(transposed_rhs_tensor) + transposed_rhs_tensor_index = len(subgraph.tensors) - 1 + + # Create TRANSPOSE operator + transpose_op = circle.OperatorT() + transpose_op.opcodeIndex = get_or_create_operator_code( + model, circle.BuiltinOperator.TRANSPOSE) + transpose_op.inputs = [rhs_tensor_index, perm_tensor_index] + transpose_op.outputs = [transposed_rhs_tensor_index] + transpose_op.builtinOptionsType = circle.BuiltinOptions.TransposeOptions + transpose_options = circle.TransposeOptionsT() + transpose_op.builtinOptions = transpose_options + + # Insert TRANSPOSE operator after BATCH_MATMUL + subgraph.operators.insert(bmm_op_idx + 1, transpose_op) + + return transposed_rhs_tensor_index + + +def fuse_bmm_transpose(): + """Main function to add RHS transpose before fusing batchmatmul(lhs, rhs) to fullyconnected(transposed_rhs, lhs) when lhs is constant.""" + o2o.log("Loading model from stdin") + model = o2o.load_model_from_stdin() + + if not model.subgraphs: + o2o.log("Model has no subgraphs. Exiting.") + o2o.save_circle_model(model, output_file) # Save original if no subgraphs + return + + subgraph = model.subgraphs[0] # Assuming single subgraph for now, can be extended + tensors_to_potentially_remove = set() + # Define operators to remove (empty list for now) + operators_to_remove = [] # No operators to remove by default + + # Iterate backwards to safely remove operators + for i in range(len(subgraph.operators) - 1, -1, -1): + transpose_op = subgraph.operators[i] + + # Check if current operator is TRANSPOSE + transpose_opcode = model.operatorCodes[transpose_op.opcodeIndex] + if transpose_opcode.builtinCode != circle.BuiltinOperator.TRANSPOSE: + continue + + if len(transpose_op.inputs) != 2: + o2o.log( + f"Transpose operator at index {i} has invalid number of inputs. Skipping." + ) + continue + + transpose_input_tensor_idx = transpose_op.inputs[0] + bmm_op_idx, bmm_op = find_operator_by_output(subgraph, transpose_input_tensor_idx) + + # Check if the found operator is BATCH_MATMUL + if bmm_op is None or model.operatorCodes[ + bmm_op.opcodeIndex].builtinCode != circle.BuiltinOperator.BATCH_MATMUL: + continue + + lhs_tensor_index = bmm_op.inputs[0] + rhs_tensor_index = bmm_op.inputs[1] + + lhs_tensor = get_tensor_by_index(subgraph, lhs_tensor_index) + rhs_tensor = get_tensor_by_index(subgraph, rhs_tensor_index) + + if not lhs_tensor or not rhs_tensor: + o2o.log( + f"Could not find LHS or RHS tensor for BATCH_MATMUL at index {bmm_op_idx}. Skipping." + ) + continue + + # Crucial check: LHS must be constant + if not is_tensor_constant(lhs_tensor, model.buffers): + o2o.log( + f"LHS tensor '{lhs_tensor.name if lhs_tensor.name else lhs_tensor_index}' for BATCH_MATMUL at index {bmm_op_idx} is not constant. Skipping fusion." + ) + continue + + # Verify Transpose permutation (assuming transpose of last two dims) + # e.g. for [..., M, N] -> [..., N, M], permutation is [..., dim_N-1, dim_N-2] + # For a 2D tensor [M, N] -> [N, M], permutation is [1, 0] + # For a 3D tensor [B, M, N] -> [B, N, M], permutation is [0, 2, 1] + valid_permutation = False + perm_tensor_index = transpose_op.inputs[1] + perm_tensor = get_tensor_by_index(subgraph, perm_tensor_index) + + if perm_tensor and is_tensor_constant(perm_tensor, model.buffers): + # Get permutation data from buffer using the new helper function + perm = from_buffer(perm_tensor.buffer, model.buffers) + if len(perm) >= 2: # At least 2D + # Check if the last two elements of permutation are swapped + # and other elements are in their original ascending order (0, 1, 2, ...) + expected_perm_prefix = list(range(len(perm) - 2)) + actual_perm_prefix = perm[:-2] + + if np.all(actual_perm_prefix == expected_perm_prefix) and \ + perm[-2] == len(perm) - 1 and \ + perm[-1] == len(perm) - 2: + valid_permutation = True + else: + o2o.log( + f"Permutation tensor for TRANSPOSE at index {i} is not constant or not found. Skipping." + ) + + if not valid_permutation: + o2o.log( + f"TRANSPOSE operator at index {i} does not have a simple last-two-dim permutation. Skipping fusion." + ) + continue + + # Add TRANSPOSE for RHS if needed (K != 1 OR B != 1) + final_rhs_tensor_index = add_rhs_transpose_if_needed(model, subgraph, bmm_op_idx, + rhs_tensor_index, rhs_tensor) + + # Create the new FULLY_CONNECTED operator + fc_op = circle.OperatorT() + fc_op.opcodeIndex = get_or_create_operator_code( + model, circle.BuiltinOperator.FULLY_CONNECTED) + # Set inputs: [transposed_rhs, original_lhs, -1] where -1 means bias not exists + fc_op.inputs = [final_rhs_tensor_index, lhs_tensor_index, -1] + # Set outputs: same as the original TRANSPOSE operator + fc_op.outputs = list(transpose_op.outputs) # Make a copy + + # Configure FULLY_CONNECTED options + fc_op.builtinOptionsType = (circle.BuiltinOptions.FullyConnectedOptions) + fc_options = circle.FullyConnectedOptionsT() + fc_options.keepNumDims = True # Important to preserve batch dimensions from BATCH_MATMUL + fc_op.builtinOptions = fc_options + + # Add the new operator to the subgraph + # Insert it at the position of the original BATCH_MATMUL operator + o2o.log(f"Replacing batchmatmul at {bmm_op_idx} with fullyconnected") + subgraph.operators[bmm_op_idx] = fc_op + + # Mark the original TRANSPOSE operator for removal + operators_to_remove.append(i) + + # The tensor connecting BMM and Transpose (bmm_output_tensor_index) is now an intermediate + # output of the new FC op. If it's not used by any other op, it could be cleaned up. + # For now, we just mark it. Actual removal is more complex (needs usage check). + tensors_to_potentially_remove.add(transpose_input_tensor_idx) + + # Remove operators marked for removal (iterate backwards again for safe removal) + for i in sorted(list(operators_to_remove), reverse=True): + if 0 <= i < len(subgraph.operators): + o2o.log(f"Removing transpose operator at index {i}") + del subgraph.operators[i] + + # Note: Cleanup of unused tensors and operator codes is a more advanced step + # and not implemented here for simplicity, but would be part of a production-ready script. + o2o.log(f"TODO: Remove tensors at {tensors_to_potentially_remove}") + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + fuse_bmm_transpose() diff --git a/tools/circle2circle/gen_circle.add.py b/tools/circle2circle/gen_circle.add.py new file mode 100755 index 00000000000..acb9bf33c9b --- /dev/null +++ b/tools/circle2circle/gen_circle.add.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import os +import numpy as np +import circle +import o2o + +# Circle Model Buffer Usage Rules (based on circle_schema.fbs and analysis) +# ====================================================================== +# +# Buffer Index Allocation Rules: +# - B(0): Always empty placeholder buffer (sentinel, must exist) +# - B(1+): Dedicated buffers for specific tensors +# +# Tensor-Buffer Assignment Rules: +# 1. Model Input Tensors: +# - Get dedicated buffer index (e.g., B(1), B(2), ...) +# - Buffer data is EMPTY (b'') +# - Added to subgraph.inputs array +# - Example: input tensor -> buffer index 1 (empty) +# +# 2. Model Output Tensors: +# - Get dedicated buffer index (e.g., B(1), B(2), ...) +# - Buffer data is EMPTY (b'') +# - Added to subgraph.outputs array +# - Example: output tensor -> buffer index 2 (empty) +# +# 3. Constant Tensors: +# - Get dedicated buffer index (e.g., B(3), B(4), ...) +# - Buffer data contains ACTUAL DATA (numpy.tobytes()) +# - NOT added to subgraph.inputs (internal to model) +# - Example: constant tensor -> buffer index 3 (with data) +# +# 4. Intermediate Tensors: +# - Get dedicated buffer index (e.g., B(4), B(5), ...) +# - Buffer data is EMPTY (b'') +# - NOT added to subgraph.inputs/outputs +# - Example: intermediate result -> buffer index 4 (empty) +# - IMPORTANT: Intermediate tensors are NOT constants, so they need dedicated buffers! +# +# Buffer Creation Order (Recommended): +# 1. B(0): Empty placeholder buffer (always first) +# 2. Input tensor buffers (empty data) +# 3. Output tensor buffers (empty data) +# 4. Constant tensor buffers (with actual data) +# +# Key Principles: +# - Each tensor type has specific buffer requirements +# - Model inputs/outputs MUST have dedicated buffers (even if empty) +# - Constants MUST have dedicated buffers with actual data +# - Intermediate results use buffer index 0 +# - Buffer index assignment follows creation order in model.buffers array +# +# Reference: circle_schema.fbs - "buffer:uint" field documentation + + +def create_simple_add_model(output_file): + """Create a simple Circle model with one ADD operator (similar to add.circle).""" + + # Create model + model = circle.ModelT() + model.version = 3 + model.operatorCodes = [] + model.subgraphs = [] + model.buffers = [] + model.metadataBuffer = [] + + # Create subgraph + subgraph = circle.SubGraphT() + subgraph.tensors = [] + subgraph.inputs = [] + subgraph.outputs = [] + subgraph.operators = [] + subgraph.name = "main" + + # Create buffers in CORRECT order (output buffer last) + # B(0) - empty_sentinel_buffer (always existent empty buffer for tensors with no data buffer) + empty_sentinel_buffer = circle.BufferT() + # No data assignment for empty buffer (buffer.data remains None) + model.buffers.append(empty_sentinel_buffer) + + # B(1) - input tensor buffer (no data) + input_buffer = circle.BufferT() + # No data assignment for input buffer (buffer.data remains None) + model.buffers.append(input_buffer) + + # B(2) - constant tensor buffer (with data, 16-byte aligned) + const_data = np.array([1], dtype=np.int32) # Simple constant value + const_buffer = circle.BufferT() + # Align to 16 bytes as required by circle_schema.fbs Buffer.force_align: 16 + raw_data = const_data.tobytes() + padded_data = raw_data + b'\x00' * (16 - len(raw_data)) # 4 + 12 = 16 bytes + const_buffer.data = padded_data + model.buffers.append(const_buffer) + + # B(3) - output tensor buffer (no data) - MOVED TO LAST + output_buffer = circle.BufferT() + # No data assignment for output buffer (buffer.data remains None) + model.buffers.append(output_buffer) + + # Create input tensor (ifm) - using dedicated buffer B(1) + input_tensor = circle.TensorT() + input_tensor.shape = [1, 1, 16] # Same as add.circle + input_tensor.type = circle.TensorType.INT32 # Using INT32 + input_tensor.buffer = 1 # B(1) - dedicated input buffer + input_tensor.name = "ifm" + subgraph.tensors.append(input_tensor) + input_tensor_index = len(subgraph.tensors) - 1 + subgraph.inputs.append(input_tensor_index) + + # Create constant tensor (add_const) - using dedicated buffer B(2) + const_tensor = circle.TensorT() + const_tensor.shape = [1, 1, 1] # Same as add.circle + const_tensor.type = circle.TensorType.INT32 + const_tensor.buffer = 2 # B(2) - dedicated constant buffer with data + const_tensor.name = "add_const" + subgraph.tensors.append(const_tensor) + const_tensor_index = len(subgraph.tensors) - 1 + + # Create output tensor (ofm) - using dedicated buffer B(3) - MOVED TO LAST + output_tensor = circle.TensorT() + output_tensor.shape = [1, 1, 16] # Same as add.circle + output_tensor.type = circle.TensorType.INT32 + output_tensor.buffer = 3 # B(3) - dedicated output buffer (last index) + output_tensor.name = "ofm" + subgraph.tensors.append(output_tensor) + output_tensor_index = len(subgraph.tensors) - 1 + subgraph.outputs.append(output_tensor_index) + + # Create ADD operator code + add_opcode = circle.OperatorCodeT() + add_opcode.builtinCode = circle.BuiltinOperator.ADD + add_opcode.deprecatedBuiltinCode = circle.BuiltinOperator.ADD # Fix: deprecatedBuiltinCode must be set to same as builtinCode + add_opcode.version = 1 + model.operatorCodes.append(add_opcode) + add_opcode_index = len(model.operatorCodes) - 1 + + # Create ADD operator + add_op = circle.OperatorT() + add_op.opcodeIndex = add_opcode_index + add_op.inputs = [input_tensor_index, const_tensor_index] # ifm + add_const + add_op.outputs = [output_tensor_index] # = ofm + add_op.builtinOptionsType = circle.BuiltinOptions.AddOptions + add_options = circle.AddOptionsT() + add_options.fusedActivationFunction = circle.ActivationFunctionType.NONE + add_op.builtinOptions = add_options + subgraph.operators.append(add_op) + + # Add subgraph to model + model.subgraphs.append(subgraph) + + # Save model + o2o.save_circle_model(model, output_file) + o2o.log(f"Simple ADD model saved to {output_file}") + o2o.log(f"Model structure:") + o2o.log(f" Input tensor: {input_tensor.name} shape={input_tensor.shape}") + o2o.log(f" Constant tensor: {const_tensor.name} shape={const_tensor.shape}") + o2o.log(f" Output tensor: {output_tensor.name} shape={output_tensor.shape}") + o2o.log(f" Operator: ADD") + o2o.log(f" Subgraph inputs: {[subgraph.tensors[i].name for i in subgraph.inputs]}") + o2o.log(f" Subgraph outputs: {[subgraph.tensors[i].name for i in subgraph.outputs]}") + + +if __name__ == "__main__": + # Generate output filename from current script filename + # e.g., add.gen_circle.py -> add.circle + script_name = os.path.basename(__file__) + output_file = script_name.replace('.gen_circle.py', '.circle') + + create_simple_add_model(output_file) diff --git a/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py b/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py new file mode 100755 index 00000000000..a60eb939433 --- /dev/null +++ b/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +import os +import numpy as np +import circle +import o2o + + +def create_test_bmm_k_not_1_model(output_file): + """Create a test Circle model with BATCH_MATMUL where RHS K != 1.""" + + # Create model + model = circle.ModelT() + model.version = 3 + model.operatorCodes = [] + model.subgraphs = [] + model.buffers = [] + model.metadataBuffer = [] + + # Create subgraph + subgraph = circle.SubGraphT() + subgraph.tensors = [] + subgraph.inputs = [] + subgraph.outputs = [] + subgraph.operators = [] + subgraph.name = "main" + + # Create buffers in CORRECT order (output buffer last, proper alignment) + # B(0) - empty_sentinel_buffer (always existent empty buffer for tensors with no data buffer) + empty_sentinel_buffer = circle.BufferT() + # No data assignment for empty buffer (buffer.data remains None) + model.buffers.append(empty_sentinel_buffer) + + # B(1) - BMM rhs : model input + bmm_rhs_buffer = circle.BufferT() + # No data assignment for input buffer (buffer.data remains None) + model.buffers.append(bmm_rhs_buffer) + + # B(2) - BMM lhs : constant (16-byte aligned) + bmm_lhs_data = np.random.rand(1, 4, 3).astype(np.float32) # [B, M, K] = [1, 4, 3] + bmm_lhs_buffer = circle.BufferT() + # Align to 16 bytes as required by circle_schema.fbs Buffer.force_align: 16 + raw_data = bmm_lhs_data.tobytes() + padded_data = raw_data + b'\x00' * (16 - len(raw_data) % 16) if len( + raw_data) % 16 != 0 else raw_data + bmm_lhs_buffer.data = padded_data + model.buffers.append(bmm_lhs_buffer) + + # B(3) - Transpose perm : constant (16-byte aligned) + perm_data = np.array( + [0, 2, 1], dtype=np.int32) # Transpose last two dims: [B, M, N] -> [B, N, M] + perm_buffer = circle.BufferT() + # Align to 16 bytes as required by circle_schema.fbs Buffer.force_align: 16 + raw_perm_data = perm_data.tobytes() + padded_perm_data = raw_perm_data + b'\x00' * (16 - len(raw_perm_data) % 16) if len( + raw_perm_data) % 16 != 0 else raw_perm_data + perm_buffer.data = padded_perm_data + model.buffers.append(perm_buffer) + + # B(4) - BMM output (intermediate, no data) + bmm_output_buffer = circle.BufferT() + # No data assignment for intermediate buffer (buffer.data remains None) + model.buffers.append(bmm_output_buffer) + + # B(5) - TRANSPOSE output (final output, moved to last) + transpose_output_buffer = circle.BufferT() + # No data assignment for output buffer (buffer.data remains None) + model.buffers.append(transpose_output_buffer) + + # Create RHS input tensor with K != 1 (K=3) + bmm_rhs_tensor = circle.TensorT() + bmm_rhs_tensor.shape = [1, 3, 5] # [B, K, N] = [1, 3, 5], K=3 != 1 + bmm_rhs_tensor.type = circle.TensorType.FLOAT32 + bmm_rhs_tensor.buffer = 1 # B(1) - dedicated input buffer + bmm_rhs_tensor.name = "bmm_rhs_input" + subgraph.tensors.append(bmm_rhs_tensor) + bmm_rhs_tensor_index = len(subgraph.tensors) - 1 + subgraph.inputs.append(bmm_rhs_tensor_index) # Add to subgraph inputs + + # Create LHS constant tensor + bmm_lhs_tensor = circle.TensorT() + bmm_lhs_tensor.shape = [1, 4, 3] # [B, M, K] = [1, 4, 3] + bmm_lhs_tensor.type = circle.TensorType.FLOAT32 + bmm_lhs_tensor.buffer = 2 # B(2) - dedicated constant buffer with data + bmm_lhs_tensor.name = "bmm_lhs_constant" + subgraph.tensors.append(bmm_lhs_tensor) + bmm_lhs_tensor_index = len(subgraph.tensors) - 1 + # Note: LHS is constant, so NOT added to subgraph.inputs + + # Create permutation tensor for TRANSPOSE + perm_tensor = circle.TensorT() + perm_tensor.shape = [3] + perm_tensor.type = circle.TensorType.INT32 + perm_tensor.buffer = 3 # B(3) - dedicated constant buffer with data + perm_tensor.name = "transpose_perm" + subgraph.tensors.append(perm_tensor) + perm_tensor_index = len(subgraph.tensors) - 1 + + # Create BATCH_MATMUL output tensor + bmm_output_tensor = circle.TensorT() + bmm_output_tensor.shape = [1, 4, 5] # [B, M, N] = [1, 4, 5] + bmm_output_tensor.type = circle.TensorType.FLOAT32 + bmm_output_tensor.buffer = 4 # B(4) - intermediate buffer (no data) + bmm_output_tensor.name = "bmm_output" + subgraph.tensors.append(bmm_output_tensor) + bmm_output_tensor_index = len(subgraph.tensors) - 1 + + # Create final output tensor + transpose_output_tensor = circle.TensorT() + transpose_output_tensor.shape = [1, 5, 4] # [B, N, M] = [1, 5, 4] after transpose + transpose_output_tensor.type = circle.TensorType.FLOAT32 + transpose_output_tensor.buffer = 5 # B(5) - dedicated output buffer (last index) + transpose_output_tensor.name = "transpose_output" + subgraph.tensors.append(transpose_output_tensor) + transpose_output_tensor_index = len(subgraph.tensors) - 1 + subgraph.outputs.append(transpose_output_tensor_index) + + # Create operator codes + # BATCH_MATMUL + bmm_opcode = circle.OperatorCodeT() + bmm_opcode.builtinCode = circle.BuiltinOperator.BATCH_MATMUL + bmm_opcode.deprecatedBuiltinCode = circle.BuiltinOperator.BATCH_MATMUL + bmm_opcode.version = 1 + model.operatorCodes.append(bmm_opcode) + bmm_opcode_index = len(model.operatorCodes) - 1 + + # TRANSPOSE + transpose_opcode = circle.OperatorCodeT() + transpose_opcode.builtinCode = circle.BuiltinOperator.TRANSPOSE + transpose_opcode.deprecatedBuiltinCode = circle.BuiltinOperator.TRANSPOSE + transpose_opcode.version = 1 + model.operatorCodes.append(transpose_opcode) + transpose_opcode_index = len(model.operatorCodes) - 1 + + # Create BATCH_MATMUL operator + bmm_op = circle.OperatorT() + bmm_op.opcodeIndex = bmm_opcode_index + bmm_op.inputs = [bmm_lhs_tensor_index, + bmm_rhs_tensor_index] # LHS constant, RHS input + bmm_op.outputs = [bmm_output_tensor_index] # BMM output + bmm_op.builtinOptionsType = circle.BuiltinOptions.BatchMatMulOptions + bmm_options = circle.BatchMatMulOptionsT() + bmm_options.adjointLhs = False # Fixed: adjacentX -> adjointLhs + bmm_options.adjointRhs = False # Fixed: adjacentY -> adjointRhs + bmm_options.asymmetricQuantizeInputs = False # Added missing field + bmm_options.fusedActivationFunction = circle.ActivationFunctionType.NONE + bmm_op.builtinOptions = bmm_options + subgraph.operators.append(bmm_op) + + # Create TRANSPOSE operator + transpose_op = circle.OperatorT() + transpose_op.opcodeIndex = transpose_opcode_index + transpose_op.inputs = [bmm_output_tensor_index, + perm_tensor_index] # BMM output, permutation + transpose_op.outputs = [transpose_output_tensor_index] # Final output + transpose_op.builtinOptionsType = circle.BuiltinOptions.TransposeOptions + transpose_options = circle.TransposeOptionsT() + transpose_op.builtinOptions = transpose_options + subgraph.operators.append(transpose_op) + + # Add subgraph to model + model.subgraphs.append(subgraph) + + # Save model + o2o.save_circle_model(model, output_file) + o2o.log(f"Test model saved to {output_file}") + o2o.log(f"Model structure:") + o2o.log(f" LHS constant tensor shape: {bmm_lhs_tensor.shape} (with actual data)") + o2o.log( + f" RHS input tensor shape: {bmm_rhs_tensor.shape} (K={bmm_rhs_tensor.shape[1]} != 1)" + ) + o2o.log(f" BMM output tensor shape: {bmm_output_tensor.shape}") + o2o.log( + f" TRANSPOSE output tensor shape: {transpose_output_tensor.shape} (after transpose)" + ) + o2o.log(f" Model inputs: {[subgraph.tensors[i].name for i in subgraph.inputs]}") + o2o.log(f" Model outputs: {[subgraph.tensors[i].name for i in subgraph.outputs]}") + o2o.log( + f" Operations: BATCH_MATMUL(LHS_constant + RHS_input) -> TRANSPOSE(properly connected)" + ) + + +if __name__ == "__main__": + # Generate output filename from current script filename + # e.g., cvt.bmm_lhs_const.fc.circle_gen.py -> cvt.bmm_lhs_const.fc.circle + script_name = os.path.basename(__file__) + output_file = script_name.replace('gen_circle.', '').replace('.py', '.circle') + + create_test_bmm_k_not_1_model(output_file) diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py new file mode 100755 index 00000000000..4dc4bbfc6b4 --- /dev/null +++ b/tools/circle2circle/o2o.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +import sys +import circle +import flatbuffers + + +def log(message): + """Log message to stderr""" + print(message, file=sys.stderr) + + +def load_model_from_stdin(): + """Load a Circle model from binary data read from stdin.""" + data = sys.stdin.buffer.read() + buf = bytearray(data) + model = circle.Model.GetRootAsModel(buf, 0) + model = circle.ModelT.InitFromObj(model) + return model + + +def save_model_to_stdout(model): + """Serialize a Circle model and write it to stdout as binary data.""" + builder = flatbuffers.Builder(1024) + builder.Finish(model.Pack(builder), b'CIR0') + sys.stdout.buffer.write(builder.Output()) + + +def load_circle_model(input_file): + """Load and parse a circle model file""" + with open(input_file, 'rb') as f: + buf = bytearray(f.read()) + + model = circle.Model.GetRootAsModel(buf, 0) + model = circle.ModelT.InitFromObj(model) + return model + + +def save_circle_model(model, output_file): + """Save a circle model to file using flatbuffers""" + builder = flatbuffers.Builder(1024) + builder.Finish(model.Pack(builder), b'CIR0') + + with open(output_file, 'wb') as f: + f.write(builder.Output()) + + +def handle_cli_args(usage_message): + """Handle common command line argument parsing and validation""" + if len(sys.argv) != 3: + log(usage_message) + sys.exit(1) + + input_file = sys.argv[1] + output_file = sys.argv[2] + return input_file, output_file + + +def get_tensor_name(tensor): + """Get tensor name as string, handling bytes conversion""" + if tensor.name: + return tensor.name.decode('utf-8') if isinstance(tensor.name, + bytes) else tensor.name + return None + + +def process_subgraphs(model, processor_func): + """Generic subgraph processor with modification tracking + + Args: + model: Circle model object + processor_func: Function that processes a subgraph and returns (modified, changes_count) + + Returns: + tuple: (overall_modified, total_changes) + """ + overall_modified = False + total_changes = 0 + + for subgraph in model.subgraphs: + modified, changes_count = processor_func(subgraph) + overall_modified = overall_modified or modified + total_changes += changes_count + + return overall_modified, total_changes + + +def rename_tensor_if_matches(tensor, pattern, replacement_func): + """Rename tensor if it matches the given pattern + + Args: + tensor: Tensor object to process + pattern: Regex pattern to match + replacement_func: Function that takes regex match and returns new name + + Returns: + tuple: (was_renamed, old_name, new_name) + """ + tensor_name = get_tensor_name(tensor) + if not tensor_name: + return False, None, None + + import re + match = re.match(pattern, tensor_name) + if match: + old_name = tensor_name + new_name = replacement_func(match) + tensor.name = new_name + return True, old_name, new_name + + return False, None, None + + +def safe_execute(main_func, + input_file, + output_file, + *args, + error_message="Error processing file"): + """Safely execute the main function with error handling""" + try: + main_func(input_file, output_file, *args) + log(f"Successfully processed {input_file} and saved to {output_file}") + except Exception as e: + log(f"{error_message}: {e}") + sys.exit(1) diff --git a/tools/circle2circle/remove.io.py b/tools/circle2circle/remove.io.py new file mode 100755 index 00000000000..9e0830ed12d --- /dev/null +++ b/tools/circle2circle/remove.io.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 + +import sys +import argparse +import o2o + + +def parse_names(names_str): + """Parse comma‑separated tensor names into a list of names.""" + return [name.strip() for name in names_str.split(',') if name.strip()] + + +def remove_io_tensors(io_type, names_to_keep): + """Remove input or output tensors, keeping only specified tensor names""" + # Load the model using utility function + model = o2o.load_model_from_stdin() + + def process_subgraph(subgraph): + """Process a single subgraph""" + if io_type == 'input': + io_list = subgraph.inputs + io_name = 'input' + elif io_type == 'output': + io_list = subgraph.outputs + io_name = 'output' + else: + raise ValueError(f"Invalid io_type: {io_type}. Must be 'input' or 'output'") + + o2o.log(f"Processing subgraph with {len(io_list)} {io_name}s") + o2o.log(f"Original {io_name} indices: {io_list}") + + # Build a mapping from tensor name to its index for the selected I/O list + name_to_index = {} + for io_idx in io_list: + tensor = subgraph.tensors[io_idx] + tensor_name = o2o.get_tensor_name(tensor) + if tensor_name: + name_to_index[tensor_name] = io_idx + + # Filter tensors to keep by name + new_io_list = [] + for name in names_to_keep: + if name in name_to_index: + new_io_list.append(name_to_index[name]) + else: + o2o.log(f"Warning: {io_name} tensor name '{name}' not found") + + # Update the subgraph + if io_type == 'input': + subgraph.inputs = new_io_list + else: + subgraph.outputs = new_io_list + + o2o.log(f"New {io_name} indices: {[i+1 for i in range(len(new_io_list))]}") + + removed_count = len(io_list) - len(new_io_list) + return removed_count > 0, removed_count + + # Process all subgraphs using utility function + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No tensors were removed.") + + # Save the model using utility function + o2o.save_model_to_stdout(model) + + +def main(): + parser = argparse.ArgumentParser( + description= + 'Remove input or output tensors from Circle model, keeping only specified tensor names' + ) + parser.add_argument('io_type', + choices=['input', 'output'], + help='Whether to process inputs or outputs') + parser.add_argument( + '--keep_by_name', + required=True, + help='Comma‑separated tensor names to keep (e.g., "tensorA,tensorB")') + # No file arguments needed; model is read from stdin and written to stdout + + args = parser.parse_args() + + # Parse the tensor names + try: + names_to_keep = parse_names(args.keep_by_name) + o2o.log(f"Tensor names to keep: {names_to_keep}") + except ValueError as e: + o2o.log(f"Error parsing tensor names: {e}") + sys.exit(1) + + # Execute with error handling using utility function + # Directly invoke the processing function; I/O handled via stdin/stdout + remove_io_tensors(args.io_type, names_to_keep) + + +if __name__ == "__main__": + main() diff --git a/tools/circle2circle/remove.unused_tensors.py b/tools/circle2circle/remove.unused_tensors.py new file mode 100755 index 00000000000..9a560bca6b8 --- /dev/null +++ b/tools/circle2circle/remove.unused_tensors.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +import sys +# import argparse # Removed: script now uses stdin/stdout instead of file arguments +import flatbuffers +import circle +import o2o # For saving the model + + +def get_tensor_name(tensor): + """Get tensor name as string, handling bytes conversion""" + if tensor.name: + return tensor.name.decode('utf-8') if isinstance(tensor.name, + bytes) else tensor.name + return None + + +def find_unused_tensors_in_subgraph(subgraph): + """ + Finds and returns the indices of unused tensors in a given subgraph. + This function uses the Native API for read-only subgraph objects. + + Args: + subgraph: The Circle read-only subgraph object. + + Returns: + list: A list of integer indices representing unused tensors. + """ + num_tensors = subgraph.TensorsLength() + if num_tensors == 0: + return [] + + used_tensor_indices = set() + output_tensor_indices = set() + + # Collect output tensor indices + for i in range(subgraph.OutputsLength()): + output_tensor_indices.add(subgraph.Outputs(i)) + + # Collect input tensor indices from all operators + for i in range(subgraph.OperatorsLength()): + operator = subgraph.Operators(i) + if operator and operator.InputsLength(): + for j in range(operator.InputsLength()): + input_tensor_index = operator.Inputs(j) + # In Circle schema, -1 indicates an optional input that is not used. + if input_tensor_index != -1: + used_tensor_indices.add(input_tensor_index) + + # A tensor is unused if it's not used by any operator AND not an output of the subgraph + unused_indices = [] + for i in range(num_tensors): + if i not in used_tensor_indices and i not in output_tensor_indices: + unused_indices.append(i) + + return unused_indices + + +def remove_tensors_and_update_model(model, subgraph_index_to_modify, + tensor_indices_to_remove): + """ + Removes specified tensors from the model and updates all relevant references. + This function uses the Object API for mutable model/subgraph/operator objects. + + Args: + model: The mutable Circle model object (ModelT). + subgraph_index_to_modify (int): The index of the subgraph to modify. + tensor_indices_to_remove (list): A list of tensor indices to remove. + Must be sorted in descending order. + + Returns: + list: The list of tensor indices that were actually removed. + """ + if not model.subgraphs or subgraph_index_to_modify >= len(model.subgraphs): + o2o.log( + f"Error: Invalid subgraph index {subgraph_index_to_modify} for modification.") + return [] + + subgraph = model.subgraphs[subgraph_index_to_modify] + removed_indices = [] + + # Sort in descending order to avoid index shifting issues during removal + for tensor_idx in sorted(tensor_indices_to_remove, reverse=True): + if 0 <= tensor_idx < len(subgraph.tensors): + tensor_name = get_tensor_name(subgraph.tensors[tensor_idx]) + o2o.log( + f" Subgraph {subgraph_index_to_modify}: Removing tensor at index {tensor_idx}: {tensor_name}" + ) + del subgraph.tensors[tensor_idx] + removed_indices.append(tensor_idx) + else: + o2o.log( + f" Subgraph {subgraph_index_to_modify}: Warning: Tensor index {tensor_idx} out of bounds, skipping." + ) + + if not removed_indices: + return [] + + # Create a map for old index to new index after removal + new_indices_map = {} + current_new_idx = 0 + # Iterate over original tensor count of this subgraph + original_tensor_count = len(subgraph.tensors) + len(removed_indices) + for old_idx in range(original_tensor_count): + if old_idx not in tensor_indices_to_remove: + new_indices_map[old_idx] = current_new_idx + current_new_idx += 1 + + # Update operator inputs/outputs + for op_idx, operator in enumerate( + subgraph.operators): # Object API: subgraph.operators + if operator.inputs is not None: # Object API: operator.inputs + updated_inputs = [] + for j in range(len(operator.inputs)): # Object API: len(operator.inputs) + old_input_idx = operator.inputs[j] # Object API: operator.inputs[j] + if old_input_idx == -1: # Optional empty input + updated_inputs.append(-1) + elif old_input_idx in new_indices_map: + updated_inputs.append(new_indices_map[old_input_idx]) + operator.inputs = updated_inputs + + if operator.outputs is not None: # Object API: operator.outputs + updated_outputs = [] + for j in range(len(operator.outputs)): # Object API: len(operator.outputs) + old_output_idx = operator.outputs[j] # Object API: operator.outputs[j] + if old_output_idx in new_indices_map: + updated_outputs.append(new_indices_map[old_output_idx]) + operator.outputs = updated_outputs + + # Update intermediates if they exist + if operator.intermediates is not None: # Object API: operator.intermediates + updated_intermediates = [] + for j in range(len( + operator.intermediates)): # Object API: len(operator.intermediates) + old_intermediate_idx = operator.intermediates[ + j] # Object API: operator.intermediates[j] + if old_intermediate_idx in new_indices_map: + updated_intermediates.append(new_indices_map[old_intermediate_idx]) + operator.intermediates = updated_intermediates + + # Update subgraph inputs/outputs + if subgraph.inputs is not None: # Object API: subgraph.inputs + updated_subgraph_inputs = [] + for j in range(len(subgraph.inputs)): # Object API: len(subgraph.inputs) + old_input_idx = subgraph.inputs[j] # Object API: subgraph.inputs[j] + if old_input_idx in new_indices_map: + updated_subgraph_inputs.append(new_indices_map[old_input_idx]) + subgraph.inputs = updated_subgraph_inputs + + if subgraph.outputs is not None: # Object API: subgraph.outputs + updated_subgraph_outputs = [] + for j in range(len(subgraph.outputs)): # Object API: len(subgraph.outputs) + old_output_idx = subgraph.outputs[j] # Object API: subgraph.outputs[j] + if old_output_idx in new_indices_map: + updated_subgraph_outputs.append(new_indices_map[old_output_idx]) + subgraph.outputs = updated_subgraph_outputs + + return sorted(removed_indices) + + +def main(): + # Read the entire model from stdin + data = sys.stdin.buffer.read() + buf = bytearray(data) + + # Create a read‑only model (Native API) and a mutable copy (Object API) + model_ro = circle.Model.GetRootAsModel(buf, 0) + model = circle.ModelT.InitFromObj(model_ro) + + total_unused_tensors_count = 0 + model_changed = False + + o2o.log(f"Processing {model_ro.SubgraphsLength()} subgraph(s)...") + for i in range(model_ro.SubgraphsLength()): + subgraph_ro = model_ro.Subgraphs(i) + if not subgraph_ro: + o2o.log(f"Warning: Could not read subgraph {i}. Skipping.") + continue + + unused = find_unused_tensors_in_subgraph(subgraph_ro) + if not unused: + o2o.log(f"Subgraph {i}: No unused tensors found.") + continue + + total_unused_tensors_count += len(unused) + o2o.log( + f"Subgraph {i}: Found {len(unused)} unused tensor(s): {', '.join(map(str, sorted(unused)))}" + ) + + actually_removed = remove_tensors_and_update_model(model, i, unused) + if actually_removed: + o2o.log(f"Subgraph {i}: Removed {len(actually_removed)} tensor(s).") + model_changed = True + else: + o2o.log(f"Subgraph {i}: No tensors were removed during the process.") + + if total_unused_tensors_count == 0: + o2o.log("\nNo unused tensors found in any subgraph.") + o2o.save_model_to_stdout(model) + sys.exit(0) + + o2o.log( + f"\nTotal unused tensors found across all subgraphs: {total_unused_tensors_count}" + ) + + if model_changed: + o2o.log("\nSaving modified model to stdout...") + else: + o2o.log( + "\nNo tensors were actually removed from any subgraph. Saving original model to stdout." + ) + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + main() diff --git a/tools/circle2circle/rename.io.remove_namespace.py b/tools/circle2circle/rename.io.remove_namespace.py new file mode 100755 index 00000000000..7f22b4fe1bf --- /dev/null +++ b/tools/circle2circle/rename.io.remove_namespace.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import sys +import re +import circle +import flatbuffers +import o2o + + +def load_model_from_stdin(): + """Load a Circle model from binary data read from stdin.""" + data = sys.stdin.buffer.read() + buf = bytearray(data) + model = circle.Model.GetRootAsModel(buf, 0) + model = circle.ModelT.InitFromObj(model) + return model + + +def save_model_to_stdout(model): + """Serialize a Circle model and write it to stdout as binary data.""" + builder = flatbuffers.Builder(1024) + builder.Finish(model.Pack(builder), b'CIR0') + sys.stdout.buffer.write(builder.Output()) + + +def remove_namespace_from_inputs_and_outputs(model): + """Remove namespace from tensor names within the given model.""" + pattern = r'(.*)::(.*)' + + def process_subgraph(subgraph): + """Process a single subgraph, renaming matching tensor names.""" + o2o.log( + f"Processing subgraph with {len(subgraph.inputs)} inputs and {len(subgraph.outputs)} outputs" + ) + renamed_count = 0 + + # Process input tensors + for input_tensor_index in subgraph.inputs: + tensor = subgraph.tensors[input_tensor_index] + was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( + tensor, pattern, lambda match: match.group(2)) + if was_renamed: + o2o.log(f"Renaming input tensor: {old_name} → {new_name}") + renamed_count += 1 + + # Process output tensors + for output_tensor_index in subgraph.outputs: + tensor = subgraph.tensors[output_tensor_index] + was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( + tensor, pattern, lambda match: match.group(2)) + if was_renamed: + o2o.log(f"Renaming output tensor: {old_name} → {new_name}") + renamed_count += 1 + + if renamed_count > 0: + o2o.log(f"Renamed {renamed_count} input/output tensors in this subgraph") + else: + o2o.log("No input/output tensors were renamed in this subgraph") + + return renamed_count > 0, renamed_count + + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No tensors were modified.") + else: + o2o.log(f"Total tensors renamed across all subgraphs: {total_changes}") + + +def main(): + """Entry point: read model from stdin, process, write to stdout.""" + model = load_model_from_stdin() + remove_namespace_from_inputs_and_outputs(model) + save_model_to_stdout(model) + + +if __name__ == "__main__": + main() diff --git a/tools/circle2circle/rename.io.remove_prefix.py b/tools/circle2circle/rename.io.remove_prefix.py new file mode 100755 index 00000000000..b5b200486d9 --- /dev/null +++ b/tools/circle2circle/rename.io.remove_prefix.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import o2o +import re +import sys + + +def rename_input_tensors(prefix): + """Main function to rename tensors by removing the specified prefix""" + # Load the model using utility function + model = o2o.load_model_from_stdin() + + # Pattern to match tensor names that start with the specified prefix followed by anything + pattern = re.escape(prefix) + r'(.*)' + + def process_subgraph(subgraph): + """Process a single subgraph""" + o2o.log(f"Processing subgraph with {len(subgraph.tensors)} tensors") + + renamed_count = 0 + for tensor in subgraph.tensors: + was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( + tensor, pattern, lambda match: match.group(1)) + + if was_renamed: + o2o.log(f"Renaming tensor: {old_name} → {new_name}") + renamed_count += 1 + + if renamed_count > 0: + o2o.log(f"Renamed {renamed_count} tensors in this subgraph") + else: + o2o.log("No tensors were renamed in this subgraph") + + return renamed_count > 0, renamed_count + + # Process all subgraphs using utility function + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No tensors were modified.") + + # Save the model using utility function + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + o2o.log("Usage: python rename_inputs.py ") + sys.exit(1) + + prefix = sys.argv[1] + + # Directly invoke processing; I/O handled via stdin/stdout + rename_input_tensors(prefix) diff --git a/tools/circle2circle/reshape.fc_weight.py b/tools/circle2circle/reshape.fc_weight.py new file mode 100755 index 00000000000..212a188d763 --- /dev/null +++ b/tools/circle2circle/reshape.fc_weight.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np +import circle +import o2o + + +def is_effectively_2d(shape): + """Check if a tensor shape is effectively 2D (all leading dimensions are 1)""" + return all(dim == 1 for dim in shape[:-2]) + + +def count_tensor_usage(model, tensor_index): + """Count how many operators use a specific tensor as input""" + count = 0 + for subgraph in model.subgraphs: + for operator in subgraph.operators: + if operator.inputs is not None: + for input_idx in operator.inputs: + if input_idx == tensor_index: + count += 1 + return count + + +def create_new_tensor(original_tensor, new_shape): + """Create a new tensor with the specified shape based on the original tensor""" + new_tensor = circle.TensorT() + new_tensor.shape = new_shape + new_tensor.type = original_tensor.type + new_tensor.buffer = original_tensor.buffer + new_tensor.name = original_tensor.name + "_reshaped" if original_tensor.name else None + new_tensor.quantization = original_tensor.quantization + new_tensor.isVariable = original_tensor.isVariable + new_tensor.sparsity = original_tensor.sparsity + new_tensor.shapeSignature = original_tensor.shapeSignature + new_tensor.hasRank = original_tensor.hasRank + new_tensor.variantTensors = original_tensor.variantTensors + new_tensor.compressionType = original_tensor.compressionType + return new_tensor + + +def modify_fully_connected_weights(): + """Main function to modify FullyConnected weights from effectively 2D to 2D""" + # Load the model using utility function + model = o2o.load_model_from_stdin() + + # Process each subgraph + for subgraph in model.subgraphs: + # Create a mapping from old tensor indices to new tensor indices + tensor_mapping = {} + + # First pass: identify and create new tensors for modification + for i, operator in enumerate(subgraph.operators): + # Check if this is a FullyConnected operator + opcode = model.operatorCodes[operator.opcodeIndex] + if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: + # Get the weights tensor (typically the second input) + if len(operator.inputs) >= 2: + weights_index = operator.inputs[ + 1] # Weights is usually the second input + weights_tensor = subgraph.tensors[weights_index] + + # Check if the weights tensor is effectively 2D + if len(weights_tensor.shape) > 2 and is_effectively_2d( + weights_tensor.shape): + operator.builtinOptions.keepNumDims = True + # Check if this tensor is used by multiple operators + usage_count = count_tensor_usage(model, weights_index) + + if usage_count > 1: + # Create a new tensor for this operator to avoid affecting others + new_shape = weights_tensor.shape[ + -2:] # Remove leading dimensions of 1 + new_tensor = create_new_tensor(weights_tensor, new_shape) + + # Add the new tensor to the subgraph + new_tensor_index = len(subgraph.tensors) + subgraph.tensors.append(new_tensor) + + # Update the mapping for this specific operator + if i not in tensor_mapping: + tensor_mapping[i] = {} + tensor_mapping[i][weights_index] = new_tensor_index + else: + # Directly modify the tensor shape since it's only used once + weights_tensor.shape = weights_tensor.shape[-2:] + + # Second pass: update operator inputs based on the mapping + for i, operator in enumerate(subgraph.operators): + # Check if this is a FullyConnected operator + opcode = model.operatorCodes[operator.opcodeIndex] + if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: + # Update inputs according to the mapping + if i in tensor_mapping: + for j, input_idx in enumerate(operator.inputs): + if input_idx in tensor_mapping[i]: + operator.inputs[j] = tensor_mapping[i][input_idx] + else: + # For tensors that were directly modified, just check if they need updating + if len(operator.inputs) >= 2: + weights_index = operator.inputs[1] + weights_tensor = subgraph.tensors[weights_index] + if is_effectively_2d(weights_tensor.shape): + # Update the shape to be truly 2D + weights_tensor.shape = weights_tensor.shape[-2:] + + # Save the model using utility function + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + modify_fully_connected_weights() diff --git a/tools/circle2circle/reshape.io.py b/tools/circle2circle/reshape.io.py new file mode 100755 index 00000000000..ce8bc4d20d7 --- /dev/null +++ b/tools/circle2circle/reshape.io.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import sys +import argparse +import circle +import o2o + + +def parse_shape(shape_str): + """Parse a shape string like '[1,16,30,4]' into a list of integers.""" + try: + # Strip surrounding brackets and whitespace, then split by commas + shape_str = shape_str.strip().strip('[]') + parts = shape_str.split(',') + shape = [int(p.strip()) for p in parts if p.strip()] + return shape + except Exception as e: + raise ValueError( + f"Invalid shape format '{shape_str}'. Expected format like [1,16,30,4]" + ) from e + + +def is_target_shape(shape, target_shape): + """Check if a tensor shape matches the target shape.""" + if len(shape) != len(target_shape): + return False + return list(shape) == target_shape + + +def reshape_input_tensors(io_type, target_shape, new_shape): + """Reshape input or output tensors from target_shape to new_shape.""" + model = o2o.load_model_from_stdin() + + for subgraph in model.subgraphs: + # Choose the appropriate tensor list based on io_type + io_list = subgraph.inputs if io_type == 'input' else subgraph.outputs + + for tensor_idx in io_list: + tensor = subgraph.tensors[tensor_idx] + if is_target_shape(tensor.shape, target_shape): + tensor.shape = new_shape + + o2o.save_model_to_stdout(model) + + +def main(): + parser = argparse.ArgumentParser( + description='Reshape input or output tensors by specifying target and new shapes.' + ) + parser.add_argument('io_type', + choices=['input', 'output'], + help='Whether to process input tensors or output tensors.') + parser.add_argument( + '--by_shape', + nargs=2, + metavar=('TARGET_SHAPE', 'NEW_SHAPE'), + required=True, + help= + 'Reshape tensors from TARGET_SHAPE to NEW_SHAPE. Example: --by_shape [1,16,30,4] [1,16,32,4]' + ) + # No file arguments needed; model is read from stdin and written to stdout + + args = parser.parse_args() + + # Parse the shape arguments + try: + target_shape = parse_shape(args.by_shape[0]) + new_shape = parse_shape(args.by_shape[1]) + except ValueError as e: + o2o.log(f"Error parsing shapes: {e}") + sys.exit(1) + + # Execute the reshaping with safe error handling + reshape_input_tensors(args.io_type, target_shape, new_shape) + + +if __name__ == "__main__": + main() diff --git a/tools/circle2circle/transpose.io.kvcache.py b/tools/circle2circle/transpose.io.kvcache.py new file mode 100755 index 00000000000..1f54fe6d034 --- /dev/null +++ b/tools/circle2circle/transpose.io.kvcache.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +import o2o +import re + + +def transpose_2d_3d(shape): + """Transpose the second and third dimensions of a 4D shape""" + if len(shape) != 4: + raise ValueError("Shape must be 4D to transpose second and third dimensions") + # Transpose shape: [d0, d1, d2, d3] -> [d0, d2, d1, d3] + return [shape[0], shape[2], shape[1], shape[3]] + + +def transpose_tensor_dimensions(): + """Main function to find tensors and transpose their dimensions""" + # Load the model using utility function + model = o2o.load_model_from_stdin() + + # Pattern to match tensor names like "past_key_values_key_cache_0", "past_key_values_key_cache_1", etc. + pattern = r'.*_cache_\d+' + + def process_subgraph(subgraph): + """Process a single subgraph""" + o2o.log(f"Processing subgraph with {len(subgraph.inputs)} input tensors") + + modified_count = 0 + for input_tensor_index in subgraph.inputs: + # Get the actual tensor object + tensor = subgraph.tensors[input_tensor_index] + + tensor_name = o2o.get_tensor_name(tensor) + if tensor_name and re.match(pattern, tensor_name): + o2o.log(f"Found input tensor: {tensor_name} with shape {tensor.shape}") + + if len(tensor.shape) == 4: + o2o.log( + f"Input tensor {tensor_name} is 4D. Transposing second and third dimensions." + ) + + # Transpose the second and third dimensions + original_shape = tensor.shape.copy() + new_shape = transpose_2d_3d(tensor.shape) + tensor.shape = new_shape + + o2o.log(f"Shape changed from {original_shape} to {new_shape}") + modified_count += 1 + else: + o2o.log(f"Input tensor {tensor_name} is not 4D. Skipping.") + + return modified_count > 0, modified_count + + # Process all subgraphs using utility function + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No tensors were modified.") + + # Save the model using utility function + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + transpose_tensor_dimensions() From fb289bb65d9675819296a66241bca0d3ba420e68 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 23 Oct 2025 15:17:30 +0900 Subject: [PATCH 02/27] Add select.op.py --- tools/circle2circle/select.op.py | 389 +++++++++++++++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100755 tools/circle2circle/select.op.py diff --git a/tools/circle2circle/select.op.py b/tools/circle2circle/select.op.py new file mode 100755 index 00000000000..fefa85d62f5 --- /dev/null +++ b/tools/circle2circle/select.op.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 + +import sys +import argparse +import flatbuffers +import o2o + + +def parse_operator_indices(indices_str): + """Parse operator index string into a list of indices. + + Supports formats like: + - "0-181" (range) + - "0,5,10-15" (mixed) + - "0" (single index) + + Args: + indices_str (str): String containing operator indices + + Returns: + list: Sorted list of unique operator indices + """ + if not indices_str: + return [] + + indices = set() + + # Split by comma first + parts = indices_str.split(',') + + for part in parts: + part = part.strip() + if not part: + continue + + # Check if it's a range + if '-' in part: + try: + start, end = part.split('-', 1) + start_idx = int(start.strip()) + end_idx = int(end.strip()) + + if start_idx < 0 or end_idx < 0: + raise ValueError("Indices must be non-negative") + + if start_idx > end_idx: + raise ValueError(f"Invalid range: {start_idx} > {end_idx}") + + indices.update(range(start_idx, end_idx + 1)) + except ValueError as e: + o2o.log(f"Error parsing range '{part}': {e}", file=sys.stderr) + sys.exit(1) + else: + # Single index + try: + idx = int(part) + if idx < 0: + raise ValueError("Index must be non-negative") + indices.add(idx) + except ValueError as e: + o2o.log(f"Error parsing index '{part}': {e}", file=sys.stderr) + sys.exit(1) + + return sorted(list(indices)) + + +def analyze_tensor_connections(subgraph): + """Analyze all tensor connections in the subgraph. + + Args: + subgraph: Circle subgraph object + + Returns: + dict: Analysis results including tensor-to-operator mappings and subgraph I/O info + """ + # Build mappings + tensor_to_def = {} # tensor_idx -> operator_idx + tensor_to_use = {} # tensor_idx -> [operator_idx, ...] + op_inputs = {} # operator_idx -> [tensor_idx, ...] + op_outputs = {} # operator_idx -> [tensor_idx, ...] + + # Analyze all operators + for op_idx, operator in enumerate(subgraph.operators): + inputs = [] + if operator.inputs is not None and len(operator.inputs) > 0: + inputs = list(operator.inputs) + outputs = [] + if operator.outputs is not None and len(operator.outputs) > 0: + outputs = list(operator.outputs) + + op_inputs[op_idx] = inputs + op_outputs[op_idx] = outputs + + # Record tensor -> producer mapping + for output_idx in outputs: + if output_idx != -1: + tensor_to_def[output_idx] = op_idx + + # Record tensor -> consumers mapping + for input_idx in inputs: + if input_idx != -1: + if input_idx not in tensor_to_use: + tensor_to_use[input_idx] = [] + tensor_to_use[input_idx].append(op_idx) + + # Analyze subgraph I/O + subgraph_inputs = list(subgraph.inputs) if subgraph.inputs is not None else [] + subgraph_outputs = list(subgraph.outputs) if subgraph.outputs is not None else [] + + return { + 'tensor_to_producer': tensor_to_def, + 'tensor_to_consumers': tensor_to_use, + 'operator_to_inputs': op_inputs, + 'operator_to_outputs': op_outputs, + 'subgraph_inputs': subgraph_inputs, + 'subgraph_outputs': subgraph_outputs + } + + +def select_operators_and_update_model(model, subgraph_index, operator_indices_to_keep): + """Keep only specified operators in the model and remove all others. + + Args: + model: Circle model object (Object API) + subgraph_index (int): Index of subgraph to modify (assumed to be 0) + operator_indices_to_keep (list): List of operator indices to keep + + Returns: + tuple: (removed_operators_count, removed_operator_codes_count) + """ + if not model.subgraphs or subgraph_index >= len(model.subgraphs): + o2o.log(f"Error: Invalid subgraph index {subgraph_index}", file=sys.stderr) + return 0, 0 + + subgraph = model.subgraphs[subgraph_index] + + # Validate operator indices + max_operator_index = len(subgraph.operators) - 1 + invalid_indices = [ + idx for idx in operator_indices_to_keep if idx > max_operator_index + ] + if invalid_indices: + o2o.log( + f"Error: Operator indices {invalid_indices} exceed maximum index {max_operator_index}", + file=sys.stderr) + sys.exit(1) + + o2o.log( + f"Subgraph {subgraph_index}: Keeping {len(operator_indices_to_keep)} operator(s): {operator_indices_to_keep}", + file=sys.stderr) + + # Step 1: Determine which operators to remove + total_operators = len(subgraph.operators) + operator_indices_to_remove = [] + for i in range(total_operators): + if i not in operator_indices_to_keep: + operator_indices_to_remove.append(i) + + o2o.log( + f"Will remove {len(operator_indices_to_remove)} operator(s): {operator_indices_to_remove}", + file=sys.stderr) + + # Step 2: Analyze tensor connections BEFORE removing operators + connections = analyze_tensor_connections(subgraph) + + # Step 3: Remove operators in descending order to avoid index shifting + removed_operators = [] + for op_idx in sorted(operator_indices_to_remove, reverse=True): + del subgraph.operators[op_idx] + removed_operators.append(op_idx) + + # Step 4: Update subgraph I/O + # Remove subgraph inputs that were used only by removed operators + inputs_to_remove = set() + for input_idx in connections['subgraph_inputs']: + if input_idx in connections['tensor_to_consumers']: + # Check if all consumers of this input were removed + all_consumers_removed = True + for consumer_idx in connections['tensor_to_consumers'][input_idx]: + if consumer_idx not in operator_indices_to_remove: + all_consumers_removed = False + break + if all_consumers_removed: + inputs_to_remove.add(input_idx) + + # Remove subgraph outputs that were produced only by removed operators + outputs_to_remove = set() + for output_idx in connections['subgraph_outputs']: + if output_idx in connections['tensor_to_producer']: + if connections['tensor_to_producer'][ + output_idx] in operator_indices_to_remove: + outputs_to_remove.add(output_idx) + + # Update subgraph inputs + if inputs_to_remove: + new_inputs = [ + idx for idx in connections['subgraph_inputs'] if idx not in inputs_to_remove + ] + subgraph.inputs = new_inputs + o2o.log( + f"Removed {len(inputs_to_remove)} subgraph inputs: {sorted(inputs_to_remove)}", + file=sys.stderr) + + # Update subgraph outputs + if outputs_to_remove: + new_outputs = [ + idx for idx in connections['subgraph_outputs'] if idx not in outputs_to_remove + ] + subgraph.outputs = new_outputs + o2o.log( + f"Removed {len(outputs_to_remove)} subgraph outputs: {sorted(outputs_to_remove)}", + file=sys.stderr) + + # Step 5: Update operator inputs that reference outputs of removed operators + for op_idx, operator in enumerate(subgraph.operators): + if operator.inputs is not None and len(operator.inputs) > 0: + updated_inputs = [] + for input_idx in operator.inputs: + if input_idx != -1 and input_idx in connections['tensor_to_producer']: + producer_idx = connections['tensor_to_producer'][input_idx] + if producer_idx in operator_indices_to_remove: + # This input comes from a removed operator, set to -1 + updated_inputs.append(-1) + o2o.log( + f" Operator {op_idx}: Breaking input connection from removed operator {producer_idx}", + file=sys.stderr) + else: + updated_inputs.append(input_idx) + else: + updated_inputs.append(input_idx) + operator.inputs = updated_inputs + + # Step 6: Clean up unused OperatorCode entries + # Get OperatorCode usage by remaining operators + used_operator_codes = set() + for operator in subgraph.operators: + if operator.opcodeIndex is not None: + used_operator_codes.add(operator.opcodeIndex) + + # Find unused OperatorCode indices + unused_operator_codes = [] + for i, operator_code in enumerate(model.operatorCodes): + if i not in used_operator_codes: + unused_operator_codes.append(i) + + # Remove unused OperatorCode entries in descending order + removed_operator_codes = [] + for code_idx in sorted(unused_operator_codes, reverse=True): + operator_code = model.operatorCodes[code_idx] + if operator_code.builtinCode is not None: + op_name = f"builtin_code={operator_code.builtinCode}" + else: + op_name = f"custom_code={operator_code.customCode}" + o2o.log(f" Removing unused OperatorCode at index {code_idx}: {op_name}", + file=sys.stderr) + del model.operatorCodes[code_idx] + removed_operator_codes.append(code_idx) + + # Step 7: Update operator code indices in remaining operators + # Create mapping from old to new indices + old_to_new_code_indices = {} + new_idx = 0 + for old_idx in range(len(model.operatorCodes) + len(removed_operator_codes)): + if old_idx not in unused_operator_codes: + old_to_new_code_indices[old_idx] = new_idx + new_idx += 1 + + # Update operator code indices + for subgraph in model.subgraphs: + for operator in subgraph.operators: + if operator.opcodeIndex is not None: + old_code_idx = operator.opcodeIndex + if old_code_idx in old_to_new_code_indices: + operator.opcodeIndex = old_to_new_code_indices[old_code_idx] + + # Count tensor usage and definition + tensor_use_count = {} + tensor_def_count = {} + + # Count usage (inputs) + for operator in subgraph.operators: + if operator.inputs is not None: + for input_idx in operator.inputs: + if input_idx != -1: + tensor_use_count[input_idx] = tensor_use_count.get(input_idx, 0) + 1 + + # Count definition (outputs) + for operator in subgraph.operators: + if operator.outputs is not None: + for output_idx in operator.outputs: + if output_idx != -1: + tensor_def_count[output_idx] = tensor_def_count.get(output_idx, 0) + 1 + + # Add tensors with use_count == 0 to subgraph outputs + added_outputs = [] + current_outputs = set(subgraph.outputs) if subgraph.outputs is not None else set() + + # First check tensors that are in tensor_use_count + for tensor_idx, use_count in tensor_use_count.items(): + if use_count == 0 and tensor_idx not in current_outputs: + if subgraph.outputs is None: + subgraph.outputs = [] + subgraph.outputs.append(tensor_idx) + added_outputs.append(tensor_idx) + + # Then check all output tensors from operators (some might not be in tensor_use_count) + for operator in subgraph.operators: + if operator.outputs is not None: + for output_idx in operator.outputs: + if output_idx != -1: + use_count = tensor_use_count.get( + output_idx, 0) # Default to 0 if not in use_count + if use_count == 0 and output_idx not in current_outputs and output_idx not in added_outputs: + if subgraph.outputs is None: + subgraph.outputs = [] + subgraph.outputs.append(output_idx) + added_outputs.append(output_idx) + + if added_outputs: + o2o.log(f"Added tensors to subgraph outputs: {sorted(added_outputs)}", + file=sys.stderr) + + # Add tensors with def_count == 0 to subgraph inputs + added_inputs = [] + current_inputs = set(subgraph.inputs) if subgraph.inputs is not None else set() + for tensor_idx, def_count in tensor_def_count.items(): + if def_count == 0 and tensor_idx not in current_inputs: + if subgraph.inputs is None: + subgraph.inputs = [] + subgraph.inputs.append(tensor_idx) + added_inputs.append(tensor_idx) + + if added_inputs: + o2o.log(f"Added tensors to subgraph inputs: {sorted(added_inputs)}", + file=sys.stderr) + + return len(removed_operators), len(removed_operator_codes) + + +def main(): + parser = argparse.ArgumentParser( + description= + 'Select operators from Circle model by index range, then clean up unused tensors') + parser.add_argument('--by_id', + required=True, + help='Operator indices to keep (e.g., "0-181", "0,5,10-15")') + + args = parser.parse_args() + + # Parse the operator indices + try: + operator_indices_to_keep = parse_operator_indices(args.by_id) + o2o.log(f"Operator indices to keep: {operator_indices_to_keep}", file=sys.stderr) + except ValueError as e: + o2o.log(f"Error parsing operator indices: {e}", file=sys.stderr) + sys.exit(1) + + if not operator_indices_to_keep: + o2o.log("No valid operator indices specified", file=sys.stderr) + sys.exit(1) + + # Load the model + model = o2o.load_model_from_stdin() + + # Assume only one subgraph (index 0) + subgraph_index = 0 + + if not model.subgraphs or subgraph_index >= len(model.subgraphs): + o2o.log(f"Error: Model has no subgraph at index {subgraph_index}", + file=sys.stderr) + sys.exit(1) + + o2o.log(f"Model has {len(model.subgraphs[subgraph_index].operators)} operators", + file=sys.stderr) + + # Select operators (keep only specified ones) + removed_ops_count, removed_codes_count = select_operators_and_update_model( + model, subgraph_index, operator_indices_to_keep) + + o2o.log( + f"Removed {removed_ops_count} operators and {removed_codes_count} unused OperatorCode entries", + file=sys.stderr) + + # Save the model directly to stdout + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + main() From fef581c0ef97c6a189eb4e4f444882cf5e6f0a26 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 28 Oct 2025 13:57:12 +0900 Subject: [PATCH 03/27] Add retype.input_ids.py --- tools/circle2circle/README.md | 6 +++ tools/circle2circle/retype.input_ids.py | 57 +++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100755 tools/circle2circle/retype.input_ids.py diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md index c6d10283d5c..fa654848b68 100644 --- a/tools/circle2circle/README.md +++ b/tools/circle2circle/README.md @@ -127,6 +127,12 @@ Identifies and removes unused tensors from all subgraphs within a Circle model. ## +### `retype.input_ids.py` + +Finds tensors named `input_ids` and changes their data type from int64 to int32. This filter is useful for models that need to be compatible with hardware or frameworks that expect input_ids to be 32-bit integers instead of 64-bit integers. + +## + ### `gen_circle.*.py` diff --git a/tools/circle2circle/retype.input_ids.py b/tools/circle2circle/retype.input_ids.py new file mode 100755 index 00000000000..6a75187f2a6 --- /dev/null +++ b/tools/circle2circle/retype.input_ids.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +import o2o +import circle +import sys + + +def retype_input_ids(): + """Main function to change input_ids tensor type from int64 to int32""" + # Load the model using utility function + model = o2o.load_model_from_stdin() + + def process_subgraph(subgraph): + """Process a single subgraph to find and retype input_ids tensors""" + o2o.log(f"Processing subgraph with {len(subgraph.tensors)} tensors") + + retyped_count = 0 + for tensor in subgraph.tensors: + tensor_name = o2o.get_tensor_name(tensor) + + # Check if this is the input_ids tensor + if tensor_name == "input_ids": + # Check if current type is int64 + if tensor.type == circle.TensorType.INT64: + old_type = "int64" + new_type = "int32" + + # Change type to int32 + tensor.type = circle.TensorType.INT32 + + o2o.log(f"Retyped tensor: {tensor_name} {old_type} → {new_type}") + retyped_count += 1 + else: + o2o.log( + f"Found input_ids tensor but type is not int64 (current type: {tensor.type})" + ) + + if retyped_count > 0: + o2o.log(f"Retyped {retyped_count} input_ids tensors in this subgraph") + else: + o2o.log("No input_ids tensors were retyped in this subgraph") + + return retyped_count > 0, retyped_count + + # Process all subgraphs using utility function + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No input_ids tensors were modified.") + + # Save the model using utility function + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + retype_input_ids() From 9f6438b26e8e438863bee8e0fbbeabf3e257aacf Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 11:05:25 +0900 Subject: [PATCH 04/27] Add fuse.attention --- tools/circle2circle/circle.py | 592 +++++++++++++++++++++++++- tools/circle2circle/fuse.attention.py | 234 ++++++++++ tools/circle2circle/o2o.py | 40 ++ tools/circle2circle/select.op.py | 66 ++- 4 files changed, 877 insertions(+), 55 deletions(-) create mode 100755 tools/circle2circle/fuse.attention.py diff --git a/tools/circle2circle/circle.py b/tools/circle2circle/circle.py index 76328683f10..dbc0fefca0b 100644 --- a/tools/circle2circle/circle.py +++ b/tools/circle2circle/circle.py @@ -34,12 +34,14 @@ class TensorType(object): UINT32 = 15 UINT16 = 16 INT4 = 17 + BFLOAT16 = 18 class QuantizationDetails(object): NONE = 0 CustomQuantization = 1 MXQuantization = 2 + BlockwiseQuantization = 3 def QuantizationDetailsCreator(unionType, table): @@ -50,6 +52,8 @@ def QuantizationDetailsCreator(unionType, table): return CustomQuantizationT.InitFromBuf(table.Bytes, table.Pos) if unionType == QuantizationDetails().MXQuantization: return MXQuantizationT.InitFromBuf(table.Bytes, table.Pos) + if unionType == QuantizationDetails().BlockwiseQuantization: + return BlockwiseQuantizationT.InitFromBuf(table.Bytes, table.Pos) return None @@ -298,6 +302,10 @@ class BuiltinOperator(object): DILATE = 203 STABLEHLO_RNG_BIT_GENERATOR = 204 REDUCE_WINDOW = 205 + STABLEHLO_COMPOSITE = 206 + STABLEHLO_SHIFT_LEFT = 207 + STABLEHLO_CBRT = 208 + STABLEHLO_CASE = 209 class BuiltinOptions(object): @@ -735,6 +743,9 @@ class BuiltinOptions2(object): DilateOptions = 18 StablehloRngBitGeneratorOptions = 19 ReduceWindowOptions = 20 + StableHLOCompositeOptions = 21 + StablehloShiftLeftOptions = 22 + StablehloCaseOptions = 23 def BuiltinOptions2Creator(unionType, table): @@ -781,6 +792,12 @@ def BuiltinOptions2Creator(unionType, table): return StablehloRngBitGeneratorOptionsT.InitFromBuf(table.Bytes, table.Pos) if unionType == BuiltinOptions2().ReduceWindowOptions: return ReduceWindowOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StableHLOCompositeOptions: + return StableHLOCompositeOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloShiftLeftOptions: + return StablehloShiftLeftOptionsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == BuiltinOptions2().StablehloCaseOptions: + return StablehloCaseOptionsT.InitFromBuf(table.Bytes, table.Pos) return None @@ -1093,6 +1110,117 @@ def Pack(self, builder): return mxquantization +class BlockwiseQuantization(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = BlockwiseQuantization() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsBlockwiseQuantization(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def BlockwiseQuantizationBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # BlockwiseQuantization + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # BlockwiseQuantization + def Scales(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BlockwiseQuantization + def ZeroPoints(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # BlockwiseQuantization + def BlockSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def BlockwiseQuantizationStart(builder): + builder.StartObject(3) + + +def BlockwiseQuantizationAddScales(builder, scales): + builder.PrependInt32Slot(0, scales, 0) + + +def BlockwiseQuantizationAddZeroPoints(builder, zeroPoints): + builder.PrependInt32Slot(1, zeroPoints, 0) + + +def BlockwiseQuantizationAddBlockSize(builder, blockSize): + builder.PrependInt32Slot(2, blockSize, 0) + + +def BlockwiseQuantizationEnd(builder): + return builder.EndObject() + + +class BlockwiseQuantizationT(object): + + # BlockwiseQuantizationT + def __init__(self): + self.scales = 0 # type: int + self.zeroPoints = 0 # type: int + self.blockSize = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + blockwiseQuantization = BlockwiseQuantization() + blockwiseQuantization.Init(buf, pos) + return cls.InitFromObj(blockwiseQuantization) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, blockwiseQuantization): + x = BlockwiseQuantizationT() + x._UnPack(blockwiseQuantization) + return x + + # BlockwiseQuantizationT + def _UnPack(self, blockwiseQuantization): + if blockwiseQuantization is None: + return + self.scales = blockwiseQuantization.Scales() + self.zeroPoints = blockwiseQuantization.ZeroPoints() + self.blockSize = blockwiseQuantization.BlockSize() + + # BlockwiseQuantizationT + def Pack(self, builder): + BlockwiseQuantizationStart(builder) + BlockwiseQuantizationAddScales(builder, self.scales) + BlockwiseQuantizationAddZeroPoints(builder, self.zeroPoints) + BlockwiseQuantizationAddBlockSize(builder, self.blockSize) + blockwiseQuantization = BlockwiseQuantizationEnd(builder) + return blockwiseQuantization + + class QuantizationParameters(object): __slots__ = ['_tab'] @@ -1332,7 +1460,7 @@ def __init__(self): self.scale = None # type: List[float] self.zeroPoint = None # type: List[int] self.detailsType = 0 # type: int - self.details = None # type: Union[None, CustomQuantizationT, MXQuantizationT] + self.details = None # type: Union[None, CustomQuantizationT, MXQuantizationT, BlockwiseQuantizationT] self.quantizedDimension = 0 # type: int @classmethod @@ -6712,6 +6840,141 @@ def Pack(self, builder): return stablehloScatterOptions +class StablehloCaseOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloCaseOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloCaseOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloCaseOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloCaseOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StablehloCaseOptions + def BranchSubgraphIndices(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Int32Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StablehloCaseOptions + def BranchSubgraphIndicesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + +def StablehloCaseOptionsStart(builder): + builder.StartObject(1) + + +def StablehloCaseOptionsAddBranchSubgraphIndices(builder, branchSubgraphIndices): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(branchSubgraphIndices), 0) + + +def StablehloCaseOptionsStartBranchSubgraphIndicesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + +def StablehloCaseOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StablehloCaseOptionsT(object): + + # StablehloCaseOptionsT + def __init__(self): + self.branchSubgraphIndices = None # type: List[int] + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloCaseOptions = StablehloCaseOptions() + stablehloCaseOptions.Init(buf, pos) + return cls.InitFromObj(stablehloCaseOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloCaseOptions): + x = StablehloCaseOptionsT() + x._UnPack(stablehloCaseOptions) + return x + + # StablehloCaseOptionsT + def _UnPack(self, stablehloCaseOptions): + if stablehloCaseOptions is None: + return + if not stablehloCaseOptions.BranchSubgraphIndicesIsNone(): + if np is None: + self.branchSubgraphIndices = [] + for i in range(stablehloCaseOptions.BranchSubgraphIndicesLength()): + self.branchSubgraphIndices.append( + stablehloCaseOptions.BranchSubgraphIndices(i)) + else: + self.branchSubgraphIndices = stablehloCaseOptions.BranchSubgraphIndicesAsNumpy( + ) + + # StablehloCaseOptionsT + def Pack(self, builder): + if self.branchSubgraphIndices is not None: + if np is not None and type(self.branchSubgraphIndices) is np.ndarray: + branchSubgraphIndices = builder.CreateNumpyVector( + self.branchSubgraphIndices) + else: + StablehloCaseOptionsStartBranchSubgraphIndicesVector( + builder, len(self.branchSubgraphIndices)) + for i in reversed(range(len(self.branchSubgraphIndices))): + builder.PrependInt32(self.branchSubgraphIndices[i]) + branchSubgraphIndices = builder.EndVector() + StablehloCaseOptionsStart(builder) + if self.branchSubgraphIndices is not None: + StablehloCaseOptionsAddBranchSubgraphIndices(builder, branchSubgraphIndices) + stablehloCaseOptions = StablehloCaseOptionsEnd(builder) + return stablehloCaseOptions + + class StablehloRngBitGeneratorOptions(object): __slots__ = ['_tab'] @@ -18844,20 +19107,9 @@ def AttentionOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # AttentionOptions - def LayerIdx(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - def AttentionOptionsStart(builder): - builder.StartObject(1) - - -def AttentionOptionsAddLayerIdx(builder, layerIdx): - builder.PrependInt32Slot(0, layerIdx, 0) + builder.StartObject(0) def AttentionOptionsEnd(builder): @@ -18868,7 +19120,7 @@ class AttentionOptionsT(object): # AttentionOptionsT def __init__(self): - self.layerIdx = 0 # type: int + pass @classmethod def InitFromBuf(cls, buf, pos): @@ -18891,12 +19143,10 @@ def InitFromObj(cls, attentionOptions): def _UnPack(self, attentionOptions): if attentionOptions is None: return - self.layerIdx = attentionOptions.LayerIdx() # AttentionOptionsT def Pack(self, builder): AttentionOptionsStart(builder) - AttentionOptionsAddLayerIdx(builder, self.layerIdx) attentionOptions = AttentionOptionsEnd(builder) return attentionOptions @@ -19030,6 +19280,282 @@ def Pack(self, builder): return operatorCode +class StableHLOCompositeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StableHLOCompositeOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStableHLOCompositeOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StableHLOCompositeOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StableHLOCompositeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # StableHLOCompositeOptions + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # StableHLOCompositeOptions + def DecompositionSubgraphIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # StableHLOCompositeOptions + def CompositeAttributes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get( + flatbuffers.number_types.Uint8Flags, + a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # StableHLOCompositeOptions + def CompositeAttributesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # StableHLOCompositeOptions + def CompositeAttributesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # StableHLOCompositeOptions + def CompositeAttributesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # StableHLOCompositeOptions + def CompositeAttributesFormat(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # StableHLOCompositeOptions + def Version(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def StableHLOCompositeOptionsStart(builder): + builder.StartObject(5) + + +def StableHLOCompositeOptionsAddName(builder, name): + builder.PrependUOffsetTRelativeSlot( + 0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) + + +def StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder, + decompositionSubgraphIndex): + builder.PrependInt32Slot(1, decompositionSubgraphIndex, 0) + + +def StableHLOCompositeOptionsAddCompositeAttributes(builder, compositeAttributes): + builder.PrependUOffsetTRelativeSlot( + 2, flatbuffers.number_types.UOffsetTFlags.py_type(compositeAttributes), 0) + + +def StableHLOCompositeOptionsStartCompositeAttributesVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + + +def StableHLOCompositeOptionsAddCompositeAttributesFormat(builder, + compositeAttributesFormat): + builder.PrependInt8Slot(3, compositeAttributesFormat, 0) + + +def StableHLOCompositeOptionsAddVersion(builder, version): + builder.PrependInt32Slot(4, version, 0) + + +def StableHLOCompositeOptionsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + + +class StableHLOCompositeOptionsT(object): + + # StableHLOCompositeOptionsT + def __init__(self): + self.name = None # type: str + self.decompositionSubgraphIndex = 0 # type: int + self.compositeAttributes = None # type: List[int] + self.compositeAttributesFormat = 0 # type: int + self.version = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + stableHlocompositeOptions = StableHLOCompositeOptions() + stableHlocompositeOptions.Init(buf, pos) + return cls.InitFromObj(stableHlocompositeOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stableHlocompositeOptions): + x = StableHLOCompositeOptionsT() + x._UnPack(stableHlocompositeOptions) + return x + + # StableHLOCompositeOptionsT + def _UnPack(self, stableHlocompositeOptions): + if stableHlocompositeOptions is None: + return + self.name = stableHlocompositeOptions.Name() + self.decompositionSubgraphIndex = stableHlocompositeOptions.DecompositionSubgraphIndex( + ) + if not stableHlocompositeOptions.CompositeAttributesIsNone(): + if np is None: + self.compositeAttributes = [] + for i in range(stableHlocompositeOptions.CompositeAttributesLength()): + self.compositeAttributes.append( + stableHlocompositeOptions.CompositeAttributes(i)) + else: + self.compositeAttributes = stableHlocompositeOptions.CompositeAttributesAsNumpy( + ) + self.compositeAttributesFormat = stableHlocompositeOptions.CompositeAttributesFormat( + ) + self.version = stableHlocompositeOptions.Version() + + # StableHLOCompositeOptionsT + def Pack(self, builder): + if self.name is not None: + name = builder.CreateString(self.name) + if self.compositeAttributes is not None: + if np is not None and type(self.compositeAttributes) is np.ndarray: + compositeAttributes = builder.CreateNumpyVector(self.compositeAttributes) + else: + StableHLOCompositeOptionsStartCompositeAttributesVector( + builder, len(self.compositeAttributes)) + for i in reversed(range(len(self.compositeAttributes))): + builder.PrependUint8(self.compositeAttributes[i]) + compositeAttributes = builder.EndVector() + StableHLOCompositeOptionsStart(builder) + if self.name is not None: + StableHLOCompositeOptionsAddName(builder, name) + StableHLOCompositeOptionsAddDecompositionSubgraphIndex( + builder, self.decompositionSubgraphIndex) + if self.compositeAttributes is not None: + StableHLOCompositeOptionsAddCompositeAttributes(builder, compositeAttributes) + StableHLOCompositeOptionsAddCompositeAttributesFormat( + builder, self.compositeAttributesFormat) + StableHLOCompositeOptionsAddVersion(builder, self.version) + stableHlocompositeOptions = StableHLOCompositeOptionsEnd(builder) + return stableHlocompositeOptions + + +class StablehloShiftLeftOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = StablehloShiftLeftOptions() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsStablehloShiftLeftOptions(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + + @classmethod + def StablehloShiftLeftOptionsBufferHasIdentifier(cls, + buf, + offset, + size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, + offset, + b"\x43\x49\x52\x30", + size_prefixed=size_prefixed) + + # StablehloShiftLeftOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def StablehloShiftLeftOptionsStart(builder): + builder.StartObject(0) + + +def StablehloShiftLeftOptionsEnd(builder): + return builder.EndObject() + + +class StablehloShiftLeftOptionsT(object): + + # StablehloShiftLeftOptionsT + def __init__(self): + pass + + @classmethod + def InitFromBuf(cls, buf, pos): + stablehloShiftLeftOptions = StablehloShiftLeftOptions() + stablehloShiftLeftOptions.Init(buf, pos) + return cls.InitFromObj(stablehloShiftLeftOptions) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos + n) + + @classmethod + def InitFromObj(cls, stablehloShiftLeftOptions): + x = StablehloShiftLeftOptionsT() + x._UnPack(stablehloShiftLeftOptions) + return x + + # StablehloShiftLeftOptionsT + def _UnPack(self, stablehloShiftLeftOptions): + if stablehloShiftLeftOptions is None: + return + + # StablehloShiftLeftOptionsT + def Pack(self, builder): + StablehloShiftLeftOptionsStart(builder) + stablehloShiftLeftOptions = StablehloShiftLeftOptionsEnd(builder) + return stablehloShiftLeftOptions + + class Operator(object): __slots__ = ['_tab'] @@ -19263,9 +19789,16 @@ def BuiltinOptions2(self): return obj return None + # Operator + def DebugMetadataIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return -1 + def OperatorStart(builder): - builder.StartObject(13) + builder.StartObject(14) def OperatorAddOpcodeIndex(builder, opcodeIndex): @@ -19347,6 +19880,10 @@ def OperatorAddBuiltinOptions2(builder, builtinOptions2): 12, flatbuffers.number_types.UOffsetTFlags.py_type(builtinOptions2), 0) +def OperatorAddDebugMetadataIndex(builder, debugMetadataIndex): + builder.PrependInt32Slot(13, debugMetadataIndex, -1) + + def OperatorEnd(builder): return builder.EndObject() @@ -19373,7 +19910,8 @@ def __init__(self): self.largeCustomOptionsOffset = 0 # type: int self.largeCustomOptionsSize = 0 # type: int self.builtinOptions2Type = 0 # type: int - self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT, ReduceWindowOptionsT] + self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT, ReduceWindowOptionsT, StableHLOCompositeOptionsT, StablehloShiftLeftOptionsT, StablehloCaseOptionsT] + self.debugMetadataIndex = -1 # type: int @classmethod def InitFromBuf(cls, buf, pos): @@ -19441,6 +19979,7 @@ def _UnPack(self, operator): self.builtinOptions2Type = operator.BuiltinOptions2Type() self.builtinOptions2 = BuiltinOptions2Creator(self.builtinOptions2Type, operator.BuiltinOptions2()) + self.debugMetadataIndex = operator.DebugMetadataIndex() # OperatorT def Pack(self, builder): @@ -19511,6 +20050,7 @@ def Pack(self, builder): OperatorAddBuiltinOptions2Type(builder, self.builtinOptions2Type) if self.builtinOptions2 is not None: OperatorAddBuiltinOptions2(builder, builtinOptions2) + OperatorAddDebugMetadataIndex(builder, self.debugMetadataIndex) operator = OperatorEnd(builder) return operator @@ -19654,9 +20194,16 @@ def Name(self): return self._tab.String(o + self._tab.Pos) return None + # SubGraph + def DebugMetadataIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return -1 + def SubGraphStart(builder): - builder.StartObject(6) + builder.StartObject(7) def SubGraphAddTensors(builder, tensors): @@ -19700,6 +20247,10 @@ def SubGraphAddName(builder, name): 4, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) +def SubGraphAddDebugMetadataIndex(builder, debugMetadataIndex): + builder.PrependInt32Slot(6, debugMetadataIndex, -1) + + def SubGraphEnd(builder): return builder.EndObject() @@ -19719,6 +20270,7 @@ def __init__(self): self.outputs = None # type: List[int] self.operators = None # type: List[OperatorT] self.name = None # type: str + self.debugMetadataIndex = -1 # type: int @classmethod def InitFromBuf(cls, buf, pos): @@ -19772,6 +20324,7 @@ def _UnPack(self, subGraph): operator_ = OperatorT.InitFromObj(subGraph.Operators(i)) self.operators.append(operator_) self.name = subGraph.Name() + self.debugMetadataIndex = subGraph.DebugMetadataIndex() # SubGraphT def Pack(self, builder): @@ -19820,6 +20373,7 @@ def Pack(self, builder): SubGraphAddOperators(builder, operators) if self.name is not None: SubGraphAddName(builder, name) + SubGraphAddDebugMetadataIndex(builder, self.debugMetadataIndex) subGraph = SubGraphEnd(builder) return subGraph diff --git a/tools/circle2circle/fuse.attention.py b/tools/circle2circle/fuse.attention.py new file mode 100755 index 00000000000..597ddd7d11d --- /dev/null +++ b/tools/circle2circle/fuse.attention.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +import numpy as np +import circle +import o2o + + +def find_operator_by_output(subgraph, output_tensor_index): + """Find the first operator that produces the given output tensor index.""" + for op_idx, operator in enumerate(subgraph.operators): + if operator.outputs and output_tensor_index in operator.outputs: + return op_idx, operator + return None, None + + +def find_attention_blocks(model, subgraph): + """Find all attention blocks in the subgraph.""" + attention_blocks = [] + + # Pattern: 45 operators per attention block + # First block: operators 19-63 + # Second block: operators 84-128 + # Third block: operators 149-193 + # Fourth block: operators 194-238 + # Fifth block: operators 239-283 + # Sixth block: operators 284-328 + # Seventh block: operators 329-373 + # Eighth block: operators 374-418 + + for layer_idx in range(8): + start_op = 20 + (layer_idx * 64) # 64 operators between blocks, starting from 20 + end_op = start_op + 43 # 44 operators total (20-63 inclusive) + + if end_op < len(subgraph.operators): + # Verify this is an attention block by checking key operators + # Should start with FULLY_CONNECTED and end with FULLY_CONNECTED + first_op = subgraph.operators[start_op] + last_op = subgraph.operators[end_op] + + first_opcode = model.operatorCodes[first_op.opcodeIndex] + last_opcode = model.operatorCodes[last_op.opcodeIndex] + + if (first_opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED and + last_opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED): + + attention_blocks.append({ + 'layer_idx': + layer_idx, + 'start_op': + start_op, + 'end_op': + end_op, + 'operators': + subgraph.operators[start_op:end_op + 1] + }) + o2o.log( + f"Found attention block {layer_idx}: operators {start_op}-{end_op}") + else: + o2o.log( + f"Pattern mismatch for block starting at {start_op}: expected FULLY_CONNECTED->FULLY_CONNECTED, got {first_opcode.builtinCode}->{last_opcode.builtinCode}" + ) + else: + o2o.log( + f"Block {layer_idx} exceeds operator count (start_op={start_op}, total={len(subgraph.operators)})" + ) + + return attention_blocks + + +def map_attention_inputs(subgraph, block, model): + """Map the 11 input tensors for attention fusion.""" + layer_idx = block['layer_idx'] + + # 1. hidden_states (RMSNorm output) + hidden_states_tensor = o2o.get_tensor_by_index(subgraph, + block['operators'][0].inputs[0]) + if not hidden_states_tensor: + o2o.log(f"Could not find hidden_states tensor for layer {layer_idx}") + return None + + # 2. wq (query weight) + wq_name = f"tico::p_model_layers_{layer_idx}_self_attn_q_proj_weight" + wq_idx = o2o.get_tensor_index_by_name(subgraph, wq_name) + if wq_idx == -1: + o2o.log(f"Could not find wq tensor: {wq_name}") + return None + + # 3. wk (key weight) + wk_name = f"tico::p_model_layers_{layer_idx}_self_attn_k_proj_weight" + wk_idx = o2o.get_tensor_index_by_name(subgraph, wk_name) + if wk_idx == -1: + o2o.log(f"Could not find wk tensor: {wk_name}") + return None + + # 4. wv (value weight) + wv_name = f"tico::p_model_layers_{layer_idx}_self_attn_v_proj_weight" + wv_idx = o2o.get_tensor_index_by_name(subgraph, wv_name) + if wv_idx == -1: + o2o.log(f"Could not find wv tensor: {wv_name}") + return None + + # 5. wo (output weight) + wo_name = f"tico::p_model_layers_{layer_idx}_self_attn_o_proj_weight" + wo_idx = o2o.get_tensor_index_by_name(subgraph, wo_name) + if wo_idx == -1: + o2o.log(f"Could not find wo tensor: {wo_name}") + return None + + # 6. position_cos + position_cos_idx = o2o.get_tensor_index_by_name( + subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::mul_1") + if position_cos_idx == -1: + o2o.log("Could not find position_cos tensor") + return None + + # 7. position_sin + position_sin_idx = o2o.get_tensor_index_by_name( + subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::mul_2") + if position_sin_idx == -1: + o2o.log("Could not find position_sin tensor") + return None + + # 8. attention_mask + attention_mask_idx = o2o.get_tensor_index_by_name( + subgraph, "transformers.models.llama.modeling_llama.LlamaModel::mul") + if attention_mask_idx == -1: + o2o.log("Could not find attention_mask tensor") + return None + + # 9. past_key + past_key_name = f"tico::past_key_values_key_cache_{layer_idx}" + past_key_idx = o2o.get_tensor_index_by_name(subgraph, past_key_name) + if past_key_idx == -1: + o2o.log(f"Could not find past_key tensor: {past_key_name}") + return None + + # 10. past_value + past_value_name = f"tico::past_key_values_value_cache_{layer_idx}" + past_value_idx = o2o.get_tensor_index_by_name(subgraph, past_value_name) + if past_value_idx == -1: + o2o.log(f"Could not find past_value tensor: {past_value_name}") + return None + + # 11. cache_position + cache_position_idx = o2o.get_tensor_index_by_name(subgraph, "tico::cache_position") + if cache_position_idx == -1: + o2o.log("Could not find cache_position tensor") + return None + + # Find the tensor index for hidden_states (not buffer index) + hidden_states_idx = o2o.get_tensor_index_by_name( + subgraph, + o2o.get_tensor_name(hidden_states_tensor)) if hidden_states_tensor.name else -1 + if hidden_states_idx == -1: + o2o.log(f"Could not find tensor index for hidden_states in layer {layer_idx}") + return None + + return [ + hidden_states_idx, # hidden_states (tensor index, not buffer index) + wq_idx, # wq + wk_idx, # wk + wv_idx, # wv + wo_idx, # wo + position_cos_idx, # position_cos + position_sin_idx, # position_sin + attention_mask_idx, # attention_mask + past_key_idx, # past_key + past_value_idx, # past_value + cache_position_idx # cache_position + ] + + +def fuse_attention(): + """Main function to fuse attention operators.""" + o2o.log("Loading model from stdin") + model = o2o.load_model_from_stdin() + + if not model.subgraphs: + o2o.log("Model has no subgraphs. Exiting.") + return + + subgraph = model.subgraphs[0] # Assuming single subgraph for now + attention_blocks = find_attention_blocks(model, subgraph) + + o2o.log(f"Found {len(attention_blocks)} attention blocks to fuse") + + operators_to_remove = [] + + for block in attention_blocks: + # Map input tensors + input_indices = map_attention_inputs(subgraph, block, model) + if input_indices is None: + o2o.log( + f"Skipping attention block {block['layer_idx']} due to missing inputs") + continue + + # Create ATTENTION operator + attention_op = circle.OperatorT() + attention_op.opcodeIndex = o2o.get_or_create_operator_code( + model, circle.BuiltinOperator.ATTENTION) + attention_op.inputs = input_indices + attention_op.outputs = [block['operators'][-1].outputs[0] + ] # Use last operator's output + + # Configure AttentionOptions (empty since it's deprecated) + attention_op.builtinOptionsType = circle.BuiltinOptions.AttentionOptions + attention_op.builtinOptions = circle.AttentionOptionsT() + + # Replace the first operator with ATTENTION operator + start_idx = block['start_op'] + subgraph.operators[start_idx] = attention_op + + # Mark intermediate operators for removal (except the first one which we replaced) + for i in range(block['end_op'], start_idx, -1): + operators_to_remove.append(i) + + o2o.log( + f"Fused attention block {block['layer_idx']}: operators {block['start_op']}-{block['end_op']} -> ATTENTION" + ) + + # Remove marked operators in reverse order to avoid index shifting + for i in sorted(operators_to_remove, reverse=True): + if 0 <= i < len(subgraph.operators): + del subgraph.operators[i] + + o2o.log(f"Removed {len(operators_to_remove)} intermediate operators") + o2o.log(f"Model now has {len(subgraph.operators)} operators") + + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + fuse_attention() diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index 4dc4bbfc6b4..27f716838d2 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -111,6 +111,46 @@ def rename_tensor_if_matches(tensor, pattern, replacement_func): return False, None, None +def get_tensor_by_index(subgraph, index): + """Safely get tensor by its index.""" + if 0 <= index < len(subgraph.tensors): + return subgraph.tensors[index] + return None + + +def get_tensor_index_by_name(subgraph, name): + """Find tensor index by name, handling byte strings.""" + name_bytes = name.encode('utf-8') # Convert str to bytes for comparison + for i, tensor in enumerate(subgraph.tensors): + if tensor.name and tensor.name == name_bytes: + return i + return -1 # Not found + + +def is_tensor_constant(tensor, model_buffers): + """Check if a tensor is constant by verifying its buffer.""" + if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): + # A non-zero buffer index that points to a valid buffer typically means it's constant. + # The 0th buffer is always an empty buffer. + return True + return False + + +def get_or_create_operator_code(model, builtin_op_type): + """Get the index of an operator code, or create it if it doesn't exist.""" + for i, op_code in enumerate(model.operatorCodes): + if op_code.builtinCode == builtin_op_type: + return i + + # If not found, create a new one + new_op_code = circle.OperatorCodeT() + new_op_code.builtinCode = builtin_op_type + new_op_code.deprecatedBuiltinCode = builtin_op_type + new_op_code.version = 1 # Default version + model.operatorCodes.append(new_op_code) + return len(model.operatorCodes) - 1 + + def safe_execute(main_func, input_file, output_file, diff --git a/tools/circle2circle/select.op.py b/tools/circle2circle/select.op.py index fefa85d62f5..9a49fd60016 100755 --- a/tools/circle2circle/select.op.py +++ b/tools/circle2circle/select.op.py @@ -48,7 +48,7 @@ def parse_operator_indices(indices_str): indices.update(range(start_idx, end_idx + 1)) except ValueError as e: - o2o.log(f"Error parsing range '{part}': {e}", file=sys.stderr) + o2o.log(f"Error parsing range '{part}': {e}") sys.exit(1) else: # Single index @@ -58,7 +58,7 @@ def parse_operator_indices(indices_str): raise ValueError("Index must be non-negative") indices.add(idx) except ValueError as e: - o2o.log(f"Error parsing index '{part}': {e}", file=sys.stderr) + o2o.log(f"Error parsing index '{part}': {e}") sys.exit(1) return sorted(list(indices)) @@ -129,7 +129,7 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to tuple: (removed_operators_count, removed_operator_codes_count) """ if not model.subgraphs or subgraph_index >= len(model.subgraphs): - o2o.log(f"Error: Invalid subgraph index {subgraph_index}", file=sys.stderr) + o2o.log(f"Error: Invalid subgraph index {subgraph_index}") return 0, 0 subgraph = model.subgraphs[subgraph_index] @@ -141,13 +141,13 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to ] if invalid_indices: o2o.log( - f"Error: Operator indices {invalid_indices} exceed maximum index {max_operator_index}", - file=sys.stderr) + f"Error: Operator indices {invalid_indices} exceed maximum index {max_operator_index}" + ) sys.exit(1) o2o.log( - f"Subgraph {subgraph_index}: Keeping {len(operator_indices_to_keep)} operator(s): {operator_indices_to_keep}", - file=sys.stderr) + f"Subgraph {subgraph_index}: Keeping {len(operator_indices_to_keep)} operator(s): {operator_indices_to_keep}" + ) # Step 1: Determine which operators to remove total_operators = len(subgraph.operators) @@ -157,8 +157,8 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to operator_indices_to_remove.append(i) o2o.log( - f"Will remove {len(operator_indices_to_remove)} operator(s): {operator_indices_to_remove}", - file=sys.stderr) + f"Will remove {len(operator_indices_to_remove)} operator(s): {operator_indices_to_remove}" + ) # Step 2: Analyze tensor connections BEFORE removing operators connections = analyze_tensor_connections(subgraph) @@ -198,8 +198,8 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to ] subgraph.inputs = new_inputs o2o.log( - f"Removed {len(inputs_to_remove)} subgraph inputs: {sorted(inputs_to_remove)}", - file=sys.stderr) + f"Removed {len(inputs_to_remove)} subgraph inputs: {sorted(inputs_to_remove)}" + ) # Update subgraph outputs if outputs_to_remove: @@ -208,8 +208,8 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to ] subgraph.outputs = new_outputs o2o.log( - f"Removed {len(outputs_to_remove)} subgraph outputs: {sorted(outputs_to_remove)}", - file=sys.stderr) + f"Removed {len(outputs_to_remove)} subgraph outputs: {sorted(outputs_to_remove)}" + ) # Step 5: Update operator inputs that reference outputs of removed operators for op_idx, operator in enumerate(subgraph.operators): @@ -222,8 +222,8 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to # This input comes from a removed operator, set to -1 updated_inputs.append(-1) o2o.log( - f" Operator {op_idx}: Breaking input connection from removed operator {producer_idx}", - file=sys.stderr) + f" Operator {op_idx}: Breaking input connection from removed operator {producer_idx}" + ) else: updated_inputs.append(input_idx) else: @@ -251,8 +251,7 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to op_name = f"builtin_code={operator_code.builtinCode}" else: op_name = f"custom_code={operator_code.customCode}" - o2o.log(f" Removing unused OperatorCode at index {code_idx}: {op_name}", - file=sys.stderr) + o2o.log(f" Removing unused OperatorCode at index {code_idx}: {op_name}") del model.operatorCodes[code_idx] removed_operator_codes.append(code_idx) @@ -298,8 +297,8 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to # First check tensors that are in tensor_use_count for tensor_idx, use_count in tensor_use_count.items(): if use_count == 0 and tensor_idx not in current_outputs: - if subgraph.outputs is None: - subgraph.outputs = [] + subgraph.outputs = list( + subgraph.outputs) if subgraph.outputs is not None else [] subgraph.outputs.append(tensor_idx) added_outputs.append(tensor_idx) @@ -311,28 +310,25 @@ def select_operators_and_update_model(model, subgraph_index, operator_indices_to use_count = tensor_use_count.get( output_idx, 0) # Default to 0 if not in use_count if use_count == 0 and output_idx not in current_outputs and output_idx not in added_outputs: - if subgraph.outputs is None: - subgraph.outputs = [] + subgraph.outputs = list( + subgraph.outputs) if subgraph.outputs is not None else [] subgraph.outputs.append(output_idx) added_outputs.append(output_idx) if added_outputs: - o2o.log(f"Added tensors to subgraph outputs: {sorted(added_outputs)}", - file=sys.stderr) + o2o.log(f"Added tensors to subgraph outputs: {sorted(added_outputs)}") # Add tensors with def_count == 0 to subgraph inputs added_inputs = [] current_inputs = set(subgraph.inputs) if subgraph.inputs is not None else set() for tensor_idx, def_count in tensor_def_count.items(): if def_count == 0 and tensor_idx not in current_inputs: - if subgraph.inputs is None: - subgraph.inputs = [] + subgraph.inputs = list(subgraph.inputs) if subgraph.inputs is not None else [] subgraph.inputs.append(tensor_idx) added_inputs.append(tensor_idx) if added_inputs: - o2o.log(f"Added tensors to subgraph inputs: {sorted(added_inputs)}", - file=sys.stderr) + o2o.log(f"Added tensors to subgraph inputs: {sorted(added_inputs)}") return len(removed_operators), len(removed_operator_codes) @@ -350,13 +346,13 @@ def main(): # Parse the operator indices try: operator_indices_to_keep = parse_operator_indices(args.by_id) - o2o.log(f"Operator indices to keep: {operator_indices_to_keep}", file=sys.stderr) + o2o.log(f"Operator indices to keep: {operator_indices_to_keep}") except ValueError as e: - o2o.log(f"Error parsing operator indices: {e}", file=sys.stderr) + o2o.log(f"Error parsing operator indices: {e}") sys.exit(1) if not operator_indices_to_keep: - o2o.log("No valid operator indices specified", file=sys.stderr) + o2o.log("No valid operator indices specified") sys.exit(1) # Load the model @@ -366,20 +362,18 @@ def main(): subgraph_index = 0 if not model.subgraphs or subgraph_index >= len(model.subgraphs): - o2o.log(f"Error: Model has no subgraph at index {subgraph_index}", - file=sys.stderr) + o2o.log(f"Error: Model has no subgraph at index {subgraph_index}") sys.exit(1) - o2o.log(f"Model has {len(model.subgraphs[subgraph_index].operators)} operators", - file=sys.stderr) + o2o.log(f"Model has {len(model.subgraphs[subgraph_index].operators)} operators") # Select operators (keep only specified ones) removed_ops_count, removed_codes_count = select_operators_and_update_model( model, subgraph_index, operator_indices_to_keep) o2o.log( - f"Removed {removed_ops_count} operators and {removed_codes_count} unused OperatorCode entries", - file=sys.stderr) + f"Removed {removed_ops_count} operators and {removed_codes_count} unused OperatorCode entries" + ) # Save the model directly to stdout o2o.save_model_to_stdout(model) From ef8f0c35e3f9bd49e9c38bf8cedb21eb7c3bfd52 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 13:41:47 +0900 Subject: [PATCH 05/27] Add gc.py --- tools/circle2circle/README.md | 4 +- .../{remove.unused_tensors.py => gc.py} | 130 +++++++++++++++++- 2 files changed, 131 insertions(+), 3 deletions(-) rename tools/circle2circle/{remove.unused_tensors.py => gc.py} (63%) diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md index fa654848b68..5677398cd36 100644 --- a/tools/circle2circle/README.md +++ b/tools/circle2circle/README.md @@ -121,9 +121,9 @@ Selectively removes operators from a Circle model based on their index range. Th ## -### `remove.unused_tensors.py` +### `gc.py` -Identifies and removes unused tensors from all subgraphs within a Circle model. A tensor is considered "unused" if it is not an input to any operator and not an output of its containing subgraph. This helps in cleaning up the model and potentially reducing its size. The script can either list unused tensors or modify the model to remove them. +Performs garbage collection by removing unreachable tensors and buffers, reducing model size and memory consumption. ## diff --git a/tools/circle2circle/remove.unused_tensors.py b/tools/circle2circle/gc.py similarity index 63% rename from tools/circle2circle/remove.unused_tensors.py rename to tools/circle2circle/gc.py index 9a560bca6b8..3230a773923 100755 --- a/tools/circle2circle/remove.unused_tensors.py +++ b/tools/circle2circle/gc.py @@ -56,6 +56,63 @@ def find_unused_tensors_in_subgraph(subgraph): return unused_indices +def find_unused_buffers(model): + """ + Finds and returns the indices of unused buffers in the model. + This function works with both Native API (read-only) and Object API (mutable) model objects. + + Args: + model: The Circle model object (read-only or mutable). + + Returns: + list: A list of integer indices representing unused buffers. + """ + # Handle both Native API and Object API + if hasattr(model, 'BuffersLength'): + # Native API + if not model.BuffersLength(): + return [] + + used_buffer_indices = set() + + # Collect buffer indices from all tensors in all subgraphs + for i in range(model.SubgraphsLength()): + subgraph = model.Subgraphs(i) + if subgraph: + for j in range(subgraph.TensorsLength()): + tensor = subgraph.Tensors(j) + if tensor and tensor.Buffer() != -1: # -1 indicates no buffer + used_buffer_indices.add(tensor.Buffer()) + + # A buffer is unused if it's not referenced by any tensor + unused_indices = [] + for i in range(model.BuffersLength()): + if i not in used_buffer_indices: + unused_indices.append(i) + + return unused_indices + else: + # Object API + if not model.buffers: + return [] + + used_buffer_indices = set() + + # Collect buffer indices from all tensors in all subgraphs + for subgraph in model.subgraphs: + for tensor in subgraph.tensors: + if tensor.buffer != -1: # -1 indicates no buffer + used_buffer_indices.add(tensor.buffer) + + # A buffer is unused if it's not referenced by any tensor + unused_indices = [] + for i in range(len(model.buffers)): + if i not in used_buffer_indices: + unused_indices.append(i) + + return unused_indices + + def remove_tensors_and_update_model(model, subgraph_index_to_modify, tensor_indices_to_remove): """ @@ -158,6 +215,61 @@ def remove_tensors_and_update_model(model, subgraph_index_to_modify, return sorted(removed_indices) +def remove_buffers_and_update_model(model, buffer_indices_to_remove): + """ + Removes specified buffers from the model and updates all tensor references. + This function uses the Object API for mutable model objects. + + Args: + model: The mutable Circle model object (ModelT). + buffer_indices_to_remove (list): A list of buffer indices to remove. + Must be sorted in descending order. + + Returns: + list: The list of buffer indices that were actually removed. + """ + if not model.buffers: + o2o.log("Model has no buffers to remove.") + return [] + + removed_indices = [] + + # Sort in descending order to avoid index shifting issues during removal + for buffer_idx in sorted(buffer_indices_to_remove, reverse=True): + if 0 <= buffer_idx < len(model.buffers): + o2o.log(f" Removing buffer at index {buffer_idx}") + del model.buffers[buffer_idx] + removed_indices.append(buffer_idx) + else: + o2o.log(f" Warning: Buffer index {buffer_idx} out of bounds, skipping.") + + if not removed_indices: + return [] + + # Create a map for old index to new index after removal + new_indices_map = {} + current_new_idx = 0 + # Iterate over original buffer count + original_buffer_count = len(model.buffers) + len(removed_indices) + for old_idx in range(original_buffer_count): + if old_idx not in buffer_indices_to_remove: + new_indices_map[old_idx] = current_new_idx + current_new_idx += 1 + + # Update tensor buffer references in all subgraphs + for subgraph_idx, subgraph in enumerate(model.subgraphs): + for tensor_idx, tensor in enumerate(subgraph.tensors): + if tensor.buffer != -1: # -1 indicates no buffer + if tensor.buffer in new_indices_map: + old_buffer_idx = tensor.buffer + tensor.buffer = new_indices_map[old_buffer_idx] + # If tensor.buffer was removed, set to -1 (no buffer) + elif tensor.buffer in buffer_indices_to_remove: + tensor.buffer = -1 + + return sorted(removed_indices) + + def main(): # Read the entire model from stdin data = sys.stdin.buffer.read() @@ -203,11 +315,27 @@ def main(): f"\nTotal unused tensors found across all subgraphs: {total_unused_tensors_count}" ) + # After removing tensors, now process unused buffers + # Use the mutable model directly since find_unused_buffers now supports both APIs + unused_buffers = find_unused_buffers(model) + if unused_buffers: + o2o.log( + f"Found {len(unused_buffers)} unused buffer(s): {', '.join(map(str, sorted(unused_buffers)))}" + ) + actually_removed_buffers = remove_buffers_and_update_model(model, unused_buffers) + if actually_removed_buffers: + o2o.log(f"Removed {len(actually_removed_buffers)} buffer(s).") + model_changed = True + else: + o2o.log("No buffers were actually removed during the process.") + else: + o2o.log("No unused buffers found.") + if model_changed: o2o.log("\nSaving modified model to stdout...") else: o2o.log( - "\nNo tensors were actually removed from any subgraph. Saving original model to stdout." + "\nNo tensors or buffers were actually removed. Saving original model to stdout." ) o2o.save_model_to_stdout(model) From 2a9695946b7776b1d6a408c4849f8805b986740f Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 15:07:50 +0900 Subject: [PATCH 06/27] Add with.py --- tools/circle2circle/with.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100755 tools/circle2circle/with.py diff --git a/tools/circle2circle/with.py b/tools/circle2circle/with.py new file mode 100755 index 00000000000..47709da413b --- /dev/null +++ b/tools/circle2circle/with.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +import sys +import pathlib + +def main(): + if len(sys.argv) < 2: + sys.stderr.write("Usage: with.py \\n") + sys.exit(1) + + input_path = pathlib.Path(sys.argv[1]) + if not input_path.is_file(): + sys.stderr.write(f"File not found: {input_path}\\n") + sys.exit(1) + + # Read the binary content of the circle file and write it to stdout + with input_path.open('rb') as f: + sys.stdout.buffer.write(f.read()) + +if __name__ == "__main__": + main() From f72e8073ade5a1f1c9aead5565f051634005d109 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 15:51:53 +0900 Subject: [PATCH 07/27] Add keep_by_id in remove.io.py --- tools/circle2circle/o2o.py | 57 ++++++++++++++++++++++ tools/circle2circle/remove.io.py | 83 ++++++++++++++++++++++++++------ 2 files changed, 126 insertions(+), 14 deletions(-) diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index 27f716838d2..bae889b42af 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -126,6 +126,63 @@ def get_tensor_index_by_name(subgraph, name): return i return -1 # Not found +def parse_operator_indices(indices_str): + """Parse operator index string into a list of indices. + + Supports formats like: + - "0-181" (range) + - "0,5,10-15" (mixed) + - "0" (single index) + + Args: + indices_str (str): String containing operator indices + + Returns: + list: Sorted list of unique operator indices + """ + if not indices_str: + return [] + + indices = set() + + # Split by comma first + parts = indices_str.split(',') + + for part in parts: + part = part.strip() + if not part: + continue + + # Check if it's a range + if '-' in part: + try: + start, end = part.split('-', 1) + start_idx = int(start.strip()) + end_idx = int(end.strip()) + + if start_idx < 0 or end_idx < 0: + raise ValueError("Indices must be non-negative") + + if start_idx > end_idx: + raise ValueError(f"Invalid range: {start_idx} > {end_idx}") + + indices.update(range(start_idx, end_idx + 1)) + except ValueError as e: + log(f"Error parsing range '{part}': {e}") + sys.exit(1) + else: + # Single index + try: + idx = int(part) + if idx < 0: + raise ValueError("Index must be non-negative") + indices.add(idx) + except ValueError as e: + log(f"Error parsing index '{part}': {e}") + sys.exit(1) + + return sorted(list(indices)) + def is_tensor_constant(tensor, model_buffers): """Check if a tensor is constant by verifying its buffer.""" diff --git a/tools/circle2circle/remove.io.py b/tools/circle2circle/remove.io.py index 9e0830ed12d..84d36e28559 100755 --- a/tools/circle2circle/remove.io.py +++ b/tools/circle2circle/remove.io.py @@ -65,34 +65,89 @@ def process_subgraph(subgraph): # Save the model using utility function o2o.save_model_to_stdout(model) +def remove_io_tensors_by_id(io_type, ids_to_keep): + """Remove input or output tensors, keeping only specified tensor indices (IDs)""" + model = o2o.load_model_from_stdin() + + def process_subgraph(subgraph): + if io_type == 'input': + io_list = subgraph.inputs + io_name = 'input' + elif io_type == 'output': + io_list = subgraph.outputs + io_name = 'output' + else: + raise ValueError(f"Invalid io_type: {io_type}. Must be 'input' or 'output'") + + o2o.log(f"Processing subgraph with {len(io_list)} {io_name}s") + o2o.log(f"Original {io_name} indices: {io_list}") + + # Keep only those indices whose position in the original list matches ids_to_keep + new_io_list = [] + for idx, tensor_idx in enumerate(io_list): + if idx in ids_to_keep: + new_io_list.append(tensor_idx) + else: + o2o.log(f"Removing {io_name} tensor at position {idx}") + + # Update the subgraph + if io_type == 'input': + subgraph.inputs = new_io_list + else: + subgraph.outputs = new_io_list + + o2o.log(f"New {io_name} indices: {new_io_list}") + + removed_count = len(io_list) - len(new_io_list) + return removed_count > 0, removed_count + + overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) + + if not overall_modified: + o2o.log("No tensors were removed.") + + o2o.save_model_to_stdout(model) + def main(): parser = argparse.ArgumentParser( description= - 'Remove input or output tensors from Circle model, keeping only specified tensor names' + 'Remove input or output tensors from Circle model, keeping only specified tensor names or IDs' ) parser.add_argument('io_type', choices=['input', 'output'], help='Whether to process inputs or outputs') - parser.add_argument( + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( '--keep_by_name', - required=True, help='Comma‑separated tensor names to keep (e.g., "tensorA,tensorB")') + group.add_argument( + '--keep_by_id', + help='Comma‑separated tensor IDs or ranges to keep (e.g., "0,2-4")') # No file arguments needed; model is read from stdin and written to stdout args = parser.parse_args() - # Parse the tensor names - try: - names_to_keep = parse_names(args.keep_by_name) - o2o.log(f"Tensor names to keep: {names_to_keep}") - except ValueError as e: - o2o.log(f"Error parsing tensor names: {e}") - sys.exit(1) - - # Execute with error handling using utility function - # Directly invoke the processing function; I/O handled via stdin/stdout - remove_io_tensors(args.io_type, names_to_keep) + if args.keep_by_name: + # Parse the tensor names + try: + names_to_keep = parse_names(args.keep_by_name) + o2o.log(f"Tensor names to keep: {names_to_keep}") + except ValueError as e: + o2o.log(f"Error parsing tensor names: {e}") + sys.exit(1) + # Execute name‑based removal + remove_io_tensors(args.io_type, names_to_keep) + elif args.keep_by_id: + # Parse the tensor IDs + try: + ids_to_keep = o2o.parse_operator_indices(args.keep_by_id) + o2o.log(f"Tensor IDs to keep: {ids_to_keep}") + except Exception as e: + o2o.log(f"Error parsing tensor IDs: {e}") + sys.exit(1) + # Execute ID‑based removal + remove_io_tensors_by_id(args.io_type, ids_to_keep) if __name__ == "__main__": From cafd38cc724029b58bf9ae07037163cc800bcfb9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 15:52:43 +0900 Subject: [PATCH 08/27] Use tico full names --- tools/circle2circle/fuse.attention.py | 8 ++++---- tools/circle2circle/retype.input_ids.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/circle2circle/fuse.attention.py b/tools/circle2circle/fuse.attention.py index 597ddd7d11d..e8cf215614e 100755 --- a/tools/circle2circle/fuse.attention.py +++ b/tools/circle2circle/fuse.attention.py @@ -28,8 +28,8 @@ def find_attention_blocks(model, subgraph): # Eighth block: operators 374-418 for layer_idx in range(8): - start_op = 20 + (layer_idx * 64) # 64 operators between blocks, starting from 20 - end_op = start_op + 43 # 44 operators total (20-63 inclusive) + start_op = 20 + (layer_idx * 65) # 64 operators between blocks, starting from 20 + end_op = start_op + 44 # 44 operators total (20-63 inclusive) if end_op < len(subgraph.operators): # Verify this is an attention block by checking key operators @@ -108,14 +108,14 @@ def map_attention_inputs(subgraph, block, model): # 6. position_cos position_cos_idx = o2o.get_tensor_index_by_name( - subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::mul_1") + subgraph, "transformers.models.llama.modeling_llama.LlamaRotaryEmbedding::mul_1") if position_cos_idx == -1: o2o.log("Could not find position_cos tensor") return None # 7. position_sin position_sin_idx = o2o.get_tensor_index_by_name( - subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::mul_2") + subgraph, "transformers.models.llama.modeling_llama.LlamaRotaryEmbedding::mul_2") if position_sin_idx == -1: o2o.log("Could not find position_sin tensor") return None diff --git a/tools/circle2circle/retype.input_ids.py b/tools/circle2circle/retype.input_ids.py index 6a75187f2a6..f0f739f9522 100755 --- a/tools/circle2circle/retype.input_ids.py +++ b/tools/circle2circle/retype.input_ids.py @@ -19,7 +19,7 @@ def process_subgraph(subgraph): tensor_name = o2o.get_tensor_name(tensor) # Check if this is the input_ids tensor - if tensor_name == "input_ids": + if tensor_name == "tico::input_ids": # Check if current type is int64 if tensor.type == circle.TensorType.INT64: old_type = "int64" From 715e1ab74e598f7ebd9928be910728148b7a0c1f Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 15:53:34 +0900 Subject: [PATCH 09/27] Fix coding style --- tools/circle2circle/o2o.py | 1 + tools/circle2circle/remove.io.py | 1 + tools/circle2circle/with.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index bae889b42af..4a7e5d3197d 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -126,6 +126,7 @@ def get_tensor_index_by_name(subgraph, name): return i return -1 # Not found + def parse_operator_indices(indices_str): """Parse operator index string into a list of indices. diff --git a/tools/circle2circle/remove.io.py b/tools/circle2circle/remove.io.py index 84d36e28559..14bb4acfb72 100755 --- a/tools/circle2circle/remove.io.py +++ b/tools/circle2circle/remove.io.py @@ -65,6 +65,7 @@ def process_subgraph(subgraph): # Save the model using utility function o2o.save_model_to_stdout(model) + def remove_io_tensors_by_id(io_type, ids_to_keep): """Remove input or output tensors, keeping only specified tensor indices (IDs)""" model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/with.py b/tools/circle2circle/with.py index 47709da413b..602253883cd 100755 --- a/tools/circle2circle/with.py +++ b/tools/circle2circle/with.py @@ -2,6 +2,7 @@ import sys import pathlib + def main(): if len(sys.argv) < 2: sys.stderr.write("Usage: with.py \\n") @@ -16,5 +17,6 @@ def main(): with input_path.open('rb') as f: sys.stdout.buffer.write(f.read()) + if __name__ == "__main__": main() From 884e34d7f9bcdc07a749cc625b6254f9a43746ed Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 18:04:47 +0900 Subject: [PATCH 10/27] Add requirements.txt --- tools/circle2circle/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tools/circle2circle/requirements.txt diff --git a/tools/circle2circle/requirements.txt b/tools/circle2circle/requirements.txt new file mode 100644 index 00000000000..0a8720d0f2a --- /dev/null +++ b/tools/circle2circle/requirements.txt @@ -0,0 +1,3 @@ +# External Python dependencies for the tools in tools/circle2circle +numpy +flatbuffers From 5081c4535cafcb38f64f5473eada0c9373232d79 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 14 Nov 2025 18:16:53 +0900 Subject: [PATCH 11/27] Update README.md and remove unncessary import --- tools/circle2circle/README.md | 9 +++++++-- tools/circle2circle/gc.py | 4 +--- tools/circle2circle/rename.io.remove_namespace.py | 1 - tools/circle2circle/reshape.fc_weight.py | 1 - tools/circle2circle/reshape.io.py | 1 - tools/circle2circle/retype.input_ids.py | 1 - tools/circle2circle/select.op.py | 1 - 7 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md index 5677398cd36..d7eb0264bd5 100644 --- a/tools/circle2circle/README.md +++ b/tools/circle2circle/README.md @@ -21,7 +21,9 @@ An example: Filters example: ```bash -./rename.io.remove_namespace.py < in.circle | ./rename.io.remove_prefix.py past_key_values_ > out.circle +./with.py in.circle | + ./select.op.py --by_id 0-181 | + ./gc.py > new.circle ```
@@ -35,7 +37,10 @@ Removes input or output tensors from a Circle model, keeping only the tensors at #### Arguments * `io_type` (required): Specifies whether to process `input` or `output` tensors. -* `--keep_by_name` (required): A string defining the names of the tensors to keep. It supports comma‑separated tensor names (e.g., "input1,input2"). +* `--keep_by_name` (optional): A string defining the names of the tensors to keep. It supports comma‑separated tensor names (e.g., "input1,input2"). +* `--keep_by_id` (optional): Specifies the tensor indices to keep. Supports multiple ranges separated by commas and individual indices (e.g., "0,2-4"). + +**Note:** Exactly one of `--keep_by_name` or `--keep_by_id` must be provided. ## diff --git a/tools/circle2circle/gc.py b/tools/circle2circle/gc.py index 3230a773923..43feb2f18b8 100755 --- a/tools/circle2circle/gc.py +++ b/tools/circle2circle/gc.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 import sys -# import argparse # Removed: script now uses stdin/stdout instead of file arguments -import flatbuffers import circle -import o2o # For saving the model +import o2o def get_tensor_name(tensor): diff --git a/tools/circle2circle/rename.io.remove_namespace.py b/tools/circle2circle/rename.io.remove_namespace.py index 7f22b4fe1bf..976cce6606c 100755 --- a/tools/circle2circle/rename.io.remove_namespace.py +++ b/tools/circle2circle/rename.io.remove_namespace.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import sys -import re import circle import flatbuffers import o2o diff --git a/tools/circle2circle/reshape.fc_weight.py b/tools/circle2circle/reshape.fc_weight.py index 212a188d763..afb5b4a0c6e 100755 --- a/tools/circle2circle/reshape.fc_weight.py +++ b/tools/circle2circle/reshape.fc_weight.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import sys import numpy as np import circle import o2o diff --git a/tools/circle2circle/reshape.io.py b/tools/circle2circle/reshape.io.py index ce8bc4d20d7..c4074901302 100755 --- a/tools/circle2circle/reshape.io.py +++ b/tools/circle2circle/reshape.io.py @@ -2,7 +2,6 @@ import sys import argparse -import circle import o2o diff --git a/tools/circle2circle/retype.input_ids.py b/tools/circle2circle/retype.input_ids.py index f0f739f9522..dafddbaf3db 100755 --- a/tools/circle2circle/retype.input_ids.py +++ b/tools/circle2circle/retype.input_ids.py @@ -2,7 +2,6 @@ import o2o import circle -import sys def retype_input_ids(): diff --git a/tools/circle2circle/select.op.py b/tools/circle2circle/select.op.py index 9a49fd60016..e0397369a8c 100755 --- a/tools/circle2circle/select.op.py +++ b/tools/circle2circle/select.op.py @@ -2,7 +2,6 @@ import sys import argparse -import flatbuffers import o2o From b74d98d4848f099a3f091563e0eab416b051cbcc Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 18 Nov 2025 09:58:28 +0900 Subject: [PATCH 12/27] Revisit o2o.py, and fuse.bmm_lhs_const.py uses stdout --- tools/circle2circle/fuse.bmm_lhs_const.py | 2 +- tools/circle2circle/gen_circle.add.py | Bin 6860 -> 448 bytes tools/circle2circle/o2o.py | 167 +++++++++++++--------- 3 files changed, 99 insertions(+), 70 deletions(-) diff --git a/tools/circle2circle/fuse.bmm_lhs_const.py b/tools/circle2circle/fuse.bmm_lhs_const.py index 8fd7f022db1..002d51357c2 100755 --- a/tools/circle2circle/fuse.bmm_lhs_const.py +++ b/tools/circle2circle/fuse.bmm_lhs_const.py @@ -160,7 +160,7 @@ def fuse_bmm_transpose(): if not model.subgraphs: o2o.log("Model has no subgraphs. Exiting.") - o2o.save_circle_model(model, output_file) # Save original if no subgraphs + o2o.save_model_to_stdout(model) # Output to stdout for consistency return subgraph = model.subgraphs[0] # Assuming single subgraph for now, can be extended diff --git a/tools/circle2circle/gen_circle.add.py b/tools/circle2circle/gen_circle.add.py index acb9bf33c9b6f6f26b30a4ef747a6374337eca2e..71b097257805b86608025bf2c21c5602d565a69d 100755 GIT binary patch literal 448 zcmZutF$%&!5FBF&NelX2#97ZBWI znB8M%XYWz~tMz7)AjK33BGx1L%)q+CmV6rkL$bmrmPAIVNA>;Fwm@LCfxXiV%6{U> zz89f;X@R(RmTzqDXXG7mAPPbrRbJL4Gt+6s+B&pac``!1L;1YzygcTDLf(REx1Zj2 zB{D)@59*}1Z}v1#d2^}^yBX?v+g>~0(@tX|FN*!Ksm~YY(mIVeGQT}B|K)bIyYyb= MA20Qc2;Usw3)`F}0RR91 literal 6860 zcmc&(TWjM;7Jm1yIO9A--cG99aqm1Byx?4B2s4SB#J#Ze2BFC1SRhMUm6VPd=D+W` zRFx#(x~ExKLxQb3_sdtON+&ImzDZ>=lYlwaT+mua~TgysRKM?ps!`iw#5u zvcf%zlQK@V`;}khz1|62@k6-F7dnN@db!jkd^XXGhKD-U27>3v=mq2%2zzMam0m|` z`D`G{7WfyX8DCLckG4;mCO5oSN0u|KF>ejX30%tRC{~qT|;{!Bw7_ zD#8r@NTx=dsgFl873WzMC7FTg)%^4HcetL;rx0*&sw#ipRN4dspTYat9HK*643ezU zC3YL7cuno<`Fu`Mop^H~ts@A&Ce;e4iX^ok72J~CH5MiIzz>Ci0&p4p{YV)07wj>2 z-`gzjDchE{r8ae1?Nsl_Cd<-1w|Adr5A*5$eBAX`!bGK{fvj=yGiP9Os8)IdnZ~Bt zZdvnF8>FsEE#NdO33!4B9p`I=1>_n~=OZQvR^1!~sYp<4nUn%f*l_|NHjxgQq$^$i z01Mc*7AFW6j9=$)6U8gnSk&;QlZC9AJQ{GZA z25O_bfcJrEEyxN6Mh(0#bKHTx8i-w{bIx>p|m2sznsVdGlr^?D1v6ooo8xV@)3L~IDT^8@LQaG9hUON@%V-|urBYKTC{ z-PX<=z_H0Q4b$su$O~OY74C8Yk|t?X5~iSZCAJO}O6##$91zH7EspXdkvFk^DbDXa z7&L3@RhNb~>j{inPo2^$#4dwtU!GcvZqM3TaG~9zU+XHOYPA{TRTk$Bb`tri`+;i3 zMfF%e|4N&rv)oZAhY4PxU$NWi?t;at#wX9>d~Uhr*Rp7>k)^-J4Pram&TFi&tJ%ZD z&D9*F3#5lYw;O@U11Rn? zTmj4%awsP_^(zLW;LGso3>E~}6MTknGeRX;EsqQ-jk4D$t#wO|)p{^8Q<^G5i^L zj-j)L1%I5RH+x|3k|%P7#TFx$5&syuS!Wr2b(S}7G!qij6g(O^GSBf2o=^Y%_Tj@R z{0#y$<{VO;1#Wr3&VGbEIc2(_{|qU;G4Z0F9E z9unNm{&{l^^BMd;eVnrn3+CU7_<@G93N81KbW}&b_s8OXr;~zYxh5xTlVmSDTQ?%m zu-#mC((olT=hKrz2x$SKQDzn4Jb}nM#Nwvgv7FHb|#Kjhm!}9Z53GU>h^v< zx*#c^*;xL$%kG0d6bNPXXvRJ9UWj&?9fD(2`n>9~-GZqrn6|~c5ejAAm~i24i=c21 zyTVzT$HTU4k+hw{Z9oZVj^Kw!b9vE*mE*aThB%PRW2x*g9#195phNF-$7!Y`GsddP z9n;d$&eD1O+?C`K=>u~-ooy?2gY+HgWDESH(uIs5|3;57%p63KagvR^c|_9*~<5%$n&k&5x>9p zI@5NHlP{;9>MSqlCt{u*NN08?r}5ZfWJ8BrP?F+37Ogfsm-5iDD%7?mUoCV|YQEmG z#NU)a`S~Sz8)GZ&W93`4XAK6gvgoqOVsy*CZ4;W~fB>27sctkTdwfJ2hb3Fh?@p@r zND!RtHdZ_MEu)(FxxjVftt2(Kd@VPj-Nz2&36twHS)Wmdv;P0!iB*s2k~$)UZz zLH*Sji;yaDf;_)_;3YiLJp>ao8(xengO5V-$ey@){mhPca`5)lY5o!{`*xzHqGMug zjaRvl{ln&xj_ZH+Ee@d;;k&~INWoNPjh8T`9`6u9=Pqarf3#~&5g7gnJU~slEBtI% zANbo8JEfw;_Kwz-*jYI!aV8hTQtBwLQ@mY)-KE14ZQAjM|N44QNFMDaw;c5h>ipso z>yGeWk|$Qp_xTAD`Z}?F`?N1OUa%PYS0!667LBiEmJo)Np)ka^BhU{+I^%|6UoLG< z@P$A|S25b18>mEhds~+I8sfUdOQZ@+TqZ?@_e8SDg?9q-!T=jbz39wVtHPH)c$rTR zrF@xI_}LI_wcMy8s#c2L?2xV?d diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index 4a7e5d3197d..5453c74ebbf 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -5,46 +5,77 @@ import flatbuffers +# ============================================================================ +# BASIC UTILITIES +# ============================================================================ + def log(message): """Log message to stderr""" print(message, file=sys.stderr) -def load_model_from_stdin(): - """Load a Circle model from binary data read from stdin.""" - data = sys.stdin.buffer.read() - buf = bytearray(data) - model = circle.Model.GetRootAsModel(buf, 0) - model = circle.ModelT.InitFromObj(model) - return model - +def safe_execute(main_func, + input_file, + output_file, + *args, + error_message="Error processing file"): + """Safely execute the main function with error handling""" + try: + main_func(input_file, output_file, *args) + log(f"Successfully processed {input_file} and saved to {output_file}") + except Exception as e: + log(f"{error_message}: {e}") + sys.exit(1) -def save_model_to_stdout(model): - """Serialize a Circle model and write it to stdout as binary data.""" - builder = flatbuffers.Builder(1024) - builder.Finish(model.Pack(builder), b'CIR0') - sys.stdout.buffer.write(builder.Output()) +# ============================================================================ +# CORE I/O FUNCTIONS +# ============================================================================ -def load_circle_model(input_file): +def load_circle_model(input_file=None): """Load and parse a circle model file""" - with open(input_file, 'rb') as f: - buf = bytearray(f.read()) + if input_file is None: + # Read from stdin + data = sys.stdin.buffer.read() + else: + # Read from file + with open(input_file, 'rb') as f: + data = f.read() + buf = bytearray(data) model = circle.Model.GetRootAsModel(buf, 0) model = circle.ModelT.InitFromObj(model) return model -def save_circle_model(model, output_file): +def load_model_from_stdin(): + """Load a Circle model from binary data read from stdin.""" + return load_circle_model() # input_file=None defaults to stdin + + +def save_circle_model(model, output_file=None): """Save a circle model to file using flatbuffers""" builder = flatbuffers.Builder(1024) builder.Finish(model.Pack(builder), b'CIR0') - with open(output_file, 'wb') as f: - f.write(builder.Output()) + if output_file is None: + # Write to stdout + sys.stdout.buffer.write(builder.Output()) + else: + # Write to file + with open(output_file, 'wb') as f: + f.write(builder.Output()) + + +def save_model_to_stdout(model): + """Serialize a Circle model and write it to stdout as binary data.""" + save_circle_model(model) # output_file=None defaults to stdout +# ============================================================================ +# CLI HANDLING +# ============================================================================ + def handle_cli_args(usage_message): """Handle common command line argument parsing and validation""" if len(sys.argv) != 3: @@ -56,6 +87,10 @@ def handle_cli_args(usage_message): return input_file, output_file +# ============================================================================ +# TENSOR UTILITIES +# ============================================================================ + def get_tensor_name(tensor): """Get tensor name as string, handling bytes conversion""" if tensor.name: @@ -64,26 +99,34 @@ def get_tensor_name(tensor): return None -def process_subgraphs(model, processor_func): - """Generic subgraph processor with modification tracking +def get_tensor_by_index(subgraph, index): + """Safely get tensor by its index.""" + if 0 <= index < len(subgraph.tensors): + return subgraph.tensors[index] + return None - Args: - model: Circle model object - processor_func: Function that processes a subgraph and returns (modified, changes_count) - Returns: - tuple: (overall_modified, total_changes) - """ - overall_modified = False - total_changes = 0 +def get_tensor_index_by_name(subgraph, name): + """Find tensor index by name, handling byte strings.""" + name_bytes = name.encode('utf-8') # Convert str to bytes for comparison + for i, tensor in enumerate(subgraph.tensors): + if tensor.name and tensor.name == name_bytes: + return i + return -1 # Not found - for subgraph in model.subgraphs: - modified, changes_count = processor_func(subgraph) - overall_modified = overall_modified or modified - total_changes += changes_count - return overall_modified, total_changes +def is_tensor_constant(tensor, model_buffers): + """Check if a tensor is constant by verifying its buffer.""" + if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): + # A non-zero buffer index that points to a valid buffer typically means it's constant. + # The 0th buffer is always an empty buffer. + return True + return False + +# ============================================================================ +# TENSOR PROCESSING FUNCTIONS +# ============================================================================ def rename_tensor_if_matches(tensor, pattern, replacement_func): """Rename tensor if it matches the given pattern @@ -111,22 +154,31 @@ def rename_tensor_if_matches(tensor, pattern, replacement_func): return False, None, None -def get_tensor_by_index(subgraph, index): - """Safely get tensor by its index.""" - if 0 <= index < len(subgraph.tensors): - return subgraph.tensors[index] - return None +def process_subgraphs(model, processor_func): + """Generic subgraph processor with modification tracking + Args: + model: Circle model object + processor_func: Function that processes a subgraph and returns (modified, changes_count) -def get_tensor_index_by_name(subgraph, name): - """Find tensor index by name, handling byte strings.""" - name_bytes = name.encode('utf-8') # Convert str to bytes for comparison - for i, tensor in enumerate(subgraph.tensors): - if tensor.name and tensor.name == name_bytes: - return i - return -1 # Not found + Returns: + tuple: (overall_modified, total_changes) + """ + overall_modified = False + total_changes = 0 + + for subgraph in model.subgraphs: + modified, changes_count = processor_func(subgraph) + overall_modified = overall_modified or modified + total_changes += changes_count + + return overall_modified, total_changes +# ============================================================================ +# OPERATOR UTILITIES +# ============================================================================ + def parse_operator_indices(indices_str): """Parse operator index string into a list of indices. @@ -185,15 +237,6 @@ def parse_operator_indices(indices_str): return sorted(list(indices)) -def is_tensor_constant(tensor, model_buffers): - """Check if a tensor is constant by verifying its buffer.""" - if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): - # A non-zero buffer index that points to a valid buffer typically means it's constant. - # The 0th buffer is always an empty buffer. - return True - return False - - def get_or_create_operator_code(model, builtin_op_type): """Get the index of an operator code, or create it if it doesn't exist.""" for i, op_code in enumerate(model.operatorCodes): @@ -207,17 +250,3 @@ def get_or_create_operator_code(model, builtin_op_type): new_op_code.version = 1 # Default version model.operatorCodes.append(new_op_code) return len(model.operatorCodes) - 1 - - -def safe_execute(main_func, - input_file, - output_file, - *args, - error_message="Error processing file"): - """Safely execute the main function with error handling""" - try: - main_func(input_file, output_file, *args) - log(f"Successfully processed {input_file} and saved to {output_file}") - except Exception as e: - log(f"{error_message}: {e}") - sys.exit(1) From a0146b331483b443fb6817bae9f1da5da8e266cc Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 19 Nov 2025 17:11:58 +0900 Subject: [PATCH 13/27] Add merge.py and reorder.output_move_0_to_last.py Also, update o2o coding format --- tools/circle2circle/merge.py | 286 ++++++++++++++++++ tools/circle2circle/o2o.py | 7 +- .../reorder.output.move_0_to_last.py | 43 +++ 3 files changed, 335 insertions(+), 1 deletion(-) create mode 100755 tools/circle2circle/merge.py create mode 100755 tools/circle2circle/reorder.output.move_0_to_last.py diff --git a/tools/circle2circle/merge.py b/tools/circle2circle/merge.py new file mode 100755 index 00000000000..c59ffbdeedd --- /dev/null +++ b/tools/circle2circle/merge.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 + +import sys +import argparse +import o2o +import circle + + +def get_operator_code_key(op_code): + """Generate a unique key for an OperatorCode to identify duplicates. + + Args: + op_code: Circle OperatorCode object + + Returns: + tuple: Unique key for the operator code + """ + if op_code.builtinCode is not None: + # Builtin operator + return ('builtin', op_code.builtinCode) + elif op_code.customCode is not None: + # Custom operator + custom_code = op_code.customCode + if isinstance(custom_code, bytes): + custom_code = custom_code.decode('utf-8') + return ('custom', custom_code) + else: + # Unknown case + return ('unknown', None) + + +def merge_operator_codes_with_deduplication(model1, model2): + """Merge operator codes from two models while removing duplicates. + + Args: + model1: First Circle model + model2: Second Circle model + + Returns: + tuple: (merged_operator_codes, model2_to_merged_mapping) + """ + # Start with first model's operator codes + merged_operator_codes = list(model1.operatorCodes) + + # Create mapping table for operator codes + # key: (builtinCode, customCode), value: new_index + opcode_mapping = {} + + # Register first model's operator codes + for i, op_code in enumerate(model1.operatorCodes): + key = get_operator_code_key(op_code) + opcode_mapping[key] = i + + # Process second model's operator codes (check for duplicates) + model2_to_merged_mapping = {} # model2's index → merged index + + for i, op_code in enumerate(model2.operatorCodes): + key = get_operator_code_key(op_code) + + if key in opcode_mapping: + # Duplicate operator code - use existing index + model2_to_merged_mapping[i] = opcode_mapping[key] + else: + # New operator code - add it + new_index = len(merged_operator_codes) + merged_operator_codes.append(op_code) + opcode_mapping[key] = new_index + model2_to_merged_mapping[i] = new_index + + return merged_operator_codes, model2_to_merged_mapping + + +def create_tensor_map_list(subgraph, tensor_indices): + """Convert tensor indices to TensorMap objects for SignatureDef. + + Args: + subgraph: Subgraph containing the tensors + tensor_indices: List of tensor indices + + Returns: + list: List of TensorMapT objects + """ + tensor_maps = [] + o2o.log(f"Creating tensor maps for {len(tensor_indices)} tensors") + + for i, tensor_idx in enumerate(tensor_indices): + # Skip optional inputs (-1 indicates unused optional input) + if tensor_idx == -1: + continue + + # Ensure tensor index is valid + if 0 <= tensor_idx < len(subgraph.tensors): + tensor_map = circle.TensorMapT() + + # Get tensor name, use fallback if no name exists + tensor_name = o2o.get_tensor_name(subgraph.tensors[tensor_idx]) + if not tensor_name: + tensor_name = f"tensor_{tensor_idx}" + + # Encode name as UTF-8 bytes for FlatBuffers compatibility + tensor_map.name = tensor_name.encode('utf-8') + tensor_map.tensorIndex = int(tensor_idx) # Convert numpy.int32 to int + + o2o.log(f" TensorMap {i}: name='{tensor_name}', index={tensor_idx}") + tensor_maps.append(tensor_map) + else: + o2o.log(f"Warning: Invalid tensor index {tensor_idx} in signature creation") + + return tensor_maps + + +def create_signatures(model, sig_name_0, sig_name_1): + """Create signature definitions for the merged model. + + Args: + model: Merged Circle model + sig_name_0: Name for first subgraph signature + sig_name_1: Name for second subgraph signature + """ + if not hasattr(model, 'signatureDefs'): + model.signatureDefs = [] + + # Create signature for first subgraph + sig0 = circle.SignatureDefT() + sig0.subgraphIndex = 0 + sig0.signatureKey = sig_name_0.encode('utf-8') + + # Create TensorMap lists for inputs and outputs + if model.subgraphs[0].inputs is not None and len(model.subgraphs[0].inputs) > 0: + sig0.inputs = create_tensor_map_list(model.subgraphs[0], model.subgraphs[0].inputs) + else: + sig0.inputs = [] + + if model.subgraphs[0].outputs is not None and len(model.subgraphs[0].outputs) > 0: + sig0.outputs = create_tensor_map_list(model.subgraphs[0], model.subgraphs[0].outputs) + else: + sig0.outputs = [] + + # Create signature for second subgraph + sig1 = circle.SignatureDefT() + sig1.subgraphIndex = 1 + sig1.signatureKey = sig_name_1.encode('utf-8') + + # Create TensorMap lists for inputs and outputs + if model.subgraphs[1].inputs is not None and len(model.subgraphs[1].inputs) > 0: + sig1.inputs = create_tensor_map_list(model.subgraphs[1], model.subgraphs[1].inputs) + else: + sig1.inputs = [] + + if model.subgraphs[1].outputs is not None and len(model.subgraphs[1].outputs) > 0: + sig1.outputs = create_tensor_map_list(model.subgraphs[1], model.subgraphs[1].outputs) + else: + sig1.outputs = [] + + model.signatureDefs = [sig0, sig1] + + +def merge_models_with_signatures(model1, model2, sig_name_0, sig_name_1): + """Merge two Circle models by keeping subgraphs separate and adding signatures. + + Args: + model1: First Circle model + model2: Second Circle model + sig_name_0: Signature name for first subgraph + sig_name_1: Signature name for second subgraph + + Returns: + circle.ModelT: Merged model with signatures + """ + # Validate that both models have exactly one subgraph + if not model1.subgraphs or len(model1.subgraphs) != 1: + o2o.log("Error: First model must have exactly one subgraph") + sys.exit(1) + + if not model2.subgraphs or len(model2.subgraphs) != 1: + o2o.log("Error: Second model must have exactly one subgraph") + sys.exit(1) + + o2o.log(f"Merging models:") + o2o.log(f" Model 1: {len(model1.subgraphs[0].tensors)} tensors, {len(model1.subgraphs[0].operators)} operators") + o2o.log(f" Model 2: {len(model2.subgraphs[0].tensors)} tensors, {len(model2.subgraphs[0].operators)} operators") + + # Step 1: Merge buffers (simple append) + merged_buffers = list(model1.buffers) + list(model2.buffers) + buffer_offset = len(model1.buffers) + + # Step 2: Merge operator codes with deduplication + merged_operator_codes, model2_opcode_mapping = merge_operator_codes_with_deduplication(model1, model2) + + # Step 3: Create merged subgraphs + merged_subgraphs = [] + + # First subgraph (keep as-is, no index remapping needed) + subgraph0 = model1.subgraphs[0] + merged_subgraphs.append(subgraph0) + + # Second subgraph (needs index remapping) + subgraph1 = model2.subgraphs[0] + + # Remap buffer indices in second subgraph tensors + for tensor in subgraph1.tensors: + if tensor.buffer is not None and tensor.buffer != 0: + tensor.buffer += buffer_offset + + # Remap operator code indices in second subgraph operators + for operator in subgraph1.operators: + if operator.opcodeIndex is not None: + operator.opcodeIndex = model2_opcode_mapping[operator.opcodeIndex] + + merged_subgraphs.append(subgraph1) + + # Step 4: Create final merged model + merged_model = circle.ModelT() + merged_model.buffers = merged_buffers + merged_model.operatorCodes = merged_operator_codes + merged_model.subgraphs = merged_subgraphs + + # Step 5: Create signatures + create_signatures(merged_model, sig_name_0, sig_name_1) + + o2o.log(f"Merge completed:") + o2o.log(f" Total buffers: {len(merged_buffers)}") + o2o.log(f" Total operator codes: {len(merged_operator_codes)}") + o2o.log(f" Total subgraphs: {len(merged_subgraphs)}") + o2o.log(f" Signatures: ['{sig_name_0}', '{sig_name_1}']") + + return merged_model + + +def main(): + """Main function to merge two Circle models with signatures.""" + parser = argparse.ArgumentParser( + description='Merge two Circle models by appending subgraphs with signatures' + ) + parser.add_argument('first_circle', help='First Circle model file') + parser.add_argument('second_circle', help='Second Circle model file') + parser.add_argument( + '--sig-names', + default='subgraph_0;subgraph_1', + help='Signature names for subgraphs, separated by semicolon (e.g., "prefill;decode")' + ) + args = parser.parse_args() + + # Parse signature names + sig_names = args.sig_names.split(';') + if len(sig_names) != 2: + o2o.log("Error: --sig-names must contain exactly 2 names separated by semicolon") + sys.exit(1) + + sig_name_0, sig_name_1 = sig_names[0].strip(), sig_names[1].strip() + + if not sig_name_0 or not sig_name_1: + o2o.log("Error: Signature names cannot be empty") + sys.exit(1) + + o2o.log(f"Loading models...") + o2o.log(f" First model: {args.first_circle}") + o2o.log(f" Second model: {args.second_circle}") + o2o.log(f" Signature names: ['{sig_name_0}', '{sig_name_1}']") + + # Load both models explicitly + try: + model0 = o2o.load_circle_model(args.first_circle) + model1 = o2o.load_circle_model(args.second_circle) + except Exception as e: + o2o.log(f"Error loading models: {e}") + sys.exit(1) + + # Merge models with signatures + try: + merged_model = merge_models_with_signatures(model0, model1, sig_name_0, sig_name_1) + except Exception as e: + o2o.log(f"Error merging models: {e}") + sys.exit(1) + + # Output to stdout + try: + o2o.save_model_to_stdout(merged_model) + o2o.log("Successfully saved merged model to stdout") + except Exception as e: + o2o.log(f"Error saving merged model: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index 5453c74ebbf..2a2eca93918 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -4,11 +4,11 @@ import circle import flatbuffers - # ============================================================================ # BASIC UTILITIES # ============================================================================ + def log(message): """Log message to stderr""" print(message, file=sys.stderr) @@ -32,6 +32,7 @@ def safe_execute(main_func, # CORE I/O FUNCTIONS # ============================================================================ + def load_circle_model(input_file=None): """Load and parse a circle model file""" if input_file is None: @@ -76,6 +77,7 @@ def save_model_to_stdout(model): # CLI HANDLING # ============================================================================ + def handle_cli_args(usage_message): """Handle common command line argument parsing and validation""" if len(sys.argv) != 3: @@ -91,6 +93,7 @@ def handle_cli_args(usage_message): # TENSOR UTILITIES # ============================================================================ + def get_tensor_name(tensor): """Get tensor name as string, handling bytes conversion""" if tensor.name: @@ -128,6 +131,7 @@ def is_tensor_constant(tensor, model_buffers): # TENSOR PROCESSING FUNCTIONS # ============================================================================ + def rename_tensor_if_matches(tensor, pattern, replacement_func): """Rename tensor if it matches the given pattern @@ -179,6 +183,7 @@ def process_subgraphs(model, processor_func): # OPERATOR UTILITIES # ============================================================================ + def parse_operator_indices(indices_str): """Parse operator index string into a list of indices. diff --git a/tools/circle2circle/reorder.output.move_0_to_last.py b/tools/circle2circle/reorder.output.move_0_to_last.py new file mode 100755 index 00000000000..1ffecfb8749 --- /dev/null +++ b/tools/circle2circle/reorder.output.move_0_to_last.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +import o2o + + +def reorder_output_tensors(): + """Reorder output tensors: move tensor 0 to the end, shift others forward""" + o2o.log("Loading model from stdin") + model = o2o.load_model_from_stdin() + + if not model.subgraphs: + o2o.log("Model has no subgraphs. Exiting.") + o2o.save_model_to_stdout(model) + return + + for subgraph_idx, subgraph in enumerate(model.subgraphs): + if len(subgraph.outputs) <= 1: + o2o.log( + f"Subgraph {subgraph_idx}: Only {len(subgraph.outputs)} output tensor(s), no reordering needed" + ) + continue + + # Convert numpy array to Python list for proper concatenation + original_outputs = subgraph.outputs.copy() + outputs_list = original_outputs.tolist() + + # Move first output tensor to the end + # Original: [a, b, c, d] -> New: [b, c, d, a] + first_output = outputs_list[0] + other_outputs = outputs_list[1:] + new_outputs = other_outputs + [first_output] + + subgraph.outputs = new_outputs + o2o.log( + f"Subgraph {subgraph_idx}: Reordered outputs {original_outputs.tolist()} -> {new_outputs}" + ) + + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + # Directly invoke processing; I/O handled via stdin/stdout + reorder_output_tensors() From 527daa288d46f2c91f6cf2ed0abbfe6fc0eee3e9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 19 Nov 2025 17:15:17 +0900 Subject: [PATCH 14/27] restore gen_circle.add.py --- tools/circle2circle/gen_circle.add.py | Bin 448 -> 6858 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/tools/circle2circle/gen_circle.add.py b/tools/circle2circle/gen_circle.add.py index 71b097257805b86608025bf2c21c5602d565a69d..a7923fc0d1d5a589121a909deaf8c9fe24f379ed 100755 GIT binary patch literal 6858 zcmc&(TWjM;7Jm1yIO9A--cG99aqm1Byx?4B2s4SB#J#Ze2BFC1SRhMUm6VPd=D+W` zRFx#(x~ExKLxQb3_sdtON+&ImzDZ>=lYlwaT+mua~TgysRKM?ps!`iw#5u zvcf%zlQK@V`;}khz1|62@k6-F7dnN@db!jkd^XXGhKD-U27>3v=mq2%2zzMam0m|` z`D`G{7WfyX8DCLckG4;mCO5oSN0u|KF>ejX30%tRC{~qT|;{!Bw7_ zD#8r@NTx=dsgFl873WzMC7FTg)%^4HcetL;rx0*&sw#ipRN4dspTYat9HK*643ezU zC3YL7cuno<`Fu`Mop^H~ts@A&Ce;e4iX^ok72J~CH5MiIzz>Ci0&p4p{YV)07wj>2 z-`gzjDchE{r8ae1?Nsl_Cd<-1w|Adr5A*5$eBAX`!bGK{fvj=yGiP9Os8)IdnZ~Bt zZdvnF8>FsEE#NdO33!4B9p`I=1>_n~=OZQvR^1!~sYp<4nUn%f*l_|NHjxgQq$^$i z01Mc*7AFW6j9=$)6U8gnSk&;QlZC9AJQ{GZA z25O_bfcJrEEyxN6Mh(0#bKHTx8i-w{bIx>p|m2sznsVdGlr^%i<5`gCY;cTaVQ&=7Lt2htqjGD$U77}a>c-{&&a5P^`p zot-&=Lz8D3rq|by7rKlp+~fizP12|&NI~gJOdTkc)?=|aACSvh{NzU>Z(`k2oZfjb zXx7xLE)DI~6BxCgI;B?#T?W^_Jhc|xmbJ6sBD+Pu)>TBsY7@q*EY2J3B=S@D1J#O) z>al+QmG($yxuZ}H61+mcVt3Qs1q)S;PoBm3+;YpWWzkwAOMi{~!*;Zt*H~azvxkS9 zt2s!;+047zO(TK}%0Es_I1XW_@U|Vgp(hkc5LkfW>4ydunV0M;%{$I-E|e^1D6nco z!+ya&pppp}&J(z|B061EcJ&6LU-3|bn1PWn$e8JNH3^MTQBfeI3IJ~VH*nQT(-Q`v zGuv1v1M8H?W(pUbzv_gDcXEHgiMIBZ?H6Rb+k+g^IyIaaiqshle?F(t0iqP=Puyg< z0+=u4P)=~@SByo$m*LYHEC{V9_zclzgi5el9vM;^Wv@|M>y{j=^D zg|jq|hi%y+X*-46fD+Ig!4Hq-@}dnZ$8#wSaUhq+QrTfVo=T2Ehu-Io(@ZC3j8&66 zrlpgerStf?E6F9&2j+M>+g9ub={wTN7Whe}3mHNFjUHo|Ifx?16J=ZdK&Fm%U6=Jc za@A$<$r?$cu@_CUYz!<_Ky!^;i%5LCvRt|_QpXe3VcR3=J5oinmG4QB=UcBMet++E zrtKIfUrs&MSzgdj#5_BY&g@K1c! z63al+mS+o$cEUZy#oQCXvTyU)$j=J8A2oA2m;0jf1;T4;x>#7UmH=LASQ?scR7AoO z56+8eTqUo3O8uqI@H-#%Y#Hyakwxl$c7M~->B`w!gW4s^tc2;Bp0j|rLwkLL z`kOHpAywi8d4Bi6OL(Mf2qtDWycbsnABEzPJ#q8;nH}%s;O(i?{3TfS?LMIqDhI`NbvH z9pS$uPpq1+^AjZWbz=MWXU@`QsO14}qx`%gIVMqxILwq*^{V=2>ZW#9E&gKMP z2V`^;qusHADuj2pWtp!bu1mZ^s=&l$QdD?LB#T^kCm^p2uy53h&TO$NeC2~z`Seh# zmuZEc4Z%*!jVhvQrRcp5=?X$F=P>MSkfaQWTdbxG$56>h literal 448 zcmZutF$%&!5FBF&NelX2#97ZBWI znB8M%XYWz~tMz7)AjK33BGx1L%)q+CmV6rkL$bmrmPAIVNA>;Fwm@LCfxXiV%6{U> zz89f;X@R(RmTzqDXXG7mAPPbrRbJL4Gt+6s+B&pac``!1L;1YzygcTDLf(REx1Zj2 zB{D)@59*}1Z}v1#d2^}^yBX?v+g>~0(@tX|FN*!Ksm~YY(mIVeGQT}B|K)bIyYyb= MA20Qc2;Usw3)`F}0RR91 From 13a71b5022a99049dd65a78b817facba6e30a07e Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 19 Nov 2025 18:02:20 +0900 Subject: [PATCH 15/27] Revise merge.circle.py --- tools/circle2circle/README.md | 22 ++++++++ .../{merge.py => merge.circle.py} | 55 +++++++++++++------ 2 files changed, 60 insertions(+), 17 deletions(-) rename tools/circle2circle/{merge.py => merge.circle.py} (81%) diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md index d7eb0264bd5..4b6c4f0b96e 100644 --- a/tools/circle2circle/README.md +++ b/tools/circle2circle/README.md @@ -150,3 +150,25 @@ Generates a simple Circle model with one `ADD` operator for testing basic functi #### `gen_circle.bmm_lhs_const.fc.py` Generates a test Circle model with `BATCH_MATMUL` and `TRANSPOSE` operations where the LHS is constant. This model is designed to test the fusion pattern used in `fuse.bmm_lhs_const.py`. + +## `merge.circle.py` + +Merges two Circle model files into a single model by appending their subgraphs and adding signatures. The script accepts one or more Circle files (currently limited to two). + +- **Positional arguments**: + `circles` – one or more Circle model files to merge (e.g., `in1.circle in2.circle`). + +- **Optional arguments**: + `--sig-names` – semicolon‑separated signature names for the subgraphs (e.g., `"prefill;decode"`). If omitted, the script derives the signature names from the input filenames by stripping the `.circle` extension. + +### Usage examples + +```bash +# Merge two models, using filenames as signature names +./merge.circle.py model1.circle model2.circle + +# Merge with custom signature names +./merge.circle.py model1.circle model2.circle --sig-names "prefill;decode" +``` + +The merged model is written to **standard output**, allowing it to be piped into other tools or redirected to a file. diff --git a/tools/circle2circle/merge.py b/tools/circle2circle/merge.circle.py similarity index 81% rename from tools/circle2circle/merge.py rename to tools/circle2circle/merge.circle.py index c59ffbdeedd..6d81d0568da 100755 --- a/tools/circle2circle/merge.py +++ b/tools/circle2circle/merge.circle.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import sys +import os import argparse import o2o import circle @@ -229,39 +230,59 @@ def merge_models_with_signatures(model1, model2, sig_name_0, sig_name_1): def main(): """Main function to merge two Circle models with signatures.""" + # This script merges multiple Circle model files into a single model. + # It keeps each input model as a separate subgraph and adds a signature + # for each subgraph. If signature names are not provided via --sig-names, + # they are derived from the input filenames (without the .circle extension). parser = argparse.ArgumentParser( - description='Merge two Circle models by appending subgraphs with signatures' + description='Merge multiple Circle models (as subgraphs) with signatures' ) - parser.add_argument('first_circle', help='First Circle model file') - parser.add_argument('second_circle', help='Second Circle model file') + # One or more Circle model files to merge, e.g. in1.circle in2.circle ... + parser.add_argument('circles', nargs='+', help='Circle model files to merge (e.g., in1.circle in2.circle ...)') + # Optional signature names for each subgraph, separated by semicolons. + # Must match the number of input files. If omitted, names are taken from the + # input filenames (without the .circle extension). parser.add_argument( '--sig-names', - default='subgraph_0;subgraph_1', - help='Signature names for subgraphs, separated by semicolon (e.g., "prefill;decode")' + default=None, + help='Signature names for subgraphs (semicolon‑separated). If omitted, derived from input filenames.' ) args = parser.parse_args() - # Parse signature names - sig_names = args.sig_names.split(';') - if len(sig_names) != 2: - o2o.log("Error: --sig-names must contain exactly 2 names separated by semicolon") + # Currently only support 2 models + if len(args.circles) != 2: + o2o.log("Error: Currently only 2 Circle models are supported") sys.exit(1) - sig_name_0, sig_name_1 = sig_names[0].strip(), sig_names[1].strip() + # Parse signature names + if args.sig_names is None: + # Use filenames without .circle extension as signature names + sig_names = [os.path.splitext(os.path.basename(f))[0] for f in args.circles] + else: + # Use user-provided signature names + sig_names = args.sig_names.split(';') + if len(sig_names) != len(args.circles): + o2o.log(f"Error: --sig-names must contain exactly {len(args.circles)} names separated by semicolon") + sys.exit(1) + sig_names = [name.strip() for name in sig_names] - if not sig_name_0 or not sig_name_1: - o2o.log("Error: Signature names cannot be empty") - sys.exit(1) + # Validate signature names are not empty + for i, sig_name in enumerate(sig_names): + if not sig_name: + o2o.log(f"Error: Signature name {i+1} cannot be empty") + sys.exit(1) + + sig_name_0, sig_name_1 = sig_names[0], sig_names[1] o2o.log(f"Loading models...") - o2o.log(f" First model: {args.first_circle}") - o2o.log(f" Second model: {args.second_circle}") + o2o.log(f" First model: {args.circles[0]}") + o2o.log(f" Second model: {args.circles[1]}") o2o.log(f" Signature names: ['{sig_name_0}', '{sig_name_1}']") # Load both models explicitly try: - model0 = o2o.load_circle_model(args.first_circle) - model1 = o2o.load_circle_model(args.second_circle) + model0 = o2o.load_circle_model(args.circles[0]) + model1 = o2o.load_circle_model(args.circles[1]) except Exception as e: o2o.log(f"Error loading models: {e}") sys.exit(1) From 7dcf908180e3ce5d32e15efcbf66e42ef8c7129e Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 20 Nov 2025 14:29:59 +0900 Subject: [PATCH 16/27] Add type annotation --- tools/circle2circle/fuse.attention.py | 15 ++- tools/circle2circle/fuse.bmm_lhs_const.py | 28 +++-- tools/circle2circle/gc.py | 18 ++- tools/circle2circle/gen_circle.add.py | 9 +- .../gen_circle.bmm_lhs_const.fc.py | 7 +- tools/circle2circle/merge.circle.py | 117 +++++++----------- tools/circle2circle/o2o.py | 62 +++++++--- tools/circle2circle/remove.io.py | 11 +- .../rename.io.remove_namespace.py | 11 +- .../circle2circle/rename.io.remove_prefix.py | 6 +- .../reorder.output.move_0_to_last.py | 7 +- tools/circle2circle/reshape.fc_weight.py | 10 +- tools/circle2circle/reshape.io.py | 11 +- tools/circle2circle/retype.input_ids.py | 4 + tools/circle2circle/select.op.py | 13 +- tools/circle2circle/transpose.io.kvcache.py | 7 +- 16 files changed, 216 insertions(+), 120 deletions(-) diff --git a/tools/circle2circle/fuse.attention.py b/tools/circle2circle/fuse.attention.py index e8cf215614e..bc67fd19bd1 100755 --- a/tools/circle2circle/fuse.attention.py +++ b/tools/circle2circle/fuse.attention.py @@ -1,11 +1,18 @@ #!/usr/bin/env python3 import numpy as np +from typing import List, Optional, Tuple, Dict, Any import circle import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def find_operator_by_output(subgraph, output_tensor_index): + +def find_operator_by_output( + subgraph: 'circle.SubGraphT', + output_tensor_index: int) -> Tuple[Optional[int], Optional['circle.OperatorT']]: """Find the first operator that produces the given output tensor index.""" for op_idx, operator in enumerate(subgraph.operators): if operator.outputs and output_tensor_index in operator.outputs: @@ -13,7 +20,8 @@ def find_operator_by_output(subgraph, output_tensor_index): return None, None -def find_attention_blocks(model, subgraph): +def find_attention_blocks(model: 'circle.ModelT', + subgraph: 'circle.SubGraphT') -> List[Dict[str, Any]]: """Find all attention blocks in the subgraph.""" attention_blocks = [] @@ -67,7 +75,8 @@ def find_attention_blocks(model, subgraph): return attention_blocks -def map_attention_inputs(subgraph, block, model): +def map_attention_inputs(subgraph: 'circle.SubGraphT', block: Dict[str, Any], + model: 'circle.ModelT') -> Optional[List[int]]: """Map the 11 input tensors for attention fusion.""" layer_idx = block['layer_idx'] diff --git a/tools/circle2circle/fuse.bmm_lhs_const.py b/tools/circle2circle/fuse.bmm_lhs_const.py index 002d51357c2..d2582c8928e 100755 --- a/tools/circle2circle/fuse.bmm_lhs_const.py +++ b/tools/circle2circle/fuse.bmm_lhs_const.py @@ -1,18 +1,25 @@ #!/usr/bin/env python3 import numpy as np +from typing import List, Optional, Tuple, Dict, Any import circle import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def get_tensor_by_index(subgraph, index): + +def get_tensor_by_index(subgraph: 'circle.SubGraphT', + index: int) -> Optional['circle.TensorT']: """Safely get tensor by its index.""" if 0 <= index < len(subgraph.tensors): return subgraph.tensors[index] return None -def is_tensor_constant(tensor, model_buffers): +def is_tensor_constant(tensor: 'circle.TensorT', + model_buffers: List['circle.BufferT']) -> bool: """Check if a tensor is constant by verifying its buffer.""" if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): # A non-zero buffer index that points to a valid buffer typically means it's constant. @@ -21,7 +28,9 @@ def is_tensor_constant(tensor, model_buffers): return False -def find_operator_by_output(subgraph, output_tensor_index): +def find_operator_by_output( + subgraph: 'circle.SubGraphT', + output_tensor_index: int) -> Tuple[Optional[int], Optional['circle.OperatorT']]: """Find the first operator that produces the given output tensor index.""" for op_idx, operator in enumerate(subgraph.operators): if operator.outputs and output_tensor_index in operator.outputs: @@ -65,7 +74,8 @@ def count_tensor_usage(model, tensor_index): return count -def get_or_create_operator_code(model, builtin_op_type): +def get_or_create_operator_code(model: 'circle.ModelT', + builtin_op_type: 'circle.BuiltinOperator') -> int: """Get the index of an operator code, or create it if it doesn't exist.""" for i, op_code in enumerate(model.operatorCodes): if op_code.builtinCode == builtin_op_type: @@ -79,7 +89,8 @@ def get_or_create_operator_code(model, builtin_op_type): return len(model.operatorCodes) - 1 -def create_transpose_permutation_tensor(model, subgraph, rank): +def create_transpose_permutation_tensor(model: 'circle.ModelT', + subgraph: 'circle.SubGraphT', rank: int) -> int: """Create a permutation tensor for transposing last two dimensions.""" # Create permutation: [0, 1, ..., rank-3, rank-1, rank-2] perm_shape = [rank] @@ -104,8 +115,9 @@ def create_transpose_permutation_tensor(model, subgraph, rank): return tensor_index -def add_rhs_transpose_if_needed(model, subgraph, bmm_op_idx, rhs_tensor_index, - rhs_tensor): +def add_rhs_transpose_if_needed(model: 'circle.ModelT', subgraph: 'circle.SubGraphT', + bmm_op_idx: int, rhs_tensor_index: int, + rhs_tensor: 'circle.TensorT') -> int: """Add TRANSPOSE operator for RHS if K != 1 OR B != 1.""" if len(rhs_tensor.shape) < 3: # Need at least 3 dimensions: [B, K, N] @@ -153,7 +165,7 @@ def add_rhs_transpose_if_needed(model, subgraph, bmm_op_idx, rhs_tensor_index, return transposed_rhs_tensor_index -def fuse_bmm_transpose(): +def fuse_bmm_transpose() -> None: """Main function to add RHS transpose before fusing batchmatmul(lhs, rhs) to fullyconnected(transposed_rhs, lhs) when lhs is constant.""" o2o.log("Loading model from stdin") model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/gc.py b/tools/circle2circle/gc.py index 43feb2f18b8..d55250e4c48 100755 --- a/tools/circle2circle/gc.py +++ b/tools/circle2circle/gc.py @@ -3,9 +3,14 @@ import sys import circle import o2o +from typing import List, Optional +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def get_tensor_name(tensor): + +def get_tensor_name(tensor: TensorT) -> Optional[str]: """Get tensor name as string, handling bytes conversion""" if tensor.name: return tensor.name.decode('utf-8') if isinstance(tensor.name, @@ -13,7 +18,7 @@ def get_tensor_name(tensor): return None -def find_unused_tensors_in_subgraph(subgraph): +def find_unused_tensors_in_subgraph(subgraph) -> List[int]: """ Finds and returns the indices of unused tensors in a given subgraph. This function uses the Native API for read-only subgraph objects. @@ -54,7 +59,7 @@ def find_unused_tensors_in_subgraph(subgraph): return unused_indices -def find_unused_buffers(model): +def find_unused_buffers(model) -> List[int]: """ Finds and returns the indices of unused buffers in the model. This function works with both Native API (read-only) and Object API (mutable) model objects. @@ -111,8 +116,8 @@ def find_unused_buffers(model): return unused_indices -def remove_tensors_and_update_model(model, subgraph_index_to_modify, - tensor_indices_to_remove): +def remove_tensors_and_update_model(model: ModelT, subgraph_index_to_modify: int, + tensor_indices_to_remove: List[int]) -> List[int]: """ Removes specified tensors from the model and updates all relevant references. This function uses the Object API for mutable model/subgraph/operator objects. @@ -213,7 +218,8 @@ def remove_tensors_and_update_model(model, subgraph_index_to_modify, return sorted(removed_indices) -def remove_buffers_and_update_model(model, buffer_indices_to_remove): +def remove_buffers_and_update_model(model: ModelT, + buffer_indices_to_remove: List[int]) -> List[int]: """ Removes specified buffers from the model and updates all tensor references. This function uses the Object API for mutable model objects. diff --git a/tools/circle2circle/gen_circle.add.py b/tools/circle2circle/gen_circle.add.py index a7923fc0d1d..f3a47bee4c3 100755 --- a/tools/circle2circle/gen_circle.add.py +++ b/tools/circle2circle/gen_circle.add.py @@ -4,6 +4,11 @@ import numpy as np import circle import o2o +from typing import List + +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) # Circle Model Buffer Usage Rules (based on circle_schema.fbs and analysis) # ====================================================================== @@ -53,7 +58,8 @@ # # Reference: circle_schema.fbs - "buffer:uint" field documentation -def create_simple_add_model(output_file): + +def create_simple_add_model(output_file: str): """Create a simple Circle model with one ADD operator (similar to add.circle).""" # Create model @@ -159,6 +165,7 @@ def create_simple_add_model(output_file): o2o.log(f" Subgraph inputs: {[subgraph.tensors[i].name for i in subgraph.inputs]}") o2o.log(f" Subgraph outputs: {[subgraph.tensors[i].name for i in subgraph.outputs]}") + if __name__ == "__main__": # Generate output filename from current script filename # e.g., add.gen_circle.py -> add.circle diff --git a/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py b/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py index a60eb939433..edb5dca2551 100755 --- a/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py +++ b/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py @@ -4,9 +4,14 @@ import numpy as np import circle import o2o +from typing import List +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def create_test_bmm_k_not_1_model(output_file): + +def create_test_bmm_k_not_1_model(output_file: str): """Create a test Circle model with BATCH_MATMUL where RHS K != 1.""" # Create model diff --git a/tools/circle2circle/merge.circle.py b/tools/circle2circle/merge.circle.py index 6d81d0568da..4beb7bd21f9 100755 --- a/tools/circle2circle/merge.circle.py +++ b/tools/circle2circle/merge.circle.py @@ -3,34 +3,18 @@ import sys import os import argparse +from typing import List, Optional, Tuple, Dict, Any import o2o import circle +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType, SignatureDefT, TensorMapT) -def get_operator_code_key(op_code): - """Generate a unique key for an OperatorCode to identify duplicates. - Args: - op_code: Circle OperatorCode object - - Returns: - tuple: Unique key for the operator code - """ - if op_code.builtinCode is not None: - # Builtin operator - return ('builtin', op_code.builtinCode) - elif op_code.customCode is not None: - # Custom operator - custom_code = op_code.customCode - if isinstance(custom_code, bytes): - custom_code = custom_code.decode('utf-8') - return ('custom', custom_code) - else: - # Unknown case - return ('unknown', None) - - -def merge_operator_codes_with_deduplication(model1, model2): +def merge_operator_codes_with_deduplication( + model1: 'circle.ModelT', + model2: 'circle.ModelT') -> Tuple[List['circle.OperatorCodeT'], Dict[int, int]]: """Merge operator codes from two models while removing duplicates. Args: @@ -49,14 +33,14 @@ def merge_operator_codes_with_deduplication(model1, model2): # Register first model's operator codes for i, op_code in enumerate(model1.operatorCodes): - key = get_operator_code_key(op_code) + key = o2o.get_operator_code_key(op_code) opcode_mapping[key] = i # Process second model's operator codes (check for duplicates) model2_to_merged_mapping = {} # model2's index → merged index for i, op_code in enumerate(model2.operatorCodes): - key = get_operator_code_key(op_code) + key = o2o.get_operator_code_key(op_code) if key in opcode_mapping: # Duplicate operator code - use existing index @@ -71,7 +55,8 @@ def merge_operator_codes_with_deduplication(model1, model2): return merged_operator_codes, model2_to_merged_mapping -def create_tensor_map_list(subgraph, tensor_indices): +def create_tensor_map_list(subgraph: 'circle.SubGraphT', + tensor_indices: List[int]) -> List['circle.TensorMapT']: """Convert tensor indices to TensorMap objects for SignatureDef. Args: @@ -110,53 +95,36 @@ def create_tensor_map_list(subgraph, tensor_indices): return tensor_maps -def create_signatures(model, sig_name_0, sig_name_1): +def create_signatures(model: 'circle.ModelT', sig_names: List[str]) -> None: """Create signature definitions for the merged model. Args: model: Merged Circle model - sig_name_0: Name for first subgraph signature - sig_name_1: Name for second subgraph signature + sig_names: List of signature names for each subgraph (must match subgraph count) """ if not hasattr(model, 'signatureDefs'): model.signatureDefs = [] - # Create signature for first subgraph - sig0 = circle.SignatureDefT() - sig0.subgraphIndex = 0 - sig0.signatureKey = sig_name_0.encode('utf-8') + signatures = [] - # Create TensorMap lists for inputs and outputs - if model.subgraphs[0].inputs is not None and len(model.subgraphs[0].inputs) > 0: - sig0.inputs = create_tensor_map_list(model.subgraphs[0], model.subgraphs[0].inputs) - else: - sig0.inputs = [] + for idx, sig_name in enumerate(sig_names): + sig = circle.SignatureDefT() + sig.subgraphIndex = idx + sig.signatureKey = sig_name.encode('utf-8') - if model.subgraphs[0].outputs is not None and len(model.subgraphs[0].outputs) > 0: - sig0.outputs = create_tensor_map_list(model.subgraphs[0], model.subgraphs[0].outputs) - else: - sig0.outputs = [] + subgraph = model.subgraphs[idx] + sig.inputs = create_tensor_map_list(subgraph, subgraph.inputs) if list( + subgraph.inputs) else [] + sig.outputs = create_tensor_map_list(subgraph, subgraph.outputs) if list( + subgraph.outputs) else [] - # Create signature for second subgraph - sig1 = circle.SignatureDefT() - sig1.subgraphIndex = 1 - sig1.signatureKey = sig_name_1.encode('utf-8') + signatures.append(sig) - # Create TensorMap lists for inputs and outputs - if model.subgraphs[1].inputs is not None and len(model.subgraphs[1].inputs) > 0: - sig1.inputs = create_tensor_map_list(model.subgraphs[1], model.subgraphs[1].inputs) - else: - sig1.inputs = [] - - if model.subgraphs[1].outputs is not None and len(model.subgraphs[1].outputs) > 0: - sig1.outputs = create_tensor_map_list(model.subgraphs[1], model.subgraphs[1].outputs) - else: - sig1.outputs = [] - - model.signatureDefs = [sig0, sig1] + model.signatureDefs = signatures -def merge_models_with_signatures(model1, model2, sig_name_0, sig_name_1): +def merge_models_with_signatures(model1: 'circle.ModelT', model2: 'circle.ModelT', + sig_name_0: str, sig_name_1: str) -> 'circle.ModelT': """Merge two Circle models by keeping subgraphs separate and adding signatures. Args: @@ -178,15 +146,20 @@ def merge_models_with_signatures(model1, model2, sig_name_0, sig_name_1): sys.exit(1) o2o.log(f"Merging models:") - o2o.log(f" Model 1: {len(model1.subgraphs[0].tensors)} tensors, {len(model1.subgraphs[0].operators)} operators") - o2o.log(f" Model 2: {len(model2.subgraphs[0].tensors)} tensors, {len(model2.subgraphs[0].operators)} operators") + o2o.log( + f" Model 1: {len(model1.subgraphs[0].tensors)} tensors, {len(model1.subgraphs[0].operators)} operators" + ) + o2o.log( + f" Model 2: {len(model2.subgraphs[0].tensors)} tensors, {len(model2.subgraphs[0].operators)} operators" + ) # Step 1: Merge buffers (simple append) merged_buffers = list(model1.buffers) + list(model2.buffers) buffer_offset = len(model1.buffers) # Step 2: Merge operator codes with deduplication - merged_operator_codes, model2_opcode_mapping = merge_operator_codes_with_deduplication(model1, model2) + merged_operator_codes, model2_opcode_mapping = merge_operator_codes_with_deduplication( + model1, model2) # Step 3: Create merged subgraphs merged_subgraphs = [] @@ -217,7 +190,7 @@ def merge_models_with_signatures(model1, model2, sig_name_0, sig_name_1): merged_model.subgraphs = merged_subgraphs # Step 5: Create signatures - create_signatures(merged_model, sig_name_0, sig_name_1) + create_signatures(merged_model, [sig_name_0, sig_name_1]) o2o.log(f"Merge completed:") o2o.log(f" Total buffers: {len(merged_buffers)}") @@ -235,17 +208,20 @@ def main(): # for each subgraph. If signature names are not provided via --sig-names, # they are derived from the input filenames (without the .circle extension). parser = argparse.ArgumentParser( - description='Merge multiple Circle models (as subgraphs) with signatures' - ) + description='Merge multiple Circle models (as subgraphs) with signatures') # One or more Circle model files to merge, e.g. in1.circle in2.circle ... - parser.add_argument('circles', nargs='+', help='Circle model files to merge (e.g., in1.circle in2.circle ...)') + parser.add_argument( + 'circles', + nargs='+', + help='Circle model files to merge (e.g., in1.circle in2.circle ...)') # Optional signature names for each subgraph, separated by semicolons. # Must match the number of input files. If omitted, names are taken from the # input filenames (without the .circle extension). parser.add_argument( '--sig-names', default=None, - help='Signature names for subgraphs (semicolon‑separated). If omitted, derived from input filenames.' + help= + 'Signature names for subgraphs (semicolon‑separated). If omitted, derived from input filenames.' ) args = parser.parse_args() @@ -262,7 +238,9 @@ def main(): # Use user-provided signature names sig_names = args.sig_names.split(';') if len(sig_names) != len(args.circles): - o2o.log(f"Error: --sig-names must contain exactly {len(args.circles)} names separated by semicolon") + o2o.log( + f"Error: --sig-names must contain exactly {len(args.circles)} names separated by semicolon" + ) sys.exit(1) sig_names = [name.strip() for name in sig_names] @@ -289,7 +267,8 @@ def main(): # Merge models with signatures try: - merged_model = merge_models_with_signatures(model0, model1, sig_name_0, sig_name_1) + merged_model = merge_models_with_signatures(model0, model1, sig_name_0, + sig_name_1) except Exception as e: o2o.log(f"Error merging models: {e}") sys.exit(1) diff --git a/tools/circle2circle/o2o.py b/tools/circle2circle/o2o.py index 2a2eca93918..130fd120daf 100755 --- a/tools/circle2circle/o2o.py +++ b/tools/circle2circle/o2o.py @@ -3,6 +3,11 @@ import sys import circle import flatbuffers +from typing import Callable, List, Optional, Tuple, Union + +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) # ============================================================================ # BASIC UTILITIES @@ -14,11 +19,11 @@ def log(message): print(message, file=sys.stderr) -def safe_execute(main_func, - input_file, - output_file, +def safe_execute(main_func: Callable, + input_file: str, + output_file: str, *args, - error_message="Error processing file"): + error_message: str = "Error processing file") -> None: """Safely execute the main function with error handling""" try: main_func(input_file, output_file, *args) @@ -33,7 +38,7 @@ def safe_execute(main_func, # ============================================================================ -def load_circle_model(input_file=None): +def load_circle_model(input_file: Optional[str] = None) -> 'circle.ModelT': """Load and parse a circle model file""" if input_file is None: # Read from stdin @@ -54,7 +59,7 @@ def load_model_from_stdin(): return load_circle_model() # input_file=None defaults to stdin -def save_circle_model(model, output_file=None): +def save_circle_model(model: 'circle.ModelT', output_file: Optional[str] = None) -> None: """Save a circle model to file using flatbuffers""" builder = flatbuffers.Builder(1024) builder.Finish(model.Pack(builder), b'CIR0') @@ -68,7 +73,7 @@ def save_circle_model(model, output_file=None): f.write(builder.Output()) -def save_model_to_stdout(model): +def save_model_to_stdout(model) -> None: """Serialize a Circle model and write it to stdout as binary data.""" save_circle_model(model) # output_file=None defaults to stdout @@ -94,7 +99,7 @@ def handle_cli_args(usage_message): # ============================================================================ -def get_tensor_name(tensor): +def get_tensor_name(tensor: 'circle.TensorT') -> Optional[str]: """Get tensor name as string, handling bytes conversion""" if tensor.name: return tensor.name.decode('utf-8') if isinstance(tensor.name, @@ -102,14 +107,15 @@ def get_tensor_name(tensor): return None -def get_tensor_by_index(subgraph, index): +def get_tensor_by_index(subgraph: 'circle.SubGraphT', + index: int) -> Optional['circle.TensorT']: """Safely get tensor by its index.""" if 0 <= index < len(subgraph.tensors): return subgraph.tensors[index] return None -def get_tensor_index_by_name(subgraph, name): +def get_tensor_index_by_name(subgraph, name: str) -> int: """Find tensor index by name, handling byte strings.""" name_bytes = name.encode('utf-8') # Convert str to bytes for comparison for i, tensor in enumerate(subgraph.tensors): @@ -118,7 +124,7 @@ def get_tensor_index_by_name(subgraph, name): return -1 # Not found -def is_tensor_constant(tensor, model_buffers): +def is_tensor_constant(tensor, model_buffers: List) -> bool: """Check if a tensor is constant by verifying its buffer.""" if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): # A non-zero buffer index that points to a valid buffer typically means it's constant. @@ -132,7 +138,9 @@ def is_tensor_constant(tensor, model_buffers): # ============================================================================ -def rename_tensor_if_matches(tensor, pattern, replacement_func): +def rename_tensor_if_matches( + tensor, pattern: str, + replacement_func: Callable) -> Tuple[bool, Optional[str], Optional[str]]: """Rename tensor if it matches the given pattern Args: @@ -158,7 +166,7 @@ def rename_tensor_if_matches(tensor, pattern, replacement_func): return False, None, None -def process_subgraphs(model, processor_func): +def process_subgraphs(model, processor_func: Callable) -> Tuple[bool, int]: """Generic subgraph processor with modification tracking Args: @@ -184,7 +192,7 @@ def process_subgraphs(model, processor_func): # ============================================================================ -def parse_operator_indices(indices_str): +def parse_operator_indices(indices_str: str) -> List[int]: """Parse operator index string into a list of indices. Supports formats like: @@ -242,7 +250,31 @@ def parse_operator_indices(indices_str): return sorted(list(indices)) -def get_or_create_operator_code(model, builtin_op_type): +def get_operator_code_key( + op_code: 'circle.OperatorCodeT') -> Tuple[str, Union[int, str, None]]: + """Generate a unique key for an OperatorCode to identify duplicates. + + Args: + op_code: Circle OperatorCode object + + Returns: + tuple: Unique key for the operator code + """ + if op_code.builtinCode is not None: + # Builtin operator + return ('builtin', op_code.builtinCode) + elif op_code.customCode is not None: + # Custom operator + custom_code = op_code.customCode + if isinstance(custom_code, bytes): + custom_code = custom_code.decode('utf-8') + return ('custom', custom_code) + else: + # Unknown case + return ('unknown', None) + + +def get_or_create_operator_code(model, builtin_op_type) -> int: """Get the index of an operator code, or create it if it doesn't exist.""" for i, op_code in enumerate(model.operatorCodes): if op_code.builtinCode == builtin_op_type: diff --git a/tools/circle2circle/remove.io.py b/tools/circle2circle/remove.io.py index 14bb4acfb72..df4439bb319 100755 --- a/tools/circle2circle/remove.io.py +++ b/tools/circle2circle/remove.io.py @@ -2,15 +2,20 @@ import sys import argparse +from typing import List import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def parse_names(names_str): + +def parse_names(names_str: str) -> List[str]: """Parse comma‑separated tensor names into a list of names.""" return [name.strip() for name in names_str.split(',') if name.strip()] -def remove_io_tensors(io_type, names_to_keep): +def remove_io_tensors(io_type: str, names_to_keep: List[str]) -> None: """Remove input or output tensors, keeping only specified tensor names""" # Load the model using utility function model = o2o.load_model_from_stdin() @@ -66,7 +71,7 @@ def process_subgraph(subgraph): o2o.save_model_to_stdout(model) -def remove_io_tensors_by_id(io_type, ids_to_keep): +def remove_io_tensors_by_id(io_type: str, ids_to_keep: List[int]) -> None: """Remove input or output tensors, keeping only specified tensor indices (IDs)""" model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/rename.io.remove_namespace.py b/tools/circle2circle/rename.io.remove_namespace.py index 976cce6606c..a236c453070 100755 --- a/tools/circle2circle/rename.io.remove_namespace.py +++ b/tools/circle2circle/rename.io.remove_namespace.py @@ -3,10 +3,15 @@ import sys import circle import flatbuffers +from typing import Tuple import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def load_model_from_stdin(): + +def load_model_from_stdin() -> 'circle.ModelT': """Load a Circle model from binary data read from stdin.""" data = sys.stdin.buffer.read() buf = bytearray(data) @@ -15,14 +20,14 @@ def load_model_from_stdin(): return model -def save_model_to_stdout(model): +def save_model_to_stdout(model: 'circle.ModelT'): """Serialize a Circle model and write it to stdout as binary data.""" builder = flatbuffers.Builder(1024) builder.Finish(model.Pack(builder), b'CIR0') sys.stdout.buffer.write(builder.Output()) -def remove_namespace_from_inputs_and_outputs(model): +def remove_namespace_from_inputs_and_outputs(model: 'circle.ModelT'): """Remove namespace from tensor names within the given model.""" pattern = r'(.*)::(.*)' diff --git a/tools/circle2circle/rename.io.remove_prefix.py b/tools/circle2circle/rename.io.remove_prefix.py index b5b200486d9..debee161429 100755 --- a/tools/circle2circle/rename.io.remove_prefix.py +++ b/tools/circle2circle/rename.io.remove_prefix.py @@ -4,8 +4,12 @@ import re import sys +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def rename_input_tensors(prefix): + +def rename_input_tensors(prefix: str): """Main function to rename tensors by removing the specified prefix""" # Load the model using utility function model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/reorder.output.move_0_to_last.py b/tools/circle2circle/reorder.output.move_0_to_last.py index 1ffecfb8749..e05e336d897 100755 --- a/tools/circle2circle/reorder.output.move_0_to_last.py +++ b/tools/circle2circle/reorder.output.move_0_to_last.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 +from typing import List import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def reorder_output_tensors(): + +def reorder_output_tensors() -> None: """Reorder output tensors: move tensor 0 to the end, shift others forward""" o2o.log("Loading model from stdin") model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/reshape.fc_weight.py b/tools/circle2circle/reshape.fc_weight.py index afb5b4a0c6e..a1bdc8ff7b5 100755 --- a/tools/circle2circle/reshape.fc_weight.py +++ b/tools/circle2circle/reshape.fc_weight.py @@ -1,16 +1,21 @@ #!/usr/bin/env python3 import numpy as np +from typing import List, Dict import circle import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) + def is_effectively_2d(shape): """Check if a tensor shape is effectively 2D (all leading dimensions are 1)""" return all(dim == 1 for dim in shape[:-2]) -def count_tensor_usage(model, tensor_index): +def count_tensor_usage(model: 'circle.ModelT', tensor_index: int) -> int: """Count how many operators use a specific tensor as input""" count = 0 for subgraph in model.subgraphs: @@ -22,7 +27,8 @@ def count_tensor_usage(model, tensor_index): return count -def create_new_tensor(original_tensor, new_shape): +def create_new_tensor(original_tensor: 'circle.TensorT', + new_shape: List[int]) -> 'circle.TensorT': """Create a new tensor with the specified shape based on the original tensor""" new_tensor = circle.TensorT() new_tensor.shape = new_shape diff --git a/tools/circle2circle/reshape.io.py b/tools/circle2circle/reshape.io.py index c4074901302..d8f500f565a 100755 --- a/tools/circle2circle/reshape.io.py +++ b/tools/circle2circle/reshape.io.py @@ -2,10 +2,15 @@ import sys import argparse +from typing import List import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def parse_shape(shape_str): + +def parse_shape(shape_str: str) -> List[int]: """Parse a shape string like '[1,16,30,4]' into a list of integers.""" try: # Strip surrounding brackets and whitespace, then split by commas @@ -19,14 +24,14 @@ def parse_shape(shape_str): ) from e -def is_target_shape(shape, target_shape): +def is_target_shape(shape: List[int], target_shape: List[int]) -> bool: """Check if a tensor shape matches the target shape.""" if len(shape) != len(target_shape): return False return list(shape) == target_shape -def reshape_input_tensors(io_type, target_shape, new_shape): +def reshape_input_tensors(io_type: str, target_shape: List[int], new_shape: List[int]): """Reshape input or output tensors from target_shape to new_shape.""" model = o2o.load_model_from_stdin() diff --git a/tools/circle2circle/retype.input_ids.py b/tools/circle2circle/retype.input_ids.py index dafddbaf3db..4f360621cd9 100755 --- a/tools/circle2circle/retype.input_ids.py +++ b/tools/circle2circle/retype.input_ids.py @@ -3,6 +3,10 @@ import o2o import circle +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) + def retype_input_ids(): """Main function to change input_ids tensor type from int64 to int32""" diff --git a/tools/circle2circle/select.op.py b/tools/circle2circle/select.op.py index e0397369a8c..8bfd7e16276 100755 --- a/tools/circle2circle/select.op.py +++ b/tools/circle2circle/select.op.py @@ -2,10 +2,15 @@ import sys import argparse +from typing import List, Dict, Tuple import o2o +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def parse_operator_indices(indices_str): + +def parse_operator_indices(indices_str: str) -> List[int]: """Parse operator index string into a list of indices. Supports formats like: @@ -63,7 +68,7 @@ def parse_operator_indices(indices_str): return sorted(list(indices)) -def analyze_tensor_connections(subgraph): +def analyze_tensor_connections(subgraph: SubGraphT) -> Dict[str, any]: """Analyze all tensor connections in the subgraph. Args: @@ -116,7 +121,9 @@ def analyze_tensor_connections(subgraph): } -def select_operators_and_update_model(model, subgraph_index, operator_indices_to_keep): +def select_operators_and_update_model( + model: ModelT, subgraph_index: int, + operator_indices_to_keep: List[int]) -> Tuple[int, int]: """Keep only specified operators in the model and remove all others. Args: diff --git a/tools/circle2circle/transpose.io.kvcache.py b/tools/circle2circle/transpose.io.kvcache.py index 1f54fe6d034..68d37964621 100755 --- a/tools/circle2circle/transpose.io.kvcache.py +++ b/tools/circle2circle/transpose.io.kvcache.py @@ -2,9 +2,14 @@ import o2o import re +from typing import List +# Import specific Circle types for better type annotations +from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, + BuiltinOperator, TensorType) -def transpose_2d_3d(shape): + +def transpose_2d_3d(shape: List[int]) -> List[int]: """Transpose the second and third dimensions of a 4D shape""" if len(shape) != 4: raise ValueError("Shape must be 4D to transpose second and third dimensions") From fa07d1bf79c1f77661ef5a9bebaf3926ebc6232c Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 20 Nov 2025 15:49:18 +0900 Subject: [PATCH 17/27] Update merge.circle.py (multiple circles input and buffer deduplication) --- tools/circle2circle/README.md | 19 +- tools/circle2circle/merge.circle.py | 391 ++++++++++------------------ 2 files changed, 150 insertions(+), 260 deletions(-) diff --git a/tools/circle2circle/README.md b/tools/circle2circle/README.md index 4b6c4f0b96e..0b5f2492d4a 100644 --- a/tools/circle2circle/README.md +++ b/tools/circle2circle/README.md @@ -153,13 +153,19 @@ Generates a test Circle model with `BATCH_MATMUL` and `TRANSPOSE` operations whe ## `merge.circle.py` -Merges two Circle model files into a single model by appending their subgraphs and adding signatures. The script accepts one or more Circle files (currently limited to two). +Merges multiple Circle model files into a single model by appending their subgraphs and adding signatures. The script accepts any number of Circle files. - **Positional arguments**: - `circles` – one or more Circle model files to merge (e.g., `in1.circle in2.circle`). + `circles` – one or more Circle model files to merge (e.g., `in1.circle in2.circle in3.circle ...`). - **Optional arguments**: - `--sig-names` – semicolon‑separated signature names for the subgraphs (e.g., `"prefill;decode"`). If omitted, the script derives the signature names from the input filenames by stripping the `.circle` extension. + `--sig-names` – semicolon‑separated signature names for the subgraphs (e.g., `"prefill;decode;extra"`). If omitted, the script derives the signature names from the input filenames by stripping the `.circle` extension. + +### Features + +- **N-model merging**: Supports merging any number of input models (not limited to two). +- **Operator code deduplication**: Identical operator codes are merged to reduce redundancy. +- **Buffer deduplication**: Buffers with identical content (e.g., shared weights) are automatically deduplicated using SHA256 hashing, reducing the merged model size. ### Usage examples @@ -167,8 +173,11 @@ Merges two Circle model files into a single model by appending their subgraphs a # Merge two models, using filenames as signature names ./merge.circle.py model1.circle model2.circle -# Merge with custom signature names -./merge.circle.py model1.circle model2.circle --sig-names "prefill;decode" +# Merge three models with custom signature names +./merge.circle.py model1.circle model2.circle model3.circle --sig-names "prefill;decode;extra" + +# Merge multiple models (N models) +./merge.circle.py prefill.circle decode.circle > merged.circle ``` The merged model is written to **standard output**, allowing it to be piped into other tools or redirected to a file. diff --git a/tools/circle2circle/merge.circle.py b/tools/circle2circle/merge.circle.py index 4beb7bd21f9..715165b1695 100755 --- a/tools/circle2circle/merge.circle.py +++ b/tools/circle2circle/merge.circle.py @@ -3,284 +3,165 @@ import sys import os import argparse -from typing import List, Optional, Tuple, Dict, Any -import o2o -import circle - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType, SignatureDefT, TensorMapT) - - -def merge_operator_codes_with_deduplication( - model1: 'circle.ModelT', - model2: 'circle.ModelT') -> Tuple[List['circle.OperatorCodeT'], Dict[int, int]]: - """Merge operator codes from two models while removing duplicates. - - Args: - model1: First Circle model - model2: Second Circle model - - Returns: - tuple: (merged_operator_codes, model2_to_merged_mapping) - """ - # Start with first model's operator codes - merged_operator_codes = list(model1.operatorCodes) - - # Create mapping table for operator codes - # key: (builtinCode, customCode), value: new_index - opcode_mapping = {} - - # Register first model's operator codes - for i, op_code in enumerate(model1.operatorCodes): - key = o2o.get_operator_code_key(op_code) - opcode_mapping[key] = i - - # Process second model's operator codes (check for duplicates) - model2_to_merged_mapping = {} # model2's index → merged index - - for i, op_code in enumerate(model2.operatorCodes): - key = o2o.get_operator_code_key(op_code) +import hashlib +from typing import List, Dict, Tuple, Union, Optional - if key in opcode_mapping: - # Duplicate operator code - use existing index - model2_to_merged_mapping[i] = opcode_mapping[key] - else: - # New operator code - add it - new_index = len(merged_operator_codes) - merged_operator_codes.append(op_code) - opcode_mapping[key] = new_index - model2_to_merged_mapping[i] = new_index - - return merged_operator_codes, model2_to_merged_mapping +import circle +import o2o -def create_tensor_map_list(subgraph: 'circle.SubGraphT', - tensor_indices: List[int]) -> List['circle.TensorMapT']: - """Convert tensor indices to TensorMap objects for SignatureDef. +def get_signature_key(filepath: str) -> str: + """Extracts the filename without extension to use as a signature key.""" + filename = os.path.basename(filepath) + if filename.endswith('.circle'): + return filename[:-7] + return filename - Args: - subgraph: Subgraph containing the tensors - tensor_indices: List of tensor indices - Returns: - list: List of TensorMapT objects - """ +def create_tensor_map(subgraph: 'circle.SubGraphT', + tensor_indices: List[int]) -> List['circle.TensorMapT']: + """Creates a list of TensorMap for SignatureDef inputs/outputs.""" tensor_maps = [] - o2o.log(f"Creating tensor maps for {len(tensor_indices)} tensors") - - for i, tensor_idx in enumerate(tensor_indices): - # Skip optional inputs (-1 indicates unused optional input) - if tensor_idx == -1: - continue + for idx in tensor_indices: + tensor = subgraph.tensors[idx] + tensor_map = circle.TensorMapT() + tensor_map.name = tensor.name + tensor_map.tensorIndex = idx + tensor_maps.append(tensor_map) + return tensor_maps - # Ensure tensor index is valid - if 0 <= tensor_idx < len(subgraph.tensors): - tensor_map = circle.TensorMapT() - # Get tensor name, use fallback if no name exists - tensor_name = o2o.get_tensor_name(subgraph.tensors[tensor_idx]) - if not tensor_name: - tensor_name = f"tensor_{tensor_idx}" +def merge_models(model_paths: List[str]) -> None: + if not model_paths: + return - # Encode name as UTF-8 bytes for FlatBuffers compatibility - tensor_map.name = tensor_name.encode('utf-8') - tensor_map.tensorIndex = int(tensor_idx) # Convert numpy.int32 to int + # Load all models + models: List['circle.ModelT'] = [] + for path in model_paths: + models.append(o2o.load_circle_model(path)) - o2o.log(f" TensorMap {i}: name='{tensor_name}', index={tensor_idx}") - tensor_maps.append(tensor_map) + # Create new merged model + merged_model = circle.ModelT() + # Use the max version among all models + merged_model.version = max(m.version for m in models) + merged_model.description = f"Merged from {', '.join([os.path.basename(p) for p in model_paths])}" + + merged_model.operatorCodes = [] + merged_model.subgraphs = [] + merged_model.buffers = [] + merged_model.signatureDefs = [] + + # 1. Merge Operator Codes (Deduplication) + # Map (type, code) -> new_index + opcode_map: Dict[Tuple[str, Union[int, str]], int] = {} + # List of maps for each model: old_index -> new_index + model_opcode_maps: List[Dict[int, int]] = [] + + for model in models: + local_map = {} + for old_idx, op_code in enumerate(model.operatorCodes): + key = o2o.get_operator_code_key(op_code) + if key in opcode_map: + new_idx = opcode_map[key] + else: + new_idx = len(merged_model.operatorCodes) + merged_model.operatorCodes.append(op_code) + opcode_map[key] = new_idx + local_map[old_idx] = new_idx + model_opcode_maps.append(local_map) + + # 2. Merge Buffers (with deduplication) + # Buffer 0 is always empty sentinel. + merged_model.buffers.append(circle.BufferT()) # Sentinel + + # Map buffer content hash -> merged buffer index (for deduplication) + buffer_hash_map: Dict[bytes, int] = {} + # The sentinel buffer (empty) hash + buffer_hash_map[hashlib.sha256(bytes()).digest()] = 0 + + # List of maps for each model: old_index -> new_index + model_buffer_maps: List[Dict[int, int]] = [] + + for model in models: + local_map = {} + # Map model's sentinel (0) to merged sentinel (0) + local_map[0] = 0 + + # Process other buffers with deduplication + for old_idx in range(1, len(model.buffers)): + buffer = model.buffers[old_idx] + + # Create a hash digest from buffer data + if buffer.data is not None: + buffer_hash = hashlib.sha256(bytes(buffer.data)).digest() + else: + buffer_hash = hashlib.sha256(bytes()).digest() + + # Check if this buffer already exists + if buffer_hash in buffer_hash_map: + # Reuse existing buffer + new_idx = buffer_hash_map[buffer_hash] + else: + # Add new buffer + new_idx = len(merged_model.buffers) + merged_model.buffers.append(buffer) + buffer_hash_map[buffer_hash] = new_idx + + local_map[old_idx] = new_idx + + model_buffer_maps.append(local_map) + + # 3. Merge Subgraphs + # We assume 1 subgraph per input model. + + for model_idx, model in enumerate(models): + if not model.subgraphs: + # Create empty subgraph if none exists (though unlikely for valid models) + subgraph = circle.SubGraphT() + merged_model.subgraphs.append(subgraph) + subgraph_idx = len(merged_model.subgraphs) - 1 else: - o2o.log(f"Warning: Invalid tensor index {tensor_idx} in signature creation") - - return tensor_maps + # Take the first subgraph + subgraph = model.subgraphs[0] + # Update Operator Opcode Indices + for op in subgraph.operators: + if op.opcodeIndex in model_opcode_maps[model_idx]: + op.opcodeIndex = model_opcode_maps[model_idx][op.opcodeIndex] -def create_signatures(model: 'circle.ModelT', sig_names: List[str]) -> None: - """Create signature definitions for the merged model. + # Update Tensor Buffer Indices + for tensor in subgraph.tensors: + if tensor.buffer in model_buffer_maps[model_idx]: + tensor.buffer = model_buffer_maps[model_idx][tensor.buffer] - Args: - model: Merged Circle model - sig_names: List of signature names for each subgraph (must match subgraph count) - """ - if not hasattr(model, 'signatureDefs'): - model.signatureDefs = [] + merged_model.subgraphs.append(subgraph) + subgraph_idx = len(merged_model.subgraphs) - 1 - signatures = [] - - for idx, sig_name in enumerate(sig_names): + # 4. Create SignatureDefs sig = circle.SignatureDefT() - sig.subgraphIndex = idx - sig.signatureKey = sig_name.encode('utf-8') - - subgraph = model.subgraphs[idx] - sig.inputs = create_tensor_map_list(subgraph, subgraph.inputs) if list( - subgraph.inputs) else [] - sig.outputs = create_tensor_map_list(subgraph, subgraph.outputs) if list( - subgraph.outputs) else [] + sig.signatureKey = get_signature_key(model_paths[model_idx]) + sig.subgraphIndex = subgraph_idx + if model.subgraphs: + sig.inputs = create_tensor_map(subgraph, subgraph.inputs) + sig.outputs = create_tensor_map(subgraph, subgraph.outputs) + merged_model.signatureDefs.append(sig) - signatures.append(sig) - - model.signatureDefs = signatures - - -def merge_models_with_signatures(model1: 'circle.ModelT', model2: 'circle.ModelT', - sig_name_0: str, sig_name_1: str) -> 'circle.ModelT': - """Merge two Circle models by keeping subgraphs separate and adding signatures. - - Args: - model1: First Circle model - model2: Second Circle model - sig_name_0: Signature name for first subgraph - sig_name_1: Signature name for second subgraph - - Returns: - circle.ModelT: Merged model with signatures - """ - # Validate that both models have exactly one subgraph - if not model1.subgraphs or len(model1.subgraphs) != 1: - o2o.log("Error: First model must have exactly one subgraph") - sys.exit(1) - - if not model2.subgraphs or len(model2.subgraphs) != 1: - o2o.log("Error: Second model must have exactly one subgraph") - sys.exit(1) - - o2o.log(f"Merging models:") - o2o.log( - f" Model 1: {len(model1.subgraphs[0].tensors)} tensors, {len(model1.subgraphs[0].operators)} operators" - ) - o2o.log( - f" Model 2: {len(model2.subgraphs[0].tensors)} tensors, {len(model2.subgraphs[0].operators)} operators" - ) - - # Step 1: Merge buffers (simple append) - merged_buffers = list(model1.buffers) + list(model2.buffers) - buffer_offset = len(model1.buffers) - - # Step 2: Merge operator codes with deduplication - merged_operator_codes, model2_opcode_mapping = merge_operator_codes_with_deduplication( - model1, model2) - - # Step 3: Create merged subgraphs - merged_subgraphs = [] - - # First subgraph (keep as-is, no index remapping needed) - subgraph0 = model1.subgraphs[0] - merged_subgraphs.append(subgraph0) - - # Second subgraph (needs index remapping) - subgraph1 = model2.subgraphs[0] - - # Remap buffer indices in second subgraph tensors - for tensor in subgraph1.tensors: - if tensor.buffer is not None and tensor.buffer != 0: - tensor.buffer += buffer_offset - - # Remap operator code indices in second subgraph operators - for operator in subgraph1.operators: - if operator.opcodeIndex is not None: - operator.opcodeIndex = model2_opcode_mapping[operator.opcodeIndex] - - merged_subgraphs.append(subgraph1) - - # Step 4: Create final merged model - merged_model = circle.ModelT() - merged_model.buffers = merged_buffers - merged_model.operatorCodes = merged_operator_codes - merged_model.subgraphs = merged_subgraphs - - # Step 5: Create signatures - create_signatures(merged_model, [sig_name_0, sig_name_1]) - - o2o.log(f"Merge completed:") - o2o.log(f" Total buffers: {len(merged_buffers)}") - o2o.log(f" Total operator codes: {len(merged_operator_codes)}") - o2o.log(f" Total subgraphs: {len(merged_subgraphs)}") - o2o.log(f" Signatures: ['{sig_name_0}', '{sig_name_1}']") - - return merged_model + # Save to stdout + o2o.save_model_to_stdout(merged_model) def main(): - """Main function to merge two Circle models with signatures.""" - # This script merges multiple Circle model files into a single model. - # It keeps each input model as a separate subgraph and adds a signature - # for each subgraph. If signature names are not provided via --sig-names, - # they are derived from the input filenames (without the .circle extension). - parser = argparse.ArgumentParser( - description='Merge multiple Circle models (as subgraphs) with signatures') - # One or more Circle model files to merge, e.g. in1.circle in2.circle ... - parser.add_argument( - 'circles', - nargs='+', - help='Circle model files to merge (e.g., in1.circle in2.circle ...)') - # Optional signature names for each subgraph, separated by semicolons. - # Must match the number of input files. If omitted, names are taken from the - # input filenames (without the .circle extension). - parser.add_argument( - '--sig-names', - default=None, - help= - 'Signature names for subgraphs (semicolon‑separated). If omitted, derived from input filenames.' - ) + parser = argparse.ArgumentParser(description='Merge multiple circle models into one.') + parser.add_argument('models', nargs='+', help='Paths to the circle models to merge') args = parser.parse_args() - # Currently only support 2 models - if len(args.circles) != 2: - o2o.log("Error: Currently only 2 Circle models are supported") - sys.exit(1) - - # Parse signature names - if args.sig_names is None: - # Use filenames without .circle extension as signature names - sig_names = [os.path.splitext(os.path.basename(f))[0] for f in args.circles] - else: - # Use user-provided signature names - sig_names = args.sig_names.split(';') - if len(sig_names) != len(args.circles): - o2o.log( - f"Error: --sig-names must contain exactly {len(args.circles)} names separated by semicolon" - ) + for model_path in args.models: + if not os.path.exists(model_path): + print(f"Error: File not found: {model_path}", file=sys.stderr) sys.exit(1) - sig_names = [name.strip() for name in sig_names] - # Validate signature names are not empty - for i, sig_name in enumerate(sig_names): - if not sig_name: - o2o.log(f"Error: Signature name {i+1} cannot be empty") - sys.exit(1) + merge_models(args.models) + - sig_name_0, sig_name_1 = sig_names[0], sig_names[1] - - o2o.log(f"Loading models...") - o2o.log(f" First model: {args.circles[0]}") - o2o.log(f" Second model: {args.circles[1]}") - o2o.log(f" Signature names: ['{sig_name_0}', '{sig_name_1}']") - - # Load both models explicitly - try: - model0 = o2o.load_circle_model(args.circles[0]) - model1 = o2o.load_circle_model(args.circles[1]) - except Exception as e: - o2o.log(f"Error loading models: {e}") - sys.exit(1) - - # Merge models with signatures - try: - merged_model = merge_models_with_signatures(model0, model1, sig_name_0, - sig_name_1) - except Exception as e: - o2o.log(f"Error merging models: {e}") - sys.exit(1) - - # Output to stdout - try: - o2o.save_model_to_stdout(merged_model) - o2o.log("Successfully saved merged model to stdout") - except Exception as e: - o2o.log(f"Error saving merged model: {e}") - sys.exit(1) - - -if __name__ == "__main__": +if __name__ == '__main__': main() From 1d79f649e9d5ab77ca2a5c2f9f76d3ff57cf8378 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 07:49:27 +0900 Subject: [PATCH 18/27] Rename circle2circle to o2o --- tools/{circle2circle => o2o}/README.md | 0 tools/{circle2circle => o2o}/circle.py | 0 tools/{circle2circle => o2o}/fuse.attention.py | 0 tools/{circle2circle => o2o}/fuse.bmm_lhs_const.py | 0 tools/{circle2circle => o2o}/gc.py | 0 tools/{circle2circle => o2o}/gen_circle.add.py | 0 tools/{circle2circle => o2o}/gen_circle.bmm_lhs_const.fc.py | 0 tools/{circle2circle => o2o}/merge.circle.py | 0 tools/{circle2circle => o2o}/o2o.py | 0 tools/{circle2circle => o2o}/remove.io.py | 0 tools/{circle2circle => o2o}/rename.io.remove_namespace.py | 0 tools/{circle2circle => o2o}/rename.io.remove_prefix.py | 0 tools/{circle2circle => o2o}/reorder.output.move_0_to_last.py | 0 tools/{circle2circle => o2o}/requirements.txt | 0 tools/{circle2circle => o2o}/reshape.fc_weight.py | 0 tools/{circle2circle => o2o}/reshape.io.py | 0 tools/{circle2circle => o2o}/retype.input_ids.py | 0 tools/{circle2circle => o2o}/select.op.py | 0 tools/{circle2circle => o2o}/transpose.io.kvcache.py | 0 tools/{circle2circle => o2o}/with.py | 0 20 files changed, 0 insertions(+), 0 deletions(-) rename tools/{circle2circle => o2o}/README.md (100%) rename tools/{circle2circle => o2o}/circle.py (100%) rename tools/{circle2circle => o2o}/fuse.attention.py (100%) rename tools/{circle2circle => o2o}/fuse.bmm_lhs_const.py (100%) rename tools/{circle2circle => o2o}/gc.py (100%) rename tools/{circle2circle => o2o}/gen_circle.add.py (100%) rename tools/{circle2circle => o2o}/gen_circle.bmm_lhs_const.fc.py (100%) rename tools/{circle2circle => o2o}/merge.circle.py (100%) rename tools/{circle2circle => o2o}/o2o.py (100%) rename tools/{circle2circle => o2o}/remove.io.py (100%) rename tools/{circle2circle => o2o}/rename.io.remove_namespace.py (100%) rename tools/{circle2circle => o2o}/rename.io.remove_prefix.py (100%) rename tools/{circle2circle => o2o}/reorder.output.move_0_to_last.py (100%) rename tools/{circle2circle => o2o}/requirements.txt (100%) rename tools/{circle2circle => o2o}/reshape.fc_weight.py (100%) rename tools/{circle2circle => o2o}/reshape.io.py (100%) rename tools/{circle2circle => o2o}/retype.input_ids.py (100%) rename tools/{circle2circle => o2o}/select.op.py (100%) rename tools/{circle2circle => o2o}/transpose.io.kvcache.py (100%) rename tools/{circle2circle => o2o}/with.py (100%) diff --git a/tools/circle2circle/README.md b/tools/o2o/README.md similarity index 100% rename from tools/circle2circle/README.md rename to tools/o2o/README.md diff --git a/tools/circle2circle/circle.py b/tools/o2o/circle.py similarity index 100% rename from tools/circle2circle/circle.py rename to tools/o2o/circle.py diff --git a/tools/circle2circle/fuse.attention.py b/tools/o2o/fuse.attention.py similarity index 100% rename from tools/circle2circle/fuse.attention.py rename to tools/o2o/fuse.attention.py diff --git a/tools/circle2circle/fuse.bmm_lhs_const.py b/tools/o2o/fuse.bmm_lhs_const.py similarity index 100% rename from tools/circle2circle/fuse.bmm_lhs_const.py rename to tools/o2o/fuse.bmm_lhs_const.py diff --git a/tools/circle2circle/gc.py b/tools/o2o/gc.py similarity index 100% rename from tools/circle2circle/gc.py rename to tools/o2o/gc.py diff --git a/tools/circle2circle/gen_circle.add.py b/tools/o2o/gen_circle.add.py similarity index 100% rename from tools/circle2circle/gen_circle.add.py rename to tools/o2o/gen_circle.add.py diff --git a/tools/circle2circle/gen_circle.bmm_lhs_const.fc.py b/tools/o2o/gen_circle.bmm_lhs_const.fc.py similarity index 100% rename from tools/circle2circle/gen_circle.bmm_lhs_const.fc.py rename to tools/o2o/gen_circle.bmm_lhs_const.fc.py diff --git a/tools/circle2circle/merge.circle.py b/tools/o2o/merge.circle.py similarity index 100% rename from tools/circle2circle/merge.circle.py rename to tools/o2o/merge.circle.py diff --git a/tools/circle2circle/o2o.py b/tools/o2o/o2o.py similarity index 100% rename from tools/circle2circle/o2o.py rename to tools/o2o/o2o.py diff --git a/tools/circle2circle/remove.io.py b/tools/o2o/remove.io.py similarity index 100% rename from tools/circle2circle/remove.io.py rename to tools/o2o/remove.io.py diff --git a/tools/circle2circle/rename.io.remove_namespace.py b/tools/o2o/rename.io.remove_namespace.py similarity index 100% rename from tools/circle2circle/rename.io.remove_namespace.py rename to tools/o2o/rename.io.remove_namespace.py diff --git a/tools/circle2circle/rename.io.remove_prefix.py b/tools/o2o/rename.io.remove_prefix.py similarity index 100% rename from tools/circle2circle/rename.io.remove_prefix.py rename to tools/o2o/rename.io.remove_prefix.py diff --git a/tools/circle2circle/reorder.output.move_0_to_last.py b/tools/o2o/reorder.output.move_0_to_last.py similarity index 100% rename from tools/circle2circle/reorder.output.move_0_to_last.py rename to tools/o2o/reorder.output.move_0_to_last.py diff --git a/tools/circle2circle/requirements.txt b/tools/o2o/requirements.txt similarity index 100% rename from tools/circle2circle/requirements.txt rename to tools/o2o/requirements.txt diff --git a/tools/circle2circle/reshape.fc_weight.py b/tools/o2o/reshape.fc_weight.py similarity index 100% rename from tools/circle2circle/reshape.fc_weight.py rename to tools/o2o/reshape.fc_weight.py diff --git a/tools/circle2circle/reshape.io.py b/tools/o2o/reshape.io.py similarity index 100% rename from tools/circle2circle/reshape.io.py rename to tools/o2o/reshape.io.py diff --git a/tools/circle2circle/retype.input_ids.py b/tools/o2o/retype.input_ids.py similarity index 100% rename from tools/circle2circle/retype.input_ids.py rename to tools/o2o/retype.input_ids.py diff --git a/tools/circle2circle/select.op.py b/tools/o2o/select.op.py similarity index 100% rename from tools/circle2circle/select.op.py rename to tools/o2o/select.op.py diff --git a/tools/circle2circle/transpose.io.kvcache.py b/tools/o2o/transpose.io.kvcache.py similarity index 100% rename from tools/circle2circle/transpose.io.kvcache.py rename to tools/o2o/transpose.io.kvcache.py diff --git a/tools/circle2circle/with.py b/tools/o2o/with.py similarity index 100% rename from tools/circle2circle/with.py rename to tools/o2o/with.py From 2acd985fbb18e5258b7569384015746f914c6a43 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 13:35:09 +0900 Subject: [PATCH 19/27] Minimize o2o --- tools/o2o/README.md | 39 +--- ... => fuse.bmm_lhs_const.gen_test_circle.py} | 0 tools/o2o/gen_circle.add.py | 175 ------------------ tools/o2o/rename.io.remove_namespace.py | 82 -------- tools/o2o/rename.io.remove_prefix.py | 58 ------ tools/o2o/reorder.output.move_0_to_last.py | 48 ----- tools/o2o/with.py | 22 --- 7 files changed, 2 insertions(+), 422 deletions(-) rename tools/o2o/{gen_circle.bmm_lhs_const.fc.py => fuse.bmm_lhs_const.gen_test_circle.py} (100%) delete mode 100755 tools/o2o/gen_circle.add.py delete mode 100755 tools/o2o/rename.io.remove_namespace.py delete mode 100755 tools/o2o/rename.io.remove_prefix.py delete mode 100755 tools/o2o/reorder.output.move_0_to_last.py delete mode 100755 tools/o2o/with.py diff --git a/tools/o2o/README.md b/tools/o2o/README.md index 0b5f2492d4a..66a5d86c58f 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -14,16 +14,11 @@ All circle2circle command scripts read a Circle model from **standard input** an An example: -```bash -./rename.io.remove_namespace.py < in.circle > out.circle -``` - Filters example: ```bash -./with.py in.circle | - ./select.op.py --by_id 0-181 | - ./gc.py > new.circle +./select.op.py --by_id 0-181 < in.circle | +./gc.py > out.circle ```
@@ -42,23 +37,6 @@ Removes input or output tensors from a Circle model, keeping only the tensors at **Note:** Exactly one of `--keep_by_name` or `--keep_by_id` must be provided. -## - -### `rename.io.remove_namespace.py` - -Removes namespaces from the names of input and output tensors. A namespace is identified as the part of the tensor name before a double colon (`::`). For example, a tensor named `module::input_tensor` would be renamed to `input_tensor`. - -## - -### `rename.io.remove_prefix.py` - -Removes a user-specified prefix from the names of all tensors in the model. - -#### Arguments - -* `prefix` (required): The string prefix to remove from tensor names. - - ## ### `reshape.fc_weight.py` @@ -138,19 +116,6 @@ Finds tensors named `input_ids` and changes their data type from int64 to int32. ## -### `gen_circle.*.py` - - -These scripts generate test Circle models with specific operator patterns for development and testing purposes. Each script follows the naming convention `gen_circle..py` and automatically generates an output file with the name `.circle` when executed. - -#### `gen_circle.add.py` - -Generates a simple Circle model with one `ADD` operator for testing basic functionality. - -#### `gen_circle.bmm_lhs_const.fc.py` - -Generates a test Circle model with `BATCH_MATMUL` and `TRANSPOSE` operations where the LHS is constant. This model is designed to test the fusion pattern used in `fuse.bmm_lhs_const.py`. - ## `merge.circle.py` Merges multiple Circle model files into a single model by appending their subgraphs and adding signatures. The script accepts any number of Circle files. diff --git a/tools/o2o/gen_circle.bmm_lhs_const.fc.py b/tools/o2o/fuse.bmm_lhs_const.gen_test_circle.py similarity index 100% rename from tools/o2o/gen_circle.bmm_lhs_const.fc.py rename to tools/o2o/fuse.bmm_lhs_const.gen_test_circle.py diff --git a/tools/o2o/gen_circle.add.py b/tools/o2o/gen_circle.add.py deleted file mode 100755 index f3a47bee4c3..00000000000 --- a/tools/o2o/gen_circle.add.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python3 - -import os -import numpy as np -import circle -import o2o -from typing import List - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - -# Circle Model Buffer Usage Rules (based on circle_schema.fbs and analysis) -# ====================================================================== -# -# Buffer Index Allocation Rules: -# - B(0): Always empty placeholder buffer (sentinel, must exist) -# - B(1+): Dedicated buffers for specific tensors -# -# Tensor-Buffer Assignment Rules: -# 1. Model Input Tensors: -# - Get dedicated buffer index (e.g., B(1), B(2), ...) -# - Buffer data is EMPTY (b'') -# - Added to subgraph.inputs array -# - Example: input tensor -> buffer index 1 (empty) -# -# 2. Model Output Tensors: -# - Get dedicated buffer index (e.g., B(1), B(2), ...) -# - Buffer data is EMPTY (b'') -# - Added to subgraph.outputs array -# - Example: output tensor -> buffer index 2 (empty) -# -# 3. Constant Tensors: -# - Get dedicated buffer index (e.g., B(3), B(4), ...) -# - Buffer data contains ACTUAL DATA (numpy.tobytes()) -# - NOT added to subgraph.inputs (internal to model) -# - Example: constant tensor -> buffer index 3 (with data) -# -# 4. Intermediate Tensors: -# - Get dedicated buffer index (e.g., B(4), B(5), ...) -# - Buffer data is EMPTY (b'') -# - NOT added to subgraph.inputs/outputs -# - Example: intermediate result -> buffer index 4 (empty) -# - IMPORTANT: Intermediate tensors are NOT constants, so they need dedicated buffers! -# -# Buffer Creation Order (Recommended): -# 1. B(0): Empty placeholder buffer (always first) -# 2. Input tensor buffers (empty data) -# 3. Output tensor buffers (empty data) -# 4. Constant tensor buffers (with actual data) -# -# Key Principles: -# - Each tensor type has specific buffer requirements -# - Model inputs/outputs MUST have dedicated buffers (even if empty) -# - Constants MUST have dedicated buffers with actual data -# - Intermediate results use buffer index 0 -# - Buffer index assignment follows creation order in model.buffers array -# -# Reference: circle_schema.fbs - "buffer:uint" field documentation - - -def create_simple_add_model(output_file: str): - """Create a simple Circle model with one ADD operator (similar to add.circle).""" - - # Create model - model = circle.ModelT() - model.version = 3 - model.operatorCodes = [] - model.subgraphs = [] - model.buffers = [] - model.metadataBuffer = [] - - # Create subgraph - subgraph = circle.SubGraphT() - subgraph.tensors = [] - subgraph.inputs = [] - subgraph.outputs = [] - subgraph.operators = [] - subgraph.name = "main" - - # Create buffers in CORRECT order (output buffer last) - # B(0) - empty_sentinel_buffer (always existent empty buffer for tensors with no data buffer) - empty_sentinel_buffer = circle.BufferT() - # No data assignment for empty buffer (buffer.data remains None) - model.buffers.append(empty_sentinel_buffer) - - # B(1) - input tensor buffer (no data) - input_buffer = circle.BufferT() - # No data assignment for input buffer (buffer.data remains None) - model.buffers.append(input_buffer) - - # B(2) - constant tensor buffer (with data, 16-byte aligned) - const_data = np.array([1], dtype=np.int32) # Simple constant value - const_buffer = circle.BufferT() - # Align to 16 bytes as required by circle_schema.fbs Buffer.force_align: 16 - raw_data = const_data.tobytes() - padded_data = raw_data + b'\x00' * (16 - len(raw_data)) # 4 + 12 = 16 bytes - const_buffer.data = padded_data - model.buffers.append(const_buffer) - - # B(3) - output tensor buffer (no data) - MOVED TO LAST - output_buffer = circle.BufferT() - # No data assignment for output buffer (buffer.data remains None) - model.buffers.append(output_buffer) - - # Create input tensor (ifm) - using dedicated buffer B(1) - input_tensor = circle.TensorT() - input_tensor.shape = [1, 1, 16] # Same as add.circle - input_tensor.type = circle.TensorType.INT32 # Using INT32 - input_tensor.buffer = 1 # B(1) - dedicated input buffer - input_tensor.name = "ifm" - subgraph.tensors.append(input_tensor) - input_tensor_index = len(subgraph.tensors) - 1 - subgraph.inputs.append(input_tensor_index) - - # Create constant tensor (add_const) - using dedicated buffer B(2) - const_tensor = circle.TensorT() - const_tensor.shape = [1, 1, 1] # Same as add.circle - const_tensor.type = circle.TensorType.INT32 - const_tensor.buffer = 2 # B(2) - dedicated constant buffer with data - const_tensor.name = "add_const" - subgraph.tensors.append(const_tensor) - const_tensor_index = len(subgraph.tensors) - 1 - - # Create output tensor (ofm) - using dedicated buffer B(3) - MOVED TO LAST - output_tensor = circle.TensorT() - output_tensor.shape = [1, 1, 16] # Same as add.circle - output_tensor.type = circle.TensorType.INT32 - output_tensor.buffer = 3 # B(3) - dedicated output buffer (last index) - output_tensor.name = "ofm" - subgraph.tensors.append(output_tensor) - output_tensor_index = len(subgraph.tensors) - 1 - subgraph.outputs.append(output_tensor_index) - - # Create ADD operator code - add_opcode = circle.OperatorCodeT() - add_opcode.builtinCode = circle.BuiltinOperator.ADD - add_opcode.deprecatedBuiltinCode = circle.BuiltinOperator.ADD # Fix: deprecatedBuiltinCode must be set to same as builtinCode - add_opcode.version = 1 - model.operatorCodes.append(add_opcode) - add_opcode_index = len(model.operatorCodes) - 1 - - # Create ADD operator - add_op = circle.OperatorT() - add_op.opcodeIndex = add_opcode_index - add_op.inputs = [input_tensor_index, const_tensor_index] # ifm + add_const - add_op.outputs = [output_tensor_index] # = ofm - add_op.builtinOptionsType = circle.BuiltinOptions.AddOptions - add_options = circle.AddOptionsT() - add_options.fusedActivationFunction = circle.ActivationFunctionType.NONE - add_op.builtinOptions = add_options - subgraph.operators.append(add_op) - - # Add subgraph to model - model.subgraphs.append(subgraph) - - # Save model - o2o.save_circle_model(model, output_file) - o2o.log(f"Simple ADD model saved to {output_file}") - o2o.log(f"Model structure:") - o2o.log(f" Input tensor: {input_tensor.name} shape={input_tensor.shape}") - o2o.log(f" Constant tensor: {const_tensor.name} shape={const_tensor.shape}") - o2o.log(f" Output tensor: {output_tensor.name} shape={output_tensor.shape}") - o2o.log(f" Operator: ADD") - o2o.log(f" Subgraph inputs: {[subgraph.tensors[i].name for i in subgraph.inputs]}") - o2o.log(f" Subgraph outputs: {[subgraph.tensors[i].name for i in subgraph.outputs]}") - - -if __name__ == "__main__": - # Generate output filename from current script filename - # e.g., add.gen_circle.py -> add.circle - script_name = os.path.basename(__file__) - output_file = script_name.replace('.gen_circle.py', '.circle') - - create_simple_add_model(output_file) diff --git a/tools/o2o/rename.io.remove_namespace.py b/tools/o2o/rename.io.remove_namespace.py deleted file mode 100755 index a236c453070..00000000000 --- a/tools/o2o/rename.io.remove_namespace.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import circle -import flatbuffers -from typing import Tuple -import o2o - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - - -def load_model_from_stdin() -> 'circle.ModelT': - """Load a Circle model from binary data read from stdin.""" - data = sys.stdin.buffer.read() - buf = bytearray(data) - model = circle.Model.GetRootAsModel(buf, 0) - model = circle.ModelT.InitFromObj(model) - return model - - -def save_model_to_stdout(model: 'circle.ModelT'): - """Serialize a Circle model and write it to stdout as binary data.""" - builder = flatbuffers.Builder(1024) - builder.Finish(model.Pack(builder), b'CIR0') - sys.stdout.buffer.write(builder.Output()) - - -def remove_namespace_from_inputs_and_outputs(model: 'circle.ModelT'): - """Remove namespace from tensor names within the given model.""" - pattern = r'(.*)::(.*)' - - def process_subgraph(subgraph): - """Process a single subgraph, renaming matching tensor names.""" - o2o.log( - f"Processing subgraph with {len(subgraph.inputs)} inputs and {len(subgraph.outputs)} outputs" - ) - renamed_count = 0 - - # Process input tensors - for input_tensor_index in subgraph.inputs: - tensor = subgraph.tensors[input_tensor_index] - was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( - tensor, pattern, lambda match: match.group(2)) - if was_renamed: - o2o.log(f"Renaming input tensor: {old_name} → {new_name}") - renamed_count += 1 - - # Process output tensors - for output_tensor_index in subgraph.outputs: - tensor = subgraph.tensors[output_tensor_index] - was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( - tensor, pattern, lambda match: match.group(2)) - if was_renamed: - o2o.log(f"Renaming output tensor: {old_name} → {new_name}") - renamed_count += 1 - - if renamed_count > 0: - o2o.log(f"Renamed {renamed_count} input/output tensors in this subgraph") - else: - o2o.log("No input/output tensors were renamed in this subgraph") - - return renamed_count > 0, renamed_count - - overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) - - if not overall_modified: - o2o.log("No tensors were modified.") - else: - o2o.log(f"Total tensors renamed across all subgraphs: {total_changes}") - - -def main(): - """Entry point: read model from stdin, process, write to stdout.""" - model = load_model_from_stdin() - remove_namespace_from_inputs_and_outputs(model) - save_model_to_stdout(model) - - -if __name__ == "__main__": - main() diff --git a/tools/o2o/rename.io.remove_prefix.py b/tools/o2o/rename.io.remove_prefix.py deleted file mode 100755 index debee161429..00000000000 --- a/tools/o2o/rename.io.remove_prefix.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 - -import o2o -import re -import sys - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - - -def rename_input_tensors(prefix: str): - """Main function to rename tensors by removing the specified prefix""" - # Load the model using utility function - model = o2o.load_model_from_stdin() - - # Pattern to match tensor names that start with the specified prefix followed by anything - pattern = re.escape(prefix) + r'(.*)' - - def process_subgraph(subgraph): - """Process a single subgraph""" - o2o.log(f"Processing subgraph with {len(subgraph.tensors)} tensors") - - renamed_count = 0 - for tensor in subgraph.tensors: - was_renamed, old_name, new_name = o2o.rename_tensor_if_matches( - tensor, pattern, lambda match: match.group(1)) - - if was_renamed: - o2o.log(f"Renaming tensor: {old_name} → {new_name}") - renamed_count += 1 - - if renamed_count > 0: - o2o.log(f"Renamed {renamed_count} tensors in this subgraph") - else: - o2o.log("No tensors were renamed in this subgraph") - - return renamed_count > 0, renamed_count - - # Process all subgraphs using utility function - overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) - - if not overall_modified: - o2o.log("No tensors were modified.") - - # Save the model using utility function - o2o.save_model_to_stdout(model) - - -if __name__ == "__main__": - if len(sys.argv) != 2: - o2o.log("Usage: python rename_inputs.py ") - sys.exit(1) - - prefix = sys.argv[1] - - # Directly invoke processing; I/O handled via stdin/stdout - rename_input_tensors(prefix) diff --git a/tools/o2o/reorder.output.move_0_to_last.py b/tools/o2o/reorder.output.move_0_to_last.py deleted file mode 100755 index e05e336d897..00000000000 --- a/tools/o2o/reorder.output.move_0_to_last.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 - -from typing import List -import o2o - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - - -def reorder_output_tensors() -> None: - """Reorder output tensors: move tensor 0 to the end, shift others forward""" - o2o.log("Loading model from stdin") - model = o2o.load_model_from_stdin() - - if not model.subgraphs: - o2o.log("Model has no subgraphs. Exiting.") - o2o.save_model_to_stdout(model) - return - - for subgraph_idx, subgraph in enumerate(model.subgraphs): - if len(subgraph.outputs) <= 1: - o2o.log( - f"Subgraph {subgraph_idx}: Only {len(subgraph.outputs)} output tensor(s), no reordering needed" - ) - continue - - # Convert numpy array to Python list for proper concatenation - original_outputs = subgraph.outputs.copy() - outputs_list = original_outputs.tolist() - - # Move first output tensor to the end - # Original: [a, b, c, d] -> New: [b, c, d, a] - first_output = outputs_list[0] - other_outputs = outputs_list[1:] - new_outputs = other_outputs + [first_output] - - subgraph.outputs = new_outputs - o2o.log( - f"Subgraph {subgraph_idx}: Reordered outputs {original_outputs.tolist()} -> {new_outputs}" - ) - - o2o.save_model_to_stdout(model) - - -if __name__ == "__main__": - # Directly invoke processing; I/O handled via stdin/stdout - reorder_output_tensors() diff --git a/tools/o2o/with.py b/tools/o2o/with.py deleted file mode 100755 index 602253883cd..00000000000 --- a/tools/o2o/with.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 -import sys -import pathlib - - -def main(): - if len(sys.argv) < 2: - sys.stderr.write("Usage: with.py \\n") - sys.exit(1) - - input_path = pathlib.Path(sys.argv[1]) - if not input_path.is_file(): - sys.stderr.write(f"File not found: {input_path}\\n") - sys.exit(1) - - # Read the binary content of the circle file and write it to stdout - with input_path.open('rb') as f: - sys.stdout.buffer.write(f.read()) - - -if __name__ == "__main__": - main() From 14c3a21bdd04751c1f70b25a5d8e8ff2971ddcb4 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 14:38:23 +0900 Subject: [PATCH 20/27] Incorporate reshape.fc_weight.py into fuse.bmm_lhs_const.py --- tools/o2o/README.md | 28 ++--- tools/o2o/fuse.attention.py | 10 -- tools/o2o/fuse.bmm_lhs_const.py | 175 ++++++++++++++++++-------------- tools/o2o/o2o.py | 39 ++++++- tools/o2o/reshape.fc_weight.py | 119 ---------------------- 5 files changed, 149 insertions(+), 222 deletions(-) delete mode 100755 tools/o2o/reshape.fc_weight.py diff --git a/tools/o2o/README.md b/tools/o2o/README.md index 66a5d86c58f..0bc4b119cc0 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -39,21 +39,9 @@ Removes input or output tensors from a Circle model, keeping only the tensors at ## -### `reshape.fc_weight.py` - -Reshapes the weight tensors of `FULLY_CONNECTED` operators from an effectively 2D shape (e.g., `[1, 1, D_out, D_in]`) to a strict 2D shape (`[D_out, D_in]`). This is useful for optimizing or standardizing the model structure. If a weight tensor is used by multiple operators, a new tensor is created for the specific operator to prevent conflicts. - -## - -### `transpose.io.kcache.py` - -Finds input tensors matching the pattern `*key_cache_\d+` (e.g., `past_key_values_key_cache_0`) and transposes their second and third dimensions if they are 4D. For example, a shape `[d0, d1, d2, d3]` will become `[d0, d2, d1, d3]`. - -## - ### `fuse.bmm_lhs_const.py` -Fuses `BATCH_MATMUL` + `TRANSPOSE` to `FULLY_CONNECTED` when LHS is constant. +Fuses `BATCH_MATMUL` + `TRANSPOSE` to `FULLY_CONNECTED` when LHS is constant, and automatically reshapes the weight tensors of the **newly created** `FULLY_CONNECTED` operators from effectively 2D shapes (e.g., `[1, 1, D_out, D_in]`) to strict 2D shapes (`[D_out, D_in]`). #### Transformation Diagram @@ -79,8 +67,22 @@ Key Relationship: - BatchMatMul's RHS becomes FullyConnected's input ``` +#### Additional Processing + +After creating each fused `FULLY_CONNECTED` operator, this script automatically reshapes its weight tensor: +- Converts effectively 2D shapes (e.g., `[1, 1, D_out, D_in]`) to strict 2D shapes (`[D_out, D_in]`) +- If a weight tensor is used by multiple operators, creates a new tensor for the specific operator to prevent conflicts +- Sets `keepNumDims = True` to preserve batch dimensions + ## +### `transpose.io.kcache.py` + +Finds input tensors matching the pattern `*key_cache_\d+` (e.g., `past_key_values_key_cache_0`) and transposes their second and third dimensions if they are 4D. For example, a shape `[d0, d1, d2, d3]` will become `[d0, d2, d1, d3]`. + +## + + ### `select.op.py` Selectively removes operators from a Circle model based on their index range. This filter allows you to keep only the operators within specified index ranges while removing all others. It automatically handles tensor connections, updates subgraph inputs/outputs, and cleans up unused operator codes. diff --git a/tools/o2o/fuse.attention.py b/tools/o2o/fuse.attention.py index bc67fd19bd1..b0387846103 100755 --- a/tools/o2o/fuse.attention.py +++ b/tools/o2o/fuse.attention.py @@ -10,16 +10,6 @@ BuiltinOperator, TensorType) -def find_operator_by_output( - subgraph: 'circle.SubGraphT', - output_tensor_index: int) -> Tuple[Optional[int], Optional['circle.OperatorT']]: - """Find the first operator that produces the given output tensor index.""" - for op_idx, operator in enumerate(subgraph.operators): - if operator.outputs and output_tensor_index in operator.outputs: - return op_idx, operator - return None, None - - def find_attention_blocks(model: 'circle.ModelT', subgraph: 'circle.SubGraphT') -> List[Dict[str, Any]]: """Find all attention blocks in the subgraph.""" diff --git a/tools/o2o/fuse.bmm_lhs_const.py b/tools/o2o/fuse.bmm_lhs_const.py index d2582c8928e..24488719a5f 100755 --- a/tools/o2o/fuse.bmm_lhs_const.py +++ b/tools/o2o/fuse.bmm_lhs_const.py @@ -10,51 +10,6 @@ BuiltinOperator, TensorType) -def get_tensor_by_index(subgraph: 'circle.SubGraphT', - index: int) -> Optional['circle.TensorT']: - """Safely get tensor by its index.""" - if 0 <= index < len(subgraph.tensors): - return subgraph.tensors[index] - return None - - -def is_tensor_constant(tensor: 'circle.TensorT', - model_buffers: List['circle.BufferT']) -> bool: - """Check if a tensor is constant by verifying its buffer.""" - if tensor and tensor.buffer != 0 and 0 <= tensor.buffer - 1 < len(model_buffers): - # A non-zero buffer index that points to a valid buffer typically means it's constant. - # The 0th buffer is always an empty buffer. - return True - return False - - -def find_operator_by_output( - subgraph: 'circle.SubGraphT', - output_tensor_index: int) -> Tuple[Optional[int], Optional['circle.OperatorT']]: - """Find the first operator that produces the given output tensor index.""" - for op_idx, operator in enumerate(subgraph.operators): - if operator.outputs and output_tensor_index in operator.outputs: - return op_idx, operator - return None, None - - -def from_buffer(buffer_index, model_buffers): - """Converts buffer data to a numpy array (int32).""" - if buffer_index > 0 and buffer_index - 1 < len(model_buffers): - buffer_obj = model_buffers[buffer_index] - if buffer_obj and len(buffer_obj.data) > 0: - # Assuming data is a bytearray of int32s. - # This needs to match the actual data type in the model. - try: - return np.frombuffer(buffer_obj.data, dtype=np.int32) - except Exception as e: - o2o.log( - f"Could not parse permutation tensor buffer for buffer index {buffer_index}: {e}" - ) - return None - return None - - def is_effectively_2d(shape): """Check if a tensor shape is effectively 2D (all leading dimensions are 1)""" if len(shape) < 2: @@ -74,21 +29,6 @@ def count_tensor_usage(model, tensor_index): return count -def get_or_create_operator_code(model: 'circle.ModelT', - builtin_op_type: 'circle.BuiltinOperator') -> int: - """Get the index of an operator code, or create it if it doesn't exist.""" - for i, op_code in enumerate(model.operatorCodes): - if op_code.builtinCode == builtin_op_type: - return i - - # If not found, create a new one - new_op_code = circle.OperatorCodeT() - new_op_code.builtinCode = builtin_op_type - new_op_code.version = 1 # Default version - model.operatorCodes.append(new_op_code) - return len(model.operatorCodes) - 1 - - def create_transpose_permutation_tensor(model: 'circle.ModelT', subgraph: 'circle.SubGraphT', rank: int) -> int: """Create a permutation tensor for transposing last two dimensions.""" @@ -115,6 +55,69 @@ def create_transpose_permutation_tensor(model: 'circle.ModelT', return tensor_index +def create_reshaped_tensor(original_tensor: 'circle.TensorT', + new_shape: List[int]) -> 'circle.TensorT': + """Create a new tensor with the specified shape based on the original tensor""" + new_tensor = circle.TensorT() + new_tensor.shape = new_shape + new_tensor.type = original_tensor.type + new_tensor.buffer = original_tensor.buffer + new_tensor.name = original_tensor.name + new_tensor.quantization = original_tensor.quantization + new_tensor.isVariable = original_tensor.isVariable + new_tensor.sparsity = original_tensor.sparsity + new_tensor.shapeSignature = original_tensor.shapeSignature + new_tensor.hasRank = original_tensor.hasRank + new_tensor.variantTensors = original_tensor.variantTensors + new_tensor.compressionType = original_tensor.compressionType + return new_tensor + + +def reshape_fc_weights(model: 'circle.ModelT', subgraph: 'circle.SubGraphT', + fc_op_idx: int) -> None: + """Reshape FullyConnected weights from effectively 2D to 2D for a specific operator""" + fc_op = subgraph.operators[fc_op_idx] + + # Get the weights tensor (typically the second input) + if len(fc_op.inputs) < 2: + return + + weights_index = fc_op.inputs[1] # Weights is usually the second input + weights_tensor = subgraph.tensors[weights_index] + + # Check if the weights tensor is effectively 2D + if len(weights_tensor.shape) > 2 and is_effectively_2d(weights_tensor.shape): + # Ensure keepNumDims is True + if hasattr(fc_op.builtinOptions, 'keepNumDims'): + fc_op.builtinOptions.keepNumDims = True + + # Check if this tensor is used by multiple operators + usage_count = count_tensor_usage(model, weights_index) + + if usage_count > 1: + # Create a new tensor for this operator to avoid affecting others + new_shape = weights_tensor.shape[-2:] # Remove leading dimensions of 1 + new_tensor = create_reshaped_tensor(weights_tensor, new_shape) + + # Add the new tensor to the subgraph + new_tensor_index = len(subgraph.tensors) + subgraph.tensors.append(new_tensor) + + # Update the operator input to use the new tensor + fc_op.inputs[1] = new_tensor_index + + o2o.log( + f"Created reshaped weight tensor {new_tensor_index} for FC operator at {fc_op_idx}" + ) + else: + # Directly modify the tensor shape since it's only used once + original_shape = weights_tensor.shape + weights_tensor.shape = weights_tensor.shape[-2:] + o2o.log( + f"Reshaped weight tensor {weights_index} from {original_shape} to {weights_tensor.shape}" + ) + + def add_rhs_transpose_if_needed(model: 'circle.ModelT', subgraph: 'circle.SubGraphT', bmm_op_idx: int, rhs_tensor_index: int, rhs_tensor: 'circle.TensorT') -> int: @@ -151,7 +154,7 @@ def add_rhs_transpose_if_needed(model: 'circle.ModelT', subgraph: 'circle.SubGra # Create TRANSPOSE operator transpose_op = circle.OperatorT() - transpose_op.opcodeIndex = get_or_create_operator_code( + transpose_op.opcodeIndex = o2o.get_or_create_operator_code( model, circle.BuiltinOperator.TRANSPOSE) transpose_op.inputs = [rhs_tensor_index, perm_tensor_index] transpose_op.outputs = [transposed_rhs_tensor_index] @@ -196,7 +199,8 @@ def fuse_bmm_transpose() -> None: continue transpose_input_tensor_idx = transpose_op.inputs[0] - bmm_op_idx, bmm_op = find_operator_by_output(subgraph, transpose_input_tensor_idx) + bmm_op_idx, bmm_op = o2o.find_operator_by_output(subgraph, + transpose_input_tensor_idx) # Check if the found operator is BATCH_MATMUL if bmm_op is None or model.operatorCodes[ @@ -206,8 +210,8 @@ def fuse_bmm_transpose() -> None: lhs_tensor_index = bmm_op.inputs[0] rhs_tensor_index = bmm_op.inputs[1] - lhs_tensor = get_tensor_by_index(subgraph, lhs_tensor_index) - rhs_tensor = get_tensor_by_index(subgraph, rhs_tensor_index) + lhs_tensor = o2o.get_tensor_by_index(subgraph, lhs_tensor_index) + rhs_tensor = o2o.get_tensor_by_index(subgraph, rhs_tensor_index) if not lhs_tensor or not rhs_tensor: o2o.log( @@ -216,7 +220,7 @@ def fuse_bmm_transpose() -> None: continue # Crucial check: LHS must be constant - if not is_tensor_constant(lhs_tensor, model.buffers): + if not o2o.is_tensor_constant(lhs_tensor, model.buffers): o2o.log( f"LHS tensor '{lhs_tensor.name if lhs_tensor.name else lhs_tensor_index}' for BATCH_MATMUL at index {bmm_op_idx} is not constant. Skipping fusion." ) @@ -228,21 +232,30 @@ def fuse_bmm_transpose() -> None: # For a 3D tensor [B, M, N] -> [B, N, M], permutation is [0, 2, 1] valid_permutation = False perm_tensor_index = transpose_op.inputs[1] - perm_tensor = get_tensor_by_index(subgraph, perm_tensor_index) + perm_tensor = o2o.get_tensor_by_index(subgraph, perm_tensor_index) - if perm_tensor and is_tensor_constant(perm_tensor, model.buffers): + if perm_tensor and o2o.is_tensor_constant(perm_tensor, model.buffers): # Get permutation data from buffer using the new helper function - perm = from_buffer(perm_tensor.buffer, model.buffers) - if len(perm) >= 2: # At least 2D - # Check if the last two elements of permutation are swapped - # and other elements are in their original ascending order (0, 1, 2, ...) - expected_perm_prefix = list(range(len(perm) - 2)) - actual_perm_prefix = perm[:-2] - - if np.all(actual_perm_prefix == expected_perm_prefix) and \ - perm[-2] == len(perm) - 1 and \ - perm[-1] == len(perm) - 2: - valid_permutation = True + perm = o2o.from_buffer(perm_tensor.buffer, model.buffers) + if perm is None: + o2o.log( + f"Could not read permutation buffer at index {perm_tensor.buffer}") + valid_permutation = False + else: + # Use the tensor's shape to determine the actual permutation length + perm_length = perm_tensor.shape[0] if perm_tensor.shape else len(perm) + perm = perm[:perm_length] # Trim to actual length + + if len(perm) >= 2: # At least 2D + # Check if the last two elements of permutation are swapped + # and other elements are in their original ascending order (0, 1, 2, ...) + expected_perm_prefix = list(range(len(perm) - 2)) + actual_perm_prefix = perm[:-2] + + if np.all(actual_perm_prefix == expected_perm_prefix) and \ + perm[-2] == len(perm) - 1 and \ + perm[-1] == len(perm) - 2: + valid_permutation = True else: o2o.log( f"Permutation tensor for TRANSPOSE at index {i} is not constant or not found. Skipping." @@ -260,7 +273,7 @@ def fuse_bmm_transpose() -> None: # Create the new FULLY_CONNECTED operator fc_op = circle.OperatorT() - fc_op.opcodeIndex = get_or_create_operator_code( + fc_op.opcodeIndex = o2o.get_or_create_operator_code( model, circle.BuiltinOperator.FULLY_CONNECTED) # Set inputs: [transposed_rhs, original_lhs, -1] where -1 means bias not exists fc_op.inputs = [final_rhs_tensor_index, lhs_tensor_index, -1] @@ -278,6 +291,9 @@ def fuse_bmm_transpose() -> None: o2o.log(f"Replacing batchmatmul at {bmm_op_idx} with fullyconnected") subgraph.operators[bmm_op_idx] = fc_op + # Reshape the weights of the newly created FC operator + reshape_fc_weights(model, subgraph, bmm_op_idx) + # Mark the original TRANSPOSE operator for removal operators_to_remove.append(i) @@ -295,6 +311,7 @@ def fuse_bmm_transpose() -> None: # Note: Cleanup of unused tensors and operator codes is a more advanced step # and not implemented here for simplicity, but would be part of a production-ready script. o2o.log(f"TODO: Remove tensors at {tensors_to_potentially_remove}") + o2o.save_model_to_stdout(model) diff --git a/tools/o2o/o2o.py b/tools/o2o/o2o.py index 130fd120daf..73d3eba5dcf 100755 --- a/tools/o2o/o2o.py +++ b/tools/o2o/o2o.py @@ -274,7 +274,8 @@ def get_operator_code_key( return ('unknown', None) -def get_or_create_operator_code(model, builtin_op_type) -> int: +def get_or_create_operator_code(model: 'circle.ModelT', + builtin_op_type: 'circle.BuiltinOperator') -> int: """Get the index of an operator code, or create it if it doesn't exist.""" for i, op_code in enumerate(model.operatorCodes): if op_code.builtinCode == builtin_op_type: @@ -287,3 +288,39 @@ def get_or_create_operator_code(model, builtin_op_type) -> int: new_op_code.version = 1 # Default version model.operatorCodes.append(new_op_code) return len(model.operatorCodes) - 1 + + +def find_operator_by_output( + subgraph: 'circle.SubGraphT', + output_tensor_index: int) -> Tuple[Optional[int], Optional['circle.OperatorT']]: + """Find the first operator that produces the given output tensor index.""" + for op_idx, operator in enumerate(subgraph.operators): + if operator.outputs and output_tensor_index in operator.outputs: + return op_idx, operator + return None, None + + +def from_buffer(buffer_index: int, + model_buffers: List['circle.BufferT']) -> Optional['np.ndarray']: + """Converts buffer data to a numpy array (int32). + + Args: + buffer_index: Buffer index (1-based, as stored in tensor.buffer) + model_buffers: List of buffers from the model + + Returns: + numpy array of int32 values, or None if buffer is invalid + """ + import numpy as np + + if buffer_index > 0 and buffer_index < len(model_buffers): + buffer_obj = model_buffers[buffer_index] + if buffer_obj and len(buffer_obj.data) > 0: + # Assuming data is a bytearray of int32s. + # This needs to match the actual data type in the model. + try: + return np.frombuffer(buffer_obj.data, dtype=np.int32) + except Exception as e: + log(f"Could not parse buffer for buffer index {buffer_index}: {e}") + return None + return None diff --git a/tools/o2o/reshape.fc_weight.py b/tools/o2o/reshape.fc_weight.py deleted file mode 100755 index a1bdc8ff7b5..00000000000 --- a/tools/o2o/reshape.fc_weight.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 - -import numpy as np -from typing import List, Dict -import circle -import o2o - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - - -def is_effectively_2d(shape): - """Check if a tensor shape is effectively 2D (all leading dimensions are 1)""" - return all(dim == 1 for dim in shape[:-2]) - - -def count_tensor_usage(model: 'circle.ModelT', tensor_index: int) -> int: - """Count how many operators use a specific tensor as input""" - count = 0 - for subgraph in model.subgraphs: - for operator in subgraph.operators: - if operator.inputs is not None: - for input_idx in operator.inputs: - if input_idx == tensor_index: - count += 1 - return count - - -def create_new_tensor(original_tensor: 'circle.TensorT', - new_shape: List[int]) -> 'circle.TensorT': - """Create a new tensor with the specified shape based on the original tensor""" - new_tensor = circle.TensorT() - new_tensor.shape = new_shape - new_tensor.type = original_tensor.type - new_tensor.buffer = original_tensor.buffer - new_tensor.name = original_tensor.name + "_reshaped" if original_tensor.name else None - new_tensor.quantization = original_tensor.quantization - new_tensor.isVariable = original_tensor.isVariable - new_tensor.sparsity = original_tensor.sparsity - new_tensor.shapeSignature = original_tensor.shapeSignature - new_tensor.hasRank = original_tensor.hasRank - new_tensor.variantTensors = original_tensor.variantTensors - new_tensor.compressionType = original_tensor.compressionType - return new_tensor - - -def modify_fully_connected_weights(): - """Main function to modify FullyConnected weights from effectively 2D to 2D""" - # Load the model using utility function - model = o2o.load_model_from_stdin() - - # Process each subgraph - for subgraph in model.subgraphs: - # Create a mapping from old tensor indices to new tensor indices - tensor_mapping = {} - - # First pass: identify and create new tensors for modification - for i, operator in enumerate(subgraph.operators): - # Check if this is a FullyConnected operator - opcode = model.operatorCodes[operator.opcodeIndex] - if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: - # Get the weights tensor (typically the second input) - if len(operator.inputs) >= 2: - weights_index = operator.inputs[ - 1] # Weights is usually the second input - weights_tensor = subgraph.tensors[weights_index] - - # Check if the weights tensor is effectively 2D - if len(weights_tensor.shape) > 2 and is_effectively_2d( - weights_tensor.shape): - operator.builtinOptions.keepNumDims = True - # Check if this tensor is used by multiple operators - usage_count = count_tensor_usage(model, weights_index) - - if usage_count > 1: - # Create a new tensor for this operator to avoid affecting others - new_shape = weights_tensor.shape[ - -2:] # Remove leading dimensions of 1 - new_tensor = create_new_tensor(weights_tensor, new_shape) - - # Add the new tensor to the subgraph - new_tensor_index = len(subgraph.tensors) - subgraph.tensors.append(new_tensor) - - # Update the mapping for this specific operator - if i not in tensor_mapping: - tensor_mapping[i] = {} - tensor_mapping[i][weights_index] = new_tensor_index - else: - # Directly modify the tensor shape since it's only used once - weights_tensor.shape = weights_tensor.shape[-2:] - - # Second pass: update operator inputs based on the mapping - for i, operator in enumerate(subgraph.operators): - # Check if this is a FullyConnected operator - opcode = model.operatorCodes[operator.opcodeIndex] - if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: - # Update inputs according to the mapping - if i in tensor_mapping: - for j, input_idx in enumerate(operator.inputs): - if input_idx in tensor_mapping[i]: - operator.inputs[j] = tensor_mapping[i][input_idx] - else: - # For tensors that were directly modified, just check if they need updating - if len(operator.inputs) >= 2: - weights_index = operator.inputs[1] - weights_tensor = subgraph.tensors[weights_index] - if is_effectively_2d(weights_tensor.shape): - # Update the shape to be truly 2D - weights_tensor.shape = weights_tensor.shape[-2:] - - # Save the model using utility function - o2o.save_model_to_stdout(model) - - -if __name__ == "__main__": - # Directly invoke processing; I/O handled via stdin/stdout - modify_fully_connected_weights() From d2e664edd07452cf35c512c2422d1ac238e258a8 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 14:41:50 +0900 Subject: [PATCH 21/27] Rename merge.circle.py and fuse.bmm_lhs_const.gen_test.circle.py --- tools/o2o/README.md | 11 ++++------- ...st.gen_test_circle.py => fuse.bmm_lhs_const.tc.py} | 0 tools/o2o/{merge.circle.py => merge.circles.py} | 0 3 files changed, 4 insertions(+), 7 deletions(-) rename tools/o2o/{fuse.bmm_lhs_const.gen_test_circle.py => fuse.bmm_lhs_const.tc.py} (100%) rename tools/o2o/{merge.circle.py => merge.circles.py} (100%) diff --git a/tools/o2o/README.md b/tools/o2o/README.md index 0bc4b119cc0..f6e70db0c57 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -118,7 +118,7 @@ Finds tensors named `input_ids` and changes their data type from int64 to int32. ## -## `merge.circle.py` +## `merge.circles.py` Merges multiple Circle model files into a single model by appending their subgraphs and adding signatures. The script accepts any number of Circle files. @@ -137,14 +137,11 @@ Merges multiple Circle model files into a single model by appending their subgra ### Usage examples ```bash -# Merge two models, using filenames as signature names -./merge.circle.py model1.circle model2.circle +# Merge multiple models (N models) +./merge.circles.py prefill.circle decode.circle > merged.circle # Merge three models with custom signature names -./merge.circle.py model1.circle model2.circle model3.circle --sig-names "prefill;decode;extra" - -# Merge multiple models (N models) -./merge.circle.py prefill.circle decode.circle > merged.circle +./merge.circles.py model1.circle model2.circle model3.circle --sig-names "prefill;decode;extra" ``` The merged model is written to **standard output**, allowing it to be piped into other tools or redirected to a file. diff --git a/tools/o2o/fuse.bmm_lhs_const.gen_test_circle.py b/tools/o2o/fuse.bmm_lhs_const.tc.py similarity index 100% rename from tools/o2o/fuse.bmm_lhs_const.gen_test_circle.py rename to tools/o2o/fuse.bmm_lhs_const.tc.py diff --git a/tools/o2o/merge.circle.py b/tools/o2o/merge.circles.py similarity index 100% rename from tools/o2o/merge.circle.py rename to tools/o2o/merge.circles.py From 7f364ed3f78884bd1aebdda62c5d9462e4a36aa4 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 15:02:02 +0900 Subject: [PATCH 22/27] remove.io.py is removed. gc() will do automatically --- tools/o2o/README.md | 14 --- tools/o2o/fuse.bmm_lhs_const.tc.py | 2 +- tools/o2o/gc.py | 57 ++++++++++ tools/o2o/gc.tc.py | 69 +++++++++++++ tools/o2o/remove.io.py | 160 ----------------------------- 5 files changed, 127 insertions(+), 175 deletions(-) create mode 100755 tools/o2o/gc.tc.py delete mode 100755 tools/o2o/remove.io.py diff --git a/tools/o2o/README.md b/tools/o2o/README.md index f6e70db0c57..6344270d355 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -25,20 +25,6 @@ Filters example: ## Filter List -### `remove.io.py` - -Removes input or output tensors from a Circle model, keeping only the tensors at the specified indices. - -#### Arguments - -* `io_type` (required): Specifies whether to process `input` or `output` tensors. -* `--keep_by_name` (optional): A string defining the names of the tensors to keep. It supports comma‑separated tensor names (e.g., "input1,input2"). -* `--keep_by_id` (optional): Specifies the tensor indices to keep. Supports multiple ranges separated by commas and individual indices (e.g., "0,2-4"). - -**Note:** Exactly one of `--keep_by_name` or `--keep_by_id` must be provided. - -## - ### `fuse.bmm_lhs_const.py` Fuses `BATCH_MATMUL` + `TRANSPOSE` to `FULLY_CONNECTED` when LHS is constant, and automatically reshapes the weight tensors of the **newly created** `FULLY_CONNECTED` operators from effectively 2D shapes (e.g., `[1, 1, D_out, D_in]`) to strict 2D shapes (`[D_out, D_in]`). diff --git a/tools/o2o/fuse.bmm_lhs_const.tc.py b/tools/o2o/fuse.bmm_lhs_const.tc.py index edb5dca2551..2113c397649 100755 --- a/tools/o2o/fuse.bmm_lhs_const.tc.py +++ b/tools/o2o/fuse.bmm_lhs_const.tc.py @@ -189,6 +189,6 @@ def create_test_bmm_k_not_1_model(output_file: str): # Generate output filename from current script filename # e.g., cvt.bmm_lhs_const.fc.circle_gen.py -> cvt.bmm_lhs_const.fc.circle script_name = os.path.basename(__file__) - output_file = script_name.replace('gen_circle.', '').replace('.py', '.circle') + output_file = os.path.splitext(script_name)[0] + '.circle' create_test_bmm_k_not_1_model(output_file) diff --git a/tools/o2o/gc.py b/tools/o2o/gc.py index d55250e4c48..d82404e9880 100755 --- a/tools/o2o/gc.py +++ b/tools/o2o/gc.py @@ -274,6 +274,59 @@ def remove_buffers_and_update_model(model: ModelT, return sorted(removed_indices) +def prune_unused_io(model: ModelT) -> bool: + """ + Removes tensors from Subgraph Inputs/Outputs if they are not connected to any operator. + + Args: + model: The mutable Circle model object. + + Returns: + bool: True if any changes were made. + """ + changed = False + for i, subgraph in enumerate(model.subgraphs): + # Collect used inputs and outputs from operators + op_inputs = set() + op_outputs = set() + for op_idx, op in enumerate(subgraph.operators): + if op.inputs is not None: + for inp in op.inputs: + if inp != -1: + op_inputs.add(inp) + if op.outputs is not None: + for out in op.outputs: + op_outputs.add(out) + + # Prune Subgraph Inputs + # A Subgraph Input is unused if it is not consumed by any operator + if subgraph.inputs is not None: + original_len = len(subgraph.inputs) + new_inputs = [idx for idx in subgraph.inputs if idx in op_inputs] + if len(new_inputs) < original_len: + removed = [idx for idx in subgraph.inputs if idx not in op_inputs] + o2o.log( + f"Subgraph {i}: Pruning unused inputs (not consumed by any op): {removed}" + ) + subgraph.inputs = new_inputs + changed = True + + # Prune Subgraph Outputs + # A Subgraph Output is unused if it is not produced by any operator + if subgraph.outputs is not None: + original_len = len(subgraph.outputs) + new_outputs = [idx for idx in subgraph.outputs if idx in op_outputs] + if len(new_outputs) < original_len: + removed = [idx for idx in subgraph.outputs if idx not in op_outputs] + o2o.log( + f"Subgraph {i}: Pruning unused outputs (not produced by any op): {removed}" + ) + subgraph.outputs = new_outputs + changed = True + + return changed + + def main(): # Read the entire model from stdin data = sys.stdin.buffer.read() @@ -286,6 +339,10 @@ def main(): total_unused_tensors_count = 0 model_changed = False + # Prune unused inputs/outputs first + if prune_unused_io(model): + model_changed = True + o2o.log(f"Processing {model_ro.SubgraphsLength()} subgraph(s)...") for i in range(model_ro.SubgraphsLength()): subgraph_ro = model_ro.Subgraphs(i) diff --git a/tools/o2o/gc.tc.py b/tools/o2o/gc.tc.py new file mode 100755 index 00000000000..9db9a18abe4 --- /dev/null +++ b/tools/o2o/gc.tc.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +import circle +import o2o +import sys + + +def generate_test_model(): + model = circle.ModelT() + subgraph = circle.SubGraphT() + model.subgraphs = [subgraph] + + # Create tensors + # T0: Input, used by Op + # T1: Input, UNUSED + # T2: Output, produced by Op + # T3: Output, UNUSED (not produced by any Op) + + t0 = circle.TensorT() + t0.name = "input_used" + t0.shape = [1, 2] + t0.type = circle.TensorType.FLOAT32 + + t1 = circle.TensorT() + t1.name = "input_unused" + t1.shape = [1, 2] + t1.type = circle.TensorType.FLOAT32 + + t2 = circle.TensorT() + t2.name = "output_used" + t2.shape = [1, 2] + t2.type = circle.TensorType.FLOAT32 + + t3 = circle.TensorT() + t3.name = "output_unused" + t3.shape = [1, 2] + t3.type = circle.TensorType.FLOAT32 + + subgraph.tensors = [t0, t1, t2, t3] + + # Set inputs and outputs + subgraph.inputs = [0, 1] # t0, t1 + subgraph.outputs = [2, 3] # t2, t3 + + # Create an operator that uses T0 and produces T2 + # Use NEG (unary) + + neg_op = circle.OperatorT() + neg_op.opcodeIndex = 0 # Will set code below + neg_op.inputs = [0] # Uses T0 + neg_op.outputs = [2] # Produces T2 + + subgraph.operators = [neg_op] + + # Add OperatorCode + op_code = circle.OperatorCodeT() + op_code.builtinCode = circle.BuiltinOperator.NEG + op_code.version = 1 + model.operatorCodes = [op_code] + + # Add default empty buffer (Buffer 0) + b0 = circle.BufferT() + model.buffers = [b0] + + # Save to stdout + o2o.save_model_to_stdout(model) + + +if __name__ == "__main__": + generate_test_model() diff --git a/tools/o2o/remove.io.py b/tools/o2o/remove.io.py deleted file mode 100755 index df4439bb319..00000000000 --- a/tools/o2o/remove.io.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import argparse -from typing import List -import o2o - -# Import specific Circle types for better type annotations -from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, - BuiltinOperator, TensorType) - - -def parse_names(names_str: str) -> List[str]: - """Parse comma‑separated tensor names into a list of names.""" - return [name.strip() for name in names_str.split(',') if name.strip()] - - -def remove_io_tensors(io_type: str, names_to_keep: List[str]) -> None: - """Remove input or output tensors, keeping only specified tensor names""" - # Load the model using utility function - model = o2o.load_model_from_stdin() - - def process_subgraph(subgraph): - """Process a single subgraph""" - if io_type == 'input': - io_list = subgraph.inputs - io_name = 'input' - elif io_type == 'output': - io_list = subgraph.outputs - io_name = 'output' - else: - raise ValueError(f"Invalid io_type: {io_type}. Must be 'input' or 'output'") - - o2o.log(f"Processing subgraph with {len(io_list)} {io_name}s") - o2o.log(f"Original {io_name} indices: {io_list}") - - # Build a mapping from tensor name to its index for the selected I/O list - name_to_index = {} - for io_idx in io_list: - tensor = subgraph.tensors[io_idx] - tensor_name = o2o.get_tensor_name(tensor) - if tensor_name: - name_to_index[tensor_name] = io_idx - - # Filter tensors to keep by name - new_io_list = [] - for name in names_to_keep: - if name in name_to_index: - new_io_list.append(name_to_index[name]) - else: - o2o.log(f"Warning: {io_name} tensor name '{name}' not found") - - # Update the subgraph - if io_type == 'input': - subgraph.inputs = new_io_list - else: - subgraph.outputs = new_io_list - - o2o.log(f"New {io_name} indices: {[i+1 for i in range(len(new_io_list))]}") - - removed_count = len(io_list) - len(new_io_list) - return removed_count > 0, removed_count - - # Process all subgraphs using utility function - overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) - - if not overall_modified: - o2o.log("No tensors were removed.") - - # Save the model using utility function - o2o.save_model_to_stdout(model) - - -def remove_io_tensors_by_id(io_type: str, ids_to_keep: List[int]) -> None: - """Remove input or output tensors, keeping only specified tensor indices (IDs)""" - model = o2o.load_model_from_stdin() - - def process_subgraph(subgraph): - if io_type == 'input': - io_list = subgraph.inputs - io_name = 'input' - elif io_type == 'output': - io_list = subgraph.outputs - io_name = 'output' - else: - raise ValueError(f"Invalid io_type: {io_type}. Must be 'input' or 'output'") - - o2o.log(f"Processing subgraph with {len(io_list)} {io_name}s") - o2o.log(f"Original {io_name} indices: {io_list}") - - # Keep only those indices whose position in the original list matches ids_to_keep - new_io_list = [] - for idx, tensor_idx in enumerate(io_list): - if idx in ids_to_keep: - new_io_list.append(tensor_idx) - else: - o2o.log(f"Removing {io_name} tensor at position {idx}") - - # Update the subgraph - if io_type == 'input': - subgraph.inputs = new_io_list - else: - subgraph.outputs = new_io_list - - o2o.log(f"New {io_name} indices: {new_io_list}") - - removed_count = len(io_list) - len(new_io_list) - return removed_count > 0, removed_count - - overall_modified, total_changes = o2o.process_subgraphs(model, process_subgraph) - - if not overall_modified: - o2o.log("No tensors were removed.") - - o2o.save_model_to_stdout(model) - - -def main(): - parser = argparse.ArgumentParser( - description= - 'Remove input or output tensors from Circle model, keeping only specified tensor names or IDs' - ) - parser.add_argument('io_type', - choices=['input', 'output'], - help='Whether to process inputs or outputs') - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument( - '--keep_by_name', - help='Comma‑separated tensor names to keep (e.g., "tensorA,tensorB")') - group.add_argument( - '--keep_by_id', - help='Comma‑separated tensor IDs or ranges to keep (e.g., "0,2-4")') - # No file arguments needed; model is read from stdin and written to stdout - - args = parser.parse_args() - - if args.keep_by_name: - # Parse the tensor names - try: - names_to_keep = parse_names(args.keep_by_name) - o2o.log(f"Tensor names to keep: {names_to_keep}") - except ValueError as e: - o2o.log(f"Error parsing tensor names: {e}") - sys.exit(1) - # Execute name‑based removal - remove_io_tensors(args.io_type, names_to_keep) - elif args.keep_by_id: - # Parse the tensor IDs - try: - ids_to_keep = o2o.parse_operator_indices(args.keep_by_id) - o2o.log(f"Tensor IDs to keep: {ids_to_keep}") - except Exception as e: - o2o.log(f"Error parsing tensor IDs: {e}") - sys.exit(1) - # Execute ID‑based removal - remove_io_tensors_by_id(args.io_type, ids_to_keep) - - -if __name__ == "__main__": - main() From 37617bba663195cf72088ca47fff456d10817e92 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 17:14:40 +0900 Subject: [PATCH 23/27] fuse.attention.py (cos, sin) --- tools/o2o/fuse.attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/o2o/fuse.attention.py b/tools/o2o/fuse.attention.py index b0387846103..c0d7ebe2f4d 100755 --- a/tools/o2o/fuse.attention.py +++ b/tools/o2o/fuse.attention.py @@ -107,14 +107,14 @@ def map_attention_inputs(subgraph: 'circle.SubGraphT', block: Dict[str, Any], # 6. position_cos position_cos_idx = o2o.get_tensor_index_by_name( - subgraph, "transformers.models.llama.modeling_llama.LlamaRotaryEmbedding::mul_1") + subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::cos") if position_cos_idx == -1: o2o.log("Could not find position_cos tensor") return None # 7. position_sin position_sin_idx = o2o.get_tensor_index_by_name( - subgraph, "transformers.models.llama.modeling_llama.LlamaRotaryEmbedding::mul_2") + subgraph, "transformers.models.llama.modeling_llama.LlamaForCausalLM::sin") if position_sin_idx == -1: o2o.log("Could not find position_sin tensor") return None From 5b1aabc158dc850e2687f8b89cdb533b666c2097 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 21 Nov 2025 21:03:36 +0900 Subject: [PATCH 24/27] Find attention by pattern + support multiple subgraphs --- tools/o2o/README.md | 46 ++++++ tools/o2o/fuse.attention.py | 302 +++++++++++++++++++++++++++++------- 2 files changed, 296 insertions(+), 52 deletions(-) diff --git a/tools/o2o/README.md b/tools/o2o/README.md index 6344270d355..5a5a2c4a8a5 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -62,6 +62,52 @@ After creating each fused `FULLY_CONNECTED` operator, this script automatically ## +### `fuse.attention.py` + +Fuses attention blocks into a single `ATTENTION` operator. The script automatically detects attention blocks by searching for `FULLY_CONNECTED` operators with `attn_q_proj` weight names, making it robust to models with different prefix operators. + +#### Usage + +**Normal mode** (fuse attention blocks): +```bash +./fuse.attention.py < input.circle > output.circle +``` + +**Debug mode** (inspect operators and extract patterns): +```bash +./fuse.attention.py --debug < input.circle +``` + +Example output: +``` +Index OpCode BuiltinCode Weight Name +----------------------------------------------------------------------------------------------- +0 GATHER 36 +... +20 FULLY_CONNECTED 9 tico::p_model_layers_0_self_attn_q_proj_weight +21 RESHAPE 22 +... +64 FULLY_CONNECTED 9 tico::p_model_layers_0_self_attn_o_proj_weight +... + +Searching for attention block pattern based on weight names... +Found start_op at 20 (Weight: tico::p_model_layers_0_self_attn_q_proj_weight) +Found end_op at 64 (Weight: tico::p_model_layers_0_self_attn_o_proj_weight) + +Extracted range: 20 - 64 +ATTENTION_PATTERN_CODES = [9, 22, 39, 9, 22, 39, 9, 22, 39, 22, 22, 18, 45, 45, 59, 2, 18, 0, 18, 45, 45, 59, 2, 18, 0, 2, 2, 22, 45, 18, 39, 18, 22, 22, 126, 22, 0, 25, 22, 22, 126, 22, 39, 22, 9] +Verified: Pattern starts with FULLY_CONNECTED +Verified: Pattern ends with FULLY_CONNECTED +``` + +#### Features + +- **Dynamic detection**: Automatically finds attention block start offset by searching for weight names +- **Pattern-based fusion**: Fuses 45-operator attention blocks (stride of 65 operators between blocks) +- **Debug mode**: Provides operator inspection and pattern extraction for analysis + + + ### `transpose.io.kcache.py` Finds input tensors matching the pattern `*key_cache_\d+` (e.g., `past_key_values_key_cache_0`) and transposes their second and third dimensions if they are 4D. For example, a shape `[d0, d1, d2, d3]` will become `[d0, d2, d1, d3]`. diff --git a/tools/o2o/fuse.attention.py b/tools/o2o/fuse.attention.py index c0d7ebe2f4d..415a7d9cb86 100755 --- a/tools/o2o/fuse.attention.py +++ b/tools/o2o/fuse.attention.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import numpy as np +import sys from typing import List, Optional, Tuple, Dict, Any import circle import o2o @@ -9,25 +10,174 @@ from circle import (TensorT, OperatorT, SubGraphT, ModelT, BufferT, OperatorCodeT, BuiltinOperator, TensorType) +# Map builtin codes to names for readability (for debug mode) +BUILTIN_NAMES = { + v: k + for k, v in circle.BuiltinOperator.__dict__.items() if not k.startswith('_') +} + +# ============================================================================ +# DEBUG FUNCTIONS +# ============================================================================ + + +def inspect_ops(subgraph: 'circle.SubGraphT', model: 'circle.ModelT', limit: int = 100): + """Inspect and print operator information (for debugging).""" + print(f"{'Index':<5} {'OpCode':27} {'BuiltinCode':^12} {'Weight Name':^55}") + print("-" * 100) + + for i in range(min(limit, len(subgraph.operators))): + op = subgraph.operators[i] + opcode = model.operatorCodes[op.opcodeIndex] + builtin_code = opcode.builtinCode + name = BUILTIN_NAMES.get(builtin_code, str(builtin_code)) + + extra_info = "" + if builtin_code == circle.BuiltinOperator.FULLY_CONNECTED: + # FC inputs: input, weights, bias (optional) + if len(op.inputs) > 1: + weight_tensor = o2o.get_tensor_by_index(subgraph, op.inputs[1]) + if weight_tensor: + extra_info = o2o.get_tensor_name(weight_tensor) or "" + + print(f"{i:>4} {name:<27} {builtin_code:>5} {extra_info:>56}") + + +def extract_pattern(subgraph: 'circle.SubGraphT', model: 'circle.ModelT'): + """Extract attention block pattern for debugging.""" + start_op = -1 + end_op = -1 + + print("\nSearching for attention block pattern based on weight names...") + + for i, op in enumerate(subgraph.operators): + opcode = model.operatorCodes[op.opcodeIndex] + if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: + if len(op.inputs) > 1: + weight_tensor = o2o.get_tensor_by_index(subgraph, op.inputs[1]) + if weight_tensor: + weight_name = o2o.get_tensor_name(weight_tensor) or "" + + # Look for start: attn_q_proj + if start_op == -1 and "attn_q_proj" in weight_name: + start_op = i + print(f"Found start_op at {i} (Weight: {weight_name})") + + # Look for end: attn_o_proj (must be after start) + if start_op != -1 and "attn_o_proj" in weight_name: + end_op = i + print(f"Found end_op at {i} (Weight: {weight_name})") + break + + if start_op != -1 and end_op != -1: + pattern_codes = [] + for i in range(start_op, end_op + 1): + op = subgraph.operators[i] + opcode = model.operatorCodes[op.opcodeIndex] + pattern_codes.append(opcode.builtinCode) + + print(f"\nExtracted range: {start_op} - {end_op}") + print("ATTENTION_PATTERN_CODES = " + str(pattern_codes)) + + if pattern_codes[0] == circle.BuiltinOperator.FULLY_CONNECTED: + print("Verified: Pattern starts with FULLY_CONNECTED") + else: + print(f"Warning: Pattern starts with {pattern_codes[0]}") + + if pattern_codes[-1] == circle.BuiltinOperator.FULLY_CONNECTED: + print("Verified: Pattern ends with FULLY_CONNECTED") + else: + print(f"Warning: Pattern ends with {pattern_codes[-1]}") + + else: + print("Could not find attention block pattern using weight names.") + + +# ============================================================================ +# ATTENTION FUSION FUNCTIONS +# ============================================================================ + + +def find_attention_pattern( + model: 'circle.ModelT', + subgraph: 'circle.SubGraphT') -> Tuple[int, int, int, int, List[int]]: + """ + Dynamically find attention block pattern parameters. + + Returns: + Tuple of (start_offset, block_length, stride, num_blocks, pattern_codes) + + Raises: + RuntimeError if pattern cannot be detected + """ + first_block_start = -1 + first_block_end = -1 + second_block_start = -1 + + # Find first and second attention blocks by searching for q_proj and o_proj + for i, op in enumerate(subgraph.operators): + opcode = model.operatorCodes[op.opcodeIndex] + if opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED: + if len(op.inputs) > 1: + weight_tensor = o2o.get_tensor_by_index(subgraph, op.inputs[1]) + if weight_tensor: + weight_name = o2o.get_tensor_name(weight_tensor) or "" + + # Look for first block start: first attn_q_proj + if first_block_start == -1 and "attn_q_proj" in weight_name: + first_block_start = i + o2o.log(f"Detected first attention block start at {i}") + + # Look for first block end: first attn_o_proj (after start) + elif first_block_start != -1 and first_block_end == -1 and "attn_o_proj" in weight_name: + first_block_end = i + o2o.log(f"Detected first attention block end at {i}") + + # Look for second block start: second attn_q_proj (after first block end) + elif first_block_end != -1 and second_block_start == -1 and "attn_q_proj" in weight_name: + second_block_start = i + o2o.log(f"Detected second attention block start at {i}") + break + + if first_block_start == -1 or first_block_end == -1 or second_block_start == -1: + raise RuntimeError( + "Could not detect attention pattern dynamically. Unable to find first and second attention blocks." + ) + + block_length = first_block_end - first_block_start + stride = second_block_start - first_block_start + + # Calculate number of blocks + num_blocks = 0 + test_start = first_block_start + while test_start + block_length < len(subgraph.operators): + num_blocks += 1 + test_start += stride + + # Extract pattern codes from the first block + pattern_codes = [] + for i in range(first_block_start, first_block_end + 1): + op = subgraph.operators[i] + opcode = model.operatorCodes[op.opcodeIndex] + pattern_codes.append(opcode.builtinCode) + + o2o.log( + f"Detected pattern: start={first_block_start}, block_length={block_length}, stride={stride}, num_blocks={num_blocks}" + ) + return (first_block_start, block_length, stride, num_blocks, pattern_codes) + def find_attention_blocks(model: 'circle.ModelT', subgraph: 'circle.SubGraphT') -> List[Dict[str, Any]]: """Find all attention blocks in the subgraph.""" attention_blocks = [] - # Pattern: 45 operators per attention block - # First block: operators 19-63 - # Second block: operators 84-128 - # Third block: operators 149-193 - # Fourth block: operators 194-238 - # Fifth block: operators 239-283 - # Sixth block: operators 284-328 - # Seventh block: operators 329-373 - # Eighth block: operators 374-418 + start_offset, block_length, stride, num_blocks, pattern_codes = find_attention_pattern( + model, subgraph) - for layer_idx in range(8): - start_op = 20 + (layer_idx * 65) # 64 operators between blocks, starting from 20 - end_op = start_op + 44 # 44 operators total (20-63 inclusive) + for layer_idx in range(num_blocks): + start_op = start_offset + (layer_idx * stride) + end_op = start_op + block_length if end_op < len(subgraph.operators): # Verify this is an attention block by checking key operators @@ -41,6 +191,19 @@ def find_attention_blocks(model: 'circle.ModelT', if (first_opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED and last_opcode.builtinCode == circle.BuiltinOperator.FULLY_CONNECTED): + # Verify the entire block pattern matches the first block's pattern + current_block_codes = [] + for i in range(start_op, end_op + 1): + op = subgraph.operators[i] + opcode = model.operatorCodes[op.opcodeIndex] + current_block_codes.append(opcode.builtinCode) + + if current_block_codes != pattern_codes: + o2o.log( + f"Pattern mismatch for block {layer_idx} at {start_op}: Opcode sequence does not match the first block." + ) + continue + attention_blocks.append({ 'layer_idx': layer_idx, @@ -49,8 +212,17 @@ def find_attention_blocks(model: 'circle.ModelT', 'end_op': end_op, 'operators': - subgraph.operators[start_op:end_op + 1] + subgraph.operators[start_op:end_op + 1], + 'q_proj_op': + start_op, # First FC is Q projection + 'k_proj_op': + start_op + 3, # K projection (after reshape, transpose) + 'v_proj_op': + start_op + 6, # V projection + 'o_proj_op': + end_op # Last FC is output projection }) + o2o.log( f"Found attention block {layer_idx}: operators {start_op}-{end_op}") else: @@ -62,6 +234,7 @@ def find_attention_blocks(model: 'circle.ModelT', f"Block {layer_idx} exceeds operator count (start_op={start_op}, total={len(subgraph.operators)})" ) + o2o.log(f"Found {len(attention_blocks)} attention blocks to fuse") return attention_blocks @@ -171,63 +344,88 @@ def map_attention_inputs(subgraph: 'circle.SubGraphT', block: Dict[str, Any], def fuse_attention(): """Main function to fuse attention operators.""" - o2o.log("Loading model from stdin") model = o2o.load_model_from_stdin() if not model.subgraphs: o2o.log("Model has no subgraphs. Exiting.") return - subgraph = model.subgraphs[0] # Assuming single subgraph for now - attention_blocks = find_attention_blocks(model, subgraph) + # Process all subgraphs + for subgraph_idx, subgraph in enumerate(model.subgraphs): + o2o.log(f"Processing subgraph {subgraph_idx}...") - o2o.log(f"Found {len(attention_blocks)} attention blocks to fuse") + attention_blocks = find_attention_blocks(model, subgraph) - operators_to_remove = [] + o2o.log( + f"Found {len(attention_blocks)} attention blocks to fuse in subgraph {subgraph_idx}" + ) - for block in attention_blocks: - # Map input tensors - input_indices = map_attention_inputs(subgraph, block, model) - if input_indices is None: - o2o.log( - f"Skipping attention block {block['layer_idx']} due to missing inputs") - continue + operators_to_remove = [] - # Create ATTENTION operator - attention_op = circle.OperatorT() - attention_op.opcodeIndex = o2o.get_or_create_operator_code( - model, circle.BuiltinOperator.ATTENTION) - attention_op.inputs = input_indices - attention_op.outputs = [block['operators'][-1].outputs[0] - ] # Use last operator's output + for block in attention_blocks: + # Map input tensors + input_indices = map_attention_inputs(subgraph, block, model) + if input_indices is None: + o2o.log( + f"Skipping attention block {block['layer_idx']} due to missing inputs" + ) + continue - # Configure AttentionOptions (empty since it's deprecated) - attention_op.builtinOptionsType = circle.BuiltinOptions.AttentionOptions - attention_op.builtinOptions = circle.AttentionOptionsT() + # Create ATTENTION operator + attention_op = circle.OperatorT() + attention_op.opcodeIndex = o2o.get_or_create_operator_code( + model, circle.BuiltinOperator.ATTENTION) + attention_op.inputs = input_indices + attention_op.outputs = [block['operators'][-1].outputs[0] + ] # Use last operator's output - # Replace the first operator with ATTENTION operator - start_idx = block['start_op'] - subgraph.operators[start_idx] = attention_op + # Configure AttentionOptions (empty since it's deprecated) + attention_op.builtinOptionsType = circle.BuiltinOptions.AttentionOptions + attention_op.builtinOptions = circle.AttentionOptionsT() - # Mark intermediate operators for removal (except the first one which we replaced) - for i in range(block['end_op'], start_idx, -1): - operators_to_remove.append(i) + # Replace the first operator with ATTENTION operator + start_idx = block['start_op'] + subgraph.operators[start_idx] = attention_op - o2o.log( - f"Fused attention block {block['layer_idx']}: operators {block['start_op']}-{block['end_op']} -> ATTENTION" - ) + # Mark intermediate operators for removal (except the first one which we replaced) + for i in range(block['end_op'], start_idx, -1): + operators_to_remove.append(i) + + o2o.log( + f"Fused attention block {block['layer_idx']}: operators {block['start_op']}-{block['end_op']} -> ATTENTION" + ) - # Remove marked operators in reverse order to avoid index shifting - for i in sorted(operators_to_remove, reverse=True): - if 0 <= i < len(subgraph.operators): - del subgraph.operators[i] + # Remove marked operators in reverse order to avoid index shifting + for i in sorted(operators_to_remove, reverse=True): + if 0 <= i < len(subgraph.operators): + del subgraph.operators[i] - o2o.log(f"Removed {len(operators_to_remove)} intermediate operators") - o2o.log(f"Model now has {len(subgraph.operators)} operators") + o2o.log( + f"Removed {len(operators_to_remove)} intermediate operators from subgraph {subgraph_idx}" + ) + o2o.log(f"Subgraph {subgraph_idx} now has {len(subgraph.operators)} operators") o2o.save_model_to_stdout(model) if __name__ == "__main__": - # Directly invoke processing; I/O handled via stdin/stdout - fuse_attention() + # Check for inspect mode flag + if len(sys.argv) > 1 and sys.argv[1] == "--inspect": + # Inspect mode: analyze operators and extract pattern (no fusion) + o2o.log("Running in INSPECT mode") + model = o2o.load_model_from_stdin() + + # Process all subgraphs in inspect mode + for subgraph_idx, subgraph in enumerate(model.subgraphs): + o2o.log(f"\n{'='*100}") + o2o.log(f"Subgraph {subgraph_idx}") + o2o.log(f"{'='*100}") + + # Inspect first 100 ops + inspect_ops(subgraph, model) + + # Extract pattern dynamically + extract_pattern(subgraph, model) + else: + # Normal mode: fuse attention blocks + fuse_attention() From 296016c01970c30fb0f305faf7142e11dc0824fb Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Sat, 22 Nov 2025 14:13:02 +0900 Subject: [PATCH 25/27] gc.py : Support multiple subgraph (including signatureDefs) --- tools/o2o/gc.py | 121 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 26 deletions(-) diff --git a/tools/o2o/gc.py b/tools/o2o/gc.py index d82404e9880..5a5eeb8c5e1 100755 --- a/tools/o2o/gc.py +++ b/tools/o2o/gc.py @@ -18,18 +18,18 @@ def get_tensor_name(tensor: TensorT) -> Optional[str]: return None -def find_unused_tensors_in_subgraph(subgraph) -> List[int]: +def find_unused_tensors_in_subgraph(subgraph: SubGraphT) -> List[int]: """ Finds and returns the indices of unused tensors in a given subgraph. - This function uses the Native API for read-only subgraph objects. + This function uses the Object API for mutable subgraph objects. Args: - subgraph: The Circle read-only subgraph object. + subgraph: The Circle mutable subgraph object (SubGraphT). Returns: list: A list of integer indices representing unused tensors. """ - num_tensors = subgraph.TensorsLength() + num_tensors = len(subgraph.tensors) if num_tensors == 0: return [] @@ -37,18 +37,16 @@ def find_unused_tensors_in_subgraph(subgraph) -> List[int]: output_tensor_indices = set() # Collect output tensor indices - for i in range(subgraph.OutputsLength()): - output_tensor_indices.add(subgraph.Outputs(i)) + if subgraph.outputs is not None and len(subgraph.outputs) > 0: + for out in subgraph.outputs: + output_tensor_indices.add(out) # Collect input tensor indices from all operators - for i in range(subgraph.OperatorsLength()): - operator = subgraph.Operators(i) - if operator and operator.InputsLength(): - for j in range(operator.InputsLength()): - input_tensor_index = operator.Inputs(j) - # In Circle schema, -1 indicates an optional input that is not used. - if input_tensor_index != -1: - used_tensor_indices.add(input_tensor_index) + for operator in subgraph.operators: + if operator.inputs is not None and len(operator.inputs) > 0: + for inp in operator.inputs: + if inp != -1: + used_tensor_indices.add(inp) # A tensor is unused if it's not used by any operator AND not an output of the subgraph unused_indices = [] @@ -215,6 +213,38 @@ def remove_tensors_and_update_model(model: ModelT, subgraph_index_to_modify: int updated_subgraph_outputs.append(new_indices_map[old_output_idx]) subgraph.outputs = updated_subgraph_outputs + # Update SignatureDefs + if model.signatureDefs: + for sig_def in model.signatureDefs: + if sig_def.subgraphIndex == subgraph_index_to_modify: + # Update inputs + if sig_def.inputs: + updated_sig_inputs = [] + for tensor_map in sig_def.inputs: + if tensor_map.tensorIndex in new_indices_map: + tensor_map.tensorIndex = new_indices_map[ + tensor_map.tensorIndex] + updated_sig_inputs.append(tensor_map) + elif tensor_map.tensorIndex in tensor_indices_to_remove: + o2o.log( + f" SignatureDef '{sig_def.signatureKey}': Removing input tensor index {tensor_map.tensorIndex}" + ) + sig_def.inputs = updated_sig_inputs + + # Update outputs + if sig_def.outputs: + updated_sig_outputs = [] + for tensor_map in sig_def.outputs: + if tensor_map.tensorIndex in new_indices_map: + tensor_map.tensorIndex = new_indices_map[ + tensor_map.tensorIndex] + updated_sig_outputs.append(tensor_map) + elif tensor_map.tensorIndex in tensor_indices_to_remove: + o2o.log( + f" SignatureDef '{sig_def.signatureKey}': Removing output tensor index {tensor_map.tensorIndex}" + ) + sig_def.outputs = updated_sig_outputs + return sorted(removed_indices) @@ -274,6 +304,40 @@ def remove_buffers_and_update_model(model: ModelT, return sorted(removed_indices) +def update_signature_defs_for_pruned_io(model: ModelT, subgraph_index: int, + removed_inputs: List[int], + removed_outputs: List[int]): + """ + Updates SignatureDefs to remove references to pruned inputs/outputs. + """ + if not model.signatureDefs: + return + + for sig_def in model.signatureDefs: + if sig_def.subgraphIndex == subgraph_index: + # Update inputs + if sig_def.inputs and removed_inputs: + original_len = len(sig_def.inputs) + sig_def.inputs = [ + tm for tm in sig_def.inputs if tm.tensorIndex not in removed_inputs + ] + if len(sig_def.inputs) < original_len: + o2o.log( + f" SignatureDef '{sig_def.signatureKey}': Pruned {original_len - len(sig_def.inputs)} input(s)" + ) + + # Update outputs + if sig_def.outputs and removed_outputs: + original_len = len(sig_def.outputs) + sig_def.outputs = [ + tm for tm in sig_def.outputs if tm.tensorIndex not in removed_outputs + ] + if len(sig_def.outputs) < original_len: + o2o.log( + f" SignatureDef '{sig_def.signatureKey}': Pruned {original_len - len(sig_def.outputs)} output(s)" + ) + + def prune_unused_io(model: ModelT) -> bool: """ Removes tensors from Subgraph Inputs/Outputs if they are not connected to any operator. @@ -286,6 +350,9 @@ def prune_unused_io(model: ModelT) -> bool: """ changed = False for i, subgraph in enumerate(model.subgraphs): + removed_inputs = [] + removed_outputs = [] + # Collect used inputs and outputs from operators op_inputs = set() op_outputs = set() @@ -304,9 +371,9 @@ def prune_unused_io(model: ModelT) -> bool: original_len = len(subgraph.inputs) new_inputs = [idx for idx in subgraph.inputs if idx in op_inputs] if len(new_inputs) < original_len: - removed = [idx for idx in subgraph.inputs if idx not in op_inputs] + removed_inputs = [idx for idx in subgraph.inputs if idx not in op_inputs] o2o.log( - f"Subgraph {i}: Pruning unused inputs (not consumed by any op): {removed}" + f"Subgraph {i}: Pruning unused inputs (not consumed by any op): {removed_inputs}" ) subgraph.inputs = new_inputs changed = True @@ -317,13 +384,19 @@ def prune_unused_io(model: ModelT) -> bool: original_len = len(subgraph.outputs) new_outputs = [idx for idx in subgraph.outputs if idx in op_outputs] if len(new_outputs) < original_len: - removed = [idx for idx in subgraph.outputs if idx not in op_outputs] + removed_outputs = [ + idx for idx in subgraph.outputs if idx not in op_outputs + ] o2o.log( - f"Subgraph {i}: Pruning unused outputs (not produced by any op): {removed}" + f"Subgraph {i}: Pruning unused outputs (not produced by any op): {removed_outputs}" ) subgraph.outputs = new_outputs changed = True + # Update SignatureDefs if any IO was pruned + if removed_inputs or removed_outputs: + update_signature_defs_for_pruned_io(model, i, removed_inputs, removed_outputs) + return changed @@ -343,14 +416,10 @@ def main(): if prune_unused_io(model): model_changed = True - o2o.log(f"Processing {model_ro.SubgraphsLength()} subgraph(s)...") - for i in range(model_ro.SubgraphsLength()): - subgraph_ro = model_ro.Subgraphs(i) - if not subgraph_ro: - o2o.log(f"Warning: Could not read subgraph {i}. Skipping.") - continue - - unused = find_unused_tensors_in_subgraph(subgraph_ro) + o2o.log(f"Processing {len(model.subgraphs)} subgraph(s)...") + for i, subgraph in enumerate(model.subgraphs): + # Use the mutable subgraph which might have been updated by prune_unused_io + unused = find_unused_tensors_in_subgraph(subgraph) if not unused: o2o.log(f"Subgraph {i}: No unused tensors found.") continue From b6636cb1cceaf3736c7266294aa2b2485e8926fc Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Sat, 22 Nov 2025 14:46:51 +0900 Subject: [PATCH 26/27] retype.input_ids.py->downcast.input_ids.py: multiple subgraphs,pattern --- tools/o2o/README.md | 4 +- ...ype.input_ids.py => downcast.input_ids.py} | 48 +++++++++++-------- 2 files changed, 31 insertions(+), 21 deletions(-) rename tools/o2o/{retype.input_ids.py => downcast.input_ids.py} (50%) diff --git a/tools/o2o/README.md b/tools/o2o/README.md index 5a5a2c4a8a5..be7446d1c5d 100644 --- a/tools/o2o/README.md +++ b/tools/o2o/README.md @@ -144,9 +144,9 @@ Performs garbage collection by removing unreachable tensors and buffers, reducin ## -### `retype.input_ids.py` +### `downcast.input_ids.py` -Finds tensors named `input_ids` and changes their data type from int64 to int32. This filter is useful for models that need to be compatible with hardware or frameworks that expect input_ids to be 32-bit integers instead of 64-bit integers. +Identifies `input_ids` tensors based on the graph structure (specifically, `INT64` tensors that are the indices input to `GATHER` operators and are also Subgraph Inputs) and changes their data type from `int64` to `int32`. This robust detection method works regardless of the tensor name. This filter is useful for models that need to be compatible with hardware or frameworks that expect input_ids to be 32-bit integers. ## diff --git a/tools/o2o/retype.input_ids.py b/tools/o2o/downcast.input_ids.py similarity index 50% rename from tools/o2o/retype.input_ids.py rename to tools/o2o/downcast.input_ids.py index 4f360621cd9..301cb7cab95 100755 --- a/tools/o2o/retype.input_ids.py +++ b/tools/o2o/downcast.input_ids.py @@ -18,25 +18,35 @@ def process_subgraph(subgraph): o2o.log(f"Processing subgraph with {len(subgraph.tensors)} tensors") retyped_count = 0 - for tensor in subgraph.tensors: - tensor_name = o2o.get_tensor_name(tensor) - - # Check if this is the input_ids tensor - if tensor_name == "tico::input_ids": - # Check if current type is int64 - if tensor.type == circle.TensorType.INT64: - old_type = "int64" - new_type = "int32" - - # Change type to int32 - tensor.type = circle.TensorType.INT32 - - o2o.log(f"Retyped tensor: {tensor_name} {old_type} → {new_type}") - retyped_count += 1 - else: - o2o.log( - f"Found input_ids tensor but type is not int64 (current type: {tensor.type})" - ) + + # Collect subgraph inputs for quick lookup + subgraph_inputs = set(subgraph.inputs) + + for op_idx, op in enumerate(subgraph.operators): + opcode = model.operatorCodes[op.opcodeIndex] + + if opcode.builtinCode == circle.BuiltinOperator.GATHER: + # GATHER input 1 is the indices tensor (params, indices) + if op.inputs is not None and len(op.inputs) > 1: + input_tensor_idx = op.inputs[1] + + # Check if this input is a subgraph input + if input_tensor_idx in subgraph_inputs: + tensor = subgraph.tensors[input_tensor_idx] + + # Check if type is INT64 + if tensor.type == circle.TensorType.INT64: + tensor_name = o2o.get_tensor_name(tensor) + old_type = "int64" + new_type = "int32" + + # Change type to int32 + tensor.type = circle.TensorType.INT32 + + o2o.log( + f"Retyped tensor: {tensor_name} (Index: {input_tensor_idx}) {old_type} → {new_type}" + ) + retyped_count += 1 if retyped_count > 0: o2o.log(f"Retyped {retyped_count} input_ids tensors in this subgraph") From 9226e00ca50ac87baef26d7b92910a0c69171316 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Mon, 24 Nov 2025 12:29:08 +0900 Subject: [PATCH 27/27] fuse.bmm_lhs_const.py : multiple subgraphs support --- tools/o2o/fuse.bmm_lhs_const.py | 262 ++++++++++++++++---------------- 1 file changed, 132 insertions(+), 130 deletions(-) diff --git a/tools/o2o/fuse.bmm_lhs_const.py b/tools/o2o/fuse.bmm_lhs_const.py index 24488719a5f..e4a348209ad 100755 --- a/tools/o2o/fuse.bmm_lhs_const.py +++ b/tools/o2o/fuse.bmm_lhs_const.py @@ -17,15 +17,14 @@ def is_effectively_2d(shape): return all(dim == 1 for dim in shape[:-2]) -def count_tensor_usage(model, tensor_index): - """Count how many operators use a specific tensor as input""" +def count_tensor_usage(subgraph, tensor_index): + """Count how many operators use a specific tensor as input within a subgraph""" count = 0 - for subgraph_idx, subgraph in enumerate(model.subgraphs): - for operator in subgraph.operators: - if operator.inputs is not None: - for input_idx in operator.inputs: - if input_idx == tensor_index: - count += 1 + for operator in subgraph.operators: + if operator.inputs is not None: + for input_idx in operator.inputs: + if input_idx == tensor_index: + count += 1 return count @@ -92,7 +91,7 @@ def reshape_fc_weights(model: 'circle.ModelT', subgraph: 'circle.SubGraphT', fc_op.builtinOptions.keepNumDims = True # Check if this tensor is used by multiple operators - usage_count = count_tensor_usage(model, weights_index) + usage_count = count_tensor_usage(subgraph, weights_index) if usage_count > 1: # Create a new tensor for this operator to avoid affecting others @@ -178,139 +177,142 @@ def fuse_bmm_transpose() -> None: o2o.save_model_to_stdout(model) # Output to stdout for consistency return - subgraph = model.subgraphs[0] # Assuming single subgraph for now, can be extended - tensors_to_potentially_remove = set() - # Define operators to remove (empty list for now) - operators_to_remove = [] # No operators to remove by default + for subgraph_idx, subgraph in enumerate(model.subgraphs): + o2o.log(f"Processing subgraph {subgraph_idx}") + tensors_to_potentially_remove = set() + # Define operators to remove (empty list for now) + operators_to_remove = [] # No operators to remove by default - # Iterate backwards to safely remove operators - for i in range(len(subgraph.operators) - 1, -1, -1): - transpose_op = subgraph.operators[i] + # Iterate backwards to safely remove operators + for i in range(len(subgraph.operators) - 1, -1, -1): + transpose_op = subgraph.operators[i] - # Check if current operator is TRANSPOSE - transpose_opcode = model.operatorCodes[transpose_op.opcodeIndex] - if transpose_opcode.builtinCode != circle.BuiltinOperator.TRANSPOSE: - continue + # Check if current operator is TRANSPOSE + transpose_opcode = model.operatorCodes[transpose_op.opcodeIndex] + if transpose_opcode.builtinCode != circle.BuiltinOperator.TRANSPOSE: + continue - if len(transpose_op.inputs) != 2: - o2o.log( - f"Transpose operator at index {i} has invalid number of inputs. Skipping." - ) - continue + if len(transpose_op.inputs) != 2: + o2o.log( + f"Transpose operator at index {i} has invalid number of inputs. Skipping." + ) + continue - transpose_input_tensor_idx = transpose_op.inputs[0] - bmm_op_idx, bmm_op = o2o.find_operator_by_output(subgraph, - transpose_input_tensor_idx) + transpose_input_tensor_idx = transpose_op.inputs[0] + bmm_op_idx, bmm_op = o2o.find_operator_by_output(subgraph, + transpose_input_tensor_idx) - # Check if the found operator is BATCH_MATMUL - if bmm_op is None or model.operatorCodes[ - bmm_op.opcodeIndex].builtinCode != circle.BuiltinOperator.BATCH_MATMUL: - continue + # Check if the found operator is BATCH_MATMUL + if bmm_op is None or model.operatorCodes[ + bmm_op. + opcodeIndex].builtinCode != circle.BuiltinOperator.BATCH_MATMUL: + continue - lhs_tensor_index = bmm_op.inputs[0] - rhs_tensor_index = bmm_op.inputs[1] + lhs_tensor_index = bmm_op.inputs[0] + rhs_tensor_index = bmm_op.inputs[1] - lhs_tensor = o2o.get_tensor_by_index(subgraph, lhs_tensor_index) - rhs_tensor = o2o.get_tensor_by_index(subgraph, rhs_tensor_index) + lhs_tensor = o2o.get_tensor_by_index(subgraph, lhs_tensor_index) + rhs_tensor = o2o.get_tensor_by_index(subgraph, rhs_tensor_index) - if not lhs_tensor or not rhs_tensor: - o2o.log( - f"Could not find LHS or RHS tensor for BATCH_MATMUL at index {bmm_op_idx}. Skipping." - ) - continue + if not lhs_tensor or not rhs_tensor: + o2o.log( + f"Could not find LHS or RHS tensor for BATCH_MATMUL at index {bmm_op_idx}. Skipping." + ) + continue - # Crucial check: LHS must be constant - if not o2o.is_tensor_constant(lhs_tensor, model.buffers): - o2o.log( - f"LHS tensor '{lhs_tensor.name if lhs_tensor.name else lhs_tensor_index}' for BATCH_MATMUL at index {bmm_op_idx} is not constant. Skipping fusion." - ) - continue - - # Verify Transpose permutation (assuming transpose of last two dims) - # e.g. for [..., M, N] -> [..., N, M], permutation is [..., dim_N-1, dim_N-2] - # For a 2D tensor [M, N] -> [N, M], permutation is [1, 0] - # For a 3D tensor [B, M, N] -> [B, N, M], permutation is [0, 2, 1] - valid_permutation = False - perm_tensor_index = transpose_op.inputs[1] - perm_tensor = o2o.get_tensor_by_index(subgraph, perm_tensor_index) - - if perm_tensor and o2o.is_tensor_constant(perm_tensor, model.buffers): - # Get permutation data from buffer using the new helper function - perm = o2o.from_buffer(perm_tensor.buffer, model.buffers) - if perm is None: + # Crucial check: LHS must be constant + if not o2o.is_tensor_constant(lhs_tensor, model.buffers): o2o.log( - f"Could not read permutation buffer at index {perm_tensor.buffer}") - valid_permutation = False + f"LHS tensor '{lhs_tensor.name if lhs_tensor.name else lhs_tensor_index}' for BATCH_MATMUL at index {bmm_op_idx} is not constant. Skipping fusion." + ) + continue + + # Verify Transpose permutation (assuming transpose of last two dims) + # e.g. for [..., M, N] -> [..., N, M], permutation is [..., dim_N-1, dim_N-2] + # For a 2D tensor [M, N] -> [N, M], permutation is [1, 0] + # For a 3D tensor [B, M, N] -> [B, N, M], permutation is [0, 2, 1] + valid_permutation = False + perm_tensor_index = transpose_op.inputs[1] + perm_tensor = o2o.get_tensor_by_index(subgraph, perm_tensor_index) + + if perm_tensor and o2o.is_tensor_constant(perm_tensor, model.buffers): + # Get permutation data from buffer using the new helper function + perm = o2o.from_buffer(perm_tensor.buffer, model.buffers) + if perm is None: + o2o.log( + f"Could not read permutation buffer at index {perm_tensor.buffer}" + ) + valid_permutation = False + else: + # Use the tensor's shape to determine the actual permutation length + perm_length = perm_tensor.shape[0] if perm_tensor.shape else len(perm) + perm = perm[:perm_length] # Trim to actual length + + if len(perm) >= 2: # At least 2D + # Check if the last two elements of permutation are swapped + # and other elements are in their original ascending order (0, 1, 2, ...) + expected_perm_prefix = list(range(len(perm) - 2)) + actual_perm_prefix = perm[:-2] + + if np.all(actual_perm_prefix == expected_perm_prefix) and \ + perm[-2] == len(perm) - 1 and \ + perm[-1] == len(perm) - 2: + valid_permutation = True else: - # Use the tensor's shape to determine the actual permutation length - perm_length = perm_tensor.shape[0] if perm_tensor.shape else len(perm) - perm = perm[:perm_length] # Trim to actual length - - if len(perm) >= 2: # At least 2D - # Check if the last two elements of permutation are swapped - # and other elements are in their original ascending order (0, 1, 2, ...) - expected_perm_prefix = list(range(len(perm) - 2)) - actual_perm_prefix = perm[:-2] - - if np.all(actual_perm_prefix == expected_perm_prefix) and \ - perm[-2] == len(perm) - 1 and \ - perm[-1] == len(perm) - 2: - valid_permutation = True - else: - o2o.log( - f"Permutation tensor for TRANSPOSE at index {i} is not constant or not found. Skipping." - ) + o2o.log( + f"Permutation tensor for TRANSPOSE at index {i} is not constant or not found. Skipping." + ) - if not valid_permutation: - o2o.log( - f"TRANSPOSE operator at index {i} does not have a simple last-two-dim permutation. Skipping fusion." - ) - continue - - # Add TRANSPOSE for RHS if needed (K != 1 OR B != 1) - final_rhs_tensor_index = add_rhs_transpose_if_needed(model, subgraph, bmm_op_idx, - rhs_tensor_index, rhs_tensor) - - # Create the new FULLY_CONNECTED operator - fc_op = circle.OperatorT() - fc_op.opcodeIndex = o2o.get_or_create_operator_code( - model, circle.BuiltinOperator.FULLY_CONNECTED) - # Set inputs: [transposed_rhs, original_lhs, -1] where -1 means bias not exists - fc_op.inputs = [final_rhs_tensor_index, lhs_tensor_index, -1] - # Set outputs: same as the original TRANSPOSE operator - fc_op.outputs = list(transpose_op.outputs) # Make a copy - - # Configure FULLY_CONNECTED options - fc_op.builtinOptionsType = (circle.BuiltinOptions.FullyConnectedOptions) - fc_options = circle.FullyConnectedOptionsT() - fc_options.keepNumDims = True # Important to preserve batch dimensions from BATCH_MATMUL - fc_op.builtinOptions = fc_options - - # Add the new operator to the subgraph - # Insert it at the position of the original BATCH_MATMUL operator - o2o.log(f"Replacing batchmatmul at {bmm_op_idx} with fullyconnected") - subgraph.operators[bmm_op_idx] = fc_op - - # Reshape the weights of the newly created FC operator - reshape_fc_weights(model, subgraph, bmm_op_idx) - - # Mark the original TRANSPOSE operator for removal - operators_to_remove.append(i) - - # The tensor connecting BMM and Transpose (bmm_output_tensor_index) is now an intermediate - # output of the new FC op. If it's not used by any other op, it could be cleaned up. - # For now, we just mark it. Actual removal is more complex (needs usage check). - tensors_to_potentially_remove.add(transpose_input_tensor_idx) - - # Remove operators marked for removal (iterate backwards again for safe removal) - for i in sorted(list(operators_to_remove), reverse=True): - if 0 <= i < len(subgraph.operators): - o2o.log(f"Removing transpose operator at index {i}") - del subgraph.operators[i] - - # Note: Cleanup of unused tensors and operator codes is a more advanced step - # and not implemented here for simplicity, but would be part of a production-ready script. - o2o.log(f"TODO: Remove tensors at {tensors_to_potentially_remove}") + if not valid_permutation: + o2o.log( + f"TRANSPOSE operator at index {i} does not have a simple last-two-dim permutation. Skipping fusion." + ) + continue + + # Add TRANSPOSE for RHS if needed (K != 1 OR B != 1) + final_rhs_tensor_index = add_rhs_transpose_if_needed( + model, subgraph, bmm_op_idx, rhs_tensor_index, rhs_tensor) + + # Create the new FULLY_CONNECTED operator + fc_op = circle.OperatorT() + fc_op.opcodeIndex = o2o.get_or_create_operator_code( + model, circle.BuiltinOperator.FULLY_CONNECTED) + # Set inputs: [transposed_rhs, original_lhs, -1] where -1 means bias not exists + fc_op.inputs = [final_rhs_tensor_index, lhs_tensor_index, -1] + # Set outputs: same as the original TRANSPOSE operator + fc_op.outputs = list(transpose_op.outputs) # Make a copy + + # Configure FULLY_CONNECTED options + fc_op.builtinOptionsType = (circle.BuiltinOptions.FullyConnectedOptions) + fc_options = circle.FullyConnectedOptionsT() + fc_options.keepNumDims = True # Important to preserve batch dimensions from BATCH_MATMUL + fc_op.builtinOptions = fc_options + + # Add the new operator to the subgraph + # Insert it at the position of the original BATCH_MATMUL operator + o2o.log(f"Replacing batchmatmul at {bmm_op_idx} with fullyconnected") + subgraph.operators[bmm_op_idx] = fc_op + + # Reshape the weights of the newly created FC operator + reshape_fc_weights(model, subgraph, bmm_op_idx) + + # Mark the original TRANSPOSE operator for removal + operators_to_remove.append(i) + + # The tensor connecting BMM and Transpose (bmm_output_tensor_index) is now an intermediate + # output of the new FC op. If it's not used by any other op, it could be cleaned up. + # For now, we just mark it. Actual removal is more complex (needs usage check). + tensors_to_potentially_remove.add(transpose_input_tensor_idx) + + # Remove operators marked for removal (iterate backwards again for safe removal) + for i in sorted(list(operators_to_remove), reverse=True): + if 0 <= i < len(subgraph.operators): + o2o.log(f"Removing transpose operator at index {i}") + del subgraph.operators[i] + + # Note: Cleanup of unused tensors and operator codes is a more advanced step + # and not implemented here for simplicity, but would be part of a production-ready script. + o2o.log(f"TODO: Remove tensors at {tensors_to_potentially_remove}") o2o.save_model_to_stdout(model)