diff --git a/include/circt/Dialect/Arc/ArcOps.td b/include/circt/Dialect/Arc/ArcOps.td index 02781979d217..247cf01e38e4 100644 --- a/include/circt/Dialect/Arc/ArcOps.td +++ b/include/circt/Dialect/Arc/ArcOps.td @@ -281,10 +281,57 @@ def CallOp : ArcOp<"call", [ }]; } -def MemoryOp : ArcOp<"memory", [MemoryEffects<[MemAlloc]>]> { +class OptionalMemoryInitializerIsCompatible + : PredOpTrait< + "initializer type is compatible with memory type", + CPred< !strconcat("!(", predicate, ") ||", + "(::llvm::cast<::circt::arc::MemoryInitializerType>($", + init, ".getType()).isCompatible(", + "::llvm::cast<::circt::arc::MemoryType>($", + mem, ".getType())))" )> +>; + +class MemoryInitializerIsCompatible + : OptionalMemoryInitializerIsCompatible <"true", mem, init>; + +def MemoryOp : ArcOp<"memory", [ + MemoryEffects<[MemAlloc]>, + OptionalMemoryInitializerIsCompatible< + "$_op.getNumOperands() > 0", "memory", "initializer"> +]> { let summary = "Memory"; + let arguments = (ins Optional:$initializer); let results = (outs MemoryType:$memory); - let assemblyFormat = "type($memory) attr-dict"; + let assemblyFormat = [{ + type($memory) attr-dict (`initial` $initializer^ `:` type($initializer))? + }]; + + let builders = [ + OpBuilder<(ins "mlir::Type":$memType), [{ + build($_builder, $_state, memType, {}); + }]>]; +} + +def InitMemoryFilledOp : ArcOp<"initmem.filled", [Pure]> { + let arguments = (ins APIntAttr:$value, UnitAttr:$repeat); + let results = (outs GenericMemoryInitializerType:$result); + let hasCustomAssemblyFormat = true; +} + +def InitMemoryRandomizedOp : ArcOp<"initmem.randomized", [Pure]> { + let results = (outs GenericMemoryInitializerType:$result); + let assemblyFormat = "attr-dict"; +} + +def InitializeMemoryOp : ArcOp<"initialize_memory", [ + MemoryEffects<[MemWrite]>, + MemoryInitializerIsCompatible<"memory", "initializer"> +]> { + let arguments = (ins MemoryInitializerType:$initializer, MemoryType:$memory); + let assemblyFormat = [{ + $initializer `->` $memory attr-dict + `:` qualified(type($initializer)) `,` qualified(type($memory)) + }]; } class MemoryAndDataTypesMatch : TypesMatchWith< @@ -905,4 +952,74 @@ def VectorizeReturnOp : ArcOp<"vectorize.return", [ let assemblyFormat = "operands attr-dict `:` qualified(type(operands))"; } +def EnvironmentCallOp : ArcOp<"environment_call", [ + FunctionOpInterface, IsolatedFromAbove, HasParent<"mlir::ModuleOp"> + ]> { + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + + static constexpr ::llvm::StringLiteral fillRandomizedSymName = + "_arc_env_fill_randomized"; + + static ::mlir::FunctionType getFillRandomizedType(::mlir::MLIRContext *ctx); + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return nullptr; } + }]; + + let hasCustomAssemblyFormat = 1; +} + +def CallEnvironmentOp : ArcOp<"call_environment", [ + Pure, + CallOpInterface, + DeclareOpInterfaceMethods +]> { + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + + operand_range getArgOperands() { + return getOperands(); + } + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), callee.get()); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; + +} + #endif // CIRCT_DIALECT_ARC_ARCOPS_TD diff --git a/include/circt/Dialect/Arc/ArcPasses.td b/include/circt/Dialect/Arc/ArcPasses.td index 6a89f69c7b02..d94dedec34be 100644 --- a/include/circt/Dialect/Arc/ArcPasses.td +++ b/include/circt/Dialect/Arc/ArcPasses.td @@ -194,6 +194,14 @@ def LowerLUT : Pass<"arc-lower-lut", "arc::DefineOp"> { let dependentDialects = ["hw::HWDialect", "comb::CombDialect"]; } +def LowerMemoryInitializers : Pass<"arc-lower-memory-initializers", + "mlir::ModuleOp"> { + let summary = "Converts InitializeMemory ops to state writes"; + let dependentDialects = [ + "arc::ArcDialect", "mlir::arith::ArithDialect", "mlir::scf::SCFDialect" + ]; +} + def LowerState : Pass<"arc-lower-state", "mlir::ModuleOp"> { let summary = "Split state into read and write ops grouped by clock tree"; let constructor = "circt::arc::createLowerStatePass()"; @@ -219,7 +227,7 @@ def LowerVectorizations : Pass<"arc-lower-vectorizations", "mlir::ModuleOp"> { This pass lowers `arc.vectorize` operations. By default, the operation will be fully lowered (i.e., the op disappears in the IR). Alternatively, it can be partially lowered. - + The "mode" pass option allows to only lower the boundary, only the body, or only inline the body given that both the boundary and the body are already lowered. diff --git a/include/circt/Dialect/Arc/ArcTypes.td b/include/circt/Dialect/Arc/ArcTypes.td index 9dbfa0498c10..2cc10556bf4b 100644 --- a/include/circt/Dialect/Arc/ArcTypes.td +++ b/include/circt/Dialect/Arc/ArcTypes.td @@ -43,6 +43,37 @@ def MemoryType : ArcTypeDef<"Memory"> { }]; } +def MemoryInitializerType : ArcTypeDef<"MemoryInitializer"> { + let mnemonic = "memory_initializer"; + + let parameters = (ins OptionalParameter<"unsigned">:$numWords, + OptionalParameter<"::mlir::IntegerType">:$wordType); + + let hasCustomAssemblyFormat = true; + + let extraClassDeclaration = [{ + bool isCompatible(::circt::arc::MemoryType memType) const { + if (getNumWords() > 0 && getNumWords() != memType.getNumWords()) + return false; + if (!!getWordType() && getWordType() != memType.getWordType()) + return false; + return true; + } + + bool isGeneric() const { return getNumWords() == 0 && !getWordType(); } + + }]; +} + +def GenericMemoryInitializerType : DialectType($_self) && + ::llvm::cast<::circt::arc::MemoryInitializerType>($_self).isGeneric() + }]>, "must be a generic memory initializer type">, + BuildableType< + "::circt::arc::MemoryInitializerType::get($_builder.getContext(), 0, {})" +> {} + def StorageType : ArcTypeDef<"Storage"> { let mnemonic = "storage"; let parameters = (ins OptionalParameter<"unsigned">:$size); diff --git a/integration_test/arcilator/JIT/initial-ram.mlir b/integration_test/arcilator/JIT/initial-ram.mlir new file mode 100644 index 000000000000..739e08956a69 --- /dev/null +++ b/integration_test/arcilator/JIT/initial-ram.mlir @@ -0,0 +1,155 @@ +// RUN: arcilator %s --run --jit-entry=main | FileCheck %s +// REQUIRES: arcilator-jit + +// Lit testing random values is iffy, but the runtime environment should ensure +// reproducible results across runs and platforms. + +// CHECK-LABEL: - addr = 21 +// CHECK-NEXT: rndA = 707ca895977cf11 +// CHECK-NEXT: rndB = 28e9cfdfcf6b898 +// CHECK-NEXT: fill = cafe +// CHECK-NEXT: rept = 8000400020001 +// CHECK-NEXT: - addr = 0 +// CHECK-NEXT: rndA = 5160879eac03cbb +// CHECK-NEXT: rndB = d78aeb0b84b4823 +// CHECK-NEXT: fill = cafe +// CHECK-NEXT: rept = 8000400020001 +// CHECK-NEXT: - addr = 1ff +// CHECK-NEXT: rndA = 198ecb046b4841d +// CHECK-NEXT: rndB = 357020a9a09635b +// CHECK-NEXT: fill = cafe +// CHECK-NEXT: rept = 8000400020001 +// CHECK-NEXT: - addr = aa +// CHECK-NEXT: rndA = 16b4a44c8c8ce64 +// CHECK-NEXT: rndB = 476fc6a9fd6fb83 +// CHECK-NEXT: fill = cafe +// CHECK-NEXT: rept = 8000400020001 + +module { + arc.define @mem_write(%arg0: i9, %arg1: i60, %arg2: i1) -> (i9, i60, i1) { + arc.output %arg0, %arg1, %arg2 : i9, i60, i1 + } + hw.module @SyncRAM( + in %clk : i1, in %reset : i1, in %en : i1, in %addr : i9, in %din : i60, in %wen : i1, + out dout0 : i60, out dout1 : i60, out dout2 : i60, out dout3 : i60, out addrOut : i9) { + %clock = seq.to_clock %clk + + %cst33_i9 = hw.constant 33 : i9 + + %randInit = arc.initmem.randomized + // Check that identical memories get different initial values + %mem0 = arc.memory <512 x i60, i9> initial %randInit : !arc.memory_initializer<* x *> + %mem3 = arc.memory <512 x i60, i9> initial %randInit : !arc.memory_initializer<* x *> + + %fillInit = arc.initmem.filled 0xcafe : i16 + %mem1 = arc.memory <512 x i60, i9> initial %fillInit : !arc.memory_initializer<* x *> + + %repeatInit = arc.initmem.filled repeat 1 : i17 + %mem2 = arc.memory <512 x i60, i9> initial %repeatInit : !arc.memory_initializer<* x *> + + + + %addrReg = seq.compreg %0, %clock powerOn %cst33_i9 : i9 + %0 = comb.mux bin %en, %addr, %addrReg : i9 + + %3 = seq.compreg %en, %clock : i1 + + %rd0 = arc.memory_read_port %mem0[%addrReg] : <512 x i60, i9> + %rd1 = arc.memory_read_port %mem1[%addrReg] : <512 x i60, i9> + %rd2 = arc.memory_read_port %mem2[%addrReg] : <512 x i60, i9> + %rd3 = arc.memory_read_port %mem3[%addrReg] : <512 x i60, i9> + + %c0_i60 = hw.constant 0 : i60 + + arc.memory_write_port %mem0, @mem_write(%addrReg, %din, %wen) clock %clock enable latency 1 : <512 x i60, i9>, i9, i60, i1 + arc.memory_write_port %mem1, @mem_write(%addrReg, %din, %wen) clock %clock enable latency 1 : <512 x i60, i9>, i9, i60, i1 + arc.memory_write_port %mem2, @mem_write(%addrReg, %din, %wen) clock %clock enable latency 1 : <512 x i60, i9>, i9, i60, i1 + arc.memory_write_port %mem3, @mem_write(%addrReg, %din, %wen) clock %clock enable latency 1 : <512 x i60, i9>, i9, i60, i1 + + hw.output %rd0, %rd1, %rd2, %rd3, %addrReg : i60, i60, i60, i60 , i9 + } + + func.func @main() { + %cst0 = arith.constant 0 : i9 + %cst1ff = arith.constant 0x1FF : i9 + %cstaa = arith.constant 0xAA : i9 + + %false = arith.constant 0 : i1 + %true = arith.constant 1 : i1 + + arc.sim.instantiate @SyncRAM as %model { + %addr0 = arc.sim.get_port %model, "addrOut" : i9, !arc.sim.instance<@SyncRAM> + %res0_0 = arc.sim.get_port %model, "dout0" : i60, !arc.sim.instance<@SyncRAM> + %res1_0 = arc.sim.get_port %model, "dout1" : i60, !arc.sim.instance<@SyncRAM> + %res2_0 = arc.sim.get_port %model, "dout2" : i60, !arc.sim.instance<@SyncRAM> + %res3_0 = arc.sim.get_port %model, "dout3" : i60, !arc.sim.instance<@SyncRAM> + arc.sim.emit " - addr", %addr0 : i9 + arc.sim.emit "rndA", %res0_0 : i60 + arc.sim.emit "rndB", %res3_0 : i60 + arc.sim.emit "fill", %res1_0 : i60 + arc.sim.emit "rept", %res2_0 : i60 + + arc.sim.set_input %model, "en" = %true : i1, !arc.sim.instance<@SyncRAM> + arc.sim.set_input %model, "wen" = %false : i1, !arc.sim.instance<@SyncRAM> + arc.sim.set_input %model, "reset" = %false : i1, !arc.sim.instance<@SyncRAM> + + arc.sim.set_input %model, "addr" = %cst0 : i9, !arc.sim.instance<@SyncRAM> + + arc.sim.set_input %model, "clk" = %false : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + + arc.sim.set_input %model, "clk" = %true : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + arc.sim.set_input %model, "clk" = %false : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + + %addr1 = arc.sim.get_port %model, "addrOut" : i9, !arc.sim.instance<@SyncRAM> + %res0_1 = arc.sim.get_port %model, "dout0" : i60, !arc.sim.instance<@SyncRAM> + %res1_1 = arc.sim.get_port %model, "dout1" : i60, !arc.sim.instance<@SyncRAM> + %res2_1 = arc.sim.get_port %model, "dout2" : i60, !arc.sim.instance<@SyncRAM> + %res3_1 = arc.sim.get_port %model, "dout3" : i60, !arc.sim.instance<@SyncRAM> + arc.sim.emit " - addr", %addr1 : i9 + arc.sim.emit "rndA", %res0_1 : i60 + arc.sim.emit "rndB", %res3_1 : i60 + arc.sim.emit "fill", %res1_1 : i60 + arc.sim.emit "rept", %res2_1 : i60 + + arc.sim.set_input %model, "addr" = %cst1ff : i9, !arc.sim.instance<@SyncRAM> + + arc.sim.set_input %model, "clk" = %true : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + arc.sim.set_input %model, "clk" = %false : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + + %addr2 = arc.sim.get_port %model, "addrOut" : i9, !arc.sim.instance<@SyncRAM> + %res0_2 = arc.sim.get_port %model, "dout0" : i60, !arc.sim.instance<@SyncRAM> + %res1_2 = arc.sim.get_port %model, "dout1" : i60, !arc.sim.instance<@SyncRAM> + %res2_2 = arc.sim.get_port %model, "dout2" : i60, !arc.sim.instance<@SyncRAM> + %res3_2 = arc.sim.get_port %model, "dout3" : i60, !arc.sim.instance<@SyncRAM> + arc.sim.emit " - addr", %addr2 : i9 + arc.sim.emit "rndA", %res0_2 : i60 + arc.sim.emit "rndB", %res3_2 : i60 + arc.sim.emit "fill", %res1_2 : i60 + arc.sim.emit "rept", %res2_2 : i60 + + arc.sim.set_input %model, "addr" = %cstaa : i9, !arc.sim.instance<@SyncRAM> + + arc.sim.set_input %model, "clk" = %true : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + arc.sim.set_input %model, "clk" = %false : i1, !arc.sim.instance<@SyncRAM> + arc.sim.step %model : !arc.sim.instance<@SyncRAM> + + %addr3 = arc.sim.get_port %model, "addrOut" : i9, !arc.sim.instance<@SyncRAM> + %res0_3 = arc.sim.get_port %model, "dout0" : i60, !arc.sim.instance<@SyncRAM> + %res1_3 = arc.sim.get_port %model, "dout1" : i60, !arc.sim.instance<@SyncRAM> + %res2_3 = arc.sim.get_port %model, "dout2" : i60, !arc.sim.instance<@SyncRAM> + %res3_3 = arc.sim.get_port %model, "dout3" : i60, !arc.sim.instance<@SyncRAM> + arc.sim.emit " - addr", %addr3 : i9 + arc.sim.emit "rndA", %res0_3 : i60 + arc.sim.emit "rndB", %res3_3 : i60 + arc.sim.emit "fill", %res1_3 : i60 + arc.sim.emit "rept", %res2_3 : i60 + } + return + } +} diff --git a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp index d6e4f5bcc130..aa1ae43ed626 100644 --- a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp +++ b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp @@ -76,6 +76,43 @@ struct ModelOpLowering : public OpConversionPattern { } }; +struct EnvironemtCallOpLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arc::EnvironmentCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + auto numResults = adaptor.getFunctionType().getNumResults(); + if (numResults > 1) { + op.emitError("Cannot lower environment call with multiple results."); + return failure(); + } + Type resType = numResults == 0 + ? LLVM::LLVMVoidType::get(rewriter.getContext()) + : adaptor.getFunctionType().getResult(0); + auto funcType = LLVM::LLVMFunctionType::get( + resType, adaptor.getFunctionType().getInputs()); + + rewriter.replaceOpWithNewOp(op, op.getName(), funcType); + return success(); + }; +}; + +struct CallEnvirionmentOpLowering + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arc::CallEnvironmentOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getCallee(), adaptor.getOperands()); + return success(); + }; +}; + struct AllocStorageOpLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -630,7 +667,9 @@ void LowerArcToLLVMPass::runOnOperation() { AllocStateLikeOpLowering, AllocStateLikeOpLowering, AllocStorageOpLowering, + CallEnvirionmentOpLowering, ClockGateOpLowering, + EnvironemtCallOpLowering, MemoryReadOpLowering, MemoryWriteOpLowering, ModelOpLowering, diff --git a/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp b/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp index a8c84afda534..35ac908575af 100644 --- a/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp +++ b/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp @@ -25,8 +25,9 @@ using llvm::MapVector; static bool isArcBreakingOp(Operation *op) { return op->hasTrait() || - isa(op) || + isa(op) || op->getNumResults() > 1; } diff --git a/lib/Dialect/Arc/ArcOps.cpp b/lib/Dialect/Arc/ArcOps.cpp index 76e457b8d77d..9a605b4eb285 100644 --- a/lib/Dialect/Arc/ArcOps.cpp +++ b/lib/Dialect/Arc/ArcOps.cpp @@ -634,6 +634,82 @@ LogicalResult SimStepOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +void InitMemoryFilledOp::print(OpAsmPrinter &p) { + if (getRepeat()) + p << " repeat"; + p << " "; + p.printAttribute(getValueAttr()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"value", "repeat"}); +} + +ParseResult InitMemoryFilledOp::parse(OpAsmParser &parser, + OperationState &result) { + IntegerAttr valueAttr; + + if (!parser.parseOptionalKeyword("repeat")) + result.addAttribute("repeat", UnitAttr::get(parser.getContext())); + + if (parser.parseAttribute(valueAttr, "value", result.attributes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + result.addTypes(MemoryInitializerType::get(parser.getContext(), 0, {})); + return success(); +} + +ParseResult EnvironmentCallOp::parse(OpAsmParser &parser, + OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void EnvironmentCallOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, "function_type", getArgAttrsAttrName(), + getResAttrsAttrName()); +} + +FunctionType EnvironmentCallOp::getFillRandomizedType(MLIRContext *ctxt) { + auto i64Type = IntegerType::get(ctxt, 64); + auto i32Type = IntegerType::get(ctxt, 32); + auto storageType = StorageType::get(ctxt, 0); + + std::array args; + args[0] = storageType; // Memory reference + args[1] = i64Type; // Num Words + args[2] = i32Type; // Word bits + args[3] = i32Type; // Stride + + return FunctionType::get(ctxt, args, TypeRange{}); +} + +LogicalResult +CallEnvironmentOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto referencedOp = + symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()); + + if (!referencedOp) + return emitError("Cannot find declaration of environment call '") + << getCallee() << "'."; + auto envCallOp = dyn_cast(referencedOp); + if (!envCallOp) { + auto diag = + emitError("Referenced operation must be an 'arc.environment_call' op."); + diag.attachNote(referencedOp->getLoc()) << "Symbol declared here:"; + return diag; + } + // TODO : Verify argument and return types + return success(); +} + #include "circt/Dialect/Arc/ArcInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/lib/Dialect/Arc/ArcTypes.cpp b/lib/Dialect/Arc/ArcTypes.cpp index 5cd342284f2d..a414c1ae82a3 100644 --- a/lib/Dialect/Arc/ArcTypes.cpp +++ b/lib/Dialect/Arc/ArcTypes.cpp @@ -37,6 +37,45 @@ unsigned MemoryType::getStride() { return llvm::alignToPowerOf2(stride, llvm::bit_ceil(std::min(stride, 16U))); } +Type MemoryInitializerType::parse(AsmParser &odsParser) { + unsigned numWords = 0; + IntegerType wordType; + + if (odsParser.parseLess()) + return {}; + + if (odsParser.parseOptionalStar()) { + auto numLoc = odsParser.getCurrentLocation(); + if (odsParser.parseInteger(numWords)) + return {}; + if (numWords == 0) { + odsParser.emitError(numLoc, "Number of words must not be zero."); + return {}; + } + } + + if (odsParser.parseXInDimensionList() || + (odsParser.parseOptionalStar() && odsParser.parseType(wordType)) || + odsParser.parseGreater()) + return {}; + + return MemoryInitializerType::get(odsParser.getContext(), numWords, wordType); +} + +void MemoryInitializerType::print(AsmPrinter &odsPrinter) const { + odsPrinter << "<"; + if (getNumWords() > 0) + odsPrinter << getNumWords(); + else + odsPrinter << "*"; + odsPrinter << " x "; + if (getWordType()) + odsPrinter << getWordType(); + else + odsPrinter << "*"; + odsPrinter << ">"; +} + void ArcDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/lib/Dialect/Arc/Transforms/CMakeLists.txt b/lib/Dialect/Arc/Transforms/CMakeLists.txt index d3690cd408be..047f46adb7cb 100644 --- a/lib/Dialect/Arc/Transforms/CMakeLists.txt +++ b/lib/Dialect/Arc/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_circt_dialect_library(CIRCTArcTransforms LowerArcsToFuncs.cpp LowerClocksToFuncs.cpp LowerLUT.cpp + LowerMemoryInitializers.cpp LowerState.cpp LowerVectorizations.cpp MakeTables.cpp diff --git a/lib/Dialect/Arc/Transforms/LowerMemoryInitializers.cpp b/lib/Dialect/Arc/Transforms/LowerMemoryInitializers.cpp new file mode 100644 index 000000000000..6ee50502d60e --- /dev/null +++ b/lib/Dialect/Arc/Transforms/LowerMemoryInitializers.cpp @@ -0,0 +1,211 @@ +//===- LowerMemoryInitializers.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Arc/ArcOps.h" +#include "circt/Dialect/Arc/ArcPasses.h" +#include "circt/Dialect/HW/HWOps.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "arc-lower-memory-initializers" + +namespace circt { +namespace arc { +#define GEN_PASS_DEF_LOWERMEMORYINITIALIZERS +#include "circt/Dialect/Arc/ArcPasses.h.inc" +} // namespace arc +} // namespace circt + +using namespace mlir; +using namespace circt; +using namespace arc; + +namespace { +struct LowerMemoryInitializersPass + : public arc::impl::LowerMemoryInitializersBase< + LowerMemoryInitializersPass> { + void runOnOperation() override; + LogicalResult processInitializerFunction(func::FuncOp funcOp); + LogicalResult lowerFilledInitialization(InitializeMemoryOp initOp, + InitMemoryFilledOp fillOp); + LogicalResult lowerRandomizedInitialization(InitializeMemoryOp initOp, + InitMemoryRandomizedOp randOp); + + EnvironmentCallOp getEnvironmentCall(StringRef symName, FunctionType type); + + SymbolTable *symbolTable; +}; +} // namespace + +EnvironmentCallOp +LowerMemoryInitializersPass::getEnvironmentCall(StringRef symName, + FunctionType type) { + assert(!!symbolTable && "Missing symbol table."); + + auto lookup = symbolTable->lookup(symName); + if (lookup) + return llvm::cast(lookup); + + auto theModule = getOperation(); + ImplicitLocOpBuilder builder(theModule.getLoc(), theModule.getContext()); + builder.setInsertionPointToStart(theModule.getBody()); + auto callOp = builder.create(symName, type, ArrayAttr(), + ArrayAttr()); + symbolTable->insert(callOp); + return callOp; +} + +LogicalResult LowerMemoryInitializersPass::lowerFilledInitialization( + InitializeMemoryOp initOp, InitMemoryFilledOp fillOp) { + auto loc = + FusedLoc::get(initOp.getContext(), + std::array{initOp.getLoc(), fillOp.getLoc()}); + ImplicitLocOpBuilder builder(loc, initOp); + + auto wordType = initOp.getMemory().getType().getWordType(); + auto addrType = initOp.getMemory().getType().getAddressType(); + auto indexType = builder.getIndexType(); + + auto destBits = wordType.getIntOrFloatBitWidth(); + APInt constVal = fillOp.getValue().zextOrTrunc(destBits); + + if (fillOp.getRepeat()) { + auto sourceBits = fillOp.getValueAttr().getType().getIntOrFloatBitWidth(); + unsigned shiftWidth = sourceBits; + while (shiftWidth <= destBits) { + constVal |= constVal << shiftWidth; + shiftWidth *= 2; + } + } + + auto fillValue = builder.create( + builder.getIntegerAttr(wordType, constVal)); + auto zero = + builder.create(builder.getIntegerAttr(indexType, 0)); + auto one = + builder.create(builder.getIntegerAttr(indexType, 1)); + auto limit = builder.create(builder.getIntegerAttr( + indexType, initOp.getMemory().getType().getNumWords())); + auto forOp = builder.create(zero, limit, one); + builder.setInsertionPointToStart(forOp.getBody()); + auto addr = builder.createOrFold( + addrType, forOp.getInductionVar()); + builder.create(initOp.getMemory(), addr, /*enable*/ Value{}, + fillValue); + + return success(); +} + +LogicalResult LowerMemoryInitializersPass::lowerRandomizedInitialization( + InitializeMemoryOp initOp, InitMemoryRandomizedOp randOp) { + auto loc = + FusedLoc::get(initOp.getContext(), + std::array{initOp.getLoc(), randOp.getLoc()}); + ImplicitLocOpBuilder builder(loc, initOp); + + auto oldStorageOp = initOp.getMemory().getDefiningOp(); + if (!oldStorageOp) { + initOp.emitError( + "Unable to retrive storage reference for lowering of initializer."); + return failure(); + } + + auto stride = initOp.getMemory().getType().getStride(); + auto numWords = initOp.getMemory().getType().getNumWords(); + unsigned sizeInBytes = numWords * stride; + + auto envCall = getEnvironmentCall( + EnvironmentCallOp::fillRandomizedSymName, + EnvironmentCallOp::getFillRandomizedType(builder.getContext())); + auto inputTypes = envCall.getFunctionType().getInputs(); + + std::array args; + auto storageType = StorageType::get(builder.getContext(), sizeInBytes); + // Memory reference + args[0] = builder.createOrFold( + storageType, oldStorageOp.getStorage(), oldStorageOp.getOffsetAttr()); + // Num Words + args[1] = builder.createOrFold( + IntegerAttr::get(inputTypes[1], numWords)); + // Word bits + args[2] = builder.createOrFold(IntegerAttr::get( + inputTypes[2], + initOp.getMemory().getType().getWordType().getIntOrFloatBitWidth())); + // Stride + args[3] = builder.createOrFold( + IntegerAttr::get(inputTypes[3], stride)); + builder.create(TypeRange{}, envCall.getSymName(), args); + + return success(); +} + +LogicalResult +LowerMemoryInitializersPass::processInitializerFunction(func::FuncOp funcOp) { + SmallVector initOps; + SmallPtrSet cleanupSet; + + funcOp.walk( + [&](InitializeMemoryOp initMemOp) { initOps.push_back(initMemOp); }); + + bool hasFailed = false; + + for (auto initOp : initOps) { + auto defOp = initOp.getInitializer().getDefiningOp(); + if (!defOp) { + initOp.emitError("Cannot lower initializer passed as argument."); + return failure(); + } + + cleanupSet.insert(defOp); + + TypeSwitch(defOp) + .Case([&](auto op) { + hasFailed |= failed(lowerFilledInitialization(initOp, op)); + }) + .Case([&](auto op) { + hasFailed |= failed(lowerRandomizedInitialization(initOp, op)); + }) + .Default([&](auto) { + defOp->emitOpError("is not a supported memory intitializer."); + hasFailed = true; + }); + } + + if (hasFailed) + return failure(); + + for (auto initOp : initOps) + initOp->erase(); + for (auto cleanupOp : cleanupSet) + if (cleanupOp->getResult(0).getUses().empty()) + cleanupOp->erase(); + + return success(); +} + +void LowerMemoryInitializersPass::runOnOperation() { + symbolTable = nullptr; + auto theModule = getOperation(); + for (auto modelOp : theModule.getOps()) { + if (auto intitFnAttr = modelOp.getInitialFnAttr()) { + if (!symbolTable) + symbolTable = &getAnalysis(); + auto initFn = llvm::dyn_cast_or_null( + symbolTable->lookupSymbolIn(theModule, intitFnAttr)); + assert(!!initFn && "Failed to look-up initializer function."); + if (failed(processInitializerFunction(initFn))) { + signalPassFailure(); + return; + } + } + } +} diff --git a/lib/Dialect/Arc/Transforms/LowerState.cpp b/lib/Dialect/Arc/Transforms/LowerState.cpp index d659e8870397..6e91be2ed85d 100644 --- a/lib/Dialect/Arc/Transforms/LowerState.cpp +++ b/lib/Dialect/Arc/Transforms/LowerState.cpp @@ -176,6 +176,8 @@ static bool canBeMaterializedInInitializer(Operation *op) { return true; if (isa(op->getDialect())) return true; + if (isa(op)) + return true; // TODO: There are some other ops we probably want to allow return false; } @@ -569,6 +571,17 @@ LogicalResult ModuleLowering::lowerState(MemoryOp memOp) { auto allocMemOp = stateBuilder.create( memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs()); memOp.replaceAllUsesWith(allocMemOp.getResult()); + + if (auto initializer = memOp.getInitializer()) { + auto &initialTree = getInitial(); + auto materializedInit = initialTree.materializeValue(initializer); + initialTree.builder.create( + memOp.getLoc(), materializedInit, allocMemOp.getResult()); + if (initializer.getUses().empty()) + if (auto initDefOp = initializer.getDefiningOp()) + initDefOp->erase(); + } + memOp.erase(); return success(); } diff --git a/tools/arcilator/arcilator-runtime.h b/tools/arcilator/arcilator-runtime.h index ee65bebe4a16..5cb49a063a53 100644 --- a/tools/arcilator/arcilator-runtime.h +++ b/tools/arcilator/arcilator-runtime.h @@ -53,6 +53,14 @@ ARC_EXPORT int _arc_libc_fputc(int ch, FILE *stream) { #define ARC_ENV_DECL_GET_PRINT_STREAM(idarg) \ ARC_EXPORT FILE *_arc_env_get_print_stream(uint32_t idarg) +#define ARC_ENV_DECL_FILL_RANDOMIZED(ptrarg, numwordsarg, wordbitsarg, \ + stridearg) \ + ARC_EXPORT void _arc_env_fill_randomized(void *ptrarg, uint64_t numwordsarg, \ + uint32_t wordbitsarg, \ + uint32_t stridearg) + +// Default implementations + #ifndef ARC_NO_DEFAULT_GET_PRINT_STREAM ARC_ENV_DECL_GET_PRINT_STREAM(id) { (void)id; @@ -60,6 +68,50 @@ ARC_ENV_DECL_GET_PRINT_STREAM(id) { } #endif // ARC_NO_DEFAULT_GET_PRINT_STREAM +#ifndef ARC_NO_DEFAULT_FILL_RANDOMIZED + +#include + +template +static void arc_random_fill_helper(T *ptr, std::mt19937_64 &prng, + uint64_t numWords, uint32_t wordBits, + uint32_t numChunks) { + constexpr unsigned typeBits = sizeof(T) * 8; + for (uint64_t word = 0; word < numWords; ++word) { + auto activeBits = wordBits; + for (uint32_t chunk = 0; chunk < numChunks; ++chunk) { + T mask; + if (activeBits < typeBits) { + mask = (static_cast(1) << activeBits) - 1; + activeBits = 0; + } else { + mask = ~static_cast(0); + activeBits -= typeBits; + } + uint64_t rand = prng(); + *(ptr++) = mask & static_cast(rand); + } + } +} + +ARC_ENV_DECL_FILL_RANDOMIZED(ptr, numWords, wordBits, stride) { + // Yuck + static std::mt19937_64 prng(0x123456); + if (stride >= 8) + arc_random_fill_helper(reinterpret_cast(ptr), prng, numWords, + wordBits, stride / 8); + else if (stride == 4) + arc_random_fill_helper(reinterpret_cast(ptr), prng, numWords, + wordBits, 1); + else if (stride == 2) + arc_random_fill_helper(reinterpret_cast(ptr), prng, numWords, + wordBits, 1); + else + arc_random_fill_helper(reinterpret_cast(ptr), prng, numWords, + wordBits, 1); +} +#endif // ARC_NO_DEFAULT_GET_PRINT_STREAM + // ---------------- struct Signal { diff --git a/tools/arcilator/arcilator.cpp b/tools/arcilator/arcilator.cpp index 2ecfcbe8c0e8..3a3a79efc3ef 100644 --- a/tools/arcilator/arcilator.cpp +++ b/tools/arcilator/arcilator.cpp @@ -356,6 +356,7 @@ static void populateArcToLLVMPipeline(PassManager &pm) { if (untilReached(UntilLLVMLowering)) return; pm.addPass(createConvertCombToArithPass()); + pm.addPass(arc::createLowerMemoryInitializers()); pm.addPass(createLowerArcToLLVMPass()); pm.addPass(createCSEPass()); pm.addPass(arc::createArcCanonicalizerPass());