Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need to reduce Keras warnings #38

Closed
cmdupuis3 opened this issue Nov 16, 2021 · 8 comments
Closed

Need to reduce Keras warnings #38

cmdupuis3 opened this issue Nov 16, 2021 · 8 comments

Comments

@cmdupuis3
Copy link

cmdupuis3 commented Nov 16, 2021

In trying to solve #36, I found that this warning

WARNING:tensorflow:Keras is training/fitting/evaluating on array-like data. Keras may not be optimized for this format, so if your input data format is supported by TensorFlow I/O (https://github.com/tensorflow/io) we recommend using that to load a Dataset instead.

is duplicated so many times in my Jupyter notebook that it explodes my browser's RAM usage and freezes my computer! Recommend only printing unique warnings, but no clue how to implement that. Setting model.fit option verbose=0 suppresses most of the spam, but this warning gets through :/ I assume that the solution to this would also make setting verbose=0 less important.

@rabernat
Copy link
Contributor

If these are normal python warnings, you should be able to suppress them as described here: https://docs.python.org/3/library/warnings.html#overriding-the-default-filter

@cmdupuis3
Copy link
Author

No luck so far.

@cmdupuis3
Copy link
Author

cmdupuis3 commented Nov 16, 2021

This style of warning suppression seems promising, but it's not working:

import warnings

def fxn():
    warnings.warn(Warning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()

@cmdupuis3
Copy link
Author

Adding clear_output(wait=True) in the for batch in bgen loop helps. It would make sense to add this to the documentation, I expect this will be a common problem. I'll see if I can add it in the (future) pull request for #36

@jhamman
Copy link
Contributor

jhamman commented Nov 16, 2021

@cmdupuis3 - how would I reproduce the warning you are getting? Can you share a contained example?

@cmdupuis3
Copy link
Author

@jhamman I don't have a simple example of this handy, can I send you a gist of what I have?

@cmdupuis3
Copy link
Author

cmdupuis3 commented Nov 16, 2021

I think I have something workable...

Using my update to xbatcher here, you can run this:

    import numpy as np
    import xarray as xr
    
    from IPython.display import clear_output
    import tensorflow as tf
    import gc

    import importlib.util
    spec = importlib.util.spec_from_file_location("xbatcher", "<your xbatcher path>")
    xb = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(xb)
        
    Z = xr.DataArray(np.random.rand(640, 640), dims=['nlon', 'nlat'], name="Z")
    t1 = xr.DataArray(np.random.rand(640, 640), dims=['nlon', 'nlat'], name="t1")
    ds = xr.Dataset({'Z':Z, 't1':t1})
    
    def train(ds, conv_dims_2D = [20,20], nfilters=80):
        
        nlons = conv_dims_2D[0]
        nlats = conv_dims_2D[1]
        
        bgen = xb.BatchGenerator(
            ds,
            {'nlon':nlons,        'nlat':nlats},
            {'nlon':int(nlons/2), 'nlat':int(nlats/2)}
        )
        
        input_stencil_2D = tf.keras.Input(shape=tuple(conv_dims_2D) + (1,))
        conv_layer_2D    = tf.keras.layers.Conv2D(nfilters, conv_dims_2D)(input_stencil_2D)
        reshape_layer_2D = tf.keras.layers.Reshape((nfilters,))(conv_layer_2D)
        
        output_layer = tf.keras.layers.Dense(1)(reshape_layer_2D)
    
        model = tf.keras.Model(inputs=[input_stencil_2D], outputs=output_layer)
        model.compile(loss='mae', optimizer='Adam', metrics=['mae', 'mse', 'accuracy'])
        
        for batch in bgen:
            batch_stencil_2D = batch['Z'].expand_dims('var', 3)
            batch_target = batch['t1'].isel(nlat=int(nlats/2), nlon=int(nlons/2)).expand_dims('var', 1)
            
            model.fit([batch_stencil_2D],
                      batch_target,
                      batch_size=32, epochs=2, verbose=0)
            #clear_output(wait=True)
    
        return model
        
    train(ds)

@cmdupuis3
Copy link
Author

cmdupuis3 commented Feb 2, 2023

Looking back at this, I think this issue is more a consequence of iterating over separate batches, rather than a problem with xbatcher itself. I could see a solution here being useful in the context of debugging an ML model, but I don't think it would make sense as part of xbatcher's core functionality.

Closing for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants