Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras - how to use class_weight with 3D data #3653

Closed
bsafacicek opened this issue Aug 31, 2016 · 66 comments
Closed

Keras - how to use class_weight with 3D data #3653

bsafacicek opened this issue Aug 31, 2016 · 66 comments

Comments

@bsafacicek
Copy link

bsafacicek commented Aug 31, 2016

Hi,

I am using Keras to segment images to road and background pixels. As you can imagine percentage of road pixels are much lower than that of background pixels. Hence, I want to use class_weight= {0:0.05, 1:0.95} while fitting the model so that cnn won't predict every pixel as background. But, when I do this I got the following error:

File "/usr/local/lib/python2.7/dist-packages/keras/models.py", line 597, in fit
sample_weight=sample_weight)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1035, in fit
batch_size=batch_size)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 973, in _standardize_user_data
in zip(y, sample_weights, class_weights, self.sample_weight_modes)]
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 387, in standardize_weights
raise Exception('class_weight not supported for '
Exception: class_weight not supported for 3+ dimensional targets.

My training labels are in this form: (number_of_training_samples=10000, number_of_pixels_in_patch=16384, number_of_classes=2). How can I weight the classes in Keras?

Thanks in advance.

@fchollet
Copy link
Collaborator

fchollet commented Sep 1, 2016

You should use sample_weight instead. class_weight is not supported for 3+ dimensional targets because the concept of class is ambiguous in that case.

@uschmidt83
Copy link
Contributor

Hi, I have the same problem.

I don't understand how this can be accomplished by using sample_weight since every pixel of a sample requires a different weight based on its class. Or do you suggest to do this by using sample_weight_mode="temporal"?

@bsafacicek
Copy link
Author

bsafacicek commented Sep 8, 2016

Hi,

I also could not get how to use sample_weight for class weighting. Because keras requires that length of the sample_weight should be the same as that of the first dimension of the class labels. But, class labels have also second and third dimensions for image height and width. And to weight the class types, I should weight the pixel labels not just the whole image.

Thanks.

@uschmidt83
Copy link
Contributor

To follow up on this, I got it to work using sample_weight. It is quite nice if you know what you have to do. Unfortunately, the documentation is not really clear on this, presumably because this feature was originally added for time series data.

  • You need to reshape your 2D image-sized output as a vector before the loss function when you specify your model.
  • Use sample_weight_mode="temporal" when you compile the model. This will allow you to pass in a weight matrix for training where each row represents the weight vector for a single sample.

I hope that helps.

@rdelassus
Copy link

rdelassus commented Jan 4, 2017

Hey @KKOG, I got exactly the same issue, did you find any solution?

@rdelassus
Copy link

@uschmidt83
you said "This will allow you to pass in a weight matrix for training where each row represents the weight vector for a single sample."

But this is not very clear. how did you build you weight vector? Say I have 3 classes, my weight vector size will be equals to the number of pixels in my image, with values being weight_0, weight_1 and weight_2? seems like a waste of space, maybe I'm wrong?

@uschmidt83
Copy link
Contributor

Hi @rdelassus,

Say I have 3 classes, my weight vector size will be equals to the number of pixels in my image, with values being weight_0, weight_1 and weight_2? seems like a waste of space, maybe I'm wrong?

it seems like a waste of space for your particular use case, although I doubt that this actually matters much in practice. However, it also allows much more fine-grained control, which is probably crucial for other applications/models.

Sorry for the late reply.

@trogloditee
Copy link

trogloditee commented May 11, 2017

@uschmidt83 I'm having trouble making this work and wonder if you have an insight.

I have 4 classes in a semantic segmentation task, and my class weights are

class_weights = {0: 0.41, 1: 1.87, 2: 1.1, 3: 7.05}

When I put this in class_weight within model.fit I get the same error as you mentioned above.

Exception: class_weight not supported for 3+ dimensional targets.

When I change class_weight to sample_weight within model.fit
and add sample_weight_mode='temporal' within model.compile, I get

line 528, in _standardize_weights
    if sample_weight is not None and len(sample_weight.shape) != 2:
AttributeError: 'dict' object has no attribute 'shape'

The shapes in the final portions of the model are

conv2d_19 (Conv2D)           (None, 4, 64, 64)         260       
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4096)           0         
_________________________________________________________________
permute_1 (Permute)          (None, 4096, 4)           0         
_________________________________________________________________
activation_1 (Activation)    (None, 4096, 4)           0         

