diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 3b18915eae5..716d5aa5bd1 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -208,6 +208,7 @@ class GraphBuilder { ComputeGraph* compute_graph_; VkGraphPtr flatbuffer_; const uint8_t* constant_data_; + uint64_t constant_data_size_; const NamedDataMap* named_data_map_; std::vector loaded_buffers_from_map_; @@ -220,10 +221,12 @@ class GraphBuilder { ComputeGraph* compute_graph, VkGraphPtr flatbuffer, const uint8_t* constant_data, + uint64_t constant_data_size, const NamedDataMap* named_data_map) : compute_graph_(compute_graph), flatbuffer_(flatbuffer), constant_data_(constant_data), + constant_data_size_(constant_data_size), named_data_map_(named_data_map), loaded_buffers_from_map_(), ref_mapping_(), @@ -298,7 +301,33 @@ class GraphBuilder { ref = compute_graph_->add_tensorref( dims_vector, dtype, std::move(buffer.get())); } else { - const uint8_t* tensor_data = constant_data_ + constant_bytes->offset(); + const uint64_t offset = constant_bytes->offset(); + VK_CHECK_COND( + offset < constant_data_size_, + "Constant data offset %lu exceeds constant data size %lu", + (unsigned long)offset, + (unsigned long)constant_data_size_); + // Validate that the tensor's full byte extent fits within the + // constant data region. Dims originate from an untrusted flatbuffer, + // so use overflow-safe multiplication. + const uint64_t max_extent = constant_data_size_ - offset; + uint64_t tensor_byte_size = vkapi::element_size(dtype); + for (int64_t dim : dims_vector) { + const uint64_t udim = static_cast(dim); + VK_CHECK_COND( + udim == 0 || tensor_byte_size <= max_extent / udim, + "Tensor byte extent at offset %lu exceeds constant data size %lu", + (unsigned long)offset, + (unsigned long)constant_data_size_); + tensor_byte_size *= udim; + } + VK_CHECK_COND( + tensor_byte_size <= max_extent, + "Tensor byte size %lu at offset %lu exceeds constant data size %lu", + (unsigned long)tensor_byte_size, + (unsigned long)offset, + (unsigned long)constant_data_size_); + const uint8_t* tensor_data = constant_data_ + offset; ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); } } else { @@ -591,19 +620,25 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { ET_NODISCARD Error compileModel( const void* buffer_pointer, + size_t buffer_size, ComputeGraph* compute_graph, const NamedDataMap* named_data_map) const { Result header = - VulkanDelegateHeader::parse(buffer_pointer); + VulkanDelegateHeader::parse(buffer_pointer, buffer_size); const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; + uint64_t constant_data_size = 0; if (header.ok()) { const uint8_t* buffer_start = reinterpret_cast(buffer_pointer); flatbuffer_data = buffer_start + header->flatbuffer_offset; constant_data = buffer_start + header->bytes_offset; + constant_data_size = header->bytes_size; + if (constant_data_size == 0 && buffer_size > header->bytes_offset) { + constant_data_size = buffer_size - header->bytes_offset; + } } else { ET_LOG(Error, "VulkanDelegateHeader may be corrupt"); return header.error(); @@ -616,10 +651,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { flatbuffers::GetBufferIdentifier(flatbuffer_data), vkgraph::VkGraphIdentifier()); + // Verify FlatBuffer structural integrity before parsing. + flatbuffers::Verifier verifier(flatbuffer_data, header->flatbuffer_size); + ET_CHECK_OR_RETURN_ERROR( + vkgraph::VerifyVkGraphBuffer(verifier), + DelegateInvalidCompatibility, + "VkGraph FlatBuffer verification failed"); + VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); GraphBuilder builder( - compute_graph, flatbuffer_graph, constant_data, named_data_map); + compute_graph, + flatbuffer_graph, + constant_data, + constant_data_size, + named_data_map); builder.resolve_layouts(); builder.build_graph(); @@ -649,7 +695,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { new (compute_graph) ComputeGraph(graph_config); const NamedDataMap* named_data_map = context.get_named_data_map(); - Error err = compileModel(processed->data(), compute_graph, named_data_map); + Error err = compileModel( + processed->data(), processed->size(), compute_graph, named_data_map); // This backend does not need its processed data after compiling the // model. diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.cpp b/backends/vulkan/runtime/VulkanDelegateHeader.cpp index 2a235144342..83bc263350b 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.cpp +++ b/backends/vulkan/runtime/VulkanDelegateHeader.cpp @@ -84,7 +84,15 @@ bool VulkanDelegateHeader::is_valid() const { return true; } -Result VulkanDelegateHeader::parse(const void* data) { +Result VulkanDelegateHeader::parse( + const void* data, + size_t buffer_size) { + // Ensure the buffer is large enough to read all header fields. + // The last field (bytes_size) ends at offset 22 + 8 = 30 = kExpectedSize. + if (buffer_size < kExpectedSize) { + return Error::InvalidArgument; + } + const uint8_t* header_data = (const uint8_t*)data; const uint8_t* magic_start = header_data + kMagic.offset; @@ -104,6 +112,16 @@ Result VulkanDelegateHeader::parse(const void* data) { return Error::InvalidArgument; } + // Validate that header offsets do not extend beyond the buffer. + if (header.flatbuffer_offset > buffer_size || + header.flatbuffer_size > buffer_size - header.flatbuffer_offset) { + return Error::InvalidArgument; + } + if (header.bytes_offset > buffer_size || + header.bytes_size > buffer_size - header.bytes_offset) { + return Error::InvalidArgument; + } + return header; } diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.h b/backends/vulkan/runtime/VulkanDelegateHeader.h index 722f01cbb75..312b08ab747 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.h +++ b/backends/vulkan/runtime/VulkanDelegateHeader.h @@ -26,7 +26,8 @@ struct VulkanDelegateHeader { bool is_valid() const; static executorch::runtime::Result parse( - const void* data); + const void* data, + size_t buffer_size); uint32_t header_size; uint32_t flatbuffer_offset;