Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion nnpackage/schema/circle_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
// Version 0.9: GGML_Q{X}_{Y} types are added. Weight compression option is added.
// ROPE op is added. MXFP4, MXINT8 types are added.
// MXQuantization is added.
// Version 0.10: Base up to TensorFlow Lite v2.20.0 schema. RUN_MODEL op is added.
// Version 0.10: Base up to TensorFlow Lite v2.20.0 schema.
// RUN_MODEL and ATTENTION op are added.

namespace circle;

Expand Down Expand Up @@ -317,6 +318,7 @@ table Tensor {
// set of acceptable options.
// LINT.IfChange
enum BuiltinOperator : int32 {
ATTENTION = -9,
RUN_MODEL = -8,
ROPE = -7,
RMS_NORM = -6,
Expand Down Expand Up @@ -673,6 +675,7 @@ union BuiltinOptions {
BitcastOptions,
BitwiseXorOptions,
RightShiftOptions,
AttentionOptions = 247,
RunModelOptions = 248,
RoPEOptions = 249,
RmsNormOptions = 250,
Expand Down Expand Up @@ -1588,6 +1591,9 @@ table RunModelOptions {
signature:string;
}

table AttentionOptions {
}

// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
Expand Down
8 changes: 7 additions & 1 deletion runtime/libs/circle-schema/circle_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
// Version 0.9: GGML_Q{X}_{Y} types are added. Weight compression option is added.
// ROPE op is added. MXFP4, MXINT8 types are added.
// MXQuantization is added.
// Version 0.10: Base up to TensorFlow Lite v2.20.0 schema. RUN_MODEL op is added.
// Version 0.10: Base up to TensorFlow Lite v2.20.0 schema.
// RUN_MODEL and ATTENTION op are added.

namespace circle;

Expand Down Expand Up @@ -317,6 +318,7 @@ table Tensor {
// set of acceptable options.
// LINT.IfChange
enum BuiltinOperator : int32 {
ATTENTION = -9,
RUN_MODEL = -8,
ROPE = -7,
RMS_NORM = -6,
Expand Down Expand Up @@ -673,6 +675,7 @@ union BuiltinOptions {
BitcastOptions,
BitwiseXorOptions,
RightShiftOptions,
AttentionOptions = 247,
RunModelOptions = 248,
RoPEOptions = 249,
RmsNormOptions = 250,
Expand Down Expand Up @@ -1588,6 +1591,9 @@ table RunModelOptions {
signature:string;
}

table AttentionOptions {
}

// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
Expand Down
42 changes: 42 additions & 0 deletions runtime/onert/backend/cpu/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ops/AddNLayer.h"
#include "ops/ArgMinMaxLayer.h"
#include "ops/AttentionLayer.h"
#include "ops/BatchToSpaceNDLayer.h"
#include "ops/BinaryArithmeticLayer.h"
#include "ops/ComparisonLayer.h"
Expand Down Expand Up @@ -1556,4 +1557,45 @@ void KernelGenerator::visit(const ir::operation::RoPE &node)
_return_fn = std::move(fn);
}

void KernelGenerator::visit(const ir::operation::Attention &node)
{
using ir::operation::Attention;

const auto input_index{node.getInputs().at(Attention::Input::INPUT)};
const auto wq_index{node.getInputs().at(Attention::Input::WQ)};
const auto wk_index{node.getInputs().at(Attention::Input::WK)};
const auto wv_index{node.getInputs().at(Attention::Input::WV)};
const auto wo_index{node.getInputs().at(Attention::Input::WO)};
const auto cos_index = node.getInputs().at(Attention::Input::COS);
const auto sin_index = node.getInputs().at(Attention::Input::SIN);
const auto mask_index = node.getInputs().at(Attention::Input::MASK);
const auto k_cache_index = node.getInputs().at(Attention::Input::K_CACHE);
const auto v_cache_index = node.getInputs().at(Attention::Input::V_CACHE);
const auto pos_index = node.getInputs().at(Attention::Input::POS);

const auto output_index{node.getOutputs().at(0)};
auto output_tensor = _tensor_reg->getPortableTensor(output_index);

auto input_tensor = _tensor_reg->getPortableTensor(input_index);
auto wq_tensor = _tensor_reg->getPortableTensor(wq_index);
auto wk_tensor = _tensor_reg->getPortableTensor(wk_index);
auto wv_tensor = _tensor_reg->getPortableTensor(wv_index);
auto wo_tensor = _tensor_reg->getPortableTensor(wo_index);
auto cos_tensor = cos_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(cos_index);
auto sin_tensor = sin_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(sin_index);
auto mask_tensor = mask_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(mask_index);
auto k_cache_tensor =
k_cache_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(k_cache_index);
auto v_cache_tensor =
v_cache_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(v_cache_index);
auto pos_tensor = pos_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(pos_index);

auto fn = std::make_unique<ops::AttentionLayer>();

fn->configure(input_tensor, wq_tensor, wk_tensor, wv_tensor, wo_tensor, cos_tensor, sin_tensor,
mask_tensor, k_cache_tensor, v_cache_tensor, pos_tensor, output_tensor);

_return_fn = std::move(fn);
}

} // namespace onert::backend::cpu
1 change: 1 addition & 0 deletions runtime/onert/backend/cpu/Operation.lst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
OP(AddN)
OP(ArgMinMax)
OP(Attention)
OP(BatchMatMul)
OP(BatchToSpaceND)
OP(BinaryArithmetic)
Expand Down
Loading