Skip to content

Commit 6fec5c4

Browse files
committed
Fix HookedTransformerConfig rotary_base types
rotary_base is frequently set to floats in the code but was typed as an int. https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984 Non-integer rotary bases are unusual but not illegal, and HF configs' rope_theta is a float: https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645 This updates the type to Union[float, int] to prevent beartype errors when loading these configs in tests. Note that beartype doesn't consider int to be a subclass of float: beartype/beartype#66
1 parent 9c5a2a8 commit 6fec5c4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

transformer_lens/HookedTransformerConfig.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class HookedTransformerConfig:
194194
Defaults to 8.0.
195195
use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before
196196
computing attention scores. Used by Gemma 3 models. Defaults to False.
197-
rotary_base_local (int, *optional*): The base for rotary positional embeddings in local
197+
rotary_base_local (float, *optional*): The base for rotary positional embeddings in local
198198
attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3)
199199
which use different RoPE bases for local (10k) and global (1M) attention. Defaults
200200
to None, which means the standard rotary_base is used for all layers.
@@ -252,9 +252,9 @@ class HookedTransformerConfig:
252252
tokenizer_prepends_bos: Optional[bool] = None
253253
n_key_value_heads: Optional[int] = None
254254
post_embedding_ln: bool = False
255-
rotary_base: int = 10000
255+
rotary_base: Union[float, int] = 10000
256256
rotary_base_local: Optional[
257-
int
257+
Union[float, int]
258258
] = None # For models with different RoPE bases per attention type (e.g., Gemma 3)
259259
trust_remote_code: bool = False
260260
rotary_adjacent_pairs: bool = False

transformer_lens/components/abstract_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def calculate_sin_cos_rotary(
532532
self,
533533
rotary_dim: int,
534534
n_ctx: int,
535-
base: int = 10000,
535+
base: Union[float, int] = 10000,
536536
dtype: torch.dtype = torch.float32,
537537
) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
538538
"""

0 commit comments

Comments
 (0)