Skip to content

Commit 34a388c

Browse files
committed
GH-49614: [C++] Fix silent truncation in base64_decode on invalid input
1 parent df88383 commit 34a388c

3 files changed

Lines changed: 114 additions & 7 deletions

File tree

cpp/src/arrow/util/base64.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <string_view>
2222

2323
#include "arrow/util/visibility.h"
24+
#include "arrow/result.h"
25+
#include "arrow/status.h"
2426

2527
namespace arrow {
2628
namespace util {
@@ -29,7 +31,7 @@ ARROW_EXPORT
2931
std::string base64_encode(std::string_view s);
3032

3133
ARROW_EXPORT
32-
std::string base64_decode(std::string_view s);
34+
arrow::Result<std::string> base64_decode(std::string_view s);
3335

3436
} // namespace util
3537
} // namespace arrow

cpp/src/arrow/util/string_test.cc

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "arrow/testing/gtest_util.h"
2929
#include "arrow/util/regex.h"
3030
#include "arrow/util/string.h"
31+
#include "arrow/util/base64.h"
3132

3233
namespace arrow {
3334
namespace internal {
@@ -232,12 +233,79 @@ TEST(ToChars, FloatingPoint) {
232233
// to std::to_string which may make ad hoc formatting choices, so we cannot
233234
// really test much about the result.
234235
auto result = ToChars(0.0f);
235-
ASSERT_TRUE(result.starts_with("0")) << result;
236+
ASSERT_TRUE(result.rfind("0", 0) == 0) << result;
236237
result = ToChars(0.25);
237-
ASSERT_TRUE(result.starts_with("0.25")) << result;
238+
ASSERT_TRUE(result.rfind("0.25", 0) == 0) << result;
238239
}
239240
}
240241

242+
TEST(Base64DecodeTest, ValidInputs) {
243+
auto r1 = arrow::util::base64_decode("Zg==");
244+
ASSERT_TRUE(r1.ok());
245+
EXPECT_EQ(r1.ValueOrDie(), "f");
246+
auto r2 = arrow::util::base64_decode("Zm8=");
247+
ASSERT_TRUE(r2.ok());
248+
EXPECT_EQ(r2.ValueOrDie(), "fo");
249+
auto r3 = arrow::util::base64_decode("Zm9v");
250+
ASSERT_TRUE(r3.ok());
251+
EXPECT_EQ(r3.ValueOrDie(), "foo");
252+
auto r4 = arrow::util::base64_decode("aGVsbG8gd29ybGQ=");
253+
ASSERT_TRUE(r4.ok());
254+
EXPECT_EQ(r4.ValueOrDie(), "hello world");
255+
}
256+
257+
TEST(Base64DecodeTest, InvalidLength) {
258+
auto r1 = arrow::util::base64_decode("abc");
259+
ASSERT_FALSE(r1.ok());
260+
auto r2 = arrow::util::base64_decode("abcde");
261+
ASSERT_FALSE(r2.ok());
262+
}
263+
264+
TEST(Base64DecodeTest, InvalidCharacters) {
265+
auto r1 = arrow::util::base64_decode("ab$=");
266+
ASSERT_FALSE(r1.ok());
267+
auto r2 = arrow::util::base64_decode("Zm9v*");
268+
ASSERT_FALSE(r2.ok());
269+
auto r3 = arrow::util::base64_decode("abcd$AAA");
270+
ASSERT_FALSE(r3.ok());
271+
}
272+
273+
TEST(Base64DecodeTest, InvalidPadding) {
274+
auto r1 = arrow::util::base64_decode("ab=c");
275+
ASSERT_FALSE(r1.ok());
276+
auto r2 = arrow::util::base64_decode("abc===");
277+
ASSERT_FALSE(r2.ok());
278+
auto r3 = arrow::util::base64_decode("abcd=AAA");
279+
ASSERT_FALSE(r3.ok());
280+
auto r4 = arrow::util::base64_decode("Zm=9v");
281+
ASSERT_FALSE(r4.ok());
282+
}
283+
284+
TEST(Base64DecodeTest, EdgeCases) {
285+
auto r1 = arrow::util::base64_decode("====");
286+
ASSERT_FALSE(r1.ok());
287+
auto r2 = arrow::util::base64_decode("TQ==");
288+
ASSERT_TRUE(r2.ok());
289+
EXPECT_EQ(r2.ValueOrDie(), "M");
290+
}
291+
292+
TEST(Base64DecodeTest, EmptyInput) {
293+
auto r = arrow::util::base64_decode("");
294+
ASSERT_TRUE(r.ok());
295+
EXPECT_EQ(r.ValueOrDie(), "");
296+
}
297+
298+
TEST(Base64DecodeTest, NonAsciiInput) {
299+
std::string input = std::string("abcd") + char(0xFF) + "==";
300+
auto r = arrow::util::base64_decode(input);
301+
ASSERT_FALSE(r.ok());
302+
}
303+
304+
TEST(Base64DecodeTest, PartialCorruption) {
305+
auto r = arrow::util::base64_decode("aGVs$G8gd29ybGQ=");
306+
ASSERT_FALSE(r.ok());
307+
}
308+
241309
#if !defined(_WIN32) || defined(NDEBUG)
242310

243311
TEST(ToChars, LocaleIndependent) {

cpp/src/arrow/vendored/base64.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
*/
3131

3232
#include "arrow/util/base64.h"
33+
#include "arrow/util/logging.h"
34+
#include "arrow/result.h"
35+
#include "arrow/status.h"
3336
#include <iostream>
37+
#include <cctype>
3438

3539
namespace arrow {
3640
namespace util {
@@ -93,18 +97,51 @@ std::string base64_encode(std::string_view string_to_encode) {
9397
return base64_encode(bytes_to_encode, in_len);
9498
}
9599

96-
std::string base64_decode(std::string_view encoded_string) {
100+
arrow::Result<std::string> base64_decode(std::string_view encoded_string) {
97101
size_t in_len = encoded_string.size();
98102
int i = 0;
99103
int j = 0;
100104
int in_ = 0;
101105
unsigned char char_array_4[4], char_array_3[3];
102106
std::string ret;
103107

104-
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
108+
static const std::string base64_chars =
109+
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
110+
"abcdefghijklmnopqrstuvwxyz"
111+
"0123456789+/";
112+
113+
auto is_base64 = [](unsigned char c) -> bool {
114+
return (std::isalnum(c) || (c == '+') || (c == '/'));
115+
};
116+
117+
if (encoded_string.size() % 4 != 0) {
118+
return arrow::Status::Invalid("Invalid base64 input: length is not a multiple of 4");
119+
}
120+
121+
size_t padding_start = encoded_string.find('=');
122+
if (padding_start != std::string::npos) {
123+
for (size_t k = padding_start; k < encoded_string.size(); ++k) {
124+
if (encoded_string[k] != '=') {
125+
return arrow::Status::Invalid("Invalid base64 input: padding character '=' found at invalid position");
126+
}
127+
}
128+
129+
size_t padding_count = encoded_string.size() - padding_start;
130+
if (padding_count > 2) {
131+
return arrow::Status::Invalid("Invalid base64 input: too many padding characters");
132+
}
133+
}
134+
135+
for (char c : encoded_string) {
136+
if (c != '=' && !is_base64(c)) {
137+
return arrow::Status::Invalid("Invalid base64 input: contains non-base64 character '" + std::string(1, c) + "'");
138+
}
139+
}
140+
141+
while (in_len-- && encoded_string[in_] != '=') {
105142
char_array_4[i++] = encoded_string[in_]; in_++;
106-
if (i ==4) {
107-
for (i = 0; i <4; i++)
143+
if (i == 4) {
144+
for (i = 0; i < 4; i++)
108145
char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
109146

110147
char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4);

0 commit comments

Comments
 (0)