Skip to content

Commit 22e6c2c

Browse files
committed
ARROW: Return Result<std::string> in base64_decode and add validation + tests
1 parent c46246a commit 22e6c2c

File tree

10 files changed

+61
-38
lines changed

10 files changed

+61
-38
lines changed

cpp/src/arrow/flight/flight_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username
620620
std::string& password) {
621621
std::string encoded_credentials =
622622
FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
623-
std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
623+
ASSERT_OK_AND_ASSIGN(auto decoded, arrow::util::base64_decode(encoded_credentials));
624+
std::stringstream decoded_stream(decoded);
624625
std::getline(decoded_stream, username, ':');
625626
std::getline(decoded_stream, password, ':');
626627
}

cpp/src/arrow/util/base64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include <string>
2121
#include <string_view>
2222

23-
#include "arrow/util/visibility.h"
2423
#include "arrow/result.h"
24+
#include "arrow/util/visibility.h"
2525

2626
namespace arrow {
2727
namespace util {

cpp/src/arrow/util/base64_test.cc

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,72 +15,76 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
#include "arrow/testing/gtest_util.h"
1918
#include "arrow/util/base64.h"
20-
19+
#include "arrow/testing/gtest_util.h"
2120

2221
namespace arrow {
2322
namespace util {
2423

2524
TEST(Base64DecodeTest, ValidInputs) {
26-
ASSERT_OK_AND_ASSIGN(auto empty, arrow::util::base64_decode(""));
25+
ASSERT_OK_AND_ASSIGN(auto empty, base64_decode(""));
2726
EXPECT_EQ(empty, "");
2827

29-
ASSERT_OK_AND_ASSIGN(auto two_paddings, arrow::util::base64_decode("Zg=="));
28+
ASSERT_OK_AND_ASSIGN(auto two_paddings, base64_decode("Zg=="));
3029
EXPECT_EQ(two_paddings, "f");
3130

32-
ASSERT_OK_AND_ASSIGN(auto one_padding, arrow::util::base64_decode("Zm8="));
31+
ASSERT_OK_AND_ASSIGN(auto one_padding, base64_decode("Zm8="));
3332
EXPECT_EQ(one_padding, "fo");
3433

35-
ASSERT_OK_AND_ASSIGN(auto no_padding, arrow::util::base64_decode("Zm9v"));
34+
ASSERT_OK_AND_ASSIGN(auto no_padding, base64_decode("Zm9v"));
3635
EXPECT_EQ(no_padding, "foo");
3736

38-
ASSERT_OK_AND_ASSIGN(auto multiblock, arrow::util::base64_decode("SGVsbG8gd29ybGQ="));
37+
ASSERT_OK_AND_ASSIGN(auto multiblock, base64_decode("SGVsbG8gd29ybGQ="));
3938
EXPECT_EQ(multiblock, "Hello world");
4039
}
4140

4241
TEST(Base64DecodeTest, BinaryOutput) {
4342
// 'A' maps to index 0 — same zero value used for padding slots
4443
// verifies the 'A' bug is not present
45-
ASSERT_OK_AND_ASSIGN(auto all_A, arrow::util::base64_decode("AAAA"));
44+
ASSERT_OK_AND_ASSIGN(auto all_A, base64_decode("AAAA"));
4645
EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3));
4746

4847
// Arbitrary non-ASCII output bytes
49-
ASSERT_OK_AND_ASSIGN(auto binary, arrow::util::base64_decode("AP8A"));
48+
ASSERT_OK_AND_ASSIGN(auto binary, base64_decode("AP8A"));
5049
EXPECT_EQ(binary, std::string("\x00\xff\x00", 3));
5150
}
5251

5352
TEST(Base64DecodeTest, InvalidLength) {
54-
ASSERT_RAISES_WITH_MESSAGE(
55-
Invalid,
56-
"Invalid: Invalid base64 input: length is not a multiple of 4",
57-
arrow::util::base64_decode("abc"));
53+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
54+
base64_decode("abc"));
5855
}
5956

6057
TEST(Base64DecodeTest, InvalidCharacters) {
61-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab$="));
58+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
59+
base64_decode("ab$="));
6260

6361
// Non-ASCII byte
6462
std::string non_ascii = std::string("abc") + static_cast<char>(0xFF);
65-
ASSERT_RAISES(Invalid, arrow::util::base64_decode(non_ascii));
63+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
64+
base64_decode(non_ascii));
6665

6766
// Corruption mid-string across multiple blocks
68-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("aGVs$G8gd29ybGQ="));
67+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
68+
base64_decode("aGVs$G8gd29ybGQ="));
6969
}
7070

7171
TEST(Base64DecodeTest, InvalidPadding) {
7272
// Padding in wrong position within block
73-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab=c"));
73+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
74+
base64_decode("ab=c"));
7475

7576
// 3 padding characters — exceeds maximum of 2
76-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("a==="));
77+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
78+
base64_decode("a==="));
7779

7880
// 4 padding characters
79-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("===="));
81+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
82+
base64_decode("===="));
8083

8184
// Padding in non-final block across multiple blocks
82-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("Zm8=Zm8="));
85+
ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid base64 input",
86+
base64_decode("Zm8=Zm8="));
8387
}
8488

85-
}
86-
}
89+
} // namespace util
90+
} // namespace arrow

cpp/src/arrow/util/string_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "arrow/testing/gtest_util.h"
2929
#include "arrow/util/regex.h"
3030
#include "arrow/util/string.h"
31-
#include "arrow/util/base64.h"
3231

