Skip to content

Add LoRA finetuning support for Contraction#1453

Open
hyjwpk wants to merge 1 commit into
ACEsuit:developfrom
hyjwpk:main
Open

Add LoRA finetuning support for Contraction#1453
hyjwpk wants to merge 1 commit into
ACEsuit:developfrom
hyjwpk:main

Conversation

@hyjwpk

@hyjwpk hyjwpk commented Apr 28, 2026

Copy link
Copy Markdown

Summary

Extends the existing LoRA implementation to cover Contraction layers inside SymmetricContraction.

Changes:

  • Add LoRAContraction in mace/modules/lora.py: applies a low-rank update to the highest-order correlation weights
    (weights_max), with inference-time delta caching and __getattr__ forwarding for compatibility with model inspection utilities
  • Update inject_lora() with a new wrap_contraction flag (default True)

Related

Related to #1450

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MACE’s LoRA fine-tuning utilities to also wrap Contraction layers (the core computation inside SymmetricContraction), enabling low-rank adaptation of the highest-order correlation weights and ensuring merge/unwrapping logic accounts for the new wrapper.

Changes:

  • Add LoRAContraction wrapper that applies a low-rank update to Contraction.weights_max, including inference-time delta caching and attribute forwarding.
  • Extend inject_lora() with a wrap_contraction flag (default True) and update wrapper detection/merging to include LoRAContraction.
  • Update LoRA merge wrapper-removal test to recognize the new wrapper type.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
mace/modules/lora.py Introduces LoRAContraction and extends LoRA injection/merge logic to handle Contraction layers.
tests/test_lora.py Updates wrapper counting in merge tests to include LoRAContraction.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mace/modules/lora.py
Comment on lines 364 to 372
def inject_lora(
module: nn.Module,
rank: int = 4,
alpha: float = 1.0,
wrap_equivariant: bool = True,
wrap_dense: bool = True,
wrap_contraction: bool = True,
_is_root: bool = True,
) -> None:
Comment thread mace/modules/lora.py
Comment on lines +332 to +356
effective = self.base.weights_max + self.scaling * delta

out = self.base.graph_opt_main(
self.base.U_tensors(self.base.correlation),
effective,
x,
y,
)
for i, (weight, contract_weights, contract_features) in enumerate(
zip(
self.base.weights,
self.base.contractions_weighting,
self.base.contractions_features,
)
):
c_tensor = contract_weights(
self.base.U_tensors(self.base.correlation - i - 1),
weight,
y,
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)

return out.view(out.shape[0], -1)

Comment thread mace/modules/lora.py

def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1):
inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True)
inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True, wrap_contraction=True)
Comment thread tests/test_lora.py
count = 0
for child in module.modules():
if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)):
if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear, LoRAContraction)):
Comment thread tests/test_lora.py
Comment on lines 264 to 269
# Count LoRA wrappers before merge
def count_lora_wrappers(module):
count = 0
for child in module.modules():
if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)):
if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear, LoRAContraction)):
count += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants