Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Unreleased

- Rename `align` attribute of `ShaderType` to `shader_align`
- Place `align`/`size` inside a `shader()` attribute instead of being free-floating

## v0.11.2 (2025-08-25)

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use encase::{ShaderType, ArrayLength, StorageBuffer};
#[derive(ShaderType)]
struct Positions {
length: ArrayLength,
#[size(runtime)]
#[shader(size(runtime))]
positions: Vec<mint::Point2<f32>>
}

Expand Down
2 changes: 1 addition & 1 deletion benches/throughput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct A {
arrm2: [mint::ColumnMatrix2<f32>; 8],
arrm3: [mint::ColumnMatrix3<f32>; 8],
arrm4: [mint::ColumnMatrix4<f32>; 8],
#[size(1600)]
#[shader(size(1600))]
_pad: u32,
}

Expand Down
2 changes: 1 addition & 1 deletion derive/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ keywords = ["wgsl", "wgpu"]
categories = ["rendering"]

[dependencies]
syn = "2.0.1"
syn = { version = "2.0.1", features = ["extra-traits"] }
quote = "1"
proc-macro2 = "1"
150 changes: 112 additions & 38 deletions derive/impl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use proc_macro2::{Ident, Literal, Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token::Comma,
Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Type,
Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Token,
Type,
};

pub use syn;

#[macro_export]
macro_rules! implement {
($path:expr) => {
#[proc_macro_derive(ShaderType, attributes(shader_align, size))]
#[proc_macro_derive(ShaderType, attributes(shader))]
pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
let expanded = encase_derive_impl::derive_shader_type(input, &$path);
Expand Down Expand Up @@ -99,7 +101,8 @@ impl FieldData {
}
}

struct AlignmentAttr(u32);
#[derive(Debug)]
pub struct AlignmentAttr(u32);

impl Parse for AlignmentAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
Expand All @@ -116,40 +119,108 @@ impl Parse for AlignmentAttr {
}
}

struct StaticSizeAttr(u32);
#[derive(Debug)]
pub struct StaticSizeAttr(u32);

impl Parse for StaticSizeAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let span = input.span();
match input
.parse::<LitInt>()
.and_then(|lit| lit.base10_parse::<u32>())
{
Ok(num) => Ok(Self(num)),
_ => Err(syn::Error::new(input.span(), "expected u32 literal")),
_ => Err(syn::Error::new(span, "expected u32 literal")),
}
}
}

enum SizeAttr {
#[derive(Debug)]
pub enum SizeAttr {
Static(StaticSizeAttr),
Runtime,
}

impl Parse for SizeAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let span = input.span();
match input.parse::<StaticSizeAttr>() {
Ok(static_size) => Ok(SizeAttr::Static(static_size)),
_ => match input.parse::<Path>() {
Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime),
_ => Err(syn::Error::new(
input.span(),
span,
"expected u32 literal or `runtime` identifier",
)),
},
}
}
}

#[derive(Debug)]
pub enum ShaderAttr {
Align { attr: AlignmentAttr, span: Span },
Size { attr: SizeAttr, span: Span },
}

impl Parse for ShaderAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ident_span = input.span();
let Ok(ident) = input.parse::<Ident>() else {
return Err(syn::Error::new(ident_span, "expected `align` or `size`"));
};

match ident.to_string().as_str() {
"align" => {
if !input.peek(syn::token::Paren) {
return Err(syn::Error::new(
ident_span,
"expected attribute arguments in parentheses: `align(...)`",
));
}

let args;
parenthesized!(args in input);
let attr_span = args.span();
let align_attr: AlignmentAttr = args.parse()?;
Ok(ShaderAttr::Align {
attr: align_attr,
span: attr_span,
})
}
"size" => {
if !input.peek(syn::token::Paren) {
return Err(syn::Error::new(
ident_span,
"expected attribute arguments in parentheses: `size(...)`",
));
}

let args;
parenthesized!(args in input);
let attr_span = args.span();
let size_attr: SizeAttr = args.parse()?;
Ok(ShaderAttr::Size {
attr: size_attr,
span: attr_span,
})
}
_ => Err(syn::Error::new(
ident_span,
"unknown shader attribute, expected `align` or `size`",
)),
}
}
}

#[derive(Debug)]
pub struct ShaderAttrList(Punctuated<ShaderAttr, Token![,]>);
impl Parse for ShaderAttrList {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Self(input.parse_terminated(ShaderAttr::parse, Token![,])?))
}
}

struct Errors {
inner: Option<Error>,
}
Expand Down Expand Up @@ -196,41 +267,44 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
size: None,
align: None,
};

