Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9982412
Add walkProgramGraph driver
MatthiasReumann Apr 24, 2026
47019cf
Move traits to wire iterator file
MatthiasReumann Apr 24, 2026
adeb7bb
Update documentation
MatthiasReumann Apr 24, 2026
246fba0
Add missing includes
MatthiasReumann Apr 24, 2026
420140d
Fix cast linting
MatthiasReumann Apr 24, 2026
062b4dc
Increase coverage
MatthiasReumann Apr 24, 2026
82d2209
Apply rabbit suggestions
MatthiasReumann Apr 24, 2026
f3ea610
Add missing includes
MatthiasReumann Apr 24, 2026
5fab8b2
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 24, 2026
6843c04
Update documentation
MatthiasReumann Apr 27, 2026
ee8dbc9
Merge branch 'enh/walk-program-graph' of https://github.com/munich-qu…
MatthiasReumann Apr 27, 2026
4472a9f
Merge branch 'main' into enh/walk-program-graph
MatthiasReumann Apr 27, 2026
0e1480d
🔥 Remove unreachable code
burgholzer Apr 27, 2026
b6c5888
🎨 Avoid a couple of copies of the 16-Byte UnitaryOpInterface
burgholzer Apr 27, 2026
a92c33d
🎨 Simplify some of the walkProgram function logic
burgholzer Apr 27, 2026
aebcbb7
🎨 Remove redundant namespace qualifier
burgholzer Apr 27, 2026
8e43104
🎨 Avoid warning about unused return value
burgholzer Apr 27, 2026
f66ed6a
📝 Add changelog entry
burgholzer Apr 27, 2026
ef4ab25
🩹 Fix test case
burgholzer Apr 27, 2026
5950d26
Add clarifying comment
MatthiasReumann Apr 28, 2026
e3e352f
🎨 pre-commit fixes
pre-commit-ci[bot] Apr 28, 2026
7ab99c0
Apply bunny suggestions
MatthiasReumann Apr 28, 2026
cd3c6c6
Merge branch 'enh/walk-program-graph' of https://github.com/munich-qu…
MatthiasReumann Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 144 additions & 32 deletions mlir/include/mlir/Dialect/QCO/Utils/Drivers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
#pragma once

#include "mlir/Dialect/QCO/IR/QCODialect.h"
#include "mlir/Dialect/QCO/IR/QCOInterfaces.h"
#include "mlir/Dialect/QCO/IR/QCOOps.h"
#include "mlir/Dialect/QCO/Utils/WireIterator.h"
#include "mlir/Dialect/QTensor/IR/QTensorOps.h"

#include <llvm/ADT/TypeSwitch.h>
#include <mlir/IR/Region.h>
Expand Down Expand Up @@ -52,12 +55,12 @@ class Qubits {
/**
* @returns the qubit value assigned to a program index.
*/
[[maybe_unused]] TypedValue<QubitType> getProgramQubit(std::size_t index);
[[nodiscard]] TypedValue<QubitType> getProgramQubit(std::size_t index) const;

/**
* @returns the qubit value assigned to a hardware index.
*/
TypedValue<QubitType> getHardwareQubit(std::size_t index);
[[nodiscard]] TypedValue<QubitType> getHardwareQubit(std::size_t index) const;

private:
DenseMap<std::size_t, TypedValue<QubitType>> programToValue_;
Expand All @@ -66,9 +69,11 @@ class Qubits {
valueToIndex_;
};

using WalkProgramFn = function_ref<WalkResult(Operation*, Qubits&)>;

/**
* @brief Perform top-down non-recursive walk of all operations within a
* region and apply callback function.
* region of a quantum program and apply a callback function.
* @details The signature of the callback function is:
*
* (Operation*, Qubits& q) -> WalkResult
Expand All @@ -77,35 +82,142 @@ class Qubits {
* @param region The targeted region.
* @param fn The callback function.
*/
template <typename Fn> void walkUnit(Region& region, Fn&& fn) {
const auto ffn = std::forward<Fn>(fn);

Qubits qubits;
for (Operation& curr : region.getOps()) {
if (ffn(&curr, qubits).wasInterrupted()) {
break;
};

TypeSwitch<Operation*>(&curr)
.template Case<StaticOp>(
[&](StaticOp op) { qubits.add(op.getQubit(), op.getIndex()); })
.template Case<AllocOp>([&](AllocOp op) { qubits.add(op.getResult()); })
.template Case<UnitaryOpInterface>([&](UnitaryOpInterface op) {
for (const auto& [prevV, nextV] :
llvm::zip(op.getInputQubits(), op.getOutputQubits())) {
const auto prevQ = cast<TypedValue<QubitType>>(prevV);
const auto nextQ = cast<TypedValue<QubitType>>(nextV);
qubits.remap(prevQ, nextQ);
}
})
.template Case<ResetOp>([&](ResetOp op) {
qubits.remap(op.getQubitIn(), op.getQubitOut());
})
.template Case<MeasureOp>([&](MeasureOp op) {
qubits.remap(op.getQubitIn(), op.getQubitOut());
})
.template Case<SinkOp>(
[&](SinkOp op) { qubits.remove(op.getQubit()); });
void walkProgram(Region& region, WalkProgramFn fn);

using ReleasedOps = SmallVector<UnitaryOpInterface, 8>;
using PendingWiresMap =
DenseMap<UnitaryOpInterface, SmallVector<std::size_t, 2>>;

struct IsReady {
bool operator()(PendingWiresMap::value_type& kv) const {
return kv.second.size() == kv.first.getNumQubits();
}
};
Comment thread
MatthiasReumann marked this conversation as resolved.

using ReadyRange =
decltype(make_filter_range(std::declval<PendingWiresMap&>(), IsReady{}));

using WalkProgramGraphFn =
function_ref<WalkResult(const ReadyRange&, ReleasedOps&)>;

/**
* @brief Walk the graph-like circuit IR of QCO dialect programs.
* @details
* Depending on the template parameter, the function collects the
* layers in forward or backward direction, respectively. Towards that end,
* the function traverses the def-use chain of each qubit until a multi-qubit
* gate (including barriers) is found. If a multi-qubit gate is visited twice,
* it is considered ready and inserted into the layer. This process is repeated
* until no more multi-qubit gates are found anymore.
Comment thread
burgholzer marked this conversation as resolved.
Outdated
*
* The signature of the callback function is:
*
* (const ReadyRange&, ReleasedOps&) -> WalkResult
*
* The operations inserted into the parameter "released" determine which
* multi-qubit gates are released in next iteration.
Comment thread
burgholzer marked this conversation as resolved.
*
* @param wires A mutable array-ref of circuit wires (wire iterators).
* @param direction The traversal direction.
Comment thread
MatthiasReumann marked this conversation as resolved.
Outdated
* @param fn The callback function.
*
* @returns
* failure(), if the callback returns WalkResult::interrupt()
* failure(), if the callback returns WalkResult::skipped()
* success(), otherwise.
*/
template <WireDirection Direction>
LogicalResult walkProgramGraph(MutableArrayRef<WireIterator> wires,
WalkProgramGraphFn fn) {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
using Traits = WireTraversalTraits<Direction>;

ReleasedOps released;

PendingWiresMap pending;
pending.reserve(wires.size());

SmallVector<std::size_t> curr(wires.size());
std::iota(curr.begin(), curr.end(), 0UL);

SmallVector<std::size_t> next;
next.reserve(wires.size());

while (!curr.empty()) {
for (std::size_t i : curr) {
auto& it = wires[i];
while (Traits::isActive(it)) {
const auto res =
TypeSwitch<Operation*, WalkResult>(it.operation())
.template Case<UnitaryOpInterface>([&](UnitaryOpInterface op) {
// If there are fewer wires than the qubit requires inputs,
// it's impossible to release the operation. Hence, fail.
if (op.getNumQubits() > wires.size()) {
return WalkResult::interrupt();
}

if (op.getNumQubits() == 1) {
std::ranges::advance(it, Traits::stride());
return WalkResult::advance();
}

// Insert the unitary to the pending map.
// The caller decides if this op should be released.
const auto [it, inserted] = pending.try_emplace(op);
auto& indices = it->second;

if (inserted) {
indices.reserve(op.getNumQubits());
}

indices.emplace_back(i);

return WalkResult::skip(); // Stop at multi-qubit gate.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
})
.template Case<AllocOp, StaticOp, qtensor::ExtractOp, ResetOp,
MeasureOp, SinkOp, qtensor::InsertOp>([&](auto) {
std::ranges::advance(it, Traits::stride());
return WalkResult::advance();
})
.Default([&](Operation* op) {
const auto name = op->getName().getStringRef();
report_fatal_error("unknown op encountered: " + name);
return WalkResult::interrupt();
});
Comment thread
burgholzer marked this conversation as resolved.
Outdated

if (res.wasSkipped()) {
break;
}

if (res.wasInterrupted()) {
return failure();
}
}
}

released.clear();
const auto ready = make_filter_range(pending, IsReady{});
const auto res = std::invoke(fn, ready, released);
if (res.wasInterrupted() || res.wasSkipped()) {
return failure();
}
Comment thread
MatthiasReumann marked this conversation as resolved.

for (UnitaryOpInterface op : released) {
const auto& mapIt = pending.find(op);
assert(mapIt != pending.end());

auto& indices = mapIt->second;
for (std::size_t i : mapIt->second) {
std::ranges::advance(wires[i], Traits::stride());
next.emplace_back(i);
}

pending.erase(mapIt);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

curr.swap(next);
next.clear();
}

return success();
}
} // namespace mlir::qco
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#pragma once

#include "mlir/Dialect/QCO/IR/QCOOps.h"
#include "mlir/Dialect/QTensor/IR/QTensorOps.h"

#include <mlir/IR/Operation.h>

#include <iterator>
Expand Down Expand Up @@ -83,4 +86,34 @@ class [[nodiscard]] WireIterator {
mlir::Value qubit_;
bool isSentinel_;
};

/**
* @brief Categorizes the current traversal direction.
*/
enum class WireDirection : bool { Forward, Backward };
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

template <WireDirection Direction> struct WireTraversalTraits {};

template <> struct WireTraversalTraits<WireDirection::Forward> {
/// @returns the forward increment stride size.
static constexpr std::ptrdiff_t stride() { return 1; }

/// @returns true if the wire iterator can continue forward.
static bool isActive(const WireIterator& it) {
return it != std::default_sentinel;
}
};

template <> struct WireTraversalTraits<WireDirection::Backward> {
/// @returns the backward increment stride size.
static constexpr std::ptrdiff_t stride() { return -1; }

/// @returns true if the wire iterator can continue backward.
static bool isActive(const WireIterator& it) {
return it.operation() == nullptr
? false
: !isa<qco::AllocOp, StaticOp, qtensor::ExtractOp>(
it.operation());
Comment thread
MatthiasReumann marked this conversation as resolved.
Outdated
}
};
Comment thread
MatthiasReumann marked this conversation as resolved.
} // namespace mlir::qco
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
ArrayRef<Operation*>::iterator anchorIt = anchors.begin();
ArrayRef<SmallVector<IndexGate>>::iterator swapIt = swaps.begin();

walkUnit(funcBody, [&](Operation* op, Qubits& qubits) {
walkProgram(funcBody, [&](Operation* op, Qubits& qubits) {
// Early exit if we've processed all layers.
if (anchorIt == anchors.end()) {
return WalkResult::interrupt();
Expand Down
38 changes: 36 additions & 2 deletions mlir/lib/Dialect/QCO/Utils/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
*/

#include "mlir/Dialect/QCO/IR/QCODialect.h"
#include "mlir/Dialect/QCO/IR/QCOInterfaces.h"
#include "mlir/Dialect/QCO/IR/QCOOps.h"
#include "mlir/Dialect/QCO/Utils/Drivers.h"

#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/Casting.h>
#include <mlir/IR/Value.h>

#include <cassert>
Expand Down Expand Up @@ -58,13 +62,43 @@ void Qubits::remove(TypedValue<QubitType> q) {
hardwareToValue_.erase(index);
}

TypedValue<QubitType> Qubits::getProgramQubit(std::size_t index) {
TypedValue<QubitType> Qubits::getProgramQubit(std::size_t index) const {
assert(programToValue_.contains(index));
return programToValue_.lookup(index);
}

TypedValue<QubitType> Qubits::getHardwareQubit(std::size_t index) {
TypedValue<QubitType> Qubits::getHardwareQubit(std::size_t index) const {
assert(hardwareToValue_.contains(index));
return hardwareToValue_.lookup(index);
}

void walkProgram(Region& region, WalkProgramFn fn) {
Qubits qubits;
for (Operation& curr : region.getOps()) {
if (fn(&curr, qubits).wasInterrupted()) {
break;
};

Comment thread
MatthiasReumann marked this conversation as resolved.
Outdated
TypeSwitch<Operation*>(&curr)
.template Case<StaticOp>(
[&](StaticOp op) { qubits.add(op.getQubit(), op.getIndex()); })
.template Case<AllocOp>([&](AllocOp op) { qubits.add(op.getResult()); })
.template Case<UnitaryOpInterface>([&](UnitaryOpInterface op) {
for (const auto& [prevV, nextV] :
llvm::zip(op.getInputQubits(), op.getOutputQubits())) {
const auto prevQ = llvm::cast<TypedValue<QubitType>>(prevV);
const auto nextQ = llvm::cast<TypedValue<QubitType>>(nextV);
qubits.remap(prevQ, nextQ);
}
})
.template Case<ResetOp>([&](ResetOp op) {
qubits.remap(op.getQubitIn(), op.getQubitOut());
})
.template Case<MeasureOp>([&](MeasureOp op) {
qubits.remap(op.getQubitIn(), op.getQubitOut());
})
.template Case<SinkOp>(
[&](SinkOp op) { qubits.remove(op.getQubit()); });
Comment thread
MatthiasReumann marked this conversation as resolved.
Outdated
}
}
} // namespace mlir::qco
Loading
Loading