Skip to content

Commit

Permalink
add ZipStorage, support loading tree from storage
Browse files Browse the repository at this point in the history
  • Loading branch information
luizirber committed Apr 22, 2020
1 parent eeed865 commit 1e0889c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 21 deletions.
70 changes: 52 additions & 18 deletions sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ def search_transcript(node, seq, threshold):
import os
from random import randint, random
import sys
from tempfile import NamedTemporaryFile

from deprecation import deprecated

from .sbt_storage import FSStorage, TarStorage, IPFSStorage, RedisStorage
from .sbt_storage import FSStorage, TarStorage, IPFSStorage, RedisStorage, ZipStorage
from .logging import error, notify, debug
from .index import Index
from .nodegraph import Nodegraph, extract_nodegraph_info, calc_expected_collisions
Expand All @@ -68,6 +69,7 @@ def search_transcript(node, seq, threshold):
'FSStorage': FSStorage,
'IPFSStorage': IPFSStorage,
'RedisStorage': RedisStorage,
'ZipStorage': ZipStorage,
}
NodePos = namedtuple("NodePos", ["pos", "node"])

Expand Down Expand Up @@ -649,10 +651,45 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr
SBT
the SBT tree built from the description.
"""
dirname = os.path.dirname(os.path.abspath(location))
sbt_name = os.path.basename(location)
if sbt_name.endswith('.sbt.json'):
sbt_name = sbt_name[:-9]
tempfile = None
sbt_name = None

if storage is not None:
try:
tempfile = NamedTemporaryFile()
tempfile.write(storage.load('tree.sbt.json'))
tempfile.flush()

dirname = os.path.dirname(tempfile.name)
sbt_name = os.path.basename(tempfile.name)
except KeyError:
tempfile = None

if sbt_name is None:
if ZipStorage.can_open(location):
tempfile = NamedTemporaryFile()
storage = ZipStorage(location)

tempfile.write(storage.load('tree.sbt.json'))
tempfile.flush()

dirname = os.path.dirname(tempfile.name)
sbt_name = os.path.basename(tempfile.name)
elif TarStorage.can_open(location):
tempfile = NamedTemporaryFile()
storage = ZipStorage(location)

tempfile.write(storage.load('tree.sbt.json'))
tempfile.flush()

dirname = os.path.dirname(tempfile.name)
sbt_name = os.path.basename(tempfile.name)

if sbt_name is None:
dirname = os.path.dirname(os.path.abspath(location))
sbt_name = os.path.basename(location)
if sbt_name.endswith('.sbt.json'):
sbt_name = sbt_name[:-9]

loaders = {
1: cls._load_v1,
Expand All @@ -666,17 +703,26 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr
leaf_loader = Leaf.load

sbt_fn = os.path.join(dirname, sbt_name)
if not sbt_fn.endswith('.sbt.json'):
if not sbt_fn.endswith('.sbt.json') and tempfile is None:
sbt_fn += '.sbt.json'
with open(sbt_fn) as fp:
jnodes = json.load(fp)

if tempfile is not None:
tempfile.close()

version = 1
if isinstance(jnodes, Mapping):
version = jnodes['version']

if version < 3 and storage is None:
storage = FSStorage(dirname, '.sbt.{}'.format(sbt_name))
elif storage is None:
klass = STORAGES[jnodes['storage']['backend']]
if jnodes['storage']['backend'] == "FSStorage":
storage = FSStorage(dirname, jnodes['storage']['args']['path'])
elif storage is None:
storage = klass(**jnodes['storage']['args'])

return loaders[version](jnodes, leaf_loader, dirname, storage,
print_version_warning)
Expand Down Expand Up @@ -756,12 +802,6 @@ def _load_v3(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru
sbt_nodes = {}
sbt_leaves = {}

klass = STORAGES[info['storage']['backend']]
if info['storage']['backend'] == "FSStorage":
storage = FSStorage(dirname, info['storage']['args']['path'])
elif storage is None:
storage = klass(**info['storage']['args'])

factory = GraphFactory(*info['factory']['args'])

max_node = 0
Expand Down Expand Up @@ -803,12 +843,6 @@ def _load_v4(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru
sbt_nodes = {}
sbt_leaves = {}

klass = STORAGES[info['storage']['backend']]
if info['storage']['backend'] == "FSStorage":
storage = FSStorage(dirname, info['storage']['args']['path'])
elif storage is None:
storage = klass(**info['storage']['args'])

factory = GraphFactory(*info['factory']['args'])

max_node = 0
Expand Down
52 changes: 52 additions & 0 deletions sourmash/sbt_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from io import BytesIO
import os
import tarfile
import zipfile


class Storage(abc.ABCMeta(str('ABC'), (object,), {'__slots__': ()})):
Expand All @@ -27,6 +28,9 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
pass

def can_open(self, location):
return False


class FSStorage(Storage):

Expand Down Expand Up @@ -96,6 +100,54 @@ def init_args(self):
def __exit__(self, type, value, traceback):
self.tarfile.close()

@staticmethod
def can_open(location):
try:
tarfile.is_tarfile(location)
except IOError:
return False

return False


class ZipStorage(Storage):

def __init__(self, path=None):
# TODO: leave it open, or close/open every time?

if path is None:
# TODO: Open a temporary file?
pass

self.path = os.path.abspath(path)

dirname = os.path.dirname(self.path)
if not os.path.exists(dirname):
os.makedirs(dirname)

if os.path.exists(self.path):
self.zipfile = zipfile.ZipFile(path, 'r')
else:
self.zipfile = zipfile.ZipFile(path, mode='w',
compression=zipfile.ZIP_BZIP2)

def save(self, path, content):
self.zipfile.writestr(path, content)
return path

def load(self, path):
return self.zipfile.read(path)

def init_args(self):
return {'path': self.path}

def __exit__(self, type, value, traceback):
self.zipfile.close()

@staticmethod
def can_open(location):
return zipfile.is_zipfile(location)


class IPFSStorage(Storage):

Expand Down
2 changes: 2 additions & 0 deletions src/core/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ impl Signature {
where
R: io::Read,
{
let (rdr, _format) = niffler::get_reader(Box::new(rdr))?;

let sigs: Vec<Signature> = serde_json::from_reader(rdr)?;
Ok(sigs)
}
Expand Down
Binary file added tests/test-data/v4.zip
Binary file not shown.
58 changes: 55 additions & 3 deletions tests/test_sbt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import print_function, unicode_literals

import shutil
import os

import pytest

from sourmash import load_one_signature, SourmashSignature
from sourmash import load_one_signature, SourmashSignature, load_signatures
from sourmash.sbt import SBT, GraphFactory, Leaf, Node
from sourmash.sbtmh import (SigLeaf, search_minhashes,
search_minhashes_containment)
from sourmash.sbt_storage import (FSStorage, TarStorage,
RedisStorage, IPFSStorage)
from sourmash.sbt_storage import (FSStorage, TarStorage, RedisStorage,
IPFSStorage, ZipStorage)

from . import sourmash_tst_utils as utils

Expand Down Expand Up @@ -442,6 +443,40 @@ def test_sbt_tarstorage():
assert old_result == new_result


def test_sbt_zipstorage():
factory = GraphFactory(31, 1e5, 4)
with utils.TempDirectory() as location:
tree = SBT(factory)

for f in utils.SIG_FILES:
sig = next(load_signatures(utils.get_test_data(f)))
leaf = SigLeaf(os.path.basename(f), sig)
tree.add_node(leaf)
to_search = leaf

print('*' * 60)
print("{}:".format(to_search.metadata))
old_result = {str(s) for s in tree.find(search_minhashes,
to_search.data, 0.1)}
print(*old_result, sep='\n')

with ZipStorage(os.path.join(location, 'tree.zip')) as storage:
tree.save(os.path.join(location, 'tree'), storage=storage)

with ZipStorage(os.path.join(location, 'tree.zip')) as storage:
tree = SBT.load(os.path.join(location, 'tree'),
leaf_loader=SigLeaf.load,
storage=storage)

print('*' * 60)
print("{}:".format(to_search.metadata))
new_result = {str(s) for s in tree.find(search_minhashes,
to_search.data, 0.1)}
print(*new_result, sep='\n')

assert old_result == new_result


def test_sbt_ipfsstorage():
ipfshttpclient = pytest.importorskip('ipfshttpclient')

Expand Down Expand Up @@ -521,6 +556,23 @@ def test_sbt_redisstorage():
assert old_result == new_result


def test_load_zip(tmpdir):
testdata = utils.get_test_data("v4.zip")
testsbt = tmpdir.join("v4.zip")

shutil.copyfile(testdata, str(testsbt))

tree = SBT.load(str(testsbt), leaf_loader=SigLeaf.load)

to_search = load_one_signature(utils.get_test_data(utils.SIG_FILES[0]))

print("*" * 60)
print("{}:".format(to_search))
new_result = {str(s) for s in tree.find(search_minhashes, to_search, 0.1)}
print(*new_result, sep="\n")
assert len(new_result) == 2


def test_tree_repair():
tree_repair = SBT.load(utils.get_test_data('leaves.sbt.json'),
leaf_loader=SigLeaf.load)
Expand Down

0 comments on commit 1e0889c

Please sign in to comment.