-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jgrss/v170 #49
jgrss/v170 #49
Conversation
src/cultionet/callbacks.py
Outdated
# `num_classes` includes background | ||
'count': 3 + num_classes - 1, | ||
'dtype': 'uint16', | ||
'blockxsize': 64 if 64 < src.gw.ncols else src.gw.ncols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion:
min(64, src.gw.ncols)
src/cultionet/callbacks.py
Outdated
'sharing': False, | ||
'compress': compression | ||
} | ||
profile['tiled'] = True if max(profile['blockxsize'], profile['blockysize']) >= 16 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The True if ... else False
part isn't strictly needed. if clarity is the goal, probably better to wrap this expression in a function:
def is_tiled(blockxsize, blockysize, tile_limit=16):
return max(blockxsize, blockysize) >= tile_limit
) | ||
rheight = pad_slice2d[0].stop - pad_slice2d[0].start | ||
rwidth = pad_slice2d[1].stop - pad_slice2d[1].start | ||
def reshaper(x: torch.Tensor, channel_dims: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be good to introduce an autoformatter like black
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea
src/cultionet/data/create.py
Outdated
train_data = joblib.load(train_path) | ||
if train_data.train_id == train_id: | ||
batch_stored = True | ||
aug_method = AugmenterMapping[aug.replace('-', '_')].value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the augmenter names should be consistent, just use _
or -
everywhere, or better yet use enums.
src/cultionet/data/create.py
Outdated
# Clip the edges to the current grid | ||
try: | ||
grid_edges = gpd.clip(df_edges, row.geometry) | ||
except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you may want to explicitly catch Topology errors, else you may emit misleading warnings when you run into other errors.
window_pad | ||
) for window, window_pad in window_chunk | ||
) | ||
pbar_total.update(len(window_chunk)) | ||
|
||
|
||
def create_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This thing is getting very long. Probably a good idea to find logical chunks to wrap in functions. See here for some guidelines on how to tell when functions are getting too long: https://stackoverflow.com/questions/475675/when-is-a-function-too-long.
qt = QuadTree(df_unique_locations, force_square=False) | ||
qt.split_recursive(max_samples=1) | ||
n_val = int(val_frac * len(df_unique_locations.index)) | ||
df_val_sample = qt.sample(n=n_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do? Something regarding the spatial distribution of the validation set, but it's not totally clear to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the spatially-balanced splitting method (see https://github.com/jgrss/geosample). I've added comments on each step to help clarify this.
@@ -134,21 +248,61 @@ def tanimoto(y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor: | |||
class TanimotoDistLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note as above, this probably needs tests.
src/cultionet/model.py
Outdated
# train_ds, val_ds = dataset.split_train_val_by_partition( | ||
# spatial_partitions=spatial_partitions, | ||
# partition_column=partition_column, | ||
# val_frac=val_frac, | ||
# partition_name=partition_name | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this commented-out block?
src/cultionet/models/base_layers.py
Outdated
# assert dims in (2, 3) | ||
# if dims == 2: | ||
# ones = torch.ones((1, channels, 1, 1)) | ||
# else: | ||
# ones = torch.ones((1, channels, 1, 1, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete?
import enum | ||
|
||
|
||
class ModelTypes(enum.Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enum class might be useful here, or elsewhere in the codebase: https://docs.python.org/3/library/enum.html#enum.StrEnum
Works nicely when you want to map enum values to strings of their names.
|
||
class SetActivation(torch.nn.Module): | ||
def __init__( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice small class! Now you can just run the followings without run torch.nn.{activation_type}
SetActivation( activation_type, channels=out_channels, dims=2)
|
||
def var(self, unbiased=True): | ||
mean = self.mean()[:, None] | ||
return self.integrate( | ||
lambda x: (x - mean).pow(2) | ||
) / (self.count - (1 if unbiased else 0)) | ||
|
||
def std(self, unbiased=True): | ||
return self.var(unbiased=unbiased).sqrt() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is interesting! This is similar to torch.var
src/cultionet/scripts/cultionet.py
Outdated
if len(ts_list) <= 1: | ||
pbar.update(1) | ||
pbar.set_description('TS too short') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so hypothetically, if I only have 20210101.tif
and 20220101.tif
in my features, this function will continue?
def generate_model_graph(args): | ||
from cultionet.models.convstar import StarRNN | ||
from cultionet.models.nunet import ResUNet3Psi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was there a reason we import inside a function for this one?
Is it because the imports are only relevant for this function + take a long time to import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because the imports are only relevant for this function
This -- this function only serves the purpose of creating .onnx files for viewing graphs. It's called in isolation and there's no need for the imports elsewhere.
src/cultionet/scripts/cultionet.py
Outdated
with open( | ||
project_path / f"{args.process}_command_{now.strftime('%Y%m%d-%H%M')}.json", mode='w' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
out_3_1=out_3_1, | ||
out_2_2=out_2_2, | ||
out_1_3=out_1_3 | ||
) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file does seem verbose, with a lot of repeated blocks. Attempting to reduce repetition could be the subject of a future PR.
@@ -0,0 +1,798 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of formatting stuff in this file. Linters will likely get it all.
* add flake8 to precommit * add black and flake8 to pyproject.toml * change flake8 repo * add install test extras * simplify checks * black formatting * created CONTRIBUTING file * format * format * format * sync names * format * format * format * remove unused function * format * moved line * format * format * format * format * format * format * format * format * format * format * format * format * format * format * format * format * use StrEnum * remove StrEnum * add version comment * format * format * fix: jgrss/refine (#58) * format * test * add missing reshape * remove edge temperature * removed edge refine layer * format * format * remove sigmoid * remove temperature override * increase lr * fixed arg name * add bash scripts * update docstring * fix: jgrss/refine (#59) * format * fix arg * use all data for refinement * add random sampler for refinement * format * format * remove old arg * format
This PR introduces changes toward
v170
.