Do you have any suggestions to get this to work?

@kglspl
Copy link

kglspl commented May 15, 2017

@mptorr I am facing a similar problem but am stuck elsewhere... However, the way I understand @uschmidt83's suggestion you need to use:

class_weights = np.zeros((4096, 4))
class_weights[:, 0] += 0.41
class_weights[:, 1] += 1.87
class_weights[:, 2] += 1.1
class_weights[:, 3] += 7.05

Hope it helps, please let us know how it goes. And if anyone knows more feel free to chime in. ;-)

@trogloditee
Copy link

@kglspl by reshaping my layers, I can actually use sample_weight---my issue is now how to do this with data augmentation, if you have time look at #6629 and let me know if you have an insight, thanks

@ahundt
Copy link
Contributor

ahundt commented May 21, 2017

Figured out where some changes could happen to make progress in this direction. #6538 (comment)

@ezisezis
Copy link

@kglspl @mptorr I tried to set the sample weights like suggested. I have a binary pixel-wise classification task that i want to perform that takes in 100x100 images and outputs the same resolution images basically. On final layer I reshape the output so it is the same as in @mptorr arcitecture above. Here is my arch:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 100, 100, 3)       0         
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 104, 104, 3)       0         
_________________________________________________________________
conv1 (Conv2D)               (None, 102, 102, 32)      896       
_________________________________________________________________
pool1 (MaxPooling2D)         (None, 51, 51, 32)        0         
_________________________________________________________________
conv2 (Conv2D)               (None, 49, 49, 32)        9248      
_________________________________________________________________
pool2 (MaxPooling2D)         (None, 25, 25, 32)        0         
_________________________________________________________________
fc6 (Conv2D)                 (None, 25, 25, 64)        73792     
_________________________________________________________________
dropout_1 (Dropout)          (None, 25, 25, 64)        0         
_________________________________________________________________
fc7 (Conv2D)                 (None, 25, 25, 64)        4160      
_________________________________________________________________
dropout_2 (Dropout)          (None, 25, 25, 64)        0         
_________________________________________________________________
score_fr (Conv2D)            (None, 25, 25, 2)         130       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 100, 100, 2)       0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 10000, 2)          0         
=================================================================

Then I try setting sample_weight to this (where 13000 is number of training samples):

sample_weight = np.zeros((13000,10000,2))
sample_weight[:, 0] += 1
sample_weight[:, 1] += 10

But I get this error:

ValueError: Found a sample_weight array with shape (13000, 10000, 2). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.

Then I also tried doing this:

sample_weight = np.zeros((10000,2))
sample_weight[:, 0] += 1
sample_weight[:, 1] += 10

But got the following error:

ValueError: Found a sample_weight array with shape (10000, 2) for an input with shape (13000, 10000, 2). sample_weight cannot be broadcast.

So now I am confused. Buy @mptorr you said you made it work by reshaping the layers. So how exactly did you reshape them and how should I do that in my case?

@trogloditee
Copy link

trogloditee commented May 30, 2017

@ezisezis
Could you try reshaping the model output so it ends up like (None, dimx * dimy, classes)?
You can see that happening on my code from above (images are 64*64 = 4096, I have 4 classes):

reshape_1 (Reshape)          (None, 4, 4096)           0         
_________________________________________________________________
permute_1 (Permute)          (None, 4096, 4)           0         
_________________________________________________________________
activation_1 (Activation)    (None, 4096, 4)           0         

Then reshape your sample_weight to (N, dimx * dimy).
Also make sure your masks are (N, dimx * dimy).
Essentially this matches masks and sample_weight as flattened tensors. I believe my training images are not flattened. Give it a try and let me know.

@ezisezis
Copy link

@mptorr I figured it out a bit earlier today with the output i have in previous comment, from 100x100 images, i output a (10000,2) shape. I have 13000 training images and then the sample_weight dimensions are (13000,10000) and it works very well.

@sebastienbaur
Copy link

sebastienbaur commented Jun 13, 2017

I ran into a similar problem, using categorical cross entropy

