Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions src/transformers/models/timesfm/modeling_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,10 @@ def __init__(self, config: TimesFmConfig):
self.post_init()

def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
self,
inputs: Sequence[torch.Tensor] | torch.Tensor,
freq: Sequence[int] | None = None,
context_len: int | None = None,
) -> tuple[torch.Tensor, ...]:
"""Pad/truncate input time series to `context_len` and build a padding mask.

Expand All @@ -605,25 +608,44 @@ def _preprocess(
if context_len is None:
context_len = self.context_len

input_ts, input_padding = [], []

for ts in inputs:
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < context_len:
num_front_pad = context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > context_len:
if isinstance(inputs, torch.Tensor) and inputs.ndim == 2:
x = inputs[:, -context_len:]
num_front_pad = context_len - x.shape[1]
x = F.pad(x, (num_front_pad, 0))
padding = torch.cat(
[
torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device),
torch.zeros(
x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device
),
],
dim=1,
)
result = (x, padding)
else:
input_ts, input_padding = [], []
for ts in inputs:
ts = ts[-context_len:]
padding = padding[-(context_len + self.horizon_len) :]

input_ts.append(ts)
input_padding.append(padding)
num_front_pad = context_len - ts.shape[0]
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat(
[
torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device),
torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device),
],
dim=0,
)
input_ts.append(ts)
input_padding.append(padding)
result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))

result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
if freq is not None:
result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
if isinstance(freq, torch.Tensor):
inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1)
else:
batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs)
inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1)
result = result + (inp_freq,)
return result

def _postprocess_output(
Expand Down Expand Up @@ -653,7 +675,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to
@auto_docstring
def forward(
self,
past_values: Sequence[torch.Tensor],
past_values: Sequence[torch.Tensor] | torch.Tensor,
freq: Sequence[torch.Tensor | int] | None = None,
window_size: int | None = None,
future_values: torch.Tensor | None = None,
Expand Down Expand Up @@ -701,12 +723,20 @@ def forward(
else:
fcontext_len = forecast_context_len

device = past_values[0].device
is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2

inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
if is_tensor:
device = past_values.device
inputs = past_values[:, -fcontext_len:]
inp_min = inputs.min()
else:
device = past_values[0].device
inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))

if window_size is not None:
if is_tensor:
raise ValueError("window_size is not supported when past_values is a 2D tensor.")
new_inputs = []
new_freqs = []
for i, ts in enumerate(inputs):
Expand All @@ -719,7 +749,11 @@ def forward(

if freq is None:
logger.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
if is_tensor:
# Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int.
freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device)
else:
freq = [0] * len(inputs)

input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
input_ts = input_ts.to(device)
Expand Down Expand Up @@ -774,9 +808,9 @@ def forward(
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
if inp_min >= 0 and truncate_negative:
mean_outputs = torch.maximum(mean_outputs, 0.0)
full_outputs = torch.maximum(full_outputs, 0.0)
if truncate_negative:
mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs)
full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs)

loss = None
if future_values is not None:
Expand Down
84 changes: 59 additions & 25 deletions src/transformers/models/timesfm/modular_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,10 @@ def __init__(self, config: TimesFmConfig):
self.post_init()

def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
self,
inputs: Sequence[torch.Tensor] | torch.Tensor,
freq: Sequence[int] | None = None,
context_len: int | None = None,
) -> tuple[torch.Tensor, ...]:
"""Pad/truncate input time series to `context_len` and build a padding mask.

Expand All @@ -562,25 +565,44 @@ def _preprocess(
if context_len is None:
context_len = self.context_len

input_ts, input_padding = [], []

for ts in inputs:
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < context_len:
num_front_pad = context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > context_len:
if isinstance(inputs, torch.Tensor) and inputs.ndim == 2:
x = inputs[:, -context_len:]
num_front_pad = context_len - x.shape[1]
x = F.pad(x, (num_front_pad, 0))
padding = torch.cat(
[
torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device),
torch.zeros(
x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device
),
],
dim=1,
)
result = (x, padding)
else:
input_ts, input_padding = [], []
for ts in inputs:
ts = ts[-context_len:]
padding = padding[-(context_len + self.horizon_len) :]

input_ts.append(ts)
input_padding.append(padding)
num_front_pad = context_len - ts.shape[0]
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat(
[
torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device),
torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device),
],
dim=0,
)
input_ts.append(ts)
input_padding.append(padding)
result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))

result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
if freq is not None:
result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
if isinstance(freq, torch.Tensor):
inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1)
else:
batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs)
inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1)
result = result + (inp_freq,)
return result

def _postprocess_output(
Expand Down Expand Up @@ -610,7 +632,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to
@auto_docstring
def forward(
self,
past_values: Sequence[torch.Tensor],
past_values: Sequence[torch.Tensor] | torch.Tensor,
freq: Sequence[torch.Tensor | int] | None = None,
window_size: int | None = None,
future_values: torch.Tensor | None = None,
Expand Down Expand Up @@ -658,12 +680,20 @@ def forward(
else:
fcontext_len = forecast_context_len

device = past_values[0].device
is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2

inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
if is_tensor:
device = past_values.device
inputs = past_values[:, -fcontext_len:]
inp_min = inputs.min()
else:
device = past_values[0].device
inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))

if window_size is not None:
if is_tensor:
raise ValueError("window_size is not supported when past_values is a 2D tensor.")
new_inputs = []
new_freqs = []
for i, ts in enumerate(inputs):
Expand All @@ -676,7 +706,11 @@ def forward(

if freq is None:
logger.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
if is_tensor:
# Tensor path keeps batch symbolic for ONNX; `[0] * shape[0]` materializes batch as a Python int.
freq = torch.zeros(past_values.shape[0], 1, dtype=torch.int32, device=device)
else:
freq = [0] * len(inputs)

input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
input_ts = input_ts.to(device)
Expand Down Expand Up @@ -731,9 +765,9 @@ def forward(
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
if inp_min >= 0 and truncate_negative:
mean_outputs = torch.maximum(mean_outputs, 0.0)
full_outputs = torch.maximum(full_outputs, 0.0)
if truncate_negative:
mean_outputs = torch.where(inp_min >= 0, mean_outputs.clamp(min=0), mean_outputs)
full_outputs = torch.where(inp_min >= 0, full_outputs.clamp(min=0), full_outputs)

loss = None
if future_values is not None:
Expand Down
72 changes: 51 additions & 21 deletions src/transformers/models/timesfm2_5/modeling_timesfm2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,10 @@ def __init__(self, config: TimesFm2_5Config):
self.post_init()

def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
self,
inputs: Sequence[torch.Tensor] | torch.Tensor,
freq: Sequence[int] | None = None,
context_len: int | None = None,
) -> tuple[torch.Tensor, ...]:
"""Pad/truncate input time series to `context_len` and build a padding mask.

Expand All @@ -697,25 +700,44 @@ def _preprocess(
if context_len is None:
context_len = self.context_len

input_ts, input_padding = [], []

for ts in inputs:
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < context_len:
num_front_pad = context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > context_len:
if isinstance(inputs, torch.Tensor) and inputs.ndim == 2:
x = inputs[:, -context_len:]
num_front_pad = context_len - x.shape[1]
x = F.pad(x, (num_front_pad, 0))
padding = torch.cat(
[
torch.ones(x.shape[0], num_front_pad, dtype=x.dtype, device=x.device),
torch.zeros(
x.shape[0], context_len + self.horizon_len - num_front_pad, dtype=x.dtype, device=x.device
),
],
dim=1,
)
result = (x, padding)
else:
input_ts, input_padding = [], []
for ts in inputs:
ts = ts[-context_len:]
padding = padding[-(context_len + self.horizon_len) :]

input_ts.append(ts)
input_padding.append(padding)
num_front_pad = context_len - ts.shape[0]
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat(
[
torch.ones(num_front_pad, dtype=ts.dtype, device=ts.device),
torch.zeros(context_len + self.horizon_len - num_front_pad, dtype=ts.dtype, device=ts.device),
],
dim=0,
)
input_ts.append(ts)
input_padding.append(padding)
result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))

result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
if freq is not None:
result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
if isinstance(freq, torch.Tensor):
inp_freq = freq if freq.ndim == 2 else freq.reshape(-1, 1)
else:
batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else len(inputs)
inp_freq = torch.tensor(freq[:batch_size], dtype=torch.int32).reshape(-1, 1)
result = result + (inp_freq,)
return result

def _postprocess_output(
Expand Down Expand Up @@ -745,7 +767,7 @@ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> to
@auto_docstring
def forward(
self,
past_values: Sequence[torch.Tensor],
past_values: Sequence[torch.Tensor] | torch.Tensor,
window_size: int | None = None,
future_values: torch.Tensor | None = None,
forecast_context_len: int | None = None,
Expand All @@ -769,12 +791,20 @@ def forward(
`config.force_flip_invariance`.
"""
forecast_context_len = forecast_context_len or self.context_len
device = past_values[0].device
is_tensor = isinstance(past_values, torch.Tensor) and past_values.ndim == 2

inputs = [ts[-forecast_context_len:] for ts in past_values]
input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
if is_tensor:
device = past_values.device
inputs = past_values[:, -forecast_context_len:]
input_min = inputs.min()
else:
device = past_values[0].device
inputs = [ts[-forecast_context_len:] for ts in past_values]
input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))

if window_size is not None:
if is_tensor:
raise ValueError("window_size is not supported when past_values is a 2D tensor.")
new_inputs: list[torch.Tensor] = []
for ts in inputs:
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
Expand Down
Loading
Loading