3332
namespace arrow {
3433
namespace internal {

cpp/src/arrow/vendored/base64.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "arrow/util/base64.h"
3333
#include "arrow/result.h"
3434
#include <iostream>
35-
#include <cctype>
3635

3736
namespace arrow {
3837
namespace util {
@@ -93,15 +92,15 @@ std::string base64_encode(std::string_view string_to_encode) {
9392
Result<std::string> base64_decode(std::string_view encoded_string) {
9493
size_t in_len = encoded_string.size();
9594
int i = 0;
96-
int in_ = 0;
95+
std::string_view::size_type in_ = 0;
9796
int padding_count = 0;
9897
int block_padding = 0;
9998
bool padding_started = false;
10099
unsigned char char_array_4[4], char_array_3[3];
101100
std::string ret;
102101

103102
if (encoded_string.size() % 4 != 0) {
104-
return Status::Invalid("Invalid base64 input: length is not a multiple of 4");
103+
return Status::Invalid("Invalid base64 input");
105104
}
106105

107106
while (in_len--) {
@@ -112,19 +111,17 @@ Result<std::string> base64_decode(std::string_view encoded_string) {
112111
padding_count++;
113112

114113
if (padding_count > 2) {
115-
return Status::Invalid("Invalid base64 input: too many padding characters");
114+
return Status::Invalid("Invalid base64 input");
116115
}
117116

118117
char_array_4[i++] = 0;
119118
} else {
120119
if (padding_started) {
121-
return Status::Invalid("Invalid base64 input: padding characters must be at the end");
120+
return Status::Invalid("Invalid base64 input");
122121
}
123122

124123
if (base64_chars.find(c) == std::string::npos) {
125-
return Status::Invalid(
126-
"Invalid base64 input: contains non-base64 byte at position " +
127-
std::to_string(in_));
124+
return Status::Invalid("Invalid base64 input");
128125
}
129126

130127
char_array_4[i++] = c;

cpp/src/gandiva/gdv_function_stubs.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,15 @@ const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t i
269269
return "";
270270
}
271271
// use arrow method to decode base64 string
272-
std::string decoded_str = arrow::util::base64_decode(std::string_view(in, in_len));
272+
auto result = arrow::util::base64_decode(std::string_view(in, in_len));
273+
if (!result.ok()) {
274+
gdv_fn_context_set_error_msg(context, result.status().message().c_str());
275+
*out_len = 0;
276+
return "";
277+
}
278+
279+
std::string decoded_str = *result;
280+
273281
*out_len = static_cast<int32_t>(decoded_str.length());
274282
// allocate memory for response
275283
char* ret = reinterpret_cast<char*>(

cpp/src/parquet/arrow/fuzz_internal.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,11 @@ class FuzzDecryptionKeyRetriever : public DecryptionKeyRetriever {
8383
}
8484
// Is it a key generated by MakeEncryptionKey?
8585
if (key_id.starts_with(kInlineKeyPrefix)) {
86-
return SecureString(
86+
PARQUET_ASSIGN_OR_THROW(
87+
auto decoded_key,
8788
::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length())));
89+
90+
return SecureString(std::move(decoded_key));
8891
}
8992
throw ParquetException("Unknown fuzz encryption key_id");
9093
}

cpp/src/parquet/arrow/schema.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,8 @@ Status GetOriginSchema(const std::shared_ptr<const KeyValueMetadata>& metadata,
953953
// The original Arrow schema was serialized using the store_schema option.
954954
// We deserialize it here and use it to inform read options such as
955955
// dictionary-encoded fields.
956-
auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index));
956+
ARROW_ASSIGN_OR_RAISE(auto decoded,
957+
::arrow::util::base64_decode(metadata->value(schema_index)));
957958
auto schema_buf = std::make_shared<Buffer>(decoded);
958959

959960
::arrow::ipc::DictionaryMemo dict_memo;

cpp/src/parquet/encryption/file_key_unwrapper.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ KeyWithMasterId FileKeyUnwrapper::GetDataEncryptionKey(const KeyMaterial& key_ma
122122
});
123123

124124
// Decrypt the data key
125-
std::string aad = ::arrow::util::base64_decode(encoded_kek_id);
125+
auto result = ::arrow::util::base64_decode(encoded_kek_id);
126+
if (!result.ok()) {
127+
throw ParquetException(result.status().message());
128+
}
129+
130+
std::string aad = std::move(result).ValueOrDie();
126131
data_key = internal::DecryptKeyLocally(encoded_wrapped_dek, kek_bytes, aad);
127132
}
128133

cpp/src/parquet/encryption/key_toolkit_internal.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ std::string EncryptKeyLocally(const SecureString& key_bytes,
5252

5353
SecureString DecryptKeyLocally(const std::string& encoded_encrypted_key,
5454
const SecureString& master_key, const std::string& aad) {
55-
std::string encrypted_key = ::arrow::util::base64_decode(encoded_encrypted_key);
55+
auto result = ::arrow::util::base64_decode(encoded_encrypted_key);
56+
if (!result.ok()) {
57+
throw ParquetException(result.status().message());
58+
}
59+
60+
std::string encrypted_key = std::move(result).ValueOrDie();
5661

5762
AesDecryptor key_decryptor(ParquetCipher::AES_GCM_V1,
5863
static_cast<int>(master_key.size()), false,

0 commit comments

Comments
 (0)