Skip to content

Commit c46246a

Browse files
committed
Fix base64 validation and add comprehensive tests
1 parent 9d6df6d commit c46246a

File tree

4 files changed

+121
-73
lines changed

4 files changed

+121
-73
lines changed

cpp/src/arrow/util/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_arrow_test(utility-test
4949
SOURCES
5050
align_util_test.cc
5151
atfork_test.cc
52+
base64_test.cc
5253
byte_size_test.cc
5354
byte_stream_split_test.cc
5455
cache_test.cc

cpp/src/arrow/util/base64_test.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow/testing/gtest_util.h"
19+
#include "arrow/util/base64.h"
20+
21+
22+
namespace arrow {
23+
namespace util {
24+
25+
TEST(Base64DecodeTest, ValidInputs) {
26+
ASSERT_OK_AND_ASSIGN(auto empty, arrow::util::base64_decode(""));
27+
EXPECT_EQ(empty, "");
28+
29+
ASSERT_OK_AND_ASSIGN(auto two_paddings, arrow::util::base64_decode("Zg=="));
30+
EXPECT_EQ(two_paddings, "f");
31+
32+
ASSERT_OK_AND_ASSIGN(auto one_padding, arrow::util::base64_decode("Zm8="));
33+
EXPECT_EQ(one_padding, "fo");
34+
35+
ASSERT_OK_AND_ASSIGN(auto no_padding, arrow::util::base64_decode("Zm9v"));
36+
EXPECT_EQ(no_padding, "foo");
37+
38+
ASSERT_OK_AND_ASSIGN(auto multiblock, arrow::util::base64_decode("SGVsbG8gd29ybGQ="));
39+
EXPECT_EQ(multiblock, "Hello world");
40+
}
41+
42+
TEST(Base64DecodeTest, BinaryOutput) {
43+
// 'A' maps to index 0 — same zero value used for padding slots
44+
// verifies the 'A' bug is not present
45+
ASSERT_OK_AND_ASSIGN(auto all_A, arrow::util::base64_decode("AAAA"));
46+
EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3));
47+
48+
// Arbitrary non-ASCII output bytes
49+
ASSERT_OK_AND_ASSIGN(auto binary, arrow::util::base64_decode("AP8A"));
50+
EXPECT_EQ(binary, std::string("\x00\xff\x00", 3));
51+
}
52+
53+
TEST(Base64DecodeTest, InvalidLength) {
54+
ASSERT_RAISES_WITH_MESSAGE(
55+
Invalid,
56+
"Invalid: Invalid base64 input: length is not a multiple of 4",
57+
arrow::util::base64_decode("abc"));
58+
}
59+
60+
TEST(Base64DecodeTest, InvalidCharacters) {
61+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab$="));
62+
63+
// Non-ASCII byte
64+
std::string non_ascii = std::string("abc") + static_cast<char>(0xFF);
65+
ASSERT_RAISES(Invalid, arrow::util::base64_decode(non_ascii));
66+
67+
// Corruption mid-string across multiple blocks
68+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("aGVs$G8gd29ybGQ="));
69+
}
70+
71+
TEST(Base64DecodeTest, InvalidPadding) {
72+
// Padding in wrong position within block
73+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab=c"));
74+
75+
// 3 padding characters — exceeds maximum of 2
76+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("a==="));
77+
78+
// 4 padding characters
79+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("===="));
80+
81+
// Padding in non-final block across multiple blocks
82+
ASSERT_RAISES(Invalid, arrow::util::base64_decode("Zm8=Zm8="));
83+
}
84+
85+
}
86+
}

cpp/src/arrow/util/string_test.cc

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -239,48 +239,6 @@ TEST(ToChars, FloatingPoint) {
239239
}
240240
}
241241