Maybe we can just use a weighted version of it ? What do you think of :
K.sum(K.log(1e-9+predicted_proba) * true_labels * class_weight, axis=(1, 2))
Where predicted_proba, true_labels, class_weight are 3 tensors of shape (batch_size, sequence_length, nb_classes)

Note that:

  • class_weight[i, :, :] is the same array whatever the value of i. Let's call that common value x (a 2d array)
  • x is actually rank 1, each of its lines being equal to its first one (ex: [[1,2,3], [1,2,3], [1,2,3], [1,2,3]] if there are 3 classes and sequences have a length of 4). This is because the weight does not depend on the position (you can change that if you need to)

That way you can give more importance to rare classes

@ahundt
Copy link
Contributor

ahundt commented Jun 13, 2017

Could someone post a concise example of this, or perhaps a small pull request in the examples directory? It seems like a number of people would find it very valuable.

@sebastienbaur
Copy link

sebastienbaur commented Jun 14, 2017

Below is what I meant with some code.

It is a bit specific to my use case but it should be easy to adapt I guess.

Just tell me if there is something wrong in it. I think that it gives more weight to rare classes. I may have misunderstood the problem

import numpy as np
import keras.backend as K


batch_size = 32
all_y = ...  # a list containing all your class vectors, each being an array of a given size, each of its component being an integer representing a given class
# in my case, I have protein sequences, that I represent as arrays of integers. These integers represent the amino acids composing the sequence (+ the padding char), they are in range(0,21)
bincount = np.bincount(np.concatenate(all_y))
n_samples = 20000
length = 500  # my proteins have a length of 500
n_classes = 21  # there are 20 amino acids + the padding character
class_weight = n_samples*1. / (n_classes * bincount)
weights = np.ones((length, n_classes))
for k, x in enumerate(class_weight):
    weights[:, k] *= x
class_weight = K.constant(np.concatenate(batch_size*[np.array(weights).reshape((1, length, n_classes))]))

def cross_entropy(true, pred):
    return - K.sum(K.log(1e-9+pred) * true * class_weight, axis=(1, 2))

@ahundt
Copy link
Contributor

ahundt commented Jun 16, 2017

@sebastienbaur Thanks, that looks like an eay way to add it in. Be careful though! The raw formulation of cross-entropy in your code can be numerically unstable as commented in the tensorflow mnist example, so that might affect your results with the code above.

@potis
Copy link

potis commented Jun 16, 2017

Hi,

I am running into the same problem as @ezisezis, using keras 2.0.5 and theano as backend (python 2.7).

My goal it to use unet to perform image segmentation but the regions i am trying to segment are of different size.
{0: 75.0, 1: 89.0, 2: 61.0, 3: 56.0, 4: 194.0, 5: 1.0}

I tried to use sample_weight instead of class weight so I compiled the model accordingly:
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=[dice_coef], sample_weight_mode="temporal")

Here is the input data size:

(4717, 1, 256, 256)

The size of labels:

(4717, 65536, 6)

The size of input weights

(4717, 65536, 6)

The last layers of my network:

conv2d_92 (Conv2D) (None, 6, 256, 256) 33 conv2d_91[0][0]


permute_4 (Permute) (None, 256, 256, 6) 0 conv2d_92[0][0]


reshape_4 (Reshape) (None, 65536, 6) 0 permute_4[0][0]

And finally the error i am getting:

ValueError: Found a sample_weight array with shape (4717, 65536, 6). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.

Then I tried the suggestion of @mptorr to reshape your sample_weight to (N, dimx * dimy).
Also make sure your masks are (N, dimx * dimy). (for my task dimx=dimy=256) but here is the error I got:
ValueError: Found a sample_weight array with shape (6, 65536) for an input with shape (4717, 6, 65536). sample_weight cannot be broadcast.

Please let me know if you have any suggestion or need more information.

@ezisezis
Copy link

