Skip to content

Commit

Permalink
Refactor to process dictionary spec like a pandas.DataFrame
Browse files Browse the repository at this point in the history
Remove huge chunk of if-then code block by
converting Python dict to pandas.DataFrame,
and handle the spec formatting using one route
only.
  • Loading branch information
weiji14 committed Nov 8, 2021
1 parent 9945684 commit e223057
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 107 deletions.
121 changes: 17 additions & 104 deletions pygmt/src/meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,95 +322,20 @@ def update_pointers(data_pointers):

# create a dict type pointer for easier to read code
if isinstance(spec, dict):
dict_type_pointer = list(spec.values())[0]
elif isinstance(spec, pd.DataFrame):
# use df.values as pointer for DataFrame behavior
dict_type_pointer = spec.values

# assemble the 1D array for the case of floats and ints as values
if isinstance(dict_type_pointer, (int, float)):
# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
(
longitude,
latitude,
depth,
plot_longitude,
plot_latitude,
) = update_pointers(data_pointers)

# Construct the array (order matters)
spec = [longitude, latitude, depth] + [spec[key] for key in foc_params]

# Add in plotting options, if given, otherwise add 0s
for arg in plot_longitude, plot_latitude:
if arg is None:
spec.append(0)
else:
if "A" not in kwargs:
kwargs["A"] = True
spec.append(arg)

# or assemble the 2D array for the case of lists as values
elif isinstance(dict_type_pointer, list):
# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
(
longitude,
latitude,
depth,
plot_longitude,
plot_latitude,
) = update_pointers(data_pointers)

# before constructing the 2D array lets check that each key
# of the dict has the same quantity of values to avoid bugs
list_length = len(list(spec.values())[0])
for value in list(spec.values()):
if len(value) != list_length:
raise GMTError(
"Unequal number of focal mechanism "
"parameters supplied in 'spec'."
)
# lets also check the inputs for longitude, latitude,
# and depth if it is a list or array
if (
isinstance(longitude, (list, np.ndarray))
or isinstance(latitude, (list, np.ndarray))
or isinstance(depth, (list, np.ndarray))
):
if (len(longitude) != len(latitude)) or (
len(longitude) != len(depth)
):
raise GMTError(
"Unequal number of focal mechanism " "locations supplied."
)

# values are ok, so build the 2D array
spec_array = []
for index in range(list_length):
# Construct the array one row at a time (note that order
# matters here, hence the list comprehension!)
row = [longitude[index], latitude[index], depth[index]] + [
spec[key][index] for key in foc_params
]

# Add in plotting options, if given, otherwise add 0s as
# required by GMT
for arg in plot_longitude, plot_latitude:
if arg is None:
row.append(0)
else:
if "A" not in kwargs:
kwargs["A"] = True
row.append(arg[index])
spec_array.append(row)
spec = spec_array

# or assemble the array for the case of pd.DataFrames
elif isinstance(dict_type_pointer, np.ndarray):
# Convert single int, float data to List[int, float] data
_spec = {
"longitude": np.atleast_1d(longitude),
"latitude": np.atleast_1d(latitude),
"depth": np.atleast_1d(depth),
}
_spec.update({key: np.atleast_1d(val) for key, val in spec.items()})
spec = pd.DataFrame.from_dict(_spec)

assert isinstance(spec, pd.DataFrame)
dict_type_pointer = spec.values

# Assemble the array for the case of pd.DataFrames
if isinstance(dict_type_pointer, np.ndarray):
# update pointers
set_pointer(data_pointers, spec)
# look for optional parameters in the right place
Expand All @@ -422,19 +347,7 @@ def update_pointers(data_pointers):
plot_latitude,
) = update_pointers(data_pointers)

# lets also check the inputs for longitude, latitude, and depth
# just in case the user entered different length lists
if (
isinstance(longitude, (list, np.ndarray))
or isinstance(latitude, (list, np.ndarray))
or isinstance(depth, (list, np.ndarray))
):
if (len(longitude) != len(latitude)) or (len(longitude) != len(depth)):
raise GMTError(
"Unequal number of focal mechanism locations supplied."
)

# values are ok, so build the 2D array in the correct order
# build the 2D array in the correct order
spec_array = []
for index in range(len(spec)):
# Construct the array one row at a time (note that order
Expand All @@ -458,8 +371,8 @@ def update_pointers(data_pointers):
else:
raise GMTError("Parameter 'spec' contains values of an unsupported type.")

# Ensure non-file types are a 2d array
if isinstance(spec, (list, np.ndarray)):
# Convert 1d array types into 2d arrays
if isinstance(spec, np.ndarray) and spec.ndim == 1:
spec = np.atleast_2d(spec)

# determine data_foramt from convection and component
Expand Down
8 changes: 5 additions & 3 deletions pygmt/tests/test_meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_meca_spec_dict_list():
strike=[330, 350], dip=[30, 50], rake=[90, 90], magnitude=[3, 2]
)
fig.meca(
focal_mechanisms,
spec=focal_mechanisms,
longitude=[-124.3, -124.4],
latitude=[48.1, 48.2],
depth=[12.0, 11.0],
Expand Down Expand Up @@ -110,7 +110,9 @@ def test_meca_spec_dataframe():
depth=[12, 11.0],
)
spec_dataframe = pd.DataFrame(data=focal_mechanisms)
fig.meca(spec_dataframe, region=[-125, -122, 47, 49], scale="2c", projection="M14c")
fig.meca(
spec=spec_dataframe, region=[-125, -122, 47, 49], scale="2c", projection="M14c"
)
return fig


Expand Down Expand Up @@ -183,7 +185,7 @@ def test_meca_spec_2d_array():
]
focal_mechs_array = np.asarray(focal_mechanisms)
fig.meca(
focal_mechs_array,
spec=focal_mechs_array,
convention="gcmt",
region=[-128, -127, 40, 41],
scale="2c",
Expand Down

0 comments on commit e223057

Please sign in to comment.