Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions src/main/java/io/airlift/slice/SliceUtf8.java
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,54 @@ private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Le
private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Length, int position, Slice translatedUtf8, int translatedPosition, int[] codePointTranslationMap)
{
while (position < utf8Length) {
int asciiStart = position;
while (position < utf8Length) {
int value = utf8[utf8Offset + position] & 0xFF;
if (value >= 0x80 || codePointTranslationMap[value] != value) {
break;
}
position++;
}

if (position > asciiStart) {
if (translatedUtf8 != null) {
int nextTranslatedPosition = translatedPosition + (position - asciiStart);
if (nextTranslatedPosition > utf8Length) {
translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition);
}

translatedUtf8.setBytes(translatedPosition, utf8, utf8Offset + asciiStart, position - asciiStart);
translatedPosition = nextTranslatedPosition;
}
else if (position == utf8Length) {
return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length);
}
}

if (position == utf8Length) {
break;
}

int value = utf8[utf8Offset + position] & 0xFF;
if (value < 0x80) {
if (translatedUtf8 == null) {
translatedUtf8 = Slices.allocate(utf8Length);
translatedUtf8.setBytes(0, utf8, utf8Offset, position);
translatedPosition = position;
}

int translatedCodePoint = codePointTranslationMap[value];
int nextTranslatedPosition = translatedPosition + lengthOfCodePoint(translatedCodePoint);
if (nextTranslatedPosition > utf8Length) {
translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition);
}

setCodePointAt(translatedCodePoint, translatedUtf8, translatedPosition);
position++;
translatedPosition = nextTranslatedPosition;
continue;
}

int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position);
if (codePoint >= 0) {
int translatedCodePoint = codePointTranslationMap[codePoint];
Expand All @@ -563,15 +611,12 @@ private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Le
translatedPosition = position;
}

// grow slice if necessary
int nextTranslatedPosition = translatedPosition + lengthOfCodePoint(translatedCodePoint);
if (nextTranslatedPosition > utf8Length) {
translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition);
}

// write translated code point
setCodePointAt(translatedCodePoint, translatedUtf8, translatedPosition);

position += codePointLength;
translatedPosition = nextTranslatedPosition;
}
Expand Down Expand Up @@ -599,12 +644,34 @@ private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Le

private static Slice toUpperCaseAsciiOrCodePoints(byte[] utf8, int utf8Offset, int utf8Length)
{
Slice translated = Slices.allocate(utf8Length);
int position = 0;

// Fast scan until the first ASCII byte that needs translation.
while (position < utf8Length) {
int value = utf8[utf8Offset + position] & 0xFF;
if (value >= 0x80) {
return translateCodePoints(utf8, utf8Offset, utf8Length, UPPER_CODE_POINTS);
return translateCodePoints(utf8, utf8Offset, utf8Length, position, null, position, UPPER_CODE_POINTS);
}

if (value >= 'a' && value <= 'z') {
break;
}
position++;
}

// Nothing to translate in the entire input.
if (position == utf8Length) {
return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length);
}

Slice translated = Slices.allocate(utf8Length);
translated.setBytes(0, utf8, utf8Offset, position);

// Continue with a single tight loop once output exists.
while (position < utf8Length) {
int value = utf8[utf8Offset + position] & 0xFF;
if (value >= 0x80) {
return translateCodePoints(utf8, utf8Offset, utf8Length, position, translated, position, UPPER_CODE_POINTS);
}

if (value >= 'a' && value <= 'z') {
Expand All @@ -615,6 +682,7 @@ private static Slice toUpperCaseAsciiOrCodePoints(byte[] utf8, int utf8Offset, i
}
position++;
}

return translated;
}

Expand Down
104 changes: 104 additions & 0 deletions src/test/java/io/airlift/slice/SliceUtf8Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,24 @@ public Slice benchmarkToLowerCase(BenchmarkData data)
return toLowerCase(data.getUtf8(), data.getOffset(), data.getByteLength());
}

@Benchmark
public Slice benchmarkToLowerCaseTargeted(LowerCaseData data)
{
return toLowerCase(data.getUtf8(), data.getOffset(), data.getByteLength());
}

@Benchmark
public Slice benchmarkToUpperCase(BenchmarkData data)
{
return toUpperCase(data.getUtf8(), data.getOffset(), data.getByteLength());
}

@Benchmark
public Slice benchmarkToUpperCaseTargeted(UpperCaseData data)
{
return toUpperCase(data.getUtf8(), data.getOffset(), data.getByteLength());
}

