Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,57 @@ def LLVM_ConstrainedFPExtIntr
}];
}

// Generic constrained floating-point intrinsic call.

def LLVM_CallConstrainedFPIntrinsicOp
: LLVM_Op<"intr.experimental.constrained_fp_call",
[Pure,
DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>,
DeclareOpInterfaceMethods<RoundingModeOpInterface>]> {
let summary = "Generic call to an LLVM constrained floating-point intrinsic.";
let description = [{
Calls an LLVM constrained floating-point intrinsic by name. The intrinsic
name is given by the `intrin` attribute (for example
`"llvm.experimental.constrained.cos.f32"`). Overloaded intrinsics are
resolved from the MLIR operand and result types of this op.

The rounding mode operand is required for intrinsics for which
`llvm::Intrinsic::hasConstrainedFPRoundingModeOperand` returns true and is
forbidden otherwise. The exception behavior attribute is always required.

This op handles every constrained FP intrinsic that follows the standard
operand layout `(args..., [rounding,] exception)`. The compare variants
`llvm.experimental.constrained.fcmp` and
`llvm.experimental.constrained.fcmps` carry an additional predicate
metadata operand and are not supported.

Example:

```mlir
%res = llvm.intr.experimental.constrained_fp_call
"llvm.experimental.constrained.cos.f32"(%arg)
towardzero ignore : (f32) -> f32
```
}];

let arguments = (ins StrAttr:$intrin,
Variadic<LLVM_Type>:$args,
OptionalAttr<ValidRoundingModeAttr>:$roundingmode,
FPExceptionBehaviorAttr:$fpExceptionBehavior);
let results = (outs LLVM_Type:$res);

let llvmBuilder = [{
return convertCallConstrainedFPIntrinsicOp(op, builder, moduleTranslation);
}];

let assemblyFormat = [{
$intrin `(` $args `)` ($roundingmode^)? $fpExceptionBehavior
attr-dict `:` functional-type($args, $res)
}];

let hasVerifier = 1;
}

// Intrinsics with multiple returns.

class LLVM_ArithWithOverflowOp<string mnem>
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Error.h"

#include "LLVMDialectBytecode.h"
Expand Down Expand Up @@ -4012,6 +4013,34 @@ LogicalResult CallIntrinsicOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// CallConstrainedFPIntrinsicOp
//===----------------------------------------------------------------------===//

LogicalResult CallConstrainedFPIntrinsicOp::verify() {
StringRef name = getIntrin();
llvm::Intrinsic::ID id = llvm::Intrinsic::lookupIntrinsicID(name);
if (!id)
return emitOpError() << "could not find LLVM intrinsic: " << name;
if (!llvm::Intrinsic::isConstrainedFPIntrinsic(id))
return emitOpError() << "intrinsic " << name
<< " is not a constrained FP intrinsic";
if (id == llvm::Intrinsic::experimental_constrained_fcmp ||
id == llvm::Intrinsic::experimental_constrained_fcmps)
return emitOpError() << "intrinsic " << name
<< " is a constrained FP compare and is not "
"supported by this op";
bool requiresRounding =
llvm::Intrinsic::hasConstrainedFPRoundingModeOperand(id);
if (requiresRounding && !getRoundingmodeAttr())
return emitOpError() << "intrinsic " << name
<< " requires a rounding mode attribute";
if (!requiresRounding && getRoundingmodeAttr())
return emitOpError() << "intrinsic " << name
<< " does not take a rounding mode attribute";
return success();
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
Expand Down
116 changes: 112 additions & 4 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/MemoryModelRelaxationAnnotations.h"

using namespace mlir;
Expand All @@ -36,24 +37,113 @@ static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
static constexpr StringLiteral intelReqdSubGroupSizeMDName =
"intel_reqd_sub_group_size";

/// Returns true if `id` is a constrained FP intrinsic that the generic
/// LLVM_CallConstrainedFPIntrinsicOp can model (i.e. it has the standard
/// trailing metadata layout: rounding mode and/or exception behavior, with no
/// additional predicate metadata).
static bool isGenericConstrainedFPIntrinsic(llvm::Intrinsic::ID id) {
if (!llvm::Intrinsic::isConstrainedFPIntrinsic(id))
return false;
// fcmp / fcmps carry an extra predicate metadata operand and are not
// representable by the generic op.
return id != llvm::Intrinsic::experimental_constrained_fcmp &&
id != llvm::Intrinsic::experimental_constrained_fcmps;
}

/// Returns true if `id` is a constrained FP compare intrinsic. These have a
/// predicate metadata operand in addition to the exception behavior operand
/// and are not currently importable, but should fail with a clean diagnostic
/// instead of falling through to the generic intrinsic path and tripping the
/// metadata assertion in `convertValue`.
static bool isConstrainedFPCmpIntrinsic(llvm::Intrinsic::ID id) {
return id == llvm::Intrinsic::experimental_constrained_fcmp ||
id == llvm::Intrinsic::experimental_constrained_fcmps;
}

/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
/// intrinsic. Returns false otherwise.
/// intrinsic. Returns false otherwise. Constrained FP compare intrinsics are
/// claimed here so that the import emits a targeted error rather than crashing
/// in the unregistered-intrinsic fallback.
static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
static const DenseSet<unsigned> convertibleIntrinsics = {
#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
};
return convertibleIntrinsics.contains(id);
if (convertibleIntrinsics.contains(id))
return true;
return isGenericConstrainedFPIntrinsic(id) || isConstrainedFPCmpIntrinsic(id);
}

