-
Notifications
You must be signed in to change notification settings - Fork 388
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(xjx): new style dist version, add storage loader and model lo…
…ader (#425) * Add singleton log writer * Use get_instance on writer * feature(nyz): polish atari ddp demo and add dist demo * Refactor dist version * Wrap class based middleware * Change if condition in wrapper * Only run enhancer on learner * Support new parallel mode on slurm cluster * Temp data loader * Stash commit * Init data serializer * Update dump part of code * Test StorageLoader * Turn data serializer into storage loader, add storage loader in context exchanger * Add local id and startup interval * Fix storage loader * Support treetensor * Add role on event name in context exchanger, use share_memory function on tensor * Double size buffer * Copy tensor to cpu, skip wait for context on collector and evaluator * Remove data loader middleware * Upgrade k8s parser * Add epoch timer * Dont use lb * Change tensor to numpy * Remove files when stop storage loader * Discard shared object * Ensure correct load shm memory * Add model loader * Rename model_exchanger to ModelExchanger * Add model loader benchmark * Shutdown loaders when task finish * Upgrade supervisor * Dont cleanup files when shutting down * Fix async cleanup in model loader * Check model loader on dqn * Dont use loader in dqn example * Fix style check * Fix dp * Fix github tests * Skip github ci * Fix bug in event loop * Fix enhancer tests, move router from start to __init__ * Change default ttl * Add comments Co-authored-by: niuyazhe <[email protected]>
- Loading branch information
Showing
48 changed files
with
1,946 additions
and
483 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
from ding.utils.data import create_dataset, offline_data_save_type # for compatibility | ||
from .buffer import * | ||
from .storage import * | ||
from .storage_loader import StorageLoader, FileStorageLoader | ||
from .shm_buffer import ShmBufferContainer, ShmBuffer | ||
from .model_loader import ModelLoader, FileModelLoader |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from abc import ABC, abstractmethod | ||
import logging | ||
from os import path | ||
import os | ||
from threading import Thread | ||
from time import sleep, time | ||
from typing import Callable, Optional | ||
import uuid | ||
import torch.multiprocessing as mp | ||
|
||
import torch | ||
from ding.data.storage.file import FileModelStorage | ||
from ding.data.storage.storage import Storage | ||
from ding.framework import Supervisor | ||
from ding.framework.supervisor import ChildType, SendPayload | ||
|
||
|
||
class ModelWorker(): | ||
|
||
def __init__(self, model: torch.nn.Module) -> None: | ||
self._model = model | ||
|
||
def save(self, storage: Storage) -> Storage: | ||
storage.save(self._model.state_dict()) | ||
return storage | ||
|
||
|
||
class ModelLoader(Supervisor, ABC): | ||
|
||
def __init__(self, model: torch.nn.Module) -> None: | ||
""" | ||
Overview: | ||
Save and send models asynchronously and load them synchronously. | ||
Arguments: | ||
- model (:obj:`torch.nn.Module`): Torch module. | ||
""" | ||
if next(model.parameters()).is_cuda: | ||
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) | ||
else: | ||
super().__init__(type_=ChildType.PROCESS) | ||
self._model = model | ||
self._send_callback_loop = None | ||
self._send_callbacks = {} | ||
self._model_worker = ModelWorker(self._model) | ||
|
||
def start(self): | ||
if not self._running: | ||
self._model.share_memory() | ||
self.register(self._model_worker) | ||
self.start_link() | ||
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) | ||
self._send_callback_loop.start() | ||
|
||
def shutdown(self, timeout: Optional[float] = None) -> None: | ||
super().shutdown(timeout) | ||
self._send_callback_loop = None | ||
self._send_callbacks = {} | ||
|
||
def _loop_send_callback(self): | ||
while True: | ||
payload = self.recv(ignore_err=True) | ||
if payload.err: | ||
logging.warning("Got error when loading data: {}".format(payload.err)) | ||
if payload.req_id in self._send_callbacks: | ||
del self._send_callbacks[payload.req_id] | ||
else: | ||
if payload.req_id in self._send_callbacks: | ||
callback = self._send_callbacks.pop(payload.req_id) | ||
callback(payload.data) | ||
|
||
def load(self, storage: Storage) -> object: | ||
""" | ||
Overview: | ||
Load model synchronously. | ||
Arguments: | ||
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. | ||
Returns: | ||
- object (:obj:): The loaded model. | ||
""" | ||
return storage.load() | ||
|
||
@abstractmethod | ||
def save(self, callback: Callable) -> Storage: | ||
""" | ||
Overview: | ||
Save model asynchronously. | ||
Arguments: | ||
- callback (:obj:`Callable`): The callback function after saving model. | ||
Returns: | ||
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class FileModelLoader(ModelLoader): | ||
|
||
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: | ||
""" | ||
Overview: | ||
Model loader using files as storage media. | ||
Arguments: | ||
- model (:obj:`torch.nn.Module`): Torch module. | ||
- dirname (:obj:`str`): The directory for saving files. | ||
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ | ||
files that do not time out when the process is stopped are not cleaned up \ | ||
(to avoid errors when other processes read the file), so you may need to \ | ||
clean up the remaining files manually | ||
""" | ||
super().__init__(model) | ||
self._dirname = dirname | ||
self._ttl = ttl | ||
self._files = [] | ||
self._cleanup_thread = None | ||
|
||
def _start_cleanup(self): | ||
""" | ||
Overview: | ||
Start a cleanup thread to clean up files that are taking up too much time on the disk. | ||
""" | ||
if self._cleanup_thread is None: | ||
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) | ||
self._cleanup_thread.start() | ||
|
||
def shutdown(self, timeout: Optional[float] = None) -> None: | ||
super().shutdown(timeout) | ||
self._cleanup_thread = None | ||
|
||
def _loop_cleanup(self): | ||
while True: | ||
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: | ||
sleep(1) | ||
continue | ||
_, file_path = self._files.pop(0) | ||
if path.exists(file_path): | ||
os.remove(file_path) | ||
|
||
def save(self, callback: Callable) -> FileModelStorage: | ||
if not self._running: | ||
logging.warning("Please start model loader before saving model.") | ||
return | ||
if not path.exists(self._dirname): | ||
os.mkdir(self._dirname) | ||
file_path = "model_{}.pth.tar".format(uuid.uuid1()) | ||
file_path = path.join(self._dirname, file_path) | ||
model_storage = FileModelStorage(file_path) | ||
payload = SendPayload(proc_id=0, method="save", args=[model_storage]) | ||
self.send(payload) | ||
|
||
def clean_callback(storage: Storage): | ||
self._files.append([time(), file_path]) | ||
callback(storage) | ||
|
||
self._send_callbacks[payload.req_id] = clean_callback | ||
self._start_cleanup() | ||
return model_storage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import Any, Optional, Union, Tuple, Dict | ||
from multiprocessing import Array | ||
import ctypes | ||
import numpy as np | ||
import torch | ||
|
||
_NTYPE_TO_CTYPE = { | ||
np.bool_: ctypes.c_bool, | ||
np.uint8: ctypes.c_uint8, | ||
np.uint16: ctypes.c_uint16, | ||
np.uint32: ctypes.c_uint32, | ||
np.uint64: ctypes.c_uint64, | ||
np.int8: ctypes.c_int8, | ||
np.int16: ctypes.c_int16, | ||
np.int32: ctypes.c_int32, | ||
np.int64: ctypes.c_int64, | ||
np.float32: ctypes.c_float, | ||
np.float64: ctypes.c_double, | ||
} | ||
|
||
|
||
class ShmBuffer(): | ||
""" | ||
Overview: | ||
Shared memory buffer to store numpy array. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dtype: Union[type, np.dtype], | ||
shape: Tuple[int], | ||
copy_on_get: bool = True, | ||
ctype: Optional[type] = None | ||
) -> None: | ||
""" | ||
Overview: | ||
Initialize the buffer. | ||
Arguments: | ||
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. | ||
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. | ||
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method. | ||
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. | ||
""" | ||
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype | ||
dtype = dtype.type | ||
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) | ||
self.dtype = dtype | ||
self.shape = shape | ||
self.copy_on_get = copy_on_get | ||
self.ctype = ctype | ||
|
||
def fill(self, src_arr: np.ndarray) -> None: | ||
""" | ||
Overview: | ||
Fill the shared memory buffer with a numpy array. (Replace the original one.) | ||
Arguments: | ||
- src_arr (:obj:`np.ndarray`): array to fill the buffer. | ||
""" | ||
assert isinstance(src_arr, np.ndarray), type(src_arr) | ||
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten | ||
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten | ||
# so we reshape dst_arr rather than flatten src_arr | ||
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) | ||
np.copyto(dst_arr, src_arr) | ||
|
||
def get(self) -> np.ndarray: | ||
""" | ||
Overview: | ||
Get the array stored in the buffer. | ||
Return: | ||
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer. | ||
""" | ||
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) | ||
if self.copy_on_get: | ||
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory | ||
if self.ctype is torch.Tensor: | ||
data = torch.from_numpy(data) | ||
return data | ||
|
||
|
||
class ShmBufferContainer(object): | ||
""" | ||
Overview: | ||
Support multiple shared memory buffers. Each key-value is name-buffer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dtype: Union[Dict[Any, type], type, np.dtype], | ||
shape: Union[Dict[Any, tuple], tuple], | ||
copy_on_get: bool = True | ||
) -> None: | ||
""" | ||
Overview: | ||
Initialize the buffer container. | ||
Arguments: | ||
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. | ||
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ | ||
multiple buffers; If `tuple`, use single buffer. | ||
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method. | ||
""" | ||
if isinstance(shape, dict): | ||
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} | ||
elif isinstance(shape, (tuple, list)): | ||
self._data = ShmBuffer(dtype, shape, copy_on_get) | ||
else: | ||
raise RuntimeError("not support shape: {}".format(shape)) | ||
self._shape = shape | ||
|
||
def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: | ||
""" | ||
Overview: | ||
Fill the one or many shared memory buffer. | ||
Arguments: | ||
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. | ||
""" | ||
if isinstance(self._shape, dict): | ||
for k in self._shape.keys(): | ||
self._data[k].fill(src_arr[k]) | ||
elif isinstance(self._shape, (tuple, list)): | ||
self._data.fill(src_arr) | ||
|
||
def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: | ||
""" | ||
Overview: | ||
Get the one or many arrays stored in the buffer. | ||
Return: | ||
- data (:obj:`np.ndarray`): The array(s) stored in the buffer. | ||
""" | ||
if isinstance(self._shape, dict): | ||
return {k: self._data[k].get() for k in self._shape.keys()} | ||
elif isinstance(self._shape, (tuple, list)): | ||
return self._data.get() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .storage import Storage | ||
from .file import FileStorage, FileModelStorage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Any | ||
from ding.data.storage import Storage | ||
import pickle | ||
|
||
from ding.utils.file_helper import read_file, save_file | ||
|
||
|
||
class FileStorage(Storage): | ||
|
||
def save(self, data: Any) -> None: | ||
with open(self.path, "wb") as f: | ||
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
def load(self) -> Any: | ||
with open(self.path, "rb") as f: | ||
return pickle.load(f) | ||
|
||
|
||
class FileModelStorage(Storage): | ||
|
||
def save(self, state_dict: object) -> None: | ||
save_file(self.path, state_dict) | ||
|
||
def load(self) -> object: | ||
return read_file(self.path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
|
||
class Storage(ABC): | ||
|
||
def __init__(self, path: str) -> None: | ||
self.path = path | ||
|
||
@abstractmethod | ||
def save(self, data: Any) -> None: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def load(self) -> Any: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tempfile | ||
import pytest | ||
import os | ||
from os import path | ||
from ding.data.storage import FileStorage | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_file_storage(): | ||
path_ = path.join(tempfile.gettempdir(), "test_storage.txt") | ||
try: | ||
storage = FileStorage(path=path_) | ||
storage.save("test") | ||
content = storage.load() | ||
assert content == "test" | ||
finally: | ||
if path.exists(path_): | ||
os.remove(path_) |
Oops, something went wrong.