Skip to content

Mark CUDACore.cudaError_enum as having no tangent#1175

Merged
sunxd3 merged 2 commits into
mainfrom
ksh/cu_err
May 15, 2026
Merged

Mark CUDACore.cudaError_enum as having no tangent#1175
sunxd3 merged 2 commits into
mainfrom
ksh/cu_err

Conversation

@kshyatt
Copy link
Copy Markdown
Collaborator

@kshyatt kshyatt commented May 14, 2026

Ran into this when Mooncake tried to compile through an error -- this enum is definitely non-differentiable

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1175 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1175/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌───────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                 Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                String │   String │   String │      String │  String │      String │ String │
├───────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│              sum_1000 │ 171.0 ns │     1.64 │        1.75 │   0.702 │         3.4 │   6.56 │
│             _sum_1000 │ 981.0 ns │     6.73 │        1.01 │  3760.0 │        42.6 │   1.05 │
│          sum_sin_1000 │  7.07 μs │     3.57 │        1.34 │    1.51 │        11.1 │   1.84 │
│         _sum_sin_1000 │  6.37 μs │     3.32 │        1.78 │   235.0 │        12.5 │   2.07 │
│              kron_sum │ 226.0 μs │     12.4 │         3.0 │    19.7 │       313.0 │   17.7 │
│         kron_view_sum │ 292.0 μs │     10.8 │        5.04 │    25.0 │       303.0 │   12.2 │
│ naive_map_sin_cos_exp │  2.38 μs │     3.01 │        1.49 │ missing │        7.48 │   2.09 │
│       map_sin_cos_exp │  2.35 μs │     3.55 │        1.62 │    1.49 │        6.49 │   2.59 │
│ broadcast_sin_cos_exp │  2.49 μs │     3.12 │        1.59 │    3.57 │        1.43 │   2.01 │
│            simple_mlp │ 358.0 μs │     4.83 │        2.66 │    2.22 │        8.74 │   2.94 │
│                gp_lml │ 176.0 μs │     10.9 │        2.42 │    4.95 │     missing │   5.37 │
│    large_single_block │ 421.0 ns │     5.31 │         1.9 │  4880.0 │        35.4 │   2.05 │
└───────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

Mark `CUDACore.cudaError_enum` as having no tangent

Signed-off-by: Katharine Hyatt <kshyatt@users.noreply.github.com>
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 14, 2026

maybe instead of the current two-line change that only register one enum, we can add CUDACore to the list at

# CUDA @cenum types are primitive types (integer-backed C enums) — never differentiable.
# Mooncake's generic tangent_type @generated function errors on primitive types with no
# registered method, so we register all of them here programmatically.
# Covers: cuBLAS, cuSPARSE, cuSOLVER.
# cuDNN enums are handled in MooncakeCUDNNExt (loaded only when cuDNN is available).
# Filter: parentmodule(T) must be one of the CUDA family modules, to avoid accidentally
# re-registering standard Julia primitive types (Bool, Int32, Float64, ...) that happen
# to be visible in the CUDA namespace.
function _register_cuda_enum_types!()
let _cuda_family = (cuBLAS, cuSPARSE, cuSOLVER)
_cenum_seen = Set{DataType}()
for _mod in _cuda_family
for _nm in names(_mod; all=true)
_T = try
getfield(_mod, _nm)
catch
nothing
end
_T isa DataType || continue
isprimitivetype(_T) || continue
parentmodule(_T) in _cuda_family || continue
_T in _cenum_seen && continue
push!(_cenum_seen, _T)
(
try
tangent_type(_T) === NoTangent
catch
false
end
) && continue
@eval tangent_type(::Type{$_T}) = NoTangent
end
end
end
return nothing
end

something like

diff --git a/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl
index 81bf0605d..0ef1ffae1 100644
--- a/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl
+++ b/ext/MooncakeCUDAExt/MooncakeCUDAExt.jl
@@ -12,7 +12,6 @@ using CUDA.CUDACore:
     CUmemPoolHandle_st,
     CuArrayStyle,
     CUdevice_attribute_enum,
-    cudaError_enum,
     cu,
     TaskLocalState,
     task_local_state!,
@@ -26,6 +25,7 @@ using CUDA.CUDACore:
     is_capturing,
     capture_status,
     hasfieldcount
+using CUDA: CUDACore
 using CUDA: cuBLAS
 using CUDA: cuSPARSE
 using CUDA: cuSOLVER
@@ -288,7 +288,6 @@ function _register_cuda_opaque_types!()
         # CuStream contains Ptr/Bool/CuContext fields; without NoTangent, Mooncake
         # generates a MutableTangent that propagates into task-local CUDA state.
         (CuStream, false),
-        (cudaError_enum, false),
         # TaskLocalState bundles device index, stream handles, and library contexts.
         (TaskLocalState, false),
         # CuContext wraps an opaque Ptr{Cvoid} to the CUDA context.
@@ -315,13 +314,13 @@ _register_cuda_opaque_types!()
 # CUDA @cenum types are primitive types (integer-backed C enums) — never differentiable.
 # Mooncake's generic tangent_type @generated function errors on primitive types with no
 # registered method, so we register all of them here programmatically.
-# Covers: cuBLAS, cuSPARSE, cuSOLVER.
+# Covers: CUDACore, cuBLAS, cuSPARSE, cuSOLVER.
 # cuDNN enums are handled in MooncakeCUDNNExt (loaded only when cuDNN is available).
 # Filter: parentmodule(T) must be one of the CUDA family modules, to avoid accidentally
 # re-registering standard Julia primitive types (Bool, Int32, Float64, ...) that happen
 # to be visible in the CUDA namespace.
 function _register_cuda_enum_types!()
-    let _cuda_family = (cuBLAS, cuSPARSE, cuSOLVER)
+    let _cuda_family = (CUDACore, cuBLAS, cuSPARSE, cuSOLVER)
         _cenum_seen = Set{DataType}()
         for _mod in _cuda_family
             for _nm in names(_mod; all=true)
@@ -332,6 +331,7 @@ function _register_cuda_enum_types!()
                 end
                 _T isa DataType || continue
                 isprimitivetype(_T) || continue
+                _T <: CUDACore.CEnum.Cenum || continue
                 parentmodule(_T) in _cuda_family || continue
                 _T in _cenum_seen && continue
                 push!(_cenum_seen, _T)

?

@kshyatt
Copy link
Copy Markdown
Collaborator Author

kshyatt commented May 15, 2026

Sure, let's try it!

@kshyatt
Copy link
Copy Markdown
Collaborator Author

kshyatt commented May 15, 2026

Flux failure looks unrelated?

@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 15, 2026

Yeah, it also appears in other PRs

@sunxd3 sunxd3 merged commit 11b46f4 into main May 15, 2026
88 of 90 checks passed
@sunxd3 sunxd3 deleted the ksh/cu_err branch May 15, 2026 13:36
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 15, 2026

Thanks, Katharine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants