Skip to content

Commit

Permalink
save entity/page dump with zstd compression
Browse files Browse the repository at this point in the history
  • Loading branch information
Binh Vu committed Jan 21, 2024
1 parent 7944c79 commit 34cf045
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 16 deletions.
43 changes: 43 additions & 0 deletions kgdata/misc/funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
from __future__ import annotations

import importlib
from io import BytesIO
from typing import Type

import zstandard as zstd


def split_tab_2(x: str) -> tuple[str, str]:
k, v = x.split("\t")
return (k, v)


TYPE_ALIASES = {"typing.List": "list", "typing.Dict": "dict", "typing.Set": "set"}


def get_import_path(type: Type) -> str:
if type.__module__ == "builtins":
return type.__qualname__

if hasattr(type, "__qualname__"):
return type.__module__ + "." + type.__qualname__

# typically a class from the typing module
if hasattr(type, "_name") and type._name is not None:
path = type.__module__ + "." + type._name
if path in TYPE_ALIASES:
path = TYPE_ALIASES[path]
elif hasattr(type, "__origin__") and hasattr(type.__origin__, "_name"):
# found one case which is typing.Union
path = type.__module__ + "." + type.__origin__._name
else:
raise NotImplementedError(type)

return path


def import_attr(attr_ident: str):
lst = attr_ident.rsplit(".", 1)
module, cls = lst
module = importlib.import_module(module)
return getattr(module, cls)


def deser_zstd_records(dat: bytes):
cctx = zstd.ZstdDecompressor()
datobj = BytesIO(dat)
return [x.decode() for x in cctx.stream_reader(datobj).readall().splitlines()]
54 changes: 53 additions & 1 deletion kgdata/spark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Generic,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Expand All @@ -22,8 +23,10 @@
)

import orjson
import zstandard as zstd
from loguru import logger
from pyspark import RDD, SparkConf, SparkContext
from pyspark import RDD, SparkConf, SparkContext, TaskContext
from sm.misc.funcs import assert_not_null

# SparkContext singleton
_sc = None
Expand Down Expand Up @@ -84,6 +87,8 @@ def get_key(key):
"spark.executor.instances",
"spark.driver.memory",
"spark.driver.maxResultSize",
"spark.driver.extraLibraryPath",
"spark.executor.extraLibraryPath",
]
if has_key(key)
]
Expand Down Expand Up @@ -480,7 +485,54 @@ def get_bytes(s: str | bytes) -> int:
return math.ceil(total_size / partition_size)


def save_as_text_file(
rdd: RDD[str] | RDD[bytes],
outdir: Path,
compression: Optional[Literal["gz", "zst"]],
compression_level: Optional[int] = None,
):
if compression == "gz" or compression is None:
return rdd.saveAsTextFile(
str(outdir),
compressionCodecClass="org.apache.hadoop.io.compress.GzipCodec"
if compression == "gz"
else None,
)

if compression == "zst":
compression_level = compression_level or 3

def save_partition(partition: Iterable[str] | Iterable[bytes]):
partition_id = assert_not_null(TaskContext.get()).partitionId()
lst = []
it = iter(partition)
first_val = next(it, None)
if isinstance(first_val, str):
lst.append(first_val.encode())
for x in it:
lst.append(x.encode()) # type: ignore
else:
lst.append(first_val)
for x in it:
lst.append(x)
datasize = sum(len(x) + 1 for x in lst) # 1 for newline
cctx = zstd.ZstdCompressor(level=compression_level, write_content_size=True)

with open(outdir / f"part-{partition_id:05d}.zst", "wb") as fh:
with cctx.stream_writer(fh, size=datasize) as f:
for record in lst:
f.write(record)
f.write(b"\n")

rdd.foreachPartition(save_partition)
(outdir / "_SUCCESS").touch()
return

raise Exception(f"Unknown compression: {compression}")


@dataclass
class EmptyBroadcast(Generic[V]):
value: V
value: V
value: V
39 changes: 35 additions & 4 deletions kgdata/spark/extended_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
)

import serde.json
from pyspark.rdd import RDD, portable_hash
from typing_extensions import TypeGuard

from kgdata.misc.funcs import deser_zstd_records
from kgdata.spark.common import (
are_records_unique,
estimate_num_partitions,
get_spark_context,
join_repartition,
left_outer_join_repartition,
)
from pyspark.rdd import RDD, portable_hash
from typing_extensions import TypeGuard

