Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve test coverage for RegionPolygonFile.py #99

Merged
merged 17 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
/dist
/build
/documentation

*.log
*.log
.coverage
6 changes: 3 additions & 3 deletions fm2prof/CrossSection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pandas as pd
import scipy.optimize as so
from scipy.integrate import cumtrapz
from scipy.integrate import cumulative_trapezoid
from tqdm import tqdm

from fm2prof import Functions as FE
Expand Down Expand Up @@ -374,10 +374,10 @@ def get_timeseries(name: str):

# Compute 1D volume as integral of width with respect to z times length
self._css_total_volume = np.append(
[0], cumtrapz(self._css_total_width, self._css_z) * self.length
[0], cumulative_trapezoid(self._css_total_width, self._css_z) * self.length
)
self._css_flow_volume = np.append(
[0], cumtrapz(self._css_flow_width, self._css_z) * self.length
[0], cumulative_trapezoid(self._css_flow_width, self._css_z) * self.length
)

# If sd correction is run, these attributes will be updated.
Expand Down
81 changes: 53 additions & 28 deletions fm2prof/RegionPolygonFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
All rights reserved.
"""

import logging
import json
from collections import namedtuple
from pathlib import Path
from typing import Iterable
from typing import Iterable, Union, List
import rtree

import numpy as np
from shapely.geometry import Point, shape
Expand All @@ -39,7 +41,7 @@ class PolygonFile(FM2ProfBase):

def __init__(self, logger):
self.set_logger(logger)
self.polygons = list()
self._polygons = list()
self.undefined = -999

def classify_points_with_property(
Expand All @@ -58,7 +60,7 @@ def classify_points_with_property(
for i, point in enumerate(points):
for polygon in self.polygons:
if point.within(polygon.geometry):
points_regions[i] = int(polygon.properties.get(property_name))
points_regions[i] = polygon.properties.get(property_name)
break

return np.array(points_regions)
Expand Down Expand Up @@ -91,7 +93,7 @@ def classify_points_with_property_shapely_prep(
return np.array(points_regions)

def classify_points_with_property_rtree_by_polygons(
self, iterable_points: Iterable[list], property_name: str = "name"
self, points: Iterable[list], property_name: str = "name"
Tjalling-dejong marked this conversation as resolved.
Show resolved Hide resolved
) -> list:
"""Applies RTree index to quickly classify points in polygons.

Expand All @@ -110,7 +112,7 @@ def classify_points_with_property_rtree_by_polygons(
idx.insert(p_id, polygon.geometry.bounds, polygon)

point_properties_list = []
for point in map(Point, iterable_points):
for point in map(Point, points):
point_properties_polygon = next(
iter(
self.polygons[polygon_id].properties.get(property_name)
Expand All @@ -124,7 +126,9 @@ def classify_points_with_property_rtree_by_polygons(
del idx
return np.array(point_properties_list)

def __get_polygon_property(self, grouped_values: list, property_name: str) -> str:
def __get_polygon_property(
self, grouped_values: list, property_name: str
) -> str: # TODO: Can this be removed?
Tjalling-dejong marked this conversation as resolved.
Show resolved Hide resolved
"""Retrieves the polygon property from the internal list of polygons.

Arguments:
Expand All @@ -140,7 +144,7 @@ def __get_polygon_property(self, grouped_values: list, property_name: str) -> st
return self.undefined
return self.polygons[polygon_id].properties.get(property_name)

def parse_geojson_file(self, file_path):
def parse_geojson_file(self, file_path: Union[Path, str]) -> None:
"""Read data from geojson file"""
PolygonFile._validate_extension(file_path)

Expand All @@ -162,20 +166,18 @@ def parse_geojson_file(self, file_path):
# self.polygons[polygon_name] = polygon

@staticmethod
def _validate_extension(file_path: Path) -> None:
if not isinstance(file_path, Path):
return
if not file_path.suffix in (".json", ".geojson"):
raise IOError(
"Invalid file path extension, " + "should be .json or .geojson."
)
def _validate_extension(file_path: Union[Path, str]) -> None:
if isinstance(file_path, str):
file_path = Path(file_path)
if file_path.suffix not in (".json", ".geojson"):
raise IOError("Invalid file path extension, should be .json or .geojson.")

def _check_overlap(self):
for polygon in self.polygons:
for testpoly in self.polygons:
if polygon.properties.get("name") == testpoly.properties.get("name"):
# polygon will obviously overlap with itself
pass
continue
else:
if polygon.geometry.intersects(testpoly.geometry):
self.set_logger_message(
Expand All @@ -186,36 +188,57 @@ def _check_overlap(self):
level="warning",
)

@property
def polygons(self) -> list[Polygon]:
return self._polygons

@polygons.setter
def polygons(self, polygons_list: List[Polygon]) -> None:
if not all([isinstance(polygon, Polygon) for polygon in polygons_list]):
raise ValueError("Polygons must be of type Polygon")
# Check if properties contain the required 'name' property
names = [polygon.properties.get("name") for polygon in polygons_list]
if not all(names):
raise ValueError("Polygon properties must contain key-word 'name'")
# Check if 'name' property is unique, otherwise _check_overlap will produce bugs
if len(names) != len(set(names)):
raise ValueError("Property 'name' must be unique")
self._polygons = polygons_list


class RegionPolygonFile(PolygonFile):
def __init__(self, region_file_path, logger):
super().__init__(logger)
self.read_region_file(region_file_path)

@property
def regions(self):
def regions(self) -> list[Polygon]:
Tjalling-dejong marked this conversation as resolved.
Show resolved Hide resolved
return self.polygons

def read_region_file(self, file_path):
def read_region_file(self, file_path) -> None:
self.parse_geojson_file(file_path)
self._validate_regions()

def _validate_regions(self):
self.set_logger_message("Validating Region file")
def _validate_regions(self) -> None:
self.set_logger_message("Validating region file", level="info")

number_of_regions = len(self.regions)

self.set_logger_message("{} regions found".format(number_of_regions))
self.set_logger_message(
"{} regions found".format(number_of_regions), level="info"
)

# Test if polygons overlap
self._check_overlap()

def classify_points(self, points: Iterable[list]):
return self.classify_points_with_property(points, property_name="id")
def classify_points(
self, points: Iterable[list], property_name: str = "id"
Tjalling-dejong marked this conversation as resolved.
Show resolved Hide resolved
) -> list:
return self.classify_points_with_property(points, property_name=property_name)


class SectionPolygonFile(PolygonFile):
def __init__(self, section_file_path, logger):
def __init__(self, section_file_path, logger: logging.Logger):
super().__init__(logger)
self.read_section_file(section_file_path)
self.undefined = 1 # 1 is main
Expand Down Expand Up @@ -252,8 +275,11 @@ def _validate_sections(self):
),
level="error",
)
section_key = str(section.properties.get("section")).lower()
if section_key not in valid_section_keys:

elif (
str(section.properties.get("section")).lower() not in valid_section_keys
):
section_key = str(section.properties.get("section")).lower()
if section_key not in list(map_section_keys.keys()):
raise_exception = True
self.set_logger_message(
Expand All @@ -272,6 +298,5 @@ def _validate_sections(self):
self._check_overlap()

if raise_exception:
raise AssertionError("Section file could not validated")
else:
self.set_logger_message("Section file succesfully validated")
raise AssertionError("Section file is not valid")
self.set_logger_message("Section file succesfully validated")
Loading
Loading