Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/361 pad #572

Merged
merged 86 commits into from
Sep 22, 2020
Merged

Features/361 pad #572

merged 86 commits into from
Sep 22, 2020

Conversation

lenablind
Copy link
Collaborator

@lenablind lenablind commented May 25, 2020

Description

Implementation of function pad for mode "constant".
Syntax is nearly the same as for numpy, whereas I use torch.nn.functional.pad internally.

Syntactical differences

Although numpy uses different values keywords for the corresponding mode types (more specifically constant_values, end_values), I decided to use simply one (values) for ease of usage, as only one mode can be used simultaneously either way.

Also, what lacks my implementation are two numpy keywords as the corresponding modes are currently not available in this version:

  • ‘stat_length’ used in numpy modes ‘maximum’, ‘mean’, ‘median’, ‘minimum’
  • ‘reflect_type’ used in numpy modes ‘reflect’ and ‘symmetric’.

Strategy

Hint: Torch allows only one padding value to be specified for all dimensions, whereas numpy offers the possibility to define one in each case. Therefore, to simulate numpy functionality but keep the performance of torch, I decided to call torch for each value in specified in values.

Preparation

  • handle different types of numpy shortcuts for pad_width and transform it into one torch pad tuple (-> shortcuts: see numpy docs)
  • handle different types of numpy shortcuts for values and transform it into one tuple if various values are included (value _ tuple (-> shortcuts: see numpy docs)
  • calculate the gshape of the resulting DNDarray

Actual Padding

CASE 0 : input tensor contains no data

  • Return the empty tensor with the adapted lshape (necessary for remapping in case of distribution and general consistency)

CASE 1 : Padding in non-split dimension or no distribution at all

  • If only one value is specified for all dimensions, pad the tensor with torch as usual, otherwise:
  • iterate through value _ tuple in reverse order (as numpy starts padding with the last dimension in contrary to torch) and call the torch pad version using the corresponding value , pad_tuple and the more and more padded tensor
    In other words, you pad each dimension with the specified value in the value _ tuple.
    This is necessary to provide numpy functionality ( -> Hint above)

CASE 2 : Padding in split dimension and function runs on more than 1 process

  • Pad only first/last tensor portion on node (i.e. only beginning/end in split dimension)
  • "Calculate" the pad _ tuple for the corresponding tensor portion, respectfully the two indices which have to be set to zero in the original/undistributed pad _ tuple depending on the dimension:
    Therefore: Calculate the index of the first element in pad tuple that has to change/be set to zero (the following is the second)
    The pad tuples can hereby be divided in three categories:
    • pad _ beginning (first process)
    • pad _ end (last process)
    • pad _ middle (all other processes)
      This is only a mathematical transcription for the manner in which the tensor chunk has to be padded.
  • Balance the tensor and return it

Docs numpy: https://numpy.org/devdocs/reference/generated/numpy.pad.html
Docs pytorch: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad

Issue/s resolved: #361

Changes proposed:

  • Additional modes (pytorch offers here a lot less than numpy by itself)

Overview modes (and their differences in numpy and torch)

Aequivalent modes numpy and torch

These might be implemented most easily, though there are some restrictions.
More specifically, only 3D, 4D and 5D padding with non-constant padding are currently supported by torch. Additionally, some scalability issues might occur for these modes.
To make it clear, padding a 9 element long-tensor with 'reflect' might already result in a RuntimeError.

Numpy Torch Description Available dimensions (Torch)
constant constant Pads the input tensor boundaries with a constant value Arbitrary
reflect reflect Pads the input tensor using the reflection of the input boundary last 2 of 4D, last of 3D
edge replicate Pads the input tensor using the replication of the input boundary last 3 of 5D, last 2 of 4D
wrap circular Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning.

Numpy modes which might result in constant padding with calculated padding values

Mode Pads with the...
linear ramp …linear ramp between end_value and the array edge value
maximum … maximum value of all or part of the vector along each axis
mean … mean value of all or part of the vector along each axis
median … median value of all or part of the vector along each axis
minimum … minimum value of all or part of the vector along each axis

the referred 'part of the vector' might be furthermore specified with the corresponding values keyword.
(-> see numpy docs.)

Type of change

  • New feature (non-breaking change which adds functionality)

Due Diligence

  • All split configurations tested
  • Multiple dtypes tested in relevant functions
  • Documentation updated (if needed)
  • Updated changelog.md under the title "Pending Additions"

Does this change modify the behaviour of other functions? If so, which?

no

lenablind and others added 30 commits January 9, 2020 13:45
Copy link
Contributor

@ClaudiaComito ClaudiaComito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job @lenablind , needs a few more changes!

@mtar
Copy link
Collaborator

mtar commented Aug 31, 2020

GPU cluster tests are currently disabled on this Pull Request.

@lenablind
Copy link
Collaborator Author

GPU cluster tests are currently disabled on this Pull Request.

@mtar Thank you for letting me know. Is there a reason for that or to put it differently, are these needed for this PR and if that is the case, could you explain to me why?

@mtar
Copy link
Collaborator

mtar commented Aug 31, 2020

The CI system that I was setting up recently has a life of its own 😃
It will be important in the future.

@mtar
Copy link
Collaborator

mtar commented Sep 21, 2020

ok to test

@mtar
Copy link
Collaborator

mtar commented Sep 22, 2020

rerun tests

ClaudiaComito
ClaudiaComito previously approved these changes Sep 22, 2020
@coquelin77 coquelin77 merged commit a9efe03 into master Sep 22, 2020
@coquelin77 coquelin77 deleted the features/361-pad branch September 22, 2020 12:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

implement pad/padding function
5 participants