Skip to content

Commit

Permalink
marhsmallow: remove deprecation warning
Browse files Browse the repository at this point in the history
  • Loading branch information
psaiz committed Apr 22, 2024
1 parent f8b00e1 commit 64664fc
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
5 changes: 3 additions & 2 deletions invenio_records_rest/loaders/marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from flask import request
from invenio_rest.errors import RESTValidationError
from marshmallow import ValidationError
from marshmallow import __version_info__ as marshmallow_version

from ..utils import marshmallow_major_version


def _flatten_marshmallow_errors(errors, parents=()):
Expand Down Expand Up @@ -81,7 +82,7 @@ def json_loader():
pid, record = pid_data.data
context["pid"] = pid
context["record"] = record
if marshmallow_version[0] < 3:
if marshmallow_major_version < 3:
result = schema_class(context=context).load(request_json)
if result.errors:
raise MarshmallowErrors(result.errors)
Expand Down
5 changes: 3 additions & 2 deletions invenio_records_rest/schemas/fields/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import warnings

from marshmallow import __version_info__ as marshmallow_version
from marshmallow import missing as missing_

from invenio_records_rest.utils import marshmallow_major_version

from .marshmallow_contrib import Function, Method


Expand All @@ -25,7 +26,7 @@ class GeneratedValue(object):
class ForcedFieldDeserializeMixin(object):
"""Mixin that forces deserialization of marshmallow fields."""

if marshmallow_version[0] < 3:
if marshmallow_major_version < 3:

def __init__(self, *args, **kwargs):
"""Override the "missing" parameter."""
Expand Down
8 changes: 4 additions & 4 deletions invenio_records_rest/schemas/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from flask import current_app
from invenio_rest.serializer import BaseSchema as Schema
from marshmallow import ValidationError
from marshmallow import __version_info__ as marshmallow_version
from marshmallow import fields, missing, post_load, validates_schema
from marshmallow import ValidationError, fields, missing, post_load, validates_schema

from invenio_records_rest.schemas.fields import PersistentIdentifier

from ..utils import marshmallow_major_version


class StrictKeysMixin(Schema):
"""Ensure only valid keys exists."""
Expand Down Expand Up @@ -74,7 +74,7 @@ def load_unknown_fields(self, data, original_data):
return data


if marshmallow_version[0] < 3:
if marshmallow_major_version < 3:

class RecordMetadataSchemaJSONV1(OriginalKeysMixin):
"""Schema for records metadata v1 in JSON with injected PID value."""
Expand Down
5 changes: 5 additions & 0 deletions invenio_records_rest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from functools import partial

import pkg_resources
import six
from flask import abort, current_app, jsonify, make_response, request, url_for
from invenio_pidstore.errors import (
Expand All @@ -33,6 +34,10 @@
)
from .proxies import current_records_rest

marshmallow_major_version = int(
pkg_resources.get_distribution("marshmallow").version[0]
)


def build_default_endpoint_prefixes(records_rest_endpoints):
"""Build the default_endpoint_prefixes map."""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from invenio_pidstore.models import PersistentIdentifier as PIDModel
from invenio_records import Record
from invenio_rest.serializer import BaseSchema as Schema
from marshmallow import __version_info__ as marshmallow_version
from marshmallow import missing

from invenio_records_rest.schemas import StrictKeysMixin
Expand All @@ -25,8 +24,9 @@
SanitizedUnicode,
TrimmedString,
)
from invenio_records_rest.utils import marshmallow_major_version

if marshmallow_version[0] >= 3:
if marshmallow_major_version >= 3:
schema_to_use = Schema
from marshmallow import EXCLUDE
else:
Expand All @@ -36,7 +36,7 @@
class CustomFieldSchema(schema_to_use):
"""Test schema."""

if marshmallow_version[0] >= 3:
if marshmallow_major_version >= 3:

class Meta:
"""."""
Expand Down Expand Up @@ -107,7 +107,7 @@ def deserialize_func(value, ctx, data):
class GeneratedFieldsSchema(schema_to_use):
"""Test schema."""

if marshmallow_version[0] >= 3:
if marshmallow_major_version >= 3:

class Meta:
"""Meta attributes for the schema."""
Expand Down
7 changes: 3 additions & 4 deletions tests/test_marshmallow_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from helpers import get_json
from invenio_records.models import RecordMetadata
from invenio_rest.serializer import BaseSchema as Schema
from marshmallow import ValidationError
from marshmallow import __version_info__ as marshmallow_version
from marshmallow import fields
from marshmallow import ValidationError, fields

from invenio_records_rest.loaders import json_pid_checker
from invenio_records_rest.loaders.marshmallow import (
Expand All @@ -25,6 +23,7 @@
)
from invenio_records_rest.schemas import Nested
from invenio_records_rest.schemas.fields import PersistentIdentifier
from invenio_records_rest.utils import marshmallow_major_version


class _TestSchema(Schema):
Expand Down Expand Up @@ -166,7 +165,7 @@ def has_error(field, parents):
def test_marshmallow_errors(test_data):
"""Test MarshmallowErrors class."""
incomplete_data = dict(test_data[0])
if marshmallow_version[0] >= 3:
if marshmallow_major_version >= 3:
try:
res = _TestSchema(context={}).load(json.dumps(incomplete_data))
except ValidationError as error:
Expand Down

0 comments on commit 64664fc

Please sign in to comment.