Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Sampler/SequenceSampler/RandomSampler #26375

Merged
merged 4 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/fluid/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from . import dataloader_iter
from .dataloader_iter import *

from . import sampler
from .sampler import *

__all__ = dataset.__all__ \
+ batch_sampler.__all__ \
+ dataloader_iter.__all__
+ dataloader_iter.__all__ \
+ sampler.__all__
79 changes: 35 additions & 44 deletions python/paddle/fluid/dataloader/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from __future__ import division

import numpy as np
from .sampler import Sampler, SequenceSampler
from .dataset import Dataset, IterableDataset

__all__ = ["BatchSampler"]


class BatchSampler(object):
class BatchSampler(Sampler):
"""
A base implement of batch sampler used by `paddle.io.DataLoader`
which yield mini-batch indices(a list/tuple with length as
Expand All @@ -41,10 +42,11 @@ class BatchSampler(object):
implement or other python object which implemented
:code:`__len__` for BatchSampler to get indices as the
range of :attr:`dataset` length. Default None.
indices (list|tuple): a substitution parameter for
:attr:`dataset` either :attr:`dataset` or
:attr:`indices` should be set, give the whole
indices to sampler from directly. Default None.
sampler (Sampler): this could be a :code:`paddle.io.Dataset`
instance which implemented :code:`__iter__` to yield
sample indices. :attr:`sampler` and :attr:`dataset`
can not be set in the same time. If :attr:`sampler`
is set, :attr:`shuffle` should not be set. Default None.
shuffle(bool): whether to shuffle indices order before genrating
batch indices. Default False.
batch_size(int): sample indice number in a mini-batch indices.
Expand All @@ -58,16 +60,7 @@ class BatchSampler(object):
.. code-block:: python
from paddle.io import BatchSampler, Dataset
# init with indices
bs = BatchSampler(indices=list(range(100)),
shuffle=True,
batch_size=8,
drop_last=True)
for batch_indices in bs:
print(batch_indices)
from paddle.io import RandomSampler, BatchSampler, Dataset
# init with dataset
class RandomDataset(Dataset):
Expand All @@ -90,34 +83,42 @@ def __len__(self):
for batch_indices in bs:
print(batch_indices)
# init with sampler
sampler = RandomSampler(RandomDataset(100))
bs = BatchSampler(sampler=sampler,
shuffle=True,
batch_size=8,
drop_last=True)
for batch_indices in bs:
print(batch_indices)
see `paddle.io.DataLoader`
"""

def __init__(self,
dataset=None,
indices=None,
sampler=None,
shuffle=False,
batch_size=1,
drop_last=False):
if dataset is None:
assert indices is not None, \
"either dataset or indices should be set"
assert isinstance(indices, list) or isinstance(indices, tuple), \
"indices should be a list or tuple, but got {}".format(type(indices))
self.indices = indices
self.sampler_iter = None
assert sampler is not None, \
"either dataset or sampler should be set"
assert isinstance(sampler, Sampler), \
"sampler should be a paddle.io.Sampler, but got {}".format(type(sampler))
assert not shuffle, "shuffle should be False when sampler is set"
self.sampler = sampler
else:
if isinstance(dataset, IterableDataset):
self.sampler_iter = iter(
_InfiniteIterableSampler(dataset, batch_size))
else:
self.sampler_iter = None
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))
assert isinstance(dataset, Dataset), \
"dataset should be a paddle.io.Dataset"
assert not isinstance(dataset, IterableDataset), \
"dataset should not be a paddle.io.IterableDataset"
assert sampler is None, \
"should not set both dataset and sampler"
self.sampler = SequenceSampler(dataset)

assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size)
Expand All @@ -130,15 +131,8 @@ def __init__(self,
self.drop_last = drop_last

def __iter__(self):
if self.sampler_iter:
yield next(self.sampler_iter)

if self.shuffle:
np.random.shuffle(self.indices)
_iter = iter(self.indices)

batch_indices = []
for idx in _iter:
for idx in self.sampler:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
Expand All @@ -147,10 +141,7 @@ def __iter__(self):
yield batch_indices

def __len__(self):
if self.sampler_iter:
raise RuntimeError("'{}' should not be called for IterableDataset".
format('__len__'))
num_samples = len(self.indices)
num_samples = len(self.sampler)
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size

Expand Down
232 changes: 232 additions & 0 deletions python/paddle/fluid/dataloader/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import division

import numpy as np

__all__ = ["Sampler", "SequenceSampler", "RandomSampler"]


class Sampler(object):
"""
An abstract class to encapsulate methods and behaviors of samplers.
All sampler used by :code:`paddle.io.BatchSampler` should be a subclass
of :code:`paddle.io.Sampler`, BatchSampler subclasses should
implement following methods:
:code:`__iter__`: return sample index iterably, which iterate over indices
of dataset elements
:code:`__len__`: the number of sample in :attr:`data_source`
Args:
data_source(Dataset, optional): this could be an instance of
:code:`paddle.io.Dataset` other Python object which
implemented :code:`__len__` for Sampler to get indices
as the range of :attr:`dataset` length. Default None.
Returns:
Sampler: an iterable object for sample indices iterating
Examples:
.. code-block:: python
from paddle.io import Dataset, Sampler
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class MySampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
sampler = MySampler(data_source=RandomDataset(100))
for index in sampler:
print(index)
see `paddle.io.BatchSampler`
see `paddle.io.DataLoader`
"""

def __init__(self, data_source=None):
self.data_source = data_source

def __iter__(self):
raise NotImplementedError

# Not define __len__ method in this base class here for __len__
# is not needed in same sence, e.g. paddle.io.IterableDataset


class SequenceSampler(Sampler):
"""
Iterate samples sequentially, yield :code:`0, 1, 2, ..., len(data_source) -1`
generally,
Args:
data_source(Dataset): dataset to sample, this could be an
instance of :code:`paddle.io.Dataset` other Python
object which implemented :code:`__len__`.
Returns:
Sampler: a Sampler yield sample index sequentially
Examples:
.. code-block:: python
from paddle.io import Dataset, SequenceSampler
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
sampler = SequenceSampler(data_source=RandomDataset(100))
for index in sampler:
print(index)
see `paddle.io.Sampler`
"""

def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(range(len(self.data_source)))

def __len__(self):
return len(self.data_source)


class RandomSampler(Sampler):
"""
Iterate samples randomly, yield shuffled indices, if :attr:`replacement=False`,
yield shuffled indices of the whole data souce, if :attr:`replacement=True`,
:attr:`num_samples` can set to specify the sample number to draw.
Args:
data_source(Dataset): dataset to sample, this could be an
instance of :code:`paddle.io.Dataset` other Python
object which implemented :code:`__len__`.
replacement(bool): If False, sample the whole dataset, If False,
set :attr:`num_samples` for how many sample to draw. Default False.
num_samples(int): set sample number to draw if :attr:`replacement`
is True. Default None.
generator(Generator): specify a generator to sample the data source. Default None
Returns:
Sampler: a Sampler yield sample index randomly
Examples:
.. code-block:: python
from paddle.io import Dataset, RandomSampler
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
sampler = RandomSampler(data_souce=RandomDataset(100))
for index in sampler:
print(index)
see `paddle.io.Sampler`
"""

def __init__(self,
data_source,
replacement=False,
num_samples=None,
generator=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator

if not isinstance(self.replacement, bool):
raise TypeError("expect boolean value for replacement, but got "
"replacement={}".format(self.replacement))

if self._num_samples is not None and not replacement:
raise ValueError(
"num_samples should not be specified while replacement is False")

if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer, "
"but got num_samples={}".format(self.num_samples))

@property
def num_samples(self):
if self._num_samples is None:
return len(self.data_source)
return self._num_samples

def __iter__(self):
n = len(self.data_source)
if self.generator:
for index in self.generator:
yield index
else:
if self.replacement:
for index in np.random.choice(
np.arange(n), self.num_samples, replace=True).tolist():
yield index
else:
for index in np.random.choice(
np.arange(n), n, replace=False).tolist():
yield index

def __len__(self):
return self.num_samples
Loading