From 34cf0455dfa5bd769271d03345858cd1f08c3375 Mon Sep 17 00:00:00 2001 From: Binh Vu Date: Sun, 21 Jan 2024 03:55:43 +0000 Subject: [PATCH] save entity/page dump with zstd compression --- kgdata/misc/funcs.py | 43 ++++++++++++++++++++ kgdata/spark/common.py | 54 ++++++++++++++++++++++++- kgdata/spark/extended_rdd.py | 39 ++++++++++++++++-- kgdata/splitter.py | 25 ++++++++++-- kgdata/wikidata/config.py | 5 ++- kgdata/wikidata/datasets/entity_dump.py | 7 +++- kgdata/wikidata/datasets/page_dump.py | 7 ++-- pyproject.toml | 4 +- 8 files changed, 168 insertions(+), 16 deletions(-) diff --git a/kgdata/misc/funcs.py b/kgdata/misc/funcs.py index 484b7dc..6bbe4d4 100644 --- a/kgdata/misc/funcs.py +++ b/kgdata/misc/funcs.py @@ -1,6 +1,49 @@ from __future__ import annotations +import importlib +from io import BytesIO +from typing import Type + +import zstandard as zstd + def split_tab_2(x: str) -> tuple[str, str]: k, v = x.split("\t") return (k, v) + + +TYPE_ALIASES = {"typing.List": "list", "typing.Dict": "dict", "typing.Set": "set"} + + +def get_import_path(type: Type) -> str: + if type.__module__ == "builtins": + return type.__qualname__ + + if hasattr(type, "__qualname__"): + return type.__module__ + "." + type.__qualname__ + + # typically a class from the typing module + if hasattr(type, "_name") and type._name is not None: + path = type.__module__ + "." + type._name + if path in TYPE_ALIASES: + path = TYPE_ALIASES[path] + elif hasattr(type, "__origin__") and hasattr(type.__origin__, "_name"): + # found one case which is typing.Union + path = type.__module__ + "." + type.__origin__._name + else: + raise NotImplementedError(type) + + return path + + +def import_attr(attr_ident: str): + lst = attr_ident.rsplit(".", 1) + module, cls = lst + module = importlib.import_module(module) + return getattr(module, cls) + + +def deser_zstd_records(dat: bytes): + cctx = zstd.ZstdDecompressor() + datobj = BytesIO(dat) + return [x.decode() for x in cctx.stream_reader(datobj).readall().splitlines()] diff --git a/kgdata/spark/common.py b/kgdata/spark/common.py index 4e536ec..e1e0fcc 100644 --- a/kgdata/spark/common.py +++ b/kgdata/spark/common.py @@ -14,6 +14,7 @@ Generic, Iterable, List, + Literal, Optional, Sequence, Tuple, @@ -22,8 +23,10 @@ ) import orjson +import zstandard as zstd from loguru import logger -from pyspark import RDD, SparkConf, SparkContext +from pyspark import RDD, SparkConf, SparkContext, TaskContext +from sm.misc.funcs import assert_not_null # SparkContext singleton _sc = None @@ -84,6 +87,8 @@ def get_key(key): "spark.executor.instances", "spark.driver.memory", "spark.driver.maxResultSize", + "spark.driver.extraLibraryPath", + "spark.executor.extraLibraryPath", ] if has_key(key) ] @@ -480,7 +485,54 @@ def get_bytes(s: str | bytes) -> int: return math.ceil(total_size / partition_size) +def save_as_text_file( + rdd: RDD[str] | RDD[bytes], + outdir: Path, + compression: Optional[Literal["gz", "zst"]], + compression_level: Optional[int] = None, +): + if compression == "gz" or compression is None: + return rdd.saveAsTextFile( + str(outdir), + compressionCodecClass="org.apache.hadoop.io.compress.GzipCodec" + if compression == "gz" + else None, + ) + + if compression == "zst": + compression_level = compression_level or 3 + + def save_partition(partition: Iterable[str] | Iterable[bytes]): + partition_id = assert_not_null(TaskContext.get()).partitionId() + lst = [] + it = iter(partition) + first_val = next(it, None) + if isinstance(first_val, str): + lst.append(first_val.encode()) + for x in it: + lst.append(x.encode()) # type: ignore + else: + lst.append(first_val) + for x in it: + lst.append(x) + datasize = sum(len(x) + 1 for x in lst) # 1 for newline + cctx = zstd.ZstdCompressor(level=compression_level, write_content_size=True) + + with open(outdir / f"part-{partition_id:05d}.zst", "wb") as fh: + with cctx.stream_writer(fh, size=datasize) as f: + for record in lst: + f.write(record) + f.write(b"\n") + + rdd.foreachPartition(save_partition) + (outdir / "_SUCCESS").touch() + return + + raise Exception(f"Unknown compression: {compression}") + + @dataclass class EmptyBroadcast(Generic[V]): value: V value: V + value: V diff --git a/kgdata/spark/extended_rdd.py b/kgdata/spark/extended_rdd.py index efeb840..18a842e 100644 --- a/kgdata/spark/extended_rdd.py +++ b/kgdata/spark/extended_rdd.py @@ -22,6 +22,10 @@ ) import serde.json +from pyspark.rdd import RDD, portable_hash +from typing_extensions import TypeGuard + +from kgdata.misc.funcs import deser_zstd_records from kgdata.spark.common import ( are_records_unique, estimate_num_partitions, @@ -29,8 +33,6 @@ join_repartition, left_outer_join_repartition, ) -from pyspark.rdd import RDD, portable_hash -from typing_extensions import TypeGuard if TYPE_CHECKING: from kgdata.dataset import Dataset @@ -244,6 +246,9 @@ def save_like_dataset( file_pattern = Path(dataset.file_pattern) if file_pattern.suffix == ".gz": compressionCodecClass = "org.apache.hadoop.io.compress.GzipCodec" + elif file_pattern.suffix == ".zst": + # this is a dummy codec for our custom version of save_as_text_file + compressionCodecClass = "kgdata.compress.ZstdCodec" else: # this to make sure the dataset file pattern matches the generated file from spark. assert file_pattern.suffix == "" and file_pattern.name.startswith( @@ -317,6 +322,14 @@ def save_as_dataset( """ outdir = str(outdir) + # if compressionCodecClass == "kgdata.compress.ZstdCodec": + # compression = "zst" + # elif compressionCodecClass == "org.apache.hadoop.io.compress.GzipCodec": + # compression = "gz" + # else: + # assert compressionCodecClass is None + # compression = None + if not auto_coalesce: self.rdd.saveAsTextFile(outdir, compressionCodecClass=compressionCodecClass) else: @@ -473,7 +486,8 @@ def is_unique( def textFile( indir: StrPath, minPartitions: Optional[int] = None, use_unicode: bool = True ): - sigfile = Path(indir) / "_SIGNATURE" + indir = Path(indir) + sigfile = indir / "_SIGNATURE" if sigfile.exists(): sig = serde.json.deser(sigfile, DatasetSignature) assert sig.is_valid() @@ -484,8 +498,25 @@ def textFile( checksum="", dependencies={}, ) + + # to support zst files (indir) + if ( + indir.is_dir() + and any( + file.name.startswith("part-") and file.name.endswith(".zst") + for file in indir.iterdir() + ) + ) or indir.name.endswith(".zst"): + return ExtendedRDD( + get_spark_context() + .binaryFiles(str(indir), minPartitions) + .flatMap(lambda x: deser_zstd_records(x[1])), + sig, + ) + return ExtendedRDD( - get_spark_context().textFile(str(indir), minPartitions, use_unicode), sig + get_spark_context().textFile(str(indir), minPartitions, use_unicode), + sig, ) @staticmethod diff --git a/kgdata/splitter.py b/kgdata/splitter.py index 94ebb4f..2864528 100644 --- a/kgdata/splitter.py +++ b/kgdata/splitter.py @@ -7,7 +7,16 @@ from io import TextIOWrapper from multiprocessing import Process, Queue from pathlib import Path -from typing import BinaryIO, Callable, ContextManager, Iterable, List, Tuple, Union +from typing import ( + BinaryIO, + Callable, + ContextManager, + Iterable, + List, + Optional, + Tuple, + Union, +) import serde.byteline from serde.helper import get_open_fn @@ -41,6 +50,7 @@ def split_a_file( override: bool = False, n_writers: int = 8, n_records_per_file: int = 64000, + compression_level: Optional[int] = None, ): r"""Split a file containing a list of records into smaller files stored in a directory. The list of records are written in a round-robin fashion by multiple writers (processes) @@ -60,6 +70,7 @@ def split_a_file( override: whether to override existing files. n_writers: number of parallel writers. n_records_per_file: number of records per file. + compression_level: compression level of output files. if None, use default compression level. """ outfile = Path(outfile) outdir = outfile.parent @@ -82,7 +93,13 @@ def split_a_file( writers.append( Process( target=write_to_file, - args=(writer_file, n_records_per_file, record_postprocess, queues[i]), + args=( + writer_file, + n_records_per_file, + record_postprocess, + queues[i], + compression_level, + ), ) ) writers[i].start() @@ -135,6 +152,7 @@ def write_to_file( n_records_per_file: int, record_postprocessing: str, queue: Queue, + compression_level: Optional[int], ): """Write records from a queue to a file. @@ -143,11 +161,12 @@ def write_to_file( n_records_per_file: number of records per file record_postprocessing: name/path to import the function that post-process an record. the function can return None to skip the record. queue: a queue that yields records to be written to a file, when it yields None, the writer stops. + compression_level: compression level of output files. if None, use default compression level. """ file_counter = 0 outfile = outfile_template.format(auto=file_counter) - writer = get_open_fn(outfile)(outfile, "wb") + writer = get_open_fn(outfile, compression_level=compression_level)(outfile, "wb") n_records = 0 postprocess_fn = import_func(record_postprocessing) diff --git a/kgdata/wikidata/config.py b/kgdata/wikidata/config.py index 2d23b0a..07a344e 100644 --- a/kgdata/wikidata/config.py +++ b/kgdata/wikidata/config.py @@ -63,7 +63,10 @@ def get_dump_date(self): return res[0] def get_entity_dump_file(self): - return self._get_file(self.dumps / "*wikidata-*all*.json.bz2") + try: + return self._get_file(self.dumps / "*wikidata-*all*.json.zst") + except: + return self._get_file(self.dumps / "*wikidata-*all*.json.bz2") def get_page_dump_file(self): return self._get_file(self.dumps / "*wikidatawiki-*page*.sql.gz") diff --git a/kgdata/wikidata/datasets/entity_dump.py b/kgdata/wikidata/datasets/entity_dump.py index 4ad73d4..b903a5b 100644 --- a/kgdata/wikidata/datasets/entity_dump.py +++ b/kgdata/wikidata/datasets/entity_dump.py @@ -1,6 +1,7 @@ from bz2 import BZ2File from functools import lru_cache from gzip import GzipFile +from io import BufferedReader from typing import BinaryIO, Union import orjson @@ -21,7 +22,7 @@ def entity_dump() -> Dataset[dict]: cfg = WikidataDirCfg.get_instance() dump_date = cfg.get_dump_date() ds = Dataset( - file_pattern=cfg.entity_dump / "*.gz", + file_pattern=cfg.entity_dump / "*.zst", deserialize=orjson.loads, name=f"entity-dump/{dump_date}", dependencies=[], @@ -30,11 +31,12 @@ def entity_dump() -> Dataset[dict]: if not ds.has_complete_data(): split_a_file( infile=cfg.get_entity_dump_file(), - outfile=cfg.entity_dump / "part.ndjson.gz", + outfile=cfg.entity_dump / "part.ndjson.zst", record_iter=_record_iter, record_postprocess="kgdata.wikidata.datasets.entity_dump._record_postprocess", n_writers=8, override=False, + compression_level=9, ) ds.sign(ds.get_name(), ds.get_dependencies()) return ds @@ -57,4 +59,5 @@ def _record_postprocess(record: str): if __name__ == "__main__": + WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619") entity_dump() diff --git a/kgdata/wikidata/datasets/page_dump.py b/kgdata/wikidata/datasets/page_dump.py index 33e4672..79adcda 100644 --- a/kgdata/wikidata/datasets/page_dump.py +++ b/kgdata/wikidata/datasets/page_dump.py @@ -15,16 +15,17 @@ def page_dump() -> Dataset[str]: dump_date = cfg.get_dump_date() ds = Dataset.string( - cfg.page_dump / "*.gz", name=f"page-dump/{dump_date}", dependencies=[] + cfg.page_dump / "*.zst", name=f"page-dump/{dump_date}", dependencies=[] ) if not ds.has_complete_data(): split_a_file( infile=cfg.get_page_dump_file(), - outfile=cfg.page_dump / "part.sql.gz", + outfile=cfg.page_dump / "part.sql.zst", record_iter=_record_iter, record_postprocess="kgdata.wikidata.datasets.page_dump._record_postprocess", n_writers=8, override=False, + compression_level=9, ) ds.sign(ds.get_name(), ds.get_dependencies()) return ds @@ -47,5 +48,5 @@ def _record_postprocess(line: bytes): if __name__ == "__main__": - WikidataDirCfg.init("~/kgdata/wikidata/20211213") + WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619") page_dump() diff --git a/pyproject.toml b/pyproject.toml index 7354db4..ebe2a6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kgdata" -version = "6.3.1" +version = "6.3.2" description = "Library to process dumps of knowledge graphs (Wikipedia, DBpedia, Wikidata)" readme = "README.md" authors = [{ name = "Binh Vu", email = "binh@toan2.com" }] @@ -25,7 +25,7 @@ dependencies = [ 'redis >= 3.5.3, < 4.0.0', 'numpy >= 1.22.3, < 2.0.0', 'requests >= 2.28.0, < 3.0.0', - 'sem-desc >= 6.0.0, < 7.0.0', + 'sem-desc >= 6.7.2, < 7.0.0', 'click >= 8.1.3, < 9.0.0', 'parsimonious >= 0.8.1, < 0.9.0', 'hugedict >= 2.12.5, < 3.0.0',