Dear @potis . If you read carefully what I experienced and how I solved it, then in your case the input_weights shape has to be: (4717,65536) OR, in general - (number_of_images, number_of_pixels_in_img). So, each value in this 2D array is a weight of the class that the pixel belongs to (you dont have to assign 6 values to each pixel, only one - the class's weight). Hope it makes more sense.

@potis
Copy link

potis commented Jun 16, 2017

@ezisezis thanks for the response. I guess i was miss interpreting the N as number of classes.

@ghost ghost mentioned this issue Jun 25, 2017
@joeyearsley
Copy link
Contributor

joeyearsley commented Jun 30, 2017

For anybody else struggling with this, this details a formula to get class weights for pixels.
https://blog.fineighbor.com/tensorflow-dealing-with-imbalanced-data-eb0108b10701

Set your class weightings up like described in the above blog, then set sample_weight_mode=temporal then setup your sample_weights such that (nb_samples, dim_x * dim_y, nb_classes).

To get your sample weights, multiply each output channel with the corresponding class weight found from the above blog (med_freq/freq_cx where x is an element of Classes).

Finally, sample_weights = np.squeeze(np.sum(sample_weights, axis=-1)); We can do this since the channel axis is one-hot encoded.

@jmtatsch
Copy link

jmtatsch commented Jul 11, 2017

class_weight is not supported for 3+ dimensional targets because the concept of class is ambiguous in that case.

Can someone elaborate on that pls? I really dont understand why that is ambigous.

@EloyRoura
Copy link

@ezisezis Thanks a lot for your post, I've been digging into the code for hours and this really did the trick. I still would prefer class_weight to do what is supposed to, but I don't get how should we use it or in which scenario.
Anyway, I still have another question. If the the fit(...) function is used, sample_weights can be the work around, but what about the fit_generator(...), there is no option for that. Do you have a solution in that case?

Thanks in advance :)

@ahundt
Copy link
Contributor

ahundt commented Jul 13, 2017

One approach to resolving the ambiguity would be to add support for a property we can attach to a tensor or numpy array that specifies the type of data each dimension in a tensor represents (batch, width, height, depth, class, etc).

@EloyRoura
Copy link

OK! I'll answer to my own comment. My bad, I didn't see the input tuple can indeed be (inputs, target, weights), so this should solve the problem

@JianbingDong
Copy link

@ezisezis Hi, youve done a great job about this error. But after reading your suggestion, i still dont know how to assign weight to my sample_weight array...
If my model output is an array which shape is (pixels_xpixels_y, num_class), and i set my sample_weight array as shape (num_samples, pixels_xpixels_y), could this work?
and how should i assign class weights to my sample_weight array?
thanks in advance.

@rdelassus
Copy link

Following @JianbingDong question, does an array with a shape (num_class) works? such that classes are weighted, not pixels

@anilsathyan7
Copy link

Iam facing a similar issue...how to set the values for weights(sample and class) if i use sparse categorical cross-entropy??
How can i use class weights or sample weights to ignore labels (void) in a segmenation task (setting to zero), where i use sparse categorical entropy as loss function and labels are not one hot encoded? Say my final output shape is (None, 16384, 2) and segmentation label is of shape (None,16384,1) with int values 0(bg),1(fg) and 2(void).Is this possible only by implementing a custom loss function? I tried corresponding tensorflow loss function with weights(custom loss function with tf); but it became more complex and messier !!!

@shahriar49
Copy link

I am skeptic if we need a custom loss function, because keras is embedding weighting and masking in its loss function prototypes inherently (see

# Compute total loss.
). I feel that if we supply it with (data, label, sample_weight) and provide class_weight if we need, there should be no reason for defining a custom loss function, but I am not sure.
@anilsathyan7

@aloerch
Copy link

aloerch commented Nov 24, 2019

Well, this is still an ongoing issue, and drove me nuts for 36 hours, but I've got it figured out for my dataset. For anyone trying to use TensorFlow 2.0 with Keras, and the tf.data.Dataset pipelines, this info may help you, since it worked for me. Here's my basic setup:

Input data: RGB images (png format)
Label data: single-channel png images with 8 classes (0 to 7) (note: in my code I called these masks)
Classes: background / ignore label: 0, 1-7 are classes of interest
Batch: 12 (who knew 2 1080Ti's weren't sufficient, lol)
Model: Resnet50

Goal: I want my model to 'learn' the relevant classes (1-7) and ignore the background / void class (0). To do this, class_weights does nothing... (not even an error when I tried it, it just did nothing). So, I want to use sample_weights with the sample_weight_mode='temporal' to handle pixel-wise 3D weights.

