Add Nucleus Image support#13471
Conversation
📝 WalkthroughWalkthroughThis PR adds a new Nucleus Image diffusion model: a MoE-based transformer with rotary positional embeddings, timestep projection, GQA cross-attention, and expert routing/FFN (SwiGLU) implementations. It integrates tokenizer and Qwen3-based text-encoder support, model detection and registration, UNet wiring (NucleusImage), quantization-aware expert weight loading, grouped-GEMM support, and unit tests covering expert weight packing/loading and mixed-precision behaviors. 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/ldm/nucleus/model.py`:
- Around line 125-134: Guard against requesting more RoPE rows than exist:
ensure that when computing txt_freqs from self.pos_freqs in the code that uses
max_vid_index and max_txt_seq_len_int you check that max_vid_index +
max_txt_seq_len_int <= self.pos_freqs.size(0); if it would exceed the available
rows, either expand the frequency table (e.g., grow/resize self.pos_freqs on
demand) or raise a clear ValueError indicating the prompt is too long and
stating the max supported sequence length. Update the logic around
max_vid_index/max_txt_seq_len_int and the extraction of txt_freqs (and any
callers like _apply_rotary_emb_nucleus) to perform this check before slicing so
you never produce a shorter txt_freqs than max_txt_seq_len_int.
- Around line 918-926: encoder_hidden_states_mask may already be a float 0/1
mask (cast earlier by BaseModel) and therefore skips the boolean conversion;
detect when encoder_hidden_states_mask is floating and contains only 0/1 values
and convert it to an additive mask (attendable -> 0, masked -> large negative)
before placing it into block_attention_kwargs; use x.dtype for the returned
tensor dtype and a large magnitude from torch.finfo(x.dtype).max (same pattern
as the existing integer/bool conversion) and then set
block_attention_kwargs["attention_mask"] = encoder_hidden_states_mask.
In `@comfy/ops.py`:
- Around line 954-965: The MoE path matching in the block that checks
quant_format uses leading-dot substrings (_moe_patterns) and raw layer_name,
which misses root-level names like "img_mlp.gate" and can overmatch siblings;
normalize layer_name by surrounding it with separators (e.g., prefix and suffix
with dots or use a tokenizer) and change the pattern list to bounded tokens (for
example ".img_mlp.experts.gate_up_projs.", ".img_mlp.experts.down_projs.",
".img_mlp.shared_expert.", ".img_mlp.gate.") so you perform membership checks
against the normalized layer_name; update the loop that sets
self._full_precision_mm and self._full_precision_mm_config when a bounded match
is found (symbols: _moe_patterns, layer_name, _full_precision_mm,
_full_precision_mm_config, quant_format).
In `@comfy/supported_models.py`:
- Around line 1572-1573: process_unet_state_dict currently returns the
state_dict unchanged, causing checkpoints with keys like
transformer_blocks.<i>.moe_layer.gate.weight to be mis-detected by
detect_unet_config; update process_unet_state_dict to normalize MoE-related key
names before returning by mapping any transformer_blocks.*.moe_layer.* keys to
the new transformer_blocks.*.img_mlp.* names (e.g., replace ".moe_layer." with
".img_mlp." for gate, gate.weight, gate.bias, etc.), only operating on UNet keys
so downstream loading finds the expected
transformer_blocks.*.img_mlp.gate.weight entries.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b103fd84-2c38-40af-a02b-b0176508d5d9
📒 Files selected for processing (10)
comfy/ldm/nucleus/__init__.pycomfy/ldm/nucleus/model.pycomfy/model_base.pycomfy/model_detection.pycomfy/ops.pycomfy/sd.pycomfy/supported_models.pycomfy/text_encoders/nucleus_image.pynodes.pytests-unit/comfy_test/model_detection_test.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/ldm/nucleus/model.py`:
- Around line 1035-1042: The final reshape assumes in_channels == out_channels
causing a silent shape error; change the reconstruction to use the model's
output channel count instead of orig_shape[1]. Locate the block around
hidden_states = hidden_states[:, :num_embeds].view(...) and the subsequent
hidden_states.reshape(orig_shape) return in NucleusMoEImageTransformer2DModel
(and any code that consumes proj_out), and replace uses of orig_shape[1] for the
channel dimension with self.out_channels (or compute channels =
patch_size*patch_size*self.out_channels) so the view/reshape match the proj_out
feature width and do not rely on in_channels equaling out_channels.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2cf3e7f3-6ed4-42dd-a0f6-f84e7fd7d634
⛔ Files ignored due to path filters (1)
script_examples/convert_nucleus_bf16_to_packed_fp8.pyis excluded by!script_examples/**
📒 Files selected for processing (5)
comfy/ldm/nucleus/model.pycomfy/ops.pycomfy/supported_models.pytests-unit/comfy_quant/test_mixed_precision.pytests-unit/comfy_test/model_detection_test.py
🚧 Files skipped from review as they are similar to previous changes (3)
- comfy/ops.py
- comfy/supported_models.py
- tests-unit/comfy_test/model_detection_test.py
| hidden_states = hidden_states[:, :num_embeds].view( | ||
| orig_shape[0], orig_shape[-3], | ||
| orig_shape[-2] // 2, orig_shape[-1] // 2, | ||
| orig_shape[1], 2, 2, | ||
| ) | ||
| hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) | ||
| # Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step(). | ||
| return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] |
There was a problem hiding this comment.
Output reshape silently assumes in_channels == out_channels.
The final unpack uses orig_shape[1] (the input channel count) as the channel dim of the reconstructed tensor, but proj_out emits patch_size * patch_size * self.out_channels features per token. The view therefore only lines up when in_channels == out_channels; any other configuration will blow up with a shape error at this .view(...), and the __init__ defaults happen to advertise a mismatch (in_channels=64, out_channels=16).
In practice model_base.NucleusImage wires matching values so real runs are fine, but the defaults in the constructor signature are misleading and a direct NucleusMoEImageTransformer2DModel() instantiation would crash here rather than at __init__. Either align the defaults or reshape using self.out_channels:
🩹 Suggested tweak
- hidden_states = hidden_states[:, :num_embeds].view(
- orig_shape[0], orig_shape[-3],
- orig_shape[-2] // 2, orig_shape[-1] // 2,
- orig_shape[1], 2, 2,
- )
- hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
- # Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
- return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
+ hidden_states = hidden_states[:, :num_embeds].view(
+ orig_shape[0], orig_shape[-3],
+ orig_shape[-2] // 2, orig_shape[-1] // 2,
+ self.out_channels, 2, 2,
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
+ out_shape = (orig_shape[0], self.out_channels, *orig_shape[2:])
+ # Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
+ return -hidden_states.reshape(out_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@comfy/ldm/nucleus/model.py` around lines 1035 - 1042, The final reshape
assumes in_channels == out_channels causing a silent shape error; change the
reconstruction to use the model's output channel count instead of orig_shape[1].
Locate the block around hidden_states = hidden_states[:, :num_embeds].view(...)
and the subsequent hidden_states.reshape(orig_shape) return in
NucleusMoEImageTransformer2DModel (and any code that consumes proj_out), and
replace uses of orig_shape[1] for the channel dimension with self.out_channels
(or compute channels = patch_size*patch_size*self.out_channels) so the
view/reshape match the proj_out feature width and do not rely on in_channels
equaling out_channels.
|
Quantized models for this are here: https://huggingface.co/e-n-v-y/Nucleus-Image-FP8-e4m3fn-scaled/tree/main |
|
Unrelated to the actual PR, but I merged your code locally and tested out the model. I can't for the life of me get anything approaching the example images on their HF. |
Summary
script_examples/convert_nucleus_bf16_to_packed_fp8.pyfor creating packed scaled-FP8 Nucleus diffusion checkpoints from BF16 sourcesValidation
/home/bart/.conda/envs/comfyui/bin/python -m py_compile comfy/ldm/nucleus/model.py comfy/text_encoders/nucleus_image.py comfy/model_base.py comfy/model_detection.py comfy/ops.py comfy/sd.py comfy/supported_models.py nodes.py tests-unit/comfy_test/model_detection_test.py/home/bart/.conda/envs/comfyui/bin/python -m py_compile script_examples/convert_nucleus_bf16_to_packed_fp8.pytests-unit/comfy_test/model_detection_test.pyrun directly with the conda Python:test_nucleus_diffusers_expert_weights_stay_packed_for_grouped_mmtest_nucleus_swiglu_experts_loads_packed_weightstest_nucleus_swiglu_experts_loads_packed_quantized_weightstest_nucleus_split_expert_weights_still_load_for_quantized_filestest_nucleus_dense_swiglu_uses_diffusers_chunk_orderscript_examples/convert_nucleus_bf16_to_packed_fp8.py, verifying packed expert keys, FP8 dtypes, scales, andcomfy_quantmetadataNote:
pytestis not installed in this local conda environment, so the focused tests were invoked directly.