Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions oxifft-codegen-impl/src/gen_simd/avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! f64 variant: 8×f64 = 4 complexes per `__m512d` register.
//! f32 variant: 16×f32 = 8 complexes per `__m512` register; uses `_ps` intrinsics.
//!
//! All emitted functions carry `#[target_feature(enable = "avx512f")]`.
//! All emitted functions carry `#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]`.
//! Complex multiply uses FMA:
//! - Real part: `_mm512_fmsub_pd(re_a, re_b, mul(im_a, im_b))` → ac − bd
//! - Imag part: `_mm512_fmadd_pd(re_a, im_b, mul(im_a, re_b))` → ad + bc
Expand Down Expand Up @@ -32,7 +32,7 @@ pub(super) fn gen_avx512_size_2_f64() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 4 f64 elements (2 complex numbers).
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_2_avx512_f64(data: &mut [f64], _sign: i32) {
use core::arch::x86_64::*;

Expand Down Expand Up @@ -79,7 +79,7 @@ pub(super) fn gen_avx512_size_4_f64() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 8 f64 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_4_avx512_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;

Expand Down Expand Up @@ -150,7 +150,7 @@ pub(super) fn gen_avx512_size_8_f64() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 16 f64 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx512_f64(data: &mut [f64], sign: i32) {
use core::arch::x86_64::*;
Expand Down Expand Up @@ -308,7 +308,7 @@ pub(super) fn gen_avx512_size_2_f32() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 4 f32 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_2_avx512_f32(data: &mut [f32], _sign: i32) {
use core::arch::x86_64::*;

Expand Down Expand Up @@ -348,7 +348,7 @@ pub(super) fn gen_avx512_size_4_f32() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 8 f32 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
unsafe fn codelet_simd_4_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;

Expand Down Expand Up @@ -418,7 +418,7 @@ pub(super) fn gen_avx512_size_8_f32() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 16 f32 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_8_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
Expand Down Expand Up @@ -581,7 +581,7 @@ pub(super) fn gen_avx512_size_16_f32() -> TokenStream {
/// - Caller must verify AVX-512F is available.
/// - `data` must contain at least 32 f32 elements.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[cfg(feature = "avx512")] #[target_feature(enable = "avx512f")]
#[allow(clippy::too_many_lines)]
unsafe fn codelet_simd_16_avx512_f32(data: &mut [f32], sign: i32) {
use core::arch::x86_64::*;
Expand Down
3 changes: 3 additions & 0 deletions oxifft-codegen-impl/src/gen_simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {

#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
if is_x86_feature_detected!("avx512f") {
// Safety: AVX-512F detected, pointer valid for len f64s
unsafe { #avx512_f64_name(f64_data, sign); }
Expand Down Expand Up @@ -298,6 +299,7 @@ fn gen_dispatcher(n: usize) -> proc_macro2::TokenStream {

#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
if is_x86_feature_detected!("avx512f") {
// Safety: AVX-512F detected
unsafe { #avx512_f32_name(f32_data, sign); }
Expand Down Expand Up @@ -363,6 +365,7 @@ fn gen_dispatcher_16() -> proc_macro2::TokenStream {

#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
if is_x86_feature_detected!("avx512f") {
// Safety: AVX-512F detected, pointer valid for len f32s
unsafe { #avx512_f32_name(f32_data, sign); }
Expand Down
4 changes: 3 additions & 1 deletion oxifft-codegen-impl/src/gen_simd/runtime_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ fn build_detect_x86_body() -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
if is_x86_feature_detected!("avx512f") {
return ISA_AVX512_LEVEL;
}
Expand Down Expand Up @@ -176,7 +177,7 @@ fn build_x86_64_branches(config: DispatcherConfig) -> TokenStream {
if size == 16 {
if config.precision == Precision::F32 {
return quote! {
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if cached_level == ISA_AVX512_LEVEL {
// Safety: avx512f detected at runtime.
Expand Down Expand Up @@ -214,6 +215,7 @@ fn build_x86_64_branches(config: DispatcherConfig) -> TokenStream {
quote! {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "avx512")]
if cached_level == ISA_AVX512_LEVEL {
// Safety: avx512f detected at runtime.
let data_len = data.len() * 2;
Expand Down
1 change: 1 addition & 0 deletions oxifft/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ std = ["dep:serde_json", "num-complex/std", "num-traits/std", "serde/std"]
# Threading support (requires std)
threading = ["std", "dep:rayon"]
simd = []
avx512 = []
# Quad-precision (128-bit) floating-point support (pure Rust)
f128-support = []
# Half-precision (16-bit) floating-point support (pure Rust)
Expand Down
6 changes: 3 additions & 3 deletions oxifft/src/dft/codelets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
mod codegen_tests;
mod composite;
pub mod generated_simd;
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
pub mod hand_avx512;
#[cfg(all(test, target_arch = "x86_64"))]
#[cfg(all(test, target_arch = "x86_64", feature = "avx512"))]
mod hand_avx512_tests;
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
pub(crate) mod hand_avx512_twiddles;
mod notw;
pub mod simd;
Expand Down
6 changes: 3 additions & 3 deletions oxifft/src/dft/codelets/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn notw_8_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
#[inline]
pub fn notw_16_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
// --- x86_64: try hand-tuned AVX-512 first ---
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if TypeId::of::<T>() == TypeId::of::<f64>() {
let x_f64 = unsafe {
Expand Down Expand Up @@ -115,7 +115,7 @@ fn notw_16_simd_f64_fallback<T: Float>(x: &mut [Complex<T>], sign: i32) {
#[inline]
pub fn notw_32_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
// --- x86_64: try hand-tuned AVX-512 first ---
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if TypeId::of::<T>() == TypeId::of::<f64>() {
let x_f64 = unsafe {
Expand Down Expand Up @@ -156,7 +156,7 @@ fn notw_32_simd_f64_fallback<T: Float>(x: &mut [Complex<T>], sign: i32) {
#[inline]
pub fn notw_64_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
// --- x86_64: try hand-tuned AVX-512 first ---
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if TypeId::of::<T>() == TypeId::of::<f64>() {
let x_f64 = unsafe {
Expand Down
12 changes: 9 additions & 3 deletions oxifft/src/dft/solvers/stockham/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,18 @@ pub fn stockham_f64(input: &[Complex<f64>], output: &mut [Complex<f64>], sign: S
unsafe { aarch64::stockham_radix4_neon(input, output, sign) }
}

#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
// Prefer AVX-512 when available (4x f64 per register for complex)
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512dq") {
unsafe { x86_64::stockham_radix4_avx512(input, output, sign) }
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { x86_64::stockham_radix4_avx512(input, output, sign) };
return;
}
}

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { x86_64::stockham_radix4_avx2(input, output, sign) }
} else {
generic::stockham_radix4_scalar(input, output, sign);
Expand Down
3 changes: 3 additions & 0 deletions oxifft/src/dft/solvers/stockham/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ pub unsafe fn stockham_radix4_avx2(
/// undefined behavior (illegal instruction trap at runtime).
/// Both `input` and `output` must have the same length, which must be a
/// power of two.
#[cfg(feature = "avx512")]
#[target_feature(enable = "avx512f", enable = "avx512dq")]
pub unsafe fn stockham_radix4_avx512(
input: &[Complex<f64>],
Expand Down Expand Up @@ -659,6 +660,7 @@ pub unsafe fn stockham_radix4_avx512(
/// Caller must ensure the target CPU supports the `avx512f` feature
/// (the `_mm_permute_pd` and `_mm_addsub_pd` intrinsics used internally
/// require at least SSE3/AVX, which is implied by `avx512f`).
#[cfg(feature = "avx512")]
#[inline(always)]
unsafe fn avx512_cmul_128(
v: core::arch::x86_64::__m128d,
Expand Down Expand Up @@ -688,6 +690,7 @@ unsafe fn avx512_cmul_128(
/// Calling this function on a CPU that lacks these features causes undefined
/// behavior (illegal instruction trap at runtime).
/// `input` and `output` must have the same length, which must be 1, 2, or 4.
#[cfg(feature = "avx512")]
#[target_feature(enable = "avx512f", enable = "avx512dq")]
unsafe fn stockham_small_avx512(input: &[Complex<f64>], output: &mut [Complex<f64>], sign: Sign) {
unsafe {
Expand Down
4 changes: 2 additions & 2 deletions oxifft/src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mod traits;
mod avx;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
mod avx512;
#[cfg(target_arch = "x86_64")]
mod sse2;
Expand Down Expand Up @@ -144,7 +144,7 @@ pub use avx::{AvxF32, AvxF64};
#[cfg(target_arch = "x86_64")]
pub use avx2::{has_avx2_fma, Avx2F32, Avx2F64};

#[cfg(target_arch = "x86_64")]
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
pub use avx512::{has_avx512f, Avx512F32, Avx512F64};

#[cfg(target_arch = "aarch64")]
Expand Down