/// Returns the list of LLVM IR intrinsic identifiers that are convertible to
/// MLIR LLVM dialect intrinsics.
static ArrayRef<unsigned> getSupportedIntrinsicsImpl() {
static const SmallVector<unsigned> convertibleIntrinsics = {
static const SmallVector<unsigned> convertibleIntrinsics = [] {
SmallVector<unsigned> ids = {
#include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
};
};
// Also register the constrained FP intrinsics that fall back to the
// generic LLVM_CallConstrainedFPIntrinsicOp. Compare variants are
// registered too so the importer can emit a clean error for them instead
// of letting them fall through to the unregistered-intrinsic path, which
// would trip the metadata assertion in `convertValue`.
DenseSet<unsigned> seen(ids.begin(), ids.end());
for (unsigned id = 1; id < llvm::Intrinsic::num_intrinsics; ++id) {
auto intrinId = static_cast<llvm::Intrinsic::ID>(id);
if (seen.contains(id))
continue;
if (isGenericConstrainedFPIntrinsic(intrinId) ||
isConstrainedFPCmpIntrinsic(intrinId))
ids.push_back(id);
}
return ids;
}();
return convertibleIntrinsics;
}

/// Imports a constrained FP intrinsic call as a generic
/// LLVM_CallConstrainedFPIntrinsicOp. Splits the call's operands into value
/// arguments and the trailing rounding-mode/exception-behavior metadata
/// operands.
static LogicalResult
convertConstrainedFPIntrinsicCallOp(OpBuilder &builder, llvm::CallInst *inst,
LLVM::ModuleImport &moduleImport) {
llvm::Intrinsic::ID id = inst->getIntrinsicID();
llvm::Function *callee = inst->getCalledFunction();
if (!callee)
return failure();
StringRef intrinName = callee->getName();
bool hasRounding = llvm::Intrinsic::hasConstrainedFPRoundingModeOperand(id);

unsigned numArgs = inst->arg_size();
unsigned numMetadata = hasRounding ? 2 : 1;
if (numArgs < numMetadata)
return failure();
unsigned numValueArgs = numArgs - numMetadata;

SmallVector<Value> args;
args.reserve(numValueArgs);
for (unsigned i = 0; i < numValueArgs; ++i) {
FailureOr<Value> v = moduleImport.convertValue(inst->getArgOperand(i));
if (failed(v))
return failure();
args.push_back(*v);
}

RoundingModeAttr roundingMode;
if (hasRounding)
roundingMode =
moduleImport.matchRoundingModeAttr(inst->getArgOperand(numValueArgs));
FPExceptionBehaviorAttr exceptionBehavior =
moduleImport.matchFPExceptionBehaviorAttr(
inst->getArgOperand(numArgs - 1));

Type resultType = moduleImport.convertType(inst->getType());
auto op = CallConstrainedFPIntrinsicOp::create(
builder, moduleImport.translateLoc(inst->getDebugLoc()), resultType,
builder.getStringAttr(intrinName), args, roundingMode, exceptionBehavior);
moduleImport.mapValue(inst) = op.getRes();
return success();
}

/// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
/// conversion exits. Returns failure otherwise.
static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
Expand All @@ -73,6 +163,24 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
llvmOpBundles.push_back(inst->getOperandBundleAt(i));

#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"

// Fallback for constrained FP intrinsics without a dedicated MLIR op.
if (isGenericConstrainedFPIntrinsic(intrinsicID))
return convertConstrainedFPIntrinsicCallOp(odsBuilder, inst,
moduleImport);

// Constrained FP compare intrinsics are claimed here so that we can emit
// a targeted error instead of falling through to convertUnregistered-
// Intrinsic (which would crash on the predicate metadata operand).
if (isConstrainedFPCmpIntrinsic(intrinsicID)) {
Location loc = moduleImport.translateLoc(inst->getDebugLoc());
StringRef intrinName = inst->getCalledFunction()
? inst->getCalledFunction()->getName()
: StringRef("<unknown>");
return emitError(loc)
<< "constrained FP compare intrinsic '" << intrinName
<< "' is not supported by the LLVM dialect importer";
}
}

return failure();
Expand Down
61 changes: 61 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,67 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}

/// Builder for LLVM_CallConstrainedFPIntrinsicOp. Resolves the intrinsic
/// identifier from the `intrin` attribute, infers any overloaded types from the
/// MLIR operand and result types, and emits an LLVM IR constrained FP call.
static LogicalResult convertCallConstrainedFPIntrinsicOp(
CallConstrainedFPIntrinsicOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Intrinsic::ID id = llvm::Intrinsic::lookupIntrinsicID(op.getIntrin());
if (!id)
return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
<< op.getIntrin();
if (!llvm::Intrinsic::isConstrainedFPIntrinsic(id))
return mlir::emitError(op.getLoc(), "not a constrained FP intrinsic: ")
<< op.getIntrin();
if (id == llvm::Intrinsic::experimental_constrained_fcmp ||
id == llvm::Intrinsic::experimental_constrained_fcmps)
return mlir::emitError(op.getLoc())
<< op.getIntrin()
<< " is a constrained FP compare and is not supported by this op";

// Build a signature matching what the intrinsic declaration looks like in
// LLVM IR, including the trailing metadata operands. This lets
// Intrinsic::isSignatureValid resolve all overloaded types.
SmallVector<llvm::Type *> argTys;
argTys.reserve(op.getArgs().size() + 2);
for (Type type : op.getArgs().getTypes())
argTys.push_back(moduleTranslation.convertType(type));
llvm::Type *metadataTy = llvm::Type::getMetadataTy(module->getContext());
if (llvm::Intrinsic::hasConstrainedFPRoundingModeOperand(id))
argTys.push_back(metadataTy);
argTys.push_back(metadataTy);

llvm::Type *resultTy = moduleTranslation.convertType(op.getRes().getType());
llvm::FunctionType *ft =
llvm::FunctionType::get(resultTy, argTys, /*isVarArg=*/false);

std::string errorMsg;
llvm::raw_string_ostream errorOS(errorMsg);
SmallVector<llvm::Type *> overloadedTys;
if (!llvm::Intrinsic::isSignatureValid(id, ft, overloadedTys, errorOS)) {
return mlir::emitError(op.getLoc(), "call intrinsic signature ")
<< diagStr(ft) << " to constrained FP intrinsic " << op.getIntrin()
<< " does not match any overload: " << errorMsg;
}

llvm::Function *fn =
llvm::Intrinsic::getOrInsertDeclaration(module, id, overloadedTys);

std::optional<llvm::RoundingMode> rounding;
if (auto roundingAttr = op.getRoundingmodeAttr())
rounding = moduleTranslation.translateRoundingMode(roundingAttr.getValue());
llvm::fp::ExceptionBehavior except =
moduleTranslation.translateFPExceptionBehavior(
op.getFpExceptionBehavior());

llvm::Value *result = builder.CreateConstrainedFPCall(
fn, moduleTranslation.lookupValues(op.getArgs()), "", rounding, except);
moduleTranslation.mapValue(op.getRes()) = result;
return success();
}

