Skip to content

Commit

Permalink
Add edit_gps_trace function (#971)
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Nov 11, 2024
1 parent ee80385 commit cf70cdf
Showing 1 changed file with 272 additions and 0 deletions.
272 changes: 272 additions & 0 deletions leafmap/maplibregl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3790,3 +3790,275 @@ def maptiler_3d_style(
}

return style


def edit_gps_trace(
filename: str,
m: Any,
colormap: Dict[str, str],
layer_name: str,
fig_width: str = "1550px",
fig_height: str = "400px",
) -> Any:
"""
Edits a GPS trace on the map and allows for annotation and export.
Args:
filename (str): The path to the GPS trace CSV file.
m (Any): The map object containing the GPS trace.
colormap (Dict[str, str]): The colormap for the GPS trace annotations.
layer_name (str): The name of the GPS trace layer.
fig_width (str, optional): The width of the figure. Defaults to "1550px".
fig_height (str, optional): The height of the figure. Defaults to "400px".
Returns:
Any: The main widget containing the map and the editing interface.
"""

from bqplot import LinearScale, Scatter, Figure, PanZoom
import bqplot as bq
from ipywidgets import VBox, Button

output_csv = os.path.join(
os.path.dirname(filename), filename.replace(".csv", "_annotated.csv")
)
output_geojson = output_csv.replace(".csv", ".geojson")

output = widgets.Output()

fig_margin = {"top": 20, "bottom": 35, "left": 50, "right": 20}
x_sc = LinearScale()
y_sc = LinearScale()

features = sorted(list(m.gps_trace.columns)[1:-3])
default_feature = "max_signal_strength"
default_index = features.index(default_feature)
feature = widgets.Dropdown(
options=features, index=default_index, description="Select feature"
)

column = feature.value
category_column = "annotation" # Replace with your categorical column name
x = m.gps_trace.index
y = m.gps_trace[column]

# Create scatter plots for each annotation category with the appropriate colors and labels
scatters = []
for cat, color in colormap.items():
if (
cat != "selected"
): # Exclude 'selected' from data points (only for highlighting selection)
mask = m.gps_trace[category_column] == cat
scatter = Scatter(
x=x[mask],
y=y[mask],
scales={"x": x_sc, "y": y_sc},
colors=[color],
marker="circle",
stroke="lightgray",
unselected_style={"opacity": 0.1},
selected_style={"opacity": 1.0},
default_size=48, # Set a smaller default marker size
display_legend=False,
labels=[cat], # Add the category label for the legend
)
scatters.append(scatter)

# Create the figure and add the scatter plots
fig = Figure(
marks=scatters,
fig_margin=fig_margin,
layout={"width": fig_width, "height": fig_height},
)
fig.axes = [
bq.Axis(scale=x_sc, label="Time"),
bq.Axis(scale=y_sc, orientation="vertical", label=column),
]

fig.legend_location = "top-right"

# Add LassoSelector interaction
selector = bq.interacts.LassoSelector(x_scale=x_sc, y_scale=y_sc, marks=scatters)
fig.interaction = selector

# Add PanZoom interaction for zooming and panning
panzoom = PanZoom(scales={"x": [x_sc], "y": [y_sc]})
fig.interaction = (
panzoom # Set PanZoom as the interaction to enable zooming initially
)

# Callback function to handle selected points with bounds check
def on_select(*args):
# output.clear_output()
with output:
selected_idx = []
for scatter in scatters:
selected_indices = scatter.selected
if selected_indices is not None:
selected_indices = [
int(i) for i in selected_indices if i < len(scatter.x)
] # Ensure integer indices
selected_x = scatter.x[selected_indices]
selected_y = scatter.y[selected_indices]
selected_idx += selected_x.tolist()
selected_idx = sorted(list(set(selected_idx)))
m.gdf.loc[selected_idx, "category"] = "selected"
m.set_data(layer_name, m.gdf.__geo_interface__)

# Register the callback for each scatter plot
for scatter in scatters:
scatter.observe(on_select, names=["selected"])

# Programmatic selection function based on common x values
def select_points_by_common_x(x_values):
"""
Select points based on a common list of x values across all categories.
"""
for scatter in scatters:
# Find indices of points in the scatter that match the given x values
selected_indices = [
i for i, x_val in enumerate(scatter.x) if x_val in x_values
]
scatter.selected = (
selected_indices # Highlight points at the specified indices
)

# Function to clear the lasso selection
def clear_selection(b):
for scatter in scatters:
scatter.selected = None # Clear selected points
fig.interaction = selector # Re-enable the LassoSelector

m.gdf["category"] = m.gdf["annotation"]
m.set_data(layer_name, m.gdf.__geo_interface__)

# Button to clear selection and switch between interactions
clear_button = Button(description="Clear Selection", button_style="primary")
clear_button.on_click(clear_selection)

# Toggle between LassoSelector and PanZoom interactions
def toggle_interaction(button):
if fig.interaction == selector:
fig.interaction = panzoom # Switch to PanZoom for zooming and panning
button.description = "Enable Lasso"
else:
fig.interaction = selector # Switch back to LassoSelector
button.description = "Enable Zoom/Pan"

toggle_button = Button(description="Enable Zoom/Pan", button_style="primary")
toggle_button.on_click(toggle_interaction)

def feature_change(change):
if change["new"]:
categories = m.gdf["annotation"].value_counts()
keys = list(colormap.keys())[:-1]
for index, cat in enumerate(keys):
mask = m.gdf["annotation"] == cat
scatters[index].x = m.gps_trace.index[mask]
scatters[index].y = m.gps_trace[feature.value][mask]
scatters[index].colors = [colormap[cat]] * categories[cat]
for scatter in scatters:
scatter.selected = None

feature.observe(feature_change, names="value")

def draw_change(lng_lat):
if lng_lat.new:
output.clear_output()
features = {
"type": "FeatureCollection",
"features": m.draw_features_selected,
}
m.gdf["category"] = m.gdf["annotation"]
gdf_draw = gpd.GeoDataFrame.from_features(features)
points_within_polygons = gpd.sjoin(
m.gdf, gdf_draw, how="left", predicate="within"
)
points_within_polygons.loc[
points_within_polygons["index_right"].notna(), "category"
] = "selected"
with output:
selected = points_within_polygons.loc[
points_within_polygons["category"] == "selected"
]
sel_idx = selected.index.tolist()
select_points_by_common_x(sel_idx)
m.set_data(layer_name, points_within_polygons.__geo_interface__)
if "index_right" in points_within_polygons.columns:
points_within_polygons = points_within_polygons.drop(
columns=["index_right"]
)
m.gdf = points_within_polygons
else:
for scatter in scatters:
scatter.selected = None # Clear selected points
fig.interaction = selector # Re-enable the LassoSelector

m.gdf["category"] = m.gdf["annotation"]
m.set_data(layer_name, m.gdf.__geo_interface__)

m.observe(draw_change, names="draw_features_selected")

widget = widgets.VBox([])
options = ["doorstep", "indoor", "outdoor", "parked"]
dropdown = widgets.Dropdown(options=options, value=None, description="annotation")
button_layout = widgets.Layout(width="97px")
save = widgets.Button(
description="Save", button_style="primary", layout=button_layout
)
export = widgets.Button(
description="Export", button_style="primary", layout=button_layout
)
reset = widgets.Button(
description="Reset", button_style="primary", layout=button_layout
)
widget.children = [feature, dropdown, widgets.HBox([save, export, reset]), output]

def on_save_click(b):
m.gdf.loc[m.gdf["category"] == "selected", "annotation"] = dropdown.value
m.gdf.loc[m.gdf["category"] == "selected", "category"] = dropdown.value
m.set_data(layer_name, m.gdf.__geo_interface__)
categories = m.gdf["annotation"].value_counts()
keys = list(colormap.keys())[:-1]
for index, cat in enumerate(keys):
mask = m.gdf["annotation"] == cat
scatters[index].x = m.gps_trace.index[mask]
scatters[index].y = m.gps_trace[feature.value][mask]
scatters[index].colors = [colormap[cat]] * categories[cat]
for scatter in scatters:
scatter.selected = None # Clear selected points
fig.interaction = selector # Re-enable the LassoSelector

m.gdf["category"] = m.gdf["annotation"]
m.set_data(layer_name, m.gdf.__geo_interface__)

save.on_click(on_save_click)

def on_export_click(b):
m.gps_trace["annotation"] = m.gdf["annotation"]
gdf = m.gps_trace.drop(columns=["category"])
gdf.to_file(output_geojson)
gdf.to_csv(output_csv, index=False)

export.on_click(on_export_click)

plot_widget = VBox([fig, widgets.HBox([clear_button, toggle_button])])

left_col_layout = v.Col(
cols=9, children=[m], class_="pa-1" # padding for consistent spacing
)
right_col_layout = v.Col(
cols=3,
children=[widget],
class_="pa-1", # padding for consistent spacing
)
row1 = v.Row(
class_="d-flex flex-wrap",
children=[left_col_layout, right_col_layout],
)
row2 = v.Row(
class_="d-flex flex-wrap",
children=[plot_widget],
)
main_widget = v.Col(children=[row1, row2])
return main_widget

0 comments on commit cf70cdf

Please sign in to comment.