Skip to content

Add Nucleus Image support#13471

Open
envy-ai wants to merge 5 commits intoComfy-Org:masterfrom
envy-ai:codex/nucleus-image-support
Open

Add Nucleus Image support#13471
envy-ai wants to merge 5 commits intoComfy-Org:masterfrom
envy-ai:codex/nucleus-image-support

Conversation

@envy-ai
Copy link
Copy Markdown

@envy-ai envy-ai commented Apr 19, 2026

Summary

  • add native Nucleus-Image model support, including Nucleus MoE transformer blocks and Qwen3-VL text encoder wiring
  • preserve packed Nucleus expert tensors for grouped-mm while supporting split expert checkpoints
  • support packed FP8 expert quant metadata and low-VRAM/DynamicVRAM weight casting
  • register Nucleus model detection, model config, CLIP loader type, and focused tests
  • add script_examples/convert_nucleus_bf16_to_packed_fp8.py for creating packed scaled-FP8 Nucleus diffusion checkpoints from BF16 sources

Validation

  • /home/bart/.conda/envs/comfyui/bin/python -m py_compile comfy/ldm/nucleus/model.py comfy/text_encoders/nucleus_image.py comfy/model_base.py comfy/model_detection.py comfy/ops.py comfy/sd.py comfy/supported_models.py nodes.py tests-unit/comfy_test/model_detection_test.py
  • /home/bart/.conda/envs/comfyui/bin/python -m py_compile script_examples/convert_nucleus_bf16_to_packed_fp8.py
  • focused Nucleus tests from tests-unit/comfy_test/model_detection_test.py run directly with the conda Python:
    • test_nucleus_diffusers_expert_weights_stay_packed_for_grouped_mm
    • test_nucleus_swiglu_experts_loads_packed_weights
    • test_nucleus_swiglu_experts_loads_packed_quantized_weights
    • test_nucleus_split_expert_weights_still_load_for_quantized_files
    • test_nucleus_dense_swiglu_uses_diffusers_chunk_order
  • CUDA low-VRAM grouped-mm smoke test with quantized packed experts
  • synthetic safetensors conversion smoke test with script_examples/convert_nucleus_bf16_to_packed_fp8.py, verifying packed expert keys, FP8 dtypes, scales, and comfy_quant metadata

Note: pytest is not installed in this local conda environment, so the focused tests were invoked directly.

@envy-ai envy-ai marked this pull request as ready for review April 19, 2026 01:59
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 19, 2026

📝 Walkthrough

Walkthrough

This PR adds a new Nucleus Image diffusion model: a MoE-based transformer with rotary positional embeddings, timestep projection, GQA cross-attention, and expert routing/FFN (SwiGLU) implementations. It integrates tokenizer and Qwen3-based text-encoder support, model detection and registration, UNet wiring (NucleusImage), quantization-aware expert weight loading, grouped-GEMM support, and unit tests covering expert weight packing/loading and mixed-precision behaviors.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add Nucleus Image support' clearly and concisely describes the main change: introducing native Nucleus Image model support to the codebase.
Description check ✅ Passed The description provides detailed information about the changes, including the addition of Nucleus MoE transformer blocks, Qwen3-VL text encoder support, expert quantization handling, model detection registration, and validation steps performed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy/ldm/nucleus/model.py`:
- Around line 125-134: Guard against requesting more RoPE rows than exist:
ensure that when computing txt_freqs from self.pos_freqs in the code that uses
max_vid_index and max_txt_seq_len_int you check that max_vid_index +
max_txt_seq_len_int <= self.pos_freqs.size(0); if it would exceed the available
rows, either expand the frequency table (e.g., grow/resize self.pos_freqs on
demand) or raise a clear ValueError indicating the prompt is too long and
stating the max supported sequence length. Update the logic around
max_vid_index/max_txt_seq_len_int and the extraction of txt_freqs (and any
callers like _apply_rotary_emb_nucleus) to perform this check before slicing so
you never produce a shorter txt_freqs than max_txt_seq_len_int.
- Around line 918-926: encoder_hidden_states_mask may already be a float 0/1
mask (cast earlier by BaseModel) and therefore skips the boolean conversion;
detect when encoder_hidden_states_mask is floating and contains only 0/1 values
and convert it to an additive mask (attendable -> 0, masked -> large negative)
before placing it into block_attention_kwargs; use x.dtype for the returned
tensor dtype and a large magnitude from torch.finfo(x.dtype).max (same pattern
as the existing integer/bool conversion) and then set
block_attention_kwargs["attention_mask"] = encoder_hidden_states_mask.

In `@comfy/ops.py`:
- Around line 954-965: The MoE path matching in the block that checks
quant_format uses leading-dot substrings (_moe_patterns) and raw layer_name,
which misses root-level names like "img_mlp.gate" and can overmatch siblings;
normalize layer_name by surrounding it with separators (e.g., prefix and suffix
with dots or use a tokenizer) and change the pattern list to bounded tokens (for
example ".img_mlp.experts.gate_up_projs.", ".img_mlp.experts.down_projs.",
".img_mlp.shared_expert.", ".img_mlp.gate.") so you perform membership checks
against the normalized layer_name; update the loop that sets
self._full_precision_mm and self._full_precision_mm_config when a bounded match
is found (symbols: _moe_patterns, layer_name, _full_precision_mm,
_full_precision_mm_config, quant_format).

In `@comfy/supported_models.py`:
- Around line 1572-1573: process_unet_state_dict currently returns the
state_dict unchanged, causing checkpoints with keys like
transformer_blocks.<i>.moe_layer.gate.weight to be mis-detected by
detect_unet_config; update process_unet_state_dict to normalize MoE-related key
names before returning by mapping any transformer_blocks.*.moe_layer.* keys to
the new transformer_blocks.*.img_mlp.* names (e.g., replace ".moe_layer." with
".img_mlp." for gate, gate.weight, gate.bias, etc.), only operating on UNet keys
so downstream loading finds the expected
transformer_blocks.*.img_mlp.gate.weight entries.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b103fd84-2c38-40af-a02b-b0176508d5d9

📥 Commits

Reviewing files that changed from the base of the PR and between 3086026 and 6c82e9b.

📒 Files selected for processing (10)
  • comfy/ldm/nucleus/__init__.py
  • comfy/ldm/nucleus/model.py
  • comfy/model_base.py
  • comfy/model_detection.py
  • comfy/ops.py
  • comfy/sd.py
  • comfy/supported_models.py
  • comfy/text_encoders/nucleus_image.py
  • nodes.py
  • tests-unit/comfy_test/model_detection_test.py

Comment thread comfy/ldm/nucleus/model.py Outdated
Comment thread comfy/ldm/nucleus/model.py Outdated
Comment thread comfy/ops.py Outdated
Comment thread comfy/supported_models.py Outdated
@envy-ai envy-ai marked this pull request as draft April 19, 2026 02:04
@envy-ai envy-ai marked this pull request as ready for review April 19, 2026 02:07
@envy-ai envy-ai marked this pull request as draft April 19, 2026 02:16
@envy-ai envy-ai marked this pull request as ready for review April 19, 2026 02:17
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy/ldm/nucleus/model.py`:
- Around line 1035-1042: The final reshape assumes in_channels == out_channels
causing a silent shape error; change the reconstruction to use the model's
output channel count instead of orig_shape[1]. Locate the block around
hidden_states = hidden_states[:, :num_embeds].view(...) and the subsequent
hidden_states.reshape(orig_shape) return in NucleusMoEImageTransformer2DModel
(and any code that consumes proj_out), and replace uses of orig_shape[1] for the
channel dimension with self.out_channels (or compute channels =
patch_size*patch_size*self.out_channels) so the view/reshape match the proj_out
feature width and do not rely on in_channels equaling out_channels.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2cf3e7f3-6ed4-42dd-a0f6-f84e7fd7d634