/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,46 @@ llvm.func @wrong_number_of_bundle_types_intrin(%arg0: i32) -> i32 {

// -----

llvm.func @constrained_fp_call_unknown_intrinsic(%arg0: f32) -> f32 {
// expected-error@+1 {{could not find LLVM intrinsic: llvm.experimental.constrained.bogus.f32}}
%0 = llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.bogus.f32"(%arg0) towardzero ignore : (f32) -> f32
llvm.return %0 : f32
}

// -----

llvm.func @constrained_fp_call_not_constrained(%arg0: f32) -> f32 {
// expected-error@+1 {{intrinsic llvm.cos.f32 is not a constrained FP intrinsic}}
%0 = llvm.intr.experimental.constrained_fp_call "llvm.cos.f32"(%arg0) towardzero ignore : (f32) -> f32
llvm.return %0 : f32
}

// -----

llvm.func @constrained_fp_call_fcmp_rejected(%arg0: f32) -> i1 {
// expected-error@+1 {{intrinsic llvm.experimental.constrained.fcmp.f32 is a constrained FP compare and is not supported by this op}}
%0 = llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.fcmp.f32"(%arg0, %arg0) ignore : (f32, f32) -> i1
llvm.return %0 : i1
}

// -----

llvm.func @constrained_fp_call_missing_rounding(%arg0: f32) -> f32 {
// expected-error@+1 {{intrinsic llvm.experimental.constrained.cos.f32 requires a rounding mode attribute}}
%0 = llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.cos.f32"(%arg0) ignore : (f32) -> f32
llvm.return %0 : f32
}

// -----

llvm.func @constrained_fp_call_unexpected_rounding(%arg0: f64) -> f64 {
// expected-error@+1 {{intrinsic llvm.experimental.constrained.maximum.f64 does not take a rounding mode attribute}}
%0 = llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.maximum.f64"(%arg0, %arg0) towardzero ignore : (f64, f64) -> f64
llvm.return %0 : f64
}

// -----

llvm.func @foo()
llvm.func @wrong_number_of_bundle_tags() {
%0 = llvm.mlir.constant(0 : i32) : i32
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,23 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
llvm.return
}

// CHECK-LABEL: @experimental_constrained_fp_call
llvm.func @experimental_constrained_fp_call(%s: f32, %d: f64, %p: i32) {
// CHECK: llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.cos.f32"(%{{.*}}) towardzero ignore : (f32) -> f32
%0 = llvm.intr.experimental.constrained_fp_call
"llvm.experimental.constrained.cos.f32"(%s) towardzero ignore
: (f32) -> f32
// CHECK: llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.maximum.f64"(%{{.*}}, %{{.*}}) strict : (f64, f64) -> f64
%1 = llvm.intr.experimental.constrained_fp_call
"llvm.experimental.constrained.maximum.f64"(%d, %d) strict
: (f64, f64) -> f64
// CHECK: llvm.intr.experimental.constrained_fp_call "llvm.experimental.constrained.powi.f32"(%{{.*}}, %{{.*}}) tonearest ignore : (f32, i32) -> f32
%2 = llvm.intr.experimental.constrained_fp_call
"llvm.experimental.constrained.powi.f32"(%s, %p) tonearest ignore
: (f32, i32) -> f32
llvm.return
}

// CHECK: llvm.func @tail_call_target() -> i32
llvm.func @tail_call_target() -> i32

Expand Down
Loading