Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ext/MooncakeCUDAExt/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using CUDA.CUDACore:
is_capturing,
capture_status,
hasfieldcount
using CUDA: CUDACore
using CUDA: cuBLAS
using CUDA: cuSPARSE
using CUDA: cuSOLVER
Expand Down Expand Up @@ -313,13 +314,13 @@ _register_cuda_opaque_types!()
# CUDA @cenum types are primitive types (integer-backed C enums) — never differentiable.
# Mooncake's generic tangent_type @generated function errors on primitive types with no
# registered method, so we register all of them here programmatically.
# Covers: cuBLAS, cuSPARSE, cuSOLVER.
# Covers: CUDACore, cuBLAS, cuSPARSE, cuSOLVER.
# cuDNN enums are handled in MooncakeCUDNNExt (loaded only when cuDNN is available).
# Filter: parentmodule(T) must be one of the CUDA family modules, to avoid accidentally
# re-registering standard Julia primitive types (Bool, Int32, Float64, ...) that happen
# to be visible in the CUDA namespace.
function _register_cuda_enum_types!()
let _cuda_family = (cuBLAS, cuSPARSE, cuSOLVER)
let _cuda_family = (CUDACore, cuBLAS, cuSPARSE, cuSOLVER)
_cenum_seen = Set{DataType}()
for _mod in _cuda_family
for _nm in names(_mod; all=true)
Expand All @@ -330,6 +331,7 @@ function _register_cuda_enum_types!()
end
_T isa DataType || continue
isprimitivetype(_T) || continue
_T <: CUDACore.CEnum.Cenum || continue
parentmodule(_T) in _cuda_family || continue
_T in _cenum_seen && continue
push!(_cenum_seen, _T)
Expand Down
Loading