Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-126317: Simplify stdlib code by using itertools.batched() #126323

Merged
merged 2 commits into from
Nov 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 21 additions & 39 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from types import FunctionType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
from itertools import batched
from functools import partial
import sys
from sys import maxsize
Expand Down Expand Up @@ -1033,31 +1033,25 @@ def _batch_appends(self, items, obj):
write(APPEND)
return

it = iter(items)
start = 0
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth using tmp instead of batch? It reduces the diff, but feel free to disregard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to have a named variable that says what it is. "tmp" is too generic IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think chunk would be better, should we change it to chunk?

Copy link
Contributor

@picnixz picnixz Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use "batchsize" I think it's fine to keep "batch"

if len(batch) != 1:
write(MARK)
for i, x in enumerate(tmp, start):
for i, x in enumerate(batch, start):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it faster to calculate start + i in case of an error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably faster because we don't need to materialize an enumerate object but I'd benchmark that one to be sure.

try:
save(x)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {i}')
raise
write(APPENDS)
elif n:
else:
try:
save(tmp[0])
save(batch[0])
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {start}')
raise
write(APPEND)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
start += n
start += len(batch)
dongwooklee96 marked this conversation as resolved.
Show resolved Hide resolved

def save_dict(self, obj):
if self.bin:
Expand Down Expand Up @@ -1086,32 +1080,26 @@ def _batch_setitems(self, items, obj):
write(SETITEM)
return

it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
if len(batch) != 1:
write(MARK)
for k, v in tmp:
for k, v in batch:
save(k)
try:
save(v)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEMS)
elif n:
k, v = tmp[0]
else:
k, v = batch[0]
save(k)
try:
save(v)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return

def save_set(self, obj):
save = self.save
Expand All @@ -1124,21 +1112,15 @@ def save_set(self, obj):
write(EMPTY_SET)
self.memoize(obj)

it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
n = len(batch)
if n > 0:
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
if n < self._BATCHSIZE:
return
for batch in batched(obj, self._BATCHSIZE):
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
dispatch[set] = save_set

def save_frozenset(self, obj):
Expand Down
Loading