Skip to content

Commit

Permalink
feature(xjx): new style dist version, add storage loader and model lo…
Browse files Browse the repository at this point in the history
…ader (#425)

* Add singleton log writer

* Use get_instance on writer

* feature(nyz): polish atari ddp demo and add dist demo

* Refactor dist version

* Wrap class based middleware

* Change if condition in wrapper

* Only run enhancer on learner

* Support new parallel mode on slurm cluster

* Temp data loader

* Stash commit

* Init data serializer

* Update dump part of code

* Test StorageLoader

* Turn data serializer into storage loader, add storage loader in context exchanger

* Add local id and startup interval

* Fix storage loader

* Support treetensor

* Add role on event name in context exchanger, use share_memory function on tensor

* Double size buffer

* Copy tensor to cpu, skip wait for context on collector and evaluator

* Remove data loader middleware

* Upgrade k8s parser

* Add epoch timer

* Dont use lb

* Change tensor to numpy

* Remove files when stop storage loader

* Discard shared object

* Ensure correct load shm memory

* Add model loader

* Rename model_exchanger to ModelExchanger

* Add model loader benchmark

* Shutdown loaders when task finish

* Upgrade supervisor

* Dont cleanup files when shutting down

* Fix async cleanup in model loader

* Check model loader on dqn

* Dont use loader in dqn example

* Fix style check

* Fix dp

* Fix github tests

* Skip github ci

* Fix bug in event loop

* Fix enhancer tests, move router from start to __init__

* Change default ttl

* Add comments

Co-authored-by: niuyazhe <[email protected]>
  • Loading branch information
sailxjx and PaParaZz1 authored Sep 8, 2022
1 parent 3f34393 commit 0206137
Show file tree
Hide file tree
Showing 48 changed files with 1,946 additions and 483 deletions.
4 changes: 4 additions & 0 deletions ding/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from torch.utils.data import Dataset, DataLoader
from ding.utils.data import create_dataset, offline_data_save_type # for compatibility
from .buffer import *
from .storage import *
from .storage_loader import StorageLoader, FileStorageLoader
from .shm_buffer import ShmBufferContainer, ShmBuffer
from .model_loader import ModelLoader, FileModelLoader
File renamed without changes.
155 changes: 155 additions & 0 deletions ding/data/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp

import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload


class ModelWorker():

def __init__(self, model: torch.nn.Module) -> None:
self._model = model

def save(self, storage: Storage) -> Storage:
storage.save(self._model.state_dict())
return storage


class ModelLoader(Supervisor, ABC):

def __init__(self, model: torch.nn.Module) -> None:
"""
Overview:
Save and send models asynchronously and load them synchronously.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
"""
if next(model.parameters()).is_cuda:
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
else:
super().__init__(type_=ChildType.PROCESS)
self._model = model
self._send_callback_loop = None
self._send_callbacks = {}
self._model_worker = ModelWorker(self._model)

def start(self):
if not self._running:
self._model.share_memory()
self.register(self._model_worker)
self.start_link()
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
self._send_callback_loop.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._send_callback_loop = None
self._send_callbacks = {}

def _loop_send_callback(self):
while True:
payload = self.recv(ignore_err=True)
if payload.err:
logging.warning("Got error when loading data: {}".format(payload.err))
if payload.req_id in self._send_callbacks:
del self._send_callbacks[payload.req_id]
else:
if payload.req_id in self._send_callbacks:
callback = self._send_callbacks.pop(payload.req_id)
callback(payload.data)

def load(self, storage: Storage) -> object:
"""
Overview:
Load model synchronously.
Arguments:
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
Returns:
- object (:obj:): The loaded model.
"""
return storage.load()

@abstractmethod
def save(self, callback: Callable) -> Storage:
"""
Overview:
Save model asynchronously.
Arguments:
- callback (:obj:`Callable`): The callback function after saving model.
Returns:
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
"""
raise NotImplementedError


class FileModelLoader(ModelLoader):

def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
"""
Overview:
Model loader using files as storage media.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
- dirname (:obj:`str`): The directory for saving files.
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
files that do not time out when the process is stopped are not cleaned up \
(to avoid errors when other processes read the file), so you may need to \
clean up the remaining files manually
"""
super().__init__(model)
self._dirname = dirname
self._ttl = ttl
self._files = []
self._cleanup_thread = None

def _start_cleanup(self):
"""
Overview:
Start a cleanup thread to clean up files that are taking up too much time on the disk.
"""
if self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
self._cleanup_thread.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._cleanup_thread = None

def _loop_cleanup(self):
while True:
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
sleep(1)
continue
_, file_path = self._files.pop(0)
if path.exists(file_path):
os.remove(file_path)

def save(self, callback: Callable) -> FileModelStorage:
if not self._running:
logging.warning("Please start model loader before saving model.")
return
if not path.exists(self._dirname):
os.mkdir(self._dirname)
file_path = "model_{}.pth.tar".format(uuid.uuid1())
file_path = path.join(self._dirname, file_path)
model_storage = FileModelStorage(file_path)
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
self.send(payload)

def clean_callback(storage: Storage):
self._files.append([time(), file_path])
callback(storage)

self._send_callbacks[payload.req_id] = clean_callback
self._start_cleanup()
return model_storage
133 changes: 133 additions & 0 deletions ding/data/shm_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any, Optional, Union, Tuple, Dict
from multiprocessing import Array
import ctypes
import numpy as np
import torch

_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}


