Skip to content

Commit bcc5073

Browse files
Googlercopybara-github
authored andcommitted
Add support for custom DropImpl in CRUBIT_OWNED_POINTEE and refactor IR.
- Support an optional second argument in CRUBIT_OWNED_POINTEE to specify a custom DropImpl method name. - Refactor Crubit IR to group owned_ptr_type and drop_impl into a single OwnedPtrConfig struct. - Apply default value of "DropImpl" at annotation parsing time. - Add GetAnnotationWithStringArgs helper in common/annotation_reader.h and use it in cxx_record.cc. PiperOrigin-RevId: 894250803
1 parent 3000d7d commit bcc5073

22 files changed

Lines changed: 316 additions & 29 deletions

common/annotation_reader.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <functional>
88
#include <optional>
99
#include <string>
10+
#include <vector>
1011

1112
#include "absl/base/attributes.h"
1213
#include "absl/base/nullability.h"
@@ -303,6 +304,30 @@ absl::StatusOr<std::optional<std::string>> GetAnnotationWithStringArg(
303304
return std::string(*arg);
304305
}
305306

307+
absl::StatusOr<std::optional<std::vector<std::string>>>
308+
GetAnnotationWithStringArgs(const clang::Decl& decl,
309+
absl::string_view annotation_name) {
310+
CRUBIT_ASSIGN_OR_RETURN(std::optional<AnnotateArgs> maybe_args,
311+
GetAnnotateAttrArgs(decl, annotation_name));
312+
if (!maybe_args.has_value()) {
313+
return std::nullopt;
314+
}
315+
const AnnotateArgs& args = *maybe_args;
316+
std::vector<std::string> result;
317+
result.reserve(args.size());
318+
for (const clang::Expr* arg_expr : args) {
319+
absl::StatusOr<absl::string_view> arg =
320+
GetExprAsStringLiteral(*arg_expr, decl.getASTContext());
321+
if (!arg.ok()) {
322+
return absl::InvalidArgumentError(
323+
absl::StrCat("Annotation ", annotation_name,
324+
" arguments must be string literals."));
325+
}
326+
result.push_back(std::string(*arg));
327+
}
328+
return result;
329+
}
330+
306331
absl::StatusOr<const clang::AnnotateTypeAttr* absl_nullable>
307332
GetTypeAnnotationSingleDecl(const clang::Type* absl_nonnull type
308333
ABSL_ATTRIBUTE_LIFETIME_BOUND,

common/annotation_reader.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ absl::StatusOr<absl::string_view> GetExprAsStringLiteral(
7979
absl::StatusOr<std::optional<std::string>> GetAnnotationWithStringArg(
8080
const clang::Decl& decl, absl::string_view annotation_name);
8181

82+
// Returns the string arguments of [[clang::annotate(annotation_name,
83+
// string_arg1, string_arg2, ...)]] annotation on `decl`, or none if the
84+
// annotation does not exist.
85+
//
86+
// Returns an error if there are conflicting annotations or if any argument is
87+
// not a string.
88+
absl::StatusOr<std::optional<std::vector<std::string>>>
89+
GetAnnotationWithStringArgs(const clang::Decl& decl,
90+
absl::string_view annotation_name);
91+
8292
// Returns true if `decl` has an annotation with the given name.
8393
//
8494
// Returns an error if an annotation with the given name exists, but it has

common/annotation_reader_test.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
namespace crubit {
1616
namespace {
1717

18+
using testing::ElementsAre;
1819
using testing::Eq;
1920
using testing::HasSubstr;
2021
using testing::Ne;
@@ -151,5 +152,41 @@ TEST(AnnotationReaderTest,
151152
ASSERT_THAT(GetAnnotateAttrArgs(var, "foo"), IsOkAndHolds(Ne(std::nullopt)));
152153
}
153154

155+
TEST(AnnotationReaderTest, GetAnnotationWithStringArgsSuccess) {
156+
clang::TestAST ast(R"cc(
157+
[[clang::annotate("foo", "arg1", "arg2")]] int i;
158+
)cc");
159+
160+
auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");
161+
162+
auto result = GetAnnotationWithStringArgs(var, "foo");
163+
ASSERT_THAT(result, IsOkAndHolds(Ne(std::nullopt)));
164+
EXPECT_THAT(**result, ElementsAre("arg1", "arg2"));
165+
}
166+
167+
TEST(AnnotationReaderTest, GetAnnotationWithStringArgsNone) {
168+
clang::TestAST ast(R"cc(
169+
int i;
170+
)cc");
171+
172+
auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");
173+
174+
EXPECT_THAT(GetAnnotationWithStringArgs(var, "foo"),
175+
IsOkAndHolds(Eq(std::nullopt)));
176+
}
177+
178+
TEST(AnnotationReaderTest, GetAnnotationWithStringArgsFailureNonString) {
179+
clang::TestAST ast(R"cc(
180+
[[clang::annotate("foo", "arg1", 42)]] int i;
181+
)cc");
182+
183+
auto& var = LookupDecl<clang::VarDecl>(ast.context(), "i");
184+
185+
EXPECT_THAT(
186+
GetAnnotationWithStringArgs(var, "foo"),
187+
StatusIs(absl::StatusCode::kInvalidArgument,
188+
HasSubstr("Annotation foo arguments must be string literals.")));
189+
}
190+
154191
} // namespace
155192
} // namespace crubit

rs_bindings_from_cc/generate_bindings/database/code_snippet.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ pub fn generated_items_to_tokens<'db>(
540540
nested_items,
541541
indirect_functions,
542542
delete,
543-
owned_type_name,
543+
owned_ptr_config,
544544
member_methods,
545545
free_functions,
546546
lifetime_params,
@@ -604,10 +604,12 @@ pub fn generated_items_to_tokens<'db>(
604604
None
605605
};
606606

607-
let owned_type_def = owned_type_name.as_ref().map(|owned_type_name| {
607+
let owned_type_def = owned_ptr_config.as_ref().map(|cfg| {
608+
let owned_type_name = &cfg.owned_type_name;
609+
let drop_meth = &cfg.drop_impl;
608610
let doc_comment = format!(
609-
"Wrapper for a C++ {} owned by Rust. \n\n Style guide: The C++ type to which this refers should be wrapped in an `Arc` or `Mutex` if it is not already thread-safe. \n\n THIS TYPE REQUIRES A MANUAL DROP IMPLEMENTATION. \n You MUST provide an `impl {} {{ pub fn DropImpl(&mut self) {{ ... }} }}` block in a separate Rust file (e.g., via `additional_rust_srcs`). Failure to do so will result in a compile-time error: `method not found in `{}``.",
610-
ident, owned_type_name, owned_type_name
611+
"Wrapper for a C++ {} owned by Rust. \n\n Style guide: The C++ type to which this refers should be wrapped in an `Arc` or `Mutex` if it is not already thread-safe. \n\n THIS TYPE REQUIRES A MANUAL DROP IMPLEMENTATION. \n You MUST provide an `impl {} {{ pub fn {}(&mut self) {{ ... }} }}` block in a separate Rust file (e.g., via `additional_rust_srcs`). Failure to do so will result in a compile-time error: `method not found in `{}``.",
612+
ident, owned_type_name, drop_meth, owned_type_name
611613
);
612614
quote! {
613615
__NEWLINE__ __NEWLINE__
@@ -618,10 +620,10 @@ pub fn generated_items_to_tokens<'db>(
618620

619621
impl Drop for #owned_type_name {
620622
fn drop(&mut self) {
621-
__COMMENT__ "IMPORTANT: The DropImpl method for `{}` MUST be implemented in a user-written .rs file (e.g., using `additional_rust_srcs`)."
623+
__COMMENT__ "IMPORTANT: The drop method MUST be implemented in a user-written .rs file (e.g., using `additional_rust_srcs`)."
622624
__COMMENT__ "Crubit cannot automatically generate the destruction logic for this type."
623625
__COMMENT__ "See the struct documentation for more details."
624-
self.DropImpl();
626+
self.#drop_meth();
625627
}
626628
}
627629
}
@@ -977,6 +979,12 @@ impl GeneratedItem {
977979
}
978980
}
979981

982+
#[derive(Clone, Debug)]
983+
pub struct OwnedPtrConfig {
984+
pub owned_type_name: Ident,
985+
pub drop_impl: Ident,
986+
}
987+
980988
#[derive(Clone, Debug)]
981989
pub struct Record {
982990
pub doc_comment_attr: Option<DocCommentAttr>,
@@ -1005,8 +1013,8 @@ pub struct Record {
10051013
/// Functions that get attached either by a trait or from a base class.
10061014
pub indirect_functions: Vec<TokenStream>,
10071015
pub delete: Option<DeleteImpl>,
1008-
/// The name of the owning wrapper type when the type was annotated with CRUBIT_OWNED_POINTEE.
1009-
pub owned_type_name: Option<Ident>,
1016+
/// The owning wrapper type configuration when the type was annotated with CRUBIT_OWNED_POINTEE.
1017+
pub owned_ptr_config: Option<OwnedPtrConfig>,
10101018
pub member_methods: Vec<TokenStream>,
10111019
pub free_functions: Vec<TokenStream>,
10121020
pub lifetime_params: Vec<syn::Lifetime>,

rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ impl RsTypeKind {
907907
lifetimes,
908908
in_cc_std,
909909
)?,
910-
owned_ptr_type: record.owned_ptr_type.clone(),
910+
owned_ptr_type: record.owned_ptr_config.as_ref().map(|cfg| cfg.owned_ptr_type.clone()),
911911
record,
912912
crate_path,
913913
lifetimes: lifetimes.to_vec(),
@@ -1511,7 +1511,7 @@ impl RsTypeKind {
15111511
)
15121512
};
15131513

1514-
let owned_ptr_type = record.owned_ptr_type.as_ref().expect(
1514+
let owned_ptr_type = record.owned_ptr_config.as_ref().map(|cfg| cfg.owned_ptr_type.as_ref()).expect(
15151515
"CRUBIT_OWNED_POINTER annotated pointers should point to a struct with an associated CRUBIT_OWNED_POINTEE",
15161516
);
15171517

rs_bindings_from_cc/generate_bindings/generate_struct_and_union.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,11 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
702702
})
703703
};
704704

705-
let owned_type_name = record.owned_ptr_type.as_ref().map(|opt| make_rs_ident(opt.as_ref()));
705+
let owned_ptr_config =
706+
record.owned_ptr_config.as_ref().map(|cfg| database::code_snippet::OwnedPtrConfig {
707+
owned_type_name: make_rs_ident(cfg.owned_ptr_type.as_ref()),
708+
drop_impl: make_rs_ident(cfg.drop_impl.as_ref()),
709+
});
706710
let member_methods = api_snippets.member_functions.remove(&record.id).unwrap_or_default();
707711
let free_functions = api_snippets.free_functions.remove(&record.id).unwrap_or_default();
708712

@@ -751,7 +755,7 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
751755
items,
752756
nested_items,
753757
indirect_functions,
754-
owned_type_name,
758+
owned_ptr_config,
755759
member_methods,
756760
free_functions,
757761
delete: operator_delete_impl,

rs_bindings_from_cc/importers/cxx_record.cc

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -922,12 +922,36 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
922922
std::optional<BridgeType> bridge_type =
923923
GetBridgeTypeAnnotation(ictx_, *record_decl);
924924

925-
absl::StatusOr<std::optional<std::string>> owned_ptr_type =
926-
GetAnnotationWithStringArg(*record_decl, "crubit_owned_pointee");
927-
if (!owned_ptr_type.ok()) {
925+
absl::StatusOr<std::optional<std::vector<std::string>>> args =
926+
GetAnnotationWithStringArgs(*record_decl, "crubit_owned_pointee");
927+
if (!args.ok()) {
928928
return ictx_.ImportUnsupportedItem(
929929
*record_decl, std::nullopt,
930-
FormattedError::FromStatus(std::move(owned_ptr_type).status()));
930+
FormattedError::FromStatus(std::move(args).status()));
931+
}
932+
933+
std::optional<OwnedPtrConfig> owned_ptr_config;
934+
935+
if (args->has_value()) {
936+
const auto& args_vec = **args;
937+
if (args_vec.empty() || args_vec.size() > 2) {
938+
return ictx_.ImportUnsupportedItem(
939+
*record_decl, std::nullopt,
940+
FormattedError::Static(
941+
"crubit_owned_pointee takes 1 or 2 arguments"));
942+
}
943+
944+
std::string owned_ptr_type = args_vec[0];
945+
std::string drop_impl = "DropImpl";
946+
947+
if (args_vec.size() == 2) {
948+
drop_impl = args_vec[1];
949+
}
950+
951+
owned_ptr_config = OwnedPtrConfig{
952+
.owned_ptr_type = std::move(owned_ptr_type),
953+
.drop_impl = std::move(drop_impl),
954+
};
931955
}
932956

933957
BazelLabel owning_target = ictx_.GetOwningTarget(record_decl);
@@ -1198,7 +1222,7 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
11981222
.unknown_attr = std::move(*unknown_attr),
11991223
.doc_comment = std::move(doc_comment),
12001224
.bridge_type = std::move(bridge_type),
1201-
.owned_ptr_type = *std::move(owned_ptr_type),
1225+
.owned_ptr_config = std::move(owned_ptr_config),
12021226
.source_loc = ictx_.ConvertSourceLocation(source_loc, nullptr),
12031227
.unambiguous_public_bases = GetUnambiguousPublicBases(*record_decl),
12041228
.fields = ImportFields(record_decl),

rs_bindings_from_cc/ir.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,13 @@ llvm::json::Value TraitDerives::ToJson() const {
670670
};
671671
}
672672

673+
llvm::json::Value OwnedPtrConfig::ToJson() const {
674+
return llvm::json::Object{
675+
{"owned_ptr_type", owned_ptr_type},
676+
{"drop_impl", drop_impl},
677+
};
678+
}
679+
673680
llvm::json::Value Record::ToJson() const {
674681
std::vector<llvm::json::Value> json_item_ids;
675682
json_item_ids.reserve(child_item_ids.size());
@@ -688,7 +695,6 @@ llvm::json::Value Record::ToJson() const {
688695
{"unknown_attr", unknown_attr},
689696
{"doc_comment", doc_comment},
690697
{"bridge_type", bridge_type},
691-
{"owned_ptr_type", owned_ptr_type},
692698
{"source_loc", source_loc},
693699
{"unambiguous_public_bases", unambiguous_public_bases},
694700
{"fields", fields},
@@ -716,6 +722,10 @@ llvm::json::Value Record::ToJson() const {
716722
{"is_thread_safe", is_thread_safe},
717723
};
718724

725+
if (owned_ptr_config.has_value()) {
726+
record.insert({"owned_ptr_config", owned_ptr_config->ToJson()});
727+
}
728+
719729
if (!lifetime_inputs.empty()) {
720730
record.insert({"lifetime_inputs", lifetime_inputs});
721731
}

rs_bindings_from_cc/ir.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,13 @@ struct TraitDerives {
706706
std::vector<std::string> custom;
707707
};
708708

709+
struct OwnedPtrConfig {
710+
llvm::json::Value ToJson() const;
711+
712+
std::string owned_ptr_type;
713+
std::string drop_impl;
714+
};
715+
709716
// A record (struct, class, union).
710717
struct Record {
711718
llvm::json::Value ToJson() const;
@@ -730,7 +737,7 @@ struct Record {
730737
std::optional<std::string> unknown_attr;
731738
std::optional<std::string> doc_comment;
732739
std::optional<BridgeType> bridge_type;
733-
std::optional<std::string> owned_ptr_type;
740+
std::optional<OwnedPtrConfig> owned_ptr_config;
734741
std::string source_loc;
735742
std::vector<BaseClass> unambiguous_public_bases;
736743
std::vector<Field> fields;

rs_bindings_from_cc/ir.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use code_gen_utils::make_rs_ident;
1111
use crubit_feature::CrubitFeature;
1212
use proc_macro2::{Ident, TokenStream};
1313
use quote::{quote, ToTokens};
14-
use serde::Deserialize;
14+
use serde::{Deserialize, Serialize};
1515
use std::cell::OnceCell;
1616
use std::cmp::Ordering;
1717
use std::collections::hash_map::{Entry, HashMap};
@@ -1165,6 +1165,13 @@ pub struct TraitDerives {
11651165
pub custom: Vec<Rc<str>>,
11661166
}
11671167

1168+
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
1169+
#[serde(deny_unknown_fields)]
1170+
pub struct OwnedPtrConfig {
1171+
pub owned_ptr_type: Rc<str>,
1172+
pub drop_impl: Rc<str>,
1173+
}
1174+
11681175
#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
11691176
#[serde(deny_unknown_fields)]
11701177
pub struct Record {
@@ -1191,7 +1198,8 @@ pub struct Record {
11911198
pub unknown_attr: Option<Rc<str>>,
11921199
pub doc_comment: Option<Rc<str>>,
11931200
pub bridge_type: Option<BridgeType>,
1194-
pub owned_ptr_type: Option<Rc<str>>,
1201+
#[serde(default)]
1202+
pub owned_ptr_config: Option<OwnedPtrConfig>,
11951203
pub source_loc: Rc<str>,
11961204
pub unambiguous_public_bases: Vec<BaseClass>,
11971205
pub fields: Vec<Field>,

0 commit comments

Comments
 (0)