Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private async Task<bool> ShouldCreateDefaultCollectionAsync(AutomaticallyConfirm
!string.IsNullOrWhiteSpace(request.DefaultUserCollectionName)
&& request.Organization!.UseMyItems
&& (await policyRequirementQuery.GetAsync<OrganizationDataOwnershipPolicyRequirement>(request.OrganizationUser!.UserId!.Value))
.RequiresDefaultCollectionOnConfirm(request.Organization!.Id);
.GetDefaultCollectionRequestOnConfirm(request.Organization!.Id).ShouldCreateDefaultCollection;

private async Task PushSyncOrganizationKeysAsync(AutomaticallyConfirmOrganizationUserValidationRequest request)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public async Task<OrganizationUser> ConfirmUserAsync(Guid organizationId, Guid o
throw new BadRequestException(error);
}

await CreateDefaultCollectionAsync(orgUser, organization, defaultUserCollectionName);
await CreateManyDefaultCollectionsAsync(organization, [orgUser], defaultUserCollectionName);

return orgUser;
}
Expand All @@ -109,14 +109,7 @@ public async Task<List<Tuple<OrganizationUser, string>>> ConfirmUsersAsync(Guid
.Select(r => r.Item1)
.ToList();

if (confirmedOrganizationUsers.Count == 1)
{
await CreateDefaultCollectionAsync(confirmedOrganizationUsers.Single(), organization, defaultUserCollectionName);
}
else if (confirmedOrganizationUsers.Count > 1)
{
await CreateManyDefaultCollectionsAsync(organization, confirmedOrganizationUsers, defaultUserCollectionName);
}
await CreateManyDefaultCollectionsAsync(organization, confirmedOrganizationUsers, defaultUserCollectionName);

return result;
}
Expand Down Expand Up @@ -278,38 +271,6 @@ private async Task<IEnumerable<string>> GetUserDeviceIdsAsync(Guid userId)
.Select(d => d.Id.ToString());
}

/// <summary>
/// Creates a default collection for a single user if required by the Organization Data Ownership policy.
/// </summary>
/// <param name="organizationUser">The organization user who has just been confirmed.</param>
/// <param name="organization">The organization.</param>
/// <param name="defaultUserCollectionName">The encrypted default user collection name.</param>
private async Task CreateDefaultCollectionAsync(OrganizationUser organizationUser, Organization organization, string defaultUserCollectionName)
{
// Skip if no collection name provided (backwards compatibility)
if (string.IsNullOrWhiteSpace(defaultUserCollectionName))
{
return;
}

// Skip if organization has disabled My Items
if (!organization.UseMyItems)
{
return;
}

var organizationDataOwnershipPolicy = await _policyRequirementQuery.GetAsync<OrganizationDataOwnershipPolicyRequirement>(organizationUser.UserId!.Value);
if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId))
{
return;
}

await _collectionRepository.CreateDefaultCollectionsAsync(
organizationUser.OrganizationId,
[organizationUser.Id],
defaultUserCollectionName);
}

/// <summary>
/// Creates default collections for multiple users if required by the Organization Data Ownership policy.
/// </summary>
Expand All @@ -331,12 +292,17 @@ private async Task CreateManyDefaultCollectionsAsync(Organization organization,
return;
}

var policyEligibleOrganizationUserIds = await _policyRequirementQuery
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id);
var confirmedUserIds = confirmedOrganizationUsers
.Select(s => s.UserId!.Value)
.ToList();

var policiesForUsers = await _policyRequirementQuery
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(confirmedUserIds);

var eligibleOrganizationUserIds = confirmedOrganizationUsers
.Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id))
.Select(ou => ou.Id)
var eligibleOrganizationUserIds = policiesForUsers
.Select(x => x.Requirement.GetDefaultCollectionRequestOnConfirm(organization.Id))
.Where(w => w.ShouldCreateDefaultCollection)
.Select(s => s.OrganizationUserId)
.ToList();

