Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
// If the substmt is default stmt or case stmt, try to handle the special case
// to make it into the simple form. e.g.
//
// swtich () {
// switch () {
// case 1:
// default:
// ...
Expand Down Expand Up @@ -1165,10 +1165,35 @@ mlir::LogicalResult CIRGenFunction::emitSwitchBody(const Stmt *s) {

auto *compoundStmt = cast<CompoundStmt>(s);

mlir::Block *swtichBlock = builder.getBlock();
for (auto *c : compoundStmt->body()) {
ArrayRef<Stmt *> body{compoundStmt->body_begin(), compoundStmt->body_end()};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not use compoundStmt->body() here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, thats an iterator_range. I could use that instead of array-ref for the range, but it doesn't have a 'drop-front' type function.


mlir::Block *switchBlock = builder.getBlock();

// Any statements appearing before the first case statement are 'unassociated'
// with anything. So we have to create them FIRST in their own block. After
// that, the 'case' regions will take care of future ones.
if (!body.empty() && !isa<SwitchCase>(body.front())) {
builder.setInsertionPointToEnd(switchBlock);
while (!body.empty() && !isa<SwitchCase>(body.front())) {

auto *c = body.front();
if (mlir::failed(emitStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c))))
return mlir::failure();

body = body.drop_front();
}

// Now that we've emitted ALL of the statements, we can create a new block
// for the actual case statements/etc to appear.
mlir::Block *lastBlock = builder.getBlock();
switchBlock = builder.createBlock(switchBlock->getParent());
builder.setInsertionPointToEnd(lastBlock);
cir::BrOp::create(builder, getLoc(s->getSourceRange()), switchBlock);
}

for (auto *c : body) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
builder.setInsertionPointToEnd(swtichBlock);
builder.setInsertionPointToEnd(switchBlock);
// Reset insert point automatically, so that we can attach following
// random stmt to the region of previous built case op to try to make
// the being generated `cir.switch` to be in simple form.
Expand Down
31 changes: 17 additions & 14 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,16 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
if (hasNestedOpsToFlatten(region))
return mlir::failure();

llvm::SmallVector<CaseOp> cases;
op.collectCases(cases);

// Empty switch statement: just erase it.
if (cases.empty()) {
if (op.getBody().hasOneBlock() &&
op.getBody().front().without_terminator().empty()) {
rewriter.eraseOp(op);
return mlir::success();
}

llvm::SmallVector<CaseOp> cases;
op.collectCases(cases);

// Create exit block from the next node of cir.switch op.
mlir::Block *exitBlock = rewriter.splitBlock(
rewriter.getBlock(), op->getNextNode()->getIterator());
Expand All @@ -322,6 +323,18 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
// 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
// recorded block and conditions.

// First we have to handle the rewrite of all of the 'break' ops to make
// sure they now go to the right place, including the ones in the pre-case
// blcoks.
walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
op.getBody(), [&](mlir::Operation *op) {
if (!isa<cir::BreakOp>(op))
return mlir::WalkResult::advance();

lowerTerminator(op, exitBlock, rewriter);
return mlir::WalkResult::skip();
});

// inline everything from switch body between the switch op and the exit
// block.
{
Expand Down Expand Up @@ -389,16 +402,6 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
break;
}

// Handle break statements.
walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
region, [&](mlir::Operation *op) {
if (!isa<cir::BreakOp>(op))
return mlir::WalkResult::advance();

lowerTerminator(op, exitBlock, rewriter);
return mlir::WalkResult::skip();
});

// Track fallthrough in cases.
for (mlir::Block &blk : region.getBlocks()) {
if (blk.getNumSuccessors())
Expand Down
Loading