Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aws s3 uri #193

Merged
merged 16 commits into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@ deepparse/version.py
*.ckpt

*mlruns/

*model/
1 change: 1 addition & 0 deletions .release/bpemb.version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
aa32fa918494b461202157c57734c374
1 change: 1 addition & 0 deletions .release/bpemb_attention.version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cfb190902476376573591c0ec6f91ece
1 change: 1 addition & 0 deletions .release/fasttext.version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
f67a0517c70a314bdde0b8440f21139d
1 change: 1 addition & 0 deletions .release/fasttext_attention.version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
a2b688bdfa2aa7c009bb7d980e352978
5 changes: 5 additions & 0 deletions .release/model_version_release.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# How to Create a New Model's Version

1. `md5sum <model.ckpt> > model.version`
2. Remove the model.cpkt text in `model.version` file
3. Update latests BPEMB and FastText hash in `tests/test_tools.py`
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,9 @@
increases the performance by about 1/100.

## dev

- New models release with more meta-data
- Add a feature to use an AddressParser from a URI
- Add a feature to upload the trained model to a URI
- Add an example of how to use URI for parsing from and uploading to
- Improve error handling of `path_to_retrain_model`
2 changes: 1 addition & 1 deletion deepparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .fasttext_tools import *
from .tools import *
from .version import __version__
from .weights_init import *
from .weights_tools import *
2 changes: 1 addition & 1 deletion deepparse/cli/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(args=None) -> None:

.. code-block:: sh

parse fasttext ./dataset.csv parsed_address.pckl --path_to_retrained_model ./path
parse fasttext ./dataset.csv parsed_address.pckl --path_to_model_weights ./path

"""
if args is None: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion deepparse/cli/parser_arguments_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def add_batch_size_arg(parser: ArgumentParser) -> None:
def add_path_to_retrained_model_arg(parser: ArgumentParser) -> None:
parser.add_argument(
"--path_to_retrained_model",
help=wrap("A path to a retrained model to use for testing."),
help=wrap("A path to a retrained model to use. It can be an S3-URI."),
type=str,
default=None,
)
Expand Down
2 changes: 1 addition & 1 deletion deepparse/network/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from ..weights_init import weights_init
from .. import weights_init


class Decoder(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion deepparse/network/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from ..weights_init import weights_init
from .. import weights_init


class Encoder(nn.Module):
Expand Down
19 changes: 10 additions & 9 deletions deepparse/network/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import random
import warnings
from abc import ABC
from collections import OrderedDict
from typing import Tuple, Union, List

import torch
from torch import nn

from .decoder import Decoder
from .encoder import Encoder
from .. import handle_weights_upload
from ..tools import download_weights, latest_version


Expand Down Expand Up @@ -113,20 +113,21 @@ def _load_pre_trained_weights(self, model_type: str, cache_dir: str, offline: bo
)
download_weights(model_type, cache_dir, verbose=self.verbose)

all_layers_params = torch.load(model_path, map_location=self.device)
self.load_state_dict(all_layers_params)
self._load_weights(path_to_model_torch_archive=model_path)

def _load_weights(self, path_to_retrained_model: str) -> None:
def _load_weights(self, path_to_model_torch_archive: str) -> None:
"""
Method to load (into the network) the weights.