for attr in &field.attrs {
if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("shader_align"))
{
if !(attr.meta.path().is_ident("shader")) {
continue;
}
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
if meta_list.path.is_ident("shader_align") {
let res = attr.parse_args::<AlignmentAttr>();
match res {
Ok(val) => data.align = Some((val.0, span)),
Err(err) => errors.append(err),

let shader_attrs = match attr.parse_args::<ShaderAttrList>() {
Ok(attrs) => attrs,
Err(err) => {
errors.append(err);
continue;
}
};

for shader_attr in shader_attrs.0 {
Comment thread
teoxoy marked this conversation as resolved.
match shader_attr {
ShaderAttr::Align { attr, span } => {
data.align = Some((attr.0, span));
}
ShaderAttr::Size { attr, span } => match attr {
SizeAttr::Runtime => {
if i == last_field_index {
is_runtime_sized = true;
} else {
let err = syn::Error::new(
span,
"only the last field can be `size(runtime)`",
);
errors.append(err);
continue;
}
}
} else if meta_list.path.is_ident("size") {
let res = if i == last_field_index {
attr.parse_args::<SizeAttr>().map(|val| match val {
SizeAttr::Runtime => {
is_runtime_sized = true;
None
}
SizeAttr::Static(size) => Some((size.0, span)),
})
} else {
attr.parse_args::<StaticSizeAttr>()
.map(|val| Some((val.0, span)))
};
match res {
Ok(val) => data.size = val,
Err(err) => errors.append(err),
SizeAttr::Static(attr) => {
data.size = Some((attr.0, span));
}
}
},
}
Err(err) => errors.append(err),
};
}
}
data
})
Expand All @@ -255,7 +329,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
if !is_runtime_sized {
let err = syn::Error::new(
field.ty.span(),
"`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[size(runtime)]`!",
"`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[shader(size(runtime))]`!",
);
errors.append(err)
}
Expand Down Expand Up @@ -305,7 +379,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get();
#root::concat_assert!(
alignment <= #align,
"shader_align attribute value must be at least ", alignment, " (field's type alignment)"
"shader(align) attribute value must be at least ", alignment, " (field's type alignment)"
)
}
check();
Expand Down
8 changes: 4 additions & 4 deletions src/core/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pub trait ShaderType {
/// # use mint;
/// #[derive(ShaderType)]
/// struct Invalid {
/// #[size(runtime)]
/// #[shader(size(runtime))]
/// vec: Vec<mint::Vector4<f32>>
/// }
/// Invalid::assert_uniform_compat();
Expand All @@ -171,7 +171,7 @@ pub trait ShaderType {
/// Invalid::assert_uniform_compat();
/// ```
///
/// Will not panic (fixed via shader_align attribute)
/// Will not panic (fixed via #[shader(align)] attribute)
///
/// ```
/// # use crate::encase::ShaderType;
Expand All @@ -182,7 +182,7 @@ pub trait ShaderType {
/// #[derive(ShaderType)]
/// struct Valid {
/// a: f32,
/// #[shader_align(16)]
/// #[shader(align(16))]
/// b: S,
/// }
/// Valid::assert_uniform_compat();
Expand All @@ -198,7 +198,7 @@ pub trait ShaderType {
/// # }
/// #[derive(ShaderType)]
/// struct Valid {
/// #[size(16)]
/// #[shader(size(16))]
/// a: f32,
/// b: S,
/// }
Expand Down
12 changes: 6 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@
///
/// Field attributes
///
/// - `#[shader_align(X)]` where `X` is a power of 2 [`u32`] literal (equivalent to [WGSL align attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-align))
/// - `#[shader(align(X))]` where `X` is a power of 2 [`u32`] literal (equivalent to [WGSL align attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-align))
///
/// Used to increase the alignment of the field
///
/// - `#[size(X)]` where `X` is a [`u32`] literal (equivalent to [WGSL size attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-size))
/// - `#[shader(size(X))]` where `X` is a [`u32`] literal (equivalent to [WGSL size attribute](https://gpuweb.github.io/gpuweb/wgsl/#attribute-size))
///
/// Used to increase the size of the field
///
/// - `#[size(runtime)]` can only be attached to the last field of the struct
/// - `#[shader(size(runtime))]` can only be attached to the last field of the struct
///
/// Used to denote the fact that the field it is attached to is a runtime-sized array
///
/// # Note about generics
///
/// While structs using generic type parameters are supported by this derive macro
///
/// - the `#[shader_align(X)]` and `#[size(X)]` attributes will only work
/// - the `#[shader(align(X))]` and `#[shader(size(X))]` attributes will only work
/// if they are attached to fields whose type contains no generic type parameters
///
/// # Examples
Expand Down Expand Up @@ -70,7 +70,7 @@
/// #[derive(ShaderType)]
/// struct Positions {
/// length: ArrayLength,
/// #[size(runtime)]
/// #[shader(size(runtime))]
/// positions: Vec<mint::Point2<f32>>
/// }
/// ```
Expand All @@ -88,7 +88,7 @@
/// const N: usize,
/// > {
/// array: [&'a mut E; N],
/// #[size(runtime)]
/// #[shader(size(runtime))]
/// rts_array: &'a mut Vec<&'b T>,
/// }
/// ```
Expand Down
4 changes: 2 additions & 2 deletions tests/assert_uniform_compat_fail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct S {

#[derive(ShaderType)]
struct WrappedF32 {
#[size(16)]
#[shader(size(16))]
elem: f32,
}

Expand Down Expand Up @@ -63,7 +63,7 @@ fn test_array_stride() {
fn test_rts_array() {
#[derive(ShaderType)]
struct TestRTSArray {
#[size(runtime)]
#[shader(size(runtime))]
a: Vec<f32>,
}

Expand Down
Loading