if (eligibleOrganizationUserIds.Count == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,27 +261,41 @@ await CreateDefaultCollectionsForConfirmedUsersAsync(organization, defaultCollec
private async Task CreateDefaultCollectionsForConfirmedUsersAsync(Organization organization, string defaultCollectionName,
ICollection<OrganizationUser> restoredUsers)
{
if (string.IsNullOrWhiteSpace(defaultCollectionName))
{
return;
}

if (!organization.UseMyItems)
{
return;
}

if (!string.IsNullOrWhiteSpace(defaultCollectionName))
var restoredConfirmedUsers = restoredUsers
Comment thread
sven-bitwarden marked this conversation as resolved.
.Where(w => w.Status == OrganizationUserStatusType.Confirmed)
.Where(w => w.UserId != null)
.Select(s => s.UserId.Value)
.ToList();

if (restoredConfirmedUsers.Count == 0)
{
var organizationUsersDataOwnershipEnabled = (await policyRequirementQuery
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id))
.ToList();
return;
}

var usersToCreateDefaultCollectionsFor = restoredUsers.Where(x =>
organizationUsersDataOwnershipEnabled.Contains(x.Id)
&& x.Status == OrganizationUserStatusType.Confirmed).ToList();
var restoredUserPolicyRequirements = await
policyRequirementQuery.GetAsync<OrganizationDataOwnershipPolicyRequirement>(restoredConfirmedUsers);

if (usersToCreateDefaultCollectionsFor.Count != 0)
{
await collectionRepository.CreateDefaultCollectionsAsync(organization.Id,
usersToCreateDefaultCollectionsFor.Select(x => x.Id),
defaultCollectionName);
}
var orgUserIdsToCreateDefaultCollectionsFor = restoredUserPolicyRequirements
.Select(s => s.Requirement.GetDefaultCollectionRequestOnConfirm(organization.Id))
.Where(w => w.ShouldCreateDefaultCollection)
.Select(s => s.OrganizationUserId)
.ToList();

if (orgUserIdsToCreateDefaultCollectionsFor.Count != 0)
{
await collectionRepository.CreateDefaultCollectionsAsync(organization.Id,
orgUserIdsToCreateDefaultCollectionsFor,
defaultCollectionName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,4 @@ public interface IPolicyRequirementQuery
/// <param name="userIds">The users that you need to enforce the policy against.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<IEnumerable<(Guid UserId, T Requirement)>> GetAsync<T>(IEnumerable<Guid> userIds) where T : IPolicyRequirement;

/// <summary>
/// Get all organization user IDs within an organization that are affected by a given policy type.
/// Respects role/status/provider exemptions via the policy factory's Enforce predicate.
/// </summary>
/// <param name="organizationId">The organization to check.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy type to evaluate.</typeparam>
/// <returns>Organization user IDs for whom the policy applies within the organization.</returns>
Task<IEnumerable<Guid>> GetManyByOrganizationIdAsync<T>(Guid organizationId) where T : IPolicyRequirement;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,6 @@ public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
return policyRequirements;
}

public async Task<IEnumerable<Guid>> GetManyByOrganizationIdAsync<T>(Guid organizationId)
where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
if (factory is null)
{
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}

var organizationPolicyDetails = await GetOrganizationPolicyDetails(organizationId, factory.PolicyType);

var eligibleOrganizationUserIds = organizationPolicyDetails
.Where(p => p.PolicyType == factory.PolicyType)
.Where(factory.Enforce)
.Select(p => p.OrganizationUserId)
.ToList();

return eligibleOrganizationUserIds;
}

private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(IEnumerable<Guid> userIds, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(userIds, policyType);

private async Task<IEnumerable<OrganizationPolicyDetails>> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ public DefaultCollectionRequest GetDefaultCollectionRequestOnPolicyEnable(Guid o
return noCollectionNeeded;
}

public bool RequiresDefaultCollectionOnConfirm(Guid organizationId)
public DefaultCollectionRequest GetDefaultCollectionRequestOnConfirm(Guid organizationId)
{
return _policyDetails.Any(p => p.OrganizationId == organizationId);
var matchingOrgUserId =
_policyDetails.FirstOrDefault(p => p.OrganizationId == organizationId)?.OrganizationUserId;

return new DefaultCollectionRequest(
OrganizationUserId: matchingOrgUserId.GetValueOrDefault(Guid.Empty),
ShouldCreateDefaultCollection: matchingOrgUserId.HasValue);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,9 @@ public async Task ConfirmUserAsync_WithOrganizationDataOwnershipPolicyApplicable
PolicyType = PolicyType.OrganizationDataOwnership
};
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]));
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser.UserId!.Value)))
.Returns([(orgUser.UserId!.Value, new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]))]);
Comment thread
sven-bitwarden marked this conversation as resolved.

await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);

Expand Down Expand Up @@ -534,8 +535,9 @@ public async Task ConfirmUserAsync_WithOrganizationDataOwnershipPolicyNotApplica
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user });

sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, []));
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser.UserId!.Value)))
.Returns([(orgUser.UserId!.Value, new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, []))]);

await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);

Expand Down Expand Up @@ -908,8 +910,10 @@ public async Task ConfirmUserAsync_UseMyItemsEnabled_CreatesDefaultCollection(
PolicyType = PolicyType.OrganizationDataOwnership
};
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]));
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser.UserId!.Value)))
.Returns([
(orgUser.UserId!.Value, new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]))
]);

// Act
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
Expand Down Expand Up @@ -949,10 +953,6 @@ public async Task ConfirmUsersAsync_UseMyItemsDisabled_DoesNotCreateDefaultColle
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser1, orgUser2 });
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2 });

sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
.Returns([orgUser1.Id, orgUser2.Id]);

// Act
await sutProvider.Sut.ConfirmUsersAsync(organization.Id, keys, confirmingUser.Id, collectionName);

Expand Down Expand Up @@ -988,9 +988,30 @@ public async Task ConfirmUsersAsync_UseMyItemsEnabled_CreatesDefaultCollections(
sutProvider.GetDependency<IOrganizationUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser1, orgUser2 });
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2 });

var policyDetails1 = new PolicyDetails
{
OrganizationId = organization.Id,
OrganizationUserId = orgUser1.Id,
IsProvider = false,
OrganizationUserStatus = orgUser1.Status,
OrganizationUserType = orgUser1.Type,
PolicyType = PolicyType.OrganizationDataOwnership
};
var policyDetails2 = new PolicyDetails
{
OrganizationId = organization.Id,
OrganizationUserId = orgUser2.Id,
IsProvider = false,
OrganizationUserStatus = orgUser2.Status,
OrganizationUserType = orgUser2.Type,
PolicyType = PolicyType.OrganizationDataOwnership
};
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
.Returns([orgUser1.Id, orgUser2.Id]);
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser1.UserId!.Value) && ids.Contains(orgUser2.UserId!.Value)))
.Returns([
(orgUser1.UserId!.Value, new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails1])),
(orgUser2.UserId!.Value, new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails2]))
]);

// Act
await sutProvider.Sut.ConfirmUsersAsync(organization.Id, keys, confirmingUser.Id, collectionName);
Expand Down
Loading
Loading