-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BaseRandomLayer Abstract Layer and RandomWidth Preprocessing Layer #7345
BaseRandomLayer Abstract Layer and RandomWidth Preprocessing Layer #7345
Conversation
…ctins for preprocessing layers Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
…d handles randomization of width Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I have a few questions and changes to suggest.
super(args); | ||
} | ||
|
||
protected setRNGType = (rngType: string) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be a normal function on the class's prototype instead of an arrow function.
protected setRNGType = (rngType: string) => { | |
protected setRNGType(rngType: string) { |
super(args); | ||
} | ||
|
||
protected setRNGType = (rngType: string) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the rng type be initialized by an option in the constructor? I think that's how keras does it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFJS might not support all the keras RNG types, but that should be fine as long as we support pseudorandom and true random.
if (this.seed !== null) { | ||
this.widthFactor = randomUniform([1], | ||
(1.0 + this.widthLower), (1.0 + this.widthUpper), | ||
'float32', this.seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this.seed
is set and never changed, this code will always return the same value. There is no automatic advancing of seeds built in to randomUniform
because randomUniform
is stateless. You may need to keep track of, and increment, a seed manually.
this.widthFactor = randomUniform([1], | ||
(1.0 + this.widthLower), (1.0 + this.widthUpper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if statement checking if seed === null
is probably not necessary. You can pass null
or undefined
as the seed and it will behave the same as if you had not passed anything.
if(args.seed) { | ||
this.seed = args.seed; | ||
} else { | ||
this.seed = null; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if statement prevents people from using the seed 0
.
if(args.seed) { | |
this.seed = args.seed; | |
} else { | |
this.seed = null; | |
} | |
this.seed = args.seed; |
} | ||
this.adjustedWidth = this.widthFactor.dataSync()[0] * imgWidth; | ||
this.adjustedWidth = Math.round(this.adjustedWidth); | ||
const size: [number, number] = [this.imgHeight, this.adjustedWidth]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny nit: Using as const
makes TypeScript correctly infer this as const [number, number]
instead of number[]
.
const size: [number, number] = [this.imgHeight, this.adjustedWidth]; | |
const size = [this.imgHeight, this.adjustedWidth] as const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resizeBilinear and resizeNearestNeighbor expect a mutable array of two numbers. using as const
assigns the readonly type
const rangeTensor = range(0, 16); | ||
const inputTensor = reshape(rangeTensor, [4,4,1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit since rangeTensor
only seems to be used once.
const rangeTensor = range(0, 16); | |
const inputTensor = reshape(rangeTensor, [4,4,1]); | |
const inputTensor = range(0, 16).reshape([4, 4, 1]); |
export declare interface BaseRandomLayerArgs extends LayerArgs {} | ||
|
||
export type RNGTypes = { | ||
[key: string]: Function; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
is a very wide type here. Can it be narrowed to be more specific?
private readonly rngTypes: RNGTypes = { | ||
gamma: randomGamma, | ||
normal: randomNormal, | ||
standardNormal: randomStandardNormal, | ||
uniform: randomUniform | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this constant is not specific to this class, it can probably be moved outside of the class. If that's the case, you might also be able to base the RNGTypes
type off of this constant.
const RNG_TYPES = {...};
type RNGTypes = typeof RNG_TYPES;
/** @nocollapse */ | ||
static className = 'RandomWidth'; | ||
randomGenerator: Function; | ||
private rngType: string; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can reference RNGTypes to be more specific (I know they're both the string
type right now, but that may change if the other change I mentioned can be applied).
private rngType: keyof RNGTypes
Thank you @mattsoulanille for reviewing our pull request. We are getting right on the changes! |
Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
…ss, got rid of rng_types as all random distribution functions seem to be stateless Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
…nd stores random distribution functions as properties Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
…mLayer Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
Co-authored-by: Silvia Kocsis <[email protected]> Co-authored-by: Natalie Umanzor <[email protected]>
We changed the abstract class layer We previously had rng_types handling which random distribution would be selected in the preprocessing layers. We removed these - how we had them defined didn't appropriately relate to keras. As you said, in keras, the rng_types were there to determine whether the op was stateful or stateless - we believe all the random distribution functions to be stateless, so we didn't incorporate rng_types when creating an instance of Some conditionals were removed / rewritten based on your suggestions. Thanks again @mattsoulanille |
tfjs-presubmit (learnjs-174218) failed. The details are not visible for us. Can you please let us know what is the problem? @mattsoulanille |
Sorry, I think I gave some bad suggestions in my review.
Looks like this Safari version does not support You should be able to access the build logs if you join the discussion or announcement mailing list Thanks for making the changes! I'll take a look and review them tomorrow. |
Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
Hey @mattsoulanille hope you are doing well. Just wanted to check in and see how things are looking with our latest changes to the PR. Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't get to this sooner.
import { randomGamma, randomNormal} from '@tensorflow/tfjs-core'; | ||
import { randomStandardNormal, randomUniform } from '@tensorflow/tfjs-core'; | ||
|
||
type randomGammaType = typeof randomGamma; | ||
type randomNormalType = typeof randomNormal; | ||
type randomStandardNormalType = typeof randomStandardNormal; | ||
type randomUniformType = typeof randomUniform; | ||
|
||
export class RandomGenerator { | ||
/** @nocollapse */ | ||
static className = 'RandomGenerator'; | ||
protected currentSeed: number; | ||
private readonly seed: number; | ||
randomGamma: randomGammaType; | ||
randomNormal: randomNormalType; | ||
randomStandardNormal: randomStandardNormalType; | ||
randomUniform: randomUniformType; | ||
|
||
constructor(seed: number) { | ||
this.seed = seed; | ||
this.currentSeed = seed; | ||
this.randomGamma = randomGamma; | ||
this.randomNormal = randomNormal; | ||
this.randomStandardNormal = randomStandardNormal; | ||
this.randomUniform = randomUniform; | ||
} | ||
|
||
next(): number | null { | ||
if (typeof this.seed === 'number'){ | ||
return ++this.currentSeed; | ||
} | ||
return null; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to store the random sampler functions on this class, since they're never called by the class. When you need them somewhere else, you can import them. This class should just keep track of the seed.
import { randomGamma, randomNormal} from '@tensorflow/tfjs-core'; | |
import { randomStandardNormal, randomUniform } from '@tensorflow/tfjs-core'; | |
type randomGammaType = typeof randomGamma; | |
type randomNormalType = typeof randomNormal; | |
type randomStandardNormalType = typeof randomStandardNormal; | |
type randomUniformType = typeof randomUniform; | |
export class RandomGenerator { | |
/** @nocollapse */ | |
static className = 'RandomGenerator'; | |
protected currentSeed: number; | |
private readonly seed: number; | |
randomGamma: randomGammaType; | |
randomNormal: randomNormalType; | |
randomStandardNormal: randomStandardNormalType; | |
randomUniform: randomUniformType; | |
constructor(seed: number) { | |
this.seed = seed; | |
this.currentSeed = seed; | |
this.randomGamma = randomGamma; | |
this.randomNormal = randomNormal; | |
this.randomStandardNormal = randomStandardNormal; | |
this.randomUniform = randomUniform; | |
} | |
next(): number | null { | |
if (typeof this.seed === 'number'){ | |
return ++this.currentSeed; | |
} | |
return null; | |
} | |
} | |
export class RandomSeed { | |
private currentSeed: number; | |
constructor(readonly seed: number) { | |
this.currentSeed = seed; | |
} | |
next(): number { | |
return ++this.currentSeed; | |
} | |
} |
switch (true) { | ||
case this.interpolation === 'bilinear': | ||
return image.resizeBilinear(inputs, size); | ||
case this.interpolation === 'nearest': | ||
return image.resizeNearestNeighbor(inputs, size); | ||
default: | ||
throw new Error(`Interpolation is ${this.interpolation} | ||
but only ${[...INTERPOLATION_METHODS]} are supported`); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switch (true) { | |
case this.interpolation === 'bilinear': | |
return image.resizeBilinear(inputs, size); | |
case this.interpolation === 'nearest': | |
return image.resizeNearestNeighbor(inputs, size); | |
default: | |
throw new Error(`Interpolation is ${this.interpolation} | |
but only ${[...INTERPOLATION_METHODS]} are supported`); | |
switch (this.interpolation) { | |
case 'bilinear': | |
return image.resizeBilinear(inputs, size); | |
case 'nearest': | |
return image.resizeNearestNeighbor(inputs, size); | |
default: | |
throw new Error(`Interpolation is ${this.interpolation} | |
but only ${[...INTERPOLATION_METHODS]} are supported`); |
static override className = 'RandomWidth'; | ||
private readonly factor: number | [number, number]; | ||
private readonly interpolation?: InterpolationType; // defualt = 'bilinear | ||
private seed?: number; // default null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RandomGenerator / RandomSeed should track this seed for you.
private seed?: number; // default null |
'float32', this.seed | ||
); | ||
|
||
this.seed = this.randomGenerator.next(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When saving this layer, should we save the current seed or the original seed we initialized it with?
this.imgHeight = inputShape[inputShape.length - 3]; | ||
const imgWidth = inputShape[inputShape.length - 2]; | ||
|
||
this.widthFactor = this.randomGenerator.randomUniform([1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
randomUniform
is stateless, so you can just import and call it directly. Same with the other random functions. You don't need to put them on the RandomGenerator class.
this.widthFactor = this.randomGenerator.randomUniform([1], | |
this.widthFactor = randomUniform([1], |
'float32', this.seed | ||
); | ||
|
||
this.seed = this.randomGenerator.next(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to avoid duplicating state where possible. The RandomGenerator / RandomSeed class already stores the seed. If you need to check the seed without incrementing it, make it public on the class or add a getter.
this.seed = this.randomGenerator.next(); |
`); | ||
} | ||
|
||
this.seed = seed; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this.seed = seed; |
|
||
return tidy(() => { | ||
const input = getExactlyOneTensor(inputs); | ||
const inputShape = input.shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just use input.shape
directly. No need to assign a local inputShape
variable.
const config: serialization.ConfigDict = { | ||
'factor': this.factor, | ||
'interpolation': this.interpolation, | ||
'seed': this.seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseRandomLayer
should implement its own getConfig
that saves the seed
.
'seed': this.seed, |
No worries @mattsoulanille ! Thanks for the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take a closer look after @mattsoulanille 's comments and suggestions are resolved.
factor: number | [number, number]; | ||
interpolation?: InterpolationType; // default = 'bilinear'; | ||
seed?: number;// default = false; | ||
autoVectorize?:boolean; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Space between ':' and 'boolean'. You cloud run clang-format on all your files.
* tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5', | ||
* 'gaussian', 'mitchellcubic' | ||
* | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the comment for the class? If so please put it on top of the class.
Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
We cleaned up the We took the We set up a Thank you @mattsoulanille @chunnienc ! |
Hey @mattsoulanille wanted to check in on the PR to see if you had any updates. Thanks! |
protected randomGenerator: RandomSeed; | ||
private seed?: number; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RandomSeed
class already stores the value of seed
, so you don't need to store it again here.
protected randomGenerator: RandomSeed; | |
private seed?: number; | |
protected randomSeed: RandomSeed; |
this.seed = args.seed; | ||
this.randomGenerator = new RandomSeed(this.seed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this.seed = args.seed; | |
this.randomGenerator = new RandomSeed(this.seed); | |
this.randomSeed = new RandomSeed(args.seed); |
|
||
override getConfig(): serialization.ConfigDict { | ||
const config: serialization.ConfigDict = { | ||
'seed': this.seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'seed': this.seed | |
'seed': this.randomSeed.seed, |
|
||
constructor(args: BaseRandomLayerArgs) { | ||
super(args); | ||
this.randomGenerator = new RandomGenerator(args.seed); | ||
this.seed = args.seed; | ||
this.randomGenerator = new RandomSeed(this.seed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.seed has type number | undefined
. I think there needs to be a check for undefined
here before constructing the RandomSeed
(or make seed
optional on RandomSeed
).
What is the behavior if no seed is passed? Is it true randomness, or does it just choose a seed randomly, or does it choose the same seed every time? If it's the second of these, then I suggest allowing and undefined seed (or no seed) to be passed to RandomSeed
. Then, you can have RandomSeed
generate a seed randomly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The utility classes for the random sampler classes use Math.random()
to assign the seed if none exists. Here's the example from randomUniform
's UniformRandom
utility class
'float32', this.randomGenerator.currentSeed | ||
); | ||
|
||
this.randomGenerator.next(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this.randomGenerator.next()
increments the seed and then returns the new seed, so you can use its return value instead of incrementing it separately.
'float32', this.randomGenerator.currentSeed | |
); | |
this.randomGenerator.next(); | |
'float32', this.randomGenerator.next()); |
If you want to use the current seed before incrementing it, change the implementation of RandomGenerator.next
to return seed++
(increments after getting the value) instead of ++seed
(increments before getting the value).
…, and added test for RandomSeed Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
Co-authored-by: Natalie Umanzor <[email protected]> Co-authored-by: Ryan Wallace <[email protected]>
Hey @mattsoulanille we removed the seed from |
Hey @mattsoulanille - wanted to reach out to see how things are looking with our latest PR. Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'm sorry it took me so long to get to this.
seed: number | undefined; | ||
constructor(seed: number | undefined) { | ||
this.seed = seed; | ||
} | ||
next() { | ||
++this.currentSeed; | ||
next(): number | undefined { | ||
if (this.seed === undefined) { | ||
return undefined; | ||
} | ||
return this.seed++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to other reviewer: When no seed is specified, the keras implementation does not choose a random seed. It just generates random values. Saving it yields something like this, with the seed set to Null
:
{'name': 'random_width_3', 'trainable': True, 'dtype': 'float32', 'factor': (-0.2, 0.3), 'interpolation': 'bilinear', 'seed': None}
This is truly random, and there is no reproducibility here.
On the other hand, when the user sets the seed, it gets saved to the config.
{'name': 'random_width_2', 'trainable': True, 'dtype': 'float32', 'factor': (-0.2, 0.3), 'interpolation': 'bilinear', 'seed': 1}
@mattsoulanille Thanks for all the help on this! |
@chunnienc Please take a look when you get a chance. Thanks! |
BaseRandomLayer Abstract Layer and RandomWidth Preprocessing Layer
Co-authored-by:
Natalie Umanzor (@numanzor) [email protected]
Silvia Kocsis (@Silvia42) [email protected]
Ryan Wallace (@RWallie) [email protected]
This change is