-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Give user registered types priority when encoding / decoding JSON #2188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
7a20cbd
b6c080b
c4b3ff6
bd37e27
fdd48db
238e6a1
f6a7719
ccbc8df
3331819
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,14 +23,20 @@ class JSONEncoder(json.JSONEncoder): | |
| """Kombu custom json encoder.""" | ||
|
|
||
| def default(self, o): | ||
| for t, (marker, encoder) in _encoders.items(): | ||
| if isinstance(o, t): | ||
| return ( | ||
| encoder(o) if marker is None else _as(marker, encoder(o)) | ||
| ) | ||
|
|
||
| reducer = getattr(o, "__json__", None) | ||
| if reducer is not None: | ||
| return reducer() | ||
|
Comment on lines
25
to
34
|
||
|
|
||
| if isinstance(o, textual_types): | ||
| return str(o) | ||
|
|
||
| for t, (marker, encoder) in _encoders.items(): | ||
| for t, (marker, encoder) in _default_encoders.items(): | ||
| if isinstance(o, t): | ||
| return ( | ||
| encoder(o) if marker is None else _as(marker, encoder(o)) | ||
|
|
@@ -66,7 +72,7 @@ def dumps( | |
| def object_hook(o: dict): | ||
| """Hook function to perform custom deserialization.""" | ||
| if o.keys() == {"__type__", "__value__"}: | ||
| decoder = _decoders.get(o["__type__"]) | ||
| decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"]) | ||
| if decoder: | ||
| return decoder(o["__value__"]) | ||
| else: | ||
|
|
@@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): | |
| T = TypeVar("T") | ||
| EncodedT = TypeVar("EncodedT") | ||
|
|
||
| # Separate user registered types from Kombu registered types to allow us to give preference to user types | ||
| _encoders: dict[type, tuple[str | None, EncoderT]] = {} | ||
| _decoders: dict[str, DecoderT] = {} | ||
|
|
||
| _default_encoders: dict[type, tuple[str | None, EncoderT]] = {} | ||
| _default_decoders: dict[str, DecoderT] = { | ||
| "bytes": lambda o: o.encode("utf-8"), | ||
| "base64": lambda o: base64.b64decode(o.encode("utf-8")), | ||
| } | ||
|
|
||
|
|
||
| def register_type( | ||
| t: type[T], | ||
|
|
@@ -110,32 +126,40 @@ def register_type( | |
| is not placed in an envelope, so `decoder` is unnecessary. Decoding must | ||
| instead be handled outside this library. | ||
| """ | ||
| _encoders[t] = (marker, encoder) | ||
| if marker is not None: | ||
| _decoders[marker] = decoder | ||
| _register_type(t, marker, encoder, decoder, is_default_encoder=False) | ||
|
|
||
|
|
||
| _encoders: dict[type, tuple[str | None, EncoderT]] = {} | ||
| _decoders: dict[str, DecoderT] = { | ||
| "bytes": lambda o: o.encode("utf-8"), | ||
| "base64": lambda o: base64.b64decode(o.encode("utf-8")), | ||
| } | ||
| def _register_type( | ||
| t: type[T], | ||
| marker: str | None, | ||
| encoder: Callable[[T], EncodedT], | ||
| decoder: Callable[[EncodedT], T] = lambda d: d, | ||
| is_default_encoder: bool = True, | ||
| ): | ||
| if is_default_encoder: | ||
| _default_encoders[t] = (marker, encoder) | ||
| if marker is not None: | ||
| _default_decoders[marker] = decoder | ||
| else: | ||
| _encoders[t] = (marker, encoder) | ||
| if marker is not None: | ||
| _decoders[marker] = decoder | ||
|
|
||
|
|
||
| def _register_default_types(): | ||
| # NOTE: datetime should be registered before date, | ||
| # because datetime is also instance of date. | ||
| register_type(datetime, "datetime", datetime.isoformat, | ||
| datetime.fromisoformat) | ||
| register_type( | ||
| _register_type(datetime, "datetime", datetime.isoformat, | ||
| datetime.fromisoformat) | ||
| _register_type( | ||
|
Comment on lines
149
to
+154
|
||
| date, | ||
| "date", | ||
| lambda o: o.isoformat(), | ||
| lambda o: datetime.fromisoformat(o).date(), | ||
| lambda o: datetime.fromisoformat(o).date() | ||
| ) | ||
| register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) | ||
| register_type(Decimal, "decimal", str, Decimal) | ||
| register_type( | ||
| _register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) | ||
| _register_type(Decimal, "decimal", str, Decimal) | ||
| _register_type( | ||
|
Comment on lines
149
to
+162
|
||
| uuid.UUID, | ||
| "uuid", | ||
| lambda o: {"hex": o.hex}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,20 @@ def test_register_type_overrides_defaults(self): | |
| loaded_value = loads(dumps({'u': value})) | ||
| assert loaded_value == {'u': "custom"} | ||
|
|
||
| def test_register_type_takes_priority(self): | ||
| class MyDecimal(Decimal): | ||
| pass | ||
|
|
||
| register_type(MyDecimal, "mydecimal", str, MyDecimal) | ||
| original = {'md': MyDecimal('3314132.13363235235324234123213213214134')} | ||
| serialized_str = dumps(original) | ||
| # Ensure our custom marker is used instead of the default Decimal handler | ||
| assert '"mydecimal"' in serialized_str | ||
| loaded_value = loads(serialized_str) | ||
|
Comment on lines
+104
to
+107
|
||
| # Ensure the decoded value is of the registered subclass, not just equal | ||
| assert isinstance(loaded_value['md'], MyDecimal) | ||
| assert original == loaded_value | ||
|
Comment on lines
+98
to
+110
|
||
|
|
||
| def test_register_type_with_new_type(self): | ||
| # Guaranteed never before seen type | ||
| @dataclass() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change gives priority to user-registered types within
JSONEncoder.default, but stdlibjsondoes not calldefault()for JSON primitives (e.g.str/int/float/bool/None) or their subclasses. If the intent is to solve #1895'sSafeString(astrsubclass) example, this implementation likely won’t affect that case; it would require intercepting encoding before the primitive fast-path (e.g. overridingiterencode/preprocessing).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soceanainn please cross check this and other suggestions