Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,15 @@ public Bundle updateAccount(Account account, UserAccount userAccount) {
accountManager.setUserData(account, key, extras.getString(key));
}

// The refresh token is stored as the Account's password (see createAccount), not as user data,
// so buildAuthBundle does not include it. Persist it explicitly here so that server-side
// Refresh Token Rotation is correctly reflected in storage.
final String refreshToken = userAccount.getRefreshToken();
if (refreshToken != null) {
final String encryptionKey = SalesforceSDKManager.getEncryptionKey();
accountManager.setPassword(account, SalesforceSDKManager.encrypt(refreshToken, encryptionKey));
}

return extras;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ public static class AccMgrAuthTokenProvider implements RestClient.AuthTokenProvi
private final Object lock = new Object();
private final ClientManager clientManager;
private String lastNewAuthToken;
private final String refreshToken;
// Mutable to support server-side Refresh Token Rotation (RTR).
private String refreshToken;
private String lastNewInstanceUrl;
private long lastRefreshTime = -1 /* never refreshed */;

Expand Down Expand Up @@ -506,6 +507,12 @@ private UserAccount refreshStaleToken(Account account) throws NetworkErrorExcept
updatedUserAccount.downloadProfilePhoto();
UserAccountManager.getInstance().clearCachedCurrentUser();

// Handle server-side Refresh Token Rotation: if the response contained a new refresh token,
// update this provider's cached copy.
if (tr.refreshToken != null && !tr.refreshToken.equals(refreshToken)) {
refreshToken = tr.refreshToken;
}

return updatedUserAccount;
} catch (OAuth2.OAuthFailedException ofe) {
if (ofe.isRefreshTokenInvalid()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,48 @@ public void testUserAccountToAccountToUserAccount() {
checkSameUserAccount(userAccount, restoredUserAccount);
}

/*
* Server-side Refresh Token Rotation (RTR) regression test.
*
* The refresh token is persisted as the Account's password and is read back via accountManager.getPassword().
* updateAccount must therefore persist a rotated refresh token via setPassword.
*/
@Test
public void testUpdateAccountPersistsRotatedRefreshToken() {
final UserAccount original = UserAccountTest.createTestAccount();
userAccMgr.createAccount(original);
final Account account = userAccMgr.getCurrentAccount();
Assert.assertEquals(
"Initial refresh token should round-trip through AccountManager",
UserAccountTest.TEST_REFRESH_TOKEN,
userAccMgr.buildUserAccount(account).getRefreshToken());

// Simulate a server-side refresh token rotation by building a
// UserAccount with a new refresh token value and updating.
final String rotatedRefreshToken = "rotated_refresh_token";
final UserAccount rotated = UserAccountBuilder.getInstance()
.populateFromUserAccount(original)
.refreshToken(rotatedRefreshToken)
.build();
userAccMgr.updateAccount(account, rotated);

// The persisted refresh token must reflect the rotated value.
final UserAccount reloaded = userAccMgr.buildUserAccount(account);
Assert.assertEquals(
"Rotated refresh token should be persisted by updateAccount",
rotatedRefreshToken,
reloaded.getRefreshToken());

// Encryption (AES-GCM with a random IV) is non-deterministic, so
// compare against the decrypted password rather than re-encrypting
// the expected value.
final String encryptionKey = SalesforceSDKManager.getEncryptionKey();
Assert.assertEquals(
"Rotated refresh token should be used as the account password.",
rotatedRefreshToken,
SalesforceSDKManager.decrypt(accMgr.getPassword(account), encryptionKey));
}

/**
* Test to get all authenticated users.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.junit.Test
private const val OLD_ACCESS_TOKEN = "old-token"
private const val REFRESHED_ACCESS_TOKEN = "refreshed-auth-token"
private const val REFRESH_TOKEN = "refresh-token"
private const val ROTATED_REFRESH_TOKEN = "rotated-refresh-token"

@SmallTest
class ClientManagerMockTest {
Expand Down Expand Up @@ -346,6 +347,147 @@ class ClientManagerMockTest {
Assert.assertEquals(ClientManager.ACCESS_TOKEN_REVOKE_INTENT, broadcastIntentSlot.captured.action)
}

/*
Server-side Refresh Token Rotation (RTR): when the token endpoint returns
a rotated refresh_token, the provider must update its cached refresh
token so subsequent calls don't reuse the now-invalidated previous one.
*/
@Test
fun testGetNewAuthToken_RefreshTokenRotation_UpdatesCachedRefreshToken() {
val responseBody = """
{
"access_token": "$REFRESHED_ACCESS_TOKEN",
"refresh_token": "$ROTATED_REFRESH_TOKEN",
"instance_url": "https://login.salesforce.com",
"id": "https://login.salesforce.com/id/orgId/userId",
"token_type": "Bearer",
"issued_at": "1234567890",
"signature": "mock-signature"
}
""".trimIndent().toResponseBody("application/json; charset=utf-8".toMediaType())
val rotatedResponse = mockk<Response>(relaxed = true) {
every { isSuccessful } returns true
every { close() } just runs
every { body } returns responseBody
}
every { HttpAccess.DEFAULT.okHttpClient } returns mockk<OkHttpClient> {
every { newCall(any()) } returns mockk<Call> {
every { execute() } returns rotatedResponse
}
}

val userSlot = slot<UserAccount>()
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
every { accounts } returns arrayOf(mockAccount)
}
every { mockUserAccountManager.currentUser } returns mockUser
every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser
every { mockUserAccountManager.updateAccount(mockAccount, any()) } returns mockk()

val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://login.salesforce.com",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

// First refresh: server rotates the refresh token.
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken())

// The persisted account should be updated with the rotated refresh token...
verify(exactly = 1) {
mockUserAccountManager.updateAccount(mockAccount, capture(userSlot))
}
Assert.assertEquals(ROTATED_REFRESH_TOKEN, userSlot.captured.refreshToken)
// ...and so should the provider's in-memory cache, so that subsequent
// refreshes (and getRefreshToken consumers) use the rotated token.
Assert.assertEquals(ROTATED_REFRESH_TOKEN, authTokenProvider.refreshToken)
}

