Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added comfy/ldm/nucleus/__init__.py
Empty file.
1,042 changes: 1,042 additions & 0 deletions comfy/ldm/nucleus/model.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.nucleus.model

import comfy.model_management
import comfy.patcher_extension
Expand Down Expand Up @@ -1771,6 +1772,22 @@ def extra_conds_shapes(self, **kwargs):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out


class NucleusImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.nucleus.model.NucleusMoEImageTransformer2DModel)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out


class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
Expand Down
7 changes: 7 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config

if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys and ('{}transformer_blocks.3.moe_layer.gate.weight'.format(key_prefix) in state_dict_keys or '{}transformer_blocks.3.img_mlp.experts.gate_up_proj'.format(key_prefix) in state_dict_keys or '{}transformer_blocks.3.img_mlp.experts.gate_up_projs.0.weight'.format(key_prefix) in state_dict_keys): # Nucleus Image
dit_config = {}
dit_config["image_model"] = "nucleus_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config

if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
Expand Down
16 changes: 16 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,22 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True

# Auto-detect MoE layers: per-tensor FP8 input quantization causes
# catastrophic error in SwiGLU intermediates (gate*up product has
# high dynamic range). Force full precision for these layers.
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
_layer_path = f".{layer_name}."
_moe_patterns = (
".img_mlp.experts.gate_up_projs.",
".img_mlp.experts.down_projs.",
".img_mlp.shared_expert.",
".img_mlp.gate.",
)
if any(_pat in _layer_path for _pat in _moe_patterns):
self._full_precision_mm = True
self._full_precision_mm_config = True


if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

Expand Down
10 changes: 8 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.nucleus_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
Expand Down Expand Up @@ -1189,6 +1190,7 @@ class CLIPType(Enum):
NEWBIE = 24
FLUX2 = 25
LONGCAT_IMAGE = 26
NUCLEUS_IMAGE = 27



Expand Down Expand Up @@ -1449,8 +1451,12 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.QWEN3_8B:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
if clip_type == CLIPType.NUCLEUS_IMAGE:
clip_target.clip = comfy.text_encoders.nucleus_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.nucleus_image.NucleusImageTokenizer
else:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
Expand Down
59 changes: 58 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.nucleus_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
Expand Down Expand Up @@ -1520,6 +1521,62 @@ def clip_target(self, state_dict={}):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))

class NucleusImage(supported_models_base.BASE):
unet_config = {
"image_model": "nucleus_image",
}

unet_extra_config = {
"in_channels": 16,
"out_channels": 16,
"patch_size": 2,
"attention_head_dim": 128,
"num_attention_heads": 16,
"num_key_value_heads": 4,
"joint_attention_dim": 4096,
"axes_dims_rope": [16, 56, 56],
"rope_theta": 10000,
"scale_rope": True,
"dense_moe_strategy": "leave_first_three_blocks_dense",
"num_experts": 64,
"moe_intermediate_dim": 1344,
"capacity_factors": [0, 0, 0, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
"use_sigmoid": False,
"route_scale": 2.5,
"use_grouped_mm": True,
}

sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}

memory_usage_factor = 2.0

latent_format = latent_formats.Wan21

supported_inference_dtypes = [torch.bfloat16, torch.float32]

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
out = model_base.NucleusImage(self, device=device)
return out

def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))

def process_unet_state_dict(self, state_dict):
out_sd = {}
for k, v in state_dict.items():
key_out = k.replace(".moe_layer.", ".img_mlp.")
out_sd[key_out] = v
return out_sd


class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
Expand Down Expand Up @@ -1781,6 +1838,6 @@ def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))


models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, NucleusImage, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]

models += [SVD_img2vid]
97 changes: 97 additions & 0 deletions comfy/text_encoders/nucleus_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch


class NucleusImageQwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_directory=embedding_directory,
embedding_size=4096,
embedding_key='qwen3_8b',
tokenizer_class=Qwen2Tokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_token=151643,
tokenizer_data=tokenizer_data,
)


class NucleusImageTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.qwen3_8b = NucleusImageQwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"

def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
llama_text = self.llama_template.format(text)
tokens = self.qwen3_8b.tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return {"qwen3_8b": tokens}

def untokenize(self, token_weight_pair):
return self.qwen3_8b.untokenize(token_weight_pair)

def state_dict(self):
return {}

def decode(self, token_ids, **kwargs):
return self.qwen3_8b.decode(token_ids, **kwargs)


class NucleusImageQwen3VLText(comfy.text_encoders.llama.Qwen3_8B):
def __init__(self, config_dict, dtype, device, operations):
config_dict = dict(config_dict)
config_dict.setdefault("max_position_embeddings", 262144)
config_dict.setdefault("rope_theta", 5000000.0)
super().__init__(config_dict, dtype, device, operations)


class NucleusImageQwen3_8BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-8, dtype=None, attention_mask=True, model_options={}):
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config={},
dtype=dtype,
special_tokens={"pad": 151643},
layer_norm_hidden_state=False,
model_class=NucleusImageQwen3VLText,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)


class NucleusImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device,
dtype=dtype,
name="qwen3_8b",
clip_model=NucleusImageQwen3_8BModel,
model_options=model_options,
)

def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
return out, pooled, extra


def te(dtype_llama=None, llama_quantization_metadata=None):
class NucleusImageTEModel_(NucleusImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return NucleusImageTEModel_
2 changes: 1 addition & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "nucleus_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
Expand Down
Loading