diff --git a/CHANGELOG.md b/CHANGELOG.md index e0eecb0..becf962 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Unreleased -- Rename `align` attribute of `ShaderType` to `shader_align` +- Place `align`/`size` inside a `shader()` attribute instead of being free-floating ## v0.11.2 (2025-08-25) diff --git a/README.md b/README.md index 0fead81..c6fad04 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ use encase::{ShaderType, ArrayLength, StorageBuffer}; #[derive(ShaderType)] struct Positions { length: ArrayLength, - #[size(runtime)] + #[shader(size(runtime))] positions: Vec> } diff --git a/benches/throughput.rs b/benches/throughput.rs index 600cbe3..119a8c6 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -64,7 +64,7 @@ struct A { arrm2: [mint::ColumnMatrix2; 8], arrm3: [mint::ColumnMatrix3; 8], arrm4: [mint::ColumnMatrix4; 8], - #[size(1600)] + #[shader(size(1600))] _pad: u32, } diff --git a/derive/impl/Cargo.toml b/derive/impl/Cargo.toml index fae5b13..659c733 100644 --- a/derive/impl/Cargo.toml +++ b/derive/impl/Cargo.toml @@ -11,6 +11,6 @@ keywords = ["wgsl", "wgpu"] categories = ["rendering"] [dependencies] -syn = "2.0.1" +syn = { version = "2.0.1", features = ["extra-traits"] } quote = "1" proc-macro2 = "1" diff --git a/derive/impl/src/lib.rs b/derive/impl/src/lib.rs index feee0e6..fdba48a 100644 --- a/derive/impl/src/lib.rs +++ b/derive/impl/src/lib.rs @@ -1,12 +1,14 @@ use proc_macro2::{Ident, Literal, Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; use syn::{ + parenthesized, parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, - Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Type, + Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Token, + Type, }; pub use syn; @@ -14,7 +16,7 @@ pub use syn; #[macro_export] macro_rules! implement { ($path:expr) => { - #[proc_macro_derive(ShaderType, attributes(shader_align, size))] + #[proc_macro_derive(ShaderType, attributes(shader))] pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput); let expanded = encase_derive_impl::derive_shader_type(input, &$path); @@ -99,7 +101,8 @@ impl FieldData { } } -struct AlignmentAttr(u32); +#[derive(Debug)] +pub struct AlignmentAttr(u32); impl Parse for AlignmentAttr { fn parse(input: ParseStream) -> syn::Result { @@ -116,33 +119,37 @@ impl Parse for AlignmentAttr { } } -struct StaticSizeAttr(u32); +#[derive(Debug)] +pub struct StaticSizeAttr(u32); impl Parse for StaticSizeAttr { fn parse(input: ParseStream) -> syn::Result { + let span = input.span(); match input .parse::() .and_then(|lit| lit.base10_parse::()) { Ok(num) => Ok(Self(num)), - _ => Err(syn::Error::new(input.span(), "expected u32 literal")), + _ => Err(syn::Error::new(span, "expected u32 literal")), } } } -enum SizeAttr { +#[derive(Debug)] +pub enum SizeAttr { Static(StaticSizeAttr), Runtime, } impl Parse for SizeAttr { fn parse(input: ParseStream) -> syn::Result { + let span = input.span(); match input.parse::() { Ok(static_size) => Ok(SizeAttr::Static(static_size)), _ => match input.parse::() { Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime), _ => Err(syn::Error::new( - input.span(), + span, "expected u32 literal or `runtime` identifier", )), }, @@ -150,6 +157,70 @@ impl Parse for SizeAttr { } } +#[derive(Debug)] +pub enum ShaderAttr { + Align { attr: AlignmentAttr, span: Span }, + Size { attr: SizeAttr, span: Span }, +} + +impl Parse for ShaderAttr { + fn parse(input: ParseStream) -> syn::Result { + let ident_span = input.span(); + let Ok(ident) = input.parse::() else { + return Err(syn::Error::new(ident_span, "expected `align` or `size`")); + }; + + match ident.to_string().as_str() { + "align" => { + if !input.peek(syn::token::Paren) { + return Err(syn::Error::new( + ident_span, + "expected attribute arguments in parentheses: `align(...)`", + )); + } + + let args; + parenthesized!(args in input); + let attr_span = args.span(); + let align_attr: AlignmentAttr = args.parse()?; + Ok(ShaderAttr::Align { + attr: align_attr, + span: attr_span, + }) + } + "size" => { + if !input.peek(syn::token::Paren) { + return Err(syn::Error::new( + ident_span, + "expected attribute arguments in parentheses: `size(...)`", + )); + } + + let args; + parenthesized!(args in input); + let attr_span = args.span(); + let size_attr: SizeAttr = args.parse()?; + Ok(ShaderAttr::Size { + attr: size_attr, + span: attr_span, + }) + } + _ => Err(syn::Error::new( + ident_span, + "unknown shader attribute, expected `align` or `size`", + )), + } + } +} + +#[derive(Debug)] +pub struct ShaderAttrList(Punctuated); +impl Parse for ShaderAttrList { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self(input.parse_terminated(ShaderAttr::parse, Token![,])?)) + } +} + struct Errors { inner: Option, } @@ -196,41 +267,44 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { size: None, align: None, }; + for attr in &field.attrs { - if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("shader_align")) - { + if !(attr.meta.path().is_ident("shader")) { continue; } - match attr.meta.require_list() { - Ok(meta_list) => { - let span = meta_list.tokens.span(); - if meta_list.path.is_ident("shader_align") { - let res = attr.parse_args::(); - match res { - Ok(val) => data.align = Some((val.0, span)), - Err(err) => errors.append(err), + + let shader_attrs = match attr.parse_args::() { + Ok(attrs) => attrs, + Err(err) => { + errors.append(err); + continue; + } + }; + + for shader_attr in shader_attrs.0 { + match shader_attr { + ShaderAttr::Align { attr, span } => { + data.align = Some((attr.0, span)); + } + ShaderAttr::Size { attr, span } => match attr { + SizeAttr::Runtime => { + if i == last_field_index { + is_runtime_sized = true; + } else { + let err = syn::Error::new( + span, + "only the last field can be `size(runtime)`", + ); + errors.append(err); + continue; + } } - } else if meta_list.path.is_ident("size") { - let res = if i == last_field_index { - attr.parse_args::().map(|val| match val { - SizeAttr::Runtime => { - is_runtime_sized = true; - None - } - SizeAttr::Static(size) => Some((size.0, span)), - }) - } else { - attr.parse_args::() - .map(|val| Some((val.0, span))) - }; - match res { - Ok(val) => data.size = val, - Err(err) => errors.append(err), + SizeAttr::Static(attr) => { + data.size = Some((attr.0, span)); } - } + }, } - Err(err) => errors.append(err), - }; + } } data }) @@ -255,7 +329,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { if !is_runtime_sized { let err = syn::Error::new( field.ty.span(), - "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[size(runtime)]`!", + "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[shader(size(runtime))]`!", ); errors.append(err) } @@ -305,7 +379,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get(); #root::concat_assert!( alignment <= #align, - "shader_align attribute value must be at least ", alignment, " (field's type alignment)" + "shader(align) attribute value must be at least ", alignment, " (field's type alignment)" ) } check(); diff --git a/src/core/traits.rs b/src/core/traits.rs index b801ece..b75658f 100644 --- a/src/core/traits.rs +++ b/src/core/traits.rs @@ -148,7 +148,7 @@ pub trait ShaderType { /// # use mint; /// #[derive(ShaderType)] /// struct Invalid { - /// #[size(runtime)] + /// #[shader(size(runtime))] /// vec: Vec> /// } /// Invalid::assert_uniform_compat(); @@ -171,7 +171,7 @@ pub trait ShaderType { /// Invalid::assert_uniform_compat(); /// ``` /// - /// Will not panic (fixed via shader_align attribute) + /// Will not panic (fixed via #[shader(align)] attribute) /// /// ``` /// # use crate::encase::ShaderType; @@ -182,7 +182,7 @@ pub trait ShaderType { /// #[derive(ShaderType)] /// struct Valid { /// a: f32, - /// #[shader_align(16)] + /// #[shader(align(16))] /// b: S, /// } /// Valid::assert_uniform_compat(); @@ -198,7 +198,7 @@ pub trait ShaderType { /// # } /// #[derive(ShaderType)] /// struct Valid { - /// #[size(16)] + /// #[shader(size(16))] /// a: f32, /// b: S, /// } diff --git a/src/lib.rs b/src/lib.rs index 3f7d619..b374be1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,15 +26,15 @@ /// /// Field attributes /// -/// - `#[shader_align(X)]` where `X` is a power of 2 [`u32`] literal (equivalent to [WGSL align attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-align)) +/// - `#[shader(align(X))]` where `X` is a power of 2 [`u32`] literal (equivalent to [WGSL align attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-align)) /// /// Used to increase the alignment of the field /// -/// - `#[size(X)]` where `X` is a [`u32`] literal (equivalent to [WGSL size attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-size)) +/// - `#[shader(size(X))]` where `X` is a [`u32`] literal (equivalent to [WGSL size attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-size)) /// /// Used to increase the size of the field /// -/// - `#[size(runtime)]` can only be attached to the last field of the struct +/// - `#[shader(size(runtime))]` can only be attached to the last field of the struct /// /// Used to denote the fact that the field it is attached to is a runtime-sized array /// @@ -42,7 +42,7 @@ /// /// While structs using generic type parameters are supported by this derive macro /// -/// - the `#[shader_align(X)]` and `#[size(X)]` attributes will only work +/// - the `#[shader(align(X))]` and `#[shader(size(X))]` attributes will only work /// if they are attached to fields whose type contains no generic type parameters /// /// # Examples @@ -70,7 +70,7 @@ /// #[derive(ShaderType)] /// struct Positions { /// length: ArrayLength, -/// #[size(runtime)] +/// #[shader(size(runtime))] /// positions: Vec> /// } /// ``` @@ -88,7 +88,7 @@ /// const N: usize, /// > { /// array: [&'a mut E; N], -/// #[size(runtime)] +/// #[shader(size(runtime))] /// rts_array: &'a mut Vec<&'b T>, /// } /// ``` diff --git a/tests/assert_uniform_compat_fail.rs b/tests/assert_uniform_compat_fail.rs index 9bc0c96..0f7cb92 100644 --- a/tests/assert_uniform_compat_fail.rs +++ b/tests/assert_uniform_compat_fail.rs @@ -7,7 +7,7 @@ struct S { #[derive(ShaderType)] struct WrappedF32 { - #[size(16)] + #[shader(size(16))] elem: f32, } @@ -63,7 +63,7 @@ fn test_array_stride() { fn test_rts_array() { #[derive(ShaderType)] struct TestRTSArray { - #[size(runtime)] + #[shader(size(runtime))] a: Vec, } diff --git a/tests/assert_uniform_compat_success.rs b/tests/assert_uniform_compat_success.rs index 6b553f8..e91cdc0 100644 --- a/tests/assert_uniform_compat_success.rs +++ b/tests/assert_uniform_compat_success.rs @@ -7,28 +7,28 @@ struct S { #[derive(ShaderType)] struct WrappedF32 { - #[size(16)] + #[shader(size(16))] elem: f32, } #[derive(ShaderType)] struct TestStruct { a: u32, - #[shader_align(16)] + #[shader(align(16))] b: S, } #[derive(ShaderType)] struct TestArray { a: u32, - #[shader_align(16)] + #[shader(align(16))] b: [WrappedF32; 1], } #[derive(ShaderType)] struct TestStructFirst { a: S, - #[shader_align(16)] + #[shader(align(16))] b: f32, } diff --git a/tests/compile_fail/array_length_err.stderr b/tests/compile_fail/array_length_err.stderr index e32d00a..340c25f 100644 --- a/tests/compile_fail/array_length_err.stderr +++ b/tests/compile_fail/array_length_err.stderr @@ -1,4 +1,4 @@ -error: `ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[size(runtime)]`! +error: `ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[shader(size(runtime))]`! --> tests/compile_fail/array_length_err.rs:7:8 | 7 | a: ArrayLength, diff --git a/tests/compile_fail/invalid_align_attr.rs b/tests/compile_fail/invalid_align_attr.rs index b521097..ff683fd 100644 --- a/tests/compile_fail/invalid_align_attr.rs +++ b/tests/compile_fail/invalid_align_attr.rs @@ -4,12 +4,12 @@ fn main() {} #[derive(ShaderType)] struct Test { - #[shader_align] + #[shader(align)] a: u32, - #[shader_align()] + #[shader(align())] b: u32, - #[shader_align(invalid)] + #[shader(align(invalid))] c: u32, - #[shader_align(3)] + #[shader(align(3))] d: u32, } diff --git a/tests/compile_fail/invalid_align_attr.stderr b/tests/compile_fail/invalid_align_attr.stderr index 9847db8..9db4751 100644 --- a/tests/compile_fail/invalid_align_attr.stderr +++ b/tests/compile_fail/invalid_align_attr.stderr @@ -1,23 +1,23 @@ -error: expected attribute arguments in parentheses: `shader_align(...)` - --> tests/compile_fail/invalid_align_attr.rs:7:7 +error: expected attribute arguments in parentheses: `align(...)` + --> tests/compile_fail/invalid_align_attr.rs:7:14 | -7 | #[shader_align] - | ^^^^^^^^^^^^ +7 | #[shader(align)] + | ^^^^^ error: expected a power of 2 u32 literal --> tests/compile_fail/invalid_align_attr.rs:9:20 | -9 | #[shader_align()] +9 | #[shader(align())] | ^ error: expected a power of 2 u32 literal --> tests/compile_fail/invalid_align_attr.rs:11:20 | -11 | #[shader_align(invalid)] +11 | #[shader(align(invalid))] | ^^^^^^^ error: expected a power of 2 u32 literal --> tests/compile_fail/invalid_align_attr.rs:13:21 | -13 | #[shader_align(3)] +13 | #[shader(align(3))] | ^ diff --git a/tests/compile_fail/invalid_size_attr.rs b/tests/compile_fail/invalid_size_attr.rs index 64176cd..ee391f2 100644 --- a/tests/compile_fail/invalid_size_attr.rs +++ b/tests/compile_fail/invalid_size_attr.rs @@ -4,12 +4,12 @@ fn main() {} #[derive(ShaderType)] struct Test { - #[size] + #[shader(size)] a: u32, - #[size()] + #[shader(size())] b: u32, - #[size(invalid)] + #[shader(size(invalid))] c: u32, - #[size(-1)] + #[shader(size(-1))] d: u32, } diff --git a/tests/compile_fail/invalid_size_attr.stderr b/tests/compile_fail/invalid_size_attr.stderr index 92229fc..1b5e793 100644 --- a/tests/compile_fail/invalid_size_attr.stderr +++ b/tests/compile_fail/invalid_size_attr.stderr @@ -1,23 +1,23 @@ error: expected attribute arguments in parentheses: `size(...)` - --> tests/compile_fail/invalid_size_attr.rs:7:7 + --> tests/compile_fail/invalid_size_attr.rs:7:14 | -7 | #[size] - | ^^^^ +7 | #[shader(size)] + | ^^^^ -error: expected u32 literal - --> tests/compile_fail/invalid_size_attr.rs:9:12 +error: expected u32 literal or `runtime` identifier + --> tests/compile_fail/invalid_size_attr.rs:9:19 | -9 | #[size()] - | ^ +9 | #[shader(size())] + | ^ -error: expected u32 literal - --> tests/compile_fail/invalid_size_attr.rs:11:12 +error: expected u32 literal or `runtime` identifier + --> tests/compile_fail/invalid_size_attr.rs:11:19 | -11 | #[size(invalid)] - | ^^^^^^^ +11 | #[shader(size(invalid))] + | ^^^^^^^ error: expected u32 literal or `runtime` identifier - --> tests/compile_fail/invalid_size_attr.rs:13:14 + --> tests/compile_fail/invalid_size_attr.rs:13:19 | -13 | #[size(-1)] - | ^ +13 | #[shader(size(-1))] + | ^ diff --git a/tests/general.rs b/tests/general.rs index 1e03d31..473c18f 100644 --- a/tests/general.rs +++ b/tests/general.rs @@ -62,7 +62,7 @@ struct A { arrm3: [mint::ColumnMatrix3; 8], arrm4: [mint::ColumnMatrix4; 8], rt_arr_len: ArrayLength, - #[size(runtime)] + #[shader(size(runtime))] rt_arr: Vec>, } diff --git a/tests/hygiene.rs b/tests/hygiene.rs index 5f34bf4..c18baae 100644 --- a/tests/hygiene.rs +++ b/tests/hygiene.rs @@ -106,10 +106,9 @@ struct TestGeneric< T: 'a + ::encase::ShaderType + ::encase::ShaderSize, const N: ::core::primitive::usize, > { - #[size(90)] + #[shader(size(90))] a: &'a mut Test, b: &'a mut [T; N], - #[shader_align(16)] - #[size(runtime)] + #[shader(align(16), size(runtime))] c: &'a mut ::std::vec::Vec<[::mint::Vector3<::core::primitive::f32>; 2]>, } diff --git a/tests/metadata.rs b/tests/metadata.rs index 6f2bab8..c1458d3 100644 --- a/tests/metadata.rs +++ b/tests/metadata.rs @@ -2,7 +2,7 @@ use encase::ShaderType; #[derive(ShaderType)] struct WrappedF32 { - #[size(16)] + #[shader(size(16), align(16))] value: f32, } diff --git a/tests/pass/attributes.rs b/tests/pass/attributes.rs index 7566dcd..7b24cb8 100644 --- a/tests/pass/attributes.rs +++ b/tests/pass/attributes.rs @@ -4,18 +4,17 @@ fn main() {} #[derive(ShaderType)] struct TestAttributes { - #[shader_align(16)] + #[shader(align(16))] a: u32, - #[size(8)] + #[shader(size(8))] b: u32, } #[derive(ShaderType)] struct TestRtArray { - #[size(8)] + #[shader(size(8))] a: u32, - #[shader_align(16)] - #[size(runtime)] + #[shader(align(16), size(runtime))] b: Vec, } diff --git a/tests/wgpu.rs b/tests/wgpu.rs index a1500b1..ec4c6b3 100644 --- a/tests/wgpu.rs +++ b/tests/wgpu.rs @@ -10,8 +10,7 @@ struct A { u: u32, v: u32, w: Vector2, - #[size(16)] - #[shader_align(8)] + #[shader(size(16), align(8))] x: u32, xx: u32, } @@ -22,13 +21,12 @@ struct B { b: Vector3, c: u32, d: u32, - #[shader_align(16)] + #[shader(align(16))] e: A, f: Vector3, g: [A; 3], h: i32, - #[shader_align(32)] - #[size(runtime)] + #[shader(align(32), size(runtime))] i: Vec, } @@ -123,8 +121,7 @@ fn array_length() { array_length: ArrayLength, array_length_call_ret_val: u32, a: Vector3, - #[shader_align(16)] - #[size(runtime)] + #[shader(align(16), size(runtime))] arr: Vec, }