/*
Server-side Refresh Token Rotation (RTR): after a refresh that rotates
the refresh token, the provider's cached refresh token must reflect
the new value so that a subsequent refresh sends the current token
and the per-account lookup matches the rotated value persisted to
the account.
*/
@Test
fun testGetNewAuthToken_RefreshTokenRotation_SubsequentRefreshSucceeds() {
val firstRotated = ROTATED_REFRESH_TOKEN
val secondRotated = "rotated-refresh-token-2"

fun rotationResponse(rt: String): Response {
val responseBody = """
{
"access_token": "$REFRESHED_ACCESS_TOKEN",
"refresh_token": "$rt",
"instance_url": "https://login.salesforce.com",
"id": "https://login.salesforce.com/id/orgId/userId",
"token_type": "Bearer",
"issued_at": "1234567890",
"signature": "mock-signature"
}
""".trimIndent().toResponseBody("application/json; charset=utf-8".toMediaType())
return mockk<Response>(relaxed = true) {
every { isSuccessful } returns true
every { close() } just runs
every { body } returns responseBody
}
}

// Return a different rotated refresh token on each refresh.
every { HttpAccess.DEFAULT.okHttpClient } returns mockk<OkHttpClient> {
every { newCall(any()) } returnsMany listOf(
mockk<Call> { every { execute() } returns rotationResponse(firstRotated) },
mockk<Call> { every { execute() } returns rotationResponse(secondRotated) },
)
}

val mockAccount = mockk<Account>(relaxed = true)
// The persisted account's refresh token follows whatever updateAccount
// was last called with (i.e., the most recent rotated value).
var persistedRefreshToken = REFRESH_TOKEN
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } answers { persistedRefreshToken }
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
every { accounts } returns arrayOf(mockAccount)
}
every { mockUserAccountManager.currentUser } returns mockUser
every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser
every { mockUserAccountManager.updateAccount(mockAccount, any()) } answers {
persistedRefreshToken = secondArg<UserAccount>().refreshToken
mockk()
}

val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://login.salesforce.com",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

// First refresh succeeds, rotates to firstRotated.
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken())
Assert.assertEquals(firstRotated, authTokenProvider.refreshToken)
Assert.assertEquals(firstRotated, persistedRefreshToken)

// Second refresh, ensure each rotation is stored.
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken())
Assert.assertEquals(secondRotated, authTokenProvider.refreshToken)
verify(exactly = 0) {
mockSDKManager.logout(any(), any(), any(), any())
}
}

/*
Non-current user tests the scenario of attempting to make a
network call as the previous user on user account switch, but
Expand Down
Loading