if TYPE_CHECKING:
from kgdata.dataset import Dataset
Expand Down Expand Up @@ -244,6 +246,9 @@ def save_like_dataset(
file_pattern = Path(dataset.file_pattern)
if file_pattern.suffix == ".gz":
compressionCodecClass = "org.apache.hadoop.io.compress.GzipCodec"
elif file_pattern.suffix == ".zst":
# this is a dummy codec for our custom version of save_as_text_file
compressionCodecClass = "kgdata.compress.ZstdCodec"
else:
# this to make sure the dataset file pattern matches the generated file from spark.
assert file_pattern.suffix == "" and file_pattern.name.startswith(
Expand Down Expand Up @@ -317,6 +322,14 @@ def save_as_dataset(
"""
outdir = str(outdir)

# if compressionCodecClass == "kgdata.compress.ZstdCodec":
# compression = "zst"
# elif compressionCodecClass == "org.apache.hadoop.io.compress.GzipCodec":
# compression = "gz"
# else:
# assert compressionCodecClass is None
# compression = None

if not auto_coalesce:
self.rdd.saveAsTextFile(outdir, compressionCodecClass=compressionCodecClass)
else:
Expand Down Expand Up @@ -473,7 +486,8 @@ def is_unique(
def textFile(
indir: StrPath, minPartitions: Optional[int] = None, use_unicode: bool = True
):
sigfile = Path(indir) / "_SIGNATURE"
indir = Path(indir)
sigfile = indir / "_SIGNATURE"
if sigfile.exists():
sig = serde.json.deser(sigfile, DatasetSignature)
assert sig.is_valid()
Expand All @@ -484,8 +498,25 @@ def textFile(
checksum="",
dependencies={},
)

# to support zst files (indir)
if (
indir.is_dir()
and any(
file.name.startswith("part-") and file.name.endswith(".zst")
for file in indir.iterdir()
)
) or indir.name.endswith(".zst"):
return ExtendedRDD(
get_spark_context()
.binaryFiles(str(indir), minPartitions)
.flatMap(lambda x: deser_zstd_records(x[1])),
sig,
)

return ExtendedRDD(
get_spark_context().textFile(str(indir), minPartitions, use_unicode), sig
get_spark_context().textFile(str(indir), minPartitions, use_unicode),
sig,
)

@staticmethod
Expand Down
25 changes: 22 additions & 3 deletions kgdata/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
from io import TextIOWrapper
from multiprocessing import Process, Queue
from pathlib import Path
from typing import BinaryIO, Callable, ContextManager, Iterable, List, Tuple, Union
from typing import (
BinaryIO,
Callable,
ContextManager,
Iterable,
List,
Optional,
Tuple,
Union,
)

import serde.byteline
from serde.helper import get_open_fn
Expand Down Expand Up @@ -41,6 +50,7 @@ def split_a_file(
override: bool = False,
n_writers: int = 8,
n_records_per_file: int = 64000,
compression_level: Optional[int] = None,
):
r"""Split a file containing a list of records into smaller files stored in a directory.
The list of records are written in a round-robin fashion by multiple writers (processes)
Expand All @@ -60,6 +70,7 @@ def split_a_file(
override: whether to override existing files.
n_writers: number of parallel writers.
n_records_per_file: number of records per file.
compression_level: compression level of output files. if None, use default compression level.
"""
outfile = Path(outfile)
outdir = outfile.parent
Expand All @@ -82,7 +93,13 @@ def split_a_file(
writers.append(
Process(
target=write_to_file,
args=(writer_file, n_records_per_file, record_postprocess, queues[i]),
args=(
writer_file,
n_records_per_file,
record_postprocess,
queues[i],
compression_level,
),
)
)
writers[i].start()
Expand Down Expand Up @@ -135,6 +152,7 @@ def write_to_file(
n_records_per_file: int,
record_postprocessing: str,
queue: Queue,
compression_level: Optional[int],
):
"""Write records from a queue to a file.
Expand All @@ -143,11 +161,12 @@ def write_to_file(
n_records_per_file: number of records per file
record_postprocessing: name/path to import the function that post-process an record. the function can return None to skip the record.
queue: a queue that yields records to be written to a file, when it yields None, the writer stops.
compression_level: compression level of output files. if None, use default compression level.
"""
file_counter = 0

outfile = outfile_template.format(auto=file_counter)
writer = get_open_fn(outfile)(outfile, "wb")
writer = get_open_fn(outfile, compression_level=compression_level)(outfile, "wb")
n_records = 0

postprocess_fn = import_func(record_postprocessing)
Expand Down
5 changes: 4 additions & 1 deletion kgdata/wikidata/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def get_dump_date(self):
return res[0]

def get_entity_dump_file(self):
return self._get_file(self.dumps / "*wikidata-*all*.json.bz2")
try:
return self._get_file(self.dumps / "*wikidata-*all*.json.zst")
except:
return self._get_file(self.dumps / "*wikidata-*all*.json.bz2")

def get_page_dump_file(self):
return self._get_file(self.dumps / "*wikidatawiki-*page*.sql.gz")
Expand Down
7 changes: 5 additions & 2 deletions kgdata/wikidata/datasets/entity_dump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bz2 import BZ2File
from functools import lru_cache
from gzip import GzipFile
from io import BufferedReader
from typing import BinaryIO, Union

import orjson
Expand All @@ -21,7 +22,7 @@ def entity_dump() -> Dataset[dict]:
cfg = WikidataDirCfg.get_instance()
dump_date = cfg.get_dump_date()
ds = Dataset(
file_pattern=cfg.entity_dump / "*.gz",
file_pattern=cfg.entity_dump / "*.zst",
deserialize=orjson.loads,
name=f"entity-dump/{dump_date}",
dependencies=[],
Expand All @@ -30,11 +31,12 @@ def entity_dump() -> Dataset[dict]:
if not ds.has_complete_data():
split_a_file(
infile=cfg.get_entity_dump_file(),
outfile=cfg.entity_dump / "part.ndjson.gz",
outfile=cfg.entity_dump / "part.ndjson.zst",
record_iter=_record_iter,
record_postprocess="kgdata.wikidata.datasets.entity_dump._record_postprocess",
n_writers=8,
override=False,
compression_level=9,
)
ds.sign(ds.get_name(), ds.get_dependencies())
return ds
Expand All @@ -57,4 +59,5 @@ def _record_postprocess(record: str):


if __name__ == "__main__":
WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619")
entity_dump()
7 changes: 4 additions & 3 deletions kgdata/wikidata/datasets/page_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ def page_dump() -> Dataset[str]:
dump_date = cfg.get_dump_date()

ds = Dataset.string(
cfg.page_dump / "*.gz", name=f"page-dump/{dump_date}", dependencies=[]
cfg.page_dump / "*.zst", name=f"page-dump/{dump_date}", dependencies=[]
)
if not ds.has_complete_data():
split_a_file(
infile=cfg.get_page_dump_file(),
outfile=cfg.page_dump / "part.sql.gz",
outfile=cfg.page_dump / "part.sql.zst",
record_iter=_record_iter,
record_postprocess="kgdata.wikidata.datasets.page_dump._record_postprocess",
n_writers=8,
override=False,
compression_level=9,
)
ds.sign(ds.get_name(), ds.get_dependencies())
return ds
Expand All @@ -47,5 +48,5 @@ def _record_postprocess(line: bytes):


if __name__ == "__main__":
WikidataDirCfg.init("~/kgdata/wikidata/20211213")
WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619")
page_dump()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kgdata"
version = "6.3.1"
version = "6.3.2"
description = "Library to process dumps of knowledge graphs (Wikipedia, DBpedia, Wikidata)"
readme = "README.md"
authors = [{ name = "Binh Vu", email = "[email protected]" }]
Expand All @@ -25,7 +25,7 @@ dependencies = [
'redis >= 3.5.3, < 4.0.0',
'numpy >= 1.22.3, < 2.0.0',
'requests >= 2.28.0, < 3.0.0',
'sem-desc >= 6.0.0, < 7.0.0',
'sem-desc >= 6.7.2, < 7.0.0',
'click >= 8.1.3, < 9.0.0',
'parsimonious >= 0.8.1, < 0.9.0',
'hugedict >= 2.12.5, < 3.0.0',
Expand Down

0 comments on commit 34cf045

Please sign in to comment.