242-
TEST(Base64DecodeTest, ValidInputs) {
243-
ASSERT_OK_AND_ASSIGN(auto two_paddings, arrow::util::base64_decode("Zg=="));
244-
EXPECT_EQ(two_paddings, "f");
245-
246-
ASSERT_OK_AND_ASSIGN(auto one_padding, arrow::util::base64_decode("Zm8="));
247-
EXPECT_EQ(one_padding, "fo");
248-
249-
ASSERT_OK_AND_ASSIGN(auto no_padding, arrow::util::base64_decode("Zm9v"));
250-
EXPECT_EQ(no_padding, "foo");
251-
252-
ASSERT_OK_AND_ASSIGN(auto single_char, arrow::util::base64_decode("TQ=="));
253-
EXPECT_EQ(single_char, "M");
254-
}
255-
256-
TEST(Base64DecodeTest, InvalidLength) {
257-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("abc"));
258-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("abcde"));
259-
}
260-
261-
TEST(Base64DecodeTest, InvalidCharacters) {
262-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab$="));
263-
}
264-
265-
TEST(Base64DecodeTest, InvalidPadding) {
266-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab=c"));
267-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("===="));
268-
}
269-
270-
TEST(Base64DecodeTest, EdgeCases) {
271-
ASSERT_OK_AND_ASSIGN(auto empty, arrow::util::base64_decode(""));
272-
EXPECT_EQ(empty, "");
273-
}
274-
275-
TEST(Base64DecodeTest, NonAsciiInput) {
276-
std::string input = std::string("abc") + static_cast<char>(0xFF);
277-
ASSERT_RAISES(Invalid, arrow::util::base64_decode(input));
278-
}
279-
280-
TEST(Base64DecodeTest, PartialCorruption) {
281-
ASSERT_RAISES(Invalid, arrow::util::base64_decode("aGVs$G8gd29ybGQ="));
282-
}
283-
284242
#if !defined(_WIN32) || defined(NDEBUG)
285243

286244
TEST(ToChars, LocaleIndependent) {

cpp/src/arrow/vendored/base64.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -93,63 +93,66 @@ std::string base64_encode(std::string_view string_to_encode) {
9393
Result<std::string> base64_decode(std::string_view encoded_string) {
9494
size_t in_len = encoded_string.size();
9595
int i = 0;
96-
int j = 0;
9796
int in_ = 0;
97+
int padding_count = 0;
98+
int block_padding = 0;
99+
bool padding_started = false;
98100
unsigned char char_array_4[4], char_array_3[3];
99101
std::string ret;
100102

101103
if (encoded_string.size() % 4 != 0) {
102104
return Status::Invalid("Invalid base64 input: length is not a multiple of 4");
103105
}
104106

105-
size_t padding_start = encoded_string.find('=');
106-
if (padding_start != std::string_view::npos) {
107-
size_t padding_count = encoded_string.size() - padding_start;
108-
if (padding_count > 2) {
109-
return Status::Invalid("Invalid base64 input: too many padding characters");
110-
}
107+
while (in_len--) {
108+
unsigned char c = encoded_string[in_];
109+
110+
if (c == '=') {
111+
padding_started = true;
112+
padding_count++;
111113

112-
for (size_t i = padding_start; i < encoded_string.size(); ++i) {
113-
if (encoded_string[i] != '=') {
114+
if (padding_count > 2) {
115+
return Status::Invalid("Invalid base64 input: too many padding characters");
116+
}
117+
118+
char_array_4[i++] = 0;
119+
} else {
120+
if (padding_started) {
114121
return Status::Invalid("Invalid base64 input: padding characters must be at the end");
115122
}
116-
}
117-
}
118123

119-
while (in_len-- && encoded_string[in_] != '=') {
120-
unsigned char c = encoded_string[in_];
124+
if (base64_chars.find(c) == std::string::npos) {
125+
return Status::Invalid(
126+
"Invalid base64 input: contains non-base64 byte at position " +
127+
std::to_string(in_));
128+
}
121129

122-
if (base64_chars.find(c) == std::string::npos) {
123-
return Status::Invalid("Invalid base64 input: contains non-base64 byte at position " + std::to_string(in_));
130+
char_array_4[i++] = c;
124131
}
125132

126-
char_array_4[i++] = c;
127133
in_++;
128134

129135
if (i == 4) {
130-
for (i = 0; i < 4; i++)
131-
char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
136+
for (i = 0; i < 4; i++) {
137+
if (char_array_4[i] != 0) {
138+
char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
139+
}
140+
}
132141

133-
char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4);
142+
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
134143
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
135-
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
144+
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
145+
146+
block_padding = padding_count;
136147

137-
for (i = 0; (i < 3); i++)
148+
for (i = 0; i < 3 - block_padding; i++) {
138149
ret += char_array_3[i];
150+
}
151+
139152
i = 0;
140153
}
141154
}
142155

143-
if (i) {
144-
for (j = 0; j < i; j++)
145-
char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff;
146-
147-
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
148-
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
149-
150-
for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
151-
}
152-
153156
return ret;
154157
}
155158

0 commit comments

Comments
 (0)