From 294280e5a74eb9131abc78bd139c40634e104124 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 3 Apr 2026 17:09:15 -0700 Subject: [PATCH 01/20] feat: make timesfm2_5 onnx export compatible --- .../models/timesfm/modeling_timesfm.py | 24 ++++++++++--------- .../models/timesfm/modular_timesfm.py | 24 ++++++++++--------- .../models/timesfm2_5/modeling_timesfm2_5.py | 24 ++++++++++--------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 219908c1e47c..e0b7518134ba 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -608,22 +608,24 @@ def _preprocess( input_ts, input_padding = [], [] for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) input_ts.append(ts) input_padding.append(padding) result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + result = result + ( + torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), + ) return result def _postprocess_output( diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index ca53ec7dd668..4b722c15e636 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -565,22 +565,24 @@ def _preprocess( input_ts, input_padding = [], [] for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) input_ts.append(ts) input_padding.append(padding) result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + result = result + ( + torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), + ) return result def _postprocess_output( diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index e7b4e799d20b..ad2deb02c995 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -700,22 +700,24 @@ def _preprocess( input_ts, input_padding = [], [] for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: - ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) input_ts.append(ts) input_padding.append(padding) result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + result = result + ( + torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), + ) return result def _postprocess_output( From 007bc1f7ecd579673584470d052ec1ccd1e1d0c3 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Sat, 4 Apr 2026 17:32:43 -0700 Subject: [PATCH 02/20] chore: revert tensor line --- src/transformers/models/timesfm/modular_timesfm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 4b722c15e636..4efcdec43a5c 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -580,9 +580,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + ( - torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), - ) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( From 70ef555ab6819d7f5b9ecbfdf8c8428b7195bb8d Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Sat, 4 Apr 2026 17:34:14 -0700 Subject: [PATCH 03/20] chore: re-generate files --- src/transformers/models/timesfm/modeling_timesfm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index e0b7518134ba..faee12a9ddca 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -623,9 +623,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + ( - torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), - ) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( From b8c0e9d3e8d2a06acc1f7f4d6b3d8b5237cd29cc Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Sat, 4 Apr 2026 17:35:28 -0700 Subject: [PATCH 04/20] chore: regen 2_5 file --- src/transformers/models/timesfm2_5/modeling_timesfm2_5.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index ad2deb02c995..72fc786c866c 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -715,9 +715,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + ( - torch.as_tensor(freq[: len(inputs)], dtype=torch.int32, device=result[0].device).reshape(-1, 1), - ) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( From 2fa9d056fd34fb8d8c1214c1bb9ad643da80c6eb Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 11:47:40 -0700 Subject: [PATCH 05/20] fix preprocess to support single tensor && fix truncate negative check to be onnx compatible --- .../models/timesfm/modeling_timesfm.py | 86 ++++++++++++------ .../models/timesfm/modular_timesfm.py | 88 +++++++++++++------ .../models/timesfm2_5/modeling_timesfm2_5.py | 75 +++++++++++----- .../models/timesfm2_5/modular_timesfm2_5.py | 21 +++-- 4 files changed, 187 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index faee12a9ddca..9519acf55005 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -590,40 +590,60 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. + inputs: Either a list of 1D tensors (one per series, may differ in length) or a single + 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. + The 2D path avoids Python loops and is ONNX-export friendly. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to `self.context_len`). + context_len: Optional context length override (defaults to ``self.context_len``). Returns: - Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. + Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. """ if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( @@ -653,7 +673,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, @@ -663,8 +683,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFmOutputForPrediction: r""" - past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `Sequence[torch.Tensor]`): + Past values of the time series that serves as input to the model. Can be a 2D tensor + (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). freq (`torch.LongTensor` of shape `(batch_size,)`): Frequency indices for the time series data. window_size (`int`, *optional*): @@ -701,12 +722,20 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: + if is_tensor: + raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs = [] new_freqs = [] for i, ts in enumerate(inputs): @@ -719,7 +748,10 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) + if is_tensor: + freq = [0] * past_values.shape[0] + else: + freq = [0] * len(inputs) input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) input_ts = input_ts.to(device) @@ -774,9 +806,9 @@ def forward( if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = torch.maximum(mean_outputs, 0.0) - full_outputs = torch.maximum(full_outputs, 0.0) + if truncate_negative: + mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs) + full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs) loss = None if future_values is not None: diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 4efcdec43a5c..fe79f925d28b 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -547,40 +547,62 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. + inputs: Either a list of 1D tensors (one per series, may differ in length) or a single + 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. + The 2D path avoids Python loops and is ONNX-export friendly. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to `self.context_len`). + context_len: Optional context length override (defaults to ``self.context_len``). Returns: - Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. + Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. """ if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros( + context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device + ), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( @@ -610,7 +632,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, @@ -620,8 +642,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFmOutputForPrediction: r""" - past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `Sequence[torch.Tensor]`): + Past values of the time series that serves as input to the model. Can be a 2D tensor + (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). freq (`torch.LongTensor` of shape `(batch_size,)`): Frequency indices for the time series data. window_size (`int`, *optional*): @@ -658,12 +681,20 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: + if is_tensor: + raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs = [] new_freqs = [] for i, ts in enumerate(inputs): @@ -676,7 +707,10 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - freq = [0] * len(inputs) + if is_tensor: + freq = [0] * past_values.shape[0] + else: + freq = [0] * len(inputs) input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) input_ts = input_ts.to(device) @@ -731,9 +765,9 @@ def forward( if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if inp_min >= 0 and truncate_negative: - mean_outputs = torch.maximum(mean_outputs, 0.0) - full_outputs = torch.maximum(full_outputs, 0.0) + if truncate_negative: + mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs) + full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs) loss = None if future_values is not None: diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index 72fc786c866c..4e58f7f23577 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -682,40 +682,60 @@ def __init__(self, config: TimesFm2_5Config): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. + inputs: Either a list of 1D tensors (one per series, may differ in length) or a single + 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. + The 2D path avoids Python loops and is ONNX-export friendly. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to `self.context_len`). + context_len: Optional context length override (defaults to ``self.context_len``). Returns: - Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. + Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. """ if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( @@ -745,7 +765,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -754,8 +774,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Each tensor is a 1D time series. + past_values (`Sequence[torch.Tensor]` or `torch.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. Can be a 2D tensor + (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): @@ -769,12 +790,20 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -forecast_context_len:] + input_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-forecast_context_len:] for ts in past_values] + input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: + if is_tensor: + raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs: list[torch.Tensor] = [] for ts in inputs: new_inputs.extend(self._timesfm_moving_average(ts, window_size)) diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index 3a912d07946b..32e2e7277d66 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -532,7 +532,7 @@ def _decode_and_project( @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -541,8 +541,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Each tensor is a 1D time series. + past_values (`Sequence[torch.Tensor]` or `torch.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. Can be a 2D tensor + (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): @@ -556,12 +557,20 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -forecast_context_len:] + input_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-forecast_context_len:] for ts in past_values] + input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: + if is_tensor: + raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs: list[torch.Tensor] = [] for ts in inputs: new_inputs.extend(self._timesfm_moving_average(ts, window_size)) From d41c4651279ffb6ae83b13314cd5f5a210f84a3b Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 12:09:42 -0700 Subject: [PATCH 06/20] add onnx related fixes & tests --- .../models/timesfm/modeling_timesfm.py | 11 +- .../models/timesfm/modular_timesfm.py | 11 +- tests/models/timesfm/test_modeling_timesfm.py | 252 ++++++++++++++++++ .../timesfm2_5/test_modeling_timesfm2_5.py | 67 +++++ 4 files changed, 335 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 9519acf55005..2aae2c478313 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -642,8 +642,12 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) + if isinstance(freq, torch.Tensor): + inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) + else: + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + result = result + (inp_freq,) return result def _postprocess_output( @@ -749,7 +753,8 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") if is_tensor: - freq = [0] * past_values.shape[0] + # Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int. + freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device) else: freq = [0] * len(inputs) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index fe79f925d28b..9fbf33299fe5 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -601,8 +601,12 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) + if isinstance(freq, torch.Tensor): + inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) + else: + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + result = result + (inp_freq,) return result def _postprocess_output( @@ -708,7 +712,8 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") if is_tensor: - freq = [0] * past_values.shape[0] + # Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int. + freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device) else: freq = [0] * len(inputs) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 31ed60ce9d5c..c465a0f10697 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -210,6 +210,258 @@ def test_model_main_input_name(self): self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) +@require_torch +class TimesFmForwardInputVariantsTest(unittest.TestCase): + def setUp(self): + config = TimesFmConfig( + patch_length=32, + context_length=64, + horizon_length=32, + hidden_size=16, + intermediate_size=32, + head_dim=8, + num_hidden_layers=1, + num_attention_heads=2, + ) + self.model = TimesFmModelForPrediction(config).to(torch_device).eval() + self.horizon_len = config.horizon_length + + def test_different_length_series(self): + """forward() handles a list of series with different lengths.""" + inputs = [ + torch.randn(20, device=torch_device), + torch.randn(50, device=torch_device), + torch.randn(100, device=torch_device), + ] + with torch.no_grad(): + out = self.model(past_values=inputs, freq=[0, 0, 0]) + self.assertEqual(out.mean_predictions.shape, (3, self.horizon_len)) + + def test_very_short_and_very_long_series(self): + """forward() works when one series is tiny and another exceeds context_len.""" + inputs = [ + torch.randn(5, device=torch_device), + torch.randn(500, device=torch_device), + ] + with torch.no_grad(): + out = self.model(past_values=inputs, freq=[0, 0]) + self.assertEqual(out.mean_predictions.shape, (2, self.horizon_len)) + + def test_2d_tensor_input(self): + """forward() accepts a 2D tensor and produces correct output shape.""" + inputs = torch.randn(4, 80, device=torch_device) + with torch.no_grad(): + out = self.model(past_values=inputs, freq=[0, 0, 0, 0]) + self.assertEqual(out.mean_predictions.shape, (4, self.horizon_len)) + + def test_list_vs_tensor_parity(self): + """forward() with list and 2D tensor of equal-length series gives identical output.""" + raw = [torch.randn(50, device=torch_device) for _ in range(2)] + stacked = torch.stack(raw) + freq = [0, 0] + with torch.no_grad(): + out_list = self.model(past_values=raw, freq=freq) + out_tensor = self.model(past_values=stacked, freq=freq) + self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) + self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) + + def test_long_series_truncated(self): + """forward() with a long series produces the same output as passing only the tail.""" + long_series = torch.randn(500, device=torch_device) + tail = long_series[-64:] + with torch.no_grad(): + out_long = self.model(past_values=[long_series], freq=[0]) + out_tail = self.model(past_values=[tail], freq=[0]) + self.assertTrue(torch.allclose(out_long.mean_predictions, out_tail.mean_predictions, atol=1e-5)) + + def test_truncate_negative_with_positive_input(self): + """truncate_negative clamps outputs to zero when all inputs are non-negative.""" + inputs = torch.rand(2, 80, device=torch_device).abs() + 1.0 + with torch.no_grad(): + out = self.model(past_values=inputs, freq=[0, 0], truncate_negative=True) + self.assertTrue((out.mean_predictions >= 0).all()) + self.assertTrue((out.full_predictions >= 0).all()) + + def test_truncate_negative_with_negative_input(self): + """truncate_negative leaves outputs untouched when inputs contain negatives.""" + inputs = torch.randn(2, 80, device=torch_device) - 5.0 + with torch.no_grad(): + out_trunc = self.model(past_values=inputs, freq=[0, 0], truncate_negative=True) + out_plain = self.model(past_values=inputs, freq=[0, 0], truncate_negative=False) + self.assertTrue(torch.allclose(out_trunc.mean_predictions, out_plain.mean_predictions)) + self.assertTrue(torch.allclose(out_trunc.full_predictions, out_plain.full_predictions)) + + def test_onnx_export_and_inference(self): + """Export to ONNX, verify dynamic batch and truncate_negative both work.""" + try: + import onnxruntime as ort + except ImportError: + self.skipTest("onnxruntime not installed") + + import tempfile + + from torch.export import Dim + + class Wrapper(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def forward(self, past_values, freq): + o = self.m(past_values, freq=freq, truncate_negative=True) + return o.mean_predictions, o.full_predictions + + wrapped = Wrapper(self.model).cpu().eval() + export_input = torch.randn(2, 80) + export_freq = torch.zeros(2, 1, dtype=torch.int32) + batch = Dim("batch", min=1, max=64) + seq = Dim("seq", min=1, max=512) + with tempfile.TemporaryDirectory() as tmp: + path = f"{tmp}/model.onnx" + torch.onnx.export( + wrapped, + (export_input, export_freq), + path, + input_names=["past_values", "freq"], + output_names=["mean_predictions", "full_predictions"], + dynamo=True, + dynamic_shapes={ + "past_values": {0: batch, 1: seq}, + "freq": {0: batch}, + }, + ) + import onnx + + onnx_model = onnx.load(path, load_external_data=False) + op_types = {n.op_type for n in onnx_model.graph.node} + + # 1. Dynamic dims: input batch & seq must be symbolic strings, not fixed ints + inp = onnx_model.graph.input[0] + dims = [d.dim_param or d.dim_value for d in inp.type.tensor_type.shape.dim] + self.assertIsInstance(dims[0], str, f"batch dim should be symbolic, got {dims[0]}") + self.assertIsInstance(dims[1], str, f"seq dim should be symbolic, got {dims[1]}") + for out in onnx_model.graph.output: + out_batch = out.type.tensor_type.shape.dim[0] + self.assertTrue(out_batch.dim_param, f"output '{out.name}' batch dim not dynamic") + + # 2. No If nodes: all Python branches (forecast_context_len, window_size, + # freq is None, return_forecast_on_context, future_values) must be + # frozen at export time, not traced as conditional ops + if_nodes = [n.name for n in onnx_model.graph.node if n.op_type == "If"] + self.assertEqual(len(if_nodes), 0, f"Graph has If nodes (unfrozen branches): {if_nodes}") + + # 3. truncate_negative: the inp_min >= 0 check must be branchless (torch.where), + # so we expect a Where op in the graph instead of an If + self.assertIn("Where", op_types, "Missing Where op — truncate_negative not branchless") + + sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) + + # (a) different batch size AND seq length to verify both dims are dynamic + diff_input = torch.randn(3, 50) + diff_freq = torch.zeros(3, 1, dtype=torch.int32) + with torch.no_grad(): + pt_out = self.model(past_values=diff_input, freq=diff_freq, truncate_negative=True) + onnx_mean, onnx_full = sess.run( + None, {"past_values": diff_input.numpy(), "freq": diff_freq.numpy()} + ) + np.testing.assert_allclose(onnx_mean, pt_out.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(onnx_full, pt_out.full_predictions.numpy(), rtol=1e-3, atol=1e-3) + + # (b) all-positive input triggers the truncate_negative clamp path + pos_input = torch.rand(2, 80).abs() + 1.0 + pos_freq = torch.zeros(2, 1, dtype=torch.int32) + with torch.no_grad(): + pt_pos = self.model(past_values=pos_input, freq=pos_freq, truncate_negative=True) + onnx_mean_pos, onnx_full_pos = sess.run( + None, {"past_values": pos_input.numpy(), "freq": pos_freq.numpy()} + ) + np.testing.assert_allclose(onnx_mean_pos, pt_pos.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) + self.assertTrue((onnx_mean_pos >= 0).all()) + + # (c) freq=None path: wrapper does not pass freq, so `if freq is None` runs. Using + # `[0] * past_values.shape[0]` there bakes batch into the graph; this block fails + # if that regression returns. The (a)/(b) paths above never hit freq=None. + class WrapperNoFreq(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def forward(self, past_values): + o = self.m(past_values, truncate_negative=True) + return o.mean_predictions, o.full_predictions + + path_nf = f"{tmp}/model_no_freq.onnx" + torch.onnx.export( + WrapperNoFreq(self.model).cpu().eval(), + (export_input,), + path_nf, + input_names=["past_values"], + output_names=["mean_predictions", "full_predictions"], + dynamo=True, + dynamic_shapes={"past_values": {0: batch, 1: seq}}, + ) + onnx_nf = onnx.load(path_nf, load_external_data=False) + inp_nf = onnx_nf.graph.input[0] + dims_nf = [d.dim_param or d.dim_value for d in inp_nf.type.tensor_type.shape.dim] + self.assertIsInstance(dims_nf[0], str, f"no-freq export: batch dim should be symbolic, got {dims_nf[0]}") + sess_nf = ort.InferenceSession(path_nf, providers=["CPUExecutionProvider"]) + nf_input = torch.randn(3, 50) + with torch.no_grad(): + pt_nf = self.model(past_values=nf_input, truncate_negative=True) + onnx_m_nf, onnx_f_nf = sess_nf.run(None, {"past_values": nf_input.numpy()}) + np.testing.assert_allclose(onnx_m_nf, pt_nf.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(onnx_f_nf, pt_nf.full_predictions.numpy(), rtol=1e-3, atol=1e-3) + + def test_onnx_export_with_forecast_context_len(self): + """Export with forecast_context_len baked in; verify ONNX uses truncated context.""" + try: + import onnxruntime as ort + except ImportError: + self.skipTest("onnxruntime not installed") + + import tempfile + + from torch.export import Dim + + short_ctx = 32 + + class Wrapper(torch.nn.Module): + def __init__(self, m, ctx): + super().__init__() + self.m = m + self.ctx = ctx + + def forward(self, past_values): + o = self.m(past_values, forecast_context_len=self.ctx) + return o.mean_predictions, o.full_predictions + + wrapped = Wrapper(self.model, short_ctx).cpu().eval() + export_input = torch.randn(2, 80) + with tempfile.TemporaryDirectory() as tmp: + path = f"{tmp}/model.onnx" + batch = Dim("batch", min=1, max=64) + seq = Dim("seq", min=1, max=512) + torch.onnx.export( + wrapped, + (export_input,), + path, + input_names=["past_values"], + output_names=["mean_predictions", "full_predictions"], + dynamo=True, + dynamic_shapes={"past_values": {0: batch, 1: seq}}, + ) + sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) + + # ONNX graph has forecast_context_len=32 baked in, so passing 80 values + # should give the same result as PyTorch with the same override + test_input = torch.randn(2, 80) + with torch.no_grad(): + pt_out = self.model(past_values=test_input, forecast_context_len=short_ctx) + onnx_mean, onnx_full = sess.run(None, {"past_values": test_input.numpy()}) + np.testing.assert_allclose(onnx_mean, pt_out.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(onnx_full, pt_out.full_predictions.numpy(), rtol=1e-3, atol=1e-3) + + @require_torch @slow class TimesFmModelIntegrationTests(unittest.TestCase): diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 7a909da6d78c..9dfc06a10fef 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -283,6 +283,73 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): return inputs_dict +@require_torch +class TimesFm2_5ForwardInputVariantsTest(unittest.TestCase): + def setUp(self): + config = TimesFm2_5Config( + patch_length=32, + context_length=128, + horizon_length=8, + hidden_size=32, + intermediate_size=64, + head_dim=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + quantiles=[0.1, 0.5, 0.9], + output_quantile_len=16, + ) + self.model = TimesFm2_5ModelForPrediction(config).to(torch_device).eval() + self.horizon_len = config.horizon_length + + def test_different_length_series(self): + """forward() handles a list of series with different lengths.""" + inputs = [ + torch.randn(30, device=torch_device), + torch.randn(80, device=torch_device), + torch.randn(200, device=torch_device), + ] + with torch.no_grad(): + out = self.model(past_values=inputs) + self.assertEqual(out.mean_predictions.shape, (3, self.horizon_len)) + + def test_very_short_and_very_long_series(self): + """forward() works when one series is tiny and another exceeds context_len.""" + inputs = [ + torch.randn(5, device=torch_device), + torch.randn(500, device=torch_device), + ] + with torch.no_grad(): + out = self.model(past_values=inputs) + self.assertEqual(out.mean_predictions.shape, (2, self.horizon_len)) + + def test_2d_tensor_input(self): + """forward() accepts a 2D tensor and produces correct output shape.""" + inputs = torch.randn(4, 150, device=torch_device) + with torch.no_grad(): + out = self.model(past_values=inputs) + self.assertEqual(out.mean_predictions.shape, (4, self.horizon_len)) + + def test_list_vs_tensor_parity(self): + """forward() with list and 2D tensor of equal-length series gives identical output.""" + raw = [torch.randn(60, device=torch_device) for _ in range(2)] + stacked = torch.stack(raw) + with torch.no_grad(): + out_list = self.model(past_values=raw) + out_tensor = self.model(past_values=stacked) + self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) + self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) + + def test_long_series_truncated(self): + """forward() with a long series produces the same output as passing only the tail.""" + long_series = torch.randn(500, device=torch_device) + tail = long_series[-128:] + with torch.no_grad(): + out_long = self.model(past_values=[long_series]) + out_tail = self.model(past_values=[tail]) + self.assertTrue(torch.allclose(out_long.mean_predictions, out_tail.mean_predictions, atol=1e-5)) + + @require_torch @slow class TimesFm2_5ModelIntegrationTests(unittest.TestCase): From b81657391d1cb223d0099504c15f5326b0c3691d Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 12:45:36 -0700 Subject: [PATCH 07/20] revert docstrings --- .../models/timesfm/modeling_timesfm.py | 13 +++++------- .../models/timesfm/modular_timesfm.py | 17 ++++++--------- .../models/timesfm2_5/modeling_timesfm2_5.py | 21 ++++++++++--------- .../models/timesfm2_5/modular_timesfm2_5.py | 5 ++--- tests/models/timesfm/test_modeling_timesfm.py | 8 ++----- 5 files changed, 26 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 2aae2c478313..5132cc38dbcc 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -598,14 +598,12 @@ def _preprocess( """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: Either a list of 1D tensors (one per series, may differ in length) or a single - 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. - The 2D path avoids Python loops and is ONNX-export friendly. + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to ``self.context_len``). + context_len: Optional context length override (defaults to `self.context_len`). Returns: - Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. + Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. """ if context_len is None: context_len = self.context_len @@ -687,9 +685,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFmOutputForPrediction: r""" - past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Can be a 2D tensor - (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. freq (`torch.LongTensor` of shape `(batch_size,)`): Frequency indices for the time series data. window_size (`int`, *optional*): diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 9fbf33299fe5..96b71a868211 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -555,14 +555,12 @@ def _preprocess( """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: Either a list of 1D tensors (one per series, may differ in length) or a single - 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. - The 2D path avoids Python loops and is ONNX-export friendly. + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to ``self.context_len``). + context_len: Optional context length override (defaults to `self.context_len`). Returns: - Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. + Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. """ if context_len is None: context_len = self.context_len @@ -590,9 +588,7 @@ def _preprocess( padding = torch.cat( [ torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros( - context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device - ), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), ], dim=0, ) @@ -646,9 +642,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFmOutputForPrediction: r""" - past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `Sequence[torch.Tensor]`): - Past values of the time series that serves as input to the model. Can be a 2D tensor - (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. freq (`torch.LongTensor` of shape `(batch_size,)`): Frequency indices for the time series data. window_size (`int`, *optional*): diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index 4e58f7f23577..6e8ca528a723 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -690,14 +690,12 @@ def _preprocess( """Pad/truncate input time series to `context_len` and build a padding mask. Args: - inputs: Either a list of 1D tensors (one per series, may differ in length) or a single - 2D tensor of shape ``(batch, seq_len)`` where all rows share the same length. - The 2D path avoids Python loops and is ONNX-export friendly. + inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task. freq: Optional list of frequencies (returned as a tensor when provided). - context_len: Optional context length override (defaults to ``self.context_len``). + context_len: Optional context length override (defaults to `self.context_len`). Returns: - Tuple of ``(padded_inputs, padding_mask)`` and optionally a freq tensor. + Tuple of (padded_inputs, padding_mask) and optionally a freq tensor. """ if context_len is None: context_len = self.context_len @@ -734,8 +732,12 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - result = result + (torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1),) + if isinstance(freq, torch.Tensor): + inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) + else: + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + result = result + (inp_freq,) return result def _postprocess_output( @@ -774,9 +776,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]` or `torch.Tensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. Can be a 2D tensor - (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). + past_values (`Sequence[torch.Tensor]`): + Past values of the time series that serves as input to the model. Each tensor is a 1D time series. window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index 32e2e7277d66..d81a9152e2e9 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -541,9 +541,8 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> TimesFm2_5OutputForPrediction: r""" - past_values (`Sequence[torch.Tensor]` or `torch.Tensor` of shape `(batch_size, sequence_length)`): - Past values of the time series that serves as input to the model. Can be a 2D tensor - (ONNX-friendly, all rows same length) or a list of 1D tensors (variable-length series). + past_values (`Sequence[torch.Tensor]`): + Past values of the time series that serves as input to the model. Each tensor is a 1D time series. window_size (`int`, *optional*): Window size of trend + residual decomposition. If `None`, decomposition is not applied. future_values (`torch.Tensor`, *optional*): diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index c465a0f10697..fec494458f6f 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -361,9 +361,7 @@ def forward(self, past_values, freq): diff_freq = torch.zeros(3, 1, dtype=torch.int32) with torch.no_grad(): pt_out = self.model(past_values=diff_input, freq=diff_freq, truncate_negative=True) - onnx_mean, onnx_full = sess.run( - None, {"past_values": diff_input.numpy(), "freq": diff_freq.numpy()} - ) + onnx_mean, onnx_full = sess.run(None, {"past_values": diff_input.numpy(), "freq": diff_freq.numpy()}) np.testing.assert_allclose(onnx_mean, pt_out.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) np.testing.assert_allclose(onnx_full, pt_out.full_predictions.numpy(), rtol=1e-3, atol=1e-3) @@ -372,9 +370,7 @@ def forward(self, past_values, freq): pos_freq = torch.zeros(2, 1, dtype=torch.int32) with torch.no_grad(): pt_pos = self.model(past_values=pos_input, freq=pos_freq, truncate_negative=True) - onnx_mean_pos, onnx_full_pos = sess.run( - None, {"past_values": pos_input.numpy(), "freq": pos_freq.numpy()} - ) + onnx_mean_pos, onnx_full_pos = sess.run(None, {"past_values": pos_input.numpy(), "freq": pos_freq.numpy()}) np.testing.assert_allclose(onnx_mean_pos, pt_pos.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) self.assertTrue((onnx_mean_pos >= 0).all()) From 9b3a2c267e8998c9a8f61b483e71755e2a6bb622 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 13:18:38 -0700 Subject: [PATCH 08/20] reset and add fix for torch tensor --- .../models/timesfm/modeling_timesfm.py | 35 ++++++------------- .../models/timesfm/modular_timesfm.py | 35 ++++++------------- .../models/timesfm2_5/modeling_timesfm2_5.py | 7 ++-- 3 files changed, 22 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 5132cc38dbcc..2f19f0182e9e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -640,11 +640,8 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - if isinstance(freq, torch.Tensor): - inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) - else: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) result = result + (inp_freq,) return result @@ -675,7 +672,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor] | torch.Tensor, + past_values: Sequence[torch.Tensor], freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, @@ -723,20 +720,12 @@ def forward( else: fcontext_len = forecast_context_len - is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 + device = past_values[0].device - if is_tensor: - device = past_values.device - inputs = past_values[:, -fcontext_len:] - inp_min = inputs.min() - else: - device = past_values[0].device - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - if is_tensor: - raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs = [] new_freqs = [] for i, ts in enumerate(inputs): @@ -749,11 +738,7 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - if is_tensor: - # Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int. - freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device) - else: - freq = [0] * len(inputs) + freq = [0] * len(inputs) input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) input_ts = input_ts.to(device) @@ -808,9 +793,9 @@ def forward( if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if truncate_negative: - mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs) - full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs) + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) loss = None if future_values is not None: diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 96b71a868211..9a0d5e8c3f6e 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -597,11 +597,8 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - if isinstance(freq, torch.Tensor): - inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) - else: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) result = result + (inp_freq,) return result @@ -632,7 +629,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor] | torch.Tensor, + past_values: Sequence[torch.Tensor], freq: Sequence[torch.Tensor | int] | None = None, window_size: int | None = None, future_values: torch.Tensor | None = None, @@ -680,20 +677,12 @@ def forward( else: fcontext_len = forecast_context_len - is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 + device = past_values[0].device - if is_tensor: - device = past_values.device - inputs = past_values[:, -fcontext_len:] - inp_min = inputs.min() - else: - device = past_values[0].device - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - if is_tensor: - raise ValueError("window_size is not supported when past_values is a 2D tensor.") new_inputs = [] new_freqs = [] for i, ts in enumerate(inputs): @@ -706,11 +695,7 @@ def forward( if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") - if is_tensor: - # Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int. - freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device) - else: - freq = [0] * len(inputs) + freq = [0] * len(inputs) input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) input_ts = input_ts.to(device) @@ -765,9 +750,9 @@ def forward( if window_size is not None: mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] - if truncate_negative: - mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs) - full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs) + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) loss = None if future_values is not None: diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index 6e8ca528a723..f1f75e826861 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -732,11 +732,8 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - if isinstance(freq, torch.Tensor): - inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1) - else: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) + batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) + inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) result = result + (inp_freq,) return result From 952f56f0e9c74cbb8b08031f6c9bbaf67b66aa97 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 14:16:42 -0700 Subject: [PATCH 09/20] fixed batch tests --- .../models/timesfm/modeling_timesfm.py | 52 +++--- .../models/timesfm/modular_timesfm.py | 52 +++--- .../models/timesfm2_5/modeling_timesfm2_5.py | 52 +++--- tests/models/timesfm/test_modeling_timesfm.py | 166 ------------------ 4 files changed, 78 insertions(+), 244 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 2f19f0182e9e..af74bb8c9665 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -608,36 +608,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - x = inputs[:, -context_len:] - num_front_pad = context_len - x.shape[1] - x = F.pad(x, (num_front_pad, 0)) + # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + # x = inputs[:, -context_len:] + # num_front_pad = context_len - x.shape[1] + # x = F.pad(x, (num_front_pad, 0)) + # padding = torch.cat( + # [ + # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + # torch.zeros( + # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + # ), + # ], + # dim=1, + # ) + # result = (x, padding) + # else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) padding = torch.cat( [ - torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - torch.zeros( - x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - ), + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), ], - dim=1, + dim=0, ) - result = (x, padding) - else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat( - [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), - ], - dim=0, - ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 9a0d5e8c3f6e..0efb8173006e 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -565,36 +565,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - x = inputs[:, -context_len:] - num_front_pad = context_len - x.shape[1] - x = F.pad(x, (num_front_pad, 0)) + # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + # x = inputs[:, -context_len:] + # num_front_pad = context_len - x.shape[1] + # x = F.pad(x, (num_front_pad, 0)) + # padding = torch.cat( + # [ + # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + # torch.zeros( + # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + # ), + # ], + # dim=1, + # ) + # result = (x, padding) + # else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) padding = torch.cat( [ - torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - torch.zeros( - x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - ), + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), ], - dim=1, + dim=0, ) - result = (x, padding) - else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat( - [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), - ], - dim=0, - ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index f1f75e826861..edd301519ca1 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -700,36 +700,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - x = inputs[:, -context_len:] - num_front_pad = context_len - x.shape[1] - x = F.pad(x, (num_front_pad, 0)) + # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + # x = inputs[:, -context_len:] + # num_front_pad = context_len - x.shape[1] + # x = F.pad(x, (num_front_pad, 0)) + # padding = torch.cat( + # [ + # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + # torch.zeros( + # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + # ), + # ], + # dim=1, + # ) + # result = (x, padding) + # else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) padding = torch.cat( [ - torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - torch.zeros( - x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - ), + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), ], - dim=1, + dim=0, ) - result = (x, padding) - else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat( - [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), - ], - dim=0, - ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index fec494458f6f..192d19e4ed2f 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -291,172 +291,6 @@ def test_truncate_negative_with_negative_input(self): self.assertTrue(torch.allclose(out_trunc.mean_predictions, out_plain.mean_predictions)) self.assertTrue(torch.allclose(out_trunc.full_predictions, out_plain.full_predictions)) - def test_onnx_export_and_inference(self): - """Export to ONNX, verify dynamic batch and truncate_negative both work.""" - try: - import onnxruntime as ort - except ImportError: - self.skipTest("onnxruntime not installed") - - import tempfile - - from torch.export import Dim - - class Wrapper(torch.nn.Module): - def __init__(self, m): - super().__init__() - self.m = m - - def forward(self, past_values, freq): - o = self.m(past_values, freq=freq, truncate_negative=True) - return o.mean_predictions, o.full_predictions - - wrapped = Wrapper(self.model).cpu().eval() - export_input = torch.randn(2, 80) - export_freq = torch.zeros(2, 1, dtype=torch.int32) - batch = Dim("batch", min=1, max=64) - seq = Dim("seq", min=1, max=512) - with tempfile.TemporaryDirectory() as tmp: - path = f"{tmp}/model.onnx" - torch.onnx.export( - wrapped, - (export_input, export_freq), - path, - input_names=["past_values", "freq"], - output_names=["mean_predictions", "full_predictions"], - dynamo=True, - dynamic_shapes={ - "past_values": {0: batch, 1: seq}, - "freq": {0: batch}, - }, - ) - import onnx - - onnx_model = onnx.load(path, load_external_data=False) - op_types = {n.op_type for n in onnx_model.graph.node} - - # 1. Dynamic dims: input batch & seq must be symbolic strings, not fixed ints - inp = onnx_model.graph.input[0] - dims = [d.dim_param or d.dim_value for d in inp.type.tensor_type.shape.dim] - self.assertIsInstance(dims[0], str, f"batch dim should be symbolic, got {dims[0]}") - self.assertIsInstance(dims[1], str, f"seq dim should be symbolic, got {dims[1]}") - for out in onnx_model.graph.output: - out_batch = out.type.tensor_type.shape.dim[0] - self.assertTrue(out_batch.dim_param, f"output '{out.name}' batch dim not dynamic") - - # 2. No If nodes: all Python branches (forecast_context_len, window_size, - # freq is None, return_forecast_on_context, future_values) must be - # frozen at export time, not traced as conditional ops - if_nodes = [n.name for n in onnx_model.graph.node if n.op_type == "If"] - self.assertEqual(len(if_nodes), 0, f"Graph has If nodes (unfrozen branches): {if_nodes}") - - # 3. truncate_negative: the inp_min >= 0 check must be branchless (torch.where), - # so we expect a Where op in the graph instead of an If - self.assertIn("Where", op_types, "Missing Where op — truncate_negative not branchless") - - sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) - - # (a) different batch size AND seq length to verify both dims are dynamic - diff_input = torch.randn(3, 50) - diff_freq = torch.zeros(3, 1, dtype=torch.int32) - with torch.no_grad(): - pt_out = self.model(past_values=diff_input, freq=diff_freq, truncate_negative=True) - onnx_mean, onnx_full = sess.run(None, {"past_values": diff_input.numpy(), "freq": diff_freq.numpy()}) - np.testing.assert_allclose(onnx_mean, pt_out.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(onnx_full, pt_out.full_predictions.numpy(), rtol=1e-3, atol=1e-3) - - # (b) all-positive input triggers the truncate_negative clamp path - pos_input = torch.rand(2, 80).abs() + 1.0 - pos_freq = torch.zeros(2, 1, dtype=torch.int32) - with torch.no_grad(): - pt_pos = self.model(past_values=pos_input, freq=pos_freq, truncate_negative=True) - onnx_mean_pos, onnx_full_pos = sess.run(None, {"past_values": pos_input.numpy(), "freq": pos_freq.numpy()}) - np.testing.assert_allclose(onnx_mean_pos, pt_pos.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) - self.assertTrue((onnx_mean_pos >= 0).all()) - - # (c) freq=None path: wrapper does not pass freq, so `if freq is None` runs. Using - # `[0] * past_values.shape[0]` there bakes batch into the graph; this block fails - # if that regression returns. The (a)/(b) paths above never hit freq=None. - class WrapperNoFreq(torch.nn.Module): - def __init__(self, m): - super().__init__() - self.m = m - - def forward(self, past_values): - o = self.m(past_values, truncate_negative=True) - return o.mean_predictions, o.full_predictions - - path_nf = f"{tmp}/model_no_freq.onnx" - torch.onnx.export( - WrapperNoFreq(self.model).cpu().eval(), - (export_input,), - path_nf, - input_names=["past_values"], - output_names=["mean_predictions", "full_predictions"], - dynamo=True, - dynamic_shapes={"past_values": {0: batch, 1: seq}}, - ) - onnx_nf = onnx.load(path_nf, load_external_data=False) - inp_nf = onnx_nf.graph.input[0] - dims_nf = [d.dim_param or d.dim_value for d in inp_nf.type.tensor_type.shape.dim] - self.assertIsInstance(dims_nf[0], str, f"no-freq export: batch dim should be symbolic, got {dims_nf[0]}") - sess_nf = ort.InferenceSession(path_nf, providers=["CPUExecutionProvider"]) - nf_input = torch.randn(3, 50) - with torch.no_grad(): - pt_nf = self.model(past_values=nf_input, truncate_negative=True) - onnx_m_nf, onnx_f_nf = sess_nf.run(None, {"past_values": nf_input.numpy()}) - np.testing.assert_allclose(onnx_m_nf, pt_nf.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(onnx_f_nf, pt_nf.full_predictions.numpy(), rtol=1e-3, atol=1e-3) - - def test_onnx_export_with_forecast_context_len(self): - """Export with forecast_context_len baked in; verify ONNX uses truncated context.""" - try: - import onnxruntime as ort - except ImportError: - self.skipTest("onnxruntime not installed") - - import tempfile - - from torch.export import Dim - - short_ctx = 32 - - class Wrapper(torch.nn.Module): - def __init__(self, m, ctx): - super().__init__() - self.m = m - self.ctx = ctx - - def forward(self, past_values): - o = self.m(past_values, forecast_context_len=self.ctx) - return o.mean_predictions, o.full_predictions - - wrapped = Wrapper(self.model, short_ctx).cpu().eval() - export_input = torch.randn(2, 80) - with tempfile.TemporaryDirectory() as tmp: - path = f"{tmp}/model.onnx" - batch = Dim("batch", min=1, max=64) - seq = Dim("seq", min=1, max=512) - torch.onnx.export( - wrapped, - (export_input,), - path, - input_names=["past_values"], - output_names=["mean_predictions", "full_predictions"], - dynamo=True, - dynamic_shapes={"past_values": {0: batch, 1: seq}}, - ) - sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) - - # ONNX graph has forecast_context_len=32 baked in, so passing 80 values - # should give the same result as PyTorch with the same override - test_input = torch.randn(2, 80) - with torch.no_grad(): - pt_out = self.model(past_values=test_input, forecast_context_len=short_ctx) - onnx_mean, onnx_full = sess.run(None, {"past_values": test_input.numpy()}) - np.testing.assert_allclose(onnx_mean, pt_out.mean_predictions.numpy(), rtol=1e-3, atol=1e-3) - np.testing.assert_allclose(onnx_full, pt_out.full_predictions.numpy(), rtol=1e-3, atol=1e-3) - @require_torch @slow From 8bb48d8f3c06d21154f14cde5bc5b1d29eabcad5 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 14:18:27 -0700 Subject: [PATCH 10/20] add sym logic --- .../models/timesfm/modeling_timesfm.py | 52 +++++++++---------- .../models/timesfm/modular_timesfm.py | 52 +++++++++---------- .../models/timesfm2_5/modeling_timesfm2_5.py | 52 +++++++++---------- 3 files changed, 78 insertions(+), 78 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index af74bb8c9665..2f19f0182e9e 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -608,36 +608,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - # x = inputs[:, -context_len:] - # num_front_pad = context_len - x.shape[1] - # x = F.pad(x, (num_front_pad, 0)) - # padding = torch.cat( - # [ - # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - # torch.zeros( - # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - # ), - # ], - # dim=1, - # ) - # result = (x, padding) - # else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 0efb8173006e..9a0d5e8c3f6e 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -565,36 +565,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - # x = inputs[:, -context_len:] - # num_front_pad = context_len - x.shape[1] - # x = F.pad(x, (num_front_pad, 0)) - # padding = torch.cat( - # [ - # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - # torch.zeros( - # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - # ), - # ], - # dim=1, - # ) - # result = (x, padding) - # else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index edd301519ca1..f1f75e826861 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -700,36 +700,36 @@ def _preprocess( if context_len is None: context_len = self.context_len - # if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: - # x = inputs[:, -context_len:] - # num_front_pad = context_len - x.shape[1] - # x = F.pad(x, (num_front_pad, 0)) - # padding = torch.cat( - # [ - # torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), - # torch.zeros( - # x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device - # ), - # ], - # dim=1, - # ) - # result = (x, padding) - # else: - input_ts, input_padding = [], [] - for ts in inputs: - ts = ts[-context_len:] - num_front_pad = context_len - ts.shape[0] - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) padding = torch.cat( [ - torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), - torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), ], - dim=0, + dim=1, ) - input_ts.append(ts) - input_padding.append(padding) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: + ts = ts[-context_len:] + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) From 6b58107792d5b604376252b5f95c7d42644117b4 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 14:33:09 -0700 Subject: [PATCH 11/20] just a test --- test_onnx_infer.py | 242 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 test_onnx_infer.py diff --git a/test_onnx_infer.py b/test_onnx_infer.py new file mode 100644 index 000000000000..6e83672c07db --- /dev/null +++ b/test_onnx_infer.py @@ -0,0 +1,242 @@ +# Run after `python test_export_onnx.py` so `test_onnx/model.onnx` exists. +# Uses the same tiny config and `torch.manual_seed(0)` as export for weight parity with the graph. +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F + +from transformers import TimesFm2_5Config, TimesFm2_5ModelForPrediction + +ONNX_PATH = Path(__file__).resolve().parent / "test_onnx" / "model.onnx" + + +def tiny_config() -> TimesFm2_5Config: + return TimesFm2_5Config( + patch_length=32, + context_length=128, + horizon_length=8, + hidden_size=32, + intermediate_size=64, + head_dim=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + quantiles=[0.1, 0.5, 0.9], + output_quantile_len=16, + ) + + +def build_model() -> TimesFm2_5ModelForPrediction: + torch.manual_seed(0) + return TimesFm2_5ModelForPrediction(tiny_config()).eval() + + +def test_preprocess_2d_tensor_matches_list_of_rows(model: TimesFm2_5ModelForPrediction) -> None: + """Batched `(B, L)` input must match `_preprocess` on `[row[0], …, row[B-1]]` (values + padding mask).""" + ctx = model.context_len + for seed, batch, seq_len in ( + (11, 1, 1), + (12, 2, 17), + (13, 4, 64), + (14, 3, ctx), + (15, 2, ctx + 40), + ): + torch.manual_seed(seed) + x = torch.randn(batch, seq_len) + rows = [x[i].clone() for i in range(batch)] + a = model._preprocess(x, context_len=ctx) + b = model._preprocess(rows, context_len=ctx) + assert a[0].shape == b[0].shape == (batch, ctx), (a[0].shape, b[0].shape) + h = model.horizon_len + assert a[1].shape == b[1].shape == (batch, ctx + h), (a[1].shape, b[1].shape) + torch.testing.assert_close(a[0], b[0], rtol=0, atol=0, msg="padded time series mismatch") + torch.testing.assert_close(a[1], b[1], rtol=0, atol=0, msg="padding mask mismatch") + + +def test_preprocess_short_2d_left_pad_and_mask_invariants(model: TimesFm2_5ModelForPrediction) -> None: + """After preprocess, leading pad slots are zero in `ts` and one in `padding`.""" + ctx = model.context_len + h = model.horizon_len + torch.manual_seed(21) + seq_len = 40 + b = 2 + x = torch.randn(b, seq_len) + ts, padding = model._preprocess(x, context_len=ctx) + num_front = ctx - seq_len + assert num_front > 0 + assert torch.all(ts[:, :num_front] == 0) + assert torch.all(padding[:, :num_front] == 1) + assert torch.all(padding[:, num_front : num_front + seq_len] == 0) + + +def list_to_left_padded_matrix(parts: list[torch.Tensor], width: int) -> np.ndarray: + """Match `_preprocess` left-padding: short series get zeros on the left.""" + rows = [] + for p in parts: + p = p.float()[-width:] + rows.append(F.pad(p, (width - p.shape[0], 0))) + return torch.stack(rows).numpy() + + +def onnx_output_names(session: ort.InferenceSession) -> list[str]: + return [o.name for o in session.get_outputs()] + + +def onnx_protobuf_input_shape(path: Path, input_name: str) -> tuple[int | str, int | str]: + """Shape of named graph input as in the .onnx file.""" + import onnx + + model = onnx.load(str(path)) + inputs = {i.name: i for i in model.graph.input} + if input_name not in inputs: + raise KeyError(f"Input {input_name!r} not found in graph") + + inp = inputs[input_name] + dims: list[int | str] = [] + for d in inp.type.tensor_type.shape.dim: + param = (d.dim_param or "").strip() + if param: + dims.append(param) + elif d.HasField("dim_value"): + dims.append(int(d.dim_value)) + else: + dims.append("?") + if len(dims) < 2: + return dims[0], "?" + return dims[0], dims[1] + + +def test_onnx_export_batch_axis_contract(session: ort.InferenceSession, onnx_path: Path, ctx: int) -> None: + """ORT must match the **onnx file** (protobuf) axis 0 contract for past_values.""" + pb0, pb1 = onnx_protobuf_input_shape(onnx_path, "past_values") + print(f" protobuf past_values shape: [{pb0!r}, {pb1!r}]") + + inp_name = "past_values" + names = onnx_output_names(session) + out = "mean_predictions" if "mean_predictions" in names else names[0] + + def run_batch(batch_size: int) -> None: + x = np.random.randn(batch_size, ctx).astype(np.float32) + session.run([out], {inp_name: x}) + + if isinstance(pb0, int): + raise AssertionError( + f"Dynamic Batch Requirement Failed: The ONNX file declares a fixed batch dimension {pb0!r} " + f"on 'past_values' (axis 0). Expected a symbolic name (e.g., 'batch')." + ) + else: + run_batch(2) + run_batch(5) + print(f" OK: ORT accepted batch 2 and 5 (protobuf symbolic batch axis {pb0!r}).") + + +def run_onnx(session: ort.InferenceSession, x_np: np.ndarray) -> dict[str, np.ndarray]: + names = onnx_output_names(session) + arrays = session.run(names, {"past_values": x_np}) + return dict(zip(names, arrays, strict=True)) + + +def assert_close(a: np.ndarray, b: np.ndarray, msg: str, rtol: float = 1e-3, atol: float = 1e-3) -> None: + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg) + + +def should_check_last_hidden_state_ort(model: TimesFm2_5ModelForPrediction, seq_len: int) -> bool: + threshold = model.context_len - model.config.patch_length + return seq_len > threshold + + +def test_pytorch_list_matches_stacked_2d_when_each_series_has_length_ctx(model: TimesFm2_5ModelForPrediction) -> None: + ctx = model.context_len + torch.manual_seed(42) + s0 = torch.randn(ctx) + s1 = torch.randn(ctx) + stacked = torch.stack([s0, s1], dim=0) + with torch.no_grad(): + o_list = model(past_values=[s0, s1]) + o_2d = model(past_values=stacked) + assert_close(o_list.mean_predictions.numpy(), o_2d.mean_predictions.numpy(), "mean_predictions list vs 2D") + assert_close(o_list.full_predictions.numpy(), o_2d.full_predictions.numpy(), "full_predictions list vs 2D") + + +def test_variable_length_list_vs_prepadded_2d_differs_in_padding_mask(model: TimesFm2_5ModelForPrediction) -> None: + ctx = model.context_len + data = [torch.linspace(0, 1, 100), torch.sin(torch.linspace(0, 20, 67))] + matrix_t = torch.from_numpy(list_to_left_padded_matrix(data, ctx)) + with torch.no_grad(): + o_list = model(past_values=data) + o_pad2d = model(past_values=matrix_t) + assert ((o_list.mean_predictions - o_pad2d.mean_predictions).abs().max() > 1e-3) + + +def test_onnx_matches_pytorch_all_outputs( + session: ort.InferenceSession, + model: TimesFm2_5ModelForPrediction, + x: torch.Tensor, +) -> None: + names = onnx_output_names(session) + with torch.no_grad(): + pt = model(past_values=x) + + ort_dict = run_onnx(session, x.numpy()) + if "mean_predictions" in ort_dict: + assert_close(ort_dict["mean_predictions"], pt.mean_predictions.numpy(), "mean_predictions ORT vs PT") + if "full_predictions" in ort_dict: + assert_close(ort_dict["full_predictions"], pt.full_predictions.numpy(), "full_predictions ORT vs PT") + + +def run_dynamic_shape_checks(session: ort.InferenceSession, model: TimesFm2_5ModelForPrediction, onnx_path: Path) -> None: + pb_batch, pb_seq = onnx_protobuf_input_shape(onnx_path, "past_values") + ctx = model.context_len + + if not isinstance(pb_batch, str): + raise AssertionError(f"Dynamic Batch Check Failed: {pb_batch!r}") + + for batch in (1, 3, 5): + torch.manual_seed(100 + batch) + x = torch.randn(batch, ctx) + test_onnx_matches_pytorch_all_outputs(session, model, x) + print(f"ONNX vs PyTorch OK: batch_size={batch}, seq_len={ctx}") + + if not isinstance(pb_seq, str): + raise AssertionError(f"Dynamic Sequence Check Failed: {pb_seq!r}") + + for seq_len in (1, 32, 64, ctx, 200): + torch.manual_seed(300 + seq_len) + x = torch.randn(2, seq_len) + test_onnx_matches_pytorch_all_outputs(session, model, x) + print(f"ONNX vs PyTorch OK: batch_size=2, seq_len={seq_len}") + + +def main() -> None: + if not ONNX_PATH.is_file(): + print(f"Missing {ONNX_PATH}", file=sys.stderr) + sys.exit(1) + + model = build_model() + ctx = model.context_len + + print("PyTorch tests...") + test_preprocess_2d_tensor_matches_list_of_rows(model) + test_preprocess_short_2d_left_pad_and_mask_invariants(model) + print(" OK") + + session = ort.InferenceSession(str(ONNX_PATH), providers=["CPUExecutionProvider"]) + + print("ONNX: past_values batch contract...") + test_onnx_export_batch_axis_contract(session, ONNX_PATH, ctx) + print(" OK") + + print("ONNX dynamic shape parity...") + run_dynamic_shape_checks(session, model, ONNX_PATH) + print(" OK") + + print("All checks passed.") + + +if __name__ == "__main__": + main() From 8fb1b931583bca246a63e4c5a9b431dff9a552dd Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 14:34:45 -0700 Subject: [PATCH 12/20] covered it --- src/transformers/models/timesfm/modular_timesfm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 9a0d5e8c3f6e..2392bcbdf2ef 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -597,9 +597,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) - result = result + (inp_freq,) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( From 499a88bd1266941b753301d7deab1ca621681d71 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 14:35:39 -0700 Subject: [PATCH 13/20] regen --- src/transformers/models/timesfm/modeling_timesfm.py | 4 +--- src/transformers/models/timesfm2_5/modeling_timesfm2_5.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 2f19f0182e9e..feaa47660414 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -640,9 +640,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) - result = result + (inp_freq,) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index f1f75e826861..c4e367c26167 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -732,9 +732,7 @@ def _preprocess( result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: - batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs) - inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1) - result = result + (inp_freq,) + result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result def _postprocess_output( From 6cac69799c48f6a3ae445da07e864db90ab38755 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 15:16:16 -0700 Subject: [PATCH 14/20] add fix for window_size --- .../models/timesfm2_5/modeling_timesfm2_5.py | 37 +++++++++++-- .../models/timesfm2_5/modular_timesfm2_5.py | 35 +++++++++++-- test_onnx_infer.py | 52 +++++++++++++++++++ .../timesfm2_5/test_modeling_timesfm2_5.py | 25 +++++++++ 4 files changed, 139 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index c4e367c26167..19d702e10eea 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -799,11 +799,13 @@ def forward( if window_size is not None: if is_tensor: - raise ValueError("window_size is not supported when past_values is a 2D tensor.") - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + else: + new_inputs: list[torch.Tensor] = [] + for ts in inputs: + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + inputs = new_inputs if truncate_negative is None: truncate_negative = self.config.infer_is_positive @@ -942,5 +944,30 @@ def _decode_and_project( return point_forecast, quantile_spreads, model_outputs + @staticmethod + def _timesfm_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + + # Apply convolution to calculate the moving average + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr + __all__ = ["TimesFm2_5ModelForPrediction", "TimesFm2_5PreTrainedModel", "TimesFm2_5Model"] diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index d81a9152e2e9..d64a9e00ffed 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -569,11 +569,13 @@ def forward( if window_size is not None: if is_tensor: - raise ValueError("window_size is not supported when past_values is a 2D tensor.") - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + else: + new_inputs: list[torch.Tensor] = [] + for ts in inputs: + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + inputs = new_inputs if truncate_negative is None: truncate_negative = self.config.infer_is_positive @@ -658,6 +660,29 @@ def _flip_quantiles(x: torch.Tensor) -> torch.Tensor: loss=loss, ) + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + + # Apply convolution to calculate the moving average + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr + __all__ = [ "TimesFm2_5Config", diff --git a/test_onnx_infer.py b/test_onnx_infer.py index 6e83672c07db..cab0efa72960 100644 --- a/test_onnx_infer.py +++ b/test_onnx_infer.py @@ -212,6 +212,56 @@ def run_dynamic_shape_checks(session: ort.InferenceSession, model: TimesFm2_5Mod print(f"ONNX vs PyTorch OK: batch_size=2, seq_len={seq_len}") +def test_input_min_and_type_parity(model: TimesFm2_5ModelForPrediction) -> None: + """ + Verifies the fix for tensor vs list inputs. + Checks that input_min is calculated across ALL rows, which affects truncate_negative. + """ + ctx = model.context_len + # Row 0 is all positive, Row 1 has a negative value. + # If the model only checked the first row (or handled the list incorrectly), + # it might wrongly decide to clamp outputs. + s0 = torch.ones(ctx) * 10.0 + s1 = torch.ones(ctx) * 10.0 + s1[5] = -100.0 # The negative value is in the second row + + stacked = torch.stack([s0, s1], dim=0) + list_input = [s0, s1] + + # We need a case where the model WOULD produce a negative value to see if it gets clamped. + # Since we use random weights, we'll just check that outputs match between tensor and list paths. + with torch.no_grad(): + out_tensor = model(past_values=stacked, truncate_negative=True) + out_list = model(past_values=list_input, truncate_negative=True) + + assert_close( + out_tensor.mean_predictions.numpy(), + out_list.mean_predictions.numpy(), + "Input type parity failed (tensor vs list with negative value)" + ) + print(" OK: Tensor and List paths matched for mixed-sign inputs.") + + +def test_window_size_tensor_vs_list_parity(model: TimesFm2_5ModelForPrediction) -> None: + """Verifies that the new batched window_size logic for tensors matches the list logic.""" + ctx = model.context_len + batch = 3 + window_size = 4 + torch.manual_seed(123) + x = torch.randn(batch, ctx) + rows = [x[i].clone() for i in range(batch)] + + with torch.no_grad(): + out_tensor = model(past_values=x, window_size=window_size) + out_list = model(past_values=rows, window_size=window_size) + + # We expect (batch, horizon_len) rows in the output + h = model.horizon_len + assert out_tensor.mean_predictions.shape == out_list.mean_predictions.shape == (batch, h) + torch.testing.assert_close(out_tensor.mean_predictions, out_list.mean_predictions, rtol=1e-5, atol=1e-5) + print(f" OK: Tensor vs List window_size parity at B={batch}, W={window_size}") + + def main() -> None: if not ONNX_PATH.is_file(): print(f"Missing {ONNX_PATH}", file=sys.stderr) @@ -223,6 +273,8 @@ def main() -> None: print("PyTorch tests...") test_preprocess_2d_tensor_matches_list_of_rows(model) test_preprocess_short_2d_left_pad_and_mask_invariants(model) + test_input_min_and_type_parity(model) + test_window_size_tensor_vs_list_parity(model) print(" OK") session = ort.InferenceSession(str(ONNX_PATH), providers=["CPUExecutionProvider"]) diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 9dfc06a10fef..06e1483bf96b 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -349,6 +349,31 @@ def test_long_series_truncated(self): out_tail = self.model(past_values=[tail]) self.assertTrue(torch.allclose(out_long.mean_predictions, out_tail.mean_predictions, atol=1e-5)) + def test_window_size_parity(self): + """forward() with window_size works for both list and 2D tensor inputs and gives identical output across edge cases.""" + batch_sizes = [1, 4] + seq_lengths = [32, 64] + window_sizes = [1, 4, 32, 128] + + for b in batch_sizes: + for slen in seq_lengths: + for w in window_sizes: + raw = [torch.randn(slen, device=torch_device) for _ in range(b)] + stacked = torch.stack(raw) + with torch.no_grad(): + out_list = self.model(past_values=raw, window_size=w) + out_tensor = self.model(past_values=stacked, window_size=w) + + self.assertTrue( + torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5), + f"Parity failed for b={b}, slen={slen}, w={w}", + ) + self.assertTrue( + torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5), + f"Full prediction parity failed for b={b}, slen={slen}, w={w}", + ) + + @require_torch @slow From 92bd304e5e6e09797fd7b9d543d783aad33c041e Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 15:21:58 -0700 Subject: [PATCH 15/20] remove tests --- tests/models/timesfm/test_modeling_timesfm.py | 82 ------------------- .../timesfm2_5/test_modeling_timesfm2_5.py | 37 --------- 2 files changed, 119 deletions(-) diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py index 192d19e4ed2f..31ed60ce9d5c 100644 --- a/tests/models/timesfm/test_modeling_timesfm.py +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -210,88 +210,6 @@ def test_model_main_input_name(self): self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) -@require_torch -class TimesFmForwardInputVariantsTest(unittest.TestCase): - def setUp(self): - config = TimesFmConfig( - patch_length=32, - context_length=64, - horizon_length=32, - hidden_size=16, - intermediate_size=32, - head_dim=8, - num_hidden_layers=1, - num_attention_heads=2, - ) - self.model = TimesFmModelForPrediction(config).to(torch_device).eval() - self.horizon_len = config.horizon_length - - def test_different_length_series(self): - """forward() handles a list of series with different lengths.""" - inputs = [ - torch.randn(20, device=torch_device), - torch.randn(50, device=torch_device), - torch.randn(100, device=torch_device), - ] - with torch.no_grad(): - out = self.model(past_values=inputs, freq=[0, 0, 0]) - self.assertEqual(out.mean_predictions.shape, (3, self.horizon_len)) - - def test_very_short_and_very_long_series(self): - """forward() works when one series is tiny and another exceeds context_len.""" - inputs = [ - torch.randn(5, device=torch_device), - torch.randn(500, device=torch_device), - ] - with torch.no_grad(): - out = self.model(past_values=inputs, freq=[0, 0]) - self.assertEqual(out.mean_predictions.shape, (2, self.horizon_len)) - - def test_2d_tensor_input(self): - """forward() accepts a 2D tensor and produces correct output shape.""" - inputs = torch.randn(4, 80, device=torch_device) - with torch.no_grad(): - out = self.model(past_values=inputs, freq=[0, 0, 0, 0]) - self.assertEqual(out.mean_predictions.shape, (4, self.horizon_len)) - - def test_list_vs_tensor_parity(self): - """forward() with list and 2D tensor of equal-length series gives identical output.""" - raw = [torch.randn(50, device=torch_device) for _ in range(2)] - stacked = torch.stack(raw) - freq = [0, 0] - with torch.no_grad(): - out_list = self.model(past_values=raw, freq=freq) - out_tensor = self.model(past_values=stacked, freq=freq) - self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) - self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) - - def test_long_series_truncated(self): - """forward() with a long series produces the same output as passing only the tail.""" - long_series = torch.randn(500, device=torch_device) - tail = long_series[-64:] - with torch.no_grad(): - out_long = self.model(past_values=[long_series], freq=[0]) - out_tail = self.model(past_values=[tail], freq=[0]) - self.assertTrue(torch.allclose(out_long.mean_predictions, out_tail.mean_predictions, atol=1e-5)) - - def test_truncate_negative_with_positive_input(self): - """truncate_negative clamps outputs to zero when all inputs are non-negative.""" - inputs = torch.rand(2, 80, device=torch_device).abs() + 1.0 - with torch.no_grad(): - out = self.model(past_values=inputs, freq=[0, 0], truncate_negative=True) - self.assertTrue((out.mean_predictions >= 0).all()) - self.assertTrue((out.full_predictions >= 0).all()) - - def test_truncate_negative_with_negative_input(self): - """truncate_negative leaves outputs untouched when inputs contain negatives.""" - inputs = torch.randn(2, 80, device=torch_device) - 5.0 - with torch.no_grad(): - out_trunc = self.model(past_values=inputs, freq=[0, 0], truncate_negative=True) - out_plain = self.model(past_values=inputs, freq=[0, 0], truncate_negative=False) - self.assertTrue(torch.allclose(out_trunc.mean_predictions, out_plain.mean_predictions)) - self.assertTrue(torch.allclose(out_trunc.full_predictions, out_plain.full_predictions)) - - @require_torch @slow class TimesFmModelIntegrationTests(unittest.TestCase): diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 06e1483bf96b..04944eb6a636 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -302,34 +302,6 @@ def setUp(self): self.model = TimesFm2_5ModelForPrediction(config).to(torch_device).eval() self.horizon_len = config.horizon_length - def test_different_length_series(self): - """forward() handles a list of series with different lengths.""" - inputs = [ - torch.randn(30, device=torch_device), - torch.randn(80, device=torch_device), - torch.randn(200, device=torch_device), - ] - with torch.no_grad(): - out = self.model(past_values=inputs) - self.assertEqual(out.mean_predictions.shape, (3, self.horizon_len)) - - def test_very_short_and_very_long_series(self): - """forward() works when one series is tiny and another exceeds context_len.""" - inputs = [ - torch.randn(5, device=torch_device), - torch.randn(500, device=torch_device), - ] - with torch.no_grad(): - out = self.model(past_values=inputs) - self.assertEqual(out.mean_predictions.shape, (2, self.horizon_len)) - - def test_2d_tensor_input(self): - """forward() accepts a 2D tensor and produces correct output shape.""" - inputs = torch.randn(4, 150, device=torch_device) - with torch.no_grad(): - out = self.model(past_values=inputs) - self.assertEqual(out.mean_predictions.shape, (4, self.horizon_len)) - def test_list_vs_tensor_parity(self): """forward() with list and 2D tensor of equal-length series gives identical output.""" raw = [torch.randn(60, device=torch_device) for _ in range(2)] @@ -340,15 +312,6 @@ def test_list_vs_tensor_parity(self): self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) - def test_long_series_truncated(self): - """forward() with a long series produces the same output as passing only the tail.""" - long_series = torch.randn(500, device=torch_device) - tail = long_series[-128:] - with torch.no_grad(): - out_long = self.model(past_values=[long_series]) - out_tail = self.model(past_values=[tail]) - self.assertTrue(torch.allclose(out_long.mean_predictions, out_tail.mean_predictions, atol=1e-5)) - def test_window_size_parity(self): """forward() with window_size works for both list and 2D tensor inputs and gives identical output across edge cases.""" batch_sizes = [1, 4] From 3fc56e49aef665f623288aa7d8eeec3b3e72e041 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 15:23:50 -0700 Subject: [PATCH 16/20] tests --- .../models/timesfm/modeling_timesfm.py | 56 ++++++++++++++----- .../models/timesfm/modular_timesfm.py | 54 +++++++++++++----- .../models/timesfm2_5/modeling_timesfm2_5.py | 47 ++++++---------- .../models/timesfm2_5/modular_timesfm2_5.py | 23 -------- 4 files changed, 98 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index feaa47660414..9442c1007ea2 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -718,21 +718,33 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if is_tensor: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + if freq is not None: + freq = torch.repeat_interleave(freq, 2, dim=0) + else: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + freq = new_freqs if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") @@ -811,15 +823,29 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 2392bcbdf2ef..283603633514 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -675,21 +675,33 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if is_tensor: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + if freq is not None: + freq = torch.repeat_interleave(freq, 2, dim=0) + else: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + freq = new_freqs if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") @@ -768,15 +780,27 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index 19d702e10eea..bc7c233609ae 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -891,15 +891,29 @@ def _flip_quantiles(x: torch.Tensor) -> torch.Tensor: ) @staticmethod - def _timesfm2_5_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm2_5_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr def _decode_and_project( self, @@ -944,30 +958,5 @@ def _decode_and_project( return point_forecast, quantile_spreads, model_outputs - @staticmethod - def _timesfm_moving_average( - arr: torch.Tensor, window_size: int - ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: - """Calculates the moving average using PyTorch's convolution function.""" - # arr shape: (T,) or (B, T) - is_2d = arr.ndim == 2 - if not is_2d: - arr = arr.unsqueeze(0) # (1, T) - - # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) - - # Create a convolution kernel - kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size - kernel = kernel.view(1, 1, -1) # (1, 1, window_size) - - # Apply convolution to calculate the moving average - # F.conv1d expects (N, C_in, L_in) - smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) - - if not is_2d: - return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] - return smoothed_arr, arr - smoothed_arr - __all__ = ["TimesFm2_5ModelForPrediction", "TimesFm2_5PreTrainedModel", "TimesFm2_5Model"] diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index d64a9e00ffed..c6cfbde0de04 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -660,29 +660,6 @@ def _flip_quantiles(x: torch.Tensor) -> torch.Tensor: loss=loss, ) - @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: - """Calculates the moving average using PyTorch's convolution function.""" - # arr shape: (T,) or (B, T) - is_2d = arr.ndim == 2 - if not is_2d: - arr = arr.unsqueeze(0) # (1, T) - - # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) - - # Create a convolution kernel - kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size - kernel = kernel.view(1, 1, -1) # (1, 1, window_size) - - # Apply convolution to calculate the moving average - # F.conv1d expects (N, C_in, L_in) - smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) - - if not is_2d: - return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] - return smoothed_arr, arr - smoothed_arr - __all__ = [ "TimesFm2_5Config", From d758c758ee63f542dba7e1944faa89ce1c5b7fc4 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 15:38:47 -0700 Subject: [PATCH 17/20] Remove file --- test_onnx_infer.py | 294 --------------------------------------------- 1 file changed, 294 deletions(-) delete mode 100644 test_onnx_infer.py diff --git a/test_onnx_infer.py b/test_onnx_infer.py deleted file mode 100644 index cab0efa72960..000000000000 --- a/test_onnx_infer.py +++ /dev/null @@ -1,294 +0,0 @@ -# Run after `python test_export_onnx.py` so `test_onnx/model.onnx` exists. -# Uses the same tiny config and `torch.manual_seed(0)` as export for weight parity with the graph. -from __future__ import annotations - -import sys -from pathlib import Path - -import numpy as np -import onnxruntime as ort -import torch -import torch.nn.functional as F - -from transformers import TimesFm2_5Config, TimesFm2_5ModelForPrediction - -ONNX_PATH = Path(__file__).resolve().parent / "test_onnx" / "model.onnx" - - -def tiny_config() -> TimesFm2_5Config: - return TimesFm2_5Config( - patch_length=32, - context_length=128, - horizon_length=8, - hidden_size=32, - intermediate_size=64, - head_dim=16, - num_hidden_layers=1, - num_attention_heads=2, - num_key_value_heads=2, - quantiles=[0.1, 0.5, 0.9], - output_quantile_len=16, - ) - - -def build_model() -> TimesFm2_5ModelForPrediction: - torch.manual_seed(0) - return TimesFm2_5ModelForPrediction(tiny_config()).eval() - - -def test_preprocess_2d_tensor_matches_list_of_rows(model: TimesFm2_5ModelForPrediction) -> None: - """Batched `(B, L)` input must match `_preprocess` on `[row[0], …, row[B-1]]` (values + padding mask).""" - ctx = model.context_len - for seed, batch, seq_len in ( - (11, 1, 1), - (12, 2, 17), - (13, 4, 64), - (14, 3, ctx), - (15, 2, ctx + 40), - ): - torch.manual_seed(seed) - x = torch.randn(batch, seq_len) - rows = [x[i].clone() for i in range(batch)] - a = model._preprocess(x, context_len=ctx) - b = model._preprocess(rows, context_len=ctx) - assert a[0].shape == b[0].shape == (batch, ctx), (a[0].shape, b[0].shape) - h = model.horizon_len - assert a[1].shape == b[1].shape == (batch, ctx + h), (a[1].shape, b[1].shape) - torch.testing.assert_close(a[0], b[0], rtol=0, atol=0, msg="padded time series mismatch") - torch.testing.assert_close(a[1], b[1], rtol=0, atol=0, msg="padding mask mismatch") - - -def test_preprocess_short_2d_left_pad_and_mask_invariants(model: TimesFm2_5ModelForPrediction) -> None: - """After preprocess, leading pad slots are zero in `ts` and one in `padding`.""" - ctx = model.context_len - h = model.horizon_len - torch.manual_seed(21) - seq_len = 40 - b = 2 - x = torch.randn(b, seq_len) - ts, padding = model._preprocess(x, context_len=ctx) - num_front = ctx - seq_len - assert num_front > 0 - assert torch.all(ts[:, :num_front] == 0) - assert torch.all(padding[:, :num_front] == 1) - assert torch.all(padding[:, num_front : num_front + seq_len] == 0) - - -def list_to_left_padded_matrix(parts: list[torch.Tensor], width: int) -> np.ndarray: - """Match `_preprocess` left-padding: short series get zeros on the left.""" - rows = [] - for p in parts: - p = p.float()[-width:] - rows.append(F.pad(p, (width - p.shape[0], 0))) - return torch.stack(rows).numpy() - - -def onnx_output_names(session: ort.InferenceSession) -> list[str]: - return [o.name for o in session.get_outputs()] - - -def onnx_protobuf_input_shape(path: Path, input_name: str) -> tuple[int | str, int | str]: - """Shape of named graph input as in the .onnx file.""" - import onnx - - model = onnx.load(str(path)) - inputs = {i.name: i for i in model.graph.input} - if input_name not in inputs: - raise KeyError(f"Input {input_name!r} not found in graph") - - inp = inputs[input_name] - dims: list[int | str] = [] - for d in inp.type.tensor_type.shape.dim: - param = (d.dim_param or "").strip() - if param: - dims.append(param) - elif d.HasField("dim_value"): - dims.append(int(d.dim_value)) - else: - dims.append("?") - if len(dims) < 2: - return dims[0], "?" - return dims[0], dims[1] - - -def test_onnx_export_batch_axis_contract(session: ort.InferenceSession, onnx_path: Path, ctx: int) -> None: - """ORT must match the **onnx file** (protobuf) axis 0 contract for past_values.""" - pb0, pb1 = onnx_protobuf_input_shape(onnx_path, "past_values") - print(f" protobuf past_values shape: [{pb0!r}, {pb1!r}]") - - inp_name = "past_values" - names = onnx_output_names(session) - out = "mean_predictions" if "mean_predictions" in names else names[0] - - def run_batch(batch_size: int) -> None: - x = np.random.randn(batch_size, ctx).astype(np.float32) - session.run([out], {inp_name: x}) - - if isinstance(pb0, int): - raise AssertionError( - f"Dynamic Batch Requirement Failed: The ONNX file declares a fixed batch dimension {pb0!r} " - f"on 'past_values' (axis 0). Expected a symbolic name (e.g., 'batch')." - ) - else: - run_batch(2) - run_batch(5) - print(f" OK: ORT accepted batch 2 and 5 (protobuf symbolic batch axis {pb0!r}).") - - -def run_onnx(session: ort.InferenceSession, x_np: np.ndarray) -> dict[str, np.ndarray]: - names = onnx_output_names(session) - arrays = session.run(names, {"past_values": x_np}) - return dict(zip(names, arrays, strict=True)) - - -def assert_close(a: np.ndarray, b: np.ndarray, msg: str, rtol: float = 1e-3, atol: float = 1e-3) -> None: - np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg) - - -def should_check_last_hidden_state_ort(model: TimesFm2_5ModelForPrediction, seq_len: int) -> bool: - threshold = model.context_len - model.config.patch_length - return seq_len > threshold - - -def test_pytorch_list_matches_stacked_2d_when_each_series_has_length_ctx(model: TimesFm2_5ModelForPrediction) -> None: - ctx = model.context_len - torch.manual_seed(42) - s0 = torch.randn(ctx) - s1 = torch.randn(ctx) - stacked = torch.stack([s0, s1], dim=0) - with torch.no_grad(): - o_list = model(past_values=[s0, s1]) - o_2d = model(past_values=stacked) - assert_close(o_list.mean_predictions.numpy(), o_2d.mean_predictions.numpy(), "mean_predictions list vs 2D") - assert_close(o_list.full_predictions.numpy(), o_2d.full_predictions.numpy(), "full_predictions list vs 2D") - - -def test_variable_length_list_vs_prepadded_2d_differs_in_padding_mask(model: TimesFm2_5ModelForPrediction) -> None: - ctx = model.context_len - data = [torch.linspace(0, 1, 100), torch.sin(torch.linspace(0, 20, 67))] - matrix_t = torch.from_numpy(list_to_left_padded_matrix(data, ctx)) - with torch.no_grad(): - o_list = model(past_values=data) - o_pad2d = model(past_values=matrix_t) - assert ((o_list.mean_predictions - o_pad2d.mean_predictions).abs().max() > 1e-3) - - -def test_onnx_matches_pytorch_all_outputs( - session: ort.InferenceSession, - model: TimesFm2_5ModelForPrediction, - x: torch.Tensor, -) -> None: - names = onnx_output_names(session) - with torch.no_grad(): - pt = model(past_values=x) - - ort_dict = run_onnx(session, x.numpy()) - if "mean_predictions" in ort_dict: - assert_close(ort_dict["mean_predictions"], pt.mean_predictions.numpy(), "mean_predictions ORT vs PT") - if "full_predictions" in ort_dict: - assert_close(ort_dict["full_predictions"], pt.full_predictions.numpy(), "full_predictions ORT vs PT") - - -def run_dynamic_shape_checks(session: ort.InferenceSession, model: TimesFm2_5ModelForPrediction, onnx_path: Path) -> None: - pb_batch, pb_seq = onnx_protobuf_input_shape(onnx_path, "past_values") - ctx = model.context_len - - if not isinstance(pb_batch, str): - raise AssertionError(f"Dynamic Batch Check Failed: {pb_batch!r}") - - for batch in (1, 3, 5): - torch.manual_seed(100 + batch) - x = torch.randn(batch, ctx) - test_onnx_matches_pytorch_all_outputs(session, model, x) - print(f"ONNX vs PyTorch OK: batch_size={batch}, seq_len={ctx}") - - if not isinstance(pb_seq, str): - raise AssertionError(f"Dynamic Sequence Check Failed: {pb_seq!r}") - - for seq_len in (1, 32, 64, ctx, 200): - torch.manual_seed(300 + seq_len) - x = torch.randn(2, seq_len) - test_onnx_matches_pytorch_all_outputs(session, model, x) - print(f"ONNX vs PyTorch OK: batch_size=2, seq_len={seq_len}") - - -def test_input_min_and_type_parity(model: TimesFm2_5ModelForPrediction) -> None: - """ - Verifies the fix for tensor vs list inputs. - Checks that input_min is calculated across ALL rows, which affects truncate_negative. - """ - ctx = model.context_len - # Row 0 is all positive, Row 1 has a negative value. - # If the model only checked the first row (or handled the list incorrectly), - # it might wrongly decide to clamp outputs. - s0 = torch.ones(ctx) * 10.0 - s1 = torch.ones(ctx) * 10.0 - s1[5] = -100.0 # The negative value is in the second row - - stacked = torch.stack([s0, s1], dim=0) - list_input = [s0, s1] - - # We need a case where the model WOULD produce a negative value to see if it gets clamped. - # Since we use random weights, we'll just check that outputs match between tensor and list paths. - with torch.no_grad(): - out_tensor = model(past_values=stacked, truncate_negative=True) - out_list = model(past_values=list_input, truncate_negative=True) - - assert_close( - out_tensor.mean_predictions.numpy(), - out_list.mean_predictions.numpy(), - "Input type parity failed (tensor vs list with negative value)" - ) - print(" OK: Tensor and List paths matched for mixed-sign inputs.") - - -def test_window_size_tensor_vs_list_parity(model: TimesFm2_5ModelForPrediction) -> None: - """Verifies that the new batched window_size logic for tensors matches the list logic.""" - ctx = model.context_len - batch = 3 - window_size = 4 - torch.manual_seed(123) - x = torch.randn(batch, ctx) - rows = [x[i].clone() for i in range(batch)] - - with torch.no_grad(): - out_tensor = model(past_values=x, window_size=window_size) - out_list = model(past_values=rows, window_size=window_size) - - # We expect (batch, horizon_len) rows in the output - h = model.horizon_len - assert out_tensor.mean_predictions.shape == out_list.mean_predictions.shape == (batch, h) - torch.testing.assert_close(out_tensor.mean_predictions, out_list.mean_predictions, rtol=1e-5, atol=1e-5) - print(f" OK: Tensor vs List window_size parity at B={batch}, W={window_size}") - - -def main() -> None: - if not ONNX_PATH.is_file(): - print(f"Missing {ONNX_PATH}", file=sys.stderr) - sys.exit(1) - - model = build_model() - ctx = model.context_len - - print("PyTorch tests...") - test_preprocess_2d_tensor_matches_list_of_rows(model) - test_preprocess_short_2d_left_pad_and_mask_invariants(model) - test_input_min_and_type_parity(model) - test_window_size_tensor_vs_list_parity(model) - print(" OK") - - session = ort.InferenceSession(str(ONNX_PATH), providers=["CPUExecutionProvider"]) - - print("ONNX: past_values batch contract...") - test_onnx_export_batch_axis_contract(session, ONNX_PATH, ctx) - print(" OK") - - print("ONNX dynamic shape parity...") - run_dynamic_shape_checks(session, model, ONNX_PATH) - print(" OK") - - print("All checks passed.") - - -if __name__ == "__main__": - main() From 35238ddb2687519c24ca93791288470de71b5e20 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 15:44:51 -0700 Subject: [PATCH 18/20] formatting issues --- src/transformers/models/timesfm/modular_timesfm.py | 4 +++- tests/models/timesfm2_5/test_modeling_timesfm2_5.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 283603633514..f18d654ad3ca 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -780,7 +780,9 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: + def _timesfm_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" # arr shape: (T,) or (B, T) is_2d = arr.ndim == 2 diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 04944eb6a636..6a703567cbdb 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -337,7 +337,6 @@ def test_window_size_parity(self): ) - @require_torch @slow class TimesFm2_5ModelIntegrationTests(unittest.TestCase): From 80fa774566d66a0352c2d29b29b47d9b095cee87 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 16:08:32 -0700 Subject: [PATCH 19/20] fix tests --- src/transformers/models/timesfm2_5/modeling_timesfm2_5.py | 4 ++-- src/transformers/models/timesfm2_5/modular_timesfm2_5.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index bc7c233609ae..ba7839b8770b 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -799,12 +799,12 @@ def forward( if window_size is not None: if is_tensor: - trend, residual = self._timesfm_moving_average(inputs, window_size) + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) else: new_inputs: list[torch.Tensor] = [] for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + new_inputs.extend(self._timesfm2_5_moving_average(ts, window_size)) inputs = new_inputs if truncate_negative is None: diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index c6cfbde0de04..3c19a5eec672 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -569,12 +569,12 @@ def forward( if window_size is not None: if is_tensor: - trend, residual = self._timesfm_moving_average(inputs, window_size) + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) else: new_inputs: list[torch.Tensor] = [] for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + new_inputs.extend(self._timesfm2_5_moving_average(ts, window_size)) inputs = new_inputs if truncate_negative is None: From d99e1be62c4a8754f3e2aa3fcef5a467ec254913 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Fri, 10 Apr 2026 16:10:19 -0700 Subject: [PATCH 20/20] add other tests --- .../timesfm2_5/test_modeling_timesfm2_5.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 6a703567cbdb..3b22dd240ca5 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -336,6 +336,23 @@ def test_window_size_parity(self): f"Full prediction parity failed for b={b}, slen={slen}, w={w}", ) + def test_input_min_parity(self): + """forward() with truncate_negative=True gives identical results for list and tensor inputs, even with mixed signs.""" + # Create sequences where one is all positive and another has negative values + raw = [ + torch.tensor([1.0, 2.0, 3.0], device=torch_device), + torch.tensor([-1.0, 0.0, 1.0], device=torch_device), + ] + stacked = torch.stack(raw) + + with torch.no_grad(): + # truncate_negative=True will use input_min to decide whether to clamp + out_list = self.model(past_values=raw, truncate_negative=True) + out_tensor = self.model(past_values=stacked, truncate_negative=True) + + self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) + self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) + @require_torch @slow