diff --git a/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl index 3ce45951fa..0ef1ffae1c 100644 --- a/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl @@ -25,6 +25,7 @@ using CUDA.CUDACore: is_capturing, capture_status, hasfieldcount +using CUDA: CUDACore using CUDA: cuBLAS using CUDA: cuSPARSE using CUDA: cuSOLVER @@ -313,13 +314,13 @@ _register_cuda_opaque_types!() # CUDA @cenum types are primitive types (integer-backed C enums) — never differentiable. # Mooncake's generic tangent_type @generated function errors on primitive types with no # registered method, so we register all of them here programmatically. -# Covers: cuBLAS, cuSPARSE, cuSOLVER. +# Covers: CUDACore, cuBLAS, cuSPARSE, cuSOLVER. # cuDNN enums are handled in MooncakeCUDNNExt (loaded only when cuDNN is available). # Filter: parentmodule(T) must be one of the CUDA family modules, to avoid accidentally # re-registering standard Julia primitive types (Bool, Int32, Float64, ...) that happen # to be visible in the CUDA namespace. function _register_cuda_enum_types!() - let _cuda_family = (cuBLAS, cuSPARSE, cuSOLVER) + let _cuda_family = (CUDACore, cuBLAS, cuSPARSE, cuSOLVER) _cenum_seen = Set{DataType}() for _mod in _cuda_family for _nm in names(_mod; all=true) @@ -330,6 +331,7 @@ function _register_cuda_enum_types!() end _T isa DataType || continue isprimitivetype(_T) || continue + _T <: CUDACore.CEnum.Cenum || continue parentmodule(_T) in _cuda_family || continue _T in _cenum_seen && continue push!(_cenum_seen, _T)