diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 219908c1e47c..9442c1007ea2 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -590,7 +590,10 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. @@ -605,23 +608,37 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) + padding = torch.cat( + [ + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), + ], + dim=1, + ) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result @@ -701,21 +718,33 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if is_tensor: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + if freq is not None: + freq = torch.repeat_interleave(freq, 2, dim=0) + else: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + freq = new_freqs if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") @@ -794,15 +823,29 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index ca53ec7dd668..f18d654ad3ca 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -547,7 +547,10 @@ def __init__(self, config: TimesFmConfig): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. @@ -562,23 +565,37 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) + padding = torch.cat( + [ + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), + ], + dim=1, + ) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result @@ -658,21 +675,33 @@ def forward( else: fcontext_len = forecast_context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-fcontext_len:] for ts in past_values] - inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -fcontext_len:] + inp_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs = [] - new_freqs = [] - for i, ts in enumerate(inputs): - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if is_tensor: + trend, residual = self._timesfm_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + if freq is not None: + freq = torch.repeat_interleave(freq, 2, dim=0) + else: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs if freq is not None: - new_freqs.extend([freq[i]] * 2) - inputs = new_inputs - if freq is not None: - freq = new_freqs + freq = new_freqs if freq is None: logger.info("No frequency provided via `freq`. Default to high (0).") @@ -751,15 +780,29 @@ def forward( ) @staticmethod - def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py index e7b4e799d20b..ba7839b8770b 100644 --- a/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modeling_timesfm2_5.py @@ -682,7 +682,10 @@ def __init__(self, config: TimesFm2_5Config): self.post_init() def _preprocess( - self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None + self, + inputs: Sequence[torch.Tensor] | torch.Tensor, + freq: Sequence[int] | None = None, + context_len: int | None = None, ) -> tuple[torch.Tensor, ...]: """Pad/truncate input time series to `context_len` and build a padding mask. @@ -697,23 +700,37 @@ def _preprocess( if context_len is None: context_len = self.context_len - input_ts, input_padding = [], [] - - for ts in inputs: - input_len = ts.shape[0] - padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) - if input_len < context_len: - num_front_pad = context_len - input_len - ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) - padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) - elif input_len > context_len: + if isinstance(inputs, torch.Tensor) and inputs.ndim == 2: + x = inputs[:, -context_len:] + num_front_pad = context_len - x.shape[1] + x = F.pad(x, (num_front_pad, 0)) + padding = torch.cat( + [ + torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device), + torch.zeros( + x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device + ), + ], + dim=1, + ) + result = (x, padding) + else: + input_ts, input_padding = [], [] + for ts in inputs: ts = ts[-context_len:] - padding = padding[-(context_len + self.horizon_len) :] - - input_ts.append(ts) - input_padding.append(padding) + num_front_pad = context_len - ts.shape[0] + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat( + [ + torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device), + torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device), + ], + dim=0, + ) + input_ts.append(ts) + input_padding.append(padding) + result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) - result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0)) if freq is not None: result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),) return result @@ -745,7 +762,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -769,16 +786,26 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -forecast_context_len:] + input_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-forecast_context_len:] for ts in past_values] + input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + if is_tensor: + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + else: + new_inputs: list[torch.Tensor] = [] + for ts in inputs: + new_inputs.extend(self._timesfm2_5_moving_average(ts, window_size)) + inputs = new_inputs if truncate_negative is None: truncate_negative = self.config.infer_is_positive @@ -864,15 +891,29 @@ def _flip_quantiles(x: torch.Tensor) -> torch.Tensor: ) @staticmethod - def _timesfm2_5_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + def _timesfm2_5_moving_average( + arr: torch.Tensor, window_size: int + ) -> list[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: """Calculates the moving average using PyTorch's convolution function.""" + # arr shape: (T,) or (B, T) + is_2d = arr.ndim == 2 + if not is_2d: + arr = arr.unsqueeze(0) # (1, T) + # Pad with zeros to handle initial window positions - arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) # (B, T + window_size - 1) + # Create a convolution kernel kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + kernel = kernel.view(1, 1, -1) # (1, 1, window_size) + # Apply convolution to calculate the moving average - smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() - return [smoothed_arr, arr - smoothed_arr] + # F.conv1d expects (N, C_in, L_in) + smoothed_arr = F.conv1d(arr_padded.unsqueeze(1), kernel).squeeze(1) # (B, T) + + if not is_2d: + return [smoothed_arr.squeeze(0), (arr - smoothed_arr).squeeze(0)] + return smoothed_arr, arr - smoothed_arr def _decode_and_project( self, diff --git a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py index 3a912d07946b..3c19a5eec672 100644 --- a/src/transformers/models/timesfm2_5/modular_timesfm2_5.py +++ b/src/transformers/models/timesfm2_5/modular_timesfm2_5.py @@ -532,7 +532,7 @@ def _decode_and_project( @auto_docstring def forward( self, - past_values: Sequence[torch.Tensor], + past_values: Sequence[torch.Tensor] | torch.Tensor, window_size: int | None = None, future_values: torch.Tensor | None = None, forecast_context_len: int | None = None, @@ -556,16 +556,26 @@ def forward( `config.force_flip_invariance`. """ forecast_context_len = forecast_context_len or self.context_len - device = past_values[0].device + is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2 - inputs = [ts[-forecast_context_len:] for ts in past_values] - input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + if is_tensor: + device = past_values.device + inputs = past_values[:, -forecast_context_len:] + input_min = inputs.min() + else: + device = past_values[0].device + inputs = [ts[-forecast_context_len:] for ts in past_values] + input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) if window_size is not None: - new_inputs: list[torch.Tensor] = [] - for ts in inputs: - new_inputs.extend(self._timesfm_moving_average(ts, window_size)) - inputs = new_inputs + if is_tensor: + trend, residual = self._timesfm2_5_moving_average(inputs, window_size) + inputs = torch.stack([trend, residual], dim=1).view(2 * inputs.shape[0], -1) + else: + new_inputs: list[torch.Tensor] = [] + for ts in inputs: + new_inputs.extend(self._timesfm2_5_moving_average(ts, window_size)) + inputs = new_inputs if truncate_negative is None: truncate_negative = self.config.infer_is_positive diff --git a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py index 7a909da6d78c..3b22dd240ca5 100644 --- a/tests/models/timesfm2_5/test_modeling_timesfm2_5.py +++ b/tests/models/timesfm2_5/test_modeling_timesfm2_5.py @@ -283,6 +283,77 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): return inputs_dict +@require_torch +class TimesFm2_5ForwardInputVariantsTest(unittest.TestCase): + def setUp(self): + config = TimesFm2_5Config( + patch_length=32, + context_length=128, + horizon_length=8, + hidden_size=32, + intermediate_size=64, + head_dim=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + quantiles=[0.1, 0.5, 0.9], + output_quantile_len=16, + ) + self.model = TimesFm2_5ModelForPrediction(config).to(torch_device).eval() + self.horizon_len = config.horizon_length + + def test_list_vs_tensor_parity(self): + """forward() with list and 2D tensor of equal-length series gives identical output.""" + raw = [torch.randn(60, device=torch_device) for _ in range(2)] + stacked = torch.stack(raw) + with torch.no_grad(): + out_list = self.model(past_values=raw) + out_tensor = self.model(past_values=stacked) + self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) + self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) + + def test_window_size_parity(self): + """forward() with window_size works for both list and 2D tensor inputs and gives identical output across edge cases.""" + batch_sizes = [1, 4] + seq_lengths = [32, 64] + window_sizes = [1, 4, 32, 128] + + for b in batch_sizes: + for slen in seq_lengths: + for w in window_sizes: + raw = [torch.randn(slen, device=torch_device) for _ in range(b)] + stacked = torch.stack(raw) + with torch.no_grad(): + out_list = self.model(past_values=raw, window_size=w) + out_tensor = self.model(past_values=stacked, window_size=w) + + self.assertTrue( + torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5), + f"Parity failed for b={b}, slen={slen}, w={w}", + ) + self.assertTrue( + torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5), + f"Full prediction parity failed for b={b}, slen={slen}, w={w}", + ) + + def test_input_min_parity(self): + """forward() with truncate_negative=True gives identical results for list and tensor inputs, even with mixed signs.""" + # Create sequences where one is all positive and another has negative values + raw = [ + torch.tensor([1.0, 2.0, 3.0], device=torch_device), + torch.tensor([-1.0, 0.0, 1.0], device=torch_device), + ] + stacked = torch.stack(raw) + + with torch.no_grad(): + # truncate_negative=True will use input_min to decide whether to clamp + out_list = self.model(past_values=raw, truncate_negative=True) + out_tensor = self.model(past_values=stacked, truncate_negative=True) + + self.assertTrue(torch.allclose(out_list.mean_predictions, out_tensor.mean_predictions, atol=1e-5)) + self.assertTrue(torch.allclose(out_list.full_predictions, out_tensor.full_predictions, atol=1e-5)) + + @require_torch @slow class TimesFm2_5ModelIntegrationTests(unittest.TestCase):