class ShmBuffer():
"""
Overview:
Shared memory buffer to store numpy array.
"""

def __init__(
self,
dtype: Union[type, np.dtype],
shape: Tuple[int],
copy_on_get: bool = True,
ctype: Optional[type] = None
) -> None:
"""
Overview:
Initialize the buffer.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
"""
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
dtype = dtype.type
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
self.copy_on_get = copy_on_get
self.ctype = ctype

def fill(self, src_arr: np.ndarray) -> None:
"""
Overview:
Fill the shared memory buffer with a numpy array. (Replace the original one.)
Arguments:
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
"""
assert isinstance(src_arr, np.ndarray), type(src_arr)
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
# so we reshape dst_arr rather than flatten src_arr
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
np.copyto(dst_arr, src_arr)

def get(self) -> np.ndarray:
"""
Overview:
Get the array stored in the buffer.
Return:
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
"""
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
if self.copy_on_get:
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
if self.ctype is torch.Tensor:
data = torch.from_numpy(data)
return data


class ShmBufferContainer(object):
"""
Overview:
Support multiple shared memory buffers. Each key-value is name-buffer.
"""

def __init__(
self,
dtype: Union[Dict[Any, type], type, np.dtype],
shape: Union[Dict[Any, tuple], tuple],
copy_on_get: bool = True
) -> None:
"""
Overview:
Initialize the buffer container.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
multiple buffers; If `tuple`, use single buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
"""
if isinstance(shape, dict):
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
elif isinstance(shape, (tuple, list)):
self._data = ShmBuffer(dtype, shape, copy_on_get)
else:
raise RuntimeError("not support shape: {}".format(shape))
self._shape = shape

def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
"""
Overview:
Fill the one or many shared memory buffer.
Arguments:
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
"""
if isinstance(self._shape, dict):
for k in self._shape.keys():
self._data[k].fill(src_arr[k])
elif isinstance(self._shape, (tuple, list)):
self._data.fill(src_arr)

def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
"""
Overview:
Get the one or many arrays stored in the buffer.
Return:
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
"""
if isinstance(self._shape, dict):
return {k: self._data[k].get() for k in self._shape.keys()}
elif isinstance(self._shape, (tuple, list)):
return self._data.get()
2 changes: 2 additions & 0 deletions ding/data/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .storage import Storage
from .file import FileStorage, FileModelStorage
25 changes: 25 additions & 0 deletions ding/data/storage/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any
from ding.data.storage import Storage
import pickle

from ding.utils.file_helper import read_file, save_file


class FileStorage(Storage):

def save(self, data: Any) -> None:
with open(self.path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load(self) -> Any:
with open(self.path, "rb") as f:
return pickle.load(f)


class FileModelStorage(Storage):

def save(self, state_dict: object) -> None:
save_file(self.path, state_dict)

def load(self) -> object:
return read_file(self.path)
16 changes: 16 additions & 0 deletions ding/data/storage/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any


class Storage(ABC):

def __init__(self, path: str) -> None:
self.path = path

@abstractmethod
def save(self, data: Any) -> None:
raise NotImplementedError

@abstractmethod
def load(self) -> Any:
raise NotImplementedError
18 changes: 18 additions & 0 deletions ding/data/storage/tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import tempfile
import pytest
import os
from os import path
from ding.data.storage import FileStorage


@pytest.mark.unittest
def test_file_storage():
path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
try:
storage = FileStorage(path=path_)
storage.save("test")
content = storage.load()
assert content == "test"
finally:
if path.exists(path_):
os.remove(path_)
Loading

0 comments on commit 0206137

Please sign in to comment.