Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class GraphBuilder {
ComputeGraph* compute_graph_;
VkGraphPtr flatbuffer_;
const uint8_t* constant_data_;
uint64_t constant_data_size_;
const NamedDataMap* named_data_map_;
std::vector<FreeableBuffer> loaded_buffers_from_map_;

Expand All @@ -220,10 +221,12 @@ class GraphBuilder {
ComputeGraph* compute_graph,
VkGraphPtr flatbuffer,
const uint8_t* constant_data,
uint64_t constant_data_size,
const NamedDataMap* named_data_map)
: compute_graph_(compute_graph),
flatbuffer_(flatbuffer),
constant_data_(constant_data),
constant_data_size_(constant_data_size),
named_data_map_(named_data_map),
loaded_buffers_from_map_(),
ref_mapping_(),
Expand Down Expand Up @@ -298,7 +301,33 @@ class GraphBuilder {
ref = compute_graph_->add_tensorref(
dims_vector, dtype, std::move(buffer.get()));
} else {
const uint8_t* tensor_data = constant_data_ + constant_bytes->offset();
const uint64_t offset = constant_bytes->offset();
VK_CHECK_COND(
offset < constant_data_size_,
"Constant data offset %lu exceeds constant data size %lu",
(unsigned long)offset,
(unsigned long)constant_data_size_);
// Validate that the tensor's full byte extent fits within the
// constant data region. Dims originate from an untrusted flatbuffer,
// so use overflow-safe multiplication.
const uint64_t max_extent = constant_data_size_ - offset;
uint64_t tensor_byte_size = vkapi::element_size(dtype);
for (int64_t dim : dims_vector) {
const uint64_t udim = static_cast<uint64_t>(dim);
VK_CHECK_COND(
udim == 0 || tensor_byte_size <= max_extent / udim,
"Tensor byte extent at offset %lu exceeds constant data size %lu",
(unsigned long)offset,
(unsigned long)constant_data_size_);
tensor_byte_size *= udim;
}
VK_CHECK_COND(
tensor_byte_size <= max_extent,
"Tensor byte size %lu at offset %lu exceeds constant data size %lu",
(unsigned long)tensor_byte_size,
(unsigned long)offset,
(unsigned long)constant_data_size_);
const uint8_t* tensor_data = constant_data_ + offset;
ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
}
} else {
Expand Down Expand Up @@ -591,19 +620,25 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {

ET_NODISCARD Error compileModel(
const void* buffer_pointer,
size_t buffer_size,
ComputeGraph* compute_graph,
const NamedDataMap* named_data_map) const {
Result<VulkanDelegateHeader> header =
VulkanDelegateHeader::parse(buffer_pointer);
VulkanDelegateHeader::parse(buffer_pointer, buffer_size);

const uint8_t* flatbuffer_data = nullptr;
const uint8_t* constant_data = nullptr;
uint64_t constant_data_size = 0;

if (header.ok()) {
const uint8_t* buffer_start =
reinterpret_cast<const uint8_t*>(buffer_pointer);
flatbuffer_data = buffer_start + header->flatbuffer_offset;
constant_data = buffer_start + header->bytes_offset;
constant_data_size = header->bytes_size;
if (constant_data_size == 0 && buffer_size > header->bytes_offset) {
constant_data_size = buffer_size - header->bytes_offset;
}
} else {
ET_LOG(Error, "VulkanDelegateHeader may be corrupt");
return header.error();
Expand All @@ -616,10 +651,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
flatbuffers::GetBufferIdentifier(flatbuffer_data),
vkgraph::VkGraphIdentifier());

// Verify FlatBuffer structural integrity before parsing.
flatbuffers::Verifier verifier(flatbuffer_data, header->flatbuffer_size);
ET_CHECK_OR_RETURN_ERROR(
vkgraph::VerifyVkGraphBuffer(verifier),
DelegateInvalidCompatibility,
"VkGraph FlatBuffer verification failed");

VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);

GraphBuilder builder(
compute_graph, flatbuffer_graph, constant_data, named_data_map);
compute_graph,
flatbuffer_graph,
constant_data,
constant_data_size,
named_data_map);

builder.resolve_layouts();
builder.build_graph();
Expand Down Expand Up @@ -649,7 +695,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
new (compute_graph) ComputeGraph(graph_config);

const NamedDataMap* named_data_map = context.get_named_data_map();
Error err = compileModel(processed->data(), compute_graph, named_data_map);
Error err = compileModel(
processed->data(), processed->size(), compute_graph, named_data_map);

// This backend does not need its processed data after compiling the
// model.
Expand Down
20 changes: 19 additions & 1 deletion backends/vulkan/runtime/VulkanDelegateHeader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,15 @@ bool VulkanDelegateHeader::is_valid() const {
return true;
}

Result<VulkanDelegateHeader> VulkanDelegateHeader::parse(const void* data) {
Result<VulkanDelegateHeader> VulkanDelegateHeader::parse(
const void* data,
size_t buffer_size) {
// Ensure the buffer is large enough to read all header fields.
// The last field (bytes_size) ends at offset 22 + 8 = 30 = kExpectedSize.
if (buffer_size < kExpectedSize) {
return Error::InvalidArgument;
}

const uint8_t* header_data = (const uint8_t*)data;

const uint8_t* magic_start = header_data + kMagic.offset;
Expand All @@ -104,6 +112,16 @@ Result<VulkanDelegateHeader> VulkanDelegateHeader::parse(const void* data) {
return Error::InvalidArgument;
}

// Validate that header offsets do not extend beyond the buffer.
if (header.flatbuffer_offset > buffer_size ||
header.flatbuffer_size > buffer_size - header.flatbuffer_offset) {
return Error::InvalidArgument;
}
if (header.bytes_offset > buffer_size ||
header.bytes_size > buffer_size - header.bytes_offset) {
return Error::InvalidArgument;
}

return header;
}

Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/VulkanDelegateHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ struct VulkanDelegateHeader {
bool is_valid() const;

static executorch::runtime::Result<VulkanDelegateHeader> parse(
const void* data);
const void* data,
size_t buffer_size);

uint32_t header_size;
uint32_t flatbuffer_offset;
Expand Down
Loading