@Benchmark
public Slice benchmarkLeftTrim(WhitespaceData data)
{
Expand Down Expand Up @@ -661,6 +673,88 @@ public int getLength()
}
}

public abstract static class CaseChangeData
{
@Param({"64", "1024"})
private int repeatCount;

private byte[] utf8;
private int offset;
private int byteLength;

@Setup
public void setup()
{
byte[] input = createInput();
offset = 7;
utf8 = new byte[offset + input.length + 3];
System.arraycopy(input, 0, utf8, offset, input.length);
byteLength = input.length;
}

protected abstract byte[] createInput();

protected int getRepeatCount()
{
return repeatCount;
}

public byte[] getUtf8()
{
return utf8;
}

public int getOffset()
{
return offset;
}

public int getByteLength()
{
return byteLength;
}
}

@State(Thread)
public static class LowerCaseData
extends CaseChangeData
{
@Param({"ascii_change", "non_ascii_noop", "mixed_non_ascii_ascii_noop", "mixed_non_ascii_ascii_change"})
private String inputKind;

@Override
protected byte[] createInput()
{
return switch (inputKind) {
case "ascii_change" -> repeatUtf8("HELLO", getRepeatCount());
case "non_ascii_noop" -> repeatUtf8("ö", getRepeatCount());
case "mixed_non_ascii_ascii_noop" -> repeatUtf8("öhello", getRepeatCount());
case "mixed_non_ascii_ascii_change" -> repeatUtf8("éHELLO", getRepeatCount());
default -> throw new IllegalArgumentException("Unknown inputKind: " + inputKind);
};
}
}

@State(Thread)
public static class UpperCaseData
extends CaseChangeData
{
@Param({"ascii_change", "non_ascii_noop", "mixed_non_ascii_ascii_noop", "mixed_non_ascii_ascii_change"})
private String inputKind;

@Override
protected byte[] createInput()
{
return switch (inputKind) {
case "ascii_change" -> repeatUtf8("hello", getRepeatCount());
case "non_ascii_noop" -> repeatUtf8("Ö", getRepeatCount());
case "mixed_non_ascii_ascii_noop" -> repeatUtf8("ÖHELLO", getRepeatCount());
case "mixed_non_ascii_ascii_change" -> repeatUtf8("Éhello", getRepeatCount());
default -> throw new IllegalArgumentException("Unknown inputKind: " + inputKind);
};
}
}

@State(Thread)
public static class CompareData
extends BenchmarkData
Expand Down Expand Up @@ -1123,4 +1217,14 @@ static void main()

new Runner(options).run();
}

private static byte[] repeatUtf8(String unit, int repeatCount)
{
byte[] encodedUnit = unit.getBytes(StandardCharsets.UTF_8);
DynamicSliceOutput output = new DynamicSliceOutput(encodedUnit.length * repeatCount);
for (int i = 0; i < repeatCount; i++) {
output.appendBytes(encodedUnit);
}
return output.slice().getBytes();
}
}
25 changes: 25 additions & 0 deletions src/test/java/io/airlift/slice/TestSliceUtf8.java
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,31 @@ public void testCaseChange()
INVALID_SEQUENCES.forEach(TestSliceUtf8::assertCaseChangeWithInvalidSequence);
}

@Test
public void testToUpperCaseNoOpWrapsInputRange()
{
byte[] bytes = "HELLO".getBytes(UTF_8);

Slice upper = toUpperCase(bytes, 0, bytes.length);
bytes[0] = 'Y';

assertThat(upper.toStringUtf8()).isEqualTo("YELLO");
}

@Test
public void testCaseChangeNoOpWrapsInputRangeForNonAscii()
{
byte[] upperBytes = "Ö".getBytes(UTF_8);
Slice upper = toUpperCase(upperBytes, 0, upperBytes.length);
upperBytes[1] = (byte) 0x98;
assertThat(upper.toStringUtf8()).isEqualTo("Ø");

byte[] lowerBytes = "ö".getBytes(UTF_8);
Slice lower = toLowerCase(lowerBytes, 0, lowerBytes.length);
lowerBytes[1] = (byte) 0xB8;
assertThat(lower.toStringUtf8()).isEqualTo("ø");
}

private static void assertCaseChangeWithInvalidSequence(byte[] invalidSequence)
{
assertThat(toLowerCase(wrappedBuffer(invalidSequence)))
Expand Down
Loading