📥 Commits

Reviewing files that changed from the base of the PR and between 6c82e9b and be403d6.

⛔ Files ignored due to path filters (1)
  • script_examples/convert_nucleus_bf16_to_packed_fp8.py is excluded by !script_examples/**
📒 Files selected for processing (5)
  • comfy/ldm/nucleus/model.py
  • comfy/ops.py
  • comfy/supported_models.py
  • tests-unit/comfy_quant/test_mixed_precision.py
  • tests-unit/comfy_test/model_detection_test.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • comfy/ops.py
  • comfy/supported_models.py
  • tests-unit/comfy_test/model_detection_test.py

Comment on lines +1035 to +1042
hidden_states = hidden_states[:, :num_embeds].view(
orig_shape[0], orig_shape[-3],
orig_shape[-2] // 2, orig_shape[-1] // 2,
orig_shape[1], 2, 2,
)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
# Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Output reshape silently assumes in_channels == out_channels.

The final unpack uses orig_shape[1] (the input channel count) as the channel dim of the reconstructed tensor, but proj_out emits patch_size * patch_size * self.out_channels features per token. The view therefore only lines up when in_channels == out_channels; any other configuration will blow up with a shape error at this .view(...), and the __init__ defaults happen to advertise a mismatch (in_channels=64, out_channels=16).

In practice model_base.NucleusImage wires matching values so real runs are fine, but the defaults in the constructor signature are misleading and a direct NucleusMoEImageTransformer2DModel() instantiation would crash here rather than at __init__. Either align the defaults or reshape using self.out_channels:

🩹 Suggested tweak
-        hidden_states = hidden_states[:, :num_embeds].view(
-            orig_shape[0], orig_shape[-3],
-            orig_shape[-2] // 2, orig_shape[-1] // 2,
-            orig_shape[1], 2, 2,
-        )
-        hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
-        # Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
-        return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
+        hidden_states = hidden_states[:, :num_embeds].view(
+            orig_shape[0], orig_shape[-3],
+            orig_shape[-2] // 2, orig_shape[-1] // 2,
+            self.out_channels, 2, 2,
+        )
+        hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
+        out_shape = (orig_shape[0], self.out_channels, *orig_shape[2:])
+        # Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
+        return -hidden_states.reshape(out_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/nucleus/model.py` around lines 1035 - 1042, The final reshape
assumes in_channels == out_channels causing a silent shape error; change the
reconstruction to use the model's output channel count instead of orig_shape[1].
Locate the block around hidden_states = hidden_states[:, :num_embeds].view(...)
and the subsequent hidden_states.reshape(orig_shape) return in
NucleusMoEImageTransformer2DModel (and any code that consumes proj_out), and
replace uses of orig_shape[1] for the channel dimension with self.out_channels
(or compute channels = patch_size*patch_size*self.out_channels) so the
view/reshape match the proj_out feature width and do not rely on in_channels
equaling out_channels.

@envy-ai
Copy link
Copy Markdown
Author

envy-ai commented Apr 19, 2026

Quantized models for this are here:

https://huggingface.co/e-n-v-y/Nucleus-Image-FP8-e4m3fn-scaled/tree/main

@jtreminio
Copy link
Copy Markdown

Unrelated to the actual PR, but I merged your code locally and tested out the model. I can't for the life of me get anything approaching the example images on their HF.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants