diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index c0c8cf8e7f0..676236c344b 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -363,6 +363,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { std::vector evalues; std::vector tensors; + std::vector> strings; + std::vector>> string_refs; static const auto typeCodeField = JEValue::javaClassStatic()->getField("mTypeCode"); @@ -373,6 +375,24 @@ class ExecuTorchJni : public facebook::jni::HybridClass { if (typeCode == JEValue::kTypeCodeTensor) { tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); evalues.emplace_back(tensors.back()); + } else if (typeCode == JEValue::kTypeCodeString) { + static const auto toStrMethod = + JEValue::javaClassStatic() + ->getMethod()>("toStr"); + auto jstr = toStrMethod(jevalue); + if (!jstr) { + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), + "String EValue input at index " + std::to_string(i) + + " is null"); + return {}; + } + auto str = std::make_unique(jstr->toStdString()); + auto ref = std::make_unique>( + str->data(), str->size()); + evalues.emplace_back(ref.get()); + strings.push_back(std::move(str)); + string_refs.push_back(std::move(ref)); } else if (typeCode == JEValue::kTypeCodeInt) { static const auto toIntMethod = JEValue::javaClassStatic()->getMethod("toInt"); @@ -385,6 +405,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); + } else { + std::stringstream ss; + ss << "Unsupported input EValue type code: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str()); + return {}; } }