diff --git a/kombu/utils/json.py b/kombu/utils/json.py index 46326c1098..ad8cf73e2a 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -23,6 +23,12 @@ class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" def default(self, o): + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return ( + encoder(o) if marker is None else _as(marker, encoder(o)) + ) + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() @@ -30,7 +36,7 @@ def default(self, o): if isinstance(o, textual_types): return str(o) - for t, (marker, encoder) in _encoders.items(): + for t, (marker, encoder) in _default_encoders.items(): if isinstance(o, t): return ( encoder(o) if marker is None else _as(marker, encoder(o)) @@ -66,7 +72,7 @@ def dumps( def object_hook(o: dict): """Hook function to perform custom deserialization.""" if o.keys() == {"__type__", "__value__"}: - decoder = _decoders.get(o["__type__"]) + decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"]) if decoder: return decoder(o["__value__"]) else: @@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): T = TypeVar("T") EncodedT = TypeVar("EncodedT") +# Separate user registered types from Kombu registered types to allow us to give preference to user types +_encoders: dict[type, tuple[str | None, EncoderT]] = {} +_decoders: dict[str, DecoderT] = {} + +_default_encoders: dict[type, tuple[str | None, EncoderT]] = {} +_default_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + def register_type( t: type[T], @@ -110,32 +126,40 @@ def register_type( is not placed in an envelope, so `decoder` is unnecessary. Decoding must instead be handled outside this library. """ - _encoders[t] = (marker, encoder) - if marker is not None: - _decoders[marker] = decoder + _register_type(t, marker, encoder, decoder, is_default_encoder=False) -_encoders: dict[type, tuple[str | None, EncoderT]] = {} -_decoders: dict[str, DecoderT] = { - "bytes": lambda o: o.encode("utf-8"), - "base64": lambda o: base64.b64decode(o.encode("utf-8")), -} +def _register_type( + t: type[T], + marker: str | None, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T] = lambda d: d, + is_default_encoder: bool = True, +): + if is_default_encoder: + _default_encoders[t] = (marker, encoder) + if marker is not None: + _default_decoders[marker] = decoder + else: + _encoders[t] = (marker, encoder) + if marker is not None: + _decoders[marker] = decoder def _register_default_types(): # NOTE: datetime should be registered before date, # because datetime is also instance of date. - register_type(datetime, "datetime", datetime.isoformat, - datetime.fromisoformat) - register_type( + _register_type(datetime, "datetime", datetime.isoformat, + datetime.fromisoformat) + _register_type( date, "date", lambda o: o.isoformat(), - lambda o: datetime.fromisoformat(o).date(), + lambda o: datetime.fromisoformat(o).date() ) - register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) - register_type(Decimal, "decimal", str, Decimal) - register_type( + _register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) + _register_type(Decimal, "decimal", str, Decimal) + _register_type( uuid.UUID, "uuid", lambda o: {"hex": o.hex}, diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index 723bd09bbc..1bb795d11e 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -95,6 +95,20 @@ def test_register_type_overrides_defaults(self): loaded_value = loads(dumps({'u': value})) assert loaded_value == {'u': "custom"} + def test_register_type_takes_priority(self): + class MyDecimal(Decimal): + pass + + register_type(MyDecimal, "mydecimal", str, MyDecimal) + original = {'md': MyDecimal('3314132.13363235235324234123213213214134')} + serialized_str = dumps(original) + # Ensure our custom marker is used instead of the default Decimal handler + assert '"mydecimal"' in serialized_str + loaded_value = loads(serialized_str) + # Ensure the decoded value is of the registered subclass, not just equal + assert isinstance(loaded_value['md'], MyDecimal) + assert original == loaded_value + def test_register_type_with_new_type(self): # Guaranteed never before seen type @dataclass()