Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions comfy/ldm/cosmos/predict2.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,14 @@ def _forward(
**kwargs,
):
orig_shape = list(x.shape)

ref_latents = kwargs.get('ref_latents', None)
if ref_latents is not None:
for ref in ref_latents:
if ref.ndim == 4:
ref = ref.unsqueeze(2)
x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2)

x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
x_B_C_T_H_W = x
timesteps_B_T = timesteps
Expand Down
15 changes: 15 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
class Anima(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
self.memory_usage_factor_conds = ("ref_latents",)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
Expand All @@ -1221,6 +1222,20 @@ def extra_conds(self, **kwargs):
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return out

def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = [1, 16, sum(math.prod(lat.size()[2:]) for lat in ref_latents)]
return out

class Lumina2(BaseModel):
Expand Down