diff --git a/img2dataset/reader.py b/img2dataset/reader.py index 77d19b1..6b64982 100644 --- a/img2dataset/reader.py +++ b/img2dataset/reader.py @@ -109,11 +109,20 @@ def write_shard(shard_id): else: raise e - shards = [] - # thread pool to make it faster to write files to low latency file systems (ie s3, hdfs) - with ThreadPool(32) as thread_pool: - for shard in thread_pool.imap_unordered(write_shard, range(number_shards)): - shards.append(shard) + for i in range(10): + shards = [] + # thread pool to make it faster to write files to low latency file systems (ie s3, hdfs) + try: + with ThreadPool(32) as thread_pool: + for shard in thread_pool.imap_unordered(write_shard, range(number_shards)): + shards.append(shard) + break + except Exception as e: # pylint: disable=broad-except + if i != 9: + print("retrying whole sharding to write to files due to error:", e) + time.sleep(2 * i) + else: + raise e shards.sort(key=lambda k: k[0])