Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion comfy_extras/nodes_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,8 +1180,30 @@ def execute(
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)

with torch.inference_mode(False):
# Now ComfyUI will load model in inference mode
# which make all parameter is now inference mode tensors
# to make the training correctly working
# we re-build the parameters in training mode
for module in mp.model.modules():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is theoretically sharable amongst training and non-training elements in a workflow so this change is global across all consumers of the shared single model.

Multiple ModelPatchers share the same model, so from that model mp.model should be reasonable immutable.

The good news we recently made this easy to do the deep clone for a few other features.

Do you just need your own full copy of the model? Something like this might do it:

(venv) rattus@rattus-box2:~/ComfyUI$ git diff
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index c9ad8727..858a7a47 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -324,10 +324,11 @@ class ModelPatcher:
     def get_clone_model_override(self):
         return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
 
-    def clone(self, disable_dynamic=False, model_override=None):
+    def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
         class_ = self.__class__
-        if self.is_dynamic() and disable_dynamic:
-            class_ = ModelPatcher
+        if self.is_dynamic() and disable_dynamic or force_deepcopy:
+            if self.is_dynamic() and disable_dynamic:
+                class_ = ModelPatcher
             if model_override is None:
                 if self.cached_patcher_init is None:
                     raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")

for name, param in list(module._parameters.items()):
if param is not None:
try:
_ = param._version
except Exception:
module._parameters[name] = torch.nn.Parameter(
param.detach().clone(),
requires_grad=param.requires_grad,
)

for name, buf in list(module._buffers.items()):
if buf is not None:
try:
_ = buf._version
except Exception:
module._buffers[name] = buf.detach().clone()

# Setup models for training
mp.model.requires_grad_(False)
mp.model.requires_grad_(False).train()

# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)
Expand Down
Loading