.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward, set_context_parallel_group
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
:members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
.. autoapifunction:: transformer_engine.pytorch.autocast
.. autoapifunction:: transformer_engine.pytorch.quantized_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
.. autoapifunction:: transformer_engine.pytorch.mark_not_offload
.. autoapiclass:: transformer_engine.pytorch.ManualOffloadSynchronizer
.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy
.. autoapifunction:: transformer_engine.pytorch.is_fp8_available
.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available
.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available
.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available
.. autoapifunction:: transformer_engine.pytorch.is_bf16_available
.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version
.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability
.. autoapifunction:: transformer_engine.pytorch.get_default_recipe
Mixture of Experts (MoE) functions
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs
Communication-computation overlap
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
.. autoapiclass:: transformer_engine.pytorch.QuantizedTensorStorage
:members: update_usage, prepare_for_saving, restore_from_saved
.. autoapiclass:: transformer_engine.pytorch.QuantizedTensor(shape, dtype, *, requires_grad=False, device=None)
:members: dequantize, quantize_
.. autoapiclass:: transformer_engine.pytorch.Float8TensorStorage(data, fp8_scale_inv, fp8_dtype, data_transpose=None, quantizer=None)
.. autoapiclass:: transformer_engine.pytorch.MXFP8TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)
.. autoapiclass:: transformer_engine.pytorch.Float8BlockwiseQTensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)
.. autoapiclass:: transformer_engine.pytorch.NVFP4TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)
.. autoapiclass:: transformer_engine.pytorch.Float8Tensor(shape, dtype, data, fp8_scale_inv, fp8_dtype, requires_grad=False, data_transpose=None, quantizer=None)
.. autoapiclass:: transformer_engine.pytorch.MXFP8Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)
.. autoapiclass:: transformer_engine.pytorch.Float8BlockwiseQTensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)
.. autoapiclass:: transformer_engine.pytorch.NVFP4Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)
.. autoapiclass:: transformer_engine.pytorch.Quantizer(rowwise, columnwise)
:members: update_quantized, quantize
.. autoapiclass:: transformer_engine.pytorch.Float8Quantizer(scale, amax, fp8_dtype, *, rowwise=True, columnwise=True)
.. autoapiclass:: transformer_engine.pytorch.Float8CurrentScalingQuantizer(fp8_dtype, device, *, rowwise=True, columnwise=True, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.MXFP8Quantizer(fp8_dtype, *, rowwise=True, columnwise=True)
.. autoapiclass:: transformer_engine.pytorch.Float8BlockQuantizer(fp8_dtype, *, rowwise, columnwise, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.NVFP4Quantizer(fp4_dtype, *, rowwise=True, columnwise=True, **kwargs)
Tensor saving and restoring functions
.. autoapifunction:: transformer_engine.pytorch.prepare_for_saving
.. autoapifunction:: transformer_engine.pytorch.restore_from_saved
.. autoapiclass:: transformer_engine.pytorch.ops.Sequential
:members: forward
.. autoapiclass:: transformer_engine.pytorch.ops.FusibleOperation
:members: fuser_forward, fuser_backward
.. autoapiclass:: transformer_engine.pytorch.ops.BasicOperation
:members: op_forward, op_backward
.. autoapiclass:: transformer_engine.pytorch.ops.FusedOperation
:members: fuser_forward, fuser_backward
.. autoapifunction:: transformer_engine.pytorch.ops.register_forward_fusion
.. autoapifunction:: transformer_engine.pytorch.ops.register_backward_fusion
.. autoapiclass:: transformer_engine.pytorch.ops.Linear
.. autoapiclass:: transformer_engine.pytorch.ops.AddExtraInput
.. autoapiclass:: transformer_engine.pytorch.ops.AllGather
.. autoapiclass:: transformer_engine.pytorch.ops.AllReduce
.. autoapiclass:: transformer_engine.pytorch.ops.BasicLinear
:members: _functional_forward, _functional_backward
.. autoapiclass:: transformer_engine.pytorch.ops.Bias
.. autoapiclass:: transformer_engine.pytorch.ops.ClampedSwiGLU
.. autoapiclass:: transformer_engine.pytorch.ops.ConstantScale
.. autoapiclass:: transformer_engine.pytorch.ops.Dropout
.. autoapiclass:: transformer_engine.pytorch.ops.GEGLU
.. autoapiclass:: transformer_engine.pytorch.ops.GELU
.. autoapiclass:: transformer_engine.pytorch.ops.GLU
.. autoapiclass:: transformer_engine.pytorch.ops.GroupedLinear
.. autoapiclass:: transformer_engine.pytorch.ops.Identity
.. autoapiclass:: transformer_engine.pytorch.ops.L2Normalization
.. autoapiclass:: transformer_engine.pytorch.ops.LayerNorm
.. autoapiclass:: transformer_engine.pytorch.ops.MakeExtraOutput
.. autoapiclass:: transformer_engine.pytorch.ops.QGELU
.. autoapiclass:: transformer_engine.pytorch.ops.QGEGLU
.. autoapiclass:: transformer_engine.pytorch.ops.Quantize
.. autoapiclass:: transformer_engine.pytorch.ops.ReGLU
.. autoapiclass:: transformer_engine.pytorch.ops.ReLU
.. autoapiclass:: transformer_engine.pytorch.ops.ReduceScatter
.. autoapiclass:: transformer_engine.pytorch.ops.Reshape
.. autoapiclass:: transformer_engine.pytorch.ops.RMSNorm
.. autoapiclass:: transformer_engine.pytorch.ops.SReGLU
.. autoapiclass:: transformer_engine.pytorch.ops.SReLU
.. autoapiclass:: transformer_engine.pytorch.ops.ScaledClampedQGeGLU
.. autoapiclass:: transformer_engine.pytorch.ops.ScaledSwiGLU
.. autoapiclass:: transformer_engine.pytorch.ops.SiLU
.. autoapiclass:: transformer_engine.pytorch.ops.SwiGLU
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init