Skip to content

Speeding up MERRA interpolates

Nick edited this page Feb 15, 2022 · 4 revisions
def wind_rh_Extrapolate(data_total):
        """ Extrapolate data beyond mask limits

        MERRA2 applies a mask to data lower than the land surface. This is different
        from the other reanalyses and will introduce errors during interpolation without
        This extra extrapolation step.

        We conduct 1D vertical extrapolation for missing values (9.9999999E14).
        """
       
        # restructure u_total,v_total from [time, lev, lat, lon] to [lat, lon, time, lev]
        data_total = data_total[:,:,:,:].transpose((2,3,0,1))

        # find and fill the value gap
        for i in range(0, len(data_total)):  # for each latitude
            for j in range(0,len(data_total[0])):  # for each longitude
                data_time = data_total[i][j][:]    # take a time slice (time, lev)
                for k in range(0,len(data_time)):  # for each time
                    data_lev = data_time[k][:]     # data at each level for particular time
                    id_interp = []
                    for z in range(0, len(data_lev)):  # for each pressure level
                        if data_lev[z] > 99999:        # note the index of missing data
                            id_interp.append(z)

                        if id_interp != []:            # if there are some missing data somewhere
                            z_top = id_interp[-1] + 1  # get the index of the pressure level one (higher/lower?) than the one that was just added
                            data_lev[id_interp] = data_lev[z_top]  # assign all previously found levels to the same value as the top one  (this might happen a few times)
                        else:
                            data_lev[z] = data_lev[z]  # (redundant) otherwise assign that level as itself

                    data_time[k][:] = data_lev  # replace the data_time slice

                # replace the interpolation value to each single pixel
                data_total[i][j][:] = data_time

        # restructure back
        data_total = data_total[:,:,:,:].transpose((2,3,0,1))

        return data_total

this tangle of for loops does continuous interpolation along the level dimension. However, it is very inefficient.

this replacement is 7000 times faster (based on 100 repetitions)

import numpy as np
def constant_extrapolation(all_data):
    """ arr:  array with dimensions [time, lev, lat, lon] """
    
    if np.any(all_data[:, -1, :, :] > 99999):
        raise ValueError("Missing data for an entire column. This might mean the terrain mask includes all the data. Increase max elevation")

    n_lev = all_data.shape[1]
    
    for lev in range(n_lev - 2, 0):  # start from the second-highest pressure level
        missing_index = np.where(all_data[:, lev, :, :] > 99999)
        all_data[:, lev, :, :][missing_index] = all_data[:, lev + 1, :, :][missing_index]  # replace from previous
    
    return(all_data)
    

A slightly faster solution could do something like this: https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array

The temperature interpolation uses interp1d from scipy, but has the same troublesome loops

@staticmethod
    def tempExtrapolate(t_total, h_total, elevation):
        """ Extrapolate data beyond mask limits.

        MERRA2 applies a mask to data lower than the land surface. This is different
        from the other reanalyses and will introduce errors during interpolation without
        This extra extrapolation step.

        We conduct 1D vertical extrapolation for missing values (9.9999999E14).

        IMPORTANT TIP:
        ! Set 'ele_max' to 2500 m or higher in the configuration file !
        Reason: to make sure to get enough levels of geopotential height for
        conducting linear interpolation (1dinterp). At least 2 data points are required
        """

        # restructure t_total [time*lev*lat*lon] to [lat*lon*time*lev]
        t_total = t_total[:,:,:,:].transpose((2,3,0,1))
        h_total = h_total[:,:,:,:].transpose((2,3,0,1))

        # find the value gap and conduct 1d extrapolation
        for i in range(0, len(t_total)):
            for j in range(0, len(t_total[0])):
                t_time = t_total[i][j][:]
                h_time = h_total[i][j][:]
                for k in range(0, len(t_time)):
                    t_lev = t_time[k][:]
                    h_lev = h_time[k][:]
                    id_interp = []
                    for z in range(0, len(t_lev)):
                        # find the indices of levels with missing values
                        if t_lev[z] > 99999:
                            id_interp.append(z)

                            if id_interp != []:
                                # get the levels of geopotential heights with missing values
                                lev_interp = h_lev[id_interp]
                                # pass the index of first found level with existing value to z_top
                                z_top = id_interp[-1] + 1
                                # get values at the lowest 3 levels of geopotential heights with existed values
                                lev_3p = h_lev[z_top:z_top + 3]
                                # get values at the lowest 3 levels of air temperature with existed values
                                t_3p = t_lev[z_top:z_top + 3]
                                # Using spicy.interpolate.interp1d function-------------------------
                                # Require >= 2 points of levs and t in minimum
                                if len(lev_3p) >= 2:
                                    # build linear function based on given values at lowest 3 levels of air temperature and geopotential heights
                                    f = interp1d(lev_3p, t_3p, kind='linear', fill_value='extrapolate')
                                    # use built function to calculate the values of air temperature at the found missing-values levels
                                    t_interp = f(lev_interp)
                                    # fill the calculated values into missing-values levels
                                    t_lev[id_interp] = t_interp
                                else:
                                    logger.error('Numbers of points for extrapolation are too low (less then 2):', len(lev_3p))
                                    logger.error('Failed to conduct extrapolation at some points in the output')
                                    logger.error('Current ele_max =', elevation['max'])
                                    logger.error('Higher Value of "ele_max" is needed to reset: > 2500')
                                    sys.exit(0)

                        else:
                            t_lev[z] = t_lev[z]
                        h_lev[z] = h_lev[z]

                    # assign back
                    t_time[k][:] = t_lev
                    h_time[k][:] = h_lev

                # replace the extrapolated value [time * level] to each individual cell
                t_total[i][j][:] = t_time
                h_total[i][j][:] = h_time

        # restructure back
        t_total = t_total[:,:,:,:].transpose((2,3,0,1))
        h_total = h_total[:,:,:,:].transpose((2,3,0,1))

        return t_total