To make this all work, I did the following:

  1. Created the sample_weights based on the labels. Because I want to ignore the background class and not change anything else, I simply did:
    sample_weight = tf.math.divide_no_nan(label, label)

This gives 0's for class 0 and 1's for all other classes. However, you can add weights to other classes by using numpy directly instead, for example:
label[label = 4] = 0.8

would replace the number 4 with your desired weight for the class 4. You could do this for any classes and set others to 1's, or whatever.

  1. Reshape the labels and sample weights to make them compatible with sample_weight_mode='temporal'. The labels are reshaped like:
    label = tf.reshape(label, [102400, -1])

(note: the shape was because my data was (320, 320, 1) and I'm not sure the -1 was necessary or if I could've used 1, but it works. Either way, your labels here (prior to batching) need to be 2D). The sample_weights are reshaped a bit differently:
sample_weight = tf.reshape(sample_weight, [102400])

(note: the shape of the sample_weight matrix needs to end up (batch_size, 1D tensor) so that it becomes a 2D tensor when given to model.fit(). In my case, I perform batching with tf.data.dataset.batch(), so the added shape value for batch_size gets added at that point.

  1. Created a tf.data.Dataset object containing the input images, labels, and sample_weights.

  2. Modify the resnet50.py file (or whatever contains your model layers) to change the output shape to work with both the label and sample_weight tensors:
    x = Reshape((102400, nclasses))(x)

  3. Set the model.compile() parameter for sample_weight_mode='temporal'

Following the above steps results in a model successfully training while ignoring the background class, and uses the keras built-in loss function sparse categorical cross entropy. In that loss function, the sample weights are automatically applied to the weights, resulting in weights for class zero being made 0, and all other class weights being unchanged.

To perform inference, and have a meaningful prediction image saved to disk, I simply remove the last layer added in step 4 above (the one that reshaped the output) and then load the model and saved weights. Predictions / inference images now have only classes 1-7.

Thanks to Keras / TensorFlow for making this challenging as humanly possible, and for allowing this complexity to persist for over 3 years!

@iuria21
Copy link

iuria21 commented Dec 5, 2019

Hi @ylmeng , thanks for your solution!
I'm using the my_weighted_loss function and I want to ask a question. I'm working with a NER problem; I have a partially labeled dataset, and I want to use this loss function in order to less punish when the model predicts a label but the dataset has not a label there. So could this be useful for me? and do I have to increase the class_weight of the O class or the others? I'm a bit confused here.

Thanks for your help!

@chrishmorris
Copy link

You can do this without a custom loss function, by changing the y_true values from eg:
(1,0,0,0), (0,0,1,0),...
to:
(class_weight[0],0,0,0), (0,0,class_weight[2],0),...
The reason why this works is the formula for cross-entropy, which is a sum of p*log(q). Downscaling p has the same effect as a class weight.

@JoanaNRocha
Copy link

JoanaNRocha commented Dec 14, 2020

I am not completely sure why this happens but I seemed to have solved this by changing the syntax of the class weights. Instead of using a dict, try this (binary classification example):

class_weight=[1, 0.1]

It worked for me...

@GitHubUser97
Copy link

GitHubUser97 commented Jan 15, 2021

As this problem was extremely annoying to overcome,
I post a simple function that you can use to create a proper sample_weight, when output of your model is 3D.
I'm working on a model with output (nr_observations x prediction_horizons x one_hot_encoded_labels) , its shape is (n x 5 x 3)
So in order to properly use sample weights, first of all, when compiling the model you need to add sample_weight_mode="temporal" argument to it.
Then you need to create an 2D array, in my case of size (n x 5) consisting exclusively of proper weight in places where there would be one of your classes in the original training array. ex. in your training, one particular observation would look like that [0,2,0,1,2] (before one hot encoding) and your weight dictionary would be for example {0:0.9, 1:1.34, 2: 0.5}
then your transformed observation would look like [0.9,0.5,0.9,1.34,0.5]
And you need to do that to the entire training array, so you will have an array of shape (n x 5)

After thats done, you can feed the array to the model during fitting it by adding argument sample_weight= your_array_with_sample_weights

The function that I created to transform the y_train array into the sample_weights 2D input array is as follows:

def generate_sample_weights(training_data, class_weights): 
    #replaces values for up to 3 classes with the values from class_weights#
    sample_weights = [np.where(y==0,class_weights[0],
                        np.where(y==1,class_weights[1],
                        np.where(y==2,class_weights[2],y))) for y in training_data]
    return np.asarray(sample_weights)

It is based on another post on stack exchange, https://datascience.stackexchange.com/a/31542
Its inputs are your y_train (before one hot encoding) and a dictionary with class weights.
Mine works for for 3 classes, but you can easily modify it to work with 2 or more.

@janwillembuist
Copy link

@GitHubUser97 I really like your solution and think it might work for my use case of this problem, but there is one concern I have. I am working with a data generator, so I cannot pre-compute the sample weights. Do you have an idea on how to let Keras generate sample weights on the fly?

@GitHubUser97
Copy link

@janwillembuist I think that the answer by @janbrrr is what you are looking for. You can modify your generator to produce array with sample weights on the fly from predefined classes as he did

height * width, 1). Assuming you have an array class_weights where the index is the class and the value at each index is the weight, you can simply add the following in the batch generation:
...
sample_weights = numpy.take(class_weights, y[:, :, 0])
return X, y, sample_weights

@janwillembuist
Copy link

@GitHubUser97, @janbrrr Thanks for your help, I ended up creating a combination of your contributions, which solved my problem in the end!

@GitHubUser97
Copy link

@GitHubUser97, @janbrrr Thanks for your help, I ended up creating a combination of your contributions, which solved my problem in the end!

No problem, happy to help @janwillembuist :)
Would you mind sharing the solution? I know for a fact that sooner or later I will face the same problem when I will start working with a generator
Cheers

@janwillembuist
Copy link

Sure! I am working with a custom generator, a subclassed keras.utils.Sequence instance. In the __getitem__(self, index) method, I have added the generation of sample weights:

sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int'))

The class weights I specify in the generator's init method, and for generating the weights I am using the second channel of my one-hot-encoded output. Then, you add the sample_weights to the return of __getitem__() and add sample_weight_mode="temporal" to model.compile(). I will come up with a more detailed answer soon!

@janwillembuist
Copy link

As I created it as an issue on stackoverflow before, where I was pointed to this issue, I gave a long answer there!

@GitHubUser97, and of course others, you can check it out here

@oilaba
Copy link

oilaba commented Apr 20, 2021

@graffam Did you open any PR about this since then?

@sayakpaul
Copy link
Contributor

@aloerch could you expand a bit more on where you are incorporating the following? sample_weight = tf.math.divide_no_nan(label, label)

Are you calculating this externally or inside the tf.data pipeline? A bit more clarification would be very helpful.

@sayakpaul
Copy link
Contributor

sayakpaul commented Apr 23, 2021

For binary segmentation problems, could we do something like the following (inspired by this tutorial)?

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = keras.losses.binary_crossentropy(real, pred)
    loss_ = tf.expand_dims(loss_, axis=-1)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    loss_ = tf.nn.compute_average_loss(loss_, global_batch_size=BATCH_SIZE)
    mask = tf.nn.compute_average_loss(mask, global_batch_size=BATCH_SIZE)
    return loss_/mask

loss_function is then passed to the compile() call like - model.compile(loss=loss_function, ...).

TensorFlow-Docs-Copybara pushed a commit to tensorflow/docs that referenced this issue May 1, 2021
@sayakpaul
Copy link
Contributor

Here's a section showing how to get the sample_weight the right way inside a tf.data pipeline: https://www.tensorflow.org/tutorials/images/segmentation#optional_imbalanced_classes_and_class_weights.

Huge thanks to @MarkDaoust for this one.

@arashnh11
Copy link

Sample weight for class imbalances is quite wasteful as mentioned earlier on this thread. One may not feel that when dealing with small datasets but it would become quite clear when dealing with volume data like my case, even 48GB GPUs run out-of-memory. Rather than wasting all the space to define similar values for an entire class incorporated into sample weights, one can write customized loss functions and split the loss function into multiple sums each with a single value for the associated class. Hope this helps those who get into this thread and are dealing with large data volumes like me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests