Skip to content

Commit

Permalink
apply black to all files
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 committed Jan 13, 2024
1 parent 5d84e72 commit e46a6c0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ install-dev: ## [Local development] Install test requirements
lint: ## [Local development] Run mypy, pylint and black
python -m mypy img2dataset
python -m pylint img2dataset
python -m black --check -l 120 img2dataset
python -m black --check -l 120 .

black: ## [Local development] Auto-format python code using black
python -m black -l 120 .
Expand Down
61 changes: 28 additions & 33 deletions examples/ray_example/ray_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,37 @@
import argparse




@ray.remote
def main(args):
download(
processes_count=1,
thread_count=32,
retries=0,
timeout=10,
url_list=args.url_list,
image_size=512,
resize_only_if_bigger=True,
resize_mode="keep_ratio_largest",
skip_reencode=True,
output_folder=args.out_folder,
output_format="webdataset",
input_format="parquet",
url_col="url",
caption_col="alt",
enable_wandb=True,
subjob_size=48*120*2,
number_sample_per_shard=10000,
distributor="ray",
oom_shard_count=8,
compute_hash="sha256",
save_additional_columns=["uid"]
processes_count=1,
thread_count=32,
retries=0,
timeout=10,
url_list=args.url_list,
image_size=512,
resize_only_if_bigger=True,
resize_mode="keep_ratio_largest",
skip_reencode=True,
output_folder=args.out_folder,
output_format="webdataset",
input_format="parquet",
url_col="url",
caption_col="alt",
enable_wandb=True,
subjob_size=48 * 120 * 2,
number_sample_per_shard=10000,
distributor="ray",
oom_shard_count=8,
compute_hash="sha256",
save_additional_columns=["uid"],
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url_list")
parser.add_argument("--out_folder")
args = parser.parse_args()
ray.init(address="localhost:6379")
main(args)




if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url_list")
parser.add_argument("--out_folder")
args = parser.parse_args()
ray.init(address="localhost:6379")
main(args)
2 changes: 1 addition & 1 deletion examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import os

if __name__ == '__main__':
if __name__ == "__main__":
output_dir = os.path.abspath("bench")

if os.path.exists(output_dir):
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setup_fixtures(count=5, disallowed=0):

def generate_url_list_txt(output_file, test_list, compression_on=False):
if compression_on:
f = gzip.open(output_file, 'wt')
f = gzip.open(output_file, "wt")
else:
f = open(output_file, "w")
with f:
Expand All @@ -63,7 +63,7 @@ def generate_json(output_file, test_list, compression=None):

def generate_jsonl(output_file, test_list, compression=None):
df = pd.DataFrame(test_list, columns=["caption", "url"])
df.to_json(output_file, orient='records', lines=True, compression=compression)
df.to_json(output_file, orient="records", lines=True, compression=compression)


def generate_parquet(output_file, test_list):
Expand Down
12 changes: 10 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def test_download_input_format(input_format, output_format, tmp_path):
)

if output_format != "dummy":

df = pd.read_parquet(image_folder_name + "/00000.parquet")

expected_columns = [
Expand Down Expand Up @@ -169,7 +168,16 @@ def test_download_input_format(input_format, output_format, tmp_path):

assert len(pd.read_parquet(image_folder_name + "/00000.parquet").index) == expected_file_count
elif output_format == "dummy":
l = [x for x in glob.glob(image_folder_name + "/*") if (not x.endswith(".json") and not x.endswith(".jsonl") and not x.endswith(".json.gz") and not x.endswith(".jsonl.gz")) ]
l = [
x
for x in glob.glob(image_folder_name + "/*")
if (
not x.endswith(".json")
and not x.endswith(".jsonl")
and not x.endswith(".json.gz")
and not x.endswith(".jsonl.gz")
)
]
assert len(l) == 0
elif output_format == "tfrecord":
l = glob.glob(image_folder_name + "/*.tfrecord")
Expand Down
1 change: 0 additions & 1 deletion tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_writer(writer_type, tmp_path):
writer.close()

if writer_type != "dummy":

df = pd.read_parquet(output_folder + "/00000.parquet")

expected_columns = [
Expand Down

0 comments on commit e46a6c0

Please sign in to comment.