diff --git a/src/main/java/io/airlift/slice/SliceUtf8.java b/src/main/java/io/airlift/slice/SliceUtf8.java index c2ac510d..0c5fbf56 100644 --- a/src/main/java/io/airlift/slice/SliceUtf8.java +++ b/src/main/java/io/airlift/slice/SliceUtf8.java @@ -13,6 +13,8 @@ */ package io.airlift.slice; +import java.lang.invoke.VarHandle; +import java.util.Arrays; import java.util.OptionalInt; import static io.airlift.slice.Preconditions.checkArgument; @@ -22,6 +24,8 @@ import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT; import static java.lang.Character.MIN_SURROGATE; import static java.lang.Integer.toHexString; +import static java.lang.invoke.MethodHandles.byteArrayViewVarHandle; +import static java.nio.ByteOrder.LITTLE_ENDIAN; import static java.util.Objects.checkFromIndexSize; import static java.util.Objects.checkIndex; @@ -35,6 +39,10 @@ private SliceUtf8() {} private static final int MIN_HIGH_SURROGATE_CODE_POINT = 0xD800; private static final int REPLACEMENT_CODE_POINT = 0xFFFD; + private static final VarHandle SHORT_HANDLE = byteArrayViewVarHandle(short[].class, LITTLE_ENDIAN); + private static final VarHandle INT_HANDLE = byteArrayViewVarHandle(int[].class, LITTLE_ENDIAN); + private static final VarHandle LONG_HANDLE = byteArrayViewVarHandle(long[].class, LITTLE_ENDIAN); + private static final int TOP_MASK32 = 0x8080_8080; private static final long TOP_MASK64 = 0x8080_8080_8080_8080L; @@ -66,27 +74,40 @@ private SliceUtf8() {} */ public static boolean isAscii(Slice utf8) { - int length = utf8.length(); + return isAscii(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Does the byte array range contain only 7-bit ASCII characters. + */ + public static boolean isAscii(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return isAsciiRaw(utf8, offset, length); + } + + private static boolean isAsciiRaw(byte[] utf8, int utf8Offset, int utf8Length) + { int offset = 0; // Length rounded to 8 bytes - int length8 = length & 0x7FFF_FFF8; + int length8 = utf8Length & 0x7FFF_FFF8; for (; offset < length8; offset += 8) { - if ((utf8.getLongUnchecked(offset) & TOP_MASK64) != 0) { + if (((long) LONG_HANDLE.get(utf8, utf8Offset + offset) & TOP_MASK64) != 0) { return false; } } // Enough bytes left for 32 bits? - if (offset + 4 < length) { - if ((utf8.getIntUnchecked(offset) & TOP_MASK32) != 0) { + if (offset <= utf8Length - Integer.BYTES) { + if (((int) INT_HANDLE.get(utf8, utf8Offset + offset) & TOP_MASK32) != 0) { return false; } offset += 4; } // Do the rest one by one - for (; offset < length; offset++) { - if ((utf8.getByteUnchecked(offset) & 0x80) != 0) { + for (; offset < utf8Length; offset++) { + if ((utf8[utf8Offset + offset] & 0x80) != 0) { return false; } } @@ -102,7 +123,19 @@ public static boolean isAscii(Slice utf8) */ public static int countCodePoints(Slice utf8) { - return countCodePoints(utf8, 0, utf8.length()); + return countCodePoints(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Counts the code points within UTF-8 encoded byte array range. + *

+ * Note: This method does not explicitly check for valid UTF-8, and may + * return incorrect results or throw an exception for invalid UTF-8. + */ + public static int countCodePoints(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return countCodePoints(utf8, offset, length, 0, length); } /** @@ -113,7 +146,12 @@ public static int countCodePoints(Slice utf8) */ public static int countCodePoints(Slice utf8, int offset, int length) { - checkFromIndexSize(offset, length, utf8.length()); + return countCodePoints(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), offset, length); + } + + private static int countCodePoints(byte[] utf8, int utf8Offset, int utf8Length, int offset, int length) + { + checkFromIndexSize(offset, length, utf8Length); // Quick exit if empty string if (length == 0) { @@ -125,19 +163,19 @@ public static int countCodePoints(Slice utf8, int offset, int length) int lastLongStart = end - 8; for (; offset <= lastLongStart; offset += 8) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getLongUnchecked(offset)); + continuationBytesCount += countContinuationBytes((long) LONG_HANDLE.get(utf8, utf8Offset + offset)); } // Enough bytes left for 32 bits? if (offset <= end - 4) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getIntUnchecked(offset)); + continuationBytesCount += countContinuationBytes((int) INT_HANDLE.get(utf8, utf8Offset + offset)); offset += 4; } // Do the rest one by one for (; offset < end; offset++) { // Count bytes which are NOT the start of a code point - continuationBytesCount += countContinuationBytes(utf8.getByteUnchecked(offset)); + continuationBytesCount += countContinuationBytes(utf8[utf8Offset + offset]); } verify(continuationBytesCount <= length); @@ -152,26 +190,44 @@ public static int countCodePoints(Slice utf8, int offset, int length) * return incorrect results or throw an exception for invalid UTF-8. */ public static Slice substring(Slice utf8, int codePointStart, int codePointLength) + { + return substring(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), codePointStart, codePointLength); + } + + /** + * Gets the substring within byte array range starting at {@code codePointStart} + * and extending for {@code codePointLength} code points. + *

+ * Note: This method does not explicitly check for valid UTF-8, and may + * return incorrect results or throw an exception for invalid UTF-8. + */ + public static Slice substring(byte[] utf8, int offset, int length, int codePointStart, int codePointLength) + { + checkFromIndexSize(offset, length, utf8.length); + return substringRaw(utf8, offset, length, codePointStart, codePointLength); + } + + private static Slice substringRaw(byte[] utf8, int utf8Offset, int utf8Length, int codePointStart, int codePointLength) { checkArgument(codePointStart >= 0, "codePointStart is negative"); checkArgument(codePointLength >= 0, "codePointLength is negative"); - int indexStart = offsetOfCodePoint(utf8, codePointStart); + int indexStart = offsetOfCodePoint(utf8, utf8Offset, utf8Length, codePointStart); if (indexStart < 0) { throw new IllegalArgumentException("UTF-8 does not contain " + codePointStart + " code points"); } if (codePointLength == 0) { return Slices.EMPTY_SLICE; } - int indexEnd = offsetOfCodePoint(utf8, indexStart, codePointLength - 1); + int indexEnd = offsetOfCodePoint(utf8, utf8Offset, utf8Length, indexStart, codePointLength - 1); if (indexEnd < 0) { throw new IllegalArgumentException("UTF-8 does not contain " + (codePointStart + codePointLength) + " code points"); } - indexEnd += lengthOfCodePoint(utf8, indexEnd); - if (indexEnd > utf8.length()) { + indexEnd += lengthOfCodePoint(utf8, utf8Offset, utf8Length, indexEnd); + if (indexEnd > utf8Length) { throw new InvalidUtf8Exception("UTF-8 is not well formed"); } - return utf8.slice(indexStart, indexEnd - indexStart); + return Slices.wrappedBuffer(utf8, utf8Offset + indexStart, indexEnd - indexStart); } /** @@ -181,13 +237,32 @@ public static Slice substring(Slice utf8, int codePointStart, int codePointLengt */ public static Slice reverse(Slice utf8) { - int length = utf8.length(); - Slice reverse = Slices.allocate(length); + return reverse(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Reverses a UTF-8 byte array range code point by code point. + *

+ * Note: Invalid UTF-8 sequences are copied directly to the output. + */ + public static Slice reverse(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return reverseRaw(utf8, offset, length); + } + + private static Slice reverseRaw(byte[] utf8, int utf8Offset, int utf8Length) + { + if (isAsciiRaw(utf8, utf8Offset, utf8Length)) { + return reverseRawAscii(utf8, utf8Offset, utf8Length); + } + + Slice reverse = Slices.allocate(utf8Length); int forwardPosition = 0; - int reversePosition = length; - while (forwardPosition < length) { - int codePointLength = lengthOfCodePointSafe(utf8, forwardPosition); + int reversePosition = utf8Length; + while (forwardPosition < utf8Length) { + int codePointLength = lengthOfCodePointSafeRaw(utf8, utf8Offset, utf8Length, forwardPosition); // backup the reverse pointer reversePosition -= codePointLength; @@ -196,38 +271,156 @@ public static Slice reverse(Slice utf8) throw new InvalidUtf8Exception("UTF-8 is not well formed"); } // copy the character - copyUtf8SequenceUnsafe(utf8, forwardPosition, reverse, reversePosition, codePointLength); + copyUtf8SequenceUnsafe(utf8, utf8Offset, forwardPosition, reverse, reversePosition, codePointLength); forwardPosition += codePointLength; } return reverse; } + private static Slice reverseRawAscii(byte[] utf8, int utf8Offset, int utf8Length) + { + Slice reverse = Slices.allocate(utf8Length); + int sourcePosition = utf8Length; + int destinationPosition = 0; + + while (sourcePosition >= Long.BYTES) { + sourcePosition -= Long.BYTES; + long value = (long) LONG_HANDLE.get(utf8, utf8Offset + sourcePosition); + reverse.setLongUnchecked(destinationPosition, Long.reverseBytes(value)); + destinationPosition += Long.BYTES; + } + + if (sourcePosition >= Integer.BYTES) { + sourcePosition -= Integer.BYTES; + int value = (int) INT_HANDLE.get(utf8, utf8Offset + sourcePosition); + reverse.setIntUnchecked(destinationPosition, Integer.reverseBytes(value)); + destinationPosition += Integer.BYTES; + } + + if (sourcePosition >= Short.BYTES) { + sourcePosition -= Short.BYTES; + short value = (short) SHORT_HANDLE.get(utf8, utf8Offset + sourcePosition); + reverse.setShortUnchecked(destinationPosition, Short.reverseBytes(value)); + destinationPosition += Short.BYTES; + } + + if (sourcePosition == 1) { + reverse.setByteUnchecked(destinationPosition, utf8[utf8Offset]); + } + return reverse; + } + /** * Compares to UTF-8 sequences using UTF-16 big endian semantics. This is * equivalent to the {@link java.lang.String#compareTo(String)}. * {@code java.lang.String}. * - * @throws InvalidUtf8Exception if the UTF-8 are invalid + * Note: this method validates UTF-8 only for byte regions it decodes during + * comparison. Invalid UTF-8 in unvisited suffix regions may not be detected. + * + * @throws InvalidUtf8Exception if invalid UTF-8 is encountered while decoding */ public static int compareUtf16BE(Slice utf8Left, Slice utf8Right) { - int leftLength = utf8Left.length(); - int rightLength = utf8Right.length(); + return compareUtf16BE( + utf8Left.byteArray(), utf8Left.byteArrayOffset(), utf8Left.length(), + utf8Right.byteArray(), utf8Right.byteArrayOffset(), utf8Right.length()); + } + + /** + * Compares two UTF-8 byte array ranges using UTF-16 big endian semantics. + * + * Note: this method validates UTF-8 only for byte regions it decodes during + * comparison. Invalid UTF-8 in unvisited suffix regions may not be detected. + * + * @throws InvalidUtf8Exception if invalid UTF-8 is encountered while decoding + */ + public static int compareUtf16BE(byte[] utf8Left, int leftOffset, int leftLength, byte[] utf8Right, int rightOffset, int rightLength) + { + checkFromIndexSize(leftOffset, leftLength, utf8Left.length); + checkFromIndexSize(rightOffset, rightLength, utf8Right.length); + return compareUtf16BERaw(utf8Left, leftOffset, leftLength, utf8Right, rightOffset, rightLength); + } + private static int compareUtf16BERaw(byte[] utf8Left, int leftOffset, int leftLength, byte[] utf8Right, int rightOffset, int rightLength) + { int offset = 0; + int equalPrefixLength = Math.min(leftLength, rightLength); + int ascii64Limit = equalPrefixLength - Long.BYTES; + int ascii32Limit = equalPrefixLength - Integer.BYTES; + while (offset < leftLength) { + while (offset <= ascii64Limit) { + long leftLong = (long) LONG_HANDLE.get(utf8Left, leftOffset + offset); + long rightLong = (long) LONG_HANDLE.get(utf8Right, rightOffset + offset); + if ((((leftLong | rightLong) & TOP_MASK64) != 0) || leftLong != rightLong) { + break; + } + offset += Long.BYTES; + } + + while (offset <= ascii32Limit) { + int leftInt = (int) INT_HANDLE.get(utf8Left, leftOffset + offset); + int rightInt = (int) INT_HANDLE.get(utf8Right, rightOffset + offset); + if ((((leftInt | rightInt) & TOP_MASK32) != 0) || leftInt != rightInt) { + break; + } + offset += Integer.BYTES; + } + + // chunk skipping can consume the full left range + if (offset >= leftLength) { + break; + } + // if there are no more right code points, right is less if (offset >= rightLength) { return 1; // left.compare(right) > 0 } - int leftCodePoint = tryGetCodePointAt(utf8Left, offset); + int leftByte = utf8Left[leftOffset + offset] & 0xFF; + int rightByte = utf8Right[rightOffset + offset] & 0xFF; + if ((leftByte | rightByte) < 0x80) { + if (leftByte != rightByte) { + return Integer.compare(leftByte, rightByte); + } + offset++; + continue; + } + + if (leftByte == rightByte) { + int leftCodePoint = tryGetCodePointAtRaw(utf8Left, leftOffset, leftLength, offset); + if (leftCodePoint < 0) { + throw new InvalidUtf8Exception("Invalid UTF-8 sequence in utf8Left at " + offset); + } + + int leftCodePointLength = lengthOfCodePoint(leftCodePoint); + if (offset + leftCodePointLength <= rightLength && utf8SequencesEqual(utf8Left, leftOffset + offset, utf8Right, rightOffset + offset, leftCodePointLength)) { + offset += leftCodePointLength; + continue; + } + + int rightCodePoint = tryGetCodePointAtRaw(utf8Right, rightOffset, rightLength, offset); + if (rightCodePoint < 0) { + throw new InvalidUtf8Exception("Invalid UTF-8 sequence in utf8Right at " + offset); + } + + int result = compareUtf16BE(leftCodePoint, rightCodePoint); + if (result != 0) { + return result; + } + + offset += leftCodePointLength; + continue; + } + + int leftCodePoint = tryGetCodePointAtRaw(utf8Left, leftOffset, leftLength, offset); if (leftCodePoint < 0) { throw new InvalidUtf8Exception("Invalid UTF-8 sequence in utf8Left at " + offset); } - int rightCodePoint = tryGetCodePointAt(utf8Right, offset); + int rightCodePoint = tryGetCodePointAtRaw(utf8Right, rightOffset, rightLength, offset); if (rightCodePoint < 0) { throw new InvalidUtf8Exception("Invalid UTF-8 sequence in utf8Right at " + offset); } @@ -251,6 +444,26 @@ public static int compareUtf16BE(Slice utf8Left, Slice utf8Right) return 0; } + private static boolean utf8SequencesEqual(byte[] left, int leftStart, byte[] right, int rightStart, int length) + { + switch (length) { + case 1 -> { + return left[leftStart] == right[rightStart]; + } + case 2 -> { + return (short) SHORT_HANDLE.get(left, leftStart) == (short) SHORT_HANDLE.get(right, rightStart); + } + case 3 -> { + return (short) SHORT_HANDLE.get(left, leftStart) == (short) SHORT_HANDLE.get(right, rightStart) && + left[leftStart + 2] == right[rightStart + 2]; + } + case 4 -> { + return (int) INT_HANDLE.get(left, leftStart) == (int) INT_HANDLE.get(right, rightStart); + } + default -> throw new IllegalArgumentException("Invalid UTF-8 sequence length: " + length); + } + } + static int compareUtf16BE(int leftCodePoint, int rightCodePoint) { if (leftCodePoint < MIN_SUPPLEMENTARY_CODE_POINT) { @@ -275,7 +488,7 @@ static int compareUtf16BE(int leftCodePoint, int rightCodePoint) /** * Converts slice to upper case code point by code point. This method does - * not perform perform locale-sensitive, context-sensitive, or one-to-many + * not perform locale-sensitive, context-sensitive, or one-to-many * mappings required for some languages. Specifically, this will return * incorrect results for Lithuanian, Turkish, and Azeri. *

@@ -283,12 +496,21 @@ static int compareUtf16BE(int leftCodePoint, int rightCodePoint) */ public static Slice toUpperCase(Slice utf8) { - return translateCodePoints(utf8, UPPER_CODE_POINTS); + return toUpperCase(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Converts byte array range to upper case code point by code point. + */ + public static Slice toUpperCase(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return toUpperCaseAsciiOrCodePoints(utf8, offset, length); } /** * Converts slice to lower case code point by code point. This method does - * not perform perform locale-sensitive, context-sensitive, or one-to-many + * not perform locale-sensitive, context-sensitive, or one-to-many * mappings required for some languages. Specifically, this will return * incorrect results for Lithuanian, Turkish, and Azeri. *

@@ -296,67 +518,167 @@ public static Slice toUpperCase(Slice utf8) */ public static Slice toLowerCase(Slice utf8) { - return translateCodePoints(utf8, LOWER_CODE_POINTS); + return toLowerCase(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Converts byte array range to lower case code point by code point. + */ + public static Slice toLowerCase(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return toLowerCaseAsciiOrCodePoints(utf8, offset, length); } - private static Slice translateCodePoints(Slice utf8, int[] codePointTranslationMap) + private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Length, int[] codePointTranslationMap) { - int length = utf8.length(); - Slice newUtf8 = Slices.allocate(length); + return translateCodePoints(utf8, utf8Offset, utf8Length, 0, null, 0, codePointTranslationMap); + } - int position = 0; - int upperPosition = 0; - while (position < length) { - int codePoint = tryGetCodePointAt(utf8, position); + private static Slice translateCodePoints(byte[] utf8, int utf8Offset, int utf8Length, int position, Slice translatedUtf8, int translatedPosition, int[] codePointTranslationMap) + { + while (position < utf8Length) { + int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position); if (codePoint >= 0) { - int upperCodePoint = codePointTranslationMap[codePoint]; + int translatedCodePoint = codePointTranslationMap[codePoint]; + int codePointLength = lengthOfCodePoint(codePoint); + + if (translatedCodePoint == codePoint) { + if (translatedUtf8 != null) { + int nextTranslatedPosition = translatedPosition + codePointLength; + if (nextTranslatedPosition > utf8Length) { + translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition); + } + + copyUtf8SequenceUnsafe(utf8, utf8Offset, position, translatedUtf8, translatedPosition, codePointLength); + translatedPosition = nextTranslatedPosition; + } + position += codePointLength; + continue; + } + + if (translatedUtf8 == null) { + translatedUtf8 = Slices.allocate(utf8Length); + translatedUtf8.setBytes(0, utf8, utf8Offset, position); + translatedPosition = position; + } // grow slice if necessary - int nextUpperPosition = upperPosition + lengthOfCodePoint(upperCodePoint); - if (nextUpperPosition > length) { - newUtf8 = Slices.ensureSize(newUtf8, nextUpperPosition); + int nextTranslatedPosition = translatedPosition + lengthOfCodePoint(translatedCodePoint); + if (nextTranslatedPosition > utf8Length) { + translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition); } - // write new byte - setCodePointAt(upperCodePoint, newUtf8, upperPosition); + // write translated code point + setCodePointAt(translatedCodePoint, translatedUtf8, translatedPosition); - position += lengthOfCodePoint(codePoint); - upperPosition = nextUpperPosition; + position += codePointLength; + translatedPosition = nextTranslatedPosition; } else { int skipLength = -codePoint; - // grow slice if necessary - int nextUpperPosition = upperPosition + skipLength; - if (nextUpperPosition > length) { - newUtf8 = Slices.ensureSize(newUtf8, nextUpperPosition); - } + if (translatedUtf8 != null) { + // grow slice if necessary + int nextTranslatedPosition = translatedPosition + skipLength; + if (nextTranslatedPosition > utf8Length) { + translatedUtf8 = Slices.ensureSize(translatedUtf8, nextTranslatedPosition); + } - copyUtf8SequenceUnsafe(utf8, position, newUtf8, upperPosition, skipLength); + copyUtf8SequenceUnsafe(utf8, utf8Offset, position, translatedUtf8, translatedPosition, skipLength); + translatedPosition = nextTranslatedPosition; + } position += skipLength; - upperPosition = nextUpperPosition; } } - return newUtf8.slice(0, upperPosition); + if (translatedUtf8 == null) { + return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length); + } + return translatedUtf8.slice(0, translatedPosition); + } + + private static Slice toUpperCaseAsciiOrCodePoints(byte[] utf8, int utf8Offset, int utf8Length) + { + Slice translated = Slices.allocate(utf8Length); + int position = 0; + while (position < utf8Length) { + int value = utf8[utf8Offset + position] & 0xFF; + if (value >= 0x80) { + return translateCodePoints(utf8, utf8Offset, utf8Length, UPPER_CODE_POINTS); + } + + if (value >= 'a' && value <= 'z') { + translated.setByteUnchecked(position, value - ('a' - 'A')); + } + else { + translated.setByteUnchecked(position, value); + } + position++; + } + return translated; + } + + private static Slice toLowerCaseAsciiOrCodePoints(byte[] utf8, int utf8Offset, int utf8Length) + { + int position = 0; + + // Fast scan until the first ASCII byte that needs translation. + while (position < utf8Length) { + int value = utf8[utf8Offset + position] & 0xFF; + if (value >= 0x80) { + return translateCodePoints(utf8, utf8Offset, utf8Length, position, null, position, LOWER_CODE_POINTS); + } + + if (value >= 'A' && value <= 'Z') { + break; + } + position++; + } + + // Nothing to translate in the entire input. + if (position == utf8Length) { + return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length); + } + + Slice translated = Slices.allocate(utf8Length); + translated.setBytes(0, utf8, utf8Offset, position); + + // Continue with a single tight loop once output exists. + while (position < utf8Length) { + int value = utf8[utf8Offset + position] & 0xFF; + if (value >= 0x80) { + return translateCodePoints(utf8, utf8Offset, utf8Length, position, translated, position, LOWER_CODE_POINTS); + } + + if (value >= 'A' && value <= 'Z') { + translated.setByteUnchecked(position, value + ('a' - 'A')); + } + else { + translated.setByteUnchecked(position, value); + } + position++; + } + + return translated; } - private static void copyUtf8SequenceUnsafe(Slice source, int sourcePosition, Slice destination, int destinationPosition, int length) + private static void copyUtf8SequenceUnsafe(byte[] source, int sourceOffset, int sourcePosition, Slice destination, int destinationPosition, int length) { switch (length) { - case 1 -> destination.setByteUnchecked(destinationPosition, source.getByteUnchecked(sourcePosition)); - case 2 -> destination.setShortUnchecked(destinationPosition, source.getShortUnchecked(sourcePosition)); + case 1 -> destination.setByteUnchecked(destinationPosition, source[sourceOffset + sourcePosition]); + case 2 -> destination.setShortUnchecked(destinationPosition, (short) SHORT_HANDLE.get(source, sourceOffset + sourcePosition)); case 3 -> { - destination.setShortUnchecked(destinationPosition, source.getShortUnchecked(sourcePosition)); - destination.setByteUnchecked(destinationPosition + 2, source.getByteUnchecked(sourcePosition + 2)); + destination.setShortUnchecked(destinationPosition, (short) SHORT_HANDLE.get(source, sourceOffset + sourcePosition)); + destination.setByteUnchecked(destinationPosition + 2, source[sourceOffset + sourcePosition + 2]); } - case 4 -> destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); + case 4 -> destination.setIntUnchecked(destinationPosition, (int) INT_HANDLE.get(source, sourceOffset + sourcePosition)); case 5 -> { - destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); - destination.setByteUnchecked(destinationPosition + 4, source.getByteUnchecked(sourcePosition + 4)); + destination.setIntUnchecked(destinationPosition, (int) INT_HANDLE.get(source, sourceOffset + sourcePosition)); + destination.setByteUnchecked(destinationPosition + 4, source[sourceOffset + sourcePosition + 4]); } case 6 -> { - destination.setIntUnchecked(destinationPosition, source.getIntUnchecked(sourcePosition)); - destination.setShortUnchecked(destinationPosition + 4, source.getShortUnchecked(sourcePosition + 4)); + destination.setIntUnchecked(destinationPosition, (int) INT_HANDLE.get(source, sourceOffset + sourcePosition)); + destination.setShortUnchecked(destinationPosition + 4, (short) SHORT_HANDLE.get(source, sourceOffset + sourcePosition + 4)); } default -> throw new IllegalStateException("Invalid code point length " + length); } @@ -369,10 +691,17 @@ private static void copyUtf8SequenceUnsafe(Slice source, int sourcePosition, Sli */ public static Slice leftTrim(Slice utf8) { - int length = utf8.length(); + return leftTrim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } - int position = firstNonWhitespacePosition(utf8); - return utf8.slice(position, length - position); + /** + * Removes all white space characters from the left side of a byte array range. + */ + public static Slice leftTrim(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + int position = firstNonWhitespacePosition(utf8, offset, length); + return Slices.wrappedBuffer(utf8, offset + position, length - position); } /** @@ -382,59 +711,69 @@ public static Slice leftTrim(Slice utf8) */ public static Slice leftTrim(Slice utf8, int[] whiteSpaceCodePoints) { - int length = utf8.length(); - - int position = firstNonMatchPosition(utf8, whiteSpaceCodePoints); - return utf8.slice(position, length - position); + return leftTrim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), whiteSpaceCodePoints); } - private static int firstNonWhitespacePosition(Slice utf8) + /** + * Removes all {@code whiteSpaceCodePoints} from the left side of a byte array range. + */ + public static Slice leftTrim(byte[] utf8, int offset, int length, int[] whiteSpaceCodePoints) { - int length = utf8.length(); + checkFromIndexSize(offset, length, utf8.length); + int position = firstNonMatchPosition(utf8, offset, length, whiteSpaceCodePoints); + return Slices.wrappedBuffer(utf8, offset + position, length - position); + } + private static int firstNonWhitespacePosition(byte[] utf8, int utf8Offset, int utf8Length) + { int position = 0; - while (position < length) { - int codePoint = tryGetCodePointAt(utf8, position); - if (codePoint < 0) { - break; + while (position < utf8Length) { + int value = utf8[utf8Offset + position] & 0xFF; + if (value < 0x80) { + if (!WHITESPACE_CODE_POINTS[value]) { + break; + } + position++; + continue; } - if (!WHITESPACE_CODE_POINTS[codePoint]) { + + int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position); + if (codePoint < 0 || !WHITESPACE_CODE_POINTS[codePoint]) { break; } + position += lengthOfCodePoint(codePoint); } return position; } - // This function is an exact duplicate of firstNonWhitespacePosition(Slice) except for one line. - private static int firstNonMatchPosition(Slice utf8, int[] codePointsToMatch) + // This function mirrors firstNonWhitespacePosition but uses a caller-provided match set. + private static int firstNonMatchPosition(byte[] utf8, int utf8Offset, int utf8Length, int[] codePointsToMatch) { - int length = utf8.length(); + long asciiMatchMaskLow = asciiMatchMaskLow(codePointsToMatch); + long asciiMatchMaskHigh = asciiMatchMaskHigh(codePointsToMatch); int position = 0; - while (position < length) { - int codePoint = tryGetCodePointAt(utf8, position); - if (codePoint < 0) { - break; + while (position < utf8Length) { + int value = utf8[utf8Offset + position] & 0xFF; + if (value < 0x80) { + if (!matches(value, codePointsToMatch, asciiMatchMaskLow, asciiMatchMaskHigh)) { + break; + } + position++; + continue; } - if (!matches(codePoint, codePointsToMatch)) { + + int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position); + if (codePoint < 0 || !matches(codePoint, codePointsToMatch, asciiMatchMaskLow, asciiMatchMaskHigh)) { break; } + position += lengthOfCodePoint(codePoint); } return position; } - private static boolean matches(int codePoint, int[] codePoints) - { - for (int codePointToTrim : codePoints) { - if (codePoint == codePointToTrim) { - return true; - } - } - return false; - } - /** * Removes all white space characters from the right side of the string. *

@@ -442,8 +781,17 @@ private static boolean matches(int codePoint, int[] codePoints) */ public static Slice rightTrim(Slice utf8) { - int position = lastNonWhitespacePosition(utf8, 0); - return utf8.slice(0, position); + return rightTrim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Removes all white space characters from the right side of a byte array range. + */ + public static Slice rightTrim(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + int position = lastNonWhitespacePosition(utf8, offset, length, 0); + return Slices.wrappedBuffer(utf8, offset, position); } /** @@ -453,32 +801,84 @@ public static Slice rightTrim(Slice utf8) */ public static Slice rightTrim(Slice utf8, int[] whiteSpaceCodePoints) { - int position = lastNonMatchPosition(utf8, 0, whiteSpaceCodePoints); - return utf8.slice(0, position); + return rightTrim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), whiteSpaceCodePoints); } - private static int lastNonWhitespacePosition(Slice utf8, int minPosition) + /** + * Removes all {@code whiteSpaceCodePoints} from the right side of a byte array range. + */ + public static Slice rightTrim(byte[] utf8, int offset, int length, int[] whiteSpaceCodePoints) + { + checkFromIndexSize(offset, length, utf8.length); + int position = lastNonMatchPosition(utf8, offset, length, 0, whiteSpaceCodePoints); + return Slices.wrappedBuffer(utf8, offset, position); + } + + private static boolean matches(int codePoint, int[] codePoints, long asciiMatchMaskLow, long asciiMatchMaskHigh) + { + if (codePoint < Long.SIZE) { + return ((asciiMatchMaskLow >>> codePoint) & 1) == 1; + } + if (codePoint < (Long.SIZE * 2)) { + return ((asciiMatchMaskHigh >>> (codePoint - Long.SIZE)) & 1) == 1; + } + + for (int codePointToTrim : codePoints) { + if (codePoint == codePointToTrim) { + return true; + } + } + return false; + } + + private static long asciiMatchMaskLow(int[] codePoints) { - int position = utf8.length(); + long asciiMatchMaskLow = 0; + for (int codePoint : codePoints) { + if (codePoint < Long.SIZE) { + asciiMatchMaskLow |= (1L << codePoint); + } + } + return asciiMatchMaskLow; + } + + private static long asciiMatchMaskHigh(int[] codePoints) + { + long asciiMatchMaskHigh = 0; + for (int codePoint : codePoints) { + if (codePoint >= Long.SIZE && codePoint < (Long.SIZE * 2)) { + asciiMatchMaskHigh |= (1L << (codePoint - Long.SIZE)); + } + } + return asciiMatchMaskHigh; + } + + private static int lastNonWhitespacePosition(byte[] utf8, int utf8Offset, int utf8Length, int minPosition) + { + int position = utf8Length; while (minPosition < position) { + int value = utf8[utf8Offset + position - 1] & 0xFF; + if (value < 0x80) { + if (!WHITESPACE_CODE_POINTS[value]) { + break; + } + position--; + continue; + } + // decode the code point before position if possible int codePoint; int codePointLength; - byte unsignedByte = utf8.getByte(position - 1); - if (!isContinuationByte(unsignedByte)) { - codePoint = unsignedByte & 0xFF; - codePointLength = 1; - } - else if (minPosition <= position - 2 && !isContinuationByte(utf8.getByte(position - 2))) { - codePoint = tryGetCodePointAt(utf8, position - 2); + if (minPosition <= position - 2 && !isContinuationByte(utf8[utf8Offset + position - 2])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 2); codePointLength = 2; } - else if (minPosition <= position - 3 && !isContinuationByte(utf8.getByte(position - 3))) { - codePoint = tryGetCodePointAt(utf8, position - 3); + else if (minPosition <= position - 3 && !isContinuationByte(utf8[utf8Offset + position - 3])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 3); codePointLength = 3; } - else if (minPosition <= position - 4 && !isContinuationByte(utf8.getByte(position - 4))) { - codePoint = tryGetCodePointAt(utf8, position - 4); + else if (minPosition <= position - 4 && !isContinuationByte(utf8[utf8Offset + position - 4])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 4); codePointLength = 4; } else { @@ -495,29 +895,36 @@ else if (minPosition <= position - 4 && !isContinuationByte(utf8.getByte(positio return position; } - // This function is an exact duplicate of lastNonWhitespacePosition(Slice, int) except for one line. - private static int lastNonMatchPosition(Slice utf8, int minPosition, int[] codePointsToMatch) + // This function mirrors lastNonWhitespacePosition but uses a caller-provided match set. + private static int lastNonMatchPosition(byte[] utf8, int utf8Offset, int utf8Length, int minPosition, int[] codePointsToMatch) { - int position = utf8.length(); + long asciiMatchMaskLow = asciiMatchMaskLow(codePointsToMatch); + long asciiMatchMaskHigh = asciiMatchMaskHigh(codePointsToMatch); + + int position = utf8Length; while (position > minPosition) { + int value = utf8[utf8Offset + position - 1] & 0xFF; + if (value < 0x80) { + if (!matches(value, codePointsToMatch, asciiMatchMaskLow, asciiMatchMaskHigh)) { + break; + } + position--; + continue; + } + // decode the code point before position if possible int codePoint; int codePointLength; - byte unsignedByte = utf8.getByte(position - 1); - if (!isContinuationByte(unsignedByte)) { - codePoint = unsignedByte & 0xFF; - codePointLength = 1; - } - else if (minPosition <= position - 2 && !isContinuationByte(utf8.getByte(position - 2))) { - codePoint = tryGetCodePointAt(utf8, position - 2); + if (minPosition <= position - 2 && !isContinuationByte(utf8[utf8Offset + position - 2])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 2); codePointLength = 2; } - else if (minPosition <= position - 3 && !isContinuationByte(utf8.getByte(position - 3))) { - codePoint = tryGetCodePointAt(utf8, position - 3); + else if (minPosition <= position - 3 && !isContinuationByte(utf8[utf8Offset + position - 3])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 3); codePointLength = 3; } - else if (minPosition <= position - 4 && !isContinuationByte(utf8.getByte(position - 4))) { - codePoint = tryGetCodePointAt(utf8, position - 4); + else if (minPosition <= position - 4 && !isContinuationByte(utf8[utf8Offset + position - 4])) { + codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 4); codePointLength = 4; } else { @@ -526,7 +933,7 @@ else if (minPosition <= position - 4 && !isContinuationByte(utf8.getByte(positio if (codePoint < 0 || codePointLength != lengthOfCodePoint(codePoint)) { break; } - if (!matches(codePoint, codePointsToMatch)) { + if (!matches(codePoint, codePointsToMatch, asciiMatchMaskLow, asciiMatchMaskHigh)) { break; } position -= codePointLength; @@ -541,9 +948,18 @@ else if (minPosition <= position - 4 && !isContinuationByte(utf8.getByte(positio */ public static Slice trim(Slice utf8) { - int start = firstNonWhitespacePosition(utf8); - int end = lastNonWhitespacePosition(utf8, start); - return utf8.slice(start, end - start); + return trim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Removes all white space characters from the left and right side of a byte array range. + */ + public static Slice trim(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + int start = firstNonWhitespacePosition(utf8, offset, length); + int end = lastNonWhitespacePosition(utf8, offset, length, start); + return Slices.wrappedBuffer(utf8, offset + start, end - start); } /** @@ -553,9 +969,18 @@ public static Slice trim(Slice utf8) */ public static Slice trim(Slice utf8, int[] whiteSpaceCodePoints) { - int start = firstNonMatchPosition(utf8, whiteSpaceCodePoints); - int end = lastNonMatchPosition(utf8, start, whiteSpaceCodePoints); - return utf8.slice(start, end - start); + return trim(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), whiteSpaceCodePoints); + } + + /** + * Removes all {@code whiteSpaceCodePoints} from the left and right side of a byte array range. + */ + public static Slice trim(byte[] utf8, int offset, int length, int[] whiteSpaceCodePoints) + { + checkFromIndexSize(offset, length, utf8.length); + int start = firstNonMatchPosition(utf8, offset, length, whiteSpaceCodePoints); + int end = lastNonMatchPosition(utf8, offset, length, start, whiteSpaceCodePoints); + return Slices.wrappedBuffer(utf8, offset + start, end - start); } public static Slice fixInvalidUtf8(Slice slice) @@ -563,11 +988,30 @@ public static Slice fixInvalidUtf8(Slice slice) return fixInvalidUtf8(slice, OptionalInt.of(REPLACEMENT_CODE_POINT)); } + public static Slice fixInvalidUtf8(byte[] utf8, int offset, int length) + { + return fixInvalidUtf8(utf8, offset, length, OptionalInt.of(REPLACEMENT_CODE_POINT)); + } + public static Slice fixInvalidUtf8(Slice slice, OptionalInt replacementCodePoint) { if (isAscii(slice)) { return slice; } + return fixInvalidUtf8(slice.byteArray(), slice.byteArrayOffset(), slice.length(), replacementCodePoint); + } + + public static Slice fixInvalidUtf8(byte[] utf8, int offset, int length, OptionalInt replacementCodePoint) + { + checkFromIndexSize(offset, length, utf8.length); + return fixInvalidUtf8Raw(utf8, offset, length, replacementCodePoint); + } + + private static Slice fixInvalidUtf8Raw(byte[] utf8, int utf8Offset, int utf8Length, OptionalInt replacementCodePoint) + { + if (isAsciiRaw(utf8, utf8Offset, utf8Length)) { + return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length); + } int replacementCodePointValue = -1; int replacementCodePointLength = 0; @@ -576,31 +1020,42 @@ public static Slice fixInvalidUtf8(Slice slice, OptionalInt replacementCodePoint replacementCodePointLength = lengthOfCodePoint(replacementCodePointValue); } - int length = slice.length(); - Slice utf8 = Slices.allocate(length); - int dataPosition = 0; int utf8Position = 0; - while (dataPosition < length) { - int codePoint = tryGetCodePointAt(slice, dataPosition); - int codePointLength; + Slice output = null; + while (dataPosition < utf8Length) { + int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, dataPosition); if (codePoint >= 0) { - codePointLength = lengthOfCodePoint(codePoint); + int codePointLength = lengthOfCodePoint(codePoint); + if (output != null) { + int nextUtf8Position = utf8Position + codePointLength; + output = Slices.ensureSize(output, nextUtf8Position); + copyUtf8SequenceUnsafe(utf8, utf8Offset, dataPosition, output, utf8Position, codePointLength); + utf8Position = nextUtf8Position; + } dataPosition += codePointLength; } else { + if (output == null) { + output = Slices.allocate(utf8Length); + output.setBytes(0, utf8, utf8Offset, dataPosition); + utf8Position = dataPosition; + } + // negative number carries the number of invalid bytes dataPosition += (-codePoint); if (replacementCodePointValue < 0) { continue; } - codePoint = replacementCodePointValue; - codePointLength = replacementCodePointLength; + output = Slices.ensureSize(output, utf8Position + replacementCodePointLength); + utf8Position += setCodePointAt(replacementCodePointValue, output, utf8Position); } - utf8 = Slices.ensureSize(utf8, utf8Position + codePointLength); - utf8Position += setCodePointAt(codePoint, utf8, utf8Position); } - return utf8.slice(0, utf8Position); + + if (output == null) { + return Slices.wrappedBuffer(utf8, utf8Offset, utf8Length); + } + return output.slice(0, utf8Position); } /** @@ -613,10 +1068,27 @@ public static Slice fixInvalidUtf8(Slice slice, OptionalInt replacementCodePoint * @return the code point or negative the number of bytes in the invalid UTF-8 sequence. */ public static int tryGetCodePointAt(Slice utf8, int position) + { + return tryGetCodePointAt(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Tries to get the UTF-8 encoded code point at {@code position} in the byte array range. + * + * @return the code point or negative the number of bytes in the invalid UTF-8 sequence. + */ + public static int tryGetCodePointAt(byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkIndex(position, length); + return tryGetCodePointAtRaw(utf8, offset, length, position); + } + + private static int tryGetCodePointAtRaw(byte[] utf8, int utf8Offset, int utf8Length, int position) { // // Process first byte - byte firstByte = utf8.getByte(position); + byte firstByte = utf8[utf8Offset + position]; int length = lengthOfCodePointFromStartByteSafe(firstByte); if (length < 0) { @@ -631,11 +1103,11 @@ public static int tryGetCodePointAt(Slice utf8, int position) // // Process second byte - if (position + 1 >= utf8.length()) { + if (position + 1 >= utf8Length) { return -1; } - byte secondByte = utf8.getByteUnchecked(position + 1); + byte secondByte = utf8[utf8Offset + position + 1]; if (!isContinuationByte(secondByte)) { return -1; } @@ -650,11 +1122,11 @@ public static int tryGetCodePointAt(Slice utf8, int position) // // Process third byte - if (position + 2 >= utf8.length()) { + if (position + 2 >= utf8Length) { return -2; } - byte thirdByte = utf8.getByteUnchecked(position + 2); + byte thirdByte = utf8[utf8Offset + position + 2]; if (!isContinuationByte(thirdByte)) { return -2; } @@ -675,11 +1147,11 @@ public static int tryGetCodePointAt(Slice utf8, int position) // // Process forth byte - if (position + 3 >= utf8.length()) { + if (position + 3 >= utf8Length) { return -3; } - byte forthByte = utf8.getByteUnchecked(position + 3); + byte forthByte = utf8[utf8Offset + position + 3]; if (!isContinuationByte(forthByte)) { return -3; } @@ -699,11 +1171,11 @@ public static int tryGetCodePointAt(Slice utf8, int position) // // Process fifth byte - if (position + 4 >= utf8.length()) { + if (position + 4 >= utf8Length) { return -4; } - byte fifthByte = utf8.getByteUnchecked(position + 4); + byte fifthByte = utf8[utf8Offset + position + 4]; if (!isContinuationByte(fifthByte)) { return -4; } @@ -715,11 +1187,11 @@ public static int tryGetCodePointAt(Slice utf8, int position) // // Process sixth byte - if (position + 5 >= utf8.length()) { + if (position + 5 >= utf8Length) { return -5; } - byte sixthByte = utf8.getByteUnchecked(position + 5); + byte sixthByte = utf8[utf8Offset + position + 5]; if (!isContinuationByte(sixthByte)) { return -5; } @@ -733,7 +1205,7 @@ public static int tryGetCodePointAt(Slice utf8, int position) return -1; } - static int lengthOfCodePointFromStartByteSafe(byte startByte) + private static int lengthOfCodePointFromStartByteSafe(byte startByte) { int unsignedStartByte = startByte & 0xFF; if (unsignedStartByte < 0b1000_0000) { @@ -778,7 +1250,17 @@ static int lengthOfCodePointFromStartByteSafe(byte startByte) */ public static int offsetOfCodePoint(Slice utf8, int codePointCount) { - return offsetOfCodePoint(utf8, 0, codePointCount); + return offsetOfCodePoint(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), 0, codePointCount); + } + + /** + * Finds the index of the first byte of the code point at a position within + * a UTF-8 byte array range, or {@code -1} if the position is not within the range. + */ + public static int offsetOfCodePoint(byte[] utf8, int offset, int length, int codePointCount) + { + checkFromIndexSize(offset, length, utf8.length); + return offsetOfCodePointRaw(utf8, offset, length, 0, codePointCount); } /** @@ -792,12 +1274,28 @@ public static int offsetOfCodePoint(Slice utf8, int codePointCount) * return incorrect results or throw an exception for invalid UTF-8. */ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount) + { + return offsetOfCodePoint(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position, codePointCount); + } + + /** + * Starting from {@code position} bytes in a UTF-8 byte array range, finds the + * index of the first byte of the code point {@code codePointCount} in the range. + * Returned position is relative to the provided range. + */ + public static int offsetOfCodePoint(byte[] utf8, int offset, int length, int position, int codePointCount) + { + checkFromIndexSize(offset, length, utf8.length); + return offsetOfCodePointRaw(utf8, offset, length, position, codePointCount); + } + + private static int offsetOfCodePointRaw(byte[] utf8, int utf8Offset, int utf8Length, int position, int codePointCount) { // allow position to be at the end of the slice - checkIndex(position, utf8.length() + 1); + checkIndex(position, utf8Length + 1); // Quick exit if we are sure that the position is after the end - if (utf8.length() - position <= codePointCount) { + if (utf8Length - position <= codePointCount) { return -1; } if (codePointCount == 0) { @@ -806,29 +1304,29 @@ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount int correctIndex = codePointCount + position; // Length rounded to 8 bytes - int length8 = (utf8.length() & 0x7FFF_FFF8) - 8; + int length8 = (utf8Length & 0x7FFF_FFF8) - 8; // process 8 bytes at a time // at most this can find 8 code points (if they are all US_ASCII), so this // is only called if there are at least 8 more code points needed while (position < length8 && correctIndex >= position + 8) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getLongUnchecked(position)); + correctIndex += countContinuationBytes((long) LONG_HANDLE.get(utf8, utf8Offset + position)); position += 8; } // Length rounded to 4 bytes - int length4 = (utf8.length() & 0x7FFF_FFFC) - 4; - // While we have enough bytes left and we need at least 4 characters process 4 bytes at once + int length4 = (utf8Length & 0x7FFF_FFFC) - 4; + // While we have enough bytes left, and we need at least 4 characters process 4 bytes at once while (position < length4 && correctIndex >= position + 4) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getIntUnchecked(position)); + correctIndex += countContinuationBytes((int) INT_HANDLE.get(utf8, utf8Offset + position)); position += 4; } // Do the rest one by one, always check the last byte to find the end of the code point - while (position < utf8.length()) { + while (position < utf8Length) { // Count bytes which are NOT the start of a code point - correctIndex += countContinuationBytes(utf8.getByteUnchecked(position)); + correctIndex += countContinuationBytes(utf8[utf8Offset + position]); if (position == correctIndex) { break; } @@ -836,7 +1334,7 @@ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount position++; } - if (position == correctIndex && correctIndex < utf8.length()) { + if (position == correctIndex && correctIndex < utf8Length) { return correctIndex; } return -1; @@ -850,7 +1348,23 @@ public static int offsetOfCodePoint(Slice utf8, int position, int codePointCount */ public static int lengthOfCodePoint(Slice utf8, int position) { - return lengthOfCodePointFromStartByte(utf8.getByte(position)); + return lengthOfCodePoint(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Gets the UTF-8 sequence length of the code point at {@code position} + * within a UTF-8 byte array range. + */ + public static int lengthOfCodePoint(byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkIndex(position, length); + return lengthOfCodePointRaw(utf8, offset, length, position); + } + + private static int lengthOfCodePointRaw(byte[] utf8, int utf8Offset, int utf8Length, int position) + { + return lengthOfCodePointFromStartByte(utf8[utf8Offset + position]); } /** @@ -861,28 +1375,44 @@ public static int lengthOfCodePoint(Slice utf8, int position) */ public static int lengthOfCodePointSafe(Slice utf8, int position) { - int length = lengthOfCodePointFromStartByteSafe(utf8.getByte(position)); + return lengthOfCodePointSafe(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Gets the UTF-8 sequence length of the code point at {@code position} + * within a UTF-8 byte array range. Invalid encodings are handled safely. + */ + public static int lengthOfCodePointSafe(byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkIndex(position, length); + return lengthOfCodePointSafeRaw(utf8, offset, length, position); + } + + private static int lengthOfCodePointSafeRaw(byte[] utf8, int utf8Offset, int utf8Length, int position) + { + int length = lengthOfCodePointFromStartByteSafe(utf8[utf8Offset + position]); if (length < 0) { return -length; } - if (length == 1 || position + 1 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 1))) { + if (length == 1 || position + 1 >= utf8Length || !isContinuationByte(utf8[utf8Offset + position + 1])) { return 1; } - if (length == 2 || position + 2 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 2))) { + if (length == 2 || position + 2 >= utf8Length || !isContinuationByte(utf8[utf8Offset + position + 2])) { return 2; } - if (length == 3 || position + 3 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 3))) { + if (length == 3 || position + 3 >= utf8Length || !isContinuationByte(utf8[utf8Offset + position + 3])) { return 3; } - if (length == 4 || position + 4 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 4))) { + if (length == 4 || position + 4 >= utf8Length || !isContinuationByte(utf8[utf8Offset + position + 4])) { return 4; } - if (length == 5 || position + 5 >= utf8.length() || !isContinuationByte(utf8.getByteUnchecked(position + 5))) { + if (length == 5 || position + 5 >= utf8Length || !isContinuationByte(utf8[utf8Offset + position + 5])) { return 5; } @@ -964,7 +1494,22 @@ public static int lengthOfCodePointFromStartByte(byte startByte) */ public static int getCodePointAt(Slice utf8, int position) { - int unsignedStartByte = utf8.getByte(position) & 0xFF; + return getCodePointAt(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Gets the UTF-8 encoded code point at {@code position} within a UTF-8 byte array range. + */ + public static int getCodePointAt(byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkIndex(position, length); + return getCodePointAtRaw(utf8, offset, length, position); + } + + private static int getCodePointAtRaw(byte[] utf8, int utf8Offset, int utf8Length, int position) + { + int unsignedStartByte = utf8[utf8Offset + position] & 0xFF; if (unsignedStartByte < 0x80) { // normal ASCII // 0xxx_xxxx @@ -977,30 +1522,30 @@ public static int getCodePointAt(Slice utf8, int position) } if (unsignedStartByte < 0xe0) { // 110x_xxxx 10xx_xxxx - if (position + 1 >= utf8.length()) { + if (position + 1 >= utf8Length) { throw new InvalidUtf8Exception("UTF-8 sequence truncated"); } return ((unsignedStartByte & 0b0001_1111) << 6) | - (utf8.getByte(position + 1) & 0b0011_1111); + (utf8[utf8Offset + position + 1] & 0b0011_1111); } if (unsignedStartByte < 0xf0) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - if (position + 2 >= utf8.length()) { + if (position + 2 >= utf8Length) { throw new InvalidUtf8Exception("UTF-8 sequence truncated"); } return ((unsignedStartByte & 0b0000_1111) << 12) | - ((utf8.getByteUnchecked(position + 1) & 0b0011_1111) << 6) | - (utf8.getByteUnchecked(position + 2) & 0b0011_1111); + ((utf8[utf8Offset + position + 1] & 0b0011_1111) << 6) | + (utf8[utf8Offset + position + 2] & 0b0011_1111); } if (unsignedStartByte < 0xf8) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - if (position + 3 >= utf8.length()) { + if (position + 3 >= utf8Length) { throw new InvalidUtf8Exception("UTF-8 sequence truncated"); } return ((unsignedStartByte & 0b0000_0111) << 18) | - ((utf8.getByteUnchecked(position + 1) & 0b0011_1111) << 12) | - ((utf8.getByteUnchecked(position + 2) & 0b0011_1111) << 6) | - (utf8.getByteUnchecked(position + 3) & 0b0011_1111); + ((utf8[utf8Offset + position + 1] & 0b0011_1111) << 12) | + ((utf8[utf8Offset + position + 2] & 0b0011_1111) << 6) | + (utf8[utf8Offset + position + 3] & 0b0011_1111); } // Per RFC3629, UTF-8 is limited to 4 bytes, so more bytes are illegal throw new InvalidUtf8Exception("Illegal start 0x" + toHexString(unsignedStartByte).toUpperCase() + " of code point"); @@ -1014,24 +1559,252 @@ public static int getCodePointAt(Slice utf8, int position) */ public static int getCodePointBefore(Slice utf8, int position) { - byte unsignedByte = utf8.getByte(position - 1); + return getCodePointBefore(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Gets the UTF-8 encoded code point before {@code position} within a UTF-8 byte array range. + */ + public static int getCodePointBefore(byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkFromIndexSize(position - 1, 1, length); + return getCodePointBeforeRaw(utf8, offset, length, position); + } + + private static int getCodePointBeforeRaw(byte[] utf8, int utf8Offset, int utf8Length, int position) + { + byte unsignedByte = utf8[utf8Offset + position - 1]; if (!isContinuationByte(unsignedByte)) { return unsignedByte & 0xFF; } - if (!isContinuationByte(utf8.getByte(position - 2))) { - return getCodePointAt(utf8, position - 2); + if (position < 2) { + throw new InvalidUtf8Exception("UTF-8 is not well formed"); + } + if (!isContinuationByte(utf8[utf8Offset + position - 2])) { + return getCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 2); } - if (!isContinuationByte(utf8.getByte(position - 3))) { - return getCodePointAt(utf8, position - 3); + if (position < 3) { + throw new InvalidUtf8Exception("UTF-8 is not well formed"); + } + if (!isContinuationByte(utf8[utf8Offset + position - 3])) { + return getCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 3); + } + if (position < 4) { + throw new InvalidUtf8Exception("UTF-8 is not well formed"); } - if (!isContinuationByte(utf8.getByte(position - 4))) { - return getCodePointAt(utf8, position - 4); + if (!isContinuationByte(utf8[utf8Offset + position - 4])) { + return getCodePointAtRaw(utf8, utf8Offset, utf8Length, position - 4); } // Per RFC3629, UTF-8 is limited to 4 bytes, so more bytes are illegal throw new InvalidUtf8Exception("UTF-8 is not well formed"); } + /** + * Decodes a UTF-8 slice into Unicode code points. + * + * @throws InvalidUtf8Exception if the input contains invalid UTF-8 + */ + public static int[] toCodePoints(Slice utf8) + { + return toCodePoints(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Decodes a UTF-8 byte array range into Unicode code points. + * + * @throws InvalidUtf8Exception if the input contains invalid UTF-8 + */ + public static int[] toCodePoints(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return toCodePointsRaw(utf8, offset, length); + } + + private static int[] toCodePointsRaw(byte[] utf8, int utf8Offset, int utf8Length) + { + if (utf8Length == 0) { + return new int[0]; + } + + if (isAsciiRaw(utf8, utf8Offset, utf8Length)) { + int[] codePoints = new int[utf8Length]; + for (int index = 0; index < utf8Length; index++) { + codePoints[index] = utf8[utf8Offset + index] & 0x7F; + } + return codePoints; + } + + int[] codePoints = new int[Math.max(8, utf8Length >>> 1)]; + int codePointCount = 0; + int position = 0; + while (position < utf8Length) { + int codePoint = tryGetCodePointAtRaw(utf8, utf8Offset, utf8Length, position); + if (codePoint < 0) { + throw new InvalidUtf8Exception("Invalid UTF-8 sequence at position " + position); + } + + if (codePointCount == codePoints.length) { + codePoints = Arrays.copyOf(codePoints, codePoints.length * 2); + } + codePoints[codePointCount] = codePoint; + codePointCount++; + + if (codePoint < 0x80) { + position++; + } + else if (codePoint < 0x800) { + position += 2; + } + else if (codePoint < 0x1_0000) { + position += 3; + } + else { + position += 4; + } + } + + if (codePointCount == codePoints.length) { + return codePoints; + } + return Arrays.copyOf(codePoints, codePointCount); + } + + /** + * Decodes UTF-8 and returns UTF-8 byte lengths ({@code 1..4}) for each code point. + *

+ * Note: This method does not explicitly check for valid UTF-8, and may + * return incorrect results or throw an exception for invalid UTF-8. + */ + public static byte[] codePointByteLengths(Slice utf8) + { + return codePointByteLengths(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length()); + } + + /** + * Decodes UTF-8 byte array range and returns UTF-8 byte lengths ({@code 1..4}) for each code point. + *

+ * Note: This method does not explicitly check for valid UTF-8, and may + * return incorrect results or throw an exception for invalid UTF-8. + */ + public static byte[] codePointByteLengths(byte[] utf8, int offset, int length) + { + checkFromIndexSize(offset, length, utf8.length); + return codePointByteLengthsRaw(utf8, offset, length); + } + + private static byte[] codePointByteLengthsRaw(byte[] utf8, int utf8Offset, int utf8Length) + { + if (utf8Length == 0) { + return new byte[0]; + } + + if (isAsciiRaw(utf8, utf8Offset, utf8Length)) { + byte[] lengths = new byte[utf8Length]; + Arrays.fill(lengths, (byte) 1); + return lengths; + } + + byte[] lengths = new byte[Math.max(8, utf8Length >>> 1)]; + int codePointCount = 0; + int position = 0; + while (position < utf8Length) { + int codePointLength = lengthOfCodePointFromStartByteSafe(utf8[utf8Offset + position]); + if (codePointLength < 0 || position + codePointLength > utf8Length) { + throw new InvalidUtf8Exception("Invalid UTF-8 sequence at position " + position); + } + + if (codePointCount == lengths.length) { + lengths = Arrays.copyOf(lengths, lengths.length * 2); + } + lengths[codePointCount] = (byte) codePointLength; + codePointCount++; + position += codePointLength; + } + + if (codePointCount == lengths.length) { + return lengths; + } + return Arrays.copyOf(lengths, codePointCount); + } + + /** + * Encodes Unicode code points into UTF-8. + * + * @throws InvalidCodePointException if any code point is invalid + */ + public static Slice fromCodePoints(int[] codePoints) + { + return fromCodePoints(codePoints, 0, codePoints.length); + } + + /** + * Encodes a range of Unicode code points into UTF-8. + * + * @throws InvalidCodePointException if any code point is invalid + */ + public static Slice fromCodePoints(int[] codePoints, int offset, int length) + { + checkFromIndexSize(offset, length, codePoints.length); + if (length == 0) { + return Slices.EMPTY_SLICE; + } + return fromCodePointsRaw(codePoints, offset, length); + } + + private static Slice fromCodePointsRaw(int[] codePoints, int codePointsOffset, int codePointsLength) + { + int utf8Length = 0; + boolean ascii = true; + for (int index = 0; index < codePointsLength; index++) { + int codePoint = codePoints[codePointsOffset + index]; + int codePointLength = lengthOfCodePoint(codePoint); + if (codePointLength == 3 && MIN_SURROGATE <= codePoint && codePoint <= MAX_SURROGATE) { + throw new InvalidCodePointException(codePoint); + } + utf8Length += codePointLength; + ascii = ascii && (codePointLength == 1); + } + + byte[] utf8 = new byte[utf8Length]; + if (ascii) { + for (int index = 0; index < codePointsLength; index++) { + utf8[index] = (byte) codePoints[codePointsOffset + index]; + } + return Slices.wrappedBuffer(utf8); + } + + int position = 0; + for (int index = 0; index < codePointsLength; index++) { + int codePoint = codePoints[codePointsOffset + index]; + if (codePoint < 0x80) { + utf8[position] = (byte) codePoint; + position++; + } + else if (codePoint < 0x800) { + utf8[position] = (byte) (0b1100_0000 | (codePoint >>> 6)); + utf8[position + 1] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + position += 2; + } + else if (codePoint < 0x1_0000) { + utf8[position] = (byte) (0b1110_0000 | ((codePoint >>> 12) & 0b0000_1111)); + utf8[position + 1] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[position + 2] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + position += 3; + } + else { + utf8[position] = (byte) (0b1111_0000 | ((codePoint >>> 18) & 0b0000_0111)); + utf8[position + 1] = (byte) (0b1000_0000 | ((codePoint >>> 12) & 0b0011_1111)); + utf8[position + 2] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[position + 3] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + position += 4; + } + } + + return Slices.wrappedBuffer(utf8); + } + private static boolean isContinuationByte(byte b) { return (b & 0b1100_0000) == 0b1000_0000; @@ -1056,6 +1829,55 @@ public static Slice codePointToUtf8(int codePoint) * @throws InvalidCodePointException if code point is not within a valid range */ public static int setCodePointAt(int codePoint, Slice utf8, int position) + { + return setCodePointAt(codePoint, utf8.byteArray(), utf8.byteArrayOffset(), utf8.length(), position); + } + + /** + * Sets the UTF-8 sequence for code point at {@code position} within a UTF-8 byte array range. + * + * @throws InvalidCodePointException if code point is not within a valid range + */ + public static int setCodePointAt(int codePoint, byte[] utf8, int offset, int length, int position) + { + checkFromIndexSize(offset, length, utf8.length); + checkIndex(position, length); + int codePointLength = lengthOfCodePoint(codePoint); + if (codePointLength == 3 && MIN_SURROGATE <= codePoint && codePoint <= MAX_SURROGATE) { + throw new InvalidCodePointException(codePoint); + } + checkFromIndexSize(position, codePointLength, length); + int start = offset + position; + + switch (codePointLength) { + case 1 -> { + // 0xxx_xxxx + utf8[start] = (byte) codePoint; + } + case 2 -> { + // 110x_xxxx 10xx_xxxx + utf8[start] = (byte) (0b1100_0000 | (codePoint >>> 6)); + utf8[start + 1] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + } + case 3 -> { + // 1110_xxxx 10xx_xxxx 10xx_xxxx + utf8[start] = (byte) (0b1110_0000 | ((codePoint >>> 12) & 0b0000_1111)); + utf8[start + 1] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[start + 2] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + } + case 4 -> { + // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx + utf8[start] = (byte) (0b1111_0000 | ((codePoint >>> 18) & 0b0000_0111)); + utf8[start + 1] = (byte) (0b1000_0000 | ((codePoint >>> 12) & 0b0011_1111)); + utf8[start + 2] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[start + 3] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); + } + default -> throw new InvalidCodePointException(codePoint); + } + return codePointLength; + } + + private static int setCodePointAtUnchecked(int codePoint, byte[] utf8, int position) { if (codePoint < 0) { throw new InvalidCodePointException(codePoint); @@ -1063,14 +1885,13 @@ public static int setCodePointAt(int codePoint, Slice utf8, int position) if (codePoint < 0x80) { // normal ASCII // 0xxx_xxxx - utf8.setByte(position, codePoint); + utf8[position] = (byte) codePoint; return 1; } if (codePoint < 0x800) { // 110x_xxxx 10xx_xxxx - checkFromIndexSize(position, 1, utf8.length()); - utf8.setByteUnchecked(position, 0b1100_0000 | (codePoint >>> 6)); - utf8.setByteUnchecked(position + 1, 0b1000_0000 | (codePoint & 0b0011_1111)); + utf8[position] = (byte) (0b1100_0000 | (codePoint >>> 6)); + utf8[position + 1] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); return 2; } if (MIN_SURROGATE <= codePoint && codePoint <= MAX_SURROGATE) { @@ -1078,19 +1899,17 @@ public static int setCodePointAt(int codePoint, Slice utf8, int position) } if (codePoint < 0x1_0000) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - checkFromIndexSize(position, 2, utf8.length()); - utf8.setByteUnchecked(position, 0b1110_0000 | ((codePoint >>> 12) & 0b0000_1111)); - utf8.setByteUnchecked(position + 1, 0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); - utf8.setByteUnchecked(position + 2, 0b1000_0000 | (codePoint & 0b0011_1111)); + utf8[position] = (byte) (0b1110_0000 | ((codePoint >>> 12) & 0b0000_1111)); + utf8[position + 1] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[position + 2] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); return 3; } if (codePoint < 0x11_0000) { - checkFromIndexSize(position, 3, utf8.length()); // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - utf8.setByteUnchecked(position, 0b1111_0000 | ((codePoint >>> 18) & 0b0000_0111)); - utf8.setByteUnchecked(position + 1, 0b1000_0000 | ((codePoint >>> 12) & 0b0011_1111)); - utf8.setByteUnchecked(position + 2, 0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); - utf8.setByteUnchecked(position + 3, 0b1000_0000 | (codePoint & 0b0011_1111)); + utf8[position] = (byte) (0b1111_0000 | ((codePoint >>> 18) & 0b0000_0111)); + utf8[position + 1] = (byte) (0b1000_0000 | ((codePoint >>> 12) & 0b0011_1111)); + utf8[position + 2] = (byte) (0b1000_0000 | ((codePoint >>> 6) & 0b0011_1111)); + utf8[position + 3] = (byte) (0b1000_0000 | (codePoint & 0b0011_1111)); return 4; } // Per RFC3629, UTF-8 is limited to 4 bytes, so more bytes are illegal diff --git a/src/test/java/io/airlift/slice/SliceUtf8Benchmark.java b/src/test/java/io/airlift/slice/SliceUtf8Benchmark.java index 4e6ac28c..9f4315e7 100644 --- a/src/test/java/io/airlift/slice/SliceUtf8Benchmark.java +++ b/src/test/java/io/airlift/slice/SliceUtf8Benchmark.java @@ -29,21 +29,31 @@ import org.openjdk.jmh.runner.options.VerboseMode; import java.nio.charset.StandardCharsets; +import java.util.OptionalInt; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; +import static io.airlift.slice.SliceUtf8.codePointByteLengths; +import static io.airlift.slice.SliceUtf8.codePointToUtf8; +import static io.airlift.slice.SliceUtf8.compareUtf16BE; import static io.airlift.slice.SliceUtf8.countCodePoints; +import static io.airlift.slice.SliceUtf8.fixInvalidUtf8; +import static io.airlift.slice.SliceUtf8.fromCodePoints; +import static io.airlift.slice.SliceUtf8.getCodePointAt; import static io.airlift.slice.SliceUtf8.leftTrim; import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; import static io.airlift.slice.SliceUtf8.lengthOfCodePointFromStartByte; +import static io.airlift.slice.SliceUtf8.lengthOfCodePointSafe; import static io.airlift.slice.SliceUtf8.offsetOfCodePoint; import static io.airlift.slice.SliceUtf8.reverse; import static io.airlift.slice.SliceUtf8.rightTrim; +import static io.airlift.slice.SliceUtf8.setCodePointAt; import static io.airlift.slice.SliceUtf8.substring; +import static io.airlift.slice.SliceUtf8.toCodePoints; import static io.airlift.slice.SliceUtf8.toLowerCase; import static io.airlift.slice.SliceUtf8.toUpperCase; import static io.airlift.slice.SliceUtf8.trim; -import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.SliceUtf8.tryGetCodePointAt; import static java.lang.Character.MAX_CODE_POINT; import static java.lang.Character.SURROGATE; import static java.lang.Character.getType; @@ -64,11 +74,14 @@ public class SliceUtf8Benchmark @Benchmark public int benchmarkLengthOfCodePointFromStartByte(BenchmarkData data) { - Slice slice = data.getSlice(); + byte[] utf8 = data.getUtf8(); + int baseOffset = data.getOffset(); + int byteLength = data.getByteLength(); + int i = 0; int codePoints = 0; - while (i < slice.length()) { - i += lengthOfCodePointFromStartByte(slice.getByte(i)); + while (i < byteLength) { + i += lengthOfCodePointFromStartByte(utf8[baseOffset + i]); codePoints++; } if (codePoints != data.getLength()) { @@ -80,66 +93,501 @@ public int benchmarkLengthOfCodePointFromStartByte(BenchmarkData data) @Benchmark public int benchmarkCountCodePoints(BenchmarkData data) { - int codePoints = countCodePoints(data.getSlice()); + int codePoints = countCodePoints(data.getUtf8(), data.getOffset(), data.getByteLength()); if (codePoints != data.getLength()) { throw new AssertionError(); } return codePoints; } + @Benchmark + public int benchmarkCountCodePointsRange(RangeCountData data) + { + int codePoints = countCodePoints(data.getUtf8(), data.getOffset(), data.getRangeLength()); + if (codePoints != data.getExpectedCodePoints()) { + throw new AssertionError(); + } + return codePoints; + } + @Benchmark public int benchmarkOffsetByCodePoints(BenchmarkData data) { - Slice slice = data.getSlice(); - int offset = offsetOfCodePoint(slice, data.getLength() - 1); - if (offset + lengthOfCodePoint(slice, offset) != slice.length()) { + int index = offsetOfCodePoint(data.getUtf8(), data.getOffset(), data.getByteLength(), data.getLength() - 1); + if (index + lengthOfCodePoint(data.getUtf8(), data.getOffset(), data.getByteLength(), index) != data.getByteLength()) { + throw new AssertionError(); + } + return index; + } + + @Benchmark + public int benchmarkTryGetCodePointAt(BenchmarkData data) + { + byte[] utf8 = data.getUtf8(); + int baseOffset = data.getOffset(); + int byteLength = data.getByteLength(); + + int offset = 0; + int codePoints = 0; + while (offset < byteLength) { + int codePoint = tryGetCodePointAt(utf8, baseOffset, byteLength, offset); + if (codePoint < 0) { + throw new AssertionError(); + } + offset += lengthOfCodePoint(codePoint); + codePoints++; + } + if (codePoints != data.getLength()) { + throw new AssertionError(); + } + return codePoints; + } + + @Benchmark + public int benchmarkGetCodePointAt(BenchmarkData data) + { + Slice utf8 = data.getSlice(); + int offset = 0; + int checksum = 0; + while (offset < utf8.length()) { + checksum ^= getCodePointAt(utf8, offset); + offset += lengthOfCodePoint(utf8, offset); + } + return checksum; + } + + @Benchmark + public int benchmarkLengthOfCodePointSafe(BenchmarkData data) + { + Slice utf8 = data.getSlice(); + int offset = 0; + int consumed = 0; + while (offset < utf8.length()) { + int codePointLength = lengthOfCodePointSafe(utf8, offset); + offset += codePointLength; + consumed += codePointLength; + } + if (consumed != utf8.length()) { + throw new AssertionError(); + } + return consumed; + } + + @Benchmark + public int[] benchmarkToCodePointsApi(BenchmarkData data) + { + int[] codePoints = toCodePoints(data.getUtf8(), data.getOffset(), data.getByteLength()); + if (codePoints.length != data.getLength()) { + throw new AssertionError(); + } + return codePoints; + } + + @Benchmark + public int[] benchmarkTrinoCastToCodePointsTwoPass(BenchmarkData data) + { + Slice utf8 = data.getSlice(); + + int codePointCount = 0; + for (int position = 0; position < utf8.length(); ) { + int codePoint = tryGetCodePointAt(utf8, position); + if (codePoint < 0) { + throw new AssertionError(); + } + position += lengthOfCodePoint(codePoint); + codePointCount++; + } + + int[] codePoints = new int[codePointCount]; + int position = 0; + for (int index = 0; index < codePoints.length; index++) { + codePoints[index] = getCodePointAt(utf8, position); + position += lengthOfCodePoint(utf8, position); + } + return codePoints; + } + + @Benchmark + public int benchmarkTrinoToCodePoints(TrinoCodePointData data) + { + Slice utf8 = data.getSlice(); + int offset = 0; + int checksum = 0; + while (offset < utf8.length()) { + int codePoint = getCodePointAt(utf8, offset); + offset += lengthOfCodePoint(utf8, offset); + checksum ^= codePoint; + } + return checksum; + } + + @Benchmark + public int benchmarkTrinoToCodePointsLengthFromDecoded(TrinoCodePointData data) + { + Slice utf8 = data.getSlice(); + int offset = 0; + int checksum = 0; + while (offset < utf8.length()) { + int codePoint = getCodePointAt(utf8, offset); + offset += lengthOfCodePoint(codePoint); + checksum ^= codePoint; + } + return checksum; + } + + @Benchmark + public int benchmarkTrinoToCodePointsByteArray(TrinoCodePointData data) + { + byte[] utf8 = data.getUtf8(); + int baseOffset = data.getOffset(); + int byteLength = data.getByteLength(); + + int offset = 0; + int checksum = 0; + while (offset < byteLength) { + int codePoint = getCodePointAt(utf8, baseOffset, byteLength, offset); + offset += lengthOfCodePoint(codePoint); + checksum ^= codePoint; + } + return checksum; + } + + @Benchmark + public Slice benchmarkFromCodePointsApi(CodePointWriteData data) + { + Slice result = fromCodePoints(data.getCodePoints()); + if (result.length() != data.getExpectedBytes()) { + throw new AssertionError(); + } + return result; + } + + @Benchmark + public Slice benchmarkTrinoCodePointsToSliceUtf8Baseline(CodePointWriteData data) + { + int[] codePoints = data.getCodePoints(); + + int bufferLength = 0; + for (int codePoint : codePoints) { + bufferLength += lengthOfCodePoint(codePoint); + } + + Slice result = Slices.wrappedBuffer(new byte[bufferLength]); + int offset = 0; + for (int codePoint : codePoints) { + offset += setCodePointAt(codePoint, result, offset); + } + if (offset != bufferLength) { + throw new AssertionError(); + } + return result; + } + + @Benchmark + public int benchmarkTrinoPatternConstantPrefixBytes(LikePatternData data) + { + Slice pattern = data.getPattern(); + int escapeChar = data.getEscapeChar(); + + boolean escaped = false; + int position = 0; + while (position < pattern.length()) { + int currentChar = getCodePointAt(pattern, position); + if (!escaped && (currentChar == escapeChar)) { + escaped = true; + } + else if (escaped) { + escaped = false; + } + else if ((currentChar == '%') || (currentChar == '_')) { + return position; + } + position += lengthOfCodePoint(currentChar); + } + if (escaped) { throw new AssertionError(); } - return offset; + return position; + } + + @Benchmark + public int benchmarkTrinoPatternConstantPrefixBytesByteArray(LikePatternData data) + { + Slice pattern = data.getPattern(); + byte[] utf8 = pattern.byteArray(); + int baseOffset = pattern.byteArrayOffset(); + int byteLength = pattern.length(); + int escapeChar = data.getEscapeChar(); + + boolean escaped = false; + int position = 0; + while (position < byteLength) { + int currentChar = getCodePointAt(utf8, baseOffset, byteLength, position); + if (!escaped && (currentChar == escapeChar)) { + escaped = true; + } + else if (escaped) { + escaped = false; + } + else if ((currentChar == '%') || (currentChar == '_')) { + return position; + } + position += lengthOfCodePoint(currentChar); + } + if (escaped) { + throw new AssertionError(); + } + return position; + } + + @Benchmark + public int benchmarkTrinoPadStringCodePointLengths(TrinoPadData data) + { + Slice padString = data.getPadString(); + int padStringLength = countCodePoints(padString); + int[] padStringCounts = new int[padStringLength]; + for (int index = 0; index < padStringLength; index++) { + padStringCounts[index] = lengthOfCodePointSafe(padString, offsetOfCodePoint(padString, index)); + } + return checksum(padStringCounts); + } + + @Benchmark + public int benchmarkTrinoPadStringCodePointLengthsSinglePass(TrinoPadData data) + { + Slice padString = data.getPadString(); + int[] padStringCounts = new int[countCodePoints(padString)]; + int position = 0; + int index = 0; + while (position < padString.length()) { + int codePoint = getCodePointAt(padString, position); + int codePointLength = lengthOfCodePoint(codePoint); + padStringCounts[index] = codePointLength; + index++; + position += codePointLength; + } + if (index != padStringCounts.length) { + throw new AssertionError(); + } + return checksum(padStringCounts); + } + + @Benchmark + public int benchmarkTrinoPadStringCodePointLengthsByteArray(TrinoPadData data) + { + byte[] utf8 = data.getUtf8(); + int baseOffset = data.getOffset(); + int byteLength = data.getByteLength(); + int[] padStringCounts = new int[countCodePoints(utf8, baseOffset, byteLength)]; + int position = 0; + int index = 0; + while (position < byteLength) { + int codePoint = getCodePointAt(utf8, baseOffset, byteLength, position); + int codePointLength = lengthOfCodePoint(codePoint); + padStringCounts[index] = codePointLength; + index++; + position += codePointLength; + } + if (index != padStringCounts.length) { + throw new AssertionError(); + } + return checksum(padStringCounts); + } + + @Benchmark + public int benchmarkTrinoPadStringCodePointLengthsSliceUtf8Helper(TrinoPadData data) + { + return checksum(codePointByteLengths(data.getPadString())); + } + + @Benchmark + public int benchmarkTrinoPadStringCodePointLengthsSliceUtf8HelperByteArray(TrinoPadData data) + { + return checksum(codePointByteLengths(data.getUtf8(), data.getOffset(), data.getByteLength())); + } + + @Benchmark + public Slice benchmarkTrinoDomainTranslatorPrefixRange(TrinoPrefixRangeData data) + { + Slice constantPrefix = data.getConstantPrefix(); + + int lastIncrementable = -1; + for (int position = 0; position < constantPrefix.length(); position += lengthOfCodePoint(constantPrefix, position)) { + if (getCodePointAt(constantPrefix, position) < 127) { + lastIncrementable = position; + } + } + + if (lastIncrementable == -1) { + return Slices.EMPTY_SLICE; + } + + Slice upperBound = constantPrefix.slice(0, lastIncrementable + lengthOfCodePoint(constantPrefix, lastIncrementable)).copy(); + setCodePointAt(getCodePointAt(constantPrefix, lastIncrementable) + 1, upperBound, lastIncrementable); + return upperBound; + } + + @Benchmark + public Slice benchmarkTrinoDomainTranslatorPrefixRangeSingleDecode(TrinoPrefixRangeData data) + { + byte[] utf8 = data.getUtf8(); + int baseOffset = data.getOffset(); + int byteLength = data.getByteLength(); + Slice constantPrefix = data.getConstantPrefix(); + + int lastIncrementableOffset = -1; + int lastIncrementableCodePoint = -1; + int lastIncrementableLength = 0; + int position = 0; + while (position < byteLength) { + int codePoint = getCodePointAt(utf8, baseOffset, byteLength, position); + int codePointLength = lengthOfCodePoint(codePoint); + if (codePoint < 127) { + lastIncrementableOffset = position; + lastIncrementableCodePoint = codePoint; + lastIncrementableLength = codePointLength; + } + position += codePointLength; + } + + if (lastIncrementableOffset == -1) { + return Slices.EMPTY_SLICE; + } + + Slice upperBound = constantPrefix.slice(0, lastIncrementableOffset + lastIncrementableLength).copy(); + setCodePointAt(lastIncrementableCodePoint + 1, upperBound, lastIncrementableOffset); + return upperBound; + } + + @Benchmark + public int benchmarkCompareUtf16BE(CompareData data) + { + int result = compareUtf16BE( + data.getUtf8(), data.getOffset(), data.getByteLength(), + data.getRightUtf8(), data.getRightOffset(), data.getRightByteLength()); + if (result != 0) { + throw new AssertionError(); + } + return result; } @Benchmark public Slice benchmarkSubstring(BenchmarkData data) { - Slice slice = data.getSlice(); int length = data.getLength(); - return substring(slice, (length / 2) - 1, length / 2); + return substring(data.getUtf8(), data.getOffset(), data.getByteLength(), (length / 2) - 1, length / 2); } @Benchmark public Slice benchmarkReverse(BenchmarkData data) { - return reverse(data.getSlice()); + return reverse(data.getUtf8(), data.getOffset(), data.getByteLength()); } @Benchmark public Slice benchmarkToLowerCase(BenchmarkData data) { - return toLowerCase(data.getSlice()); + return toLowerCase(data.getUtf8(), data.getOffset(), data.getByteLength()); } @Benchmark public Slice benchmarkToUpperCase(BenchmarkData data) { - return toUpperCase(data.getSlice()); + return toUpperCase(data.getUtf8(), data.getOffset(), data.getByteLength()); } @Benchmark public Slice benchmarkLeftTrim(WhitespaceData data) { - return leftTrim(data.getLeftWhitespace()); + return leftTrim(data.getLeftWhitespace(), 0, data.getLeftWhitespace().length); + } + + @Benchmark + public Slice benchmarkLeftTrimCustom(WhitespaceData data) + { + return leftTrim(data.getLeftWhitespace(), 0, data.getLeftWhitespace().length, data.getTrimCodePoints()); } @Benchmark public Slice benchmarkRightTrim(WhitespaceData data) { - return rightTrim(data.getRightWhitespace()); + return rightTrim(data.getRightWhitespace(), 0, data.getRightWhitespace().length); + } + + @Benchmark + public Slice benchmarkRightTrimCustom(WhitespaceData data) + { + return rightTrim(data.getRightWhitespace(), 0, data.getRightWhitespace().length, data.getTrimCodePoints()); } @Benchmark public Slice benchmarkTrim(WhitespaceData data) { - return trim(data.getBothWhitespace()); + return trim(data.getBothWhitespace(), 0, data.getBothWhitespace().length); + } + + @Benchmark + public Slice benchmarkTrimCustom(WhitespaceData data) + { + return trim(data.getBothWhitespace(), 0, data.getBothWhitespace().length, data.getTrimCodePoints()); + } + + @Benchmark + public Slice benchmarkFixInvalidUtf8WithReplacement(FixInvalidUtf8Data data) + { + return fixInvalidUtf8(data.getUtf8(), data.getOffset(), data.getLength()); + } + + @Benchmark + public Slice benchmarkFixInvalidUtf8WithoutReplacement(FixInvalidUtf8Data data) + { + return fixInvalidUtf8(data.getUtf8(), data.getOffset(), data.getLength(), OptionalInt.empty()); + } + + @Benchmark + public int benchmarkSetCodePointAt(CodePointWriteData data) + { + Slice output = data.getOutput(); + int position = 0; + int[] codePoints = data.getCodePoints(); + for (int codePoint : codePoints) { + position += setCodePointAt(codePoint, output, position); + } + if (position != data.getExpectedBytes()) { + throw new AssertionError(); + } + return position; + } + + @Benchmark + public int benchmarkCodePointToUtf8(CodePointWriteData data) + { + int totalBytes = 0; + for (int codePoint : data.getCodePoints()) { + totalBytes += codePointToUtf8(codePoint).length(); + } + if (totalBytes != data.getExpectedBytes()) { + throw new AssertionError(); + } + return totalBytes; + } + + private static int checksum(int[] values) + { + int checksum = 1; + for (int value : values) { + checksum = (31 * checksum) ^ value; + } + return checksum; + } + + private static int checksum(byte[] values) + { + int checksum = 1; + for (byte value : values) { + checksum = (31 * checksum) ^ value; + } + return checksum; } @State(Thread) @@ -162,8 +610,10 @@ public static class BenchmarkData @Param({"true", "false"}) private boolean ascii; + private byte[] utf8; + private int offset; + private int byteLength; private Slice slice; - private int[] codePoints; @Setup public void setup() @@ -171,14 +621,33 @@ public void setup() int[] codePointSet = ascii ? ASCII_CODE_POINTS : ALL_CODE_POINTS; ThreadLocalRandom random = ThreadLocalRandom.current(); - codePoints = new int[length]; DynamicSliceOutput sliceOutput = new DynamicSliceOutput(length * 4); - for (int i = 0; i < codePoints.length; i++) { + for (int i = 0; i < length; i++) { int codePoint = codePointSet[random.nextInt(codePointSet.length)]; - codePoints[i] = codePoint; sliceOutput.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); } - slice = sliceOutput.slice(); + + byte[] data = sliceOutput.slice().getBytes(); + offset = 7; + utf8 = new byte[offset + data.length + 3]; + System.arraycopy(data, 0, utf8, offset, data.length); + byteLength = data.length; + slice = Slices.wrappedBuffer(utf8, offset, byteLength); + } + + public byte[] getUtf8() + { + return utf8; + } + + public int getOffset() + { + return offset; + } + + public int getByteLength() + { + return byteLength; } public Slice getSlice() @@ -192,6 +661,41 @@ public int getLength() } } + @State(Thread) + public static class CompareData + extends BenchmarkData + { + private byte[] rightUtf8; + private int rightOffset; + private int rightByteLength; + + @Override + @Setup + public void setup() + { + super.setup(); + rightOffset = 11; + rightByteLength = getByteLength(); + rightUtf8 = new byte[rightOffset + rightByteLength + 5]; + System.arraycopy(getUtf8(), getOffset(), rightUtf8, rightOffset, rightByteLength); + } + + public byte[] getRightUtf8() + { + return rightUtf8; + } + + public int getRightOffset() + { + return rightOffset; + } + + public int getRightByteLength() + { + return rightByteLength; + } + } + @State(Thread) public static class WhitespaceData { @@ -213,23 +717,28 @@ public static class WhitespaceData @Param({"true", "false"}) private boolean ascii; - private Slice leftWhitespace; - private Slice rightWhitespace; - private Slice bothWhitespace; + private byte[] leftWhitespace; + private byte[] rightWhitespace; + private byte[] bothWhitespace; + private int[] trimCodePoints; @Setup public void setup() { - Slice whitespace = createRandomUtf8Slice(ascii ? ASCII_WHITESPACE : ALL_WHITESPACE, length + 1); - leftWhitespace = whitespace.copy(); - leftWhitespace.setByte(leftWhitespace.length() - 1, 'X'); - rightWhitespace = whitespace.copy(); - rightWhitespace.setByte(0, 'X'); - bothWhitespace = whitespace.copy(); - bothWhitespace.setByte(length / 2, 'X'); + trimCodePoints = ascii ? ASCII_WHITESPACE : ALL_WHITESPACE; + byte[] whitespace = createRandomUtf8Bytes(trimCodePoints, length + 1); + + leftWhitespace = whitespace.clone(); + leftWhitespace[leftWhitespace.length - 1] = 'X'; + + rightWhitespace = whitespace.clone(); + rightWhitespace[0] = 'X'; + + bothWhitespace = whitespace.clone(); + bothWhitespace[bothWhitespace.length / 2] = 'X'; } - private static Slice createRandomUtf8Slice(int[] codePointSet, int length) + private static byte[] createRandomUtf8Bytes(int[] codePointSet, int length) { int[] codePoints = new int[length]; ThreadLocalRandom random = ThreadLocalRandom.current(); @@ -237,27 +746,370 @@ private static Slice createRandomUtf8Slice(int[] codePointSet, int length) int codePoint = codePointSet[random.nextInt(codePointSet.length)]; codePoints[i] = codePoint; } - return utf8Slice(new String(codePoints, 0, codePoints.length)); + return new String(codePoints, 0, codePoints.length).getBytes(StandardCharsets.UTF_8); + } + + public byte[] getLeftWhitespace() + { + return leftWhitespace; + } + + public byte[] getRightWhitespace() + { + return rightWhitespace; + } + + public byte[] getBothWhitespace() + { + return bothWhitespace; + } + + public int[] getTrimCodePoints() + { + return trimCodePoints; + } + } + + @State(Thread) + public static class RangeCountData + { + @Param({"true", "false"}) + private boolean ascii; + + @Param("1000") + private int rangeLengthCodePoints; + + @Param({"1", "7", "31"}) + private int offsetBytes; + + private byte[] utf8; + private int offset; + private int rangeLength; + private int expectedCodePoints; + + @Setup + public void setup() + { + int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS; + ThreadLocalRandom random = ThreadLocalRandom.current(); + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput((rangeLengthCodePoints * 4) + offsetBytes); + + // Fixed-width ASCII prefix guarantees a non-zero, deterministic byte offset. + for (int i = 0; i < offsetBytes; i++) { + sliceOutput.appendByte('x'); + } + + offset = offsetBytes; + + for (int i = 0; i < rangeLengthCodePoints; i++) { + int codePoint = codePointSet[random.nextInt(codePointSet.length)]; + sliceOutput.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); + } + + utf8 = sliceOutput.slice().getBytes(); + rangeLength = utf8.length - offset; + expectedCodePoints = rangeLengthCodePoints; + } + + public byte[] getUtf8() + { + return utf8; + } + + public int getOffset() + { + return offset; + } + + public int getRangeLength() + { + return rangeLength; + } + + public int getExpectedCodePoints() + { + return expectedCodePoints; + } + } + + @State(Thread) + public static class FixInvalidUtf8Data + { + @Param({"ascii", "valid_non_ascii", "invalid_non_ascii"}) + private String inputKind; + + @Param("1024") + private int inputLength; + + private byte[] utf8; + private int offset; + private int length; + + @Setup + public void setup() + { + ThreadLocalRandom random = ThreadLocalRandom.current(); + if (inputKind.equals("ascii")) { + int[] asciiCodePoints = BenchmarkData.ASCII_CODE_POINTS; + int[] codePoints = new int[inputLength]; + for (int i = 0; i < codePoints.length; i++) { + codePoints[i] = asciiCodePoints[random.nextInt(asciiCodePoints.length)]; + } + setPaddedInput(new String(codePoints, 0, codePoints.length).getBytes(StandardCharsets.UTF_8)); + return; + } + + DynamicSliceOutput out = new DynamicSliceOutput(inputLength * 4); + int[] allCodePoints = BenchmarkData.ALL_CODE_POINTS; + for (int i = 0; i < inputLength; i++) { + int codePoint = allCodePoints[random.nextInt(allCodePoints.length)]; + out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); + } + byte[] input = out.slice().getBytes(); + + if (inputKind.equals("invalid_non_ascii") && input.length > 8) { + // Insert an illegal byte to force invalid UTF-8 handling. + input[input.length / 2] = (byte) 0xFF; + } + + setPaddedInput(input); + } + + private void setPaddedInput(byte[] input) + { + offset = 5; + utf8 = new byte[offset + input.length + 3]; + System.arraycopy(input, 0, utf8, offset, input.length); + length = input.length; + } + + public byte[] getUtf8() + { + return utf8; + } + + public int getOffset() + { + return offset; } public int getLength() { return length; } + } - public Slice getLeftWhitespace() + @State(Thread) + public static class TrinoCodePointData + extends BenchmarkData + { + // Uses BenchmarkData setup and exposes the Slice-based access pattern used in Trino loops. + } + + @State(Thread) + public static class LikePatternData + { + @Param("1000") + private int length; + + @Param({"true", "false"}) + private boolean ascii; + + private Slice pattern; + + @Setup + public void setup() { - return leftWhitespace; + int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS; + ThreadLocalRandom random = ThreadLocalRandom.current(); + DynamicSliceOutput out = new DynamicSliceOutput((length * 4) + 1); + + for (int i = 0; i < length - 1; i++) { + int codePoint; + do { + codePoint = codePointSet[random.nextInt(codePointSet.length)]; + } + while (codePoint == '%' || codePoint == '_' || codePoint == '\\'); + out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); + } + out.appendByte('%'); + + byte[] encoded = out.slice().getBytes(); + byte[] padded = new byte[3 + encoded.length + 2]; + System.arraycopy(encoded, 0, padded, 3, encoded.length); + pattern = Slices.wrappedBuffer(padded, 3, encoded.length); } - public Slice getRightWhitespace() + public Slice getPattern() { - return rightWhitespace; + return pattern; } - public Slice getBothWhitespace() + public int getEscapeChar() { - return bothWhitespace; + return -1; + } + } + + @State(Thread) + public static class TrinoPadData + { + @Param("128") + private int length; + + @Param({"true", "false"}) + private boolean ascii; + + private byte[] utf8; + private int offset; + private int byteLength; + private Slice padString; + + @Setup + public void setup() + { + int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS; + ThreadLocalRandom random = ThreadLocalRandom.current(); + DynamicSliceOutput out = new DynamicSliceOutput(length * 4); + for (int index = 0; index < length; index++) { + int codePoint = codePointSet[random.nextInt(codePointSet.length)]; + out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); + } + + byte[] encoded = out.slice().getBytes(); + offset = 9; + utf8 = new byte[offset + encoded.length + 3]; + System.arraycopy(encoded, 0, utf8, offset, encoded.length); + byteLength = encoded.length; + padString = Slices.wrappedBuffer(utf8, offset, byteLength); + } + + public byte[] getUtf8() + { + return utf8; + } + + public int getOffset() + { + return offset; + } + + public int getByteLength() + { + return byteLength; + } + + public Slice getPadString() + { + return padString; + } + } + + @State(Thread) + public static class TrinoPrefixRangeData + { + @Param("256") + private int length; + + @Param({"true", "false"}) + private boolean ascii; + + private byte[] utf8; + private int offset; + private int byteLength; + private Slice constantPrefix; + + @Setup + public void setup() + { + int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS; + ThreadLocalRandom random = ThreadLocalRandom.current(); + + int[] codePoints = new int[length]; + codePoints[0] = 'a'; + for (int index = 1; index < codePoints.length; index++) { + codePoints[index] = codePointSet[random.nextInt(codePointSet.length)]; + } + + DynamicSliceOutput out = new DynamicSliceOutput(length * 4); + for (int codePoint : codePoints) { + out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8)); + } + + byte[] encoded = out.slice().getBytes(); + offset = 13; + utf8 = new byte[offset + encoded.length + 5]; + System.arraycopy(encoded, 0, utf8, offset, encoded.length); + byteLength = encoded.length; + constantPrefix = Slices.wrappedBuffer(utf8, offset, byteLength); + } + + public byte[] getUtf8() + { + return utf8; + } + + public int getOffset() + { + return offset; + } + + public int getByteLength() + { + return byteLength; + } + + public Slice getConstantPrefix() + { + return constantPrefix; + } + } + + @State(Thread) + public static class CodePointWriteData + { + @Param("1000") + private int length; + + @Param({"true", "false"}) + private boolean ascii; + + private int[] codePoints; + private Slice output; + private int expectedBytes; + + @Setup + public void setup() + { + int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS; + ThreadLocalRandom random = ThreadLocalRandom.current(); + + codePoints = new int[length]; + expectedBytes = 0; + for (int i = 0; i < codePoints.length; i++) { + int codePoint = codePointSet[random.nextInt(codePointSet.length)]; + codePoints[i] = codePoint; + expectedBytes += lengthOfCodePoint(codePoint); + } + + byte[] buffer = new byte[11 + expectedBytes + 5]; + output = Slices.wrappedBuffer(buffer, 11, expectedBytes); + } + + public int[] getCodePoints() + { + return codePoints; + } + + public Slice getOutput() + { + return output; + } + + public int getExpectedBytes() + { + return expectedBytes; } } diff --git a/src/test/java/io/airlift/slice/TestSliceUtf8.java b/src/test/java/io/airlift/slice/TestSliceUtf8.java index 93304eb4..27e4be96 100644 --- a/src/test/java/io/airlift/slice/TestSliceUtf8.java +++ b/src/test/java/io/airlift/slice/TestSliceUtf8.java @@ -25,10 +25,12 @@ import java.util.stream.IntStream; import static com.google.common.primitives.Bytes.concat; +import static io.airlift.slice.SliceUtf8.codePointByteLengths; import static io.airlift.slice.SliceUtf8.codePointToUtf8; import static io.airlift.slice.SliceUtf8.compareUtf16BE; import static io.airlift.slice.SliceUtf8.countCodePoints; import static io.airlift.slice.SliceUtf8.fixInvalidUtf8; +import static io.airlift.slice.SliceUtf8.fromCodePoints; import static io.airlift.slice.SliceUtf8.getCodePointAt; import static io.airlift.slice.SliceUtf8.getCodePointBefore; import static io.airlift.slice.SliceUtf8.isAscii; @@ -41,9 +43,11 @@ import static io.airlift.slice.SliceUtf8.rightTrim; import static io.airlift.slice.SliceUtf8.setCodePointAt; import static io.airlift.slice.SliceUtf8.substring; +import static io.airlift.slice.SliceUtf8.toCodePoints; import static io.airlift.slice.SliceUtf8.toLowerCase; import static io.airlift.slice.SliceUtf8.toUpperCase; import static io.airlift.slice.SliceUtf8.trim; +import static io.airlift.slice.SliceUtf8.tryGetCodePointAt; import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; @@ -204,6 +208,172 @@ public void testCodePointCountRange() } } + @Test + public void testByteArrayOverloadsMatchSlice() + { + Slice utf8 = utf8Slice(" \tAé😀Z\n "); + byte[] bytes = utf8.getBytes(); + + byte[] padded = concat(new byte[] {'#'}, bytes, new byte[] {'!'}); + int offset = 1; + int length = bytes.length; + Slice view = wrappedBuffer(padded, offset, length); + + assertThat(isAscii(padded, offset, length)).isEqualTo(isAscii(view)); + assertThat(countCodePoints(padded, offset, length)).isEqualTo(countCodePoints(view)); + assertThat(countCodePoints(padded, offset + 1, length - 2)).isEqualTo(countCodePoints(view, 1, length - 2)); + + assertThat(substring(padded, offset, length, 1, 3)).isEqualTo(substring(view, 1, 3)); + assertThat(reverse(padded, offset, length)).isEqualTo(reverse(view)); + + Slice other = utf8Slice(" \tAé😀Y\n "); + assertThat(compareUtf16BE(padded, offset, length, other.byteArray(), other.byteArrayOffset(), other.length())) + .isEqualTo(compareUtf16BE(view, other)); + + assertThat(toLowerCase(padded, offset, length)).isEqualTo(toLowerCase(view)); + assertThat(toUpperCase(padded, offset, length)).isEqualTo(toUpperCase(view)); + + int[] trimCodePoints = new int[] {' ', '\t', '\n'}; + assertThat(leftTrim(padded, offset, length)).isEqualTo(leftTrim(view)); + assertThat(leftTrim(padded, offset, length, trimCodePoints)).isEqualTo(leftTrim(view, trimCodePoints)); + assertThat(rightTrim(padded, offset, length)).isEqualTo(rightTrim(view)); + assertThat(rightTrim(padded, offset, length, trimCodePoints)).isEqualTo(rightTrim(view, trimCodePoints)); + assertThat(trim(padded, offset, length)).isEqualTo(trim(view)); + assertThat(trim(padded, offset, length, trimCodePoints)).isEqualTo(trim(view, trimCodePoints)); + + byte[] invalid = concat(new byte[] {'x'}, INVALID_UTF8_2, new byte[] {'y'}); + assertThat(fixInvalidUtf8(invalid, 1, INVALID_UTF8_2.length)).isEqualTo(fixInvalidUtf8(wrappedBuffer(INVALID_UTF8_2))); + assertThat(fixInvalidUtf8(invalid, 1, INVALID_UTF8_2.length, java.util.OptionalInt.empty())) + .isEqualTo(fixInvalidUtf8(wrappedBuffer(INVALID_UTF8_2), java.util.OptionalInt.empty())); + assertThat(fixInvalidUtf8(invalid, 1, INVALID_UTF8_2.length, java.util.OptionalInt.of('?'))) + .isEqualTo(fixInvalidUtf8(wrappedBuffer(INVALID_UTF8_2), java.util.OptionalInt.of('?'))); + + int position = offsetOfCodePoint(view, 3); + assertThat(tryGetCodePointAt(padded, offset, length, position)).isEqualTo(tryGetCodePointAt(view, position)); + assertThat(offsetOfCodePoint(padded, offset, length, 3)).isEqualTo(offsetOfCodePoint(view, 3)); + assertThat(offsetOfCodePoint(padded, offset, length, 2, 2)).isEqualTo(offsetOfCodePoint(view, 2, 2)); + assertThat(lengthOfCodePoint(padded, offset, length, position)).isEqualTo(lengthOfCodePoint(view, position)); + assertThat(lengthOfCodePointSafe(padded, offset, length, position)).isEqualTo(lengthOfCodePointSafe(view, position)); + assertThat(getCodePointAt(padded, offset, length, position)).isEqualTo(getCodePointAt(view, position)); + + int nextPosition = position + lengthOfCodePoint(view, position); + assertThat(getCodePointBefore(padded, offset, length, nextPosition)).isEqualTo(getCodePointBefore(view, nextPosition)); + + Slice sliceTarget = Slices.allocate(8); + byte[] byteArrayTarget = new byte[8]; + int sliceWritten = setCodePointAt(0x1F600, sliceTarget, 0); + int arrayWritten = setCodePointAt(0x1F600, byteArrayTarget, 0, byteArrayTarget.length, 0); + assertThat(arrayWritten).isEqualTo(sliceWritten); + assertThat(wrappedBuffer(byteArrayTarget, 0, arrayWritten)).isEqualTo(sliceTarget.slice(0, sliceWritten)); + + assertThat(toCodePoints(padded, offset, length)).isEqualTo(toCodePoints(view)); + assertThat(codePointByteLengths(padded, offset, length)).isEqualTo(codePointByteLengths(view)); + assertThat(fromCodePoints(toCodePoints(view))).isEqualTo(view); + } + + @Test + public void testToCodePoints() + { + assertToCodePoints(STRING_EMPTY); + assertToCodePoints(STRING_HELLO); + assertToCodePoints(STRING_OESTERREICH); + assertToCodePoints(STRING_DULIOE_DULIOE); + assertToCodePoints(STRING_FAITH_HOPE_LOVE); + assertToCodePoints(STRING_OO); + assertToCodePoints(STRING_ASCII_CODE_POINTS); + assertToCodePoints(STRING_ALL_CODE_POINTS_RANDOM); + } + + private static void assertToCodePoints(String value) + { + Slice utf8 = utf8Slice(value); + int[] expectedCodePoints = value.codePoints().toArray(); + assertThat(toCodePoints(utf8)).isEqualTo(expectedCodePoints); + } + + @Test + public void testFromCodePoints() + { + assertFromCodePoints(STRING_EMPTY); + assertFromCodePoints(STRING_HELLO); + assertFromCodePoints(STRING_OESTERREICH); + assertFromCodePoints(STRING_DULIOE_DULIOE); + assertFromCodePoints(STRING_FAITH_HOPE_LOVE); + assertFromCodePoints(STRING_OO); + assertFromCodePoints(STRING_ASCII_CODE_POINTS); + assertFromCodePoints(STRING_ALL_CODE_POINTS_RANDOM); + } + + @Test + public void testFromCodePointsEncodesUtf8ByteWidths() + { + int[] codePoints = new int[] { + 0x0024, // $ + 0x00A2, // ¢ + 0x20AC, // € + 0x10348, // 𐍈 + }; + + byte[] expectedUtf8 = new byte[] { + 0x24, + (byte) 0xC2, (byte) 0xA2, + (byte) 0xE2, (byte) 0x82, (byte) 0xAC, + (byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0x88, + }; + + assertThat(fromCodePoints(codePoints)).isEqualTo(wrappedBuffer(expectedUtf8)); + } + + private static void assertFromCodePoints(String value) + { + Slice utf8 = utf8Slice(value); + int[] codePoints = value.codePoints().toArray(); + assertThat(fromCodePoints(codePoints)).isEqualTo(utf8); + } + + @Test + public void testToCodePointsInvalidUtf8() + { + assertThatThrownBy(() -> toCodePoints(wrappedBuffer(INVALID_UTF8_2))) + .isInstanceOf(InvalidUtf8Exception.class) + .hasMessageContaining("Invalid UTF-8 sequence at position"); + } + + @Test + public void testCodePointByteLengths() + { + assertCodePointByteLengths(STRING_EMPTY); + assertCodePointByteLengths(STRING_HELLO); + assertCodePointByteLengths(STRING_OESTERREICH); + assertCodePointByteLengths(STRING_DULIOE_DULIOE); + assertCodePointByteLengths(STRING_FAITH_HOPE_LOVE); + assertCodePointByteLengths(STRING_OO); + assertCodePointByteLengths(STRING_ASCII_CODE_POINTS); + assertCodePointByteLengths(STRING_ALL_CODE_POINTS_RANDOM); + } + + private static void assertCodePointByteLengths(String value) + { + Slice utf8 = utf8Slice(value); + int[] codePoints = value.codePoints().toArray(); + byte[] expectedLengths = new byte[codePoints.length]; + for (int index = 0; index < codePoints.length; index++) { + expectedLengths[index] = (byte) lengthOfCodePoint(codePoints[index]); + } + assertThat(codePointByteLengths(utf8)).isEqualTo(expectedLengths); + } + + @Test + public void testFromCodePointsInvalid() + { + assertThatThrownBy(() -> fromCodePoints(new int[] {-1})) + .isInstanceOf(InvalidCodePointException.class); + assertThatThrownBy(() -> fromCodePoints(new int[] {MIN_SURROGATE})) + .isInstanceOf(InvalidCodePointException.class); + assertThatThrownBy(() -> fromCodePoints(new int[] {MAX_CODE_POINT + 1})) + .isInstanceOf(InvalidCodePointException.class); + } + private static void assertCodePointCount(String string) { assertThat(countCodePoints(utf8Slice(string))).isEqualTo(string.codePoints().count()); @@ -877,4 +1047,12 @@ public void testSetCodePointContinuationByte() .isInstanceOf(InvalidCodePointException.class) .hasMessage("Invalid code point 0xFFFFFFBF"); } + + @Test + public void testSetCodePointAtByteArrayInvalidRange() + { + byte[] utf8 = new byte[8]; + assertThatThrownBy(() -> setCodePointAt('a', utf8, 7, 2, 0)) + .isInstanceOf(IndexOutOfBoundsException.class); + } }