Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xjx): new style dist version, add storage loader and model loader #425

Merged
merged 46 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
545d3e8
Add singleton log writer
sailxjx May 16, 2022
759968d
Use get_instance on writer
sailxjx May 16, 2022
4b8e12e
feature(nyz): polish atari ddp demo and add dist demo
PaParaZz1 Jun 7, 2022
34e5154
Refactor dist version
sailxjx Jun 15, 2022
ade2231
Wrap class based middleware
sailxjx Jun 16, 2022
60e6d85
Change if condition in wrapper
sailxjx Jun 16, 2022
33b53e7
Only run enhancer on learner
sailxjx Jun 17, 2022
6d861b6
Support new parallel mode on slurm cluster
sailxjx Jun 21, 2022
8ea6559
Temp data loader
sailxjx Jun 15, 2022
10943f9
Stash commit
sailxjx Jun 21, 2022
f9c1f85
Init data serializer
sailxjx Jun 21, 2022
9c9b2e2
Update dump part of code
sailxjx Jun 22, 2022
a21270d
Test StorageLoader
sailxjx Jun 23, 2022
b57837d
Turn data serializer into storage loader, add storage loader in conte…
sailxjx Jun 28, 2022
9cd6bea
Add local id and startup interval
sailxjx Jun 28, 2022
fa3bcd8
Fix storage loader
sailxjx Jun 28, 2022
c1360fd
Support treetensor
sailxjx Jun 30, 2022
3a83646
Add role on event name in context exchanger, use share_memory functio…
sailxjx Jul 1, 2022
8bb3295
Double size buffer
sailxjx Jul 1, 2022
10bf11f
Copy tensor to cpu, skip wait for context on collector and evaluator
sailxjx Jul 1, 2022
56cf6ac
Remove data loader middleware
sailxjx Jul 7, 2022
7597787
Upgrade k8s parser
sailxjx Jul 8, 2022
56c984b
Add epoch timer
sailxjx Jul 13, 2022
2fa6aea
Dont use lb
sailxjx Jul 14, 2022
f8c1f41
Change tensor to numpy
sailxjx Jul 18, 2022
074d6ba
Remove files when stop storage loader
sailxjx Jul 19, 2022
f90472c
Discard shared object
sailxjx Jul 19, 2022
11b6f16
Ensure correct load shm memory
sailxjx Jul 20, 2022
a43948d
Add model loader
sailxjx Jul 21, 2022
ddca885
Rename model_exchanger to ModelExchanger
sailxjx Jul 21, 2022
9f5087b
Add model loader benchmark
sailxjx Jul 21, 2022
c5c2909
Shutdown loaders when task finish
sailxjx Jul 21, 2022
ed7acce
Upgrade supervisor
sailxjx Jul 21, 2022
3d71c1e
Dont cleanup files when shutting down
sailxjx Jul 21, 2022
6854e86
Fix async cleanup in model loader
sailxjx Jul 21, 2022
870d683
Check model loader on dqn
sailxjx Jul 22, 2022
7059952
Dont use loader in dqn example
sailxjx Jul 22, 2022
73f706e
Fix style check
sailxjx Jul 22, 2022
c43f8cf
Merge branch 'dev-dist' into dev-data-loader
sailxjx Jul 22, 2022
48713d7
Fix dp
sailxjx Jul 22, 2022
cacb257
Fix github tests
sailxjx Jul 22, 2022
4276ee3
Skip github ci
sailxjx Jul 25, 2022
ba5f638
Fix bug in event loop
sailxjx Jul 25, 2022
134891a
Fix enhancer tests, move router from start to __init__
sailxjx Jul 25, 2022
a51017d
Change default ttl
sailxjx Jul 26, 2022
c896504
Add comments
sailxjx Jul 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ding/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from torch.utils.data import Dataset, DataLoader
from ding.utils.data import create_dataset, offline_data_save_type # for compatibility
from .buffer import *
from .storage import *
from .storage_loader import StorageLoader, FileStorageLoader
from .shm_buffer import ShmBufferContainer, ShmBuffer
from .model_loader import ModelLoader, FileModelLoader
155 changes: 155 additions & 0 deletions ding/data/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp

import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload


class ModelWorker():

def __init__(self, model: torch.nn.Module) -> None:
self._model = model

def save(self, storage: Storage) -> Storage:
storage.save(self._model.state_dict())
return storage


class ModelLoader(Supervisor, ABC):

def __init__(self, model: torch.nn.Module) -> None:
"""
Overview:
Save and send models asynchronously and load them synchronously.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
"""
if next(model.parameters()).is_cuda:
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
else:
super().__init__(type_=ChildType.PROCESS)
self._model = model
self._send_callback_loop = None
self._send_callbacks = {}
self._model_worker = ModelWorker(self._model)

def start(self):
if not self._running:
self._model.share_memory()
self.register(self._model_worker)
self.start_link()
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
self._send_callback_loop.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._send_callback_loop = None
self._send_callbacks = {}

def _loop_send_callback(self):
while True:
payload = self.recv(ignore_err=True)
if payload.err:
logging.warning("Got error when loading data: {}".format(payload.err))
if payload.req_id in self._send_callbacks:
del self._send_callbacks[payload.req_id]
else:
if payload.req_id in self._send_callbacks:
callback = self._send_callbacks.pop(payload.req_id)
callback(payload.data)

