Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions opacus/accountants/accountant.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def get_epsilon(self, delta: float, *args, **kwargs) -> float:
"""
pass

@abc.abstractmethod
def __len__(self) -> int:
"""
Number of optimization steps taken so far
"""
pass

@classmethod
@abc.abstractmethod
def mechanism(cls) -> str:
Expand All @@ -66,6 +59,13 @@ def mechanism(cls) -> str:
"""
pass


def __len__(self) -> int:
"""
Number of optimization steps taken so far
"""
return sum(num_steps for _, _, num_steps in self.history)

def get_optimizer_hook_fn(
self, sample_rate: float
) -> Callable[[DPOptimizer], None]:
Expand Down
3 changes: 0 additions & 3 deletions opacus/accountants/gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def get_epsilon(self, delta: float, poisson: bool = True, **kwargs) -> float:
delta=delta,
)

def __len__(self):
return len(self.history)

@classmethod
def mechanism(cls) -> str:
return "gdp"
3 changes: 0 additions & 3 deletions opacus/accountants/prv.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,3 @@ def _get_domain(
@classmethod
def mechanism(cls) -> str:
return "prv"

def __len__(self):
return len(self.history)
3 changes: 0 additions & 3 deletions opacus/accountants/rdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def get_epsilon(
eps, _ = self.get_privacy_spent(delta=delta, alphas=alphas)
return eps

def __len__(self):
return len(self.history)

@classmethod
def mechanism(cls) -> str:
return "rdp"
24 changes: 24 additions & 0 deletions opacus/tests/accountants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,30 @@ def test_prv_accountant(self) -> None:
epsilon = accountant.get_epsilon(delta=1e-5)
self.assertAlmostEqual(epsilon, 6.777395712150674)

def test_len_counts_optimization_steps(self) -> None:
noise_multiplier = 1.5
sample_rate = 0.04
steps = 50

for accountant in (RDPAccountant(), GaussianAccountant(), PRVAccountant()):
for _ in range(steps):
accountant.step(
noise_multiplier=noise_multiplier, sample_rate=sample_rate
)
self.assertEqual(len(accountant.history), 1)
self.assertEqual(len(accountant), steps)

for accountant in (RDPAccountant(), PRVAccountant()):
for _ in range(steps):
accountant.step(
noise_multiplier=noise_multiplier, sample_rate=sample_rate
)
accountant.step(
noise_multiplier=noise_multiplier * 2, sample_rate=sample_rate
)
self.assertEqual(len(accountant.history), 2)
self.assertEqual(len(accountant), steps + 1)

def test_get_noise_multiplier_rdp_epochs(self) -> None:
delta = 1e-5
sample_rate = 0.04
Expand Down