Args:
path_to_retrained_model (str): The path to the fine-tuned model.
path_to_model_torch_archive (str): The path to the fine-tuned model Torch archive.
"""
all_layers_params = torch.load(path_to_retrained_model, map_location=self.device)
if isinstance(all_layers_params, dict) and not isinstance(all_layers_params, OrderedDict):
# Case where we have a retrained model with a different tagging space
all_layers_params = all_layers_params.get("address_tagger_model")
all_layers_params = handle_weights_upload(
path_to_model_to_upload=path_to_model_torch_archive, device=self.device
)

# All the time, our torch archive include meta-data along with the model weights
all_layers_params = all_layers_params.get("address_tagger_model")
self.load_state_dict(all_layers_params)

def _encoder_step(self, to_predict: torch.Tensor, lengths: List, batch_size: int) -> Tuple:
Expand Down
77 changes: 65 additions & 12 deletions deepparse/parser/address_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, List, Tuple, Union, Callable

import torch
from cloudpathlib import CloudPath, S3Path
from poutyne.framework import Experiment
from torch.optim import SGD
from torch.utils.data import DataLoader, Subset
Expand Down Expand Up @@ -43,6 +44,7 @@
from ..pre_processing import trailing_whitespace_cleaning, double_whitespaces_cleaning
from ..tools import CACHE_PATH, valid_poutyne_version
from ..vectorizer import VectorizerFactory
from ..weights_tools import handle_weights_upload

_pre_trained_tags_to_idx = {
"StreetNumber": 0,
Expand Down Expand Up @@ -86,7 +88,7 @@ class AddressParser:
- ``"lightest"`` (the one using the less RAM and GPU usage) (equivalent to ``"fasttext-light"``),
- ``"best"`` (the best accuracy performance) (equivalent to ``"bpemb"``).

The default value is ``"best"`` for the most accurate model. Ignored if ``path_to_retrained_model`` is not
The default value is ``"best"`` for the most accurate model. Ignored if ``path_to_model_weights`` is not
``None``. To further improve performance, consider using the models (fasttext or BPEmb) with their
counterparts using an attention mechanism with the ``attention_mechanism`` flag.
attention_mechanism (bool): Whether to use the model with an attention mechanism. The model will use an
Expand All @@ -102,10 +104,13 @@ class AddressParser:
The default value is GPU with the index ``0`` if it exists. Otherwise, the value is ``CPU``.
rounding (int): The rounding to use when asking the probability of the tags. The default value is four digits.
verbose (bool): Turn on/off the verbosity of the model weights download and loading. The default value is True.
path_to_retrained_model (Union[str, None]): The path to the retrained model to use for prediction. We will
infer the ``model_type`` of the retrained model. The default value is ``None``, meaning we use our
path_to_retrained_model (Union[S3Path, str, None]): The path to the retrained model to use for prediction.
We will infer the ``model_type`` of the retrained model. The default value is ``None``, meaning we use our
pretrained model. If the retrained model uses an attention mechanism, ``attention_mechanism`` needs to
be set to True.
be set to True. The path_to_retrain_model can also be a S3-like (Azure, AWS, Google) bucket URI string path
(e.g. ``"s3://path/to/aws/s3/bucket.ckpt"``). Or it can be a ``S3Path`` S3-like URI using `cloudpathlib`
to handle S3-like bucket. See `cloudpathlib <https://cloudpathlib.drivendata.org/stable/>`
for detail on supported S3 buckets provider and URI condition. The default value is None.
cache_dir (Union[str, None]): The path to the cached directory to use for downloading (and loading) the
embeddings model and the model pretrained weights.
offline (bool): Whether or not the model is an offline one, meaning you have already downloaded the pre-trained
Expand Down Expand Up @@ -164,23 +169,23 @@ class AddressParser:
.. code-block:: python

address_parser = AddressParser(model_type="fasttext",
path_to_retrained_model="/path_to_a_retrain_fasttext_model.ckpt")
path_to_model_weights="/path_to_a_retrain_fasttext_model.ckpt")
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")

Using a retrained model trained on different tags

.. code-block:: python

# We don't give the model_type since it's ignored when using path_to_retrained_model
address_parser = AddressParser(path_to_retrained_model="/path_to_a_retrain_fasttext_model.ckpt")
# We don't give the model_type since it's ignored when using path_to_model_weights
address_parser = AddressParser(path_to_model_weights="/path_to_a_retrain_fasttext_model.ckpt")
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")

Using a retrained model with attention

.. code-block:: python

address_parser = AddressParser(model_type="fasttext",
path_to_retrained_model="/path_to_a_retrain_fasttext_attention_model.ckpt",
path_to_model_weights="/path_to_a_retrain_fasttext_attention_model.ckpt",
attention_mechanism=True)
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")

Expand All @@ -193,6 +198,21 @@ class AddressParser:
offline=True)
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")

Using a retrained model in an S3-like bucket.

.. code-block:: python

address_parser = AddressParser(model_type="fasttext",
path_to_model_weights="s3://path/to/bucket.ckpt")
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")

Using a retrained model in an S3-like bucket using CloudPathLib.

.. code-block:: python

address_parser = AddressParser(model_type="fasttext",
path_to_model_weights=CloudPath("s3://path/to/bucket.ckpt"))
parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6")
"""