def load(self, storage: Storage) -> object:
"""
Overview:
Load model synchronously.
Arguments:
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
Returns:
- object (:obj:): The loaded model.
"""
return storage.load()

@abstractmethod
def save(self, callback: Callable) -> Storage:
"""
Overview:
Save model asynchronously.
Arguments:
- callback (:obj:`Callable`): The callback function after saving model.
Returns:
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
"""
raise NotImplementedError


class FileModelLoader(ModelLoader):

def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
"""
Overview:
Model loader using files as storage media.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
- dirname (:obj:`str`): The directory for saving files.
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
files that do not time out when the process is stopped are not cleaned up \
(to avoid errors when other processes read the file), so you may need to \
clean up the remaining files manually
"""
super().__init__(model)
self._dirname = dirname
self._ttl = ttl
self._files = []
self._cleanup_thread = None

def _start_cleanup(self):
"""
Overview:
Start a cleanup thread to clean up files that are taking up too much time on the disk.
"""
if self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
self._cleanup_thread.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._cleanup_thread = None

def _loop_cleanup(self):
while True:
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
sleep(1)
continue
_, file_path = self._files.pop(0)
if path.exists(file_path):
os.remove(file_path)

def save(self, callback: Callable) -> FileModelStorage:
if not self._running:
logging.warning("Please start model loader before saving model.")
return
if not path.exists(self._dirname):
os.mkdir(self._dirname)
file_path = "model_{}.pth.tar".format(uuid.uuid1())
file_path = path.join(self._dirname, file_path)
model_storage = FileModelStorage(file_path)
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
self.send(payload)

def clean_callback(storage: Storage):
self._files.append([time(), file_path])
callback(storage)

self._send_callbacks[payload.req_id] = clean_callback
self._start_cleanup()
return model_storage
133 changes: 133 additions & 0 deletions ding/data/shm_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any, Optional, Union, Tuple, Dict
from multiprocessing import Array
import ctypes
import numpy as np
import torch

_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}


class ShmBuffer():
"""
Overview:
Shared memory buffer to store numpy array.
"""

def __init__(
self,
dtype: Union[type, np.dtype],
shape: Tuple[int],
copy_on_get: bool = True,
ctype: Optional[type] = None
) -> None:
"""
Overview:
Initialize the buffer.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
"""
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
dtype = dtype.type
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
self.copy_on_get = copy_on_get
self.ctype = ctype

def fill(self, src_arr: np.ndarray) -> None:
"""
Overview:
Fill the shared memory buffer with a numpy array. (Replace the original one.)
Arguments:
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
"""
assert isinstance(src_arr, np.ndarray), type(src_arr)
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
# so we reshape dst_arr rather than flatten src_arr
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
np.copyto(dst_arr, src_arr)

def get(self) -> np.ndarray:
"""
Overview:
Get the array stored in the buffer.
Return:
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
"""
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
if self.copy_on_get:
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
if self.ctype is torch.Tensor:
data = torch.from_numpy(data)
return data


class ShmBufferContainer(object):
"""
Overview:
Support multiple shared memory buffers. Each key-value is name-buffer.
"""

def __init__(
self,
dtype: Union[Dict[Any, type], type, np.dtype],
shape: Union[Dict[Any, tuple], tuple],
copy_on_get: bool = True
) -> None:
"""
Overview:
Initialize the buffer container.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
multiple buffers; If `tuple`, use single buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
"""
if isinstance(shape, dict):
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
elif isinstance(shape, (tuple, list)):
self._data = ShmBuffer(dtype, shape, copy_on_get)
else:
raise RuntimeError("not support shape: {}".format(shape))
self._shape = shape

def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
"""
Overview:
Fill the one or many shared memory buffer.
Arguments:
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
"""
if isinstance(self._shape, dict):
for k in self._shape.keys():
self._data[k].fill(src_arr[k])
elif isinstance(self._shape, (tuple, list)):
self._data.fill(src_arr)

def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
"""
Overview:
Get the one or many arrays stored in the buffer.
Return:
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
"""
if isinstance(self._shape, dict):
return {k: self._data[k].get() for k in self._shape.keys()}
elif isinstance(self._shape, (tuple, list)):
return self._data.get()
2 changes: 2 additions & 0 deletions ding/data/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .storage import Storage
from .file import FileStorage, FileModelStorage
25 changes: 25 additions & 0 deletions ding/data/storage/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any
from ding.data.storage import Storage
import pickle

from ding.utils.file_helper import read_file, save_file


class FileStorage(Storage):

def save(self, data: Any) -> None:
with open(self.path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load(self) -> Any:
with open(self.path, "rb") as f:
return pickle.load(f)


class FileModelStorage(Storage):

def save(self, state_dict: object) -> None:
save_file(self.path, state_dict)

def load(self) -> object:
return read_file(self.path)
16 changes: 16 additions & 0 deletions ding/data/storage/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any


class Storage(ABC):

def __init__(self, path: str) -> None:
self.path = path

@abstractmethod
def save(self, data: Any) -> None:
raise NotImplementedError

@abstractmethod
def load(self) -> Any:
raise NotImplementedError
18 changes: 18 additions & 0 deletions ding/data/storage/tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import tempfile
import pytest
import os
from os import path
from ding.data.storage import FileStorage


@pytest.mark.unittest
def test_file_storage():
path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
try:
storage = FileStorage(path=path_)
storage.save("test")
content = storage.load()
assert content == "test"
finally:
if path.exists(path_):
os.remove(path_)
Loading