def __init__(
Expand All @@ -202,7 +222,7 @@ def __init__(
device: Union[int, str, torch.device] = 0,
rounding: int = 4,
verbose: bool = True,
path_to_retrained_model: Union[str, None] = None,
path_to_retrained_model: Union[S3Path, str, None] = None,
cache_dir: Union[str, None] = None,
offline: bool = False,
) -> None:
Expand All @@ -222,7 +242,7 @@ def __init__(
seq2seq_kwargs = {} # Empty for default settings

if path_to_retrained_model is not None:
checkpoint_weights = torch.load(path_to_retrained_model, map_location="cpu")
checkpoint_weights = handle_weights_upload(path_to_model_to_upload=path_to_retrained_model)
if checkpoint_weights.get("model_type") is None:
# Validate if we have the proper metadata, it has at least the parser model type
# if no other thing have been modified.
Expand All @@ -237,6 +257,7 @@ def __init__(
"See AddressParser.retrain for more details."
)
raise RuntimeError(error_text)

if validate_if_new_seq2seq_params(checkpoint_weights):
seq2seq_kwargs = checkpoint_weights.get("seq2seq_params")
if validate_if_new_prediction_tags(checkpoint_weights):
Expand Down Expand Up @@ -501,6 +522,12 @@ def retrain(
logging_path (str): The logging path for the checkpoints. Poutyne will use the best one and reload the
state if any checkpoints are there. Thus, an error will be raised if you change the model type.
For example, you retrain a FastText model and then retrain a BPEmb in the same logging path directory.
The logging_path can also be a S3-like (Azure, AWS, Google) bucket URI string path
(e.g. ``"s3://path/to/aws/s3/bucket.ckpt"``). Or it can be a ``S3Path`` S3-like URI using `cloudpathlib`
to handle S3-like bucket. See `cloudpathlib <https://cloudpathlib.drivendata.org/stable/>`
for detail on supported S3 buckets provider and URI condition.
If the logging_path is a S3 bucket, we will only save the best checkpoint to the S3 Bucket at the end
of training.
By default, the path is ``./checkpoints``.
disable_tensorboard (bool): To disable Poutyne automatic Tensorboard monitoring. By default, we disable them
(true).
Expand Down Expand Up @@ -801,6 +828,7 @@ def retrain(
else f"retrained_{self.model_type}_address_parser.ckpt"
)
file_path = os.path.join(logging_path, file_name)

torch_save = {
"address_tagger_model": exp.model.network.state_dict(),
"model_type": self.model_type,
Expand All @@ -821,7 +849,29 @@ def retrain(
}
)

torch.save(torch_save, file_path)
if isinstance(file_path, S3Path):
# To handle CloudPath path_to_model_weights
try:
with file_path.open("wb") as file:
torch.save(torch_save, file)
except FileNotFoundError as error:
raise FileNotFoundError("The file in the S3 bucket was not found.") from error

elif "s3://" in file_path:
file_path = CloudPath(file_path)
try:
with file_path.open("wb") as file:
torch.save(torch_save, file)
except FileNotFoundError as error:
raise FileNotFoundError("The file in the S3 bucket was not found.") from error
else:
try:
torch.save(torch_save, file_path)
except FileNotFoundError as error:
if "s3" in file_path or "//" in file_path or ":" in file_path:
raise FileNotFoundError(
"Are You trying to use a AWS S3 URI? If so path need to start with s3://."
) from error
return train_res

def test(
Expand Down Expand Up @@ -1254,9 +1304,12 @@ def _apply_pre_processors(self, addresses: List[str]) -> List[str]:
res = []

for address in addresses:
processed_address = address

for pre_processor in self.pre_processors:
processed_address = pre_processor(address)
res.append(" ".join(processed_address.split()))

res.append(" ".join(processed_address.split()))
return res

def is_same_model_type(self, other) -> bool:
Expand Down
7 changes: 5 additions & 2 deletions deepparse/parser/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import os
from typing import List, OrderedDict, Tuple

import math
import numpy as np
import torch

Expand Down Expand Up @@ -134,7 +134,10 @@ def infer_model_type(checkpoint_weights: OrderedDict, attention_mechanism: bool)
else:
model_type = "fasttext"

if "decoder.linear_attention_mechanism_encoder_outputs.weight" in checkpoint_weights.keys():
if (
"decoder.linear_attention_mechanism_encoder_outputs.weight"
in checkpoint_weights.get("address_tagger_model").keys()
):
attention_mechanism = True

return model_type, attention_mechanism
21 changes: 0 additions & 21 deletions deepparse/weights_init.py

This file was deleted.

Loading