From 6211f46bf908aa793bc502845afbf469f34e6bb3 Mon Sep 17 00:00:00 2001 From: Brax Team Date: Thu, 12 Jan 2023 16:32:32 -0500 Subject: [PATCH] Internal change PiperOrigin-RevId: 501652739 Change-Id: I5232dd2c1dbc5a3b204d076be8f3d3fe51abf2b1 --- brax/tools/urdf.py | 6 +- brax/v2/base.py | 17 +- brax/v2/generalized/base.py | 87 +++++ brax/v2/generalized/constraint.py | 97 ++--- brax/v2/generalized/dynamics.py | 84 ++-- brax/v2/generalized/mass.py | 39 +- brax/v2/generalized/pipeline.py | 134 ++----- brax/v2/geometry/contact.py | 111 +++++- brax/v2/geometry/contact_test.py | 91 +++++ brax/v2/geometry/math.py | 366 +++++++++++++++++- brax/v2/geometry/math_test.py | 168 +++++++- brax/v2/geometry/mesh.py | 219 +++++++++-- brax/v2/geometry/mesh_test.py | 126 +++++- brax/v2/io/html.py | 16 +- brax/v2/io/json.py | 60 ++- brax/v2/io/json_test.py | 20 +- brax/v2/io/mjcf.py | 144 +++++-- brax/v2/io/mjcf_test.py | 4 +- brax/v2/math.py | 3 +- brax/v2/spring/pipeline.py | 15 +- brax/v2/test_data/convex_convex.xml | 38 ++ brax/v2/test_data/meshes/cylinder.stl | Bin 0 -> 20884 bytes brax/v2/test_data/meshes/dodecahedron.stl | Bin 0 -> 1884 bytes brax/v2/test_data/meshes/pyramid.stl | Bin 0 -> 384 bytes brax/v2/test_data/meshes/tetrahedron.stl | Bin 0 -> 284 bytes brax/v2/test_data/ur5e/meshes/base.stl | Bin 21084 -> 0 bytes brax/v2/test_data/ur5e/meshes/base_vis.stl | Bin 240784 -> 0 bytes brax/v2/test_data/ur5e/meshes/forearm.stl | Bin 53284 -> 0 bytes brax/v2/test_data/ur5e/meshes/forearm_vis.stl | Bin 648934 -> 0 bytes brax/v2/test_data/ur5e/meshes/shoulder.stl | Bin 70084 -> 0 bytes .../v2/test_data/ur5e/meshes/shoulder_vis.stl | Bin 1056884 -> 0 bytes brax/v2/test_data/ur5e/meshes/upperarm.stl | Bin 99684 -> 0 bytes .../v2/test_data/ur5e/meshes/upperarm_vis.stl | Bin 1706034 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist1.stl | Bin 59584 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist1_vis.stl | Bin 806284 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist2.stl | Bin 67584 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist2_vis.stl | Bin 946784 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist3.stl | Bin 7184 -> 0 bytes brax/v2/test_data/ur5e/meshes/wrist3_vis.stl | Bin 45634 -> 0 bytes brax/v2/test_data/ur5e/robot.xml | 79 ---- brax/v2/visualizer/js/system.js | 144 +++---- brax/v2/visualizer/js/viewer.js | 58 ++- brax/v2/visualizer/visualizer.py | 6 +- notebooks/Brax_v2_Training_Preview.ipynb | 7 +- setup.py | 1 + 45 files changed, 1581 insertions(+), 559 deletions(-) create mode 100644 brax/v2/generalized/base.py create mode 100644 brax/v2/test_data/convex_convex.xml create mode 100644 brax/v2/test_data/meshes/cylinder.stl create mode 100644 brax/v2/test_data/meshes/dodecahedron.stl create mode 100644 brax/v2/test_data/meshes/pyramid.stl create mode 100644 brax/v2/test_data/meshes/tetrahedron.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/base.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/base_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/forearm.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/forearm_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/shoulder.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/shoulder_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/upperarm.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/upperarm_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist1.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist1_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist2.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist2_vis.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist3.stl delete mode 100644 brax/v2/test_data/ur5e/meshes/wrist3_vis.stl delete mode 100644 brax/v2/test_data/ur5e/robot.xml diff --git a/brax/tools/urdf.py b/brax/tools/urdf.py index 503b1267..d641d9ad 100644 --- a/brax/tools/urdf.py +++ b/brax/tools/urdf.py @@ -254,11 +254,11 @@ def expand_node(self, else: warnings.warn(f'No collider found on link {node}.') - - if inertia.find("mass") is None: + + if inertia.find('mass') is None: body.mass += 1. else: - body.mass = float(inertia.find("mass").get("value")) + body.mass = float(inertia.find('mass').get('value')) # TODO: load inertia body.inertia.x, body.inertia.y, body.inertia.z = 1., 1., 1. diff --git a/brax/v2/base.py b/brax/v2/base.py index 010dc537..d707d898 100644 --- a/brax/v2/base.py +++ b/brax/v2/base.py @@ -349,12 +349,23 @@ class Mesh(Geometry): The mesh is expected to be in the counter-clockwise winding order. Attributes: - face: (num_faces, num_face_vertices) vertices associated with each face vert: (num_verts, 3) spatial coordinates associated with each vertex + face: (num_faces, num_face_vertices) vertices associated with each face """ - face: jp.ndarray vert: jp.ndarray + face: jp.ndarray + + +@struct.dataclass +class Convex(Mesh): + """A convex mesh geometry. + + Attributes: + unique_edge: (num_unique, 2) vert index associated with each unique edge + """ + + unique_edge: jp.ndarray @struct.dataclass @@ -405,12 +416,14 @@ class State: qd: (qd_size,) joint velocity vector x: (num_links,) link position in world frame xd: (num_links,) link velocity in world frame + contact: calculated contacts """ q: jp.ndarray qd: jp.ndarray x: jp.ndarray xd: jp.ndarray + contact: Optional[Contact] @struct.dataclass diff --git a/brax/v2/generalized/base.py b/brax/v2/generalized/base.py new file mode 100644 index 00000000..000ae9ba --- /dev/null +++ b/brax/v2/generalized/base.py @@ -0,0 +1,87 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""Base types for generalized pipeline.""" + +from brax.v2 import base +from brax.v2.base import Inertia, Motion, System, Transform +from flax import struct +from jax import numpy as jp + + +@struct.dataclass +class State(base.State): + """Dynamic state that changes after every step. + + Attributes: + com: center of mass position + cinr: inertia in com frame + cd: body velocities in com frame + cdof: dofs in com frame + cdofd: cdof velocity + mass_mx: (qd_size, qd_size) mass matrix + mass_mx_inv: (qd_size, qd_size) inverse mass matrix + contact: calculated contacts + con_jac: constraint jacobian + con_pos: constraint position + con_diag: constraint A diagonal + qf_smooth: smooth dynamics force + qf_constraint: (qd_size,) force from constraints (collision etc) + qdd: (qd_size,) joint acceleration vector + """ + + # position/velocity based terms are updated at the end of each step: + com: jp.ndarray + cinr: Inertia + cd: Motion + cdof: Motion + cdofd: Motion + mass_mx: jp.ndarray + mass_mx_inv: jp.ndarray + con_jac: jp.ndarray + con_pos: jp.ndarray + con_diag: jp.ndarray + # acceleration based terms are calculated using terms from the previous step: + qf_smooth: jp.ndarray + qf_constraint: jp.ndarray + qdd: jp.ndarray + + @classmethod + def zero(cls, sys: System) -> 'State': + """Returns an initial State given a brax system.""" + return State( + q=jp.zeros(sys.q_size()), + qd=jp.zeros(sys.qd_size()), + x=Transform.zero((sys.num_links(),)), + xd=Motion.zero((sys.num_links(),)), + contact=None, + com=jp.zeros((sys.num_links(), 3)), + cinr=Inertia( + Transform.zero((sys.num_links(),)), + jp.zeros((sys.num_links(), 3, 3)), + jp.zeros((sys.num_links(),)), + ), + cd=Motion.zero((sys.num_links(),)), + cdof=Motion.zero((sys.num_links(),)), + cdofd=Motion.zero((sys.num_links(),)), + mass_mx=jp.zeros((sys.num_links(), sys.num_links())), + mass_mx_inv=jp.zeros((sys.num_links(), sys.num_links())), + con_jac=jp.zeros(()), + con_pos=jp.zeros(()), + con_diag=jp.zeros(()), + qf_smooth=jp.zeros((sys.qd_size(),)), + qf_constraint=jp.zeros((sys.qd_size(),)), + qdd=jp.zeros(sys.qd_size()), + ) diff --git a/brax/v2/generalized/constraint.py b/brax/v2/generalized/constraint.py index 26ab92e2..b2b05217 100644 --- a/brax/v2/generalized/constraint.py +++ b/brax/v2/generalized/constraint.py @@ -14,15 +14,15 @@ # pylint:disable=g-multiple-import """Functions for constraint satisfaction.""" -from typing import Optional, Tuple +from typing import Tuple from brax.v2 import math from brax.v2 import scan -from brax.v2.base import Contact, Motion, System, Transform +from brax.v2.base import Motion, System, Transform +from brax.v2.generalized.base import State import jax from jax import numpy as jp import jaxopt -from jaxopt.projection import projection_non_negative def _pt_jac( @@ -92,13 +92,13 @@ def _imp_aref(pos: jp.ndarray, vel: jp.ndarray) -> jp.ndarray: def jac_limit( - sys: System, q: jp.ndarray + sys: System, state: State ) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]: """Calculates the jacobian for angle limits in dof frame. Args: sys: the brax system - q: joint angle vector + state: generalized state Returns: jac: the angle limit jacobian @@ -111,8 +111,8 @@ def jac_limit( # determine q and qd indices for non-free joints q_idx, qd_idx = sys.q_idx('123'), sys.qd_idx('123') - pos_min = q[q_idx] - sys.dof.limit[0][qd_idx] - pos_max = sys.dof.limit[1][qd_idx] - q[q_idx] + pos_min = state.q[q_idx] - sys.dof.limit[0][qd_idx] + pos_max = sys.dof.limit[1][qd_idx] - state.q[q_idx] pos = jp.minimum(jp.minimum(pos_min, pos_max), 0) side = ((pos_min < pos_max) * 2 - 1) * (pos < 0) @@ -123,28 +123,26 @@ def jac_limit( def jac_contact( - sys: System, com: jp.ndarray, cdof: Motion, contact: Optional[Contact] + sys: System, state: State ) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]: """Calculates the jacobian for contact constraints. Args: sys: the brax system - com: center of mass position - cdof: dofs in com frame - contact: contacts computed for link geometries + state: generalized state Returns: jac: the contact jacobian pos: contact position in constraint frame diag: approximate diagonal of A matrix """ - if contact is None: + if state.contact is None: return jp.zeros((0, sys.qd_size())), jp.zeros((0,)), jp.zeros((0,)) def row_fn(contact): link_a, link_b = contact.link_idx - a = _pt_jac(sys, com, cdof, contact.pos, link_a) - b = _pt_jac(sys, com, cdof, contact.pos, link_b) + a = _pt_jac(sys, state.com, state.cdof, contact.pos, link_a) + b = _pt_jac(sys, state.com, state.cdof, contact.pos, link_b) diff = b - a # 4 pyramidal friction directions @@ -163,66 +161,42 @@ def row_fn(contact): lambda x: x * (contact.penetration > 0), (jac, pos, diag) ) - return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(contact)) + return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(state.contact)) -def jacobian( - sys: System, - q: jp.ndarray, - com: jp.ndarray, - cdof: Motion, - contact: Optional[Contact], -) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]: - """Calculates the full constraint jacobian and constraint position. +def jacobian(sys: System, state: State) -> State: + """Calculates the constraint jacobian, position, and A matrix diagonal. Args: sys: a brax system - q: joint position vector - com: center of mass position - cdof: dofs in com frame - contact: contacts computed for link geometries + state: generalized state Returns: - jac: the constraint jacobian - pos: position in constraint frame - diag: approximate diagonal of A matrix + state: generalized state with jac, pos, diag updated """ - jpds = jac_contact(sys, com, cdof, contact), jac_limit(sys, q) + jpds = jac_contact(sys, state), jac_limit(sys, state) + jac, pos, diag = jax.tree_map(lambda *x: jp.concatenate(x), *jpds) + return state.replace(con_jac=jac, con_pos=pos, con_diag=diag) - return jax.tree_map(lambda *x: jp.concatenate(x), *jpds) - -def force( - sys: System, - qd: jp.ndarray, - qf_smooth: jp.ndarray, - mass_mx_inv: jp.ndarray, - jac: jp.ndarray, - pos: jp.ndarray, - diag: jp.ndarray, -) -> jp.ndarray: +def force(sys: System, state: State) -> jp.ndarray: """Calculates forces that satisfy joint, collision constraints. Args: sys: a brax system - qd: joint velocity vector - qf_smooth: joint force vector for smooth dynamics - mass_mx_inv: inverse mass matrix - jac: the constraint jacobian - pos: position in constraint frame - diag: approximate diagonal of A matrix + state: generalized state Returns: qf_constraint: (qd_size,) constraint force """ - if jac.shape[0] == 0: - return jp.zeros_like(qd) + if state.con_jac.shape[0] == 0: + return jp.zeros(sys.qd_size()) # calculate A matrix and b vector - imp, aref = _imp_aref(pos, jac @ qd) - a = jac @ mass_mx_inv @ jac.T - a = a + jp.diag(diag * (1 - imp) / imp) - b = jac @ mass_mx_inv @ qf_smooth - aref + imp, aref = _imp_aref(state.con_pos, state.con_jac @ state.qd) + a = state.con_jac @ state.mass_mx_inv @ state.con_jac.T + a = a + jp.diag(state.con_diag * (1 - imp) / imp) + b = state.con_jac @ state.mass_mx_inv @ state.qf_smooth - aref # solve for forces in constraint frame, Ax + b = 0 s.t. x >= 0 def objective(x): @@ -230,15 +204,26 @@ def objective(x): return jp.sum(0.5 * residual**2) # profiling a physics step, most of the time is spent running this solver: + # + # there might still be some opportunities to speed this up. consider that + # the A matrix is positive definite. could possibly use conjugate gradient? + # we made a jax version of this: https://github.com/david-cortes/nonneg_cg + # but it was still not as fast as the projected gradient solver below. + # this is possibly due to the line search method, which is a big part + # of the cost. jaxopt uses FISTA which we did not implement + # + # another avenue worth pursuing is that these A matrices are often + # fairly sparse. perhaps worth trying some kind of random or + # learned projection to solve a smaller dense matrix at each step pg = jaxopt.ProjectedGradient( objective, - projection_non_negative, + jaxopt.projection.projection_non_negative, maxiter=sys.solver_iterations, implicit_diff=False, maxls=5, ) # solve and convert back to q coordinates - qf_constraint = jac.T @ pg.run(jp.zeros_like(b)).params + qf_constraint = state.con_jac.T @ pg.run(jp.zeros_like(b)).params return qf_constraint diff --git a/brax/v2/generalized/dynamics.py b/brax/v2/generalized/dynamics.py index e8304a9d..83cb9a4b 100644 --- a/brax/v2/generalized/dynamics.py +++ b/brax/v2/generalized/dynamics.py @@ -14,48 +14,39 @@ # pylint:disable=g-multiple-import """Functions for smooth forward and inverse dynamics.""" -from typing import Tuple - from brax.v2 import math from brax.v2 import scan -from brax.v2.base import Inertia, Motion, System, Transform +from brax.v2.base import Motion, System, Transform +from brax.v2.generalized.base import State import jax from jax import numpy as jp -def transform_com( - sys: System, q: jp.ndarray, qd: jp.ndarray, x: jp.ndarray -) -> Tuple[jp.ndarray, Inertia, Motion, Motion, Motion]: +def transform_com(sys: System, state: State) -> State: """Transforms inertia, dof, and link velocity into center of mass frame. Args: sys: a brax system - q: joint position vector - qd: joint velocity vector - x: link position in world frame + state: generalized state Returns: - com: center of mass position - cinr: inertia in com frame - cd: body velocities in com frame - cdof: dofs in com frame - cdofd: cdof velocity + state: generalized state with com, cinr, cd, cdof, cdofd updated """ # TODO: support multiple kinematic trees in the same system - xi = x.vmap().do(sys.link.inertia.transform) + xi = state.x.vmap().do(sys.link.inertia.transform) mass = sys.link.inertia.mass com = jp.sum(jax.vmap(jp.multiply)(mass, xi.pos), axis=0) / jp.sum(mass) cinr = xi.replace(pos=xi.pos - com).vmap().do(sys.link.inertia) # motion dofs to global frame centered at subtree-CoM - x_pad = x.concatenate(Transform.zero(shape=(1,))) - pidx = jp.array( + parent_idx = jp.array( [ i if t == 'f' else p for i, (t, p) in enumerate(zip(sys.link_types, sys.link_parents)) ] ) - j = x_pad.take(pidx).vmap().do(sys.link.transform).vmap().do(sys.link.joint) + parent = state.x.concatenate(Transform.zero(shape=(1,))).take(parent_idx) + j = parent.vmap().do(sys.link.transform).vmap().do(sys.link.joint) # propagate motion through stacked joints def cdof_fn(typ, q, motion): @@ -88,11 +79,11 @@ def cdof_fn(typ, q, motion): return motion - cdof = scan.link_types(sys, cdof_fn, 'qd', 'd', q, sys.dof.motion) + cdof = scan.link_types(sys, cdof_fn, 'qd', 'd', state.q, sys.dof.motion) ang = jax.vmap(math.rotate)(cdof.ang, j.take(sys.dof_link()).rot) cdof = cdof.replace(ang=ang) cdof = Transform.create(pos=com - j.pos).take(sys.dof_link()).vmap().do(cdof) - cdof_qd = jax.vmap(lambda x, y: x * y)(cdof, qd) + cdof_qd = jax.vmap(lambda x, y: x * y)(cdof, state.qd) # forward scan down tree: accumulate link center of mass velocity def cd_fn(cd_parent, cdof_qd, dof_idx): @@ -130,20 +121,13 @@ def cdofd_fn(typ, cd, cdof, cdof_qd): return cdofd - cd_p = cd.concatenate(Motion.zero(shape=(1,))).take(pidx) + cd_p = cd.concatenate(Motion.zero(shape=(1,))).take(parent_idx) cdofd = scan.link_types(sys, cdofd_fn, 'ldd', 'd', cd_p, cdof, cdof_qd) - return com, cinr, cd, cdof, cdofd + return state.replace(com=com, cinr=cinr, cd=cd, cdof=cdof, cdofd=cdofd) -def inverse( - sys: System, - qd: jp.ndarray, - cinr: jp.ndarray, - cd: jp.ndarray, - cdof: jp.ndarray, - cdofd: jp.ndarray, -) -> jp.ndarray: +def inverse(sys: System, state: State) -> jp.ndarray: """Calculates the system's forces given input motions. This function computes inverse dynamics using the Newton-Euler algorithm: @@ -152,11 +136,7 @@ def inverse( Args: sys: a brax system - qd: joint velocity vector - cinr: inertia in com frame - cd: body velocities in com frame - cdof: dofs in com frame - cdofd: cdof velocity + state: generalized state Returns: tau: generalized forces resulting from joint positions and velocities @@ -171,13 +151,15 @@ def cdd_fn(cdd_parent, cdofd, qd, dof_idx): return cdd - cdd = scan.tree(sys, cdd_fn, 'ddd', cdofd, qd, sys.dof_link(depth=True)) + cdd = scan.tree( + sys, cdd_fn, 'ddd', state.cdofd, state.qd, sys.dof_link(depth=True) + ) # cfrc_flat = cinr * cdd + cd x (cinr * cd) def frc(cinr, cdd, cd): return cinr.mul(cdd) + cd.cross(cinr.mul(cd)) - cfrc_flat = jax.vmap(frc)(cinr, cdd, cd) + cfrc_flat = jax.vmap(frc)(state.cinr, cdd, state.cd) # backward scan up tree: accumulate link center of mass forces def cfrc_fn(cfrc_child, cfrc): @@ -188,7 +170,7 @@ def cfrc_fn(cfrc_child, cfrc): cfrc = scan.tree(sys, cfrc_fn, 'l', cfrc_flat, reverse=True) # tau = cdof * cfrc[dof_link] - tau = jax.vmap(lambda x, y: x.dot(y))(cdof, cfrc.take(sys.dof_link())) + tau = jax.vmap(lambda x, y: x.dot(y))(state.cdof, cfrc.take(sys.dof_link())) return tau @@ -207,17 +189,8 @@ def stiffness_fn(typ, q, dof): return frc -def forward( - sys: System, - q: jp.ndarray, - qd: jp.ndarray, - cinr: Inertia, - cd: Motion, - cdof: Motion, - cdofd: Motion, - tau: jp.ndarray, -) -> jp.ndarray: - """Calculates the system's motion given input forces. +def forward(sys: System, state: State, tau: jp.ndarray) -> jp.ndarray: + """Calculates resulting joint forces given input forces. This method builds and solves the linear system: M @ qdd = -C + tau @@ -226,19 +199,14 @@ def forward( Args: sys: a brax system - q: joint position vector - qd: joint velocity vector - cinr: inertia in com frame - cd: body velocities in com frame - cdof: dofs in com frame - cdofd: cdof velocity + state: generalized state tau: joint force input vector Returns: - qdd: joint acceleration vector + qfrc: joint force vector """ - qfrc_passive = _passive(sys, q, qd) - qfrc_bias = inverse(sys, qd, cinr, cd, cdof, cdofd) + qfrc_passive = _passive(sys, state.q, state.qd) + qfrc_bias = inverse(sys, state) qfrc = qfrc_passive - qfrc_bias + tau return qfrc diff --git a/brax/v2/generalized/mass.py b/brax/v2/generalized/mass.py index 63dca1e5..a4eae70a 100644 --- a/brax/v2/generalized/mass.py +++ b/brax/v2/generalized/mass.py @@ -13,16 +13,18 @@ # limitations under the License. # pylint:disable=g-multiple-import -"""Functions for calculating the mass matrix.""" +"""Functions for calculating the mass matrix and its inverse.""" import itertools +from brax.v2 import math from brax.v2 import scan from brax.v2.base import System +from brax.v2.generalized.base import State import jax from jax import numpy as jp -def matrix(sys: System, cinr: jp.ndarray, cdof: jp.ndarray) -> jp.ndarray: +def matrix(sys: System, state: State) -> jp.ndarray: """Calculates the mass matrix for the system given joint position. This function uses the Composite-Rigid-Body Algorithm as described here: @@ -30,9 +32,8 @@ def matrix(sys: System, cinr: jp.ndarray, cdof: jp.ndarray) -> jp.ndarray: https://users.dimi.uniud.it/~antonio.dangelo/Robotica/2019/helper/Handbook-dynamics.pdf Args: - sys: system defining the kinematic tree and other properties - cinr: inertia in com frame - cdof: dofs in com frame + sys: a brax system + state: generalized state Returns: a symmetric positive matrix (qd_size, qd_size) representing the generalized @@ -44,7 +45,7 @@ def crb_fn(crb_child, crb): crb += crb_child return crb - crb = scan.tree(sys, crb_fn, 'l', cinr, reverse=True) + crb = scan.tree(sys, crb_fn, 'l', state.cinr, reverse=True) # expand composite inertias to a matrix: M[i,j] = cdof_j * crb[i] * cdof_i @jax.vmap @@ -55,9 +56,9 @@ def mx_row(dof_link, cdof_i): def mx_col(cdof_j): return cdof_j.dot(f) - return mx_col(cdof) + return mx_col(state.cdof) - mx = mx_row(sys.dof_link(), cdof) + mx = mx_row(sys.dof_link(), state.cdof) # mask out empty parts of the matrix si, sj = [], [] @@ -79,3 +80,25 @@ def mx_col(cdof_j): mx = mx + jp.diag(sys.dof.armature) return mx + + +def matrix_inv(sys: System, state: State, approximate: bool = False) -> State: + """Calculates the mass matrix and its inverse for the system. + + Args: + sys: a brax system + state: generalized state + approximate: if true, iteratively approximates matrix inverse + + Returns: + state: generalized state with com, cinr, cd, cdof, cdofd updated + """ + + mx = matrix(sys, state) + + if approximate: + mx_inv = math.inv_approximate(mx, state.mass_mx_inv) + else: + mx_inv = jax.scipy.linalg.solve(mx, jp.eye(sys.qd_size()), assume_a='pos') + + return state.replace(mass_mx=mx, mass_mx_inv=mx_inv) diff --git a/brax/v2/generalized/pipeline.py b/brax/v2/generalized/pipeline.py index e544d62b..a7f16853 100644 --- a/brax/v2/generalized/pipeline.py +++ b/brax/v2/generalized/pipeline.py @@ -15,62 +15,18 @@ # pylint:disable=g-multiple-import """Physics pipeline for generalized coordinates engine.""" -from typing import Optional - from brax.v2 import actuator -from brax.v2 import base from brax.v2 import geometry from brax.v2 import kinematics -from brax.v2 import math -from brax.v2.base import Contact, Inertia, Motion, System +from brax.v2.base import System from brax.v2.generalized import constraint from brax.v2.generalized import dynamics from brax.v2.generalized import integrator from brax.v2.generalized import mass -from flax import struct -import jax +from brax.v2.generalized.base import State from jax import numpy as jp -@struct.dataclass -class State(base.State): - """Dynamic state that changes after every step. - - Attributes: - com: center of mass position - cinr: inertia in com frame - cd: body velocities in com frame - cdof: dofs in com frame - cdofd: cdof velocity - mass_mx: (qd_size, qd_size) mass matrix - mass_mx_inv: (qd_size, qd_size) inverse mass matrix - contact: calculated contacts - con_jac: constraint jacobian - con_pos: constraint position - con_diag: constraint A diagonal - qf_smooth: smooth dynamics force - qf_constraint: (qd_size,) force from constraints (collision etc) - qdd: (qd_size,) joint acceleration vector - """ - - # position/velocity based terms are updated at the end of each step: - com: jp.ndarray - cinr: Inertia - cd: Motion - cdof: Motion - cdofd: Motion - mass_mx: jp.ndarray - mass_mx_inv: jp.ndarray - contact: Optional[Contact] - con_jac: jp.ndarray - con_pos: jp.ndarray - con_diag: jp.ndarray - # acceleration based terms are calculated using terms from the previous step: - qf_smooth: jp.ndarray - qf_constraint: jp.ndarray - qdd: jp.ndarray - - def init(sys: System, q: jp.ndarray, qd: jp.ndarray) -> State: """Initializes physics state. @@ -82,35 +38,17 @@ def init(sys: System, q: jp.ndarray, qd: jp.ndarray) -> State: Returns: state: initial physics state """ + state = State.zero(sys) + + # position/velocity level terms x, xd = kinematics.forward(sys, q, qd) - com, cinr, cd, cdof, cdofd = dynamics.transform_com(sys, q, qd, x) - mass_mx = mass.matrix(sys, cinr, cdof) - one = jp.eye(sys.qd_size()) - mass_mx_inv = jax.scipy.linalg.solve(mass_mx, one, assume_a='pos') - contact = geometry.contact(sys, x) - con_jac, con_pos, con_diag = constraint.jacobian(sys, q, com, cdof, contact) - qf_smooth, qf_constraint, qdd = jp.zeros((3, sys.qd_size())) + state = state.replace(q=q, qd=qd, x=x, xd=xd) + state = state.replace(contact=geometry.contact(sys, x)) + state = dynamics.transform_com(sys, state) + state = mass.matrix_inv(sys, state) + state = constraint.jacobian(sys, state) - return State( - q, - qd, - x, - xd, - com, - cinr, - cd, - cdof, - cdofd, - mass_mx, - mass_mx_inv, - contact, - con_jac, - con_pos, - con_diag, - qf_smooth, - qf_constraint, - qdd, - ) + return state def step(sys: System, state: State, act: jp.ndarray) -> State: @@ -124,54 +62,26 @@ def step(sys: System, state: State, act: jp.ndarray) -> State: Returns: state: physics state after step """ + # calculate acceleration terms tau = actuator.to_tau(sys, act, state.q) + state = state.replace(qf_smooth=dynamics.forward(sys, state, tau)) + state = state.replace(qf_constraint=constraint.force(sys, state)) - # calculate acceleration terms - qf_smooth = dynamics.forward( - sys, state.q, state.qd, state.cinr, state.cd, state.cdof, state.cdofd, tau - ) - qf_constraint = constraint.force( - sys, - state.qd, - qf_smooth, - state.mass_mx_inv, - state.con_jac, - state.con_pos, - state.con_diag, - ) # add dof damping to the mass matrix # because we already have M^-1, we use the derivative of the inverse: # (A + εX)^-1 = A^-1 - εA^-1 @ X @ A^-1 + O(ε^2) mx_inv = state.mass_mx_inv mx_inv_damp = mx_inv - mx_inv @ (jp.diag(sys.dof.damping) * sys.dt) @ mx_inv - qdd = mx_inv_damp @ (qf_smooth + qf_constraint) + qdd = mx_inv_damp @ (state.qf_smooth + state.qf_constraint) + state = state.replace(qdd=qdd) # update position/velocity level terms q, qd = integrator.integrate(sys, state.q, state.qd, qdd) x, xd = kinematics.forward(sys, q, qd) - com, cinr, cd, cdof, cdofd = dynamics.transform_com(sys, q, qd, x) - mass_mx = mass.matrix(sys, cinr, cdof) - mass_mx_inv = math.inv_approximate(mass_mx, state.mass_mx_inv) - contact = geometry.contact(sys, x) - con_jac, con_pos, con_diag = constraint.jacobian(sys, q, com, cdof, contact) + state = state.replace(q=q, qd=qd, x=x, xd=xd) + state = state.replace(contact=geometry.contact(sys, x)) + state = dynamics.transform_com(sys, state) + state = mass.matrix_inv(sys, state, approximate=True) + state = constraint.jacobian(sys, state) - return State( - q, - qd, - x, - xd, - com, - cinr, - cd, - cdof, - cdofd, - mass_mx, - mass_mx_inv, - contact, - con_jac, - con_pos, - con_diag, - qf_smooth, - qf_constraint, - qdd, - ) + return state diff --git a/brax/v2/geometry/contact.py b/brax/v2/geometry/contact.py index 28553990..b31996ee 100644 --- a/brax/v2/geometry/contact.py +++ b/brax/v2/geometry/contact.py @@ -15,13 +15,14 @@ # pylint:disable=g-multiple-import """Calculations for generating contacts.""" -from typing import Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, TypeVar from brax.v2 import math from brax.v2.base import ( Box, Capsule, Contact, + Convex, Geometry, Mesh, Plane, @@ -36,6 +37,9 @@ from jax.tree_util import tree_map +Geom = TypeVar('Geom', bound=Geometry) + + def _combine( geom_a: Geometry, geom_b: Geometry ) -> Tuple[float, float, Tuple[int, int]]: @@ -188,6 +192,61 @@ def capsule_face(face, face_norm): return capsule_face(face_vert, face_norm) +def _convex_convex(convex_a: Convex, convex_b: Convex) -> Contact: + """Calculates contacts between two convex objects.""" + normals_a = geom_mesh.get_face_norm(convex_a.vert, convex_a.face) + normals_b = geom_mesh.get_face_norm(convex_b.vert, convex_b.face) + faces_a = jp.take(convex_a.vert, convex_a.face, axis=0) + faces_b = jp.take(convex_b.vert, convex_b.face, axis=0) + + def transform_faces(convex, faces, normals): + faces = convex.transform.pos + jax.vmap(math.rotate, in_axes=[0, None])( + faces, convex.transform.rot + ) + normals = math.rotate(normals, convex.transform.rot) + return faces, normals + + v_transform_faces = jax.vmap(transform_faces, in_axes=[None, 0, 0]) + faces_a, normals_a = v_transform_faces(convex_a, faces_a, normals_a) + faces_b, normals_b = v_transform_faces(convex_b, faces_b, normals_b) + + def transform_verts(convex, vertices): + vertices = convex.transform.pos + math.rotate( + vertices, convex.transform.rot + ) + return vertices + + v_transform_verts = jax.vmap(transform_verts, in_axes=[None, 0]) + vertices_a = v_transform_verts(convex_a, convex_a.vert) + vertices_b = v_transform_verts(convex_b, convex_b.vert) + + unique_edges_a = jp.take(vertices_a, convex_a.unique_edge, axis=0) + unique_edges_b = jp.take(vertices_b, convex_b.unique_edge, axis=0) + + c = geom_math.sat_hull_hull( + faces_a, + faces_b, + vertices_a, + vertices_b, + normals_a, + normals_b, + unique_edges_a, + unique_edges_b, + ) + friction, elasticity, link_idx = jax.tree_map( + lambda x: jp.repeat(x, 4), _combine(convex_a, convex_b) + ) + + return Contact( + c.pos, + c.normal, + c.penetration, + friction, + elasticity, + link_idx, + ) + + def _mesh_plane(mesh: Mesh, plane: Plane) -> Contact: """Calculates contacts between a mesh and a plane.""" @@ -205,35 +264,23 @@ def point_plane(vert): (Sphere, Plane): jax.vmap(_sphere_plane), (Sphere, Sphere): jax.vmap(_sphere_sphere), (Sphere, Capsule): jax.vmap(_sphere_capsule), + (Sphere, Box): jax.vmap(_sphere_mesh), (Sphere, Mesh): jax.vmap(_sphere_mesh), (Capsule, Plane): jax.vmap(_capsule_plane), (Capsule, Capsule): jax.vmap(_capsule_capsule), + (Capsule, Box): jax.vmap(_capsule_mesh), (Capsule, Mesh): jax.vmap(_capsule_mesh), + (Convex, Convex): jax.vmap(_convex_convex), (Mesh, Plane): jax.vmap(_mesh_plane), } -def contact(sys: System, x: Transform) -> Optional[Contact]: - """Calculates contacts in the system. - - Args: - sys: system defining the kinematic tree and other properties - x: link transforms in world frame - - Returns: - Contact pytree, one row for each element in sys.contacts - - Raises: - RuntimeError: if sys.contacts has an invalid type pair - """ - - contacts = [] - +def _geom_pairs( + sys, x +) -> List[Tuple[Optional[Callable[[Geom, Geom], Any]], Geom, Geom]]: + """Transforms geometries and gets contact functions to apply to each pair.""" + geom_pairs = [] for geom_a, geom_b in sys.contacts: - # convert Box geometry to Mesh - geom_a = geom_mesh.box(geom_a) if isinstance(geom_a, Box) else geom_a - geom_b = geom_mesh.box(geom_b) if isinstance(geom_b, Box) else geom_b - # check both type signatures: a, b <> b, a fun = _TYPE_FUN.get((type(geom_a), type(geom_b))) if fun is None: @@ -251,6 +298,28 @@ def contact(sys: System, x: Transform) -> Optional[Contact]: tx_b = x.take(geom_b.link_idx).vmap().do(geom_b.transform) geom_b = geom_b.replace(transform=tx_b) + geom_pairs.append((fun, geom_a, geom_b)) # type: ignore + + return geom_pairs + + +def contact(sys: System, x: Transform) -> Optional[Contact]: + """Calculates contacts in the system. + + Args: + sys: system defining the kinematic tree and other properties + x: link transforms in world frame + + Returns: + Contact pytree, one row for each element in sys.contacts + + Raises: + RuntimeError: if sys.contacts has an invalid type pair + """ + + contacts = [] + + for fun, geom_a, geom_b in _geom_pairs(sys, x): c = fun(geom_a, geom_b) # type: ignore c = tree_map(jp.concatenate, c) contacts.append(c) diff --git a/brax/v2/geometry/contact_test.py b/brax/v2/geometry/contact_test.py index 31b2fc1f..9d23a347 100644 --- a/brax/v2/geometry/contact_test.py +++ b/brax/v2/geometry/contact_test.py @@ -19,6 +19,7 @@ from brax.v2 import kinematics from brax.v2 import test_utils from brax.v2.io import mjcf +from etils import epath from jax import numpy as jp import numpy as np @@ -214,6 +215,96 @@ def test_capsule_mesh(self): ) +class ConvexTest(absltest.TestCase): + """Tests the convex-convex contact function.""" + + _BOX_BOX = """ + + + + + + + + + + + + + """ + + def test_box_box(self): + sys = mjcf.loads(self._BOX_BOX) + x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) + c = geometry.contact(sys, x) + self.assertEqual(c.pos.shape[0], 4) + self.assertTrue((c.penetration > 0).all()) + np.testing.assert_array_almost_equal(c.pos[:, 2], jp.array([0.39] * 4), 2) + np.testing.assert_array_almost_equal( + c.normal, jp.array([[0.0, 0.0, -1.0]] * 4) + ) + + _BOX_BOX_EDGE = """ + + + + + + + + + + + + + """ + + def test_box_box_edge(self): + """Tests the edge contact for a box-box collision.""" + sys = mjcf.loads(self._BOX_BOX_EDGE) + x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) + c = geometry.contact(sys, x) + # Only one contact point. + self.assertGreater(c.penetration[0], 0) + self.assertTrue((c.penetration[1:] < 0).all()) + # The normal is pointing in the edge-edge axis direction. + np.testing.assert_array_almost_equal(c.normal[0, 0], 0) + self.assertGreater(c.normal[0, 1], 0) + self.assertLess(c.normal[0, 2], 0) + + _CONVEX_CONVEX = """ + + + + + + + + + + + + + + + + + """ + + def test_hull_hull(self): + """Tests generic convex-convex collision.""" + sys = mjcf.loads( + self._CONVEX_CONVEX, + asset_path=epath.resource_path('brax') / 'v2/test_data/', + ) + x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) + c = geometry.contact(sys, x) + # Only one contact point for an edge contact. + self.assertGreater(c.penetration[0], 0) + self.assertTrue((c.penetration[1:] < 0).all()) + np.testing.assert_array_almost_equal(c.normal[0], jp.array([0, 0, -1])) + + class MeshTest(absltest.TestCase): """Tests the mesh contact functions.""" diff --git a/brax/v2/geometry/math.py b/brax/v2/geometry/math.py index c5bf5940..f5103348 100644 --- a/brax/v2/geometry/math.py +++ b/brax/v2/geometry/math.py @@ -15,7 +15,10 @@ """Geometry functions.""" from typing import Tuple + from brax.v2 import math +from brax.v2.base import Contact +import jax from jax import numpy as jp @@ -95,9 +98,9 @@ def closest_segment_point_plane( # get the line-plane intersection. We then clip t to be in [0, 1] to be on # the line segment. n = plane_normal - d = jp.dot(p0, n) # shortest distance from origin to plane - denom = jp.dot(n, b - a) - t = (d - jp.dot(n, a)) / (denom + 1e-6 * (denom == 0.0)) + d = jp.sum(p0 * n) # shortest distance from origin to plane + denom = jp.sum(n * (b - a)) + t = (d - jp.sum(n * a)) / (denom + 1e-6 * (denom == 0.0)) t = jp.clip(t, 0, 1) segment_point = a + t * (b - a) @@ -181,3 +184,360 @@ def closest_segment_triangle_points( tri_pt = jp.sum(tri_pt, axis=0) / jp.sum(mask) return seg_pt, tri_pt + + +def _project_pt_onto_plane( + pt: jp.ndarray, plane_pt: jp.ndarray, plane_normal: jp.ndarray +) -> jp.ndarray: + """Projects a point onto a plane along the plane normal.""" + dist = (pt - plane_pt).dot(plane_normal) + return pt - dist * plane_normal + + +def _project_poly_onto_plane( + poly: jp.ndarray, plane_pt: jp.ndarray, plane_normal: jp.ndarray +) -> jp.ndarray: + """Projects a polygon onto a plane using the plane normal.""" + return jax.vmap(_project_pt_onto_plane, in_axes=[0, None, None])( + poly, plane_pt, math.normalize(plane_normal)[0] + ) + + +def _project_poly_onto_poly_plane( + poly1: jp.ndarray, norm1: jp.ndarray, poly2: jp.ndarray, norm2: jp.ndarray +) -> jp.ndarray: + """Projects poly1 onto the poly2 plane along poly1's normal.""" + d = poly2[0].dot(norm2) + denom = norm1.dot(norm2) + t = (d - poly1.dot(norm2)) / (denom + 1e-6 * (denom == 0.0)) + new_poly = poly1 + t.reshape(-1, 1) * norm1 + return new_poly + + +def point_in_front_of_plane( + plane_pt: jp.ndarray, plane_normal: jp.ndarray, pt: jp.ndarray +) -> bool: + """Checks if a point is strictly in front of a plane.""" + return (pt - plane_pt).dot(plane_normal) > 1e-6 + + +def _clip_edge_to_planes( + edge_p0: jp.ndarray, + edge_p1: jp.ndarray, + plane_pts: jp.ndarray, + plane_normals: jp.ndarray, +) -> Tuple[jp.ndarray, jp.ndarray]: + """Clips an edge against side planes. + + We return two clipped points, and a mask to include the new edge or not. + + Args: + edge_p0: the first point on the edge + edge_p1: the second point on the edge + plane_pts: side plane points + plane_normals: side plane normals + + Returns: + new_ps: new edge points that are clipped against side planes + mask: a boolean mask, True if an edge point is a valid clipped point and + False otherwise + """ + p0, p1 = edge_p0, edge_p1 + p0_in_front = jax.vmap(jp.dot)(p0 - plane_pts, plane_normals) > 1e-6 + p1_in_front = jax.vmap(jp.dot)(p1 - plane_pts, plane_normals) > 1e-6 + + # Get candidate clipped points along line segment (p0, p1) by clipping against + # all clipping planes. + candidate_clipped_ps = jax.vmap( + closest_segment_point_plane, in_axes=[None, None, 0, 0] + )(p0, p1, plane_pts, plane_normals) + + def clip_edge_point(p0, p1, p0_in_front, clipped_ps): + @jax.vmap + def choose_edge_point(in_front, clipped_p): + return jp.where(in_front, clipped_p, p0) + + # Pick the clipped point if p0 is in front of the clipping plane. Otherwise + # keep p0 as the edge point. + new_edge_ps = choose_edge_point(p0_in_front, clipped_ps) + + # Pick the clipped point that is most along the edge direction. + # This degenerates to picking the original point p0 if p0 is *not* in front + # of any clipping planes. + dists = jp.dot(new_edge_ps - p0, p1 - p0) + new_edge_p = new_edge_ps[jp.argmax(dists)] + return new_edge_p + + # Clip each edge point. + new_p0 = clip_edge_point(p0, p1, p0_in_front, candidate_clipped_ps) + new_p1 = clip_edge_point(p1, p0, p1_in_front, candidate_clipped_ps) + clipped_pts = jp.array([new_p0, new_p1]) + + # Keep the original points if both points are in front of any of the clipping + # planes, rather than creating a new clipped edge. If the entire subject edge + # is in front of any clipping plane, we need to grab an edge from the clipping + # polygon instead. + both_in_front = p0_in_front & p1_in_front + mask = ~jp.any(both_in_front) + new_ps = jp.where(mask, clipped_pts, jp.array([p0, p1])) + # Mask out crossing clipped edge points. + mask = jp.where((p0 - p1).dot(new_ps[0] - new_ps[1]) < 0, False, mask) + return new_ps, jp.array([mask, mask]) + + +def clip( + clipping_poly: jp.ndarray, + subject_poly: jp.ndarray, + clipping_normal: jp.ndarray, + subject_normal: jp.ndarray, +) -> Tuple[jp.ndarray, jp.ndarray]: + """Clips a subject polygon against a clipping polygon. + + A parallelized clipping algorithm for convex polygons. The result is a set of + vertices on the clipped subject polygon in the subject polygon plane. + + Args: + clipping_poly: the polygon that we use to clip the subject polygon against + subject_poly: the polygon that gets clipped + clipping_normal: normal of the clipping polygon + subject_normal: normal of the subject polygon + + Returns: + clipped_pts: points on the clipped polygon + mask: True if a point is in the clipping polygon, False otherwise + """ + # Get clipping edge points, edge planes, and edge normals. + clipping_p0 = jp.roll(clipping_poly, 1, axis=0) + clipping_plane_pts = clipping_p0 + clipping_p1 = clipping_poly + clipping_plane_normals = jax.vmap(jp.cross, in_axes=[0, None])( + clipping_p1 - clipping_p0, + clipping_normal, + ) + + # Get subject edge points, edge planes, and edge normals. + subject_edge_p0 = jp.roll(subject_poly, 1, axis=0) + subject_plane_pts = subject_edge_p0 + subject_edge_p1 = subject_poly + subject_plane_normals = jax.vmap(jp.cross, in_axes=[0, None])( + subject_edge_p1 - subject_edge_p0, + subject_normal, + ) + + # Clip all edges of the subject poly against clipping side planes. + clipped_edges0, masks0 = jax.vmap( + _clip_edge_to_planes, in_axes=[0, 0, None, None] + )( + subject_edge_p0, + subject_edge_p1, + clipping_plane_pts, + clipping_plane_normals, + ) + + # Project the clipping poly onto the subject plane. + clipping_p0_s = _project_poly_onto_poly_plane( + clipping_p0, clipping_normal, subject_poly, subject_normal + ) + # TODO consider doing a roll here instead of projection. + clipping_p1_s = _project_poly_onto_poly_plane( + clipping_p1, clipping_normal, subject_poly, subject_normal + ) + + # Clip all edges of the clipping poly against subject planes. + clipped_edges1, masks1 = jax.vmap( + _clip_edge_to_planes, in_axes=[0, 0, None, None] + )(clipping_p0_s, clipping_p1_s, subject_plane_pts, subject_plane_normals) + + # Merge the points and reshape. + clipped_edges = jp.concatenate([clipped_edges0, clipped_edges1]) + masks = jp.concatenate([masks0, masks1]) + clipped_points = clipped_edges.reshape((-1, 3)) + mask = masks.reshape(-1) + + return clipped_points, mask + + +def manifold_points( + poly: jp.ndarray, poly_mask: jp.ndarray, poly_norm: jp.ndarray +) -> jp.ndarray: + """Chooses four points with maximal area within a polygon.""" + dist_mask = jp.where(poly_mask, 0.0, -1e6) + a_idx = jp.argmax(dist_mask) + a = poly[a_idx] + # choose point b farthest from a + b_idx = (((a - poly) ** 2).sum(axis=1) + dist_mask).argmax() + b = poly[b_idx] + # maximize area of triangle + qa, qb = poly - a, poly - b + area_0 = jax.vmap(jp.cross)(qa, qb).dot(poly_norm) + c_idx = jp.argmax(area_0 + dist_mask) + c = poly[c_idx] + # maximize negative area with each edge against fourth point + qc = poly - c + area_1 = jax.vmap(jp.cross)(qb, qc).dot(poly_norm) + area_2 = jax.vmap(jp.cross)(qc, qa).dot(poly_norm) + min_area = (jp.stack([area_0, area_1, area_2]) - dist_mask).min(axis=0) + d_idx = jp.argmin(min_area) + return jp.array([a_idx, b_idx, c_idx, d_idx]) + + +def _create_contact_manifold( + clipping_poly: jp.ndarray, + subject_poly: jp.ndarray, + clipping_norm: jp.ndarray, + subject_norm: jp.ndarray, + sep_axis: jp.ndarray, +) -> Contact: + """Creates a contact manifold between two convex polygons. + + The polygon faces are expected to have a counter clockwise winding order so + that clipping plane normals point away from the polygon center. + + Args: + clipping_poly: The reference polygon to clip the contact against. + subject_poly: The subject polygon to clip contacts onto. + clipping_norm: The clipping polygon normal. + subject_norm: The subject polygon normal. + sep_axis: The separating axis + + Returns: + contact: Contact object. + """ + # Clip the subject (incident) face onto the clipping (reference) face. + # The incident points are clipped points on the subject polygon. + poly_incident, mask = clip( + clipping_poly, subject_poly, clipping_norm, subject_norm + ) + # The reference points are clipped points on the clipping polygon. + poly_ref = _project_poly_onto_plane( + poly_incident, clipping_poly[0], clipping_norm + ) + behind_clipping_plane = point_in_front_of_plane( + clipping_poly[0], -clipping_norm, poly_incident + ) + mask = mask & behind_clipping_plane + + # Choose four contact points. + best = manifold_points(poly_ref, mask, clipping_norm) + contact_pts = jp.take(poly_ref, best, axis=0) + mask_pts = jp.take(mask, best, axis=0) + penetration_dir = jp.take(poly_incident, best, axis=0) - contact_pts + penetration = penetration_dir.dot(-clipping_norm) + penetration = jp.where(mask_pts, penetration, -jp.ones_like(penetration)) + + contact = Contact( + pos=contact_pts, + normal=jp.stack([sep_axis] * 4, 0), + penetration=penetration, + friction=jp.array([]), + elasticity=jp.array([]), + link_idx=jp.array([]), + ) + + return contact + + +def sat_hull_hull( + faces_a: jp.ndarray, + faces_b: jp.ndarray, + vertices_a: jp.ndarray, + vertices_b: jp.ndarray, + normals_a: jp.ndarray, + normals_b: jp.ndarray, + unique_edges_a: jp.ndarray, + unique_edges_b: jp.ndarray, +) -> Contact: + """Runs the Separating Axis Test for a pair of hulls. + + Given two convex hulls, the Separating Axis Test finds a separating axis + between all edge pairs and face pairs. Edge pairs create a single contact + point and face pairs create a contact manifold (up to four contact points). + We return both the edge and face contacts. Valid contacts can be checked with + penetration > 0. Resulting edge contacts should be preferred over face + contacts. + + Args: + faces_a: An ndarray of hull A's polygon faces. + faces_b: An ndarray of hull B's polygon faces. + vertices_a: Vertices for hull A. + vertices_b: Vertices for hull B. + normals_a: Normal vectors for hull A's polygon faces. + normals_b: Normal vectors for hull B's polygon faces. + unique_edges_a: Unique edges for hull A. + unique_edges_b: Unique edges for hull B. + + Returns: + contact: A contact. + """ + # get the separating axes + edge_dir_a = unique_edges_a[:, 0] - unique_edges_a[:, 1] + edge_dir_b = unique_edges_b[:, 0] - unique_edges_b[:, 1] + edge_dir_a_r = jp.tile(edge_dir_a, reps=(unique_edges_b.shape[0], 1)) + edge_dir_b_r = jp.repeat(edge_dir_b, repeats=unique_edges_a.shape[0], axis=0) + edge_edge_axes = jax.vmap(jp.cross)(edge_dir_a_r, edge_dir_b_r) + edge_edge_axes = jax.vmap(lambda x: math.normalize(x, axis=0)[0])( + edge_edge_axes + ) + + axes = jp.concatenate([normals_a, normals_b, edge_edge_axes]) + + # for each separating axis, get the support + @jax.vmap + def get_support(axis): + support_a = jax.vmap(jp.dot, in_axes=[None, 0])(axis, vertices_a) + support_b = jax.vmap(jp.dot, in_axes=[None, 0])(axis, vertices_b) + dist1 = support_a.max() - support_b.min() + dist2 = support_b.max() - support_a.min() + sign = jp.where(dist1 > dist2, -1, 1) + dist = jp.minimum(dist1, dist2) + dist = jp.where(~jp.all(axis == 0.0), dist, 1e6) # degenerate axis + return dist, sign + + support, sign = get_support(axes) + + # choose the best separating axis + best_idx = jp.argmin(support) + best_sign = sign[best_idx] + best_axis = axes[best_idx] + is_edge_contact = best_idx >= (normals_a.shape[0] + normals_b.shape[0]) + + # get the (reference) face most aligned with the separating axis + dist_a = jax.vmap(jp.dot, in_axes=[None, 0])(best_axis, normals_a) + dist_b = jax.vmap(jp.dot, in_axes=[None, 0])(best_axis, normals_b) + a_max = dist_a.argmax() + b_max = dist_b.argmax() + a_min = dist_a.argmin() + b_min = dist_b.argmin() + + ref_face = jp.where(best_sign > 0, faces_a[a_max], faces_b[b_max]) + ref_face_norm = jp.where(best_sign > 0, normals_a[a_max], normals_b[b_max]) + incident_face = jp.where(best_sign > 0, faces_b[b_min], faces_a[a_min]) + incident_face_norm = jp.where( + best_sign > 0, normals_b[b_min], normals_a[a_min] + ) + + contact = _create_contact_manifold( + ref_face, + incident_face, + ref_face_norm, + incident_face_norm, + -best_sign * best_axis, + ) + + # For edge contacts, we use the clipped face point, mainly for performance + # reasons. For small penetration, the clipped face point is roughly the edge + # contact point. + # TODO revisit edge contact pos (for deep penetration) with same perf + idx = contact.penetration.argmax() + contact = contact.replace( + penetration=jp.where( + is_edge_contact, + jp.array([contact.penetration[idx], -1, -1, -1]), + contact.penetration, + ), + pos=jp.where( + is_edge_contact, jp.tile(contact.pos[idx], (4, 1)), contact.pos + ), + ) + + return contact diff --git a/brax/v2/geometry/math_test.py b/brax/v2/geometry/math_test.py index 4960f103..50bea696 100644 --- a/brax/v2/geometry/math_test.py +++ b/brax/v2/geometry/math_test.py @@ -18,8 +18,10 @@ from absl.testing import parameterized from brax.v2 import math from brax.v2.geometry import math as geom_math +import jax import jax.numpy as jp import numpy as np +from scipy import spatial def _get_rand_point(seed=None): @@ -43,10 +45,32 @@ def _get_rand_triangle_vertices(seed=None): return verts[0, :], verts[1, :], verts[2, :] +def _get_rand_dir(seed=None): + if seed is not None: + np.random.seed(seed) + r = np.random.randn(1) + theta = np.random.random(1) * 2 * np.pi + a = (np.random.random(1) - 0.5) * 2.0 + phi = np.arccos(a) + x = r * np.sin(phi) * np.cos(theta) + y = r * np.sin(phi) * np.sin(theta) + z = r * np.cos(phi) + return jp.array([x, y, z]).squeeze() + + +def _get_rand_convex_polygon(seed=None): + if seed is not None: + np.random.seed(seed) + points = np.random.random((15, 2)) + hull = spatial.ConvexHull(points) + points = points[hull.vertices] + points = np.array([[p[0], p[1], 0] for p in points]) + return points + + def _minimize(fn, sample_fn, lb, ub, tol, max_iter=20, seed=42): """Minimize a function, roughly, using the cross-entropy method.""" - # We can alternatively use scipy.optimize with non-linear constraints here, - # but we are avoiding a scipy dependency for now. + # We can alternatively use scipy.optimize with non-linear constraints here. assert lb.shape == ub.shape, 'bounds need to have the same shape' np.random.seed(seed) @@ -262,5 +286,145 @@ def test_closest_segment_triangle_points(self, i, j): self.assertLessEqual(test_dist, expected_dist + 1e-5) +def _check_eq_pts(pts1, pts2, atol=1e-6): + # For every point in pts1, make sure we have a point in pts2 + # that is close within `atol`, and vice versa. + if not pts1.size and not pts2.size: + return True + elif not pts1.size: + return False + elif not pts2.size: + return False + if pts1.shape[-1] != 3 or pts2.shape[-1] != 3: + raise AssertionError('Points should be three dimensional.') + eq = True + for p1 in pts1: + eq = eq and jp.any(jp.sum(jp.abs(p1 - pts2) < atol, axis=-1) == 3) + for p2 in pts2: + eq = eq and jp.any(jp.sum(jp.abs(p2 - pts1) < atol, axis=-1) == 3) + return eq + + +def _clip(clipping_polygon, subject_polygon): + """Returns the clipped subject polygon, using Sutherland-Hodgman Clipping.""" + polygon1 = clipping_polygon + polygon1_normal = jp.cross( + polygon1[-1] - polygon1[0], polygon1[1] - polygon1[0] + ) + clipping_edges = [ + (polygon1[i - 1], polygon1[i]) for i in range(len(polygon1)) + ] + # Clipping plane normals point away from the clipping poly center (polygon + # points assumed to have a clockwise winding order). + clipping_planes = [ + (e0, jp.cross(polygon1_normal, e1 - e0)) for e0, e1 in clipping_edges + ] + + output_polygon = subject_polygon + for clipping_plane in clipping_planes: + input_ = output_polygon + output_polygon = [] + + starting_pt = input_[-1] + clipping_plane_pt, clipping_plane_normal = clipping_plane + + for endpt in input_: + intersection_pt = geom_math.closest_segment_point_plane( + starting_pt, endpt, clipping_plane_pt, clipping_plane_normal + ) + starting_pt_front = geom_math.point_in_front_of_plane( + clipping_plane_pt, clipping_plane_normal, starting_pt + ) + endpt_front = geom_math.point_in_front_of_plane( + clipping_plane_pt, clipping_plane_normal, endpt + ) + + if not endpt_front and not starting_pt_front: + output_polygon.append(endpt) + elif not endpt_front and starting_pt_front: + output_polygon.append(intersection_pt) + output_polygon.append(endpt) + elif not starting_pt_front: + output_polygon.append(intersection_pt) + + starting_pt = endpt + + if not output_polygon: + # All clipping points outside the subject polygon. + return jp.array(output_polygon) + + return jp.array(output_polygon) + + +class ClippingTest(parameterized.TestCase): + """Tests that the clipping algorithm matches a baseline implementation.""" + + clip_vectorized = jax.jit(geom_math.clip) + + @parameterized.parameters(range(100)) + def test_clipped_triangles(self, i): + subject_poly = jp.array(_get_rand_triangle_vertices(i)) + clipping_poly = jp.array(_get_rand_triangle_vertices(i + 1)) + + poly_out = _clip(clipping_poly, subject_poly) + + clipping_normal = jp.cross( + clipping_poly[1] - clipping_poly[0], + clipping_poly[-1] - clipping_poly[0], + ) + subject_normal = jp.cross( + subject_poly[1] - subject_poly[0], subject_poly[-1] - subject_poly[0] + ) + poly_out_jax, mask = ClippingTest.clip_vectorized( + clipping_poly, subject_poly, clipping_normal, subject_normal + ) + + self.assertTrue( + _check_eq_pts(poly_out, poly_out_jax[mask], atol=1e-4), + f'Clipped triangles {i} did not match.', + ) + + @parameterized.parameters(range(100)) + def test_clipped_hulls(self, i): + # The hulls are all in the x-y plane, unlike in `test_clipped_triangles`. + subject_poly = jp.array(_get_rand_convex_polygon(i)) + clipping_poly = jp.array(_get_rand_convex_polygon(i + 1)) + + poly_out = _clip(clipping_poly, subject_poly) + + clipping_normal = jp.cross( + clipping_poly[1] - clipping_poly[0], + clipping_poly[-1] - clipping_poly[0], + ) + subject_normal = jp.cross( + subject_poly[1] - subject_poly[0], + subject_poly[-1] - subject_poly[0], + ) + poly_out_jax, mask = ClippingTest.clip_vectorized( + clipping_poly, subject_poly, clipping_normal, subject_normal + ) + + self.assertTrue( + _check_eq_pts(poly_out, poly_out_jax[mask], atol=1e-2), + f'Clipped hulls {i} did not match.', + ) + + +class ManifoldPointsTest(parameterized.TestCase): + """Tests manifold point selection.""" + + def test_manifold_points(self): + poly = jp.array([ + [0.99999994, 0.14842263, 0.39985055], + [0.8585786, 0.00145163, 0.39985055], + [1.0, -0.14551926, 0.39985055], + [1.1414213, 0.00145174, 0.39985055], + ]) + poly_mask = jp.array([False, True, True, True]) + poly_norm = jp.array([0.0, 0.0, 1.0]) + idx = geom_math.manifold_points(poly, poly_mask, poly_norm) + self.assertSequenceEqual(idx.tolist(), [1, 3, 1, 2]) + + if __name__ == '__main__': absltest.main() diff --git a/brax/v2/geometry/mesh.py b/brax/v2/geometry/mesh.py index 78e4a27c..9d067b69 100644 --- a/brax/v2/geometry/mesh.py +++ b/brax/v2/geometry/mesh.py @@ -16,9 +16,14 @@ """Useful functions for creating and processing meshes.""" import itertools +import logging +from typing import Dict, Tuple -from brax.v2.base import Box, Mesh +from brax.v2.base import Box, Convex, Mesh from jax import numpy as jp +import numpy as np +from scipy import spatial +import trimesh _BOX_CORNERS = list(itertools.product((-1, 1), (-1, 1), (-1, 1))) @@ -27,42 +32,206 @@ # The faces of a triangulated box, i.e. the indices in _BOX_CORNERS of the # vertices of the 12 triangles (two triangles for each side of the box). _TRIANGULATED_BOX_FACES = [ - 0, 1, 4, 4, 1, 5, # front - 0, 4, 2, 2, 4, 6, # bottom - 6, 4, 5, 6, 5, 7, # right - 2, 6, 3, 3, 6, 7, # back - 1, 3, 5, 5, 3, 7, # top - 0, 2, 1, 1, 2, 3, # left + 0, 4, 1, 4, 5, 1, # left + 0, 2, 4, 2, 6, 4, # bottom + 6, 5, 4, 6, 7, 5, # front + 2, 3, 6, 3, 7, 6, # right + 1, 5, 3, 5, 7, 3, # top + 0, 1, 2, 1, 3, 2, # back +] +# Rectangular box faces using a counter-clockwise winding order convention. +_BOX_FACES = [ + 0, 4, 5, 1, # left + 0, 2, 6, 4, # bottom + 6, 7, 5, 4, # front + 2, 3, 7, 6, # right + 1, 5, 7, 3, # top + 0, 1, 3, 2, # back ] # pyformat: enable +_MAX_HULL_FACE_VERTICES = 20 +_CONVEX_CACHE: Dict[Tuple[int, int], Convex] = {} + + +def get_face_norm(vert: jp.ndarray, face: jp.ndarray) -> jp.ndarray: + """Calculates face normals given vertices and face indexes.""" + assert len(vert.shape) == 2 and len(face.shape) == 2, ( + f'vert and face should have dim of 2, got {len(vert.shape)} and ' + f'{len(face.shape)}' + ) + face_vert = jp.take(vert, face, axis=0) + # use CCW winding order convention + edge0 = face_vert[:, 1, :] - face_vert[:, 0, :] + edge1 = face_vert[:, -1, :] - face_vert[:, 0, :] + face_norm = jp.cross(edge0, edge1) + face_norm = face_norm / jp.linalg.norm(face_norm, axis=1).reshape((-1, 1)) + return face_norm + + +def get_unique_edges(vert: np.ndarray, face: np.ndarray) -> np.ndarray: + """Returns unique edges. + + Args: + vert: (n_vert, 3) vertices + face: (n_face, n_vert) face index array -def box(b: Box) -> Mesh: - """Creates a mesh from a box geometry.""" - assert len(b.halfsize.shape) == 2 and b.halfsize.shape[-1] == 3, ( - 'Box halfsize should have a batch dimension and have length 3, ' - f'got {b.halfsize.shape}' + Returns: + edges: 2-tuples of vertice indexes for each edge + """ + r_face = np.roll(face, 1, axis=1) + edges = np.concatenate(np.array([face, r_face]).T) + + # do a first pass to remove duplicates + edges.sort(axis=1) + edges = np.unique(edges, axis=0) + edges = edges[edges[:, 0] != edges[:, 1]] # get rid of edges from padded face + + # get normalized edge directions + edge_vert = vert.take(edges, axis=0) + edge_dir = edge_vert[:, 0] - edge_vert[:, 1] + norms = np.sqrt(np.sum(edge_dir**2, axis=1)) + edge_dir = edge_dir / norms.reshape((-1, 1)) + + # get the first unique edge for all pairwise comparisons + diff1 = edge_dir[:, None, :] - edge_dir[None, :, :] + diff2 = edge_dir[:, None, :] + edge_dir[None, :, :] + matches = (np.linalg.norm(diff1, axis=-1) < 1e-6) | ( + np.linalg.norm(diff2, axis=-1) < 1e-6 ) - n_boxes = b.halfsize.shape[0] - box_corners = jp.array([_BOX_CORNERS] * n_boxes) - vert = box_corners * b.halfsize.reshape(n_boxes, -1, 3) - face = jp.array([_TRIANGULATED_BOX_FACES] * n_boxes).reshape(n_boxes, -1, 3) + matches = np.tril(matches).sum(axis=-1) + unique_edge_idx = np.where(matches == 1)[0] + + return edges[unique_edge_idx] + + +def _box(b: Box, triangulated=True) -> Tuple[np.ndarray, np.ndarray]: + """Creates face and vert from a box geometry.""" + assert b.halfsize.shape == ( + 3, + ), f'Box halfsize should have shape (3,), got {b.halfsize.shape}' + box_corners = jp.array(_BOX_CORNERS) + vert = box_corners * b.halfsize.reshape(-1, 3) + box_faces = _TRIANGULATED_BOX_FACES if triangulated else _BOX_FACES + face_dim = 3 if triangulated else 4 + face = jp.array([box_faces]).reshape(-1, face_dim) + return vert, face + + +def box_tri(b: Box) -> Mesh: + """Creates a triangulated mesh from a box geometry.""" + vert, face = _box(b, triangulated=True) return Mesh( + vert=vert, face=face, + link_idx=b.link_idx, + transform=b.transform, + friction=b.friction, + elasticity=b.elasticity, + ) + + +def box_hull(b: Box) -> Convex: + """Creates a mesh for a box with rectangular faces.""" + vert, face = _box(b, triangulated=False) + return Convex( vert=vert, + face=face, link_idx=b.link_idx, transform=b.transform, friction=b.friction, elasticity=b.elasticity, + unique_edge=get_unique_edges(vert, face), ) -def get_face_norm(vert: jp.ndarray, face: jp.ndarray) -> jp.ndarray: - """Calculates face normals given vertices and face indexes.""" - face_vert = jp.take(vert, face, axis=0) - # use CCW winding order convention - edge0 = face_vert[:, 0, :] - face_vert[:, 2, :] - edge1 = face_vert[:, 0, :] - face_vert[:, 1, :] - face_norm = jp.cross(edge0, edge1) - face_norm = face_norm / jp.linalg.norm(face_norm, axis=1).reshape((-1, 1)) - return face_norm +def _convex_hull_2d(points: np.ndarray, normal: np.ndarray) -> np.ndarray: + """Calculates the hull face for a set of points on a plane.""" + # project points onto the closest axis plane + best_axis = np.abs(np.eye(3).dot(normal)).argmax() + axis = np.eye(3)[best_axis] + d = points.dot(axis).reshape((-1, 1)) + axis_points = points - d * axis + axis_points = axis_points[:, list({0, 1, 2} - {best_axis})] + + # get the polygon hull face, and make the points ccw wrt the face normal + # TODO: consider sorting unique edges by their angle to get the hull + c = spatial.ConvexHull(axis_points) + order_ = np.where(axis.dot(normal) > 0, 1, -1) + hull_point_idx = c.vertices[::order_] + assert (axis_points - c.points).sum() == 0 + + return hull_point_idx + + +def _merge_coplanar(tm: trimesh.Trimesh) -> np.ndarray: + """Merges coplanar facets.""" + if not tm.facets: + return tm.faces.copy() # no facets, return faces + if not tm.faces.shape[0]: + raise ValueError('Mesh has no faces.') + + # Get faces. + face_idx = set(range(tm.faces.shape[0])) - set(np.concatenate(tm.facets)) + face_idx = np.array(list(face_idx)) + faces = tm.faces[face_idx] if face_idx.shape[0] > 0 else np.array([]) + + # Get facets. + facets = [] + for i, facet in enumerate(tm.facets): + point_idx = np.unique(tm.faces[facet]) + points = tm.vertices[point_idx] + normal = tm.facets_normal[i] + + # convert triangulated facet to a polygon + hull_point_idx = _convex_hull_2d(points, normal) + face = point_idx[hull_point_idx] + + # resize faces that exceed max polygon vertices + every = face.shape[0] // _MAX_HULL_FACE_VERTICES + 1 + face = face[::every] + facets.append(face) + + # Pad facets so that they can be stacked. + max_len = max(f.shape[0] for f in facets) if facets else faces.shape[1] + assert max_len <= _MAX_HULL_FACE_VERTICES + for i, f in enumerate(facets): + if f.shape[0] < max_len: + f = np.pad(f, (0, max_len - f.shape[0]), 'edge') + facets[i] = f + + if not faces.shape[0]: + assert facets + return np.array(facets) # no faces, return facets + + # Merge faces and facets. + faces = np.pad(faces, ((0, 0), (0, max_len - faces.shape[1])), 'edge') + return np.concatenate([faces, facets]) + + +def _convex_hull(m: Mesh) -> Convex: + """Creates a convex hull from a mesh.""" + tm = trimesh.Trimesh(vertices=m.vert, faces=m.face) + tm_convex = trimesh.convex.convex_hull(tm) + vert = tm_convex.vertices.copy() + face = _merge_coplanar(tm_convex) + return Convex( + vert=vert, + face=face, + link_idx=m.link_idx, + transform=m.transform, + friction=m.friction, + elasticity=m.elasticity, + unique_edge=get_unique_edges(vert, face), + ) + + +def convex_hull(mesh: Mesh) -> Convex: + """Creates a convex hull from a mesh.""" + def _key(mesh): + return (hash(mesh.vert.data.tobytes()), hash(mesh.face.data.tobytes())) + key = _key(mesh) + if key not in _CONVEX_CACHE: + logging.info('Converting mesh %s into convex hull.', key) + _CONVEX_CACHE[key] = _convex_hull(mesh) + return _CONVEX_CACHE[key] diff --git a/brax/v2/geometry/mesh_test.py b/brax/v2/geometry/mesh_test.py index a108c91d..2d44d375 100644 --- a/brax/v2/geometry/mesh_test.py +++ b/brax/v2/geometry/mesh_test.py @@ -12,30 +12,144 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint:disable=g-multiple-import """Tests for mesh.py.""" from absl.testing import absltest -from brax.v2.base import Box +from brax.v2.base import Box, Convex, Mesh, Transform from brax.v2.geometry import mesh import numpy as np -class BoxMeshTest(absltest.TestCase): +class BoxTest(absltest.TestCase): def test_box(self): + """Tests a triangulated box.""" b = Box( - halfsize=np.repeat(0.5, 6).reshape(2, 3), + halfsize=np.repeat(0.5, 3), link_idx=None, transform=None, friction=0.42, elasticity=1, ) - m = mesh.box(b) - self.assertSequenceEqual(m.vert.shape, (2, 8, 3)) # eight box corners + m = mesh.box_tri(b) + self.assertIsInstance(m, Mesh) + self.assertSequenceEqual(m.vert.shape, (8, 3)) # eight box corners self.assertEqual(np.unique(np.abs(m.vert)), 0.5) - self.assertSequenceEqual(m.face.shape, (2, 12, 3)) # two triangles per face + self.assertSequenceEqual(m.face.shape, (12, 3)) # two triangles per face self.assertEqual(m.friction, 0.42) + expected_face_norm = [ + [0, -1.0, 0], # left + [0, -1.0, 0], + [0, 0, -1.0], # bottom + [0, 0, -1.0], + [+1.0, 0, 0], # front + [+1.0, 0, 0], + [0, +1.0, 0], # right + [0, +1.0, 0], + [0, 0, +1.0], # top + [0, 0, +1.0], + [-1.0, 0, 0], # back + [-1.0, 0, 0], + ] + face_norm = mesh.get_face_norm(m.vert, m.face) + np.testing.assert_array_almost_equal(face_norm, expected_face_norm) + + def test_box_hull(self): + """Tests a polygon box.""" + b = Box( + halfsize=np.repeat(0.5, 3).reshape(3), + link_idx=None, + transform=None, + friction=0.42, + elasticity=1, + ) + h = mesh.box_hull(b) + self.assertIsInstance(h, Convex) + self.assertSequenceEqual(h.vert.shape, (8, 3)) + self.assertEqual(np.unique(np.abs(h.vert)), 0.5) + np.testing.assert_array_equal(h.unique_edge, [[0, 1], [0, 2], [0, 4]]) + self.assertSequenceEqual(h.face.shape, (6, 4)) # one rectangle per face + self.assertEqual(h.friction, 0.42) + + expected_face_norm = [ + [0, -1.0, 0], # left + [0, 0, -1.0], # bottom + [+1.0, 0, 0], # front + [0, +1.0, 0], # right + [0, 0, +1.0], # top + [-1.0, 0, 0], # back + ] + face_norm = mesh.get_face_norm(h.vert, h.face) + np.testing.assert_array_almost_equal(face_norm, expected_face_norm) + + +class ConvexTest(absltest.TestCase): + + def test_pyramid(self): + """Tests a pyramid convex hull.""" + vert = np.array([ + [-0.025, 0.05, 0.05], + [-0.025, -0.05, -0.05], + [-0.025, -0.05, 0.05], + [-0.025, 0.05, -0.05], + [0.075, 0.0, 0.0], + ]) + face = np.array( + [[0, 1, 2], [0, 3, 1], [0, 4, 3], [0, 2, 4], [2, 1, 4], [1, 3, 4]] + ) + pyramid = Mesh( + link_idx=0, + transform=Transform.zero((1,)), + vert=vert, + face=face, + friction=1, + elasticity=0, + ) + h = mesh.convex_hull(pyramid) + + self.assertIsInstance(h, Convex) + + # check verts + vidx = [0, 3, 2, 1, 4] # verts get mixed up by trimesh + np.testing.assert_array_equal(h.vert, vert[vidx]) + + # check faces + map_ = {v: k for k, v in enumerate(vidx)} + h_face = np.vectorize(map_.get)(h.face) + np.testing.assert_array_equal( + h_face, + np.array([ + [2, 4, 0, 0], + [0, 4, 3, 3], + [4, 2, 1, 1], + [1, 3, 4, 4], + [3, 1, 2, 0], + ]), + ) + + # check edges + np.testing.assert_array_equal( + np.vectorize(map_.get)(h.unique_edge), + np.array([[0, 3], [0, 2], [0, 4], [3, 4], [2, 4], [1, 4]]), + ) + self.assertEqual(h.friction, 1) + + +class UniqueEdgesTest(absltest.TestCase): + + def test_tetrahedron_edges(self): + """Tests unique edges for a tetrahedron.""" + vert = np.array( + [[-0.1, 0.0, -0.1], [0.0, 0.1, 0.1], [0.1, 0.0, -0.1], [0.0, -0.1, 0.1]] + ) + face = np.array([[0, 1, 2], [0, 2, 3], [0, 3, 1], [2, 1, 3]]) + idx = mesh.get_unique_edges(vert, face) + np.testing.assert_array_equal( + idx, np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]) + ) + if __name__ == '__main__': absltest.main() diff --git a/brax/v2/io/html.py b/brax/v2/io/html.py index 74c08218..d1863325 100644 --- a/brax/v2/io/html.py +++ b/brax/v2/io/html.py @@ -18,18 +18,18 @@ from typing import List, Optional, Union import brax -from brax.v2.base import System, Transform +from brax.v2.base import State, System from brax.v2.io import json from etils import epath import jinja2 -def save(path: str, sys: System, xs: List[Transform]): - """Saves trajectory as an HTML file.""" +def save(path: str, sys: System, states: List[State]): + """Saves trajectory as an HTML text file.""" path = epath.Path(path) if not path.parent.exists(): path.parent.mkdir(parents=True) - path.write_text(render(sys, xs)) + path.write_text(render(sys, states)) def render_from_json( @@ -52,16 +52,16 @@ def render_from_json( def render( sys: System, - xs: List[Transform], + states: List[State], height: Union[int, str] = 480, colab: bool = True, base_url: Optional[str] = None, ) -> str: - """Returns an HTML string that visualizes the brax system and trajectory. + """Returns an HTML string for the brax system and trajectory. Args: sys: brax System object - xs: world coordinates to render + states: list of system states to render height: the height of the render window colab: whether to use css styles for colab base_url: the base url for serving the visualizer files. By default, a CDN @@ -70,4 +70,4 @@ def render( Returns: string containing HTML for the brax visualizer """ - return render_from_json(json.dumps(sys, xs), height, colab, base_url) + return render_from_json(json.dumps(sys, states), height, colab, base_url) diff --git a/brax/v2/io/json.py b/brax/v2/io/json.py index 88270bfa..0de8cbf9 100644 --- a/brax/v2/io/json.py +++ b/brax/v2/io/json.py @@ -18,11 +18,15 @@ import json from typing import List, Text -from brax.v2.base import System, Transform +from brax.v2.base import State, System from etils import epath import jax +import jax.numpy as jp import numpy as np +# State attributes needed for the visualizer. +_STATE_ATTR = ['x', 'contact'] + def _to_dict(obj): """Converts python object to a json encodeable object.""" @@ -47,27 +51,61 @@ def _to_dict(obj): return obj -def dumps(sys: System, xs: List[Transform]) -> Text: - """Creates a json string of the system config.""" +def _compress_contact(states: List[State]) -> List[State]: + """Reduces the number of contacts based on penetration > 0.""" + stacked = jax.tree_map(lambda *x: jp.stack(x), *states) + if stacked.contact is None: + return states + contact_mask = stacked.contact.penetration > 0 + n_contact = contact_mask.sum(axis=1).max() + + def pad(arr, n): + r = jp.zeros(n) + if len(arr.shape) > 1: + r = jp.zeros((n, arr.shape[1])) + r = r.at[: arr.shape[0]].set(arr) + return r + + def compress(i, s): + mask = contact_mask[i] + contact = jax.tree_map(lambda x: pad(x[mask], n_contact), s.contact) + return s.replace(contact=contact) + + return [compress(i, s) for i, s in enumerate(states)] + + +def dumps(sys: System, states: List[State]) -> Text: + """Creates a json string of the system config. + + Args: + sys: brax System object + states: list of brax system states + + Returns: + string containing json dump of system and states + """ d = _to_dict(sys) # fill in empty link names link_names = [n or f'link {i}' for i, n in enumerate(sys.link_names)] - # group geoms by their links + # key geoms by their link names link_geoms = {} for g in d['geoms']: - link = 'world' if g['link_idx'] is None else link_names[g['link_idx']] - link_geoms.setdefault(link, []).append(g) + link_name = 'world' if g['link_idx'] is None else link_names[g['link_idx']] + link_geoms.setdefault(link_name, []).append(g) d['geoms'] = link_geoms - d['pos'] = [x.pos for x in xs] - d['rot'] = [x.rot for x in xs] - d['debug'] = False # TODO implement debugging + states = _compress_contact(states) + + # stack states for the viewer + states = _to_dict(jax.tree_map(lambda *x: jp.stack(x), *states)) + d['states'] = {k: states[k] for k in _STATE_ATTR} + return json.dumps(_to_dict(d)) -def save(path: str, sys: System, xs: List[Transform]): +def save(path: str, sys: System, states: List[State]): """Saves a system config and trajectory as json.""" with epath.Path(path).open('w') as fout: - fout.write(dumps(sys, xs)) + fout.write(dumps(sys, states)) diff --git a/brax/v2/io/json_test.py b/brax/v2/io/json_test.py index 330d5532..55ae48e5 100644 --- a/brax/v2/io/json_test.py +++ b/brax/v2/io/json_test.py @@ -15,32 +15,28 @@ """Tests for json.""" import json + from absl.testing import absltest from brax.v2 import test_utils +from brax.v2.generalized import pipeline from brax.v2.io import json as bjson +import jax.numpy as jp class JsonTest(absltest.TestCase): def test_dumps(self): - sys = test_utils.load_fixture('ur5e/robot.xml') - res = bjson.dumps(sys, []) + sys = test_utils.load_fixture('convex_convex.xml') + state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) + res = bjson.dumps(sys, [state]) res = json.loads(res) self.assertIsInstance(res['geoms'], dict) self.assertSequenceEqual( sorted(res['geoms'].keys()), - [ - 'forearm_link', - 'shoulder_link', - 'upper_arm_link', - 'world', - 'wrist_1_link', - 'wrist_2_link', - 'wrist_3_link', - ], + ['box', 'dodecahedron', 'pyramid', 'tetrahedron', 'world'], ) - self.assertLen(res['geoms']['world'], 2) + self.assertLen(res['geoms']['world'], 1) if __name__ == '__main__': diff --git a/brax/v2/io/mjcf.py b/brax/v2/io/mjcf.py index eb354fdc..182e9dfa 100644 --- a/brax/v2/io/mjcf.py +++ b/brax/v2/io/mjcf.py @@ -16,7 +16,7 @@ """Function to load MuJoCo mjcf format to Brax system.""" import itertools -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, TypeVar, Union import warnings from xml.etree import ElementTree @@ -24,7 +24,9 @@ Actuator, Box, Capsule, + Convex, DoF, + Geometry, Inertia, Link, Mesh, @@ -34,6 +36,7 @@ System, Transform, ) +from brax.v2.geometry import mesh as geom_mesh from etils import epath from jax import numpy as jp from jax.tree_util import tree_map @@ -41,6 +44,9 @@ import numpy as np +Geom = TypeVar('Geom', bound=Geometry) + + # map from mujoco geom_type to brax geometry string _GEOM_TYPE_CLS = {0: Plane, 2: Sphere, 3: Capsule, 6: Box, 7: Mesh} @@ -70,7 +76,10 @@ ((Capsule, False), (Box, False)), ((Capsule, False), (Mesh, False)), ((Box, False), (Plane, True)), + ((Box, False), (Box, False)), + ((Box, False), (Mesh, False)), ((Mesh, False), (Plane, True)), + ((Mesh, False), (Mesh, False)), ] @@ -110,7 +119,8 @@ def _find_assets( meshdir = meshdir or _get_meshdir(elem) fname = elem.attrib.get('file') or elem.attrib.get('filename') if fname: - assets[fname] = (path.parent / (meshdir or '') / fname).read_bytes() + dirname = path if path.is_dir() else path.parent + assets[fname] = (dirname / (meshdir or '') / fname).read_bytes() for child in list(elem): assets.update(_find_assets(child, path, meshdir)) @@ -118,7 +128,7 @@ def _find_assets( return assets -def _get_mesh(mj: mujoco.MjModel, i: int) -> Tuple[jp.ndarray, jp.ndarray]: +def _get_mesh(mj: mujoco.MjModel, i: int) -> Tuple[np.ndarray, np.ndarray]: """Gets mesh from mj at index i.""" last = (i + 1) >= mj.nmesh face_start = mj.mesh_faceadr[i] @@ -129,7 +139,7 @@ def _get_mesh(mj: mujoco.MjModel, i: int) -> Tuple[jp.ndarray, jp.ndarray]: vert_end = mj.mesh_vertadr[i + 1] if not last else mj.mesh_vert.shape[0] vert = mj.mesh_vert[vert_start:vert_end] - return face, vert + return vert, face def _get_name(mj: mujoco.MjModel, i: int) -> str: @@ -194,6 +204,92 @@ def _get_custom(mj: mujoco.MjModel) -> Dict[str, np.ndarray]: return custom +def _contact_geoms(geom_a: Geom, geom_b: Geom) -> Tuple[Geom, Geom]: + """Converts geometries for contact functions.""" + if isinstance(geom_a, Box) and isinstance(geom_b, Box): + geom_a = geom_mesh.box_hull(geom_a) + geom_b = geom_mesh.box_hull(geom_b) + elif isinstance(geom_a, Box) and isinstance(geom_b, Mesh): + geom_a = geom_mesh.box_hull(geom_a) + geom_b = geom_mesh.convex_hull(geom_b) + elif isinstance(geom_a, Mesh) and isinstance(geom_b, Box): + geom_a = geom_mesh.convex_hull(geom_a) + geom_b = geom_mesh.box_hull(geom_b) + elif isinstance(geom_a, Mesh) and isinstance(geom_b, Mesh): + geom_a = geom_mesh.convex_hull(geom_a) + geom_b = geom_mesh.convex_hull(geom_b) + elif isinstance(geom_a, Box): + geom_a = geom_mesh.box_tri(geom_a) + elif isinstance(geom_b, Box): + geom_b = geom_mesh.box_tri(geom_b) + + # pad face vertices so that we can broadcast between geom_a and geom_b faces + if isinstance(geom_a, Convex) and isinstance(geom_b, Convex): + sa = geom_a.face.shape[-1] + sb = geom_b.face.shape[-1] + if sa < sb: + face = np.pad(geom_a.face, ((0, 0), (0, sb - sa)), 'edge') + geom_a = geom_a.replace(face=face) + elif sb < sa: + face = np.pad(geom_b.face, ((0, 0), (0, sa - sb)), 'edge') + geom_b = geom_b.replace(face=face) + + return geom_a, geom_b + + +def _contacts_from_geoms( + mj: mujoco.MjModel, geoms: List[Geom] +) -> List[Tuple[Geom, Geom]]: + """Gets a list of contact geom pairs.""" + collidables = [] + for key_a, key_b in _COLLIDABLES: + if mj.opt.collision == 1: # only check predefined pairs in mj.pair_* + geoms_ab = [] + for geom_id_a, geom_id_b in zip(mj.pair_geom1, mj.pair_geom2): + geom_a, geom_b = geoms[geom_id_a], geoms[geom_id_b] + static_a, static_b = geom_a.link_idx is None, geom_b.link_idx is None + cls_a, cls_b = type(geom_a), type(geom_b) + if (cls_a, static_a) == key_a and (cls_b, static_b) == key_b: + geoms_ab.append((geom_a, geom_b)) + elif (cls_a, static_a) == key_b and (cls_b, static_b) == key_a: + geoms_ab.append((geom_b, geom_a)) + elif key_a == key_b: # types match, avoid double counting (a, b), (b, a) + geoms_a = [g for g in geoms if (type(g), g.link_idx is None) == key_a] + geoms_ab = list(itertools.combinations(geoms_a, 2)) + else: # types don't match, take every permutation + geoms_a = [g for g in geoms if (type(g), g.link_idx is None) == key_a] + geoms_b = [g for g in geoms if (type(g), g.link_idx is None) == key_b] + geoms_ab = list(itertools.product(geoms_a, geoms_b)) + if not geoms_ab: + continue + # filter out self-collisions + geoms_ab = [(a, b) for a, b in geoms_ab if a.link_idx != b.link_idx] + # convert the geometries so that they can be used for contact functions + geoms_ab = [_contact_geoms(a, b) for a, b in geoms_ab] + collidables.append(geoms_ab) + + # meshes with different shapes cannot be stacked, so we group meshes by vert + # and face shape + def key_fn(x): + def get_key(x): + if isinstance(x, Convex): + return (x.vert.shape, x.face.shape, x.unique_edge.shape) + if isinstance(x, Mesh): + return (x.vert.shape, x.face.shape) + return -1 + + return get_key(x[0]), get_key(x[1]) + + contacts = [] + for geoms_ab in collidables: + geoms_ab = sorted(geoms_ab, key=key_fn) + for _, g in itertools.groupby(geoms_ab, key=key_fn): + geom_a, geom_b = tree_map(lambda *x: np.stack(x), *g) + contacts.append((geom_a, geom_b)) + + return contacts + + def load_model(mj: mujoco.MjModel) -> System: """Creates a brax system from a MuJoCo model.""" # do some validation up front @@ -311,45 +407,11 @@ def load_model(mj: mujoco.MjModel) -> System: elif geom_cls is Box: geom = Box(halfsize=mj.geom_size[i, :], **kwargs) elif geom_cls is Mesh: - face, vert = _get_mesh(mj, mj.geom_dataid[i]) - geom = Mesh(face=face, vert=vert, **kwargs) + vert, face = _get_mesh(mj, mj.geom_dataid[i]) + geom = Mesh(vert=vert, face=face, **kwargs) geoms.append(geom) - # create contacts from geoms - contacts = [] - for key_a, key_b in _COLLIDABLES: - if mj.opt.collision == 1: # only check predefined pairs in mj.pair_* - geoms_ab = [] - for geom_id_a, geom_id_b in zip(mj.pair_geom1, mj.pair_geom2): - geom_a, geom_b = geoms[geom_id_a], geoms[geom_id_b] - static_a, static_b = geom_a.link_idx is None, geom_b.link_idx is None - cls_a, cls_b = type(geom_a), type(geom_b) - if (cls_a, static_a) == key_a and (cls_b, static_b) == key_b: - geoms_ab.append((geom_a, geom_b)) - elif (cls_a, static_a) == key_b and (cls_b, static_b) == key_a: - geoms_ab.append((geom_b, geom_a)) - elif key_a == key_b: # types match, avoid double counting (a, b), (b, a) - geoms_a = [g for g in geoms if (type(g), g.link_idx is None) == key_a] - geoms_ab = list(itertools.combinations(geoms_a, 2)) - else: # types don't match, take every permutation - geoms_a = [g for g in geoms if (type(g), g.link_idx is None) == key_a] - geoms_b = [g for g in geoms if (type(g), g.link_idx is None) == key_b] - geoms_ab = list(itertools.product(geoms_a, geoms_b)) - if not geoms_ab: - continue - - # meshes with different shapes cannot be stacked, so we group meshes by vert - # and face shape - def key_fn(x): - id_fn = lambda x: ( # pylint:disable=g-long-lambda - (x.vert.shape, x.face.shape) if isinstance(x, Mesh) else -1 - ) - return id_fn(x[0]), id_fn(x[1]) - - geoms_ab = sorted(geoms_ab, key=key_fn) - for _, g in itertools.groupby(geoms_ab, key=key_fn): - geom_a, geom_b = tree_map(lambda *x: np.stack(x), *g) - contacts.append((geom_a, geom_b)) + contacts = _contacts_from_geoms(mj, geoms) # create actuators ctrl_range = mj.actuator_ctrlrange diff --git a/brax/v2/io/mjcf_test.py b/brax/v2/io/mjcf_test.py index 9cd290b3..cef081c6 100644 --- a/brax/v2/io/mjcf_test.py +++ b/brax/v2/io/mjcf_test.py @@ -124,9 +124,9 @@ def test_load_humanoid(self): self.assertIsNone(plane.link_idx) def test_load_mesh_and_box(self): - sys = test_utils.load_fixture('ur5e/robot.xml') + sys = test_utils.load_fixture('convex_convex.xml') n_meshes = sum(isinstance(g, Mesh) for g in sys.geoms) - self.assertEqual(n_meshes, 14) + self.assertEqual(n_meshes, 3) n_boxes = sum(isinstance(g, Box) for g in sys.geoms) self.assertEqual(n_boxes, 1) diff --git a/brax/v2/math.py b/brax/v2/math.py index bfaa5477..60257c35 100644 --- a/brax/v2/math.py +++ b/brax/v2/math.py @@ -19,6 +19,7 @@ import jax from jax import custom_jvp from jax import numpy as jp +import numpy as np def rotate(vec: jp.ndarray, quat: jp.ndarray): @@ -314,7 +315,7 @@ def normalize( Returns: A tuple of (normalized array x, the norm). """ - norm = safe_norm(x) + norm = safe_norm(x, axis=axis) n = x / (norm + 1e-6 * (norm == 0.0)) return n, norm diff --git a/brax/v2/spring/pipeline.py b/brax/v2/spring/pipeline.py index ee7d6a70..ed174e5c 100644 --- a/brax/v2/spring/pipeline.py +++ b/brax/v2/spring/pipeline.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Physics pipeline for fully articulated dynamics and collisiion.""" # pylint:disable=g-multiple-import +"""Physics pipeline for fully articulated dynamics and collisiion.""" + from brax.v2 import actuator from brax.v2 import base from brax.v2 import geometry from brax.v2 import kinematics -from brax.v2.base import Motion, System, Transform +from brax.v2.base import System, Transform from brax.v2.spring import collisions from brax.v2.spring import integrator from brax.v2.spring import joints from brax.v2.spring import maximal from flax import struct + import jax from jax import numpy as jp @@ -31,10 +33,6 @@ @struct.dataclass class State(base.State): """Dynamic state that changes after every step.""" - q: jp.ndarray - qd: jp.ndarray - x: Transform - xd: Motion def init(sys: System, q: jp.ndarray, qd: jp.ndarray) -> State: @@ -49,7 +47,8 @@ def init(sys: System, q: jp.ndarray, qd: jp.ndarray) -> State: state: initial physics state """ x, xd = kinematics.forward(sys, q, qd) - return State(q, qd, x, xd) + contact = geometry.contact(sys, x) + return State(q, qd, x, xd, contact) def step(sys: System, state: State, act: jp.ndarray) -> State: @@ -112,4 +111,4 @@ def step(sys: System, state: State, act: jp.ndarray) -> State: x, xd = maximal.com_to_maximal(xi, xdi, coord_transform) q, qd = kinematics.inverse(sys, x, xd) - return State(q, qd, x, xd) + return State(q, qd, x, xd, contact) diff --git a/brax/v2/test_data/convex_convex.xml b/brax/v2/test_data/convex_convex.xml new file mode 100644 index 00000000..cdedf61d --- /dev/null +++ b/brax/v2/test_data/convex_convex.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/brax/v2/test_data/meshes/cylinder.stl b/brax/v2/test_data/meshes/cylinder.stl new file mode 100644 index 0000000000000000000000000000000000000000..3a953be24a0ff2dd5561cc35a044b6d82cb0f129 GIT binary patch literal 20884 zcmb`PZRj>fb;dV7R&0w^n)JmY#1{;$6->ZHsy_E*i^gCH#+o7;Um6k!SfvQ11p4A& zpcLwdmInHw!F&+&B@|L9DCD`wgiiW7af7xY++aG+}X@;S7xJHj8 z=4i}N}2LKJdh?XHY!c_ap6pUGCnlc9CgLKO9w)5xQUbMI?! zx?fD$5xeXe=PfGBTy3tBiW0JwhictDWexZ}ry==q)a9C2U3KZLzyFqPRMZ=28JWQd zqedbsLEF#hC0;@l*@hiQEhy=&CiaYI*A}#b(H@Q>d&j7Pi$N5kV(sSk=>666ARFyr z3|+D6zouhHU9S50UwmAau2{8?nT8tYSH?x5P(phSc3>Q6P@&1V`LSDmkRx)+o-KCtXcw4{yDu+pb0F z8KXqo*q_1Z8Y*#Hc{Qd-*zIr@A8L*SgwfuJ-6Tb@o0FL!3m-)Nb+EZ;PxhH#~IR;kRzP&V9dR`q#r_7k}+}>t1kEc&#x>`KYE}0GL7R|4U1?*HT+VK?4e<-tU<<0 zzxL4ahrj%vmj*hFy1e7f|9H6UbuU_Uf+(^bT>ZCceDKL!=C1{VCo`!}Qb*s_aV@5lzeYc8%dQX0-H#D4q8cmFyM$mFp zWTU3nsEm~vM4@&!b;T0aZmz2bUj0Yq?(2X4sL*(w!=8^<{6Ts6(Kk3IQfMuxv3)z1 zc<=k*D{d~o`ByiujM{j0p)aDKHzeAf6R18KNMm!IGY@LOFZ0;cE>YBjCu4+$UlYOb zaVbzb-y_ud+#ac}!05H^o{~l-3N|&?x-!e}YW@Dvv-POZPHX1X7z^ei6RxfaGnuTB zMp3fw|EYgC{M+rtN7O-a4f>)6uBb*ZbQ=@t3w<|5=#&b zRE<7;q)r;t&QZ}9@|dYTO}hTnx5U-FKz&QLt~s^4nQ30CGSyS7D>4uX*#@Ns{1QbC zW7tvKPxY(6M*w&&EvQD`DuFNlfSxr_wO!lQPX#xdF^69 zIzPkH+OMpjE?0d1^5rj|_mUm^`WyfK_{m3J8ybSPpSr-;L^kvcx*m0*FS5_)L0_<$ z$0}e}EzlrpMuaW%D!3T*OB5o^byW~=(x7&3K`_{wY(dkAHeW2Vj&eKLvb&z!qs@1W zh}}_r1y;ne|Y@*Z7)B;fZxe_i}xD&%M9j04cN@11yt?BeyV!m@Jqe3 zc@(JAZ?9dig(lK|*TjHcV&~eWI-;6!{&c@`buy^u@mSV6@~Y^IpLf?TY6N`o(=Wu?6;{N`t-_31;B|4^o=x zLvKG`iBRodfreHrV>N;N*73?2L?MrvdUEeal?LKX8q}VqFFbbF;eX$9S213-yV@b{ z&bps>d%otrdk#PS^zR8{huxhIeU$yGJW5);$M@3YQ zk!L@u+9UGt5@;mtz()&`0gc}Np%MLR=NB5xV@R~?1Acq`$WxzRyFDn?0`t|1&3C$Wjh@WrR&Y%X_wqvwt)EMY`iorZiY8b1g zJm?qq!EoL+Si;O1ZNmnmpkzi=GwQ0fN9_!jKx3vS#|{QEOeoY_Zu5d{XQF;I4d$_9 zC^U<#7Sz>Dy&YrZ!4hUpPt)}5pM3xF5ztO%c*ljtY^Si^`lcbE-3ToX>d_mXLv{ zHVl@KfoP2pQ5}_ipxh2U*}Zn%9+~eLU~7Lyb?6v*R-e+5#}a5H?S^f?>jJ;McHJ*} zzCoD>{laD*;XFPHGuEz@$vt?Zw*fpqEy94J-IogDXg~r%AhA+Vv9GYP4vTy|J}E_Z=Rc_ZBSu8YIKoz;|^F zGUI&C!}s922L1l_Ph45P{KDtCw#&q?nixbuZxau{ z%Sgzk-iWGSA8HS`BNEwN!|l-?`m2U7XTa9zqqgz$Nu$EAb~)9~?iXxxEZ^;?`aWo& zdXI6xwszJ`Dl|lCM_vB;2UnKQ-u#qbQAF)DPBMtfc^JI5pIYEI)emzD_x@1}{A#D5 z$37$4(=E8IpAADh&USrxJuqJC;r!aKEEgMhvDU}ag$e83mWH9*b&itsJOZ9_9$jOR zBdpJ+fBD-#KfLp2-{|L6HfndVU9^ifZtNJ?xp#syL|o{L8uZjCK`daTbtYG=j;;3F ze)4kzXw1-9YgZ)g2KGFX2DJ@1j=EqyFkv2)p1!YeWv+UXj?UtlaAFUJ8Dkdi_YB*u z8}H&!FEfau-dRRLBxxOm{MKGlRafvelMTIl)mRM z)hE7s{L1g#xkXLd!Lhli&d>0a`<3szbPf8w=Ei%D|8(t}PH4dIIgJR+KB{UL>J$Ca zQ_h@`fX&Qv9tD0Aifj+dB5ETX)Mv6YBidsTH%i=&2xa%$b=#b`_SiANoXgslZb!w9vJjDZsR%~1^r8-2h?*)qbMIekHcnQyKSFtWxP zZEo}Qd=Ssrdapc@xxICE)UyNSMg~Ues3iNS_AlL|$iQw?W5n<3JEvL`>I506!wrKa zWH4G%YGGZ0v45gi8?UClJ~E)OwuiCu&K()3tK1%`myp3|a~@PD12vMPO79Nh&Z*Vp zm>DFE)PiI{n*?ifmXLuN&dR9jB}C!PvuVdLbQ^sTZ<+z{$<`{O_CQpLalSpEo*48r zG%T__heIBTJx#W+RBR`wC>LwnL)>L_-E~_%v1lgV{M(p#==Ug~G+rIz?xT@muv1jJ zqI$_hznsHcD62+=-*Xy6qFv`?P%mpB^Argf%zV8*s%jU$5{m2*){a3GqBV_(%14|) zxgB*SyVtJUqpo%gur&s&Z5-oKSwuSgy$rSOj(zR(FO=te`s$V8a51n3n23U&VVMVF z?=>yI%tPMC;j@epFZHs95o#mVf=KAMw~gzq0!GxhtDfj3WJAx|BYL9O18b3NoL_s1 zf=e3ody-+f_^yHP94J;>ZR-t{BHLEQdMf9({vMg9cMZ7fxMKv%SM=RidiHO70A0J~ zQAJT6@HImtu^YDi9`E~jyjEAoKCa+B01_;7mXHzmkDPf`U9kl38cg&=k^IIMG}KF$ z7~wW82o3eCp2z-u28*cO4WO>KCdV6;^yJsaq8)U_7wT?@hU^}d+nE6rXGmn21G@&L zQJ}H4nbU@WYgnop9qgYw22mVi$jajt z?t>}9!EgOenH50~zYwv0gv(Nrp$o667>g5;BnAhCvjg%8ZEWp!9>` zcIe6OdAL0?*I$(k*nJPkJ&nwr0l#_%=`(_jhW%+r_21SW@H^f8@UNGTz4{aG7wmvG zR%S%z&LG<!|k6-wx0Ynoi)Y`?Sxb zvVQnC-1_}Bp~&v&78UQpKUp5wriF*NVd8pKe%?(}v-f5O+tUB3zv zqXKn~A;-v^S0{Qt*_q)hYTeWAbh-xNY0YC1IZBQjyFaybjLO#bAR}^e6vjj+3`Ri- z{r0+&_9iO&gBf_=KjhOF)MByn}X81}}YeMN_7CzzHRM9x}*OcKf7;Gw2u3zxEo2)?vE_>s)k3gqx9}q4x~vCHBDAtp@Vw@B4VqKhHoOsWqKp z{&nQI$f?;cUG7V_;}hGGJ4q`Pc7i|I%AXh_L68 zqY`P+cP8|W2f45{g4PNL+34;1vmI3omf%{M@Xo|Q%g79(sNE<*EYRq0Zzw8#W4%Y4 z@3s1$Li&3vdZV#F=h|Gg3a4w(7unSI?{gNt!z&hciu7H}u~UTe7M0$;>I|UHG0bl; z)H*eMdC!2&7E~zXs0-HvU}%@Szf(5QGW=U~ql66fvCK&Q3XHT)AZ6^sd5Z~!47Qu$ZzSzE`>+JyQpJpFw*~WgFW5TPpXZSp9zDp+26Hq zO)mNdnshkUawNiet1B{q>i;X3Ms1Fd_Gd6OKkFJ>yN{y&y=3A1sOpKK;UyRolc&HN zT$x9_($IH<#~2JxYaSj6O6ZqSwO$fJb-|1^&D*hN7(vOXdb7U&KSgPd3s1Qx>YFe9 aO}YLK{XxEXZR+p!{(+J{D;8t3bNBz;VQRJj literal 0 HcmV?d00001 diff --git a/brax/v2/test_data/meshes/dodecahedron.stl b/brax/v2/test_data/meshes/dodecahedron.stl new file mode 100644 index 0000000000000000000000000000000000000000..1a3f9f69a16a7e61769680a8489d155d6f456f67 GIT binary patch literal 1884 zcmb7EJ#Q0H5Zo#y9len#LXilOB8rgWGl-H8ku)jNAUjpQ#<$ zRCKiYFOk^U+j*X6B}FV*cX~TBJF|OtufCk0%|FhjqoeuR$>_!L>~uPtZ#>)F-Wjc5 zeEKkY`o!+_dz|R;)xr4vx3%%t@~h+j_u;f#e113q#Lo)P2g*xlydD@$g)E|R9T>aL z*mDLEIQ9E-dk;DSO0|fFikufSI$B{A?I(=RjE=xiUAeE>oLS5;BcS1~)ru@JuRIT? z8D<0q+-XI6XpYVdGXewda8>`Dmns1>IC6KrqQxQHFiQzS=c<&|JJjs~Z`^XQQ~4>TBF4M*tNDow5q`ZvD!KYzb3yjnFA6=-G0bo+PCc>d_e zmR5v;DO9DHZtqzlwV;5YE4`0=L!bOIAV*|L-PP$P!&CS9Wnjvv&`qj@l8?}TKG{CM z42=UibA~9%JIE&yS_MjUmW*IE?;sgbas3Hd(!-VH9gIyEt9)aW#N43Oe=TvD`uP%oPkZG_lr59uPT}S2aFQw A9{>OV literal 0 HcmV?d00001 diff --git a/brax/v2/test_data/meshes/pyramid.stl b/brax/v2/test_data/meshes/pyramid.stl new file mode 100644 index 0000000000000000000000000000000000000000..94de845b7f9886b0d5b1c8f2b7ba509a47a954fa GIT binary patch literal 384 zcmb`A!3}^g3>J AF#rGn literal 0 HcmV?d00001 diff --git a/brax/v2/test_data/meshes/tetrahedron.stl b/brax/v2/test_data/meshes/tetrahedron.stl new file mode 100644 index 0000000000000000000000000000000000000000..68fed6b855a9e18a2535f742c8fbca17f0288d35 GIT binary patch literal 284 zcmZQzpe|s68mIR&#a{2{lYI;f4f`QNaM~Uy2E`y5m^e_!eoP%8mB^|d4gjoMDmefE literal 0 HcmV?d00001 diff --git a/brax/v2/test_data/ur5e/meshes/base.stl b/brax/v2/test_data/ur5e/meshes/base.stl deleted file mode 100644 index c5881046676ce1ca5c88bbc482e05473ffb616bc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21084 zcmbW92T+wq+xHj6g1z@HYBWXxY0AC37E}~bu>uFNV?z`hMG@=-1qD%o*kemn6e*T- z?z1uWZtziKH@2Wg#fXXBZ_m+}{P)T0%sU@5j?C=-_S(LC{r9@*KcH8mF{49APZ>UK z^r%K7A{+UPo;tSI@Db$y^JkL{Ay>!O&ux6Ze=Z@hBJ71hH*WLSn^xJjXVuA4TKgku z`pmd9S^1r*vF)aGeP%#%79lNt*Jp(_Z0=$qFjwTu-%KB$E~WSD_RpLc=&m)opGvbk z4$CCuP^Wv@>wm9cB`_D#SC5A2Q_ppzcERN(@iDDcPIl5QXNwI?LBhRCs(y9LXl=5? z1xdJc%E*bUHz39S9lQq?>CApUYMPQ;>KsdwII-?S46%EkwPaT%Y%QXKhFM-IeIi=i>C_ z;s&KA4+iz4wK7NO4_I?YLVODs5chwQBcf!+kMhhEPOUtpW9M;tY@N60mwnS5laF6N9)6-*aeq|zE&0!_nBNgC2BvD%3`w>)WYZMmEY znx*Vpa@}Fw)ioF~(&YTZvIhSma2_Kq=T+fC*H9i~Z3oUD%(Zy_c-sG15B=8bK5_)l z_YQHLw2)f~%!TtbAC$o;2{B6ZVx&y+k@`!VA;QJIggf z+Sh32?sGH4IunsV8%Pr}_YY%Mq42b}9Cx40C1p+RkEYQ(QCoMbef3vZM$Y7z2Ueoa zX*a!3xk!3?>mWy@)%CH#tLa|fvg56@Waoz@-8b&EHY&^C5myw>V?w_BIFO$oVk3HW zHLyib=F{8Jwe=q^W@xYb_oGjbUN?Cbo}@LRyU?^}C3Va1g#4N7%R|SP78#BF6#{c@ zNt#aW^W5}(&9Y=)`M#|t)_>Dibn!Z9thT~M&lnt~SLvFp;R;q)ZE~#$;ntX#vzPj_ zXJ;npEi2qNTC86>ma326dCFMln=I|b4=eTT8Y4`Z?m5y%bm3HyST@r9q+5H|jjYfY zkKb$xxtyeh8IpD9Qd9Tm5v;Ty1x?xq8!sP4sJZ_GmBq z_tT3X`_ylQybjCDJ=)FHIurK{j?*vqsG_fZKhf!C9olvEW%0MOPV@v%e@n9 zVKmzwJ)y9sfa8dQqT(g&`W#It3leu{VVmg2eP$&TMl^C$V=7tO}C8sToL5< zpu!k8ORZi^^{10t22#7K%nA1{T=Rq&UQHJnKF%!K(ZG6b4A#ev7^zwM9mhr;kFs~d z`1|j+vK4vNOo<^Y=;?CnwGMTEcEo)!eO?mv?D&gzsb8|JRYUR0_18+*AiOma<-N_W1DG5?I8MT?zU)u#ys)ENbXoNkYTQpa;_H2wQk9!5+8q? zle?_gWoxY*XIIiakH^uT4^x!g2X`exPCOmtwy8t1bx-`YdwV+VO1S=(OmV{70Pf&~ zoN=?`)9U+)^I>a@NW{jsp&hEn=rI{d8q$Q^8vT$h>^ekro#oG-M%SkEp3K*Ol>3!z zhxVsjdzN-AJ64%OldE=U5?9In-X}dSJdb9F58t&Ta86~a<@_LUR znY*u6cWR_@pC@Gf(2+d8Y4crYZeT(z(6KYFZx=s8n2>>uqWs;+VpHl>|LP8vyh=U$o#KiCt3W04xTJG(;Bh3T4 zOcCwc`m>QEMw-0k$!6K3Le}M&?)OO62ffJF9=9E1@>#G_Up2sTleu^2ogUORO0*r< zhTU6|Z8RQD)yJ+=bE$LJCU}ZSTpMn#)YPB#l;tN%I~E&wL&AHIypiQk7f(7lnf8R^-&tN7f5qc+*2#B`U6p6f_}BFTT{E+!K+&lH|U7ZI<}(wjOO= zwl2f6uvWbrr|Zuy<#=e;xoe7;ErxjAqF(Ek7%>I!(W(d8XVy&7z2$9M=w6p0fj00~ zLrB7mNn&W1E#|>-5-=GJ>V z&USU~UO(g^5(hRDUsv{52+TF+w4+{X{-@K2kk0!%i|B8fi*)yw#tMZI`o0}*`naW) zmAfd~BP5{KpuLkze$oJtz`L)y25lSp36q1JNX{v1Lc+x7>VIu*t{=RbtRXGe`Tkyl z{ZLzs^)+ai3u(1hlM{SI+@DoMvj}^J1oj)=;RzYJ(OYbuQ%hVqbC+USNMK2Xj6UTc zLgl~11qSv}p5LLeCoSs$dlY+G?(JLc#r5%;XyxOl5SUBdOSZ}T!qD@c;+was4a}7; z{j}=QU97eVsp5XdJo5cWVW=C(=F0nEM(rSLN`dt3UVXkTX}4aRet5Q6AL3=%ePWKK zm_zo5ibirpA%Xi9?qh_EX_R3miQ~o1yM79RaD6*q@J8LJDp+OON|nSfd+=R zRJ^^)Te$Y)=W!yWFsB($uiu4V|Gti?M0GEE>HGw`a88njK9%Z6m0VJTFD=oBA7LRH z5}3<>L{pkh=h4vLEn3FGj)A$yPkZ~NlB4*6j%S!dkb!NNKE`*l&XyFURf4}9$dd-!@STkV8K$6(N>XRo{H!GX zWW&*$^MMUoaqohoM)Vr7MHsk>VBa=EvJuMb4=6Z+9c%e z50m+?3o^_{8lCXK6ujxHikIKDDf~gf2J?Md(}*dkCt^v2{4}AROIY|CtGA8w6}81H zah2%R!*SHzr>xZ5_cbiV?}hktmz2Man1VALOCsdt$*sBOE&Hu~)hs7K@0J%p8(fxq zD=2KLsy2FA3*oS+6F-?z*@VQ%Yu@_2__1_uRI-LNA>mHF#Q0Hlxc$x39$1#m_82`c z;~Tn2_Gs5l6ZAV@dC?h(mNjS_?=1>!YjbU79>o+SetI`ue_G$2=2y#*M5l#ZJhOUZ4=USi>nwqxE=TWI#9!0NMoA_p?RL7X1h8(34mVdhMS_xDsRGlGZ{T=;7t0*2kwLM#2uY;*kz$#kqY>wu% zc(-xH@l^Wie74p=h2>*H+h z|7tp~=Ok+dHn4_RlH3!EcQem)n9lb+@K*?I2hxN*n&DtB{q-!~y6YSdC$D%@{)Uy* zU)~_G?x^Sx;<2%FDV95_yfM+OQlG(+2w7_U-Mr)25PsEHCC~<{RD`^Ib!4w&g$S#vwx@8e zK6|GHERbOe5_WQxjs1Lo zB&1EqSh4Z@Eq%6bkA~WkdK#23IZE8-XINmUK_T$`kh+il;^_aJ z=f!7_7Co*WW&LIcFwBMaR2$#)87J=N=CPn1=RB|`s~UW(Yo-eHT*H+AQoMF77$z3} z{gOpFOfq6F92<26f1YL#{Z1C=KYrC%!(2$?XcAKD$QaQ)=q!8kw6xN4B(S!Gq^}9) z8=sssKbYjg9z`Tl>i3)J`Lk>dH8MQm2r0KYl)p>=(cEoEC5ERL(khYtB9u4x+G1Yg zQBSdfbl>1)deJf4)T+R8XBfEGPLw%0Sa`Ils(C!At-t@UKp*attf9_;I+MI14|#4L zba=Rk$e(9K0&}U_QcLMK*Op$r{VW5^D{tLJb;~+O%^~I8L8l(ltN(N(tJ}rd{Pf{; z5w8U*S_bMDSdu)Wp1sx&wvQ11DebS=KmzR%@N7+`WGEI8(+of0EGEudGN};L_X3N)%yRA2yyI=EXn1VL2w({hh z6D)kqr_ANvwo_h4k?8D_pm%=q)>LAU#aDYBG+bOV{%P(tev+c%U@lendD&>Zm}Yy< zoV})&312EP7t(S(B8Q5j`A^L|w^e4S?_mF^#Gw&GMTK?+WDPzNfifY1J*?P&ieO7 zZ}f96d5Wd3bwtk)1H%;5uTYsHWN&kC5nZf`$ZF!R5NHEwxz00Z3tQK-x@d9I1ND*} zq3Qa(xq-&z&z38S4k7QK#)#dG_nCZKPcvdJR8LU@l4{h87_n+(yjJ~mBMk{uRgos- zm*82#x#2~<*Od|s39KQiMuhC%JX356tfTKL>8=o%3zcF*-kgaRUR}3qi)z^_`Yzg2 zi4Xl|GpAxxL|s2Wi&s>1vfznxe!_9h=iV~?+9J#OG2y~&_G-=)v0znihPhB@Qi;eT zQe(&s6+_M$7^a|}gtiHZTDpmS7&BQob$Lw>p1q@uIJ8t>Aw5&1(Hjfpg5+ifk6W@#8-fa;fwziKg>H!&P%qkjxg2+W1rnSB3PF@*p9 z_6dv3XAE^QB!Wi7QQO6p>DL(+Pk8_1;oSRB0o$p7l-) z*Dx2-*d{_m-LZUDjkD~oyA8u!SSyvdw9uQ|#8&3{Z~ZJU4OIGpRv!@cJ!~H#^DlYw z#jbUDSGivyA@?X+=CkI2G$9Apdhlhj&G@8}*4K<88#tOs%X9Fm#s>{)&g&2GP;8*q zhcqD-TxW;_2YRutySNe|geqP@@l?IS)M46^*Ou|P{dcrT+C$jLt1~<>1>cHr?;>PF z=1k!fT1fqp%QGY}m-<#zGcinmwknc0k$U3FcL(=)e>>6YpQHMrY8iKr&t`R6PUrmw zRAYC}j@CRAR?_D!a<$cbp_ZW%Q)^Knw%GG!n|dn*s%og0$!L*WP2BLc=PjgafGMc3 zp>0Ckt8OxHzZ=5eg!WOQcTmN-aQ2+(Xj&2-?_wFjPEQw`gT_qb-nP#vrXYa|k__Fx zwZZ)I#uWbX(p)1Fs6eSvNX|UbY|5X`P4^pUNMN~0%Uk%7d~@Bvp?p(+H+CYqq;6`s zm?p^+2UR0g_sO?CW?pX|$ESV&j$#Vh!;)nDZrTNN!}{a+kZ-FQvA!6?q!N?59y512 zH-VRTJL7>VXah?kB-yUE7%Vx#cksiFl# z0&`(6$d~HpwZ+reUi`kiS)#g(vkyxmWN`nN=IxDw_@Akz6dPDBwomSGGfRmU?fiKu zhbks43;POf%Q%3?f1B&S9?IK`1{xArF4FRiCBHNmQwNCGhnujzD_v->3*q{`av2(` zY8Y8X$d6k}@m$A#V);;ig+T2MX+kEitizLf^%QURU!+(TMsJ~Q83`y)n>zBATrS+8 z*g)+K?GaM9T@(Ip5fd4M{1gIJJfx-nF}n%>>IM^qg?$tPYlSo+1$F!Il<=zJ>5yl1 z3XU!A78JU@FK;Pa*IHyIZyC(f-*Q?p>-57xKzatd1|v7nCW@Pc9CVgq$b^`)VEps!f@vNX48=x>RjD|&;#TA@ApN_Hbq%y{^Solo{xY@o`C_M}?t z6ePYX@r3==)LkL4O-Rdl(S7&KQO+Yo$#oSOzJQ}9=R4$u$^ENDeN`EYCtR`p@8$Ipnie2<-Y5_*}S`0D+K1ky^@f=Q%mrT1%t)(C|iap zxYwd>LRzQfv-V?$icX{47^WbhdUriW<*~6Rf<>!$WfTt*+kqt!(&X<=EPBusk-oKp zw#@Ui*4Q>#uS}XbA%UJAA$I2$Gv+x%+? zMjlcRsrP?+X^8nglGjTO)0b`!V39nPM*dd8)yi}14&nX{@ws^?bY9*_?cN}Z24J&4g1gyWqm8@y zF-$=MZ4=V=xF6@$OYl>RA5iq2OZAAM)jGAI7v$Nldcx(CeR#sqDmM8cXZrNFH2UY|di%e!wOKEcH1vwm^CqNwnMtBB zehZuaq!z>1NqphMk_d4cI!R>YuV;p0UolL9+*`of-cK?jfw@o*AY|jL2yxb723s`Vq^Jf+tAt;vQM`DkGv*%$7?f`To z&XbBM1iqzW3=<*WyH4Z2SLT_kMbu|lmO_ASLYDs>&JR9nV}5wkU9o}Xs(VSo`{{g^ zud})QF9DY5=%RfI+kx#PbrQX~)j!@$0>>F?3W zIgPsZ&(cr>Qq_+WhWpefIfAp{ev0mcv`Vy3Pd5!|62)0f8Ce7IH)sQCLb8VXusf%s zc-+8P+Q!I1H3+Io|57CgoT&nJ~vrmNUlIOwJ*bgL7 zDMEXMlq4J3?aNd6@mWTtR#-05ge*8)z-sp~@C{P4#PbeMwIh9OY1Ml1wCr9>r0kN{ z&sg@RAf7R~0mD-mwJeo5R;z&BP6^`oYcyn-f;O-ud6SXxq5*QPhsbq~gpAD9cFQ$s zIs34-@@6SxkcXUd;0vD{nDj<<8AoExg?bX|O!9j*`A)Pz>aP~<6$y;kL6t(rS%x(e zq1!!poifso0woP;m3U^y_&!53VK>NCiIYL^w`uwas(sZ@cXPLRMfz&a7frAiqNW<5=8=Spyw^`aKClp87*6g&|{B0uw10kkCrjpIc~h1{3>P0 zsWL|VW(;#7O~|?bIPtly-Gs;WJP%C4nqW!t-N2;|AAix4^MwZGeuZ}`8DUN*e;!#v z7Rm8WmH#fvodHu&|5lal#=gz@sQeE6hkJg?`v>0Vup~n4n@00ptuE=y#TDu_KbdI3MQ}FFZzNHtHM93Y_5PqF+V)dR$4-(44cO#?;sqYfX>mJ+0>>9RH2z*yTTAm+1 zQM`Eab$Vo=J9}6qoxaZ3Jyy%x61GFVi$+$Q&HHtKuI2P7R7MbKl?dMu!)>w*G+|sL zBSsPX?&w5}k92=Ebkv)Uc_&7E@B0!tAYZXncFj>provdLLGxma1 zCVrmAf7`Q&-KkELzCv0hsveuhe_0;IZsyoCB^Qhh_M5yR_nyhK-gjov%gULMz+S)_ z67ud;G;fexf_ZtmD(%2tKw4_ZooDi8*Nd^j+Eo;90tvKDNW0F@>5iYG_?^Bjl=B=< z&XwP1d;F(SIyG2hwqL%;rz_&4_)ksSGEBiU9A8i5E#v4h>T_&1H#DtmL;}xO{Hj^b zcANKfXX)Ae)Uq=kNMNmymT|~Ehv@!Z(R^6bG7O_t@y7ACmHYm=^U~?2CYINXv^N_} z`u&+aXNHS%FTr-G#EQ!`=!5Oi{K~hs>TL=7MH!1!HkA0cI~me*VT8pFG|%`>76EEj1RBP@5GYx_dEIn7{+xhs0}!W+Pd zUr%elEL=u&aqzn#l&CWLVos@@aP9+o7(46{GP za{CAK_&k5b2G$DiD{_acbcod<<9WcC3JQU@C9I+Rn!Lt3 z%klKWm_&@f`!a#ixcGJUmkHDW)YIS}My(Z=tNLnRBJdWjzVLn~Vx`8-@vAND1ynr$ zv0x!E7vAC2d*VNgO5nGavKNZ}R>wd2|Jw$ts{glIVSEWjW`EfR>i8JZY$akpaj_gz z@OvEWkuTdo?@smA{$bRfPtF0r{WWSGfDAUaBSQCn}cX%MXq@+WwUHe+W#$F_$9< zf9vHRMy(a@a9G2C@>|-0U%23yqwOLh)498FB7=;VMO%5 z2w5Y`Pb|ki!FB#G;u9bVOp(7A%`05F!IueKwKzk+OyIo28CpcdT4o8Juh^#l z+X&)J#L+CWA%9V43GQ$h4{If4jVwR098<7&v4)@WTC!UR+{5wR;D05oHgI3U)A`E; zrl6{dXGu}5tTvE94M5d>{$bRr<$r>UmmTs_8errWl2?oPQ_@7^1FFU-LLH&mTyciv2VMlwf5z0 zYFgJG8qzXRU{yQ4{?J&qFu9*ZpjO6`iq-~XV^u&8vhgrFufC{oQx+XJSPXfyEyF0) zZqFnQ?a4&?xLkVWdGlD`qkSa;wYsm_mSNOs#_vS3k$5gZTRSyB8y`4ausvx9em&Wy z+Luq((B6eJs}Iz#k*c;VluSg>wovU#qq*$zt^NWfNXRy(_&?Jl|7p(l%!w8SBVSvM zXO$n{MEj@mN$asXscOWX*4mGhlUCbj+tjviTa%4%XJ6`l=6+D$rgV`A)attCweUKybP^+n4%OtYLmm;F+g%;v!YADNF%T8RiglqTfrmBZ$ zy|=!r5~y{gdJKurt96VCH|sHS=6kE6%*kpg>Nc}sKPu-FC_!R=NT6xsaYA0Q(Xn}) zxW2wQI}v|OkH6Z{EW3;E25b0*R%W{$UwYQM??)@sM%RgF$;R2YJw%NeAJnY`@VtxM{~nDr>%Xq5F@vlx?5)lo!j z8aGHZoAXw!w@4;XYv|8-lQ>~horo(Hhl|TU?O6@qUfS}!$>wukmg|)sb|ulQ)2c~_ zGPb@-GzrU?gBgl4dH!%QuX?ts73(PxsI|@|*(3@jB&93L-t_*Wa==tpXy-%Kt5kw! zyaDM zUUd+8A2w%?)Mk8r%nz&42N`oh)fZFpn|*%Hm8mh4@|*SO)R+-brbK>m^kNA6CsQU+ zOAG&DzTwgig%XjeoYw=2MX~3Py75cXcUX5g}G%4nBO?U3Dxv#cI4) zhKsJM=ikYujeAeK5)tM(TrYj#t=d0ephTcnBfAjO#;5PyiKuvcr8eeSBpbVa2oIZ? zKf`#g^PMEMLXRx-|7_nxHTv0$4CB39sAD1#yT-57zKjZGudAg<1ZpL;buft~*-1n^ z_1U0KaZYD7<|Omq^|Sqr_j%}>1T|lbuYZrgRCV^IB(=;~KmQKzwyBxV3}RZ_gR0|@ zk7|p8!zBW>?kE5CH{b2Q6NuP-ek?1NE1201?#p*qooO+?A3eWxQS)XtvlxBzrb&18 z#-N!NqaI-nMqAZjm}yz6ve*g>&8p{bIW}0 z-RKsoW4|YD9+a)l_K1@R)C#6miD4ruu>}!b?iJ*g z!oh6J-zI$I#)7I*k8_hOD*O1ywAZvoBh~d%LDi^-+Z+oK>pmCa8&n$>-ywpd1PPiy zrVT}HM8r6UYCJI}ifw&Vomab2SvC6HvrIX)d9$*rQI9*Ls;Y-amNjj(98;BuGHa^w zRpUZgc=H+(fm*+_E1QIkM>!(0-ZbW?*EMHHmK5gW?>114J}>K#L%oyN->g%{qQdIr zTK*>CwXiS|A3~e)tphD~1Zp+N(NHzoYV!|;h~rbb@>;jT*xPg0*~6Ue)KR|l zZqJ==xyPHU##~u{_nt+q+|IPI;LQPg!)5-A;S(R`V&(X4h7u%f+c!6DTt9J-h@F#S zdEdlzR^4_qE4?mGO{ei{$<@}y5v#_=@uZ+MmfBS& zP%Ans&Lo~+XhX!fQNwt4uno(+o2h1OOft)!V__%%g!u`oF?J65TKd#P^SN6_UCRjgm1Oo!ldLo8 zt_6}!!llgx`f_%Q8O{sk$yWb*%#;Y!sCMGD!pZgN8QjRJ(pUN|0FJE8ev6&e1}|`(m+t zeRLE%HuIDozP6<~c5@|Nw6=fH!5q7`zqeRjKXou|M7P>X#N9oyykbHq3yL}|5vVov zKueQ&nR=0k7WU0~`GUdhOLu$mVIDW@F=Dp8cJN`CS&x$;U#xXkaFfVNx7QS9Vsa$! z7iz=u_i+#?LE^{jFtZ*B$zRCEP*Izo{}#!2ZTX?I6`TwcOrM2nL zYM3@AZ!b+YDiy86-=>7J!$)KSwT|B^ZxSVjdlNC+s~~rDO=s1tOuTGtWA^jMe8HM^ zVli{<-acDb>#(4hITGs?_9dc|R|)RC_@mk>I$R=9tGb_!S&!URgNg7z{E#`6X~0-? zXR)=%5%WE0`4pwCy!hN=%)uk)nre9(95LUCk*;lsxL))rTk$v-8!@x1KnW6etSV=U};CX36J|p7UO;8Iwn?I)?=B)n2$|{b<(0bEwePEIk@ytCnCnZ zpT%4+gt9xKy(I#*a_vkqOZ96*ED=k8e^qzCjAVS;U@@}aLVx4ETQ(?3TV2$}w0BAL z&?et-Hi?x_jNRS)!(Y|;mx5S!ry&x7S`D@>@HgJnS7MWh*qbm&U)Z!XpZPFQ+NB2V znPz=|Wt;jfsjKFEZ?(11t8MC%gV98U7OAC|9@|b#E%8f72@=ztcbjzWDeZ|UTjh`* z=n%{beCaN5e}**f!W3mm_B1`SPECGbv!_4_60(iiS5>`DWULr=F;hoLM)ezJNrG)7 z|7SngD{QuY&^?T2H?Ws%AR+GuYme-pAMFq)4)`9{QG&$Ft-nlrai1Ge&0FcVqUBS2 z=2WM(z&#-DSn)|T2dmqOy1pOOi5q1Cwd5USfrJ%$;JL{~a+Z)#G`s3_ZPQI8XcaMIulO_q)d4d6r%*Dpt5{*sh}liR8F4 zri~BD)yT$8cTcf(dUN*fOg(|SU)*`)lN4oMu}Y%s>JRF{JQj&SE!-8;NL+bYzgRYm zKMZ)GqXY@LR1dwD>IFN+iRp`1NCaxx|MD^Gao*F7>QRB;(&O7i@UqoUO9X1k`*WwJ zoAlM4(+3Lja zr3Ffmz}>x~+;?gu<`;5crB2EOYN5wMQF_)))1#u}#G=t-bbKzBuyxG_W<9dY<)nH{ zTbV=LDHzV(k~T^NYN7u@QMPwJqOWZgD}1Z<&{2X!{DuhA#;SS`>8o}!v!M7gDuj=F zzfeaB`Wf&^igM?AN3rv05PO;XwPXVc^lKQQeH^}YN6*vQ9f-wtAD3@ z^xNM+D%IziZOyWe8fX0jilE!Y)y=8TjAW`a0 zXVacd!avkjb2{c1$x$?RL+j}%L4OH8Nm1r^?js`0WvlJNmPv)Q4Exwy40&PC z-q(}~)ItxEqBNLyP5(;sqvpo18kVZmh&Z$Cg<_-s$M3WHQX#RYbp)U1Xs2UIkdXa8 z6P$1BEvT(BzHQM^f<#baf@yDPL;-56sdb79PBvC8*rK5X{XY04MY%U~sMt~WgPOXe zh>j8@(9fhOw`&d+Tc6}&JrBtQYN01gQ6B%ktN)~WY&y@SQrRy}GRyA1zyJUEufzyP z5!p6^SBtaPup~&x{woK|eSJLH@Vi;ciV`IHIV79*3f5oarzmH}JBfaU!ufiyF&31d z&kCPJez^4EA~?c^rP$xHpaco@&(ZGAbEr5%yVTJ1XBLzof&MZ@c~SVW9@H{c+q<5`)3B4_>vY)gkm3?)dQ-;dUCV|t5Uv*$4D z&H#x(E%Zzp@AkX8LTz<&=v5Uzab1|R=rk;ZuiR9ZVM&mX{fTpr zT-R5$juq#_3o(=+(LbT1X>ZkvP1IJ?!itDFZejey=NN_(^daJt6vgv$lDM_;tvWd0 z0m%ju=#5mAxl?KNGd$?MkO!atwa3RCzVhMBHs9@G(K|e35N9${udIegGhTgg=5vV2msY_R#uV>NN z-Qv2Qp#+Izn~Iw@(jDuPjVh5x^y`k{e8-%33?=BN#wXD#!{1d5Cg0MkTfrP9NT45H zQPu~S5%7&0R8|&^KWGF!!Pd#3m`@wTnno?Vpue3t%MQhQ9&ILG1(BFvpt{x~LZ9NT8M+Hz>T`Q~#%ZoXB-4lc5BO=mINE8^f!0A{%~L zP4ut%X>V|}Fh>c-4e&{dGUoI~eRcyIHn~eLjuIp=E}|&aKCRQ|9|&ci?Wc}QSU6N$Vbl$;3z>tjz=x~9;^PP@5k+4cNt2MPzJbJ4152M ztA5unsh!Es;1g4XqXc76_$2y9{jJGz`)8{&l@yK=Bruw#DBGIUWX-?Wvwe$)NCawO zgpP=9x*A6Qt2aKk7?z6Whgo)G4x&9p`F`A$=|#eMNnVVj1PP3{8L?Db=1=45TC1B3 zB}mZxH|_oVzIwK`6YDoJgiric(s*C}C&9QHK1osBXVL9ekcVj~V`}ch42$JNTajV}kf3MX5jZ z535opTP<0k14juG7+s{8SJgkv_!=mROrRD3FxCuc zk0NmiJJ{6X;k?}0KMW;EVEj{2wEL6TP8wH!nQ06qNYETR)}LrkQAV_Q%C^(~BR%#xLkSXcyjK6Qj(O1fs$uF9h7u%b?Q7cmcisN2 z-FM2j3*j4{T{70~|4A@ji%(LN(~~&gc|V9PT8o~;@r%zRSt&o6KrM`5 zE6QPQ7mFZYU%6A$q*BrT!z`h($3c6FQrbtMHDVawc{q)s1PM9rymnnWJ4^5On;rug zN|2ykpK0&k{ouTs3U5I1`Q8`)F?L@6Nif!oPg0aaW!m$p)84AhzY_^jZzM2ktthsG z+wor4?ODo6nLsU!#8VV`_))VS54~DSrJ{YUSwdsKjP~d)+QXg?wh!kUH;rc~K?39I zigM}51(r;HhVOT4NfCD>W+rqr?IoIihJ2%{b2h&c&oyi!LkY&uWg92+B=E=gqbLV* zr(^>O%nzW*&xFpr=H1=w(9sQ2<^U2H;ir7Kae4XaA1%a{w{;n6VF}51V%Ed(&nQY1 z_4k28c6|Ty9uk3CavnhOuMRwpZ16v}3?)dAKghI^Io|MJCDmvA^QMaYNnr=WfAyaP z^8oNkit=mQ5Uy>>R{bCNNj8wcyaPq~`nm@%P%MKjJyKgDPzy67$bTj7u*&pqPtQHU z@L&B`D)Pgb{~P{2w5KS!b2;*2EhBh`K7CY_AR*^7j5u|VrP8|ndD1frN{}GmqG>N5 zH~gZHb~y68og;Xr#nJGK{wKkF27HpDqtHZhBL-q|G zEfJ`N86Gr0>fL9~)K)JJm9k=~$iHiraOi>l8EB8bob4Pri>7zU{E8Y%kdX5-Z0+tb zKU%}tzTz56knl`QFzsD5{q1vJ72zK#&KVzlKtl=UV&Ib$WlB~GZyHGBs&^S3B}iaa z2fcj`q;SJCaj2|JpcZDTD2jcB+h#q!WpB~&xmZGqA(-_r;uPc^bSln03xxA(w~FdW zpcdwxP|o%)TIbTYHn8h24JAlWT*S2TZyYDCniHR#62d=(wJ_p1|4A^v1fQfR)AuFw z;`c(>C%@HF)(aAt&!Q;jeUiET>*g$E*(!-ZEzGzf@5j(9%I#x#s5^|o5*PgS?$OD*ix`B=oBq$1K+WR*?I7`pVyEuh$7vBX& zB6=EmTY$a|h1%@_2|R+k5VmTVw_SvHC?{!M%Cy(y9v$ap6asD+tBl$VhqSW}vV zl`q9erJ^{eSwbUziuP!foL6{Rim_~Xwpm9B5|{@>*`J&Cv!&El^S+PMQG&$ghJL2K zf8)-54`#DBv=(h|e?~_M=KJ826y;(-F!ve|$~p~kma>qL!2BZ0toa(mKhj#XS$Ua2 zEzE#Yl$sN_vN5#2s@rm&R4R&}n`M7Vae1^y@2l3YnGu^9&mQO~K|;=3`c`fQ<7A`V ziq$$wkf59Z)1E_5BcGvm${+Te{0uebP)rphn74#aQj|$6J-A4VWSZhHP=W;JWl?_O ziZcB14?Td26F20<|zBO;K_U?Z7TlJxZvT7;yBy0@(Pq7f%#>WIdDN`zv|7*X@zt!+98RJzh)sT7Kdc@yA;WMh7l$nVH=3G)V{JK?p z`A;a{n#)p8? zu@)pytMJ@Jt>@BIwNKhO%F8I-*hiENReA7@`4*HQA?JR0F0Lf@ez)*dYnP}FK-@nR>(pW*v$Q`yxML4P2@;swpeU=~IEasK{`^4SOo>3PhI@R? zQSv>+IDNK#cwS*wkMqwPPcoDsA?E?)t(ij{ug-W(nTHaAT0L4-GTW-vS~s$BzVRP@ zeCI%J&2@*N1PRPLpcA>2`}koj=Zj)KO9W~;Ty!?)`9IcbWTR)P7y8W#_4&)w?-)vu zkfZt~UtQL_(md}{BL_zUwFdrjF#BU!50z}J{BcKLdf1O&xopc(f`pt0aI4HQ-D|EN z52O=GNT8M+4c}mMLcdb69$!8fTXtW^U7+=hx*wFO}vfLE_}# zD^|O;scPbV<3!cM`$&FF;}`b3N(5@%dbr(cjNNK( z#yKy$s9Cz*W}w4Ltf#i@wfzO!64vm!?c63G>sS-fd2F?YDp=N={Bs?Rb5 z_}W7iI7*Pf_%KBp(yFrW8v^*}Lv9j*T3=tyu;iywo%J%tZtak2EU{4lxA$`AC_zGw zrQR6ZgSl8ZPrqAMB2ep(Z<^&2jjPvX+Zi=))tx==8Ne&jJ{=`UVC#-6eh{I7*O^;GJRx0i9oGR0R`1O3sco$r~S#s z^|ZyyZpn#Ka9yO5vcX;Wf|32qn1BYg>2*-caMf& zDG{h8N8(QZF39u#@a1#Pos$UEim%c@HNJxjIv8hPTX%KhDRt`dQC;^ilpwKpRD^20 zcdLE6OKZ3z-HYrz!7SG6ur znU<5G1c_0ZJyhde657T%!&~&dJHP5tpRZr=T}25J7-OUUaQEPwz6SFAanmINwZiu% znr-#9k#Xv`(6cH$t+*dAJbi|W5+vk!)YiXM`K5>T_{E#P79>y$qgjg5wWt@@Y6tM| zQT7snTGh%Ysm3=-j5kgek5;|;{OgP-Jo430f<)Sq1hXD5Mqj0q#m|pc;R84N@z6%W z8cL9mV@2Js`0&Uxe%ycRJ&8cA$OYZa{wUbo-%5L&E9LoL+B${pYFI3Yw?;KaOJa5`CfDfwXrlSN2Igay>Z8?gTaDH)i2Z=zf zRb@M<#=dq@`WCXWVzx7H(mjZGs@O_L2@)7LqH_~pO7b@!19{HT6C?t)*1d`}-@Ca_ zULYHz>zCqfHU{!smnP^aK|+qPlb%Q&}PZ&7Z(=qBf?*JT|gNMO8$zG}a}us>@! zzwjYTB2X)?L3wlRPQL6#HUh@IXCul4@nr>`=qN!#j%FO3e}`QpU2B@TSkQG$dV|7i5<1Zx^tpLa~Q5lEm`w^6xOW3RTmcM#c# zyL+C!N%!aTm)Z-IAR)&MYQNjhzB<(7xiX7O1Zwd&&&)NPLj~hB`Grn9*y=8UeEY5f z0wqXboI+9FdMspXPE)p4aB+!1E!n3(G+`O*UEadiS)C*TwX|-BEyifq>qSwiTxO1D zO&9y|BRAXxN|30$eW}G*U-{fK&V~oI9n8vhr#;S0CxH?qOQ~i0< zH$DO-NTBasQ3kPdmWdJd_>EE3Bm%WAzdU`wuyMx4@adn*b)kljsK>V@f2~G;|M;rZ7^3-napXY(xcVz{E5+v|GI`Tl?E~0m%HFeMycZonPyd{r_ z3B~nzHw#yXc}WCn6}malYRr!-AG=Yh#{bhww;b~0Jr8*Zlpt}i;A*oT8=iM0V$IA@ z-I>-`dkdEpC_w__6pB*7ag^SK+N%8mCy79<{CiJWjrnn;xN*+pi-5qNK$p&wkt`9rl&wtqz6(~Uh?^K|3+oP`QKJ?zb+}J@P zP%E}p3Dak}XJ&1((R#!w{mbh5JTyGFKnW6f{{VS8$3M`c^49001HS178K6;IKS+8Kq649;mXF^bo!p3-TIE!SBZ=1cAIhad2H=89VJM}=je;| zFDg#l2;?t!E|UnEcx2)*)!Jn@m zHc28-Ykk3v=Ge7gipw%1x;619_v3r6dBiY88*y3?Es2-)dwden$nd zz@PJz3rp)LK|(%%o;1(Pj7>bY9@UUQEsPr~%IEJDM2W6}eAULoQbZ96j1`jie4wk= z@X1hdp?(d0*VjQ-t>$)h_^%l4N>fksf9t>=TBDC1))#Ku)z@c?@5lWMQ?wQ%hlo#e z+v(MxtTweQnb8{3g=-JA&T-nV?xou#DWW)kzBcjLFp=|WL4ELr*CvrMvxzo!;YsuV z13z19+curBdUG+w0-Sbno*CXA;xW?OUf41 z4-k=2eWCW;p|we5(v6+d2ZfveH<(Yl$n&k}L(jEjDnY@8|YsP!q_!z9RCZ|&JV#{7RwTt=se(CT<-rY5jCCNqbX?`ZiBudX__ymQb6@?tU<%LhD46c+lt&?a%Z7tf9Z7 zQIbDT5w&R3I8*CXV2ou|M1pDKlaHh2?V28%F-lh5FJhsU{u+0^Z}KpaP@AhuMt3%~ zc<>?%(w$RdO&j81S0Wzf^VO}?AJLrIjYKSba7CS8$j2n!HFZ}L z2h}tGpO#fromjk{NrXfh5ye}l9%yODhl!<;St?6#F>C%}Vs(jLV69{l&-NJc!R2{g zYg1;Whyimdvv$pYm?Qkml~A?(oct#Jv^!JV)Ti4iX(Vb(j7VJ37aO&=RP!>~e$0OO z4wKkevxAzt^Q!s(>$8#S@~(8(BoRlHNFtuOPST=?Shh5fW&N0JYJL9RRYf|f`3{pf z7Gp%Po-DT2jMla99L9$1dy!$xsAjzq)k%+nO*(g>c=cJS5R(`kVBChd?Qc_S6{`8p zN#$6%NA@Pst5C9v^lFinVT^XSBSvhZ*^{0YVr||95|% zq|Ulj$KN3Gh9?nmF8is)nCAm0m1Th~BP<4yqjz@|Y43S|Ok!Ve!^WOP(W=pdb9hNM zJ1fm3dX;Ri-kd$t{J*bLS2f4+85YBab*d4Y=u>#I`kRQMcZ;y5FK(Gy)d#guk+#i~ zW)h$07?Fni8(ymM)K+1gJekYmf~rv}k0ln>uTi#1e{Il6Em-7_#psVHtKrjMRON|T zs%2r%)YgTon8edc<|mKs8ErE<+(= zbYKx?%#V3Hr>c_!+L>CdULUX^UHL?WX(MBtaT}uj{t_nfzEq5w)T@h0RN6b&g7o#- zZB63AS;N13Zhkp7mqy9iPySXPYocnrKX#42p3%`Q*`z1=TuRR|Jjo<#-#bm-LBBT5 zm@z+obj+!FjgB|9Hg>Q`w0NIj*2D7k3cYQ36{$lm+5T+V$(XUa4H)?2TO%_PR zn4AUmgts+Jt%`?AOLTf^ZW7~N4Byg*@~_nVRP!Q^_q9q7%bSG%Om7Y8*K=!_rF!CR zc#482UsL-KvGCSMZS|^RW?Rj5_m$|!bIY5=wGW2h=hQTdYJ55CmM*RPtVuFQ-xta= z%Iv(%{J-OwPTJ|HW#-t;=WN`JH+faxS6hr-YR*yRtj4<}^kcZb z?|iEHW*AmCNkh8z(B%gV!rI%|I~SULSl_VSmPNHokak|UrwaPM%9`;wRrR=H?Cz$0 zTA&x|K=&F{=_?Vah5JOxvvj|qUho_u7F~(Z2laby)-ZSLrkeeq!xp2ThqR8;E`%I0 zTV!la$9%ElDwR#HTq_xYP8kVt%kQY z?^+92PPvNm_M`NA3%Z&To@>1IOwJBwsTL2|Vx4H)%^X)*c7~sN;}*_q{%{kY{K-89 z1Zv4M>T`?ctmDKKai-xui@#%n`K)&N^I9s{CYdv}LP%$Sk4s7BESNj6BgH0G%=2Oq zv;we+m?jaZHRx=DISYIj8xfP*ms+rAsdOiKgTaz-3B5)y-zQp(nrCe?R+f!(&SR08 z-_%1+{Uidl%HCXR67ItAs^zS-k2!RR*7J>ym;7+3g(HUY0O~wuE&KP^UCMTp2-Hfy z@XTWL|FXZvT9kQwWF7A~iH3jO#q-~eW*-cyT34I?J-69btHuOrIXe_H+i>PU!}t2g zFArbg>>}KpOGpH2$!~`G(al(SuOXt=vb);ehAqu!&A)L$daL2vjv{g11hC;6{hux%Me9}KOSsLAvEOyMHx>Mon~ojv+U z1Zv6elGWXHb^Q`oVf!XTjO*=e_Eb~4j~=!9dApdid0nV+FXz%J<=Km`ZsJnOK#4#t z`3+a^^cFR7*)Y-2rhtxpfNu$To_D-=N4-bisQ3XjBm%Xja2s>%=4%m5rJ8v-gkSmO zB=W8e*0=ZRZua?SCx0u>E_p@;W;EpAYq*GcH9u=epqBg{%vex~ovc4hjBE1U66+pk zj-}WyZKT;H+VLRtOuDPn?KNfimys`n|*?DKW!jk}7scb?1klqtZr&+RMukkLDg zbA@ggAE&dg(ZAJcV-qC;wb0{DS*+XtutEbD>%J2rB%eOcBJ3wxh4(AWd$icByKE1T z2-L!vMqg)c%WM8|5{cbP3Vb`GmOP^Z_ZQ)1O1p^hHue&MTJrZiU|dON{b!hXUv`2v zZEy#34tANl#ftM$o>3cHIhyOMthHk#0=4=m9nJp8-^aM!_{hG_`hx=H+2H7bk|$OA z_EoNEeSIrcJ>J9EpSO8#r)#a<81I!L5ejJG45l+b>G{QEfRrR^4e$h zf*hs4II86Hz#dG#}*|)n=V>iH#GnZ_r4->CfUXoVDSQ0EDor*7ejdiT)BI>uRB=GHwJ&11!S|^G$Hv58;xG|=V zM4%Rasp*cbCm+X`!{Mq%Q3@aJq1SMB z7yJ8LM6APVtFhO&3y#*XMRE5+KC;yd^;VjzxZ&U{5vYYNO+J}d4YXaAhKOA&VkK`a z_CMNFl!k3{iNnQhMc$^BMF1;ddgzK=sjZO({9%=YNZs!#+KtR5(DRFR!V=Q_pkK8b;_o4zyS&qpKrO5v-Di2dortX8 zTQBRg@WRg2c#gZes26CmMQ}yL5}!jJe`dySml!O@e0ZEOzBH zehundo`xcLwW~OjT~#7b>s;w1ZS9*>_3nmV^o3&yWgA~FmJ&7QhKt-qmRpcOty`&y+KgP=)wIpViK=E>xkzaIN4>Mp zgTHE9$*f`Nk#6eD?Y`#!RlSR+J37lA}s~jgHd%P0=3*SoHb+5>Au7`F%r~< z4vxJ|W(j%?juIsBT!^Bi)9vLSzV>GCZv{#OYT+pl@-sAx5XbekdZpXn*ywM;n&IC$ zcrK?}!MmARkEFCm7ChNQ`|~+LVnoiBy3K?_90}CIR-ip+$x3=>K1>AOcaS0t7+sL# z29H=*y(sO3!@rG`A`PhZ&xuZE#N=t1acb)yJ69p>-9_h|&l!6Av4nD|DzqyumRVgz zf}WEjfm+zJG;6Cm>Yg;Cdaj%=#Un6+A;%`}E_M*tX#ZiUS(y8{xR`UI{Jm-_Mi;PW z>F!S(Mfg8(64A@cNCaxhk%mP}Y(@I_E&Apz4f&4y-%W4-6^BqM9)YKBXe8Epqvu<` zT5o(dQX)_bPu);n-w;n8=D=H0lmy!i>qHrJL96tz5o`76l`;Io2l~A$8do-6 zT~&-=;E5Q@zz*oE2RC1-_ZgTd5vYadTgbQc%11x;=#FKh-!MM0--`_6%fh=RsHlaf zRp{;0b-A`G|5J;d+enE(E$=f98OHZx^%vvBh}QD8c7WClfd}Ns5B34JI^F!Wo^JRX zXv4OgAHv@$%}nn{T6~-upF7*%cn=2TNKo<2hN5II9>+eW*)r!h{Urjm@B{{BQ$=jk z+^AH2`f@3{gb^idX+=pGJw|Kj;v(z@gz`;y3z)0dW!;lhjP}SiPy4mUKkT)WxO%3s zM4;Bg#=ig7>7LP6=PVbQ@xD6qu?s)9@216Q#h#Ozt9Y^lPhgN&ZO~6vAx|GxCA+mm zpcbB|P?UB@rfM~*=Ci6akfLN*TO4nc8#R2q|DrE$BBf^t$4C{nf?SV3)#j+_w1ykk zy@5oa7WO8^ims)yHs@woKF#jUr@Y-`F-Brj?{@0=j2UMC_blB-Y6Y56!)~(Hd7RW_ z`eivJu&uBS=}d#ed20#d?eo)JiuPfI4||iwRTq0U?Yo-@imcC@ziVpx+j8EFQ!)C6 zBa|}eJ~v}~(_Mu_y8wwmEvysmyquz}59be~6U-h`L=j7ZEloEFq&8vh0;c|F6pEJk&M0@|AO8)=@$={d0uPcG{ruK$=Q5vYZJ zEs8srisxg0Y|(ce7{kzqj9PdihUWS6o_zD=Rk~yDsuF=(eb@IeedImNvoh($MsSY} zU)9nN*I4v|NoG&Aan4M~6DoL;gyx6iaGvgyhs}PwKtlqx@FWSX`gT@l#*XFo&r%E{ z?%1Lq2m6^3-WDOo$*?Ql^y_sQ$*jPqVv0D084^AD;6@Z|1e5w?v>8o(rM3 z=*@!67$wtYu9Kqr7|F-pr0;pN20VrC{BYgx#4r+veIvK-)VpCk^RAO{UAjU=0<|!X zM>hhU9>|?Xd{a|n`%3Xaj8n+<7+bP8|8&}hm7BX(B2WwCkreaFlf>`$?7@nLtkhA0 z1fFl9Uy5zljqh68pKUI1N+M7TPpi-z3@&bZ=x1illrkW&MRDFxmfo{IypmX}A80T@ z!{`f+G3%JiGXd)=9_WAWZ`Tgt4!l3;9#a)jUK z;*CTyyWtZl(BFzDJMaVsy+sr6u{l}2*{8YD5`kKBtcXRwP#ar@iKL4;b@arehZcSP zbo*Y=eCAu!T?|jHBhW*OUSNz((R#OdJS(x^RV-S`Bm%WCW}ztkLThVto6~Qy&K@q% z--;0v%-^T?)p$F7;Tl_3xA#zqKrPJ1R}@Frm-_E@Hf(3jE>b2rdQxTYpvRWOddA*w z>MAWkB2Wu`i*#1TVX?lZO|-uLK(dr|k6N-naa7@!`u2eSdgCX3B|_4INE78i=1Wp@ zP!u3-lar2pfPI7cEVK$=yIH+6Vu;xCyro2-7UpPBu5-w3%VEk+*mf{NM_(~|YcZdN z;No49zsnna)$`r#Gj_1(Yvz{t&d32V4OpDt>lU!3l(EmusuShrZO zRXI>1Pzyb`l$jaOK~&Dui>>_mT*?baeQj7W!f7H}yV- zibbDGsM+~k1ZLf%7Wzo(mb4oFBKP&FmY1`uNCax3KbiXc*fl-hN>{P@eMKoYf$b*y zcPHLY(|u`I#4J@M0=2Mb=}mq5h}C2BQ1Q{hPsh=K9(v5JrqdixQ>{kUWcS6c5`kKn zWkt8Kziy#@SvgER%k3g%VB;LbF{UU*n)MY~?|ZYHg}X}G>X=pT-n^UX3ok##$cXRY znKRY+Qr-lS-p|_qgWvOK^@!LF-n5HDfdUqo()9z4^HCv>Ut7^ zTG%)AZeKD~-0~~QrjNX@)~eq_GjdfLx)qgfm%}^V=*;M@LE_4CH&)0klp%py+v>)e z@2lj)#_53uo%@T+lX|c>A;AnKNGz7_hs#l`r#SqhKWpzWLn2TM?{rg?{keyUq7JtD z>b-BITkEida^xp`!*H=`w4L7Gez#Ox)WXq2I~K|qSbES^3@ZFg!&by`EJv`4dIyME z6j$5eR$F?PpceKlc|UY}-FQ^+946Bg(qU@)}g3C z{nYFfVR^P%I>&-<1{`A)`DxNdj6Pd{jjWcN;~jJuE5b+chF>GnO;l4C_#mwJgcXXEtzzb8oqYT0Y(smGu?06YjJ0v=!$zMyl}ym!f2JcGZpZ24T(jO9X1+i80C@tm`dm zQN+DR|MLv{03)Mv)be@n>S9tMckx+ll?c?rzM<9YYCk=E(@+t5ZIuOUi*1GH$|xr_ z#ZFXhu~OewD^$8g5Mz)SW20Q>P$%*G*joMOTR(|FEsT=VN!GM{;uYo4SN&N)inybe z9JL&syO_vF@7)ihb4UbgVa${6fG8KFzod6!r4ox(9DUe6c-oEHYE3JBKAlP_zInYw zpcbC4qm}-zTDp;^AJF)^bS4h_1ACKZZS-q>S==Vwad!)jF+q$}$`Rf_hfeArmM_su zG>Dc6)WSOt$!FQ}xxekMp6uuQ!P4z}c&{GD3MtFmf1UP`c96Gb$^>d*B-Z!^msR@5 zPAm2I&0?h+7cnM`kw5aud_AF0r?~Tpn;j$qwXj8KcjsfH*CQXmtSZMLmXF_A3(p_)4h-W{@WJap??pK5k8D< zVw_u1PJDIJOTJm9=RMdn* zhDii!;SGj#dU;b#wsfNnyLxVjbhjbKw&hr=S3nBk87) zrQ_Jj``^?}5d$OwwJ@$qIom}?Yc{k?{Yq!hv8S*`4Ce&a7SF*dN*TJDV*}MAcx?lT zKrK9XOJC>5fqc&nTUKtnJIA;z-Y6)?Am7fb!OzFpuwNZ(Nd#(P1e3hyH68e`+$;5) zI~qv$Kw{Jyql=1?dH5~sadfS|=tB#MKrM{jQs%(GH7te3)v{&Pq+A2klH;{!Uyo(y zS9ww{SuKe`EsXZk>a|lBKV$z`huB&#bJ*Q|q>T;>xgw*avu8 znj*Z(UgkN-Qhl8z0=4jTwxU=UO*1p%4>^^P&QxQcV@83ZxL(f3XZ~^%E8077jJRTK zTW+h@t4s2v`L6VS8OCnGU@42Ma(uP6&%ZJ{~=be!5 zfy8Jy#%^i<5!as=OiN~!(-_0MCQ%Ec?{q_K;(4_RtrL3(*JapKIF|6dJDq(!zeYxt*~F4y76I94(3S5^U#kz#+$IsIg{@9I$W>MO z0*i-8`MW{N%RnuR_R>vLgDUW4G^5_%KPnNZg=2yGy!cf!FC(L%!f@W<*u|6nigJAN zdDXZnme9bK{lcRSXbWqJ8YXT$MkL_EPozlk)g zD6jOZC+j%5k#v6|#?SG*BAqdrQGrK)i(@+{`bh+8$u>&()Z?AfpIC0Js7Aj^@_+YM zq86SQQymkH-dcH6hff6L}%oq8{;+yfdEqdxF z4&;>x)RMCkjwkx@aUY#TyJc&o+z;$c%!Z)0YE8cuQpQC%sAD7oweaK{W$*l`z>HH6 zr{)av!_kLh4DYp2lzCN}FyjuPwS7YUkw7gx?@lpcy&N;vK8H7!mF}{@%!>ahRr4U0 zx8@MxwL6E7cUfTjV7t+IgMdiZvgIH3?9jo|32Dq>z%%uVQt8)QHRw?RmYzOTB2WwO zRiqVivqUlW)(QRb-PT(9Imz0K{@c}tCxg=`w@A{ubl{qvkO3lgYx@HyQ;6tP_$vvmUb$QE~OOecydY53fT4n4FtUAC*= zR#vy7FH2Ec&T1w`HZ3Lwz6g^D)Ixg{`RN!c9LR>d?#{xO#+hZ0DH3DZv?0N?G51zc z>CPfWX}-3hxX6o%s~vkt1m0}qf2Ff_j7r$H^CF5(sGkamnsrKw=OrEar;RRJ18Seo zJFBZ7y7`zk{Cl}eJAFmz@ieFSnXkBb^vqQvaGyEpLVm4x+;%m6E~8wv?u~MYLoG^) z5l&?}N|3-EEA4BGW$En~6c@=(l_UbS3OxE@HCj3{Zy4G56%fWF>XsC%FO_6}3MZ;Y zJth~-YkA_|T{X0VR8hSqqkZ0;?jhaXW;s2`?7^BXjk{j#n1^1WgO5-)`wEm` zt^wZTMw!dq$d58l~#%YTQ6`(JoG(HKCHoU9Pr32@-Nt-)rz3 zy=rZ5;k34zM4;BXJtxijRcmV8rd#OLEIsX^kLcaGia-exm_->lel?9zk1E^_OjZL~F_xK?xFg9~}|T&*^?PULvEbn?#^ivwgPaobUD3 zxP$K7^V9m@Ha=qY`LY5fNZ<{2^t-PGvh;7cy~N#jB_#s2s^loC8DqgZr4HF>X7^OD zbk9fpbuBJXf`lAJ-r7LBE1LBfbk(>SZ{WskeQs+X@xswTpacnw zTGJW#ocYDAT=e^A`*KPIYSmY2n>Mnq8+Ym5dzw#_it-T;mMH=yNZ?I;lp%DegeXKd z96es^NT61&#ewFV+ID?jvN1kSaq(bmCGo0dmW~o6R=)k= zrj6X$#tnaqMwJz>Csh*N7G>%vK?38dbk6Ivr|_=hEq3qOEfJ`7B)qlR=lWOUPQdzm zJj8LDA7w*#>L@`1Zxp06?lY>2YOB1({%!Lm0<}stjW+vZW#0{Cqqj#DQGv$pqL_I) zN|2Bvil=te5gY${irc??Nd#&=-X3Gt5 zzs6gXUS35aPz%=&w1x}w70EQNhUBTNqXY>#N>Sp{j(OP)w1`?oyoA=f1VaL~aK%A4g@n`*4Te<`*-N4ru0K#qjukEP ztRZgG*j@Z)ibSARO82(r*j+PFSZM4Ps4muzt|aWer!bTtfw3vNy>(h8alfXwxcpDXDC4e?}wy&)tb7CO7w2uH}aT7pw`pVA=+j7 zMh%^jOQo^9*iBTVc|P#iQHByERVLy17GddKRS_1K=zk!NN4qr77(nRkLWz_J3|Q)c+Vy6)h^qJ ziiN$z#_YTt3Dm+hBjtsAU(wr!`-s6;oH?!sah)jp^!xtF(3zc=n7^;0M4;A3-)rW# zcCu|tYOAYR`}ISSJ|etk1&$IVFor;TgW=2dQaQYYZ;h%Ffm#b6?=aU4_Xl<$8*P)8 z=&kPfh}$K6I7*Pf`!gv9nJ`>Wr~de6Q%fRH>r=uMvmTzT8`&7=GgM#R$w$1KR)eDi z2{~@isz!ib&elti*FYjr>&})C)5eFk@nqx5C5s+U@7>E0bva6qz&Hi1!VhI@!)aa8 zdYLK_s5PK(R)(>LvkK#M=6ugzTJ#v2AN2biC_w`60i|E4m@+_nOy8)2i3UoUd--C_zH@=@-yv`tPN0)R}8cB2cUGHD7;YFVZL2u#s=p ze18{Om;4D=IZBYg7y_M2xin3kL2cD}AeRW#y2$^UHrmD-Z8dn7NbAD_UuYF_O$(fT%C1TRm~UgkAYw(2q>bUfQX2wfN=I6S_QESJ5dQ+ z6h*=A?iTB{J2-Oo9$V~qQD3{eyY;Twzvn*pJNG{y*Z0FZyJoLhYi8{zb@^Q%UNkI# zp#=%72vAp*H)GYt8@%~BwT?uf%Kybav$!5728u4UmLFp%uj|8oPt|5L35=>i*i^JnJBh zKJd3t6*KjZaXtQ>*N}YF%huFa<7@D=ZB-dskifh+<-#Z4P#x$!*yCMAi9prK{G}9e zJr+DSlaC!Quc=QdhWq%)nV|&}*O+ZzPpLva3Lmjy`RQD_~Z` znoRExP`5W)iIlhEtBM4wuBU|<=PK@=*s1sBO-VMI;>0OkKdERzLe5kF9$$g2qd87_ zpW70Fs@ic4jB_<0JRkXZ*4UYeInLJBH&wJCfq8MUd$}50Msw$?o<}7DRde)~7!HaW8fA3R}^f_W-b$@$&?d>HE+;>{NYgqe{*73Q*Nx9n*%a})7yN_{h0kidL0 zWjR}_tlAzLcb~V=kU$k?zvwu&hp-$vj-{90YM2W{6?RiLnf9-&&+gKF^ar< zV~iz@UyF{`!3)Zq79`|djj!xystI&oX)d)T z0#ygDtuW%N;w8G0j}HG%S1-|+SouILjus?v7p2MMj~5;eb*W7 zG|ipIIfY1-1|+bCK;L$Xo2z$^8_LhDZg0hk8|KC2_3e}2&(-$@4CQ-01I z_!KSp>`CS1@D!=$Nj>{#XYgY?rtCj5wVzajL}DFPNo`M^RK`V$udB?D8^yP8EF4&o zmD91#i7L!9($}DxjpDD}-)3+AR94DaBGJA}U+oV0sFhlq-Z=fJf*jDcLN@E4C=sZ- zyQint@BB$+psKT5q?N|>V%qw;qkiBXADk=Ah1ZIjUpBO!u zn-95I-21MQGQvn4Z{FEBS5^Cnth04U5^pr}es-fTxe|dY%y`pk+~G-l#jIMEUypL7 z%r+9!nzS*_Rrq(2JuhI>o%?_Ew-ir#ClRQ^%stIg^L6J}KIXR^OM53}=#j`f&{%tw zds6YRc~3sN{-|i*mFeacW7J;bYMzQOy`q zB!)-zb+*(BaFGa9VO@mkA4H@CRVP(r?I#kVv zKh3FS`H-ea6)Yrz^E(>XeARSOC3{f60-t);-|}@%tVEy+D|8fv2Ug(a?OZHx=EX{t zI3yZr^u;N92F;&Dwa?SOFn=)WUUsLt?IZ$KSZSm+sm_IY@0u=_EWdVAMG=W zUdFFlakeiJu~mrBGc6xa%LJxz@*!LAI6Kov$Gv}R&VwDx>i&3%7&XGl>7663~h zv504|l@ZAlIPRp(y6F{o_L5AX3Zp2hhTA$S554>>|6R_Iq9r7n^;{ex?yEtk#Br>t z{!@AV!QT>9dbC8K3Zp@)s}1<6IQMv+9aL(x6!jtTxw(rup3YV61|qT?TEDeg;myhH z@jXUJ1gbEqHJKh4YpvcW;b*zmYlIZNBH`XN#Vn446%rBh&;kq8lkfa3A6^WR2vlKo zPCLB|FHkeax>#(9KnoH!=T?}-bH1{)i1e@K>Z<>rd)dLWQzQaaILn}YMF(`X;v_%I zv3V)dYyyeadyg8&(S*LUMqhnf|3syBq?W@odP)SUaJEF>WZm^d-TW=TW!lu9(yR!H z<(|*X;@PL_)2;w1uAxPFG_waW^z>GzxMZ+-)4_gHP(Sgr>8N%LAH za86CNqJ#C>gEj>$XMFvn*)|flgX$Q+s>mxbqn}l`A#X(|+b8F+8=e-u0 zBKO!t_T!|#C8Fhdi9i)*sc4mV<3whj;bQ3=cV5ayAyJGnUgF-}Jwjx>p0*mqEWPe! zKQFmnB2b0dHd@)=d=NW7vzEnLWxbSDL&EEA59KBKxMXDRzKj~qR2r{N)|B&bsKOi| z&1*A;GuM#?ETLsmrR*RQg|GE7?kk@zV$EIR&2*;cI$8ex;wllS!qt1UF8oe9vo{qv zP$A1zTGxk!g;ofPC#rJLHJV3xbbg+w^>pq z6^Y)%QjF_y!*qdGvDW>T!G4COTa=u1DO-yw%=^;4`!$0N_KVK`<~df%2qV#yvnw#e?iJRqY773(mJwjquG$b@Ap@*UXa#VBH>>q$v9WMaA7lz zJ|jjlheJ&QGM6xEEhnllt4{O5jFIf&>53L#Gn2CCNW4*d8|P}pib_M)ivFzl-jS9&_4i5ys<6sI z*{HewS!%9D7oM4$?5dbASIHG*w>lHYPRyP8xVMB-X7OH#Ys=CEf;QdlB(fId~v*K#BgIy ziHiG?*#FdkC@0HxyFL@(usw*$n=fmMTds6LnP zW&g<-AQ7m-RSWd4ov=e~e7u(BK>7e_bpjH{X$6Fc;i?Z5D9Q{kQos&*;A#i`OTX_W;MPdYU-&eenCVkO5RR!=>?#K{u&d89<33RjU(XQR-1 zs{2D{iyskaK_Y$7-w+Wq*e8e;DQ(8(C?nR~&pz~Ov_zl^SHn>M#3?ySP@C7;JBUDw zL`pc&=&K{MTF0H0QB4bRHk!Nt|)Q5b0n~|Z1zIVxP`ec|y zplV>uR3p!DCsM2R9!q%TodozLW{-jn)i7n`Kpa;vLjB&1gdb?4b4(VRpC?B zlG)4GMoYVEkifk_^n7Gi;r4q>2Y#=PmiGK0arPe<<5xY|U7LPYP$_@j>3z+Bg|7o8 z0#&%ilhzs(@aKDay~=j}5h$(xM50h9A5FYJ@`kySk0-8S{DYrMcAKJJ5`n4>N9!BE zs#a;S?lrPr7>tU%M+0`bz@R&B1|E@VOv><`2bj2?0CVW8M{FY1S z92i=VD0)81_*Lhoixt189Xs%T_D&Ypz6B)$Rk+%i#**?Kc%?u;%e$n43@u31?%7Ne z?+@QHV&(FGD|_)no&7DBIz5oqA)^XcP16XM)r+?oT)^_J*#l{HG!jh{%GMn5M@ z(eAS)0#&$IPK=2|d6$|~_BRWhE$xXz;{LJ(O}wKb9D9>qLX5N`9*LNZiG~lmPi4tRoy5`HL%W+DFzcbT>K;{C1*ED>$I;wx zZ102lyFHcagGgNel59L5MV5&^Uj1%n@GpUL4qiB9ClRQ^-HsIRKFi>rY>oui?s!|; z2Z;pM3TaQn(C_+hUq`m@s5{5Z|K!8vwf8?xDjj+TXs^fF8vm|-R8$(Gm;R>))^lWg ziqztmv+}ZkW){THPffI%y=@fn@A82nN0!*PgjOh{2AgzK;h3S_&|_PUP?;JeYk~9D zSyui!sqCSzgHy&U_>SKChdnD^w-(3h0{%kzS5@wGOW#+=kxjc(QzB5ccIqpu@ONjY zsAlZj_*Jd+*Pitq=)p0QhQ}lyM}_yF)ISZK*v0`>B?48qPS_hq-ql8A?s()cGmmg$ zu}>Ov%p2obSn_Fci1<~B86r>ps>3bwKW1-Mqzt2;#sANOgnX`Awcz??lLMPvHi%=* z1J9d$gVL^HJ;98dz`fN|2#DD6unyGrC~jiDsjynxz)yj)w@_sB2a}@ zQd&9C+l{yOcVs@Ff9cB`hZ#P~EiSB~KbhDWRF${zwr4NW{z?R@u-Z%O^yBOC*qx4S zWBtSW`}}PTA73AzwPyBhZ2Y^;+jmy1L5pwd1@IqqJId<~aXu{K zq_2LcY+vK=+O4c@EzzN;ajt4i$h6{`0F$Z2%0~RtGzZpqrdc9Tg=-t=8`2MAXjM%` z7Itd8x%W;tRVvJMG5+n=r?!GIBXxUI3h;Mz?b*>@RU`sc7|BwlVMeaLo}R&5SN&OX z$*aaQ`2287<@DC?#=p-VZ=_&EPTyBMa#|ns)1Ix$S5G2Pg;784o54N z?AA`>?`HXQRM4MHTFHtoQr|R`)>2^WM#mW5{ADYPALk` zqp6$j%0(en$2hPO_Dv-MRhSQ;+IICes!Hc7r+YN($7Y$uJ2jzBcV+NDO^tuon$TOp zyav5-+}f#8fP%(hV8=gL{N;5K@HRP$%8Z2lO3cTV4iioWWG@$bV6TPT=k zqHod7Kd&ko4$NnC9f?2{X1`3PZW{}*#We1k%#~P5s;eT7!?LTEa^YA>nH9{* zQ3slX1zB;5WLEcgl?YT}rjTNvBc+%rz>%?5WtjfGrtx>*x2~d~Kbh##wG3MqV9x^1 zm6ZroVaAi<-6EcB2}MhL&b?M+i!?BN4D`sSjH(l2{5vMMw1Rn9>R(c-IxGCkp4pCg zClRQ^>@e*tZd03yShV4vi|Xax&5ggi-}SB;{mDe{=sN6HO?x)6;=dAsD$H`5OtWt? zRw2fTbze1GU7ys=@R8qPzIk>~d*k1&V-K1!_b$He6wbE(wr5W&&6NmLVMTy?xAkny z{_JyPt#|#S@Egg--|d~;HDvnnMC0G*+7>orErg!h{7qPIx_7;XY?KI8VTFcf6F+0v z2YS~Q9P(GoT+qk(yFS}Ptg8+u8UH?3@M;d$rf3gVMjUJIYtKY?7$i`I6*JnA>lDuF zH*jKuYmC=lROn>kNyLZQ4vZ6nkh?uLe0%;Dc=EV&Xfhw$$QZI{PHP{h9M;3nmp&ojU z89x5mT2!m=6l?r@R-4aOtlgSS4UhR4_ioIA#}a`mtSHkyVY@2KLEO9bOgziJj`4TH z*H_TcpG*|n<;D`}ect$=d=h~wtc+78mLo19is{e^DTCH_Qjejrk zZl~e;391zhU#ePY=CyN1s6?O&S9MUL@h_w^m!!T)g#6DyY&zD($JqwEN(VX ztxYl9){4;*fht^`Li->iSLnxRY>&F8aGaT7d@V=F85>sW{b}4?AIl^HRXAIsoJ`rD zx^21x`)_|$X(o(o?r>I2b%Py+dD~DY*7-|$jw+0)Wg^OA;>YWIv-aJKbF?6VF+5e5 zX8ZEa^scpU^FYVg8RK1ymQAK~+9TX&mm^!Y{jNly3gdaY<_pYxGucZ4o3ZQ+@8}gj*B;EP@lIFfht`0 zOj+kvalCFHNA_S#xP}&t+cD~=4*cnH#yl!vyG)=8f33+>ZFekpYvsgBy~N zsc6AD4$hEhFF-~u-ks(jE$#9o0#*2HO{QHA6?g}FI}bfmhFnJzT zGom7|Orz+`r?wJ-D$JzQ+S;Y1xF@|oyl2}pwBS4nXJJ$~SXqLvk8@&WV#-Pcs_@rR zlyUT#o|NgphQ-uiI48q-6wbnEmRjPeevYnSlULOx0#!Iqq*WHKQ}scVFYyhDWH?{K zIS$T{=qpOgC+o{7ZeM#YOd?Q)b2z$p8%)r)``ELUU7IqT%i#P2XD>904jZop(-r)1 zp|M1u3g=g}o8#;PC73xek9v*SjOl-jSrMz8q~Hump5wIWwOBco-+@(`(?lXrg>_Yv z>0pzxYI!=2!bcl3wBQ^EXGoMO>R(#*ptxkik{F3V75-Yv!(~iY3(>fHaAhc)xon>i z0r0vV6r6?0^C+#wOf{0;AItBDO9ZO0mQFLs+#Tvc8cUAPHZ!!~JPK!Fv}WSOHZ_&v z#O=$1B?49WYfYvqZm-okG{Qwt{uSqBIFG_v7_Alwe4$>TBcC7RB@w8?c_OU@yyU|E z(C9P5u@J-g63%gOhD04UCRAiS_EVn0*+wEzg>yKX3I7|+(kYTDVcl)45c_MD`$@8edRZt|&dG2dg|jgFmT5)<_HnBt zi!Kr+5vamhAMHx{Q;*f69GUKaQpfoc&T(*tL|?ppMzd|o0X*JwOd?Q)b2u6k-RxNm z)eY?06z5nUz!{P}$H{w3*P@<1dt9xgM4$?57ADgL{f7FTuE+1Meo_SsS9xPKi>l!V z52`uz)H>V@;;6z|5zeV-w^K)pDzdd@dj(1as$?GxPI#+1bgu4BZOE}|fpaQ(R&@W3 zkNW*nId*wdtVEy+D=Acc?d+kwHQTd9t*KOr!@U_;m7`H~&@D}54ce@Vvs(J5DHo82kI z=|z$fx-C~JkSIENX2uL;`#)^fKi*NyXhEWMg=D4S;#_6-hcmRcc2bg!z9Dmn{UTYe7~s}k@3Ewx9{X&NM4)PCN~-nn`dp={Q%@pFu|>M`tpwd}OA1E|5^b(! zS?{mTRsMb7iHOZP9^B{s0R7bVAYQjzc}*PQvKjTXSMDB~IF47jp4!R~53R;Oxr)nk zPa>Qu1@hk4QuJFEFOC)@w*IZ7=}U8!32z;VIFPejxn8ZkzScF3g@y!%2p@l*Cn>Kh z{|OO3Tz}9tDqp~Sh|bl8LWv61KknC3_eAc{?~h1kXhGui$|mMD%X5{i^aLUXc-CV} z&-c?Oac{o7iN7+EybpWotlh0qRS`ZK?60oXXfX>h!BNl&r?`t)s7{{R!r1 zLE>8>dqo^aZx%>I^|c{do25haPfdsO(e<}ms%GUXb3>9fzfRk7gpb*klQg^+Xmxbu zLi&t7ef8!)Yi|cvbiMoHM9>%>p?OulU^`G9FI9Knx`dL4f>7%)?p{I76pOwB7 z=Qfi66Iw_gGpvk`79=uWCmQsZC!E&lFZCSBOKr2!=M9z#RE;ntY2rAV%s50u^*y7w zHtnlcdUm3Q79^s)lQnUNb~}Gf5M%860Yf<$biUdFkKx-GsrId0uxzVWV&-oDmM6)jW7 zbuoPOPFig4zoLubqg(66L=-=f%$J?|s(H`IR?&h)mRmc6-q27dLhP7ay~$QD$*-79&cEC75)LKw)QU2Js(iVTnixfs_X#31 zCXTo5@>|na+cUHvab;_e;e*#KOGLI?93QgJMjtg)CQ$V@H^{giWqbtD<{*9j^zK)! zS*`~|3ljI+)G&N>9P36z1=`=0k!7o^!(;+gK8ZDqbG3S=AlfQ^e0qY7UV38yLkklA z_bVEAX#cr>M6??GU4MPhR-gZ^BP-bIoAutjT&389hDxJ5dB$jP?pzC{=8rt9xaQaH ziWS(S9WUvJhT7 zc5c-3_xP>VUDcnV1&L>wQw$%SUUnm*?ZS<^Zzb~KBonCWa&n5{qxd&L)a{wB`#%1q zjVzbS(1Jwdwn(dZ+a(ti>#4ukdh0*Bl+f3_N@Hj_7W*tmj92|SBr0>7JwQIW)dcnUPnD6xKW-*do+1Ol}`RA8e97k;521@3` zUuJO}&8IgYA6_}7+4sJ`v|r_$F|;7zYFkoSL3dBzmmx$PukX)#-8WTVI1tkFg`7RJ$~Gc#DmGc8N@&3S(4yt36I<1uJ~j+Ii(l(JK-dm(nwMas+#K z!$wb!tf+hJPBdblnH%$RlwrM%U$V4zh&6RqZ{u9e`W9@ZjPQhdJ-8I=+(D1RKU4%C4P(i~-PM-+!;XA~J{i1i& z+9z!}T9638@Wecq?#czWO^9%bw_%-r71v9wmI+j4I6XCHpBnLriJOQTezKhq_L~j0Fjt zi_qHIz!2Vc@MG)FA-)_f5x?sgzofuRM@^ow{QlC2S6zME>g7dW)`htY5~YTR8np9$ zJ35YWw;J(|W6Ep)KY^-cy+VyQXSo7}iSW&B!^c;iW3{KQDriCC^MqK#N4tOC((`d? zcN;$8!c}W#q)edd;iy<64puG-!u?A(t`8q-oju{5jus>w%Ct6o3_5X!2=n!Byw>eU z*467|0#(!PTN{z(sdIv8yC{j5Ty)FoN?%w(3lfiQI%(n#UEFB{5d#Vh=2IK|XD!uf zwT_nkU3wV5ByLKE72|p8u@#xd-9`?!mOI-=iu94d7@k(1b;#iA8b|G3W~zpfx35RC z;bV+P%0W2_uX8(tZ*!evegE;U87)XWT9jna5AKYotikx43|{k7d9C=G8xnykjOR_J z^3kLDrjIkMZbx58kv_ObPFn?@>H!XtYbzog#1?q-bVY0X6cq1?3h zmi6oN*;1sB1jg|6jVcd+?mXL3yS_M(VWijma7DvM{6;@Tj`}M3SK-cmXIeM>8_m#y zL`?DG+QZp2(@#(oT3<3NQ+K;|&3a|;FoqG*($8)T)M7@w~y>#=FLDn9b z=~A?W1ja$MSLUA?AzNl&vu=5o!7zF$I4>wf_(*?DIUmSB5HbI-vT0g>>-c|1OPL8I zFy}x!z03Qn*Ow^T?MovWS{9UVY8F1|^LC1ynK+v^QQbNHgSFk`RE8EL-prqA7U$}x zn}{sq{jaItqn)(H<@>O?ANQKYle5RYz4G?Vb;HMtxy==fyJ(GYK@NU z>wG+%m3{TdYWY!(q2+u+u;C-YuCyXYWJNDGW0&UKvJQD&mZ1fS0PkocGN?S%MCa;x zjgD+jBPVTXM1F>r4x^eGKKi}9Z64phnGqAst#ijrJ)K@8Fx7pIwa~R&Dq4`p35z%A zd7%~}I$IK0{`2KEJDb}QfvPtF@rI8=(*)r?vp*X(ezvtzi@hpZkf@*7)$nmQU>*@w z)BCgj&8}MOB+CS=w;{t<(%6%J03qE{p^E~TDMw=!71mDjEDL$ah?=nscv zY5zyFhV0T#N9}MWSB{b1{o4Mv!I=Hk=YCMV3fl z%t-6{x;$0)U%h6{AJvm%r1yU25#u@E6VO(ZqrTybb#>{Rq1Lx^QaDVl2U`!nWy6zT97Cn|II3XVU&9#8cQCnF2;RR zJE)0Ann(nyFvmep?ZtQc@I48tTfg>F_5%sbYfxr;Xoi}e*;!q>U?9UcK$^!av-l+w z8g^6Uw?XL5&1$}KgVoW~`!cj3(YpOUgT58po{qz+Zc%>??4YiBne!uJZj2h*>p znJW|2St;G5Hx3f`E-{%dUMh&%wb)1L?Tf2_DDq|TB>(n z)YVM0KPtZ*2dfhIl%YTK!TfS- zg8Dr!O`2^Zf%9nE%@KD~d-T3TsHzE9~$Py;*q|ws=FRR9!+Mam*US z$2wo}Hdy#{pg#0eWfpv1kqA^_rHQHl0~hI$-`&{x76BYBNQ`mL{oh&WBC1O%8yD;4 zt(Dn<->|}lX`l3F$fd?z6RFw!+VJ1ZEU3#wHT;s)- zxj1pOAn_utyy2s4w#b((DD_I;8sfoP{Oc?csKQJLt$>)Dk53{W zg`2Bz^J*{FIOUR#79{S!j5U1Bz9}*+#X7k2Xy(Dbp8QuLP=(nR>dalnn|~%B+pq1_ z(SpR|wAO}?YtKZ6CF5od{!#T{KcDQE2vlL*MXM$!*XB0Iyx5Kl^QDLk35>O5T$gs*0ndhVs-kFZ0$Lm@;T9B~m-^ZYXLq+wq?1wOJ-Bp=2 zH3v!rs^l!r_0AEzHH{^WogZpwL85=hB*Vv~cu}uy?-RwvbABrS6Nx|-=7Xp{=oH0& ze|BS&ZU;)4A|y^vOE!Fr3`(GSZLgN$yevKETZ-0^2vlL#h<1QZ4&{T^da-t6wwTd^ z#PJ{hJC1;S;(Nl=SB3LMK_2W~i_H>&D$E+0OvgGfUW%=|Mi+xk}$d@?d>~ zD@X*YFf%2-AQ{Aqi2JH%yoweizTE5bKOfITrCNz(yyl+DEX&$OB2Xn~ysFU`1lE6X zqrH~1RJ0%wGoqd0!|RHuR1exvpPxKfna#+aEfJ`~sx?IbrEBw0@?qJyT}2BLiD^v@ zA6?o#Bp;1*f5XQtr=1dkDy&-5zM>~yd=ot%_WjSQXhFi|ZKUC&nS!|{QrRCjFcz>nrsX5WtGmk3m0E|WSa`xoX}U)`AN*Wyx^ z6NzzeDjGh9*w!W=dDU(CjA##*Rl1l&pbF=rw8x~xd;QuGFBa(NAk9*d!1*Wbia&Bp zuQs<(!u)zT>@^JdlAFO z%f6!G{&!-AR+g@Z&&p_tKo#b%X}8npRUzU%SSBe-%4{Ryx<5EX9LJxA!bjmjJ3>B$ zdoZ_+;Szx=%yQHGyxL6VD2==78OG3p#FjCC4Ii!gh;x-?o~t~0>dY$riIE6YVUC<) zp9v+^IyA>w`Z`+5o+GjFK~u9hj_5ODU0>n5MOEiG5B6dxlL%B{9>rv8krAifJ>~kl7ZH5_-$WAwmUlMw^xs-dS+IF);>iaA&mhy;J)<9Ms ziKT0Q8}!6|Vg-cD;1lY_y_MOe3qBHoD$JwMPQ7^#Rh#ne>_k8nDZ_#U=1$1R$*-zk z8Hy7-I54zy@^n?i8Qo=9OOZ1yx$|vV<8@xF!C5bA&|4}pJmt1qWW0o^T11{BMn`!=0&1Qp& z79^rK#~XC@_M!?9xwaN7NOQGTzy6U3RAEJc;;Y7W**&_ig2qi!(Sk(Azg-O<0jj7Q z>pJcovUvlR_c>_Cz3_)tIVu>zDNYB z=k!WzHSRvFpDU692&Zjq%2vlKCmcD5bAIw6W-Pwhl zF;cb`3CzP%{bQ8}D>l0_%eJ|xV`lS2MfSflI-sZ~XE_hIc4nbJ+?aR37ac7~*e3ZK z^qqLohu*V{BTJ`u?UZ64B?47gX`&};M^TnY&-uL-1vy%faOqvu@KMc0bhBFGU5;f_ zPUiKyq7s2BtW(igQmznN?&QM0gqD{oSV(NCVQ=`jzao%)*dH@7`xYMTpO>X20##Ve zqPhWn!fZMnhq}<2qXmhSD+LW7dya_JR8u0}sN((6T6dKQRLNDci32aFH|f6W^uv>* z1&J$#o*M7klp$g@mEX8a>SV1l%NsACDXONtHMxs%|@C_;|BY ztiRe~w?i!-K`Nzb@os($CrIK+0d0 z+Fn;8P=$GT%Fu7`pk{C{Hgar`l%YogbMGe8#>cjb-JHrSv~B~A+3XQjxJ5jvGv_2r zYs#o@aK3^zlH%RCT2UNx-bnnnXuA>b&X_Io?Y)Ln(6T-{vpUmbB?48LBd7dT+zoA3 zMHjXsvZ0hcM*{QWw4QpNowZv^I{OphtxJ1};FrjK?!J#IVEwQlovj(^p`!(T;4#q^ zoP0K99#yODA5@VDRAH|;THo$`Ib_|h;cWlM6df%{;8~|mZ2cdbom-?a+TSh_sKVV` z)W2lvGjr2|Y0UF*l8zQ6uogmRcVU?FiF}+3?jaGV!d+st3*vU9Qr#wvHE7jGM+*}2 zef4_WB&9L=C=}gUB2a}LJxr#_POX)a)RgSGrJ9ZwB=A0_ei@FVl@oN${i;`z2vp(z zIEnyV$0{?;>8x{4HytfV;PXhci4#kd-^uChpueL;pi16%H+Jp{rH_)%T2-j1qXh{Z zSLm8|+@n;Yu5m&-&fSZ zY{*x`$Bj8JG$c@kdpqf=J@H!^MLufXf2W}Z34HI;zU_84>MOeDDRc590#(>mN9>ic zQ+JY&5so*cemh8D{6Om&z79ftuGdCG3lbPJQb(W> z)l@N-bnNRc5vY=PFQ;Fwu8QZQOgAs7{}mD#`-&BGUg|aSk^cQ(EB-B1;f`$TlX}s| z@X_ze4XLjd5*WkNSB~wx)qUin$vjmeP$lp59zD%h70;l3<7g{dkifYJRZLR;=Of`+ ztq>$oh5Oy$ zJr1ZF$sjLPj925{m=&}jfpc}S^WN3)VK-=iM4$>gFVJ{((%rZo>+7Z}Xh8yV4%GL< z>ZGotbG6nyPa;qycXa4<-^DmrvyLxR(1HZ!PH649x3uNAZ)f!Q5;20JRoSn}&=A*m-Is<0D=$yCIzm@$@2uW?@LHHZY}nW)BcGoLz$ z&egf6wDWP=lwmj2DMGF#`ji&F6Myyi0(YYGl zxxGZ73VY<3Og$>CRYuXfwu(&`6)i}}x$vmqLCR=)pFeEbRU%M@oqK2>OW_*I*{F0D zaokHq3lf-nr<(AkcV;o}+Mn{22vo@(fgTV3W)|^PTMr)0Wh-K$YAD>Pg`));tQqf6T4w zB@w8??!EL@EB@3vmCn_JK|U&4kiZ%n?EszRtyw6hE?1?RM4$@0XVI#c{=QmU@^P+h zbrmg0V6D((nzpXHCXVCH#mW+aD!D7y{^EVK0P@lBM^zOqNMOyA;`S-iH4!J~HLM^J zsKSoTCewcUUZ1!gp2@B%S|kFhy)^UMzri?;-|zzQ#RGo#g1+aGD)Xe=3bq?w8q zB(Ss9s1l9|c2vo^k(wa55(M1km+~i0VEl9|vJ#{f3{5Ua6K?@SNhKcr!+It%J z)vNo3Bmz~~Ka=+Ar`9m8$DLP2q`sU;;94x2N5xk+&Q-tTOQc?$sKPG{(A)W1O3B*vO=>pv><`2Lg^XQTk9h4GqCS|i9i+h`=#04*Gjs`$?TmxLqiJ^ zxIUFuY$Z7uW69iZt0e+e*oBx@&}BLsu z>24B%D*O@=&5CZf)I{tviF@d1K?3)iP&bkq_L`_o%#4=_RAH|`s`lwlS_3+cAxpe< zv>+kZ2V30OZ`23F-g!v`s<3WoGBx{dwTjp$qFqg?qKE|63h8^0{=bw_BZ9cL#g|of z`V;c@eXi0vFH!1km-e)o`S+h(WgLBdn$B*Euga_S3b*yD&d`EA?{LpJJsu?@b5zGY3SWM>0z(TD$1mJ4ul$>< z6l^K_MwwU5Q7dOCeB{>(3@u2=yW-ni+oYPa81MF@j6|TSwd)`A6WddYTCXAb_&H>& z+UuT~S2}FZ(1HZ^FQN{ZOOL91!c^|QvY14mYL#A6IZxN4!2hf`q)g|48gR z_2yBP7o7b;MFLe3_i8GcpK_IBc2&qnxy5hPa<3RqUhqyu3li9eiB_bPFtP4s!+5E) z*Chf~xa*&usBr~Z$0}w%G5(%Jpz2M{5JgC=<*H~w0{4>A{@R!2*wL$IUVr2ci9l7DZ*wJyuE)T|qUS(F z#j?yNPvs6@wyJ1B0(%jdO#bnXEQP+A(ev4Ci9prDYVD09Pkka*8>V$|VGUY`a9v-h zq6G=;ML^Ng@+$1VPiAg&c(6pE>U&%_5vDTI?T+^lr&2Os^;uVP}-Z!}xW5qC}wT{k~-5 zTrK%5ddk|ld9wY!VLWx|L~8sYQy3W zYOpOk70zxnGNT0v?At*5iekLkh$AZh*dts+0#$#eCMjav4IUv@8?M|}gVjt|c;hNrMR{8OYRdpm#gS^@;t~(R$R~jGyS= zUPlWOyUukqu18>vtwe<0qQKdGyy zgz|lppGgF&j-RV)+$BZ&i+jOTWA~fulew?5bp7yrbM91{pCK3hRwhqV+rM{$yGU8kig#YCR5X0F?vfSjCZoHED@+GylstD z%nkloEV@R-HH_6SW{2|28!L0PAc4Kg=^5;msW0su&X-u7Bmz|tLvoFCmHtF@jYtk1 zuTQ^e=2y46a%YSX?4db#JJxCVc!FD!N8Y8NE-Ragp(BbxU%zAc4I!O{T@Q z&*_J%gmQg+A&Ed${oUm?aXp5V3lPVVa#623TH$Vg?KoPHz+Rg4{n~s_^yG6Yj|l#y zBY`UH{%h(RP>q1%lxkUx-j11?Nc2sNbFluR};^9 z+wP)kL{!E%y*X$6%gZM^T9Ckgoz%BB)rO}8g>u`a7bF5zQ5C~9@tkMHiN2hZHWuQ2 z&Z~TZ<&usTB(Ptn$&_@UICq;CO5I4Q#|Zo_RBcIUtciGcW%qX!Ew$-hhWnaBc$X+$ zM+*|zuhV3DlIF<$|C)JL*)Vy>Tt}JbVY^wfC%&2vk+dZLNt}(Za50=~vC# zRDpkT3gs2|jMveEM86B2jC;5EbJ5AOSfxrlXr00>fs=H!AR%}Da<1*p3sJnA|3hnu zK-IAtJ&k)etwJXGXgJB8H=3aE8k5@SXhA~m7TPMvgHJuJaxI{OM4$@0|I5d3q9?a; z4&_D89@Ef*1b#1ozRhvPlTR)Y#&cp_Bmz}az9j$eTrEwc`QX=K-aO|3eVcQUn}!x7 zu(v6_Kc;!{@vB1lRi7*~5~%t)Hd*^z=#&z>W)jV7vufAkN0$ck@5dIH(Sn5B^>lu3 zZ@#{a${QD0svv=?N5vA2JF(zp(V3!(>c!9hWxSu?5(OgIJfmC~-`dhf zMFLe_O$o+Wa^{%mO!1d_aLW3QUbNu;*MOc#<>0Lg%W|PQy-gZBC|(rO^BpQ$kZ2SeWt^*wAkmqkXyH=)%j-~X^Lv|$79{Z73Ume23i15i!+HGr zV-kTX_pbGgUnRaePx}X-72(CsDg3d=X%#I<;I|d1-_jKmpZ%Z8Z|Zj?0#*Cp)G(g& zb`{*o$B;$8^l9tD_~C~)RkR?1-&Qc0w$*;7S1uCHD-?ex5vY2nx)|r`I}@ELCT@DI zU(Ho{L;KGvT9ClLyp;X8mZt|z4COKJ@-rk*r4KA_jBrD}RPxay?>1F@g8AlywhS#u zU|&U=v2;4Bm%JLn$Gj*h5vV#H_uYs^cW)9M7ZCuWN1M`?zq^r;3mEK zb;{5mbC3vB)fj)pDq;p@a!c}Y?Z+DZ`MxlI(6079{XH6tu^EPY2ySA)HSMb&&{E?fpK*h68NnPn&UiK8*;Q~ zIKMH+Ln2V++%qIZ#BiJQgpai+>xFczAHokU@?~g2Lhd&0wReo77F77EhczSuRap6> z)jP>a%9%3M(JF`b(t~$YVWp9xrJr5&8`rkylscQn&vqMTozv`;Qu<9#ZOB7UYaac3 zY4^Sw&Mc`jH+9Dkcyb`8>SABccGuP7qfRMD*E?wgu6byJNE_p&VRTNB<IMswfR>A@Y_+*eSA@x$6CiH7fd14Ya*W_K@s&M7;mPvHmZ$fC_O zapWVHJv0XtjyHVdmEdN4+mVmKWAqvqvvYQC8^&JFjIxR&Z~wcu5?y|(ajp*E=&s7qf^R|LP<)WI`u=uxjOt`vVzx%=4wvcRj-fdauT!w z%znl`v+yyiN(W_n+gWBo{4CsE!Ky4pWOdxx?wez+3ALkG^+Z=CH|&&>Z*?uj{zyqh z5TmGyh;>7gY2Nf^OzF2hXH{xBw)S&0eM9Dya@jVYax?9K7%mcr+6T45GG=I!XB{b_&HIJ;We zb{ggYsh9T^dmi^~g5Gvmq%?QNIqsP~n-A5(0o&f2v-Yoo8y!83rWqSJOgGBSN9pGA z%_Ra=uTo!Gjq9{TWIsNSE5W$ue!YTEsFXv-?67o4W{FpDQ2`I`4tm5AEotf@ybBm!0aD=jvj$A3DB>|n=& zPxNi}6Z8(3+DmyK%;?BRUbM(dJ+|AvoL(tCBm!08?Q@M$!u^fNM(rHz&W_ZYppWQR zpJNpObMNwzhicVX#G(yadL~t|prU}P70;>~Pr;NvqSD~l>5mezc!RdN`Y0*)idoeQ zbzRKj`5$vr$deU&|{ekZc%o>{h4`W{f zu2b2)|3;>86h)Z{ndifC&b#+x&O8@FW|1MvP|2JjilT`$XQ(7(IHz;=-ldX?WN0L% zq6leFnd*P9_3nMTAK&lPzw7$0?_Acp_j<;)?)9u^?>9EH)KdStVq3f_jqV7(nEzas z?1QPhZuU6(*obkjJnApHW}e?U*GLP&T6}7V_XjVl7mT}My}zYpMJq;eT;lT}c{~Qly+_}l$bhfC4U@bn~K_%{Jk6`H&uliFL zer3f$j(vQNf>VS0#|4L$P42kDuf9?Q{ zug8x(eo^EsbHl)(*^@8feSlZ2e2@{G@8Zp{#T|oPJp*s!E5*&eyP9ZpzI1CbXVbz~ zIybYL<#sqD1Bd&qdGM%t)gSs@A;S_zaGr>GzO7Wy?ecs5pUO5h->)pJBl=^C>wGb# zjM_NcyPm_DOf2?S!K;He|6c1K9G7ZkER5i62^HQIXZ?O{w|MOe-E2Dj@o|={tK#dr zJIltNQLgpWy0#8yGB|ZE_k+Lvr%nFVExK9=*5bSbw?$I6`JbO$>rY7QgRjc1Q(ki5 zn!XO(vk?V9e#f7ad%b_|$W0c4wKy-q+fEYwYi-v9H<;THE{zQY*+-X3_cy#H{Q zm;8r{j4_Q19?O)TAMQWWVdi8b<_#a#%aIo zi}(n5=U>NlbNGH4@{gSt+{~&kc`dv2v+j{Gf^WUa*=tGM9N6T4SG2Lk6PX+GO)i}2 zm#%GAe>L0h+oO&dbhMU6;TxJ)aF}KCO)i{3URKTYzwA~2OwAgGC5*6b#Lsp%v+vmI zooHRm;-bty`IZ%ab)iEa<2K#m4Jw+;biT2xvZ4ifzlt-HwGo~6bu#?mcGG-aT3CwwLoC7(HA$?y{fKvf4g1(velNr1vM8 ze76Pn`U>N%$#O4w&7LY_-C-qwsW%NMQmD!%-@zam#Z#1^8vm^0@nBa_2$WXTfEm=l?`6Y z-9Xn>r=_Jl&Ma(1$EQ-wt#vo~2QMjZAy|uVk>Eb{@lEc%MP~bT4h*&O56(IGegJ-} z@%|<5_GP=hc-9@ktOBFe8~nL^sK?n0SNmhJJ=-2~-~VcgUo2^~Rn=#NZDT~Pcl=*| z9+Ks>7!@!}Wp>6DQ`|`TB-3BDVY>fpyP+1}WrWShe!a9x@bxYC`2BBAu_}tp$T&kn zKW?uX{C3NbteQ2PfF+FJGY6awR&N`mbbQr+_>0^DGcV>>% zlVEt|*?!3b6)gm7aa{zrMQT?GQg7PiKhdM6VF@#DX5_Kh+1)jRcQdB=vDGV`x4V^1FoIb>-V=V|Ie+!)m;Bjg zuyv--oSOMOzTLKBwZE~(;H=#1##y|b5qug6hWy!5uWz39{!e3Xv(Cbq<1+Wf{ur># z8;i50ZhLG5YcZe4dw08iXW=W0{6!Cpw@x|vWYa!%p7cd^cYEK5{CVl)tT$hn74vy1 z-Wi>m?7nf=CI0^XcUd+V!Tegjag{iY(NY>ATDqBpS5qvrf zUNXLysn+*Z%;0Y8eFJ95%(Zdi)#Mga;PpXSyPn!_v2aH4c`|;-{p@HnfBanko*}cW zGh}Aoe19Ktu<0mM?Z;RBcGuOn-UVPD&HNf~r8K|8%(()i zQeXA&+gHgt7iEOa!k>ENyj%b8b^gX1?eliNVa#Xm*dM$0yU*g(;QMcHv~CqMZ?_S% zCm(jd95pCw_U<78OBlgb2YiLLbr1K{oiBQ;=igbnLXT}v&IJ-+V`*!to0L~ zEh8dNL|*%3CFky!rUj>pUgEHX5w@rP{fbkmAO1Zp7{A0zy$)j`&yo>fEj~LDe>FG0 z_Sb1ar8br0j9@JpZF!Q6K5jJ5Sm6#nJ}UUFemj@7I0D#+Zq*06?I+$Dn9=V!EMbHl z4VJu+>s6V-BejDJMzGdBgJ7)_S02uk5kWuEs%gjjd*l1x9&8=;oXc9w z4Q<40pS+y)cD{+huB#q!S;7crnX%Y~ij%yJsZG3*t)>}9u$Ft{bM>WF@EmB>%l?YA z*D6m78n0jJu!IqOQjA*BkFTT^Ju@|U>!;0DU7q=;tbu=ep8P`>=i3`52G4vm&t(ZC zY%ZGAph#NrbJK#>hg!xN!CJC9#D>fhZZx*&lj>pDzE!kl8cP_#r_uQ0;_b)dZ@xD* zSX9`TRie+5^}|{=du>~6i~HkSXZ>Fnwl|DmE%vUlSiW&J(yl_Q9v}KhCQBG$^WA&{ z2Bj_fYg#a{PDUCdSc`oyDh=1a;pEvmH8^$Y&(>`LW??c$GFCD&V13OSxk2Ut^`f-Eg1b2zM4T2MsURxzmJpNICBlgYX16FsfYslEZHZlWuLG#%b%3>`1xtU zSKVGoV+3n~^XQW#&f^C5fDfE`%clk(-*Uk^$LI5Q@f0#v5?}E92Bqt{WBX4IW{iB& zV+kW{rhhu)CHLpnV}eO@uJIYcS`@Wm9H5uJeR;@uGiD3wmE)DvJ@UCEQJ(1b&5e)w9J0cNR|P|2}q~ z&sv;K*oay`lyUFfcUN$*bXlJzjNnWr7W1EZC%*LaX~E1r+p;2_2rlQb7T@ZK#d6gs z;r=z`&fxwwEq#{oJ(-Brg0Z0=6~Afi-m-CY@YRK9eU>nS^Aem>-?!U6y18_)ykV_? z5v)ZX95(8{TFWW-?9^axt~WiFFv8ApmTX=XziaKZ;CxD|M-j)mf5mkQ~HpVJh{lr9$oL0qzb1K}j=&?Lid}7%eTQm7tti`?yw~#KG z8edoD&ft$VlYRD4e5c6vS3UDSlC^NhZ9$7mUh`SP2+ol3>%bp$@fMaZ5LDmZGhhU3 zflEY;3_bs3&*H2*vZe;Dj=k%#gb|#>p+49#_PVf)X1CUro<}^6gLTi6&ya0@g)b=ky-MB_+}Cfv!xBdD zc{09{(jm9k=)Gyd@SVHj*n{f(bZ61c5&L|%*(X2x&Dt~%Hm>>6Wdv)9N05F2j3N|XyqTQ#WzRrYqZY~ z@ISt}S8!pDow2YM=OVbhv+ptQ#|I|`;|JX3^UC1*puI{?#7lWEADtS!wC82<4B6|4 z5nP+ZyZZ54v-W*DEqD;`*~0rMLgFE7@%=m8XIMDT`{9o%!EgAr7?$vj#%K&>-Qm}{ z9(v2$^6@=EuH}DuEMWxKCh_%)e5?G4U3v%IPVE!FtMehxVy$RqW!<4}(Bqe^)!U{8 zv7Wz3G*E;{7-82r?>zOIcO>`JVBDn}J&7Ap^kOZ(w}!9Q)jQ;W(6L95d(t|e5v(O~ zL!Kn9Ms`y9F$K{<|Ja7-82^5A=P=`=rdYU`e;S z9x)l~UN~1tV=?!x;lYrFuY0$&-0d@hwZ!X4KV%i*4W8IvUZZtWg4;5?c&x?MP#dxI zsmuMk)u#kMf1WDbC;M59;JPaAz1~tRC|bWm@aTt)h5KY9Sc~{0S~axnz23LCO$)w1 zb7vMy7{QfTyv?z{Suk?IwdSEJNj@W33w&N}OQww9=N$R4_h;v6!5;%^WwC@2b}e=F z-=qB8dG88N)Ea6bSWDQqjINx|;kP*+ujF6);yrlxwYX#~*%C(ZDIU&CW={+5el96q zspBM%5#Q|`={5iDedkc&!@wqSFQq#^_!R#GqwK z@t+pk2-e~;#Mdc@jSqf#FDbsd>*LnwGGfg&H>-__6J=d>1Y%ugQoMH;8^Kz<7GkmU zw~YvXfQ@zK_E@Wg5vRMue_<59-gW>s3XdHT+&(`kKKddAYw<3?31rbeL1m0p+Q(CBv7k|=y4n!5SYW)}+!CJhVu@Yyt4ZZ|o|5s%KmM~)Aj(Te2vU`P73*xzL z|BGNP_DYE7xs!uK8%khV(6g_;p=UJZYAVU@i9fIN8QLh>eaIWtK3a_`3(P#D>gK0&zWTT#R5X zj&{h^usTX&&dXx8u!Iq7ux7=EtTns`S$>cI;ZCe}tVl+%7RO-xdIVN^vA2`rHL>bh z!ibgFdtyWO41V7LJE>X5{~}n6`6JGKuvb7vC2|u}bXoDz{x~_F@mlFG8>ucC#a6 zB_2>cgOPnMo`EHdfL~D?lht2kd{xH0iB_F}_hAHU!7Hf^@qqG{P_hxfI{}`KC5(VC zR2$;Mz`}DUo245O2jLSL!CGLNYC}9A-ejHG)~te!fnS#iSi%U5k=jVa547J@hTmNc zZ^{VP0{>MTa#o`;%19h+1K-FJMqsYhhWIdW(V~6Ktt*n^Zz4u9g0(Q$YD4@P-sTuP z!c=`fDgJ!fJ;D>SC5*tDRU6{NVzK?VjW9o9wKsxqWCUwr&FYiH1IA+Tg=W(qN%4yC zbSz;6e4*M9ABHcI0&#q4QhYhQDI-`5drxi139B+0BiskOS9nRbgc0zCYNIiHV*)Yn zvHwM|7Fd)%DRG7jU#K>Ok+Fmk@P%sQ*qnDaD2Fr7OW-~iAy^CCPHo7EcZk4z4DT>W z>ol@+M8b&3-VXT21R|f86i>eh!CFyk6DRcGiDnMQ3OrFNacW&8jEKffd{|;*n@qsYrzvO1Z%+;>WGLxQ`Tq7 zpda9gh9!(3zY-!oOYlSs!CLTzYD4^)@UJ`bLbC5(VC)cqlS!xw?j<3npS8dwO{ihREKGi7~dP=Y5KmM{XoFp7gA zVsU~eS_synXa@v%qB|CIj#z70!U*`nD4qi$v9{8O$`dUFYf%ga0zAH6|MUfHk zg=$0gOoC-Rmf(pNg0+Z!00Exp)_Y!=j8v*4#^Nvnd`|h5?0w~lZgs>DFiQ);T3~}w zoeBu>L}zq@CmNP80(`E1$PDlK z=(T}Q6dR0SEy?}GyNU-?p6H)S@I{F(B^K>V(-Kb9~8&sH1a!_+4R-z9jWgO z37%*nSPMQxZHPZhoZN{G@I>)0*%C%b&Ma%UcVjsnlq@^si57yj_zW8-5te|FUQ$4k}_L|NyX$m9aB(kd~?~3*I2@cXdR0W z!~UrCfc9hA(oR-CSc^x*ZWZ=eDGrkNs+d|7xCw~7(i-%*T0t4ftEtfO0U)`!;UvKEhs z-4CpxxU4u?Ay}u`XFY7&MzDktd}@T>iF={8_T%Bje_QK{wRlAECAOKDD5C8gZ=*1I zX_d?xR)n@t=ALcidjJa_O0263JuDlHz=y%JnP?R(M7Go z4WaCP&h7ob5Q!6(oca-=MT z#Hqpm$10K4uj3|jk6)a+v**j=OX|E>W;3^!wYY!wd_k;5RiDsfDz2+y`2$nEH*oB~F{BlOH7WdCaNW_$9NzWx> zM$cv2xSg;uu5Z@w+c2qX^HI4i#WB8ZerKw z`{8HZA3I{}t^L6WiJ0;vnN8e#%>lt$JR)|h@CG8&x$`Vu#=bXXeTds(OL0cB~v_P6S;D45k&jU zM^d^TYN3dM57hQ{9&_~XZ+@zE5n@owS5gGA@6vml-1uj0eQxdjDV}JGJ+W)sxa#cp zjRi4h;BU#}-|Xy&*1e-sJ+w^!7R0tE8#~gfBTv^ss|GG#*+dW{S}oknZSQ<-N40Ub zRojbfd_VnRZPm~lXJ)*2Z3ESMWNUqoTWGiH^}n9gvHEgT{&ciT5i$ZYUeZDt0UME4 zrbZ)aRl&R;Y?1bhmLPegZA7gpAEXPS#vLuwx$SJvM*MhV9Ys`mZdwM9XqyvtbfmXq zCT&D=?n|{*58e9o7PQ}qT7~xO-_jp@bdUBZtYg)%(mjpQD%IMbKiT8Cu@Swu)YSPn z*f6;Xw?t54Px@xJD&B9u_TyykeT~r%wIL&VbZ8y@`wpx=+Xh;ti2Ob8+cLLCJ=KzV z#%k2RWo~T5W3_J6R_(2GDg`Uskrghh1t>*Gt8B!UCzfxP`IxoAq+kYBOL~nN)W2m# zZbxfnCCb0ym;PEV*R}il3zZv1<0GRSt=$b6WgGEwTGNJNV?fi{8LSnx@Ryy9E%dr=X72M}rWL_{2`nSxBJ%iCI zX^2%KJ4yDPtS_|C3+Ka0>fc3udt&Vd%M+{p;Chqo+hq+i+D5cqF~~ zwI?HPBk~M;RQE^lQJ%)EB{4)=DDz??es8NSwqCpJ|QMvG=tw?zDQeDyvJ2r@N8{~>@04f-KtUr z3hPQNw76@NXf4T#j8?v^NV`?{{*|mYh8K7{dw+;lv_E8D+K69Y_-D@9<3#&IY}ttO zo%`q>yyL#S>HOTtOX%EKUNoBXiruQFq8~bIVhhjKQ4m{p&y$Y~)t-M;tzr{l7c$Q> zva+jWq{W7f*g0ddu96e^lQP8fi- zh-Z!Cl|6$)w?C5T$IKrZOC*-Dl2s^CAY8{6#q&yMnbSc%CtEeQ#|r0C({*_nwS$ zrc|+t7J{{Y8`D7%m$a!5#3KXux%&pTH5)4)bJ+%K@%-aE6X|u$Ed{SM^OhfUS;B~M z+q!6<+SF2C{8H?x3X1ygjL8JLf5o0a;@wV*oPC>rQt~70)|HfenBT_!9 zrTzH%mK4}n{b0G^5A515qszODU@c|?vDimdN(D`^Ygbm=y;jQdbxwHuZ&0$#@}ja@6=0^dm4`kh-evD#fmuvX;3 z#n<4A??*2+%NNuzJ$L3aEMWw^j}#U9L#2J$S{lCcvuAif4~TEyZ;JlvP1dr9`2 zEnx(_j@l4!fiJ$-u5QG`tuI-_X9R26@qFa+W+uJ3 z*4LE^Si%VAjHm)U{fR#j-bcI;OBeyKqdgaIfp4C*yF6Hq7*%TeJ3b>=i%;=lv1{Hd z8vOK1iaES*o6izP!0V_D@h|uqd%M;_HF%$^7S*s2tOc&9V6?t&+Lih!j&5eSY;G(Ttb;$PZB4Grdogq3esUB3`)x^xqRmWusBfvqmAHp_K zgDh4tIGtY32q(1=q!syK;g0yaam8c4aM91eNg2Uf%+B#`s^tg$dl1!r1}9|+Bfvp* ztb_^TJ4Ndj`B!6IT@Oym2-afOk8e#LSn9RJx)Kk>5=KP6T|5Wk#~TlOL$R*9lyWTJ z&ItC}%6;6B`xvZtM?9H43%k7Ovrs)mM{WdM@K=t zg>oMw+-Jzi*$yLE3#(miB%%Shj|sVtocw1?7y+-N{gAz{+{c96$7cjKY6X^_!=g>?z2s2@yC5(XA(VolR*GxDFGhqwCT9Tv59u!{_i-D5_ z$Kicyf4S7Ec`yQ=LJW30nx(Vh^sFFurD`nXu*C8Nps#xsMxi zpKOG9A^0XoPKv}s!0Tu~#MdbIaYOE7Si%T+9Yu)eQ10V~+{dtl5%44*3i0cseFnk{v4j!uIyzS3Ei@DM!%Wyhuok>d#C?E}^S3Y)4p_nn zcpdGzcni&h{V)^u8NphS2Nz$XnQ#zh!U0Pd0k5O|5O1NGa1ds~*}h$##afYX7hj{9 za1ds~K1&z@ucJ1^TWBU6gqd)5G>~VpR^-8@+N+sx5N5*J(Lf}O0J~Hh;)U=#KZph+ z5Di{-{)q;HV6DjK3y+M&mThbjjNGH=qS?M(B#e-G5Z@)-5np0!bl!h5aV}~hSS#}R z!XuH_;`~jx=x}gSmM{VwRBZ@%#P8kV{O!6C?aT|Y!WM$HBA+ijQZr#+GT}vs+GhK9 zkuU-rRL4rVV=Q*Z@!4LO30nx(ilTvdaLt6BFcY?XJ0sW!$6_PqKC5+3=Z2;k7(Gu$ zQF>i{NQzf=?(BHeA~!j6deqau^Ip>vHhNY*qX^?ywGgaTjHq3|irRRtWqTm@_nEGU_qTnN!U)#FsCwc> zWi)VF+v;ZR#}hZr+`X(U6tVcLQ7MdItpociYpVwAX#hmtHmkKC2a{(tVFYUl z2bG>nY{XdgGKzR+=OrmDVMKv=Vg@ndK(w2eM-dg0T3QI!LX1@#5>3IWYv)zOx8<6r zu!Io}E^DCuz=#8}?w4z|=budbY6~M+3sG4|S7IZ+q|>^W+W07Nny`cs(>ul0#*u~B z0`W?xN?Nf_sn#)QF3L! z6qYbz;yvxu21Xno~swaBlCC*vlCC5(7^VmGyc5eK4m zz2|hSsx2+jgb}QTIIg2B(G|Qdq)>Ri%5W4U9MtDI+^8qW_MO7J{{Y z3T{#aMjVL2y$UL#!|Bo)Y=gBB$F=7YfAIygv17I8+kPLqnI(*X&r%yNA8Q4~^1KDq zMvd*UCX8S$cs%X7#9zEIGPkKBMt6KRoh6Joesxc^(fO?KtNMGYDWcl-XDkG3!51c4 zHBoMVcCU9%eaUUzI%Tj8)KK9oZ27 z3lHtd`IY!gyiwKiN}Z1x@0tvjFajQ0doI2hzqR%KNZnCKer#tUSPLFn5#m{qb2h&> zLwdeu+{$#e!CLUp+FJ3;I5l|u61DO5%4X>-VMG)Y#k1nOoK;&X;Vn?0dr>kTd;vgef3m#fWSA3@YMwO#Ie|1ambe1pz z9$Lprd@+1t>J5rmS?B7;j9{%OCW>bTLtgW(_G9479hxwLwd^zGF~#~RV*lY0jakBo zC?<+$#aYJDri!R|zF-O?SSyOC;$>s8_XZ~rZI%|uj+`Q4L=+RnvjXwk<=XS@8PBKl zvseorT1QuWCf=Kv_r12N{GM`6Si*=XCW@v1vmzQa&ZRv+y)s`DMz9t5L*O?XPYjSSyNf;#p&{!0njITqX9u7%n-yQy#rOZd4_OayxcVoRqEin#BZ^u~-}EqG|PA$~a)JATg% zig@zPm5o`#2zY4ix%e==h25;bB1#rr(1;PN1rMz@#LFVDJwHSd&oB5qnI(*fVxstB zFoL6xvR92!AI$HO0l(@*0G@s03>`gbBmfxYUEdachl>5O14 zKC_O+z+TnHFS&0{X9*+V3)MyYGwCd01on#B zNJKTTS6zwktt*nj2-box)DaO6sO(jH-tg@s=`3Lc_KMm_L^ZHiMKrEH%0jRfe4!%5 zhbeni#4BHxZo)QL3%*c$F8)l}tM>fp)dL!_gb`TzYD0XOvRAe7!B?>iMz9ukl=fV_ zld@Mu+)?(~6qYaozEEw5*HiYYh#yOwu@J0lSWC{WWX{E( zDSOqPpMR=t3QHIPmZ>(xhbep2mALnF8^Ky~ZY4IvpDBA)M2EILQdq(W>?pM%K1|uG za?vY?_puPH1z)H(#Gfg9)u{IUTN6@P!U*_6wIM!C*{hDOL^VdR7S}xWOh#AxlZa|8 zVFY}k+7Pd&XEKU-dSz7$!CLTziVz>BXEII1Uyb-VkA+|@_(EM*;sNzcMiEc<`%_L> zvL%dwFH{?esD?8cMQlF(k(}aWBUlTpTt`f-h8rFmpYV(SE$r zY*q#%SPQ;TdoH|Q&tw#FXN_FK`m!aAfG<=ViKvD%8AY5K+0sI=78t%hNm#m`$tdEJ zDb3Sa!U*_6wULNwIFr$ypM2|JV@9wR#}z%3Q5&~?>R2&~5ge~HD^eS8G+&v58FHc; z6h=~8A|oK@R$=^bG%HdY`P(*2VF@FmRVjVbtVj`$y^za7uolKhZAd>gD^kSnL3vVG z!iZ>9O5ZdqQt$IMRw5%<3(;EpCu62rk@h^ZckdLIFhb6)B(}?#X;!3&s;jQq!U)!i zR;5I2&5E=im%P;>gAuGH=Tt7hO#pCyctb1PX_65};1QpDT8RJRbU6)`7?*qRk- z&-1s*m15Zld7Ee<@@vl}-fLE*Hue@vN@oc_H{x0nu{A4FMC!-sTNuGwi2Q0p;=N`? zis&?D)fSd80+C;PE)iR^B1PoCWZ`B;uofb}+K}k2S&<@czW(qgmM|jXS`x7}E7GwN z4#x=Al9~Ibf`zO|5$!?_7fBc))fkyUiSe2hDI(%<5rVZMt|dN9vm!-gTv|vuT!dgP zsm92^OZ3*PNPAwWa?K_zVFV(-wn}`MW<@Eo5_i|0Vj)-ykzYqwqPJ#63D&o0XbMXh z5pgYv*m~Bd{kZ&v6O9?cS{zsOtWP<^p;t{4mM}u9F*3U1_4KSy`*Hogb{2xQIQrnG zQ{h(5xJ7y6->jczu(v&VmU!t0w>0x+Zrc%ma?Z65Jhc9M!p7{I7r9R_Ihgt48xssm z7y*B)2=SBhYl36kfu%FkdY`@1u!Iq%mbdY0?${BZP*V^K=6>&1efCu5?cIi02-do^ zZf9*_^31A0oLW%CJXN7t-1()ugHvB$y zV59ZG&T3=W6~!$CYheegjYO<1aAAnKan-p@bLcylC5#wQvA6d8=?%%S@qO|{Q}L%` zX)Eg7Zy{II2L>U`>Cd4?aGKN*w%uAHV1)b~X<^SN`k~Re z@YW7yRKxt4Z?3u~Upq<*-+gA`B>)!8nx+giUP2-6cg0-;c)rNR){3h%18SWdVa^~8n zCIu{E#F){E9gH~w;t$Mq^9jjmn0pJsT8IK_L%cT-)jTgVUO5wUAFzZG1wL-9{lFXn zA#>gBgVJf3dkeu@ayBM&F8&_O5_A1Yyyk1)9k?}M$;Akn%><&x$P&Tp14_Jh;nq$T zLeP;G;)dD~e~#|oRZlsJG8I)RiD)Z0t+kHl`)+=2qX{!#_lO1(k zgNeb{6`RF()ws%%nRBjH8?7D_&wwZpOnESG zym7%+4kK6#Q9x}-Ou<(j-_!2<;29`!|N+&PAxgtWeFoD zT~S3_^=W%SynDKPa2NK+=!HcLBUlSQU2P;dL)DXAf(KG7XO2Bt%&>$Jmpq@C^Sj0f z%h-y!egQTx_ZEV+;M3Jcf-_9PT(?J5Tkt|h!xBbBYwTF^b+93GU4PZ7Ow7H7V68t( zr)bYHM?l<PYVy!67!!rPJWKuWtY^}mW z3sFGFN_>AThA3b%5OF#oDzJnRQKS~{onQceBC6ewsK5x;ilVmo{saRk3>&{!|AzsH zgb`7s7Vn*409*Fi3_uX91)r|{kR1R7=DIx?fJ6nBT#S&}Od!f)t}lQAVD1A($lsAx z6z9eF;}`PYOfY~w!~jIX2t)y$5AojW8I%D?RN!Z^mdv$8HSy%J*rQRA2;aAqr@##Cxk}aO)>LgN0zNh#80{$4Su#tDI{So*|n7 zh=dV{0%}8iF}?#?xRv)%A~&!QtR*Lp(sS|su^4g#f6wEZ8w4z21fqZ<#FOJLXUVy~ ztX+u;EMWwqfVN7!cPxe|5cHgpoF-9$5v&!(dGY;VODBs3!pI~lu!Iqa0%}9Nx9;=c zIqVPYa|^**QPdXSkBkL#{eG#;G>HlS1Vp+Y0`%N3< z&y>87C5+%mt+|2o&cp-*5auZBN1_5^g(GJG;^X1D_22(t04dl}k{7as5r_hcPA~xE z2JTBcG&isitOcL0HpCa(41gsUBgD@q5XcP-a)SoQ4GbgX???;#Uv2!u07hc0kQ*45 zFal9P`yt+2a|2}n$PFw6YY9J;F-$N3u+`xRXLURL`5v+wB ztgRC7t)9Vzxq*dXE%KRP=hc!3wS;7eHW<`i6*WAEVYNNS<&k{x;3TUguduwiB#8XRDU<7MnZ>kOP{rKJR zlU>YUtX+u;EMWwqfZ7o6Z8HEyuom{F+7RE5dMy~hCPV{?3M^qnw8q5WYi{5kOymX@ zg0&C@v{mB0Z3e(LSPT1KTP2=ca|0*L4FZ-h0#QJ1h%eUDLGSXA0WgBK;Oo`KKMY_1 z<{YPk0ZSMGKdcDxjKaH`HsK63NAPVUI5Pz?ygZ{q6>7a#REkptBhj?#29Snpm z;dIbKuoj|#wn{uXYAnAdYwnEG!GI-7dUNM!>&oKO`z>ZV(j4 z=?6{+9Y(O0oDhkJlUz=7gWwP3I5-`2Si%TI0kt9do8|^Vn|Cxfa2df`$m-Na!YAT% zFu>_xeVh)uEV&pVem;S~>0p4nkHBP~RML_d@P;B+u}6wv^ugN7xHKpv?5 zko-+g2ldoI@wTo*OrL?mGd5%bMfTL^xY~6rf*on2>5hGh+kHw?>0uy z!SoGF7y+NI_QV$}(>JqbD$}Ld!iXpaizip6Z@z|a z2h-1v=OST*RD5Kt#Ct2#H_s)QzJ*{d_;l@u_+n-HCS>~A3_xtKRur|xlPl9V5(mNb zvl)O$7y+NIHpCYz(>DbYOy5GVmduar5Ao#6^xaPrOg}rGi-ZyI>DqJg#me;Er_LzT zw-Brq#bEKv%Jki>IN1i%&yMFJVFY}-_CtKJp7nVx6HMPiuoir}_Cx$~q6*-{mq-5jEhxp>e8Guv)_ z;H@AP0rN6unoq{SDv&MlgqG#5nos8sW!g)UT%xzJ67IgUf;+RPkzQ_1Z%-VYtIuk z`r%HB`DQ{HbIp+#T$V7RT(drEqf-glQLlYi+6+6>+FW`=F$=+3*j?HW@vQioRsWvm zvsneqLo+7HzRQ*{;^i}aymc6>tn%U+uB~;A`RwTqrh2cz*{AaIEY`yA(kF>$#di(X zb~ZQ8EMQ7c8RfHt5yMh?tBo!niig{O_9`>ra7)wljhPmLwcw$3tYkf7C6*{>ZW>g< zG^}}2tNOBjM8XL8TWyv2$yn^M8%vmnFoV_d?e|&2h>G31YpdFy7Tgx;dOlVx zU<7Ny->MDqld;&tvv#|sVPogZB?Fc)qREAJYGc<@iE8T`HF6tbCEl>2mW5y~_*=Ch zeiAcSrIGtPY;>xT9I%8D-=;QG8#@d41>)?Ad%Xj&ar+zfEd*=9->MDqleoV-c&X>W zM%6t{0+ujhU`k`PF(FSMAXeYM$lu(ik*W4UB@4k?@X$Jg;wNLVm95wL>#*9#Kbou+ zcbP$vFyhDRwRN?xeC9?VuK4Z^Usn74T$fk~)`Gv)eu$rp#p)F(6?C0jz?5oRJYWeU z_E)=3ZTvn=SYPrJxr3QliT4lP>NA41q_z^RH+&aoO{3r=ti*K_m-{SX1pKYGO8g{# zF=S@-pijX%@NSR$EMdg*!q;l6<~A0@v19dvR+zz|y+>xB%FDA@3%Qiq5dRyCHSU%g z9D`zvLjI&SB<}(Dd4IU~(~&7=(GABu zmN4SJmKD5rTJMVg9i#!#VBYmXheMZ`$@Q{aMz9tn#Lp zA>-0kNv4GNKiU)zwqeds<}YPf!U%Y1wSf@_qTDT8{q*lznjKH%wh*iptzF5jz{n;n z_x~DP!CbMXk(>zYY7hw{ka20N60!E|6Oa3{Yww*|-LQla4N9bFt1w5fQQ@YM-croR zqsQu72-f0=8jEGdb9vQ%EP`(vcDG_JBRD3)#>ApQg&keZ=wI?UJO>YtF7J)$uq%E| zHL%H3WxX{+cE|nma*K7{;v&JZ54)S@Xa9&Zg0+rsEa#OTxI2EwNtyE(Qd$I;Px{zB z=3H<$Pb}mOnfX!t)bg3m^16AwcaQImZ>loa`K@#zuhp(S@oR?71mff46@vIxHO;x) z@3_Gm`Mff(?}{(lJI6xslkkp7>U6(N*E`JI1#!1>najN8*YAoa@0jf{+D0s{m=s)p zM;o(l@gjHn>0(~|fnD*R(kD5Lwh^-@Hwm6NT-5wB`7H~w3iLd=*BKk4D zO4s0y4j;OElJ~kSVT9fDt}F5eudM23UTECc9a^-ccckO4_~-5jhg;b5i>tjgJ$J`f zb{`HK-G;Uao-dnZDxF&7vV;+A8^7ZEe$k-Zr=84opI_$wvapPI;qG1Wm1lZ8Y>%IW zyXQZ&3A*>b()9RyfQ4YKuO^oES`OY5&nhG_>d$$X2f253H4`gucX%zaJzg=f*xoBT z1QXn=O_{YHTPuzcwvEH9Z}%#%+wDG3c!bG+GR@of)Smdc$I3Z>Y;Egpy?1x~=Tl{! zWk02PHJ|w?zTvX+Xw{gy8Cf4sY;S&PR?4u15pULQ=iS_SSA0}|LA<`=R93sC1I+HU z(=H=etN*UXUiA^X;zNg50AjGt6pP(<_YSXUmyV|8Tg6T34!yk%<9EgTZ!Y6-KX?>k zv8PV#_DphdQ|`Hr7J{|*59#ft-nS=SY)NVKdQ1Kjx}v%^)M#HmU$^-}77N|K5GL1m=Y1oV9pn@#f@|QSMT3>R#a1 zVA&d3gj<87D+?2D4KA&J3!{z2ZU|<&&x6Ul4{pr})&gVJ*d^Q=tG(t)cO;n1*5c*# zrju|SkuU;WT5Ss-#+w%T_Pa9>)mDRBvxE`g(%LHF!0GWUX07k*$-c2*?MVl6Ocok3wCu~?_alFjqu%b1$iuhW}OGJ_&v1h{m>uYl;4 z>k{*#aO!X0un?>TF0Jz+d>CWZwzv^a{c`?N0ZSMGE*PmKZ3%E5SSSwnI z!l3a^+{ESXQt+!G;MOc*1o*GEN*FY9oF^W4uLj>81a8d|Mu1C4{0fNTJx4misn3C1 zGlI3irPYS;VT^9A8y&H63f!6{i~yHb8^VFX+b27|*jNj0%?Q>4msT6XhjB8|^_X+{ z_fyP|{&0^ai~yHbPA&Wkk+W`=|L512n7QEAj9{%O-w_UsH(z$H_n!o(zNYLspAoDD zF0HK+290mfVXn`Fc@#?+5oIsJXz@M{=K4H1^)BR5j9@J=W?H)~FxM|(&f6o8VhJOn zY)Ke2Y|N-0Jcaq#0&dL`Mnw6JaA17T5_A0l=Ho=qfAT1K7Hgq~pgk9EjW-S#b`M5@ zQ#W^VSvQXufio;U%Mxyl46@3V!S?kX&Hghln|kQ1+A&+c?t z!U&vnX{+S43TMbgY6N+4;&nStzF5MDDBqD2DpcI>Kkv(l*TW6}wh*j^lPaN^ZxKu^OGMsO6x z8;3mxJDYLh)e0wHj9@K}sQ9Mr={(*Lm>6JXe0Gtz4rk zc!P)TiZ9$Q+z0uSzXo|!s~Kf|Xl=rk+Qc=;rx?Ln$dYsh zC7;sTg!@LKHla6uWd=pU2sxjYilXE_TAOgg+JuE*E#yzy56OYFHlcYGY7>Sfi~y5W z8zce`Y@OV*?}2mJ8nK)@14Ab-+UN!Fw{2mJ8nK)@14Ab--HOWvb52h{K4 z=77%#) z@L9qLrp=PCnu~&d5pFZs7INw9_1r{a>9C)$7maYdXx$4Q5J%=kUu$LJqrCOSG%jZ zuj+^H^{IP(mN3HZIqFd+tVj9CpPaBB<#7vs)`dX%-|7-8E$ zJ<5@K)RV}kkUu$LJ<3DH#0l$B9`YwAtVf|$s7J-adQ`v?Mj(H3!g>@Cr&gSbS4BOl zQT(*e2-ZS9%K2@$)}w$p*rTO;{OoAcyvnma@+T*(M|qXb&2jpd{3zZvHJ|rx(K${l zR2o~{TL6gj2M@YWOuWMU`F8z)5y+n$$)hAYVl@12yWiatgHO1}?sfxIEu63(rv^il2-f16)!IZ5)+VyEBH`9x+{(O# zDa&mhiBZBrv^EihwF$kM8A%ucF0C<2IIz|xg0MDWAy^AsTH}iFVXaN*eIL{&3`-aR zF0D3%(c+hya6TQ@CM*PNA*<7V2!qzzgdf%>q)L`8VFdWEwn`YZ@+&{AO&FFi0$f^K zB^+356JA)Gun?>TF0D3%4{L403u_aGC5!-PTt=`K7&EP1)Fw<=n{Zjeh$veUKCHC~6V@iO^C*!pBFcA!18Z%< zNNr*d>QekH*5daYv^HVF+JyDy10zroP);q3R%;U`tWEg*EY^}+xUx5eL2GTogtZBu zC5(_;xWe>>4{L40gtZBuC5(vj9pS)Qo6vifs7+W1){^!q2-X6Z)}9L=*4l&{)+Pd$Falg!ca(5otxdRLZNfsZ7Pz$bL-?@PCY-P~ z5wL_2atl{Zri3+XZNdv{6BdHCaGs?$gj?&)0bgzo3}4$hU2*!U&vnsSP=y(whT8cyqu)uvV1G$oUo`4sPLwHwS!{ zFajrCYC}$_^yWYyHwVtFyf-^5l4r42lod%t)tdv_k2-^vT5mouf}^O`qnxlFW!Trs zeLl&fgae~Gpm~(=UbO1;AMvmrWg%Eg?(@mLKFO!>3wcG(yL}3L>^kW!3@S02M@c5( z)}QN~?Yk#FFDb8s{7LgD$tuA5X8q!RSp1Yb{WI5;jL&s0@ApwW*YkNT1V0Jil)X^T zz4`SC=BJ(mUAfOEd6Z;C3)AN~3#aaiAFh+%u@Qc~``w^$KeOV?&s>%;f^8$7;|5?@ z!7`}!Xyzr^3fto+;nW#70K*EFgELiVF@Gbu|fsQ4J%kiZU9Q= zB^eR75OpKXyd>|@3YPmvqJm{u!U(pl6)ZQbU>VehH1m@DhVAi_w1VZfNK~*a1Z&9+ zK*_u$JJJf4BNePqk%95*V|%>Dw1VY?6)eLNM%Xqo7LG744cz5!=rh5|c(j~zXw9B@ z(hq6Q>EmTJ^ICSSts^%8CG(Q33BTv{SV`0J&-Uij=-2gbU?gFL+yIo!OL7wAoX%3} zf@L9)acSly`4pqIf@Q)AmWw)v6IQSsMzHNz?9%Zi%zN*3FmKOP%BoM^DsoQSH&pOR^QV$4`pIe$6x5 z%-(sq`{a<h5mI*6ZF3%L(utygaEE86+tQlkkw^l1yCahpt2-f16 z#VJm^=~eMf*`S`6)}Yei$DJ=40G3!4@q&&UMt2*_? z>n5Bj+z<$FEf9VG6*0VTN_=b60qbGQyY89$;^VLBw|=J=?!f*3e+nlp38-|MP0tnFR%!)zztd*7@}dAj7kAWnUip0=;b zpnB5l#xJHe82ja*dY8Yl470}nanB3Z%Imz|y~FzJJ1_9?Z0GM`E!V@=zu4gMk=}2{ zNMkgz*n&MD(wg5No?B@s&M!}O&D2P$W8POeW-!rUb zy+RnLuut{Sr+?`=w_Rp#^ZvH$`aN^szx0FWnP)TSR?XVg)cLqmy?T4LjdRd`r_#W* zdNQjyBe?C|6RV^iR>{90@NBQZu)_7kRybC5#a7NXq?gfpl|AxExQ=s1V0@h8``gxC zf|ZW->zMVk>kh$s&KZGzc+tATYR~%b)qeQKZr-PEm3rmZ4|4E7Z(q}r_3&^1(hu$} z=3Q1yU0K69wt*i4`Cl*pmNSCa5@s_i z@~Al@&~^`PuPZ)Rd~c)wT)X0xrMF^B{9-f;Iok+ZqU$Pb;eYJf@VRwmzjF)!(}u;&gOdA0Xf@XF2WHm zXG9ol*JG7bp+(ZtZWv_RzLM8@?bnCm@!|_CM#fCd=85b%IqupvEcQwshy99;*f#oJ z|B3XfX=~d~GUs<}%$i$qNxaRQOC0vvlZV}xb$!a>c>NcJ4X_PfOW{gXMsC|+1jkpl zku!pQt*pfV=Ayz0qo~HSnX?UkuB`khL&_P!+^|8y2d@|Aku!pKDtpwN5gd_2Po&K0 zUlG~-lSd(Eghc*`e@eWKc&UvT{P8LG*}HE{`}l`}*4pLudc0yO$HVGLdT#bbtAsr# z^9S1ouaa3^_ikwT-k|lu-y}0ag!_-t(TdBQzGohakIS_|xM(($Jvw8VQ>gxu z_~0+6;+I%1>_6;WpQl>d+JRGpMT3*$GK#l+`nbn7o_&5}TAQ^G#Xl-9YpmmlWK(d# zU#T^pX=q9=@8(Q}q({mihvP@JWrG~qwehNwCv4l=%iLP?+O)C_@|!+?%yu>c(Y-`T z$%?Y$hwbfnK7_qQ9**bS?l~h&zmzy~r+Q%djy?VN9M>O2dUc(NAy z?cRcXXZFb|=J8zB(k@)p+#HHG)19-S>Ntma3D0aSR^@yZGo|shY27YsZp}F(c>Zy| zTf3_Nz_eX%{?uXSx&@8AT!Z$+CqG->;dh^Lx=HUlZ#}OU+5JHW3&C3U+sjqH zo$mj#zLd$8(bQC3P|rJ;zB~T%LlqsiXCp2@KHXn^zcE)yGS-XB?;u#PdC$@b7%gm}kxv zH!NX2@anuXjY!M&Sa%m^;UegJ3jx-st()ZC*j6#`Y}K5JLZxr&$*0Xt^99P z)7$vRHdaHc{Ep@P_^gSh(vt&SmN3F@)yv;k#hr=E%+Hy5%qxv6d+TsY{@Rh64%>6j zRMwtz1fo%g0-scz1I5#Gk8hS`e;?dKL+^g7fduYZz<=pgb{YDw(Myi zj9GA*DLHq(o9C(W-hu)6M#AHb9ByHYu@$uE!&f(gjT4ub3Tl38%r7VBxh!D>+r~YB zzFmR~a|@a6dz-uekFM_y%i`$XUK?UZPy|IlvG=Y#&&~{3u=j?&*Vuy*1;t*#Zp4B; z_Siv@vb%PP3bxo2Ta3L%qlxu9vq9f`e%JN=r5ESi&zYH>**!D&IcKl}1=_=3qPQhb zH>>K6cEaQOJcbC=qBsVqU-D+lM04GDWWpy=vu*q&nLJ``r&ar$Rxjo|qCNa2ih#H=&ycHDu3ki- z7G_jx+CsNzVeAz$Ok zwz2HZYFQV@Wlfaz&mF^2^Gq*>5`3=^(+-t%lBZ0v0*9S-_+9%=L7>*{zls|!{ra=9 z&DxTU%YNgn4cC7U^h^1U%lyEv^;W$06*cbq2eT$U`mmh2eT)*lv!qiQO*?_tE2I;(FoTd?XwTiXBE61r2*NB;IDWz)Y= z;^A_pAW(~9aO_`_u`^WrbA_JTmk=waxGD&WQL?pUq!8Mp8&=2FO+Ecjl(>A`S4RmV z-s}%I>eZlF=bjbF#=q4!CcpR+C3gDl4nzqeDDKL(C*!5)woRW2$?`@|{q3tnM4%Q$ zgW01lqqwM=)wH+C2HB`PUUMP>wXn_%MYk3I+r9(zQN8?iQGy7HYO`&~Xf(>sy!Q1jeW?>?qF;OOCd2+|+SCgtQpZ1x60u={I!X|Mqd-;CLSCfoeIF$p-7D&i ztCTm|)HFq=bBJhn$H#_VMOBqOW*Ym_^9VLiR z-@7k+FH4d0YE7MsPFx4LJ}8>U_NirD3*Ab>=cdRNwf@}+1%X=l9-}z2x;d?qo`v)e z-Ma{^T81^RD)gLecs$q1u?Sbc5-zYNedE#wa~U1df8H z`IpLL&Up5dqwKngg7RaGuhsDWY24KCnKX{=xkwpww>uk&=Q^@{H)WJw-=^N?!e2d; zql>Eu)WZHzcIM6R&7l#2j*c7Do+AQVNSQDDdm4Y|SsduTHcBk~*J<32k7ggD{TQ|i z&$2YFPVrzPDmcPXfcziS!ak{pxVMvxl@;qbo{|4vB+6;zD-gq$uktf+pDRhb|V@srfnE3?tXP+xo5-uK%aXO>~!R z*6vTIohy>NS*SH;?G$5(Uf%zMNX+}t91^uRa9fT37S@eH1dfQN9ozrdTsvTR z%GqiCl(9lYyOasWgGYe{aQ-z2zTD< zZu^$AyhDC#Fa1>^nY;_ThZ2dEbjW+M|1NzB$_zU>I#qtv7VZd3q1}SrUilTHy@tK# zkY76bdfI=NUpmSFtQ(t>3y-+S}N_CHFeC zx3O)=9R}Ldv|CG9s@&srzVRgm5vZlKR;xe3)!ySAU8^fLXm4ZNllu_Go|ff-uf4}P zH?(dlN)Ulhlh~z3yOQJ{XMX36$tXdD`pn-xK4E{)Z->lH!Doxl9Osy(9nM+7SaLK< zIDe~|`qdYedjvd5!?Q)Ic(SgIz5j@;-YoU2pD1nNj8RXc!pd*ATea)EJx-j7*ut+~ zxZDX)-d&YjDROVnY4~sjfm&F(hj!bZ=ThYUqgA&iPLv>m@|d97TlI}esmGqvF16B# z6hxqw+VjfZ)s5vxqQq}!UM8ai5jcjL)}h~$)TXbaM7cBBl(E9mRS{3Wyi1*YF-pX| z`dLBXh$33k5^rv@caZmgeWa{fT&akrywolgQ>EX2v(bwb^4D#@UHaBo`3tLO(FtU| z^-k%xzj>QC6(zXN@t0^tHJ_O*{q`-~O|gLpd}B~OfH76=9m@}uV*~NGP>cL^dv*O+ zWwTr@!}i<91}7#X0=3jP!>1XaY{cd~?E+DP2<084IY$TDyP|s6PNXVrhq__v8}8-! zB~H16?2~aj1tpXP@vYJrRs^Fgp$=Y7xr4mdXPbgRE!sKSZ&5j0G;PXyH;44}M>o|{ zQGy8CIob%hUnG|F7r#KcgLJg4mWmQY;Cr5CNwF?gyUS0UT{h&jqATY&ej1*Ir7fGR z)YwuhG|lWV#0tG#HD&q63no^SA`jp8^rf$hzeGF8$9=86wa%t2(Vm$o!QaJ-P{jPS z7-il2(1U&Hx=TTz7TVUd-+PB!Z=S!bH#$qWiM6x_l=d-N7j4bTx^*+=Z}K+AMmtyn zdH0Y}GJP19kW=!@voE$Zv`&;8YjZWv0^Qo8g;i8Hx6y@#kTcoPyJZ5I>MSYb5_c_ zt-6U-%TNnPf$pd(GSPZ+?WUvLD{m7eIBxh$#LAHOy*lW}z!ndB8fBLH8GKs|3l4M| z*eYbqG_BNuD9fen4~}O;>ML6KyDEbCbDqhcTg1SzO#a-!S`AoJLDMQfs>kDrKYx9? zg@QmWI?IA;3yuGn%nuQNUMeZbA_k6S^5+KHQxQKunaq9X2{GTlxPm|}I)`PM{5d_3 zXm*ch^5+&Ya4eHQH_#q&&@7WbCj{~5d@%9ndp~$uC_#j3V=T)dekCTZ+>5dnF>p*W zagx14d-zLKvnK8_uTM-|7GmTOfm(F7$Yd1;`5m$ObDqhcn?ek&kbxLF>rRzC97rg=OU9o zH&KEJv~Ba}B9lKiWrpY1iaBTxf63<0g=FIVug^0Ofm*T-!2fD&5r57lfBu0OIg}tm zZB@lt1I2|p8NC0WWx0%Yk*ij&Cd~6wbNlvCv`fxNckzx^FCtJ&&D1-0Fi3RXS4>R0 zJDJO97rAQX+Ca35XmiI$#2gl)#6uNZA+N_WeqZ5=5vrMm;Yp0$u)yf%3^%GnB@xlu!>LsN&M+qX-RuO;BTl~kL3mNSqSFK!4*g{&ZcF*Msu=#VI z$)5|9AOdaM{5jwDAAc_B%<5ZyHrm5qviWnC$)76-)FK}as(288&NBIPjuJ$ut?D#6 zN;FHT764f0e&# zO_?N)AI@;NW&2%0pcb9Ku}uD)Y|#0eczwm@&rPh|MCWfTlRr0z=VqDwIU(r$O&s{y z=Fb&`qNTKo&fi1{I){_zao7?>sI4kBWrXOMm^EdFOR`dx>*Q!pR;+y)*6PX#11no; zT3SH3*j2rO{%mrDQez7d*gs9%_FI@(U~Eiz-fgMU4@BT7(A}~fx`@tn4!4NTwNdCoW+(|H_f(fJ#b=Wz0@OhqiO(O=9uwJ%Vf$DsrfYOSz? zUHgd=w|`7o(*B7uR)`2sD9H9LEWnQXPoO7QWaW2aS9dR`e@L2wKrI{vP1`-^zRmoU zzE&@g&fl11+2t8Sg_x;Ka#L@IdKz^8X8&EDy3wv^*fEFXai%n1nv4=e(D@t7tNjY5 zAW#ciYqR>gWc6Q6aa9m>(q*?@o=TxTo7LAPt6%qyZz@USo6+i zEp>VNQ6gl$6D4Q^f5~Poo$~afW3QRXC_x0a)@Cg!0)4Wm5?WP3pq8x6C1;O3t*}{3 zd+tU%4F!Q(>MSAFGD*(%=|6o+MhPNtzR{j@^q&9D6X^sGSCp*&BUhI^UBen{Hmh%+ z-1V(|D+MKpkaZdUR~e5Y2keu(j?d4gpac=J2GalPi&AugQ${DO8ge8R5vYZ`Fghda zlGQ0AAX+yGHc)~HI`y>s^xp{!#bxL+E~CrSO{vsNTOv<5(Z^LieTvJ_GvhK81ZvUS z$o{4O&cZ1=L6^}91NtPVpihl9(6-Id>yo3NF)lC_=MvU7RA)QIGuWq5%fpvAQGy7H zBKTj8P8rW&pGH+}=%pY~i^j66a|4=Xq(<54ReZ4^?mdtV&brpBh~@D*l{v@P&_M3@1LKZ#{=(+JSn|u%E!PV zmaL%hwT2;`*s)BG-cS)e+CAf~SH-5ByVTc03EIFKgj8F-<}Us>>%dlaAEh8r3vCk{ zzAT)VNXnKxiO%D&dSlt${n@D14z~D|w}I6cHErnoRKEMm5oh}+!<9Oah)`{uy19Yx zO{$d=9zRCuIU;Zrs0Oxo4&FWc1IM$A6D_P^nbEc}`*eK*>!$e`i8VU2+=V-{Uq47r zcw{qg-eP<$$ z9KApZB8c;bYF5P2v&pAzj$T2amW-*Ddjsi>+Z;V#T*T(+1##HlDj3O_TDdon-Z63M z_TE7H@-|1$2ko{wdIf=6XwT;8d9%JYN3T>iLWJ5X;^=wZ#x_T<)QuwMn05M(qi0xq zNz)j`)VAMabM#6TC`4fYG);_J%H!&6N|}BwOz8(Aa1^Na=OS0J`^dP|JCi~MR{Nqo zfxS17-oIMSia2@^^4aF-1!~crz}_23|6N7Uo`Bz(V{`NZC5TX~S&eD`gjdPHQNQH6x8%tgv&|T-qZ*y)M+kQLuSuU3%g5ehzbz_ig)8(!V9| z+xC#9-%H-N{degD*}Svug>Q75?^Gm!kkMe$3zy!M%{$v(c*}n4?3h~Vt4VL1yl>lv z^nuZy%{%MT3y;12#fi1;P)lj8WcBUupnv~mk{SJ01q$mFef_RTkBw|N%iFx#v!@jt z9Em(5=3B~Lp#dCj(})!{#9Mp%%2(6$|AVMlVTcs^c1pqAPy;!<_Fn<(Dsg_Cv>_HIIY{8)ny>$urmYKq)VEP0#P zzPDD^3zxeJVi&*Fio;*Bxm1VTO|)iiiVZ~IleM{2dsoybFg6K)3$B|wi6|YAkOn!Ek&Dswc~e7_SvQ&P)o+` z%6TQdTbq4NmcD4grkcD-z1^nkxf5P|1&)T+lPoH8=g zFJ!J0XFINQoX1oHJL_J(6-5rDHTqk}QN+AUVHZ0tfLm58>@)ne-G4a(K9bHCF9#}LjD#azU~>M7fbtditd&@?=ozUj9dEl zhQm?#x+{`gQ?ii4sI$)mvf|hab`9o^xr$3>jbY zt-c^?A%{RQkRicL?m2UdtWK05LhZR0;bd27H?ikfIeVtb`6I^!#|nRm&P9uL5&P~w za@O0BQ(XCwkGb5@SlXe_y^HRP~{*@yF>*pmlDa>A$^JBkX9!;bB z_DNrH`uq5lX&;{|2>e~Fc}KS{R39a-79PP)uiB*`Pz&v8T1=a8@!;|L1KC{0D^&!s za$wK3-t6B_4wh7L0JSQ;d$dsMG_{cw&FL=c0`G8ukCw3t0wR9rvNS$M}b#Rj6) zo_BpUT%4NKH>Ke1R0V-rYCq0}OcH^e_B%QRK4bYl`?Au}6Ii~LjTtfSER&a4YFlbr z(=`#|_0ZWVyI1HO>oOuj?Z*V~iK6DvTaKLZ-W(;^bNnS*AMsJ5{F0%8h4n@X`$l{{ z%jD~o)>4I7nztt-R|F174ptD@C)Gye>N94)kL8>fwhy&Z1O7;rwG$S^FOXFPztvSc zyx7H%6%vyCVhpObKpFQDxyw5#1K*Mv_Y!eJ4>qutNg&lykO<1aCq%m!Ev>aK<(;?P zy(~xmF@{9E@(Gd%SzXSqc=L99lx^eIF3BPnn$(|d9M#piXoXs50xKa@4>;s}Qelby zQ*K#fNTgD8;N_D#M)dTB>|MhT9;fax^VqP-N&UlyDG1cs=26Xv=oQP-9rEkkZekv5 z)%7h&r?z$yR~uF|LZ+=`{kxT6s3r5Czg92Mv}*Z3nv2KzQbiM_18)Wh(LRkQT1Vjb*=1~z(3|E zDis89CKBWOt=<7u9^N?0x_)4S!*%;ErB#SPdzzLS8fg{Vy2){>Lzfol1B9`%M zZ6yMVB02g*SY_S1IJ&M1RuC8`rA9jC+&981`DM5xEl-k(zl-$%Fbkh5@NXMxZTcx! zQkP9<6a;GZ*&kp$U$cx|{w%ALP6_U5bxz;!nC|{rsRw{sh^EoqUBlY(vChF+?dn?j zwTecI!znXoF$`zSRlPwDzSM zSQpO3B#+uxTB#s_2&^WcY3|MPSk)*0npF5yCu_>QYQ~zTu`KabNdsGmS@;y!*RHgc z)PI)a!`KE&Jpe?gt&%rQ_53m;*`jrVneenO?DZjMp=u-YK!&MTY3}@3Yp7xa5!hNy zo4+{G++2Ob!4vN(n*(z7aVDZ|O`Dx>vl+j5xnp|yF-oftf%Y`*Zs$s7*0Pa~5?vyd z%zMnVLEBWzqDGk2uVNEcZ2dBx(9F+>Y7xVHLUe{}5LcI`O=~g2s-OEa3;1U@$2@vO ztB4=hjkDTpu5K2HVjT0RF$+8GKwe|^un&4v+7x=X-<#RM$}%{K6<8IZY#;(hgd$R!kF?H|d&|n!_*oe%MBup5 z*ExQOe%{95cMdkZ)&P0Z`K_7&<#`PcwH9YD$O~M~CQyP190j`T z_kB@oWb+REbW%f6WJL~RSIl|++2C0WM;9~lHEm9x=GNfTfqY)CLIQsaZQvM^jplw< z&~!I(`f_d|Z?w+5k&R8S;cDdAm#%N|p2bufdb2E6t%`r}oimuyD%3(WML^WuZN6#U zUpT*C#);Ln@4}YnG2xeU858bg=mqP~WN4dw(W!r%>UB?JR`m>hL(XY5uNJ!$F@5iZiRFhUbJW5X(#eq>jW17NG_7^^%GRO1 z0bD&`YeyMTnhMlg%dWxG2B3 zfOs{&f^p)`8GT6I2@G4J+IaWyxLG?kNc{AptAaqStDAJ=<CG9v**bYzyK1@~k}d3rJGkCNwfGY!)<@l3mRA`c{3Ur;d^PKA0f_L7-Ms|B=R~@G|UnAy4W@PFGLsOc{-by=);+f(U$?bRO{em3gJyCbope z3K6KKW_nvGS*&io3h}skT?9&Sq}8#?ZyYzbN1kP)_77E_Anxw*X;Q4gw$0|)MT^QPs3z~vjr1zF`Q1q~U=_%*ZmeZFZ%i9LVmcMIkr8!x@;SPwds=_N~Y;&n>&j+?EWs$cwbSNtpvazk6pR<6S6gvBt+x^S0fuZZyASd&*oF zjj}NR9Q%Z$plLo0^H?QMhVZiQ+gSL!i(WM`CiNT7hEDhWKQ`ja9W(DP$jV>V9BQF; zM4)ZDQK!dtbFA;*?3b}4l~y4FTPSa$DPT29O3;(vw6|6do@LaZbYAyclf%Ff#c`vP zqIC_dN{vH#SP55c^hhu`#*ECS=9!(@qrscEtDVv z=K|ej`l_&1;2h&MKQ~oo2_o>x(mG$3+j_ChP3-RNXDzrm)Ao?1x4P1MmeFi!x}GP< z)%ZVS^^RH3F5l&i(==stQ47(;jPljyz@dG`vOnWZ%=+BdW~R|UBSUYNk;_2anpP~T zinVZ{n|QVMxrq`)s5YwBnq_7?@Vyul?Q5bI+9N;L_IKrXoo=bTUDRsr>kz~EX4C(t zsm9HE8M;RWFQw=BOXLBpO*N|omK51aRk5r|uEyz3>H6-1G$Wf9F&@U#|KCqF$TPM* zWce+oQ=HU%)5{YO5UPiF>;Aw_>Pjewp;!N2KFz1=MhMJwTd`R9>Jm(CT~&0=1f4 ztYC!Dh=jiyPlzgG)69Jbh3Ik9)j|m(h<%4j2*bwjH&?m@i_uH+DG1cU=TFg}QGc2p z938}iT0fik4Dbn(XABht8m+x$vbeS)ZiLQ8X1$LQ@n3YN>PggFe%oaPNC@sY`KXec+nFK2i0(rF)IC62F3p2Ij5d0tN&FpHkx?+x==o$FuWb9&daP=W|~ z^R&EynxYX2QD@g#Gm$*x*5oN;53%QnrqM0-$_&Y!L!?ZsV<7^yaQ$oAhkXfb=%L%@ z&ky49u)O=lhu5LM_b5RS{_`D_e7$ zm*zF+)>Zm}PYwG-XUNN&Sl1uCVk^t|C}V|M>gayy(cIE+X6N1yiYN%wdRfrZa1ATN zd`gd`e&o8@!s<9!;|Gg-S|~vTK21&QTDi9sykjeC6!+Rh1ZrV^wWd}5r-SwUc?mwK zf~$oR%zakJ>cILj)|(%?vPqxTDd+9@1o6pgTEkWmqWz7Uj-fRgn3z#DZlpg8B_{6I zn{^D#)FgIkVR=6G_B2PGagj>)a_XZxtk05#Y2UR7+maskD;4OgtCM zWa5-eX39mb(n9=`y45k}yP}E>L||)eCQcl3v6(m}7nv9amdV5^nanm5Cu*FwnK%=3 zoDqTcG;P6$5u#_?(}8Q}Cn~wfI1@3;m_~Q=DAE7%1jotsyOdTT0`1vMoQMtBESWeZ z7a3>1I@^hf<4arEOq>!`OglOznYf89{Tb#&(s{|i2_igWW#IW79XV$Do=x&+9|tUA zkAAJqXbpagZJ>I6Eh9us`IZikkOm3@wJe^j`hvF1!|!U%+96g&r%IUrL!@~i}#gQa*+{%dC8g< zH!6=PTIF0)-o2f~*{Rjo@z7Wn@~$Mq7Ggdx#j3R}Ewc2O<={UwP%?`Vp|*;cIKKO% z&BQ5LzmlPr>qE{$)kb>2S-$H;6K96cP{jr!u(dW5$FJnDnK&iu7iS{o^V&=t@3+)u z;*?e)0`1vM98arcGjU1|FUHfWv%N>_FtIqc5leX$YeqNsV;(JIShj|`fol*~7e&Mm zbc!pR3-SvKDp;5;iMf{b>ABG6$$ZR~o{kVzN5rMWf?P;NrBG1KlY@)tH zX%!-{h4dY4-%>OUY{jN@^|E5tnb6orN^JdtZ6-24 zI(W@3&Q`yxh^KmL0MXcq% zMn#P2b7!*0F&TQx3AqUIaX~8oz06>7`bJw5bH*^2Ohphs$R$6Rba=9fEx}QsbCPfQ zs+99nbSLlYR!^+`zL@ob6=B_;pV7w$OjDkR+K&&NO+MjB9WnM%5e0!-ts*M0vBY4F zv9qw_d=q%-2SFmaQhp01h)`Sgc=2I=FndGsq;?Jkfm*QvH7ubxt)c}z~4=L2KWRqf0%NS_b2k= z9XpD?&95j3)WVULkpmC-g_EJe{ig*cTE|x5x}oY9NzZuJMxi3S!fXYBTIz}tFCtl^ zzb1+&f7~%}zT)b_KG}>S%VZP{98r9N*h1P(xHyF0z-*%N=|^U%4>MV>qUZFsj+_R* zCDc}V-*pqNoeK$G($n(oJ&SeReO_<3JBNb6U(&SS68G^9c{+*0+pa1$@OSaOLh;l^ zyIXOK&zPStY%q!QVwwD)l9z}%e^i;GK~?djeHA`yWNjrolg>t1CMRfAJ~f4H?fy}p zI4dvB#FNvySbtxAYxXd`IX(jpnE|ryM zjtKlEO{-bcPyFcloW%!+Sm*1FX9fDXv-;->D$hV2-49*7M8E#I_{l5H6a;Ej%{`2* zx>SkP*Nf2THfdi%xU}%+TlUnqP=W}2nwqvN(ctxo8$Z8$xPm|}H7~V8n!$bkUCid? z8?HP-)WWP*D#@|?6VDdx&zmN9RS>A9K93RA%ZiwfUHSTd0xkSqoH5ux%1f<#f`?@g zYq@f$G86G61ZO2>L9Cp~tNB)C(+>S0(7Q*!{lzrKa=x3xW>D^grfoWOf&V)*D{r1{ zkU+^lgX^#r-EG#A;V;oC&Of!N-dRoFa7a~wnT*7Leap3+*l!vuSN^5`XnSt5G5J+3 zF|w&IcR5&1$>&4_W^EFqIQ}aCZt`EOaF&4rvsaru=*+&{?966I`q+2Y{4c-Od&eF= zagoNK6d$2v2UwvZ6_Jw?5M4y=B!R?1l6h~l`>>ENLZe&>hb zY<=<&f!U8b-fftSBl|M`FFwkpu9&zpn?M9=71d`kw;t#8tUtLBB4Sz* zG5Tj+yeqMVqZZmAU)uJm<+qr2UdCj;g#6vs$P^S`DtVh5Ovhapr$f1EHB1Zr6=C$Jk7N9JfDW4!L1Df!`(1*!Ga6KScZ5 zJ~w&Z2V6AnoJAl4wLad~ZGZmpdKo)t7JkGh6s;!;^m7p?K?LzswudbJFHKv$?h^lg zP#sZpsGEX7t*!kl*q;8>KzUQsv;Jwk@NpsDUT_sCK?HGLkc+&#>^?rPQ3K&SCy#qBR+HzPwH*v_zds~Voof@umoP>-i}T}f3sgfC|WSmnpR``F#c&+e-XFv z1V`)GDm4T2&x8GVwZ?<&cu_>4mbyOPzqrFnk%xTYWhBG-it7>kWHX9PGKvr1{KjxZ z@d;uJ=?yodpg6M0Lo})BsjOP!?wE|ZBSw*3o}GjFksmy1*erU2y-Vg7*(u-mgMVFt z5=0O;#WER1LJ*_KGZ{sWJ;a_PnksD6e#JdrW)o|I>k34m7QSQX38rbrk-`tn&w2V= z#6j6!vGkNMixM*fH7#3488Px;Yo6-o6f=$#WRAM-tXYTg?BD5mS+;_obUkH?f~ND` zuC>MfZMpf2_SKc1qYXq;&d-LC!rS))D{^Hk$J1<|?tUy|{9IN!i0X7eCZeVdKiNSn z?cb80I-Fgg1QE}A`O^J0k*w6q#x%Nz(g%toH}A3iA1)~whnS=IYNZb=)gzS6UM^$k zM-*-&4hFQ~hl}SGsD<`a#I1Ji#En`Zd{Ia)ff7Wht$J{;w0Q8@!KXK6%2?sdz;UB^ zh5}7QmEy12lnfu`d7u_PS;{wP*-Xp`aOKxZdMgOjigWd3!-tn)Zq-Lp&s(l-Bs#Vz z%-8q#6(~W3I!kI~=_RJ_-@}T|e8&-iT9~6s_pv1 z`hI%CZg#z?a1;27wJ#^ zWai`ow*pLGI**lbMhVoy`@Pj%%TLuF^YY#1anT>6# zbwF=dGn#DFGqZ}Eu`zsCY8!!XI7H(+pW;i3XA`wnMf0TXZIpLAB1%86WZUkvaU$8s z?&D7LDw0d;k7TTkCVy!B|#f;`Yx~-W#i7x> ze93GghMu)-m3$WHImzcB|DKzk3)Ljw+E`qC7tNpdb`vN;1olbOn(LwBYVus3yYdtM z`O6$uF5iCL`98bi^`UJ!x?Sv9vUA2GjuP|}(SM{GjlZ-JcNfR-1>3VK{Xi{jA^E%S z28%j;8oz&LGjBa`F{}GBK|ftRL2s2bk8SfhqHCTv^*wjy+dV(j>Ned?-g9+7G5X{* z9&_|CM+qXZf0W~Xte<#rbQ+(X|FBsxD4MN(lb~06lFe95HcogR(R&rXnF_y@O7Fzi zgT+*R8b5e{vxyQ!EbFtFP0G1nH@YPd;={Nu;_b3|{7k<4%2_ztE4Y0+`!qdXzYyVV zP)_ya*5XaeXkOVhhlOXBsP*;9(wZSe^==_z+J`xOUw?N6fm(7U$yq{ch7c!THxifY z#qepHR0L|_nXIP0i76@07K-K{I|N$AK5mD3G5hrYHQV+v`n*2CQzK{N49+Ylvo5D{`dNi>Zm^?d1S zcWe9P7@qT4BL#t47#m3ydk6bl72Cvc_rD!V3?w2j{*tm%PWoHznnd$nzdDr2OhnLp zvs?AwOnlR}igo$R9O_3E1%X;pQ}#ptP19yfsA7F_jp6IxR8|nEg|Vtseea^f>eC^b zPd(vpp#%}*XpE!{*;1->o6l**`Of7#d{qQ$Vf-bXK-TGI{jqgA-*)((xw2Wb{acT! z=P(ko&oTZOctl?|FuTz*V!Cm>eS+?|CEq2vTeP;W<%r?o<#SkaZc6JkBNZ)->y_0h zL#&eXqWR)bcLjl3*h0EzG)st8cUTNRG1J{b2_n>fI8XGkem*;cAGw}pV%xD#8T;oO zEoQ~*yJp^^x9IHVKUkA5&Ec+EP1Yc2h0%3)ynesS@`E|uR@i;2I4jY4VeK-b--IK2 z4ZjR0)gx>3gH^Ba6kaFnB2%=cFEetsi`SE0)zei34{B}Yo*Ba%6wM*dZ=P<1-AK@v zR`+IK$4&Ytn3YCfocf(Xo&qupJr5X*ODG%wN3U7!RJawTQX5=MKq zKOI)H^U-`|tiOUlEzJ4TG`TLmr}eS+q`yE3B4AdM4Y@A<`EPwF2-JeLLWt7$D_cjZ z$MAELt0)N6!fZ#%CeK25-Frs!%kx@`(Z72eVrIM^5Ivn?-ld$Gvgb4!WTU|Ximk05 z%{@0a7AQePz(Z~fn-Z@VIxlw<<3&lU$famrJSty&Y~#;UC=%NO;H z6uxtd)XViSL4RCn1w)T~&*o(3uCejD&xPe=W9+`FEcH}T{^)!p`PBcDAVT%qN4|Z) zs=f5#Pbsny5vbMNSzCYGIYCdHzKm>Kt@xC=G%L!ZyF>_-AmU`YLyv#GNRK`&TQzCB zgU<-^kJe??nC1&&2@zUv z3ZGWJB=_nSrXWzOeL#X=&mQpXSxWM1alI7; zYJK%367WcTynxvw-=5U538GusBeUkGuv zpQm^*#)IGbt+_x6BJSwHc0XhtBTdUwwzN1x{cw9E6a;F~YPN0sx4RqSQ^xMc$U#Ec zM(?150jee7w zi9Cru^iC|IAW%!bJ!Ltb*5Q)dreZJN4+L@jl%7M0pdWPMYNuW0&I*`*?4 zAms$e=g~Z;LKS;ET^x=_}(iH@1!5)rwki`lP7B$-z;+u1wR(6nxn00Rfdl0=? z4;pZX5S@}oiXS%=<~3XWq##fWcdwe(ks|0K8~X5?(VLWgEh0uAT4K-Ldv*4cjifqZ zV(<^ed82F76a;Er4PVMOgvIN%=FcKT{1w!0 z$nMSa^axZCsFmx7Wy~0ypcgG%Pa+P7i#TUdexQkiqXZEpt}nOWi96zJ6QXO=2ywWh zFL(a@Cqo2kSx1($5nU4WT~W#8!$mKM5Tnj}^W=3m8A=f0UUDVt{_i4vbk==@7{51C z+`Z?^yVeg-5U5o@cm->_I$r;G%Jh1a?esEIT%7IA|CvP*10X@fqbVy{-Wp5ve!owm zTL-r6ixefCe%v>|kAVo(s`1NmX0Zglf8_%y^gOObieGyd%QLH^3=CACxOgr#eQxaQ)^51+bYpac;gQLP)j{8Np2b@mYpibd!L?Wq687rSCBq~Yt2Nm@w&rE;aAz0A6@;E zf^e4?Ba9D-QpqvyTLam<&K7)clt*^d>+$m9(>FXMz zQT~E_(#V?1i5DWWUv0}mTQAlx-|kC@Tblz!_cRZFt$VP7KrQv$CvQe6u_jk>p6#hn z&fyRdln})BU0bC8v~e`qXt}xsd3Hv5+kITb^g-@?@tID_`4u9r{ZQ1N zyL+z8AVhG;L*Dn0JO8s{9|eJ0c=|a+r+!QD8k#cA8DZk5z<$X??n2>6Jk}( zsl0OTV*JyrFa?2Hg@?`7FAq)71G_CH#OHiddCiSQx$B}Z3nhq%Xgb^8bLz>lgm{}z z=ic;*;C3W?CR}%E#@0Stcak$PuelE_BP6@YAf`|qEYTJ9xL06U&Vnp-< zR?MR$@8A=mAW*Aa%zo#9Iq`Zf8!`RrH8wlSi&xkgVW9*O>wijcHleXP=f9F{tlX5J z?LX|zPri&)5U5q#@C^&m zfNKH#$HEZ`0=2sPI8wbb_Uom(%I6XF`y=B_8GjyLHo`&)BF5&b_NBE8Q!*{X;k z4)ew=A3kAQxPm~fr;&3~%>nWH@E&qrjn|o3E6I!3*%@x31Q7vEqEmzWFVY{bm-DK{ z$?4|k`$hPbR$&SPwYD@#NZmdsLBC+Hk3HL_nL(S1^2c4mER-OkC_#i;75HBDC+2|x{`~F5UJ3%W z@T{M>yKc|SJ-dqXN>h6&EC3?5cos9{S;pw?^1ZuiXb!9Ajbi-BhfWFtwWLSzzwbx8 z#yPF&72SBg4xKENAmV6x6=N~^?Yr+zA{%Ficv^oCEW`(=H&+m-C4B{HL!M(&?KacX z@@U}Bf7sI8LJ1AIsjTTj|~@T*N#ZUVJb zHsi(UPL?CAC|}dZ#X<=p0$=kVds%(2 zc=8&HpD7#*A}W@RF-`<8(erhXetYSigDs<0Vg9*5nu0(rmCe`@I@nq|z>5$0oMxf~ z5r>{HFyuRNZ}~g4{}|vi()#n67e7{Vvw}daXRQ_+>&GSNTej~fL}6{DwIfSWo;+i- zi4sIiJhH@)@5B~wB~#tHRG3xawl9D7W}1RPt*1?w8lxsB=(%Rkkv2lZti2n3c)H&- z6D5e?2V)KS9gOvtJIHNS!mZMxB!9hACeb$MP9o_okM1%X-@fH9-GwL>GUm_uHC-KHlDC5YH{V!4q(@5FU&9XjQIkbkHx zpt+kPh$8~EE_7OIgoGvNy%ObV)EGi+`?n}BCuE})HKZc>#+~7 zK7N{lK&`Re78|GUCg|nU5(sgwVVJcjp%{;PJB_0R5jkR(81ihX$N5A;U+rZgxOFADErdifT5KF1Z9vTXw&bwfwsh{Fjm{mx@|nzY^pj z5aLUYPL|xCZzqldC5RwS!v5CIBaVVjYp*o7Ci!}B$x$EzwUG0qidoy5TQy_c`0UG` z0wsu;(0;USV^;BCvQa2!kQGhe^D{xE6$EO*SDO$aje@KV6+9^}qm)1iA})0AYe>JH zwU)d=Y1f)o!FEM>ZC9IV`04>L)B;~Z+L&0w3XgT;&7Rn7hK(Rw3gXnew)Xq#`6bDJ zjH^=0nlRd(d%h16_*FgRyO3>(vtC#kNS5*4y*o?g1lvu&I$sx zkX@$O!B`jT-as#|5ALM!&WLFI%-a~-aRFZJaw7W7gOsMX-F>_!9H zakrs(If_Gm^28jJy#u3e?=idSYkE4<29ei&?LRA20QCkb*$1HlBZ^ z&OdlS|D&|zIa_Q^Gh1Eo;pO$g0wsuexc0Wa#}V%&&-q8O?dF}nrFrpt!xaQ-MX%kT z+T(SCey{0bLTqlZ#e6lTIR7qrgg^-*KAuTPmEVshhb3QIZ|Zb2^i)ZH_rdoH0=0-? zNbR#SUe7*Q^3MB5O)dy1YW_9JpPSe28&&`H<6~+?C3eTXD#@X#l*`)Q=sUP7X-&u** zy;<}hX&fbpz$*5d_UiRnbM3?H#vknmD>sIOWUXQBDYKNV_b92{9zruQx~}EXcqX6H zrL;gTta^&Khfs|a=5M`?n96_Jty=+JC5`>xFJ-%auWsNi4^&|V?rAuk-`fgK<$O(S zV+*VKQ_kSG8!E5@Dpm4~ePZ^UQHX8y?x!G7OTBx*v+ei1bM{`wqmXcGPE<`6Flz~0 zJhz0=pl@~7NQ-5)LQ5K$6G`$;r_dMu3 z(RfW!w*O@r=>M*6ryx)ZD?-s3neZ2d#!uz#cj>~_t0bF8Pw?~a)frX?!Wuc`0l0XJ ziyJsEx2CbcDlnM!iP@l3M|j3#zTs?M7BHZnfs1^-Y~n&isR#MbWuhZ$F1PmwAAh)3t5M=vcBl#H9f7nyI1hZiD~Bh z_+Cc7B|oxDE#p&5yzXpt8Sx`K8gMxk>-~~XJ->tXba))U*y4qWHFr@9(G-`_sGD`5 z*>c|KtjHaZs*BxD~9N*358^8Ri##usbTs#$_`_)o9S}~uO^Wr~0RnAM$ z2G)M1ejIOY74Nj1|COLAXELaTk)U*&N0lsA^M|qgXsQ~AjIqa96O@uOM?Epqy2bJV zbJP<^)WYXayQ1dz&8P13c;Lh?%9$l<;R!444XU@c(lb`_*Y`E6Tiw>i+inZln5ZwQ zSm_v_6WvT&t%H?qNgOXQ`K5^n)KZ^d#UV|t`t?_F*RHvg`qWt67=KCAMwM)36}`QZ z_g|D%X%!-{g;YzgU^DBf8OPmcxmww`G%y;Cr#D%On^I?6wXydOvn~`{!^?i>W#P9E zt61YN(S4B3xmER_Rou6xmtq4ESPz@xUT+3lrwhgLGQNcr1Zt^uo$v1u7Tui0-?YvEV@-$ z*++HvzhfF7;&`#yjOC!uDz7K8ro11m+wC?f1SF3$7~B#c93X8-3h1fdSbTe5X*la zukM5qfj5az&CE)gb=_|{zZTF&*$JZ-Mq^X7$QXBPa)ns_Ly?Bc831bGtt*;#{iLt8 zebQ3itdB!Mpq3i5y>@YH%VT0JAG|1w5~q(5_;~9Iox{})wI_nDck7S7IdYxQ5V`SG100wsvRQP;G=tdMnb%35yR z3KFP=cZR8m>+=g*p3$rL(kTrDYGFTA#EILvM2S5s`Jt{&mCRzyCSKk!Cu@6VAzSIu zjQUafft!d+BO9+82$Uc~&D;KGqld^EwT^rH2P+8F!n|b4$PIM0Th)F-Go@8%4>PB! z0?qGN`M}hr{9HsgC0ZRb`mwSARqJV1ia)(|i7$QGU&$uNYw1Q7x~nAar~9h zO*oGIqTlzX%FH`@D76em4SAyPEWeZmd}>Q$bs#mXSU6}IPye}U07ABtsV8^hK}gx4|}n#F+Z|R?f235eC?k- zglor@yx5@A93_ZQZA3JsT>bg6{Dp5CM+vq{t&7wDbkG0S)pf_`lzjb1?+HOj)Cdw1 zBtqJ~GovPY7tw+sLZU_VNJyfyTC|5v5S{2f&%JlCtX@`U_0_x8>wE6A*6%y-U*dDl z_nABG%&9ZvFImdlpE)EUkW0=Bud=PH$a8BUpEBX1l;w>GtV_07s*UL&?$5RI(Uo5E zD~T7V= zT@9ol0)1!HMu&!U7ps*r8lB`e;CJ!HH1a1lZY~P6S;B{1%PZZ)hV7MZ*P7i@X0!7B zzT3!lZr4m4y1#_)Et6kDAeVd-+t6bjM6Vl5_`@2{^?W)Q7Vwp;FG!ioc&j7c-AKN^ z20a8zvhvAw4@n5*!kT)@hdUN77Ai~lJnbcyR(6}4dX4u zeXbu>TZGMKbCXRk&glU;ME7Y+__OR*#`@(C)m44dSWw0G+P&TRS*Mb-S=i#H8s0BV z=Xs?sywpA`fBUM9gg`F5Q<`+2D{uL;aaQhrwVi}OF8o@HrB&67d``h7Tnp|hA&^V% zhuet5+;+^$uRM=1kb*tOd|0}P?f50WX3heB#@bc7c^L2el@SHsJ>V@T|H9v$3YRiv z5rOsYq?g<;CI-eY;D?4blJbzT#(jZw(`}Jy_3dpPEFq8!?*^ba=bP4|WV;1? zhb4!U-;Vbv;Qgr;;3SxAQz(LyQ6zPC@B8$Uc&#J*+fDhS8ur8n$B~cQv}2= z;S1GP5(2q!RZ=}pa#>OHZaTlxHQ2!K#pr^3mv^0Iu2coFgzw5wB?NL|pDdQizYjGk z)LP0L27G2MI>l&TW7638Kif!^*O(JRw=~}#Y^+hZ}=6$+GNgfvTcwnrU zGG#kW)^q1x!e{g!AR&-TEFEF<%Qc(5-xx!h{=uj5&>TD>V==E>Bu+vg*X5MTYK38GtiVKbKGH{DVbSDoKbRCJ zkb(%TeYaR%gdOIi?^$`}R}qp`g!Uj>mZUS>H~5XsR$eGluEItxtQn`+!~q+Re{AI? zbM}&I!x2%c)C%=uWEzVu)SLRz<@GXN^Y6vnX=0RwKrXC(r_ANHv-sA(to)AXFV(ms z;=}i$YFzI$_D`4o)W*~W2|VH1Vm{?cjD$cgtl_6S=<>DZkB?ia25XR1tB(jAb+U2B z`SJJ{i+R=WgCzuV$@5WlXq51=S$Vbj%OracExjFCQLJabH1;a@43de1K1YcMgROj* z_A5sUBIIh>eYInS2qH~p{s;+yTvM0FuoBnPSjQ=$)W+8jF(RUmm7iHOnj-}fSQAV4 zExi~bo=#iL53j5sA&{%S=U`Th)^6g&SCln5c+X(5@wSzBtKiR(f(Wc|wOIIqIC1vT zV(#4fIzt3RT7w62B5Xe>W!Vp#{mUPtAs`ORIz6=vPXeIibD9(_A z2&`8nogrtuIDE~@PyhaMA0m*;`{EF`wR0L9)zO#oUiW_+CQ8#v)S5eJNI?WvxRT$e z{4nv=u=2mEuaywURUzMC)^$u8tCRA9+IX=cPF(!HnBQu4Q9}wMuwK<-**krRxVq8G z2hORiBLcZ%+Q+bcm($prCM|6==Vu3tsm~YlTCD?gq#yz-Tqz^|S&WDy4WPl8F%kl~ z&bEkR{YR63xAQO5#*oJ`!nu`|ci1>aM+zda!j=3!xue7#ZsoZ~Wl9L-+M3>jC6A?* z_;@?D@%xh~5pmhd-);Lr|f=yW{H)btkhGoKhT1}S~)rgFK6?l&*c5^=_MhMtMut*YOeS+){gboOkdw} z?s|PO_n`M{k%9=Um7}`5i8Hy=RkArJ^_LLHmDqQ%IMj8omP9o>Vecoh1dT4Mjk6aadxnQ6hvUH9M#vZ z`oIQcTKQ^+I0=DV_fylAu2)Dh6f`SeT-YUc?#*J}ZGD`96hvUH9K}C27iHFlR=(-Q zFbRQNMatFLGnaIqr6tWO)5j}H{i`!P9A+Q|5m+ln?|)p%t%V3HcYZG;kP9o{=%$dS zZd%`&BonX52;{0Cn~}YZB(-aR+4GLeE^4LSDPpoJ&Oizxu=0(fGI4LT?j(16xylIS z>e{2SZStHn=03`t!EHsl=}!jJDtSAQUIYIh1rbB|a9gk0vshN1SV9ylqv|#1kQUx< zuz?grV67b4wI9an>SQZ_d@M$)iNk6){3UuTrNm@CilX6(Lk37~Afjc7<+kpZ(%9{< zX6@2zYlhzI)?(g$f0Tg~M96tE?{?et*)_->{LoWEAeWpcGj;3ly5CAGw{`AeAO#Va zmqWQYn{VjPpDy8bUxZ5tHMq4K>>TCca*V_Y)%%kTtXn%`?*nCa?dn2{rC)$ z?Ga6*j8;9Z{K%Q*I#LjUl|6KWz^`#e!k>%zD5FDfn0LWx2_#x%o(?c5BgF~@&TO5Jy=^p5ezNy{r|fk zZm`2Jqfg}}{D@e~@E$lsr0GMn^6S&sh_h~_U)|3WZ)`rbnD_K6C?SvwD|<-#92jOq z^jpG5I5_6 zuIn!TRFah?Z$8EZ0^7TBuoDMHcQ{wvL$=c zVYrdEPOR-qwG4KT3+h`@nX5X(scV(pp2V<4qsh+w1zUH6gq`^BcZi zt=YM#a0!82WBu}IhjV1Gs&5(*VoX3)LoM5$O{~|zD6_#`%fD#_(*wP=EUkn#J0hLU zNej~Gok;D+Ym$WM`$aJ(KH9(*4)K-{$mP1UwwByBomI#J|9bqXhwda zts9fcDtQ*6Hk>~7H^v=o#jaOArO%w(O*>dDgB2XV)mG(od+m>rnQZ3pi?-&gJJ`{& zgD%oOkAKqLm@~H_OLh9DBLxvRn>5o#*wWdx{GSN1pi)QUP0mnuuR#$5zXZADuPVK; zs!^guM>gbXePQt|srdw?v*?CFtZHm!?YAbG?CCHsc5ZFa~%ofXr~Hqx9U#KRJeSluyS^$+=*vm9U2 znHuxn=6kQ^`Bykb`eg);tdobgAEmCB->>1^}kSncz|ifR$}Om?Np04?NeMf+EU z?lq;)u9RLpbYLiJ+;6mj6hySt2C40xGuX{aQG`%qtBVy&J?VznU}F3jJjijiA>qQ`b`p7qX0%Av=+d3io2->M_dJgdUpI{Qd3LgQD- zh(*!$gnP;iZu!MUdUYDJ{4pL$nf(FHMWHGwylR~y21fgki%tV(o(yyv&`NwfCsH)? zNan*g{2|$cXanw>5oSlc4!=x@xUZc=o2e=MLEBe4QV=1xaqUAVG1w`EKV0!jM+zeF zYe`;h?<9i8j^wjq-|%|auNc$<=6Udcj?@xJK?JsEv6P-yQ}}%h<5jx*NfjelC4y3$YBO?H7i;6H^O*~S z1yT@!Qky(??>dR=mA2@9rO)YDiGmd>ayHewC2k_sYc?#w3y_}M3Nx{zo$(sTvT!N=6iQuSfpi6z==1j6jHOeJ3*+eU6y9V;L8)Q{x#(oDBk26l$ToQ2!gxdbquY8Jfn0L^$A?&Ed{<}E zyBXyStdPX&5k%8Vv0n9zNbeM$|3Mi8DTt8Uh&23+zDIoN)mt9}Z6^ndmqv9VY|f0T z8rnYO9c)n7$XZjCKYmn8vWgIaUrTjG_i7rIZg%9qyO9nDRi;>3ik2SLc~!4&ygOTs z$0h_DNI?XCEoC~LtY9>In#}c2s)03p*dEsWQT_R(5=NJQQuvGGjis76oG6K1<+Kpt%nH&0d_Z0r? zyKYi#A0j}aQX9sb8+v=XkK@`c8G&5r1EkD>p*!@HcgcLmuqesbhY0jcQsigVd_9(S z?U=t~Brhj&!C69I)g?VtKQ@J4iD?%nA&?8bxRe_;dIw*7YdmkM$&m*7s{c+)4YRDs zhg|+HuX7=Ve>u=qie?}J^UEmOXZy-uY)Gc+%nnla4sy|1!hW1S)Q`bUii;`LQ~CGJ zArb<)Fsq03_KL+t@5mG$|2#y>-a!QDDAdNujc#H<%M^Z~jUpkCOSW;2##aP_Mw zFUkny!u%qOrR$hRqFgQ~{^END$tps-rs8hF%702{ZC4kjuS$Q~K!krQ%P(*@34vT_ zy^s{`(@wNLLHDYq7nZU=5rK9)Mcli$7e(ur+J+ zKrW0MQmwx8U~#5OSw1p(m1Nx^V#MFQ*}1fIR^HpZ)2Y70Ady_53NKM_orFLxjDJ$a z`mJ~oyPym|@n@nG=S0MlDbdWUat5ochEN+>pW;R4TvzUOeVl|qE{x;Sdn{*0i_+Um z@~v&FO7U7mIKGKtOHZV;rzM_|CpGQEC~-QuA`hzKBO#Crqs#P$f|4L|A1%WZ_GU1*wf=nPbK6Q&sr(9NQ!Wa$#JZ?qSI}PE4UY;vXg++Jh8CjGH!) zX$3P_??+}HfYav$(f(vv-m+(xeTYCVjP{%F4E737iM_SE0wF`MfF_tB0@s4l*xgJz&yQKiACw`-r1SdM)mvg;_`*c^t#F@ z34vUgsbaA_j2R?4_?P9bt=35yDTs&->c-~2N@qpr)ew52d*L9#^Hk*X;?_tAc>!Zxy-*!j_s zM-=}g`7#hu>v2=|eNH;7+sljsWaQ{BBAS=xrJuf&5XgmI3XA2|+qU9WUnidL&msaT zh$#3EXBCTQuuJQTFnZPMd0Vl~(TUf-D?R?QOZH4GYUM5pEpX-C+trsm6X`?_Vd{#-1xKP z{t^PYWY5Ha%(?vRDHqg!>aRI7SP2nDZPc8R#`}MzJARK0l@Q2<*>vRV zE8UuFgUfPTZ*^sC+BP<;uohwwoMu7OpgnvHqw?E`r>S7UN=%kAQ$as@WXlRzpEekjoy4| z5^3(*pA>1#T>)!vj>e>mOH9ecXb?ccp61ahH&jz+h59wXPnQrsu9ljONW zL}Gj%&9__zn|YuSwXq~DuQ9f$6F+`2OhO-bb;FN^X+cb zPj=!v+++lDp@+?4`SLZ&$oWrEo;muAsma(qDW-Q25ta9L(_B3=*k2pBP#g2s_cd1bo+fzmvqtEZG%gNVR?Vzj0eGuVeNkH|KtU4FE2uV4jUw{lGhfn1nnL$?k!Nie!o zsp*~3D;ZJ{kz?jSt)F8CTh^;Awb9)r!T6(DS?;uFy@Wt6%$cEV-2CH=)6sOtZ}S1! zNI^ua^nqFsodyr)yJb^!>F)%i2ld1KOrKqdKrYO(q1+Gu1Y=llS6<-OY7HrfICU{b zbEeZ^%GlS`#+fP!M*d^vxV8B*34vU4_RibBqwV+B?%q*dM+zeTY9Fob^UPr4YEv8e zED3R~e;M97r%^wv{_^!ql(@$toZ_qvV-QV@X|H&mmzu&rVGmTs;sA|a3q{qPpc zOP9vRr4MC!uC(%!haM5Kzdbvoq4CuyNIn@i34vU)zkOtMW#dlHiu`JsdXm2#J?ZG> zChz&EQpTJf75F21y9OzUkTbR%UzRW$T`R^rd~YEkkPH3dbnB4FWfUmo#5-!?22v1# z8Di!=ce#zRG`ejEbdnIrg?@3mOK;JA{mJ*Td=9-`gA_zy-Wt6JS>SK|$*_{VO65oi zfn17q##1hHat= zBHy}tgoHpY*{6T5)+=po-Ew?g zt;TjwzRf#9$~{MfoauBe)?1r)q#|#%NJbzR=GI#*{bKysq(92@oof;dq#y!wrYL{+ zSRIyoXn9^~y^KIEtP!wShF#3ZU*vG*nJ30bwE&2~EGw#5PkP4^m1=zRLK%TvSnELV zJNIeAOEmD{C8v*+Y8ntBXKcM`*Ot3IF3(TJ$_V7bnhnyzFQxGddCKx3XNO9)8Hm_C zFGF2Ueuiq_&79N~d**T0xI8a;Uq&Do)}D~QyIJSm+~`GwE(4?*6GUKU9lfes@c<95 zT#}!7E?1=>7wzVsktXt4u6@XNu5#cf)n4YV{YgOt=HgK;z{eYW@cHWeY<#4IKrW0l zSuCN4bBbqm^q(M zdABwA(S>*eDTu(`MLvD66u!3J1a$!IT0|h%vu4fJ=_|9?qMXsxMzQiO`Q|nGvlpBm zV;}_)nAb_w0|#pHO+7lQr)wlg2;{S7m1hu6REsYh9!v5`-)o_~G1re#299ZU>EVj;O-b_(5ptRU@X^!&kTuTXo zT)0=rYtXiu$hq%@y7Vt0<-H;TGmL3;Lo10M&QopuUN<(7f{4xw%Coby+OMDUp|7f( z&{`b4m#pr;SJJ?}hI?@SvS60AA&cez+ldgHyZMQ#16!*fm-|a?;P2k)SA&VHEOxDN zB|^;5{l%Keqk|no{Uii(VGP_q##26sy?SeMSQa|NPSJvyQiJ$-h?a`{oF=v)Og%aJe<~k-@8+nB?NL|e3+`ROAHqGho`FU zS9j@1LBwIT7yE5h77N+3fZ90KCP6${>#mG&uAyU$w$6hAtbTeHo6@9+4dbfxy5sDj zV&BsqO8c?vq&Dz(F``Tsz56IJxAhhEP=_%(QV@Y{lfUH6a8YDkdUkcoFVa^bBK>9* zTS{xUMMw|o$ENCIMbfE3YTYr7B?NL|pXhB_=Oi)HXPIhvuhp=J$c52fdauZuC}OU) z+q*dVf`$}Ctj;lznUZ=*nRxOujC+zKF3mrr28LZ#5P@76?WH!JCy5##Ztp9Tw$FwX zMC=(kknN&XGVM!7FuedaI7zIG8>DWUn92}=T=m^z*^JCA=9=#$d2zyuCW$iJAVa?OqZBxiCIV z@!BV$B0Z~sdgx9CapjXAi(Qw+_J46@7$3ywA|alS5}h1}sm&4war`Z8BS*_VOr4*_ zYS&4juiCbJr0{DTqf9+FK|&xGMwF@E;MfpRbkPd6yYzIyRpLROIk%EY% zSsj^;#?ayK)6_X@rx}ikb;O3=jyRq)JE#of{Z*Y`J0NZ zA6lw4MwXWl$c3wiazB{czCA6I)vWij4u^W%p^xs_rrfR5!qoZ(`^RdL4d!MypPOVL z1rb;`LpQ#6*`yWC<)+?qkP*n$ApA}CKx-C@n;k=K)O~hGJD96owrgynffPibRZMp} zwe``LO`oLJsGJ}nkgG!f=C-ONwFl=ib-34g!}Vpyeh>bvjxmse2(<1g*LgsS&Mu8q zH}x4QA&{%pvl+Jk>7vE%SoO{o=Vc$Y#2ZLK1lARjA8z4By}MV4UEf6na<#6!-ByUC zsO6ohA>XdDM-Q9bVqf1IgAAk~0=+Gi2jF-?pBOS-U8F@z2;>@n;*M zXL|FUQ?rlVi8PRc2&{Ib+>abz_2-Q$t6Qf>NC@Q0?fJ!KuB$lKL^fiK!>9U+QmoJVAO#U}`T+6XS7_j;z2xR0Qx~08dZ@9bZ?qEczEMXCA}~%tw`RW` zZtV3)RPP?1DIt*S^xY_}Ij!BaLZ+U0=FCXr${!9&rP%R0QV@ah7P|AycdXHNNq4nR zzQz&)xxDWT(9$!=x=SlUZKU}n7^fz>C{y;;(vgA)j2lrF(v(D_c~W1stLGIB5y&+< z??BBwYfF4KHOsTPl8nUkxyp$8+cl&h0^?s4D|(n@cr8j+N5^fmAp*Jb)6E*qsEs2# zmy>20`aQ|`?$Je=GQER>6hvSgkK%)alZ-1Px~R>Si3}0Qb+kgP7DscwCDznMGu@Jm z9Dd`Kzk01@NI?X~BPnOwX`E4nWQof?7mf(zx>mfueeHg}Y3ibLhbI{2U0Ns`wp8Ot zK?KHCDfj%;D5Li8VXDWfArb<)sxpyv8_i$#Ql`#7%aMWzjMvj!zxR3?`5$uiz{EEa0=e#d zXs($$LrQB?PdxIhhf)3adrC>y4;(3oz`O(6QNCfu#`k^IR-T0gB9JTML_O_y8ms%+ zrk*(VeS4#Blch?Y>5c*^h`{^_x+m4GsnKdhkeboCyo5lmvIXjB>q&3F^2(Lk*y_^6 z*xliUvVUC#ffPhw9ti0SLxPPFgBz=V_p2o#kZVVks`jx;+F@4SuQ}~+G_G}O@7B70 z0x5{VDtwD&c?~Z^U3_0D9L^;Ka;4@grM+5DXH<&-YU4#2FGFefm*N%81yT@!88BqA zv~e|FOz5Y64sIbKkgIrBUd@^2W8vC{)W*C{rHrH7vXy+JS_z2)_YX*!JeEb*^BGm% zRa76o36~JawK)8-Et%#cExi@BaqvZ6>P~R}9yxQ5hk3b3{FtdvEl9~tfoNw!^2d>6S2;^!tZkeq! zjn%_hJ*W-$mAc+_$zA38{DA@~h`zNv?4%w0nQ<5(2r7s&Q(5>c`tG zb6vHG>%<){R#sQkA0v>02+YhSi~jLKe*W-5W%S5+34vUD$1PLMb1>IzbARMLKaW%7 zQvIX&aDfy=V1_YiOBD|A6Vq+V%GAoDl~W%Hfm}JdKUU52Bj8t4UahQqpVug;sCicP5lBG(^-7DGhJ2wZY7X{2+Y8zRdTnw2rZJX++M;Z1ajf6F%%zk_7Yk* zo6=;Uj6klg6|364QD^p>C#U0{V6pq``0Px#S^_DEkUd3Rw+0BG@k7*#jcQ9CNJOAF zk}@;*6yT>bLPRLt0a2!AoEm*0i|ydF4~AD#U!KpRo9p{C+-c^mDmlcL0}aK2Iw$!2 zn0m}aL{BToJ{|XCCL&~xE0e!!{PPOpTy%(7wBP_MW{a^S4sJQFT-*?CN1U4IucF73 zVt4)(#j(s#ar8t_8~Tk4EseD!#_o;JE>UD4OQE*Mx633=e|Uw$f;{KKxz=iZtA7YP zdzm7jRdQ<4YueZm%Y$y)Fit@^!dD%|!h4}~1IBb6DTo*|yR-dkccq#ZD$I)`|}6-<*4%Z-|9ZdM?*A5 zfym~J*u~Em3>DL|avKpwA^0n7ymfdbg)ScE#5}&hff!dx#h_sH8aRu*_!mqr7hq zmh0jn`<$N~+leY(P_ukmXRbbOWT=>Z*+J415i!Z*v3>QFt7V=aDNR>vh3RFkM3;(U zjALfDIp@cM=#qf=O7=EP>4SO~_58$L+rT-EM4c<{;_u~_+U8fel^Mtojan47QT(Ic z;8B1mI{%xFzCQHHjri=X?WJGZU(NK~72fnwADkQ@@>TsVA&_fIRC)WT-=*D5x7_Tm zZFs)$6BV)xNgfvTgt*2Iu(#2_odd~|-4V5nRm}s$x>W@v1af60M%mj)>lZ((JXcNs+_<}50CgwfY4f7SDmM|3~u zAi;JOlMu*7^3lFRJH2f}{qVW!Zk+BLEG7kvl{|Oolas${*?}s?xG{mE-=%R90=eeT z?rfjI0~5{D;MvYq?B?W9alT*$fv4ps$4tdsl`+Aw|L^4Nabp`BG$>G9s7QCDfL@CT z)SRihW_BFEHYHRH8C^mIT|A_k`@`IK5e~~#bG;6t{fN39od(v4ywLLiQQV=5q{ATs zHBOSEZ7y)1j6m_*<|4v*y)!d^cSFelRz4yxGbN`fS5WVw+z<5<)fohcgbNN50=ZC& zqIXQ*JmznmLWO6+zc}g%lP_0c=2+#7ulE0RhQY5aqTk*?(b4Z8j@kqwP$x3;01AuV zy8}eTyK9nOf@qAfQIza$VR1P-R6KmUkbAZ2$js5@MNg|u)0^8zLHqbx#aJQbI{TIt z9=?IX`RY236hvU;k1AQUQsPXXhGN_D6sZkFe;Ux;J_>i2nLe^>lPilNcm2hSjWan? z5P?xrlGN+UigGza#eotfIqFnPAM|E@uVk^rx5@uczw&6&NX#!(PJBzO&r$P21V(#F zrw;HDhe(#xz1Km~$WRNzc)i84Y)Ex6cv+yh@{DB3Up?(T96}PHP*RkAbXWC>uwiZk z)d_d<6pu*fU0t=lgg`FLpRicm?w1$GrUi)`gtxerQC{pGcUsaD zQ3J%B8OlQ*{el~d=$*##Cnc>9H9c8}E7ag4&)qpdgyymssP!Qi<`>buQ5%2fc~+?+ zwV#8e^`WMRSx@A-`xwp-(2k0mTFO98?{vBCcFFLHG%{K1%l$ou`#)5~pVQo>3_3($ z-kQZSbW9u@)IUh9yI#u}Q`g-t)t=pn)hY+SQA{aOsS>@*19^)U%L5+IW;+Flrp;$rmL$%vHYwT<^V1H^zrK@tMFF!n;xjQd_juo5H++T0~gAGLdo zz0l3Wg^L)A`-X}~eHL@n?z?vAsF`xN0G%bWroS$_j?uxzRb)5*O^O>J0%I2BznYoD zICwEoRGs-svJKE4K+iqpi%wmk*P)ZMcyc*O3rEcxeap0K59HS)l0(I_y=4Vz-mh!M z*>$Q@fkb@=^&ooxU~u7w^wi)EctM}A_ozsMNrC3)^pLqg4o zs%t8DpvwCIG2ujY34vUaAH}l2{5*c%N;f*5c9e8D)Z5BsvxTac(SX1|0$F7#4Sjr%Df=J}TstzM4fs7Il`gL)7}eujFAyMck?cGcbz0=ZCA zqp|X=Ay)n#EIO7rEqT09*TT$0I?uoOinNshV#@vih7{DPP}j0pW)|=jeaRmD8ZIM{ z3%zaRr`}9AolFf8S6WTe(36H9w(l_m?3$WGYtu^2^3 zBl0%$7u63J(2;^V9BP0TOUWUAqMw(a$g#Sxgg`EA+hS=ws)~r*7bsS}AE%=}h?*X1 zfK&w$<}PwQ3lVv5Dp`Ve?kD#1nt%QnmD{Ycy zi3s!*QdZcw>%9C*g}gX<4b%rw(?boAG_ta{c=((k(X>W>34vUwaZ=Ckr}LNop@1LkPF*3>j!=K>u!o@anMupa-xolRUQ;kEc%v}y5KJ^&hU|{ zIZzixO_pwHZuFXU>k=fUysIH0kPAJmR9QSKdC!OIfnt%fuT))wx+rS0bgC6;wkPgG zpt$+TM?xSM>e|#slA?dwT}xc{sv~7IqMnF4DtWy2XX&G!2Z;vxT+BKD=@&uIt*j^R z4cM$#ToNc0e^&{CT$t}ik?N*-jlJIkh1-bxI(pJjn?x;>YLTz!GCtf)*N%4}@2vh}f4G;FWr4an>hP3n5YRzSs}mqx_g9w?$c51Xs^Lvo zrFBUO6~l^EkbL@RLC7BZ78@Vfw%iF6wkJN4PanB3Yk^)@2`X(YpW&gL`Lm@+IbBb? z(8|~M&%;C6m5((vW6B(L-smK@VnsbIG;prkccc0GX^u}VMWSauWm4|ahRInVs>oIo znYU>YGZBUA*VK9?C$Z!S-?vn)QINP;K!4CXzS^k$BEy!6snO|`|bHGt$v?5>egzvZC)PJ zwZrx2sYN>ErLXEWb*!ja-=g^T;Jh-mI@d3s&)2SjYWUs-%xr^RrBzSp3)s_ZO?~Dq zn@aw~6Coo-71g3VOG)KO8T#8iJK}Huo$4C;TWp=^)2HPwLeyF@)M(Tzr_#Z79e+n( z61Q`e&(ycK)jZSZY27}~QM;erZ6Y-PpNKzbjGnc1RA|gJ^Q*4<<=a9p&e)zhS(!O* zFCJ>y=6Fp@o4k4Htv!_e`8&w`mU4>$#PX+SsO3>?5u2TcCN+_s6?* z-%4LiJ-^oC9^>@djrr=$^2{Ad6YA}midGHL+iuzCW~=!Ej=kqb-QJGMaq9Wl`3tmh z^d+rMPBX_x+ePEkoqk{LM?D(d$J9oA53V0x^6!Y2pEJ#deKN<8&|3&i8nWY`Ng_x4 zoXTNO7v@yS%FOW@nsb|C_SSrZ$NZwe3@f|;wYJB*{ssNW=A^Ibz%9yOAtMd!p*auc z4AQ)4R+@)PJHxk{UuUtbEq6l;*w#zQH7?N@UMSEOeQS=|>&zm})^T5Uv?WcgTg0mQ z-QJT;tMY$67iScU*>l~T+ZtOZZlSrgVV~rFT*zJ5zIM%jf2`fDw07yH>~{9G+aWMg zGyi{ySKVsTNhanlEygM?3hgYT-i~@~+m$(L=E|X3e_GkJAGew-y-=N+?0rvVEA1>o z(C8ZX>v@}NS5G{($i6!h-UZlx%!hqM%{8!#Ru9c8jqdZaNh0^^V8ywbi-v0nM^s){ zRligAN}y9elUy0L=5$DR~pdn#oqpqEB1cKJ^w^| z&(_aTX+x{UTtl=*ea+ECDIrS+OBEM-ebS=PYB5sUp4(REXyjJvgNEx^Uc2VZ*}8Zu z>uL9njhbpRC0Xrt3mB~o8?6jqb5_k2_+Q%;qfbrchibXo5zKG5eJuqqwsZab!XN9( zj;8Er{z6|FQc`j07%5Dtv#)&%Uy?#B^qGq`yyaXq*Y}C6eTJ*{XH+BK!Vf3XSScMN zjV%S-l??Vml2<{mCThR5o~zz3?W(yvoT%kHdS^=@AxX4F6c9j(tZN?~1()e&OMwA*HFC)joq?F77mCiMEMkNjJMJ zaWCN^x~lHVrnmJCnkn|*8pTza`*K?rPG5NOmbrEZwYsO(nqAs1OK?ZwE;8j8$#%O` zq8HN%(WBhUU{lg{|97w1Pk?2o{cR&jQvY&g5hFQlBRS@4N(;&_A@v~p`A=!! zc6E_mCT`qbS2a2Rli>f#t3Nu}<&}d^XFGF+o10!n#n<*c~`&o=+z3qOP z_P)7Z%%XnOqi@+hx3r>X4iHHBvM7&sY{p!5ccD zMKQn5Z1unWaq`q1`|j#jy%KWMWeirlO)am3j`>{o*(KPp=Dcke=?KWrd zr#!YxQ6s8bPQ{6&C`z1zCu?eR=`7iNC()Fm+W(}e`8At+I4UVLFF(e_t~1=w!bAB02cTE~U=G=a=xHp+*GqsF|%NJ-rWXX@Swq~5rSI_cL>L%DF z+)pWks~x2h)!iMSm+`#amGU(AG}rbsVeSU<|#n>m5({fR7FSU$7CAa zmmlpqoOx#0bvRrx@+#?ak?Adc$|~Bdu?Fr=><8{cUbeb7{jV#J; zI^BLqAG?NxFAK`A3DzS$afrvuU~{#jTte-Y_M@pkZZRdtzpG>_wLucXz;gn}O&%-9 zopp^#fAmt8wXY-a4F9uRq-N?qtaUr~LtnA$+7`<~k_b(F-IWUTb*NLB8Z7P9t>zc~ zI0t{|SG4y;q;EcR1_iZl+5q-7I^e)i`zkq}R?pPiO`E~j)Nbr;=v{j{+S~Xc6Yam- zao}#d6umIVuEUup^^Z1o+zS_LXL~3EKIb!g{&U>X{=j*n6Q{Aa7GCIJ@C(YP_J8s% zXnuo5CNo~sX_GEd{`C;H3 z;2!)TUALM@i{(n!MB~C755+p_foVd-{70JRrp6<$0YGeQ0ymslcx59J{ z=jO*5^`i}Q1#a@nr;ywdsNI; zI&s>Ua%Df(n_UK5ETrcOQ(l>RFRqLqE8K1s%d4d0&H5sV<|A)S57VywsZHSiK<(pS zOLzup8D<+O2hV+rw6BtfT3d4^nilepm1xdf)1|?74FJ~?+Mn{e+SE7ED3iOI($e@_ zdi}JJgH@YMqx&sEH?6Rr8t1=fZAOxlX|YJQ0nRhdKlP(eOEINgm{RtX>^0ctIh2_* zcv!73c@W4$f8;*fSItZ56d`*JeBFOjo>ZTwW)`$c9t5f%bgyqDPaddr&Fd+74N$_N z2f?l_8HaKdQzHM>0DgK6P}ZUhrWF4Z1tvvn(eaJU{>$CsyQG1F3BDQlGHQCibvl)m6xSu ziyqrBwebHu7F4}lDTTK#6r>y-F5BnmnZR>}ZcnN{Qk;pYr1S}&C|TR6^`YfW^-+`R z8vR^?l$WKFxoh171ORA7HP~03{a{ZdoEeR z=&eJmm|mSel4QIZ8Kx9Z^N=iIlvijK(<|Utl8i!k1C^oyg(XWEx$sQ1`!Yl#b5TrB zBx(TYlRZji+Aij;)E*(*wLLrS$o7tyqdu)*WvKO2-cnqmD84O7X;$F2WY?ni z3*`#ktv+F-acM%B()~~}Ni4hX&F<^LwTm_`-S5+PjlL~4Ov&|Jwo<>G*kLnedz1b1 z7+S3Kvct#YditIqW$G!}N<}W5P0ALjoL?-@S55IhAzP_WE0wWJABPoo@5fJzm9(Xm z14P7K59MICY^6^AdzRgrtk?IpiWV!yKeC33HoUZQ;Obh*N<{=(taNKckyCtgu^`32 zi)?e6dX;H^(Amh){v;{dcc?g3Fi3enY`tW2qCJR`h+=m|U5rl6qLo#(Wt$UsE!v+{ zw;eJ-7&F3@z`my>n=`^;x_zup@5`%Ud-PSsC+Y341uDmj$~Grz$Y{6G5f#DpT_1{? z9%0j~_S5D>uZp}9Lr;VoJwJIWpWS7flV;8?6U~`K9hKf!yHi?Zln7I5c9d<-+iHD# zKg@F-?N7?5CSPU1$D*dMPqI1DUY7eozA+v>%E_)LqD_W6F6zGI(_fR{=(9OUnJ`+m zIdKlq{-j$^_p~&|(24WCn`}X%o`BY*`I2%z(Z#E@GW>7Zf<#XUT9f2&ZyhOq-Rr5G zYWqS*%LM2AhsJ65KhO(DQDcp|C4-cfwd$KX!%rVO`rgr+q|@O3P~*4LfyzH~*GLv5 zN&{Jwx$D>2m^8eB(kjtzNB#6PqO3)0(qf5L*YOb26Fry8_F&|QTdFBb&b^yu+6g~x zL(1g%rImN#vIl?IpY|9O+J;nd=#>-eT2m6XlnWP1=L4xXq~kMppX(TQYx>V=n*J&1b{ z^?CAgzBX^TkN^AkK%pKgg46+-w2eG26=t#gAAGA|7ytkO diff --git a/brax/v2/test_data/ur5e/meshes/forearm.stl b/brax/v2/test_data/ur5e/meshes/forearm.stl deleted file mode 100644 index 3a02f8b0ef78e60f87bda07902a01342586e8f42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53284 zcmbT9Wn31|_xDFdNfne7QBYApLgllKURzf8t z1iKUKYoq=Le)oejKD>Rz%RTQoGqZcm&Ybg^S^q(OY@;H^M@$Z#5)p127H{hp5j(n1 zXc+qc{O6#M$z(Lgpw@#RFS+#P!zw8PdZ_A>PG_O+!vE;vZu`cm|V z5ILEMqg$IvCNq5HC!w}-oRXtnT9U26a-lzjQcldoPwVT@#*Z%vMB`;<^88ma$={!) zv}|uD8~rR%OCql;2whH1#Lay&yqo2976@ns{UKzbJBzVSH>3emt(11JEs59PK>0?W zY^AfFCD~+;<>6~{6vI?=vZHH|+|4GJ`|yZPB7?0Gv5oI0t#CimV~EY!Al@6}wONZ} zNJVjwJUFJW1|h$pF>2c%iFit_GoFxuR`6~jw7)Wu9liclyY$^Xjh<5s`McRwo>P8A zBM`tJLYsGV#LM*B;g~(Ag!^$wHmkn>K2Es`yf+ACPg|CP&>Xx%XFXSOy@0EMkivk@zUZJ%Q$w`q;#5tFk1@sY4L z5x;3(B;~IfBtt8BpT$}|-jIkZQ!h)`#}AUB6{O+QgU~+h4170Rm)c&jR^XN9MJDoj zr%vQb*IWfYF$mRrGZVLNHdJ-Zv!Q#BM#}XnTgnx;hpF(%7oU=z@v|_Bs7nXcxh%nF z36}f2Yq%WoI!pQ;rz;_(-VVgm?&-2^uYX9rtxaTuw7w+C;F_}Wqp|!u-;WGlo~68+ zU?fksA54r*ukqeJ?jm8mUTv6jur;atQd15}Z%Fcnj*#~nTgWk|{;D<(b7hgRIBAY2 zul8dBrsoA>z8dZeI(4Xj|-u>3p|9b*%n;%Bti2S~N7-36; zTLf@hQH_Y?ak%?8U)Exj^qlzP5xF3?~do}Xo_R*7s_yqf&?rHp%$l|J*vj5 zheuYX!7*yrz|+-Zl)h!gKjT%~wkH#gMB?c4#C{guKfsdKX_=;lqfAp2CWrqzAt9I9-moNutP&Hk zTRI0wp)uUgADY?!)cs3JFgTw9jb*Cq&cc{B!3^1P%P72l$t7lz2+NRWKHU%EiR zk`Q_pFbUUd{#9*YeM}%=UqOEeWn5pAQOWOz{Z`Y=d!IwqE&T1Zxgt&R9Qr{$`!z_8 zd)b#1S8i4}^Z&nPbB_N$7IzrOUaqRE)|+*n!e1)<1&4MWOdjsDmE&vW@cETa>P+nS zw4Is~&|G+b40eo?%lie%o7#>L-X#eAcQjI2|27WKJdr@473>A^DT#`U#bahl>Tc^@ zLdziydjTQm4vR^EerUFai zbD!G9UjE#NmHk~6mJ1135<=aJQZxSkO=L$JeDut#VW~FJM3Kfv&npeP`AAkiQRLg# z9K|UwK|R+Yii~%;$!oPI_A41_IGerLcwZVk>xOET5J_qU+)!=~dZC7DB1u~RYYOZU zgoZ6yL(dGI%GPdm5Z)h2j|f$Mio^+z%4tF!M*>^uxY0|#o-;<( z1ip~OcJ#S69uK{lPJ^)lfmT<8`p7Hxzs=mYIqe_f&}K7yB&{iHQ>Md08(GMm8ija$ zRnpbJtf3rZZ?4Av=&Gh~cp@!2)m)uac!-ZBYhE#4ro>mUbaE`V7XmPCMzSe+v zzkMyeGT5Nm!;DBBs|u;pu)flwYt8v68egke8R@6TbS6j)UV+5BFI%O1C+#)iB{rPs zI(3N>Q$CS7%Y$jTZiQx(pEucir?EWRVTp9kwI%5n`&=r2yin41XivTrxbk~Gc44}* zvE^7cGaM@^;RVuYyHJu4oTnJSeIWfp5hQ>MJv1DcF9R6T9D_xA0s+D@gMZ?r$6X=a>g8GjO5s z3iJnO%LrZV+y|dZwqj!&4JGh-g*1HL_}cklFC4SWnpOPiDG>0qtUP+Q@L`j8GfBIQqqEIvda^OBfR&0pDE+ zT`Y8`?r0jjrWZn?Robv_YWimNF!7fH;2YFnzTd%*`JU`Ys_v8h3Q+D&E2z6+?(TL1mpw)-4NQqh@zTK=s0Q)NfR0M_tf zy0E%iylI6r@<|8gi)x-p?mzL8dpyEYq;o@Z2@?spvVb)2-JZGCk%`SSwkN}R zd!|E_b}OH^c^^qt;M|VC&ksbU|KYQ~arG@T;Y<$_hS4v)tLU`NseHWZzdf(o$BY*K z$*`<+*{nLMHtfe`1^Py4eaQa^?>;?*T0x@i-*Txc+Gc*Ti2FG9|A?!-dL%3_9a-$_b_#L*=q`VMc$obm$dv)2}1Z|#WT6-an2O;`Oag5>yLC7ejd zPB`kQ7xs?#P~eIVuJnrc)URGGExukgzuGiqr_$ZV2v6$L18XxEOPwu)?)kh{OWHc)%S#&I47(3Pxv*BSwg@$9R0p4aVvAGV8sZH}nbLZbAot)a)ijY8 z{KwSYBdXr#M(bwblsea?q&Wt{9fVepMyO?tSe)qmoA_95P@q+QxRE?+ju)XXZwkH< zYQ8EG=gs>NNZ?d&jRqxK*-7fP7|9zMpX^=T? z$2YYq{(Wc#Q3m`iul{@4RP8^p@J@Z5stFg{hMCX`juPT8%DHYh*)l7U^%?g`h3gvl zHmDgGtbKF58`-pf9)IKXn1u+rlE@CHJ{JgRC9YWX8#&7Dy3SQb~g%c+)%t?n@3f-d!!nD3McP9bA{C%T=nsI;-^G*qWeyDzfZAt_>5@MVVSAi zBTf!A z*s3k>q^ZAyFtyvR>ZLYOiBp*Z-x+=mr`6Imp=z`7Rc{;}+uAH(*as#~H{J3=Ar;_&+Q-DsUt#=^J+2^in!S(%YD z@x~oabh_7OVa*AxT;7<Bz!fWB8+?n!IZd`xY$<8s6-bCH;rx!1aGME9 z^k@1tExZB=xVlH^gx@6m`q433WaBNg16o0v&-5n^#XCC=qizS@P#Du&)IgFeX5Emy zW`6xAp18}-3zs)EV4WY`rI}kjWWy=HH8z>sgw-}&g(KwEvjN`j(21Rq4+=A6I4d4E ztDao%^?tSGw-r2h7pl*5EE$-M_TMEC&`Mkn#{M+Hnsghst+644YyI{kUq}lLyK674 zd7_H7a`7|6SHJwE$8Os*xQd7IfVweO^47x@niG#-@u;=)s z*Z!vXBINU8B0gf_LXTY9DBM99r|jhu zCZ~*jBKeU(|bLhLhjp*PP0do7?7-{mS+gfNPMsbGDPQbs?ThP!27lr6K zw1PB`U!6<9r)vHtNB`^Wm!h^0(KAOY8HJmY?%A-12|kDYv)FT_h=eh|_KWtHAnrp1Ia zpNuaAwgXy08lf#dca%7@NVen4BLc7JBrcGi7kH7Ek3wWv5<>4<>f?zw0$7cVKQfH< zgcsD54Q_=Hb};9kNYT@dUbxd_SGFkLhD7zNBO7lRO8WDhDU9NX(ZO*h<8fT4BlL5* zg91lUIQPA1R8KB_--p~rRrB3$U!w8c>@@nPv$a4#E10`Q$mP2ij)-w$_fCnCA~+hn zNp2yBd^xLLjJqKr^gcQe|H!D#;zv9s@CuAA{u3$M(FgmswqO^oJ|!?B1Zj~tGgyb! z_#Vvuo-+_Kqr;PQNPqL;q+D}N$dV#-P2&q6&9s&WQMvzA6LRx*6wcMYnB_0H3b)glc z5gI)@f<65Fhz{%GL$3eVfvju4MJ?<(K+dS)MamN=sS*9(OJe-$RPs2M>rZIA06jA1 zZx1qg{{iiqnM{HNED52!28wLv634denMUDk+uVIb#*?NlWNLYi1S6~nEuCri+F5@XH6NS{X!rQNqI`6&8!PaI2_^GEyW zcC3(zg9OZyBV$6MAkD|bkNc}*OHP{yrgUxmT(ga@QX9Qa72-*vvgV#a zW;J2C$W$O;{7Q^Qy*N>sT$K&4Q*tm3T0wte#4A26sY4a9tK*(D=aTMHRjtnMOj95s z*6PptwVqW`wT`M)CbWWUXVFJfcbrj0cq}!`Tt29Nb!@4xwU~W{(KdwsHpot`s?}73 zT|zq`EfSAia?e%y7t#O?jcM&Iw=Km%&H4uVsm8 zRkfh)vdiWpjc^HRH~NCL2{H`F>iT{#bbIh+7owt zt%qTJ4-#-i#^0RFYvUR6FR6RiU?FlhY|IhKbzv{{_{BR)$=!`o5s$Wg4Xs+;?Z00K z`!}yGv++R!0jFXYVOZ6CGT!8$!}vfHT*=mlG(I} zbhw5w*|fPTIyiQAQ=BPD*l5cV1xDyXhZ@Q4eh12#QF)3Oft>!TBc5=_6gQuLNr>D* z0+z&QebN$Q*JBHuCWe z2UWB(SBSPDw4)>**R(k=<$M|@T!93Pw()VdB_|S6r1zS!LL3ef&^OnrrB1=)dmkoO zCazIp3$5hhl61AzsG9;U5{F(+#v?{nlHs?+SRbTeL=U0FjA;C9-)oZmgeow;2MN(f z`yt4H;qK9T+KjCC2(<3LOdwZ)3|(@v>hQG2q+t5kXL zXR|bGr$PeSmhz{h=^uL`8zJNU z4RD{Ay7X{^ei&LoeWc^ka3i!OC;<|X)mRz(}`O83rr5Zj-(dCL?29#XE9j-duc zV_Dzcmn5$bFQq!C`jVNgBcw6OWzzo6A>{ksNNIX4bNRkc1mUraZH3B%?z-%IsmvN1 z?2tltxl3anbr7=CV&3^!@+4)6UwtMA_GbN-)RCg@ZSyWFYDpk15~JSLl;`fPL6S%! zTRZBT_xYS4*{1dc(jv!5!i~J-AC1P5jc!e}3*R=8zo(DmwQ{!WjD2R#qvy6-F&HVh zHoBSI*s?^5PPZYV3aP1=H$LZ(Ob<7*5HjwNfF&X1c0EbiZaADyJtg)P?4RDRv!(XA zJmY^bfY)lg`$w;;k=iJvK_=9yz4r8yKYtIBEe$OQ)UEON;Pw=+-<%lXV44Z9Ktg;w zYdh2^3C5XF(FO_Vo5$O&=6F?Uqb}&0XTli(EEm!|Qq*#wH;>k0>#6lJ;T7ml zoRT>ERYsroh2gYquW6(8jTpQVJFhjFVYfhPQ2tewp6Zi%kB3X1uPk^wE~t*| z`1DKkbx8{bt=y)yA@g@!l4eLoL?r6@db1OUU(!v>YB5+7NJHOz-$d7@>KLwv+xzi> zpd*9%h*e%X-UfHOkqDQ$GD5aLm#W+DC9);QAFD9i1Fay;_l^2#>#>X}5An&?SjxzGyI{EgFgmF)0&CaXAX zBxIN1T6RR6kD8G@Hx+Rvn791&{U6e+F^Oz-&*_2>XeG7-S41mTDbv_r*+7VqK?3?i z=;n_Z$`)xF8>9VBU_4161lnvEcd52)qJ`5y#(3%pihHalOKct_Cxd)%tPFIkkn@ z4qydo{*=s%lv`|0WciCzG_W20lAlOBRz;F;W3MR#JaeTxD2k9iRq?Aq^?E90=`)z# zn=XR>3R*#$$Kh`DR(`La!CnP-6$oerby-|hJawHizt>deoY6>yS0Ev_+`l|nG5ZnE zq_*!>p|3#8A&tj)n_)PDsF#_|u+$mA0KRmeoIZOCY=kd{%ou7)icN z&s89e(9E5`=#7uV*|5Tz61)N$b4nPAzJE$u+C~##c7A0Az0;&ZQhJ05 z1hf+KwX^!RW9Gx{vB!`DLOvVjxzAbEt=3}2?7hW(6dJm)nztWI%eHDTj9IMO$Mf|` zvsCj>#Y$oGTWZf}jk@hal?v&eu@lqI_##EL!+7AsY&E?lmKW=n2zh5%65nch#fh!| z@k)BXUcs;|NQ=a=rH(B1)N3jAI2Ix-kboueS@?2=tvL8hTA$~HpZ2Iw3ws2~%lF+@ zzSVrChVdGnYhLtEOy+TO%+QpF(tuVXF=p*|^|;Lv)iA6`pm{rd`!bzi)#R>PYIoLAX6s zBtQbzmMeY@v#aBx^K0(%fCSVwK$`!~Px~v)rq{;zKKWxvyme)ofrdfyzLozfAg)h2 zLwCiRV;ctz?%OI#&FUN^&+#u)p#Dz$eO#PfN~5>Fm9E{7!k70ZsEZ6u)HMBa1=8@> zg%GtrSKV@}vwz%S3!BHh$gGO>zF1zaz}i;r{}_s`HVva`p<>28?8890kwun7w7XPs zB^=s+BF<~ro|azTOkw^M5-@j` zj0W$!`G_`orXvv03TEGVj-}NQe0$&)dj9zf3hN6Av3Jvp2jfe{nRM;An}VJJ60i^X z9gGOXrCU#&H31p206o%Be~X%x%M-Mo4-Xi_iXE zNqa5)OCSNQVD_Ew?OPd#yKI_F>5}>cDpDXV5;vYt$KTBpsXAl43JIuAfiyyIerK(zx0e|f7?o7k3vGcA7c`}C=EvsVc**)3R&+?eE-J-%MjAz^DSWq z2A}CK{jRXheVMu2am8g!v6OltnAp6?7px%7wQ46VakxV}w$&?7AfVNhtB<9wZTv}Q zXaRq7e(mgq>rZ!O-OBW3m~n?$_LKRwWy5wu2_03H6E4d!#$m1l*sqgMWOxN;^stEXAj}iWLs^BG3x1SjAeI4XuY= zP7GorhF_In@50u?k`P)Dt13>l# z)Rx)R{HDMsZfl)v$^XPaGTY;};2R;EF3s^oFFn?4O+9SczCa37dlI9ww-rdhl6Xvi zZf$IQA&sty>@H*hVC4V&(-di*MNg7DtSaYa^t=gPY?DqkyW9i_T*xS83?~?w^SFbhf9To^2KM)X)klVEAtk zMd2q7C+OCkB??r!Kr2Wiv|(I0ZrCoJuKBz|AfOejA&;b<@WnS;rBjR7HE`gfkJ9>< zok*d6p<qVsGwo#|3Q7a&xp*N|5Fn6C3|1mv(a-gD{*f? z;mHKU0RcuxDhdL z6F?m2WeGJz=#QH>TQaRJ8*-TK@8{WAd8|%D-ZZmfxk(0)$`wLB_yFDfH`r*gSw0^Ey7+dDC*QNtmH~w~RxAY{1 z@m)w{*YYLqb6usSwgal7&Ru#k>lTPbtS%}>gv_8m?UsGx$;2uSn!RZNxIQoG+{*~g)$1qB_nf;8Ww7~g=7d_I<)drbraTETwf zir6G^7nHmJD>G-KbXY-Z0r>&W{sf(c5hJ={g{yO%y zlC!rl+aKvHD6kZ$*{y`@w<_F~`l>Iq8nR$JFe^242KYu8yS)cYY+(r^|VdC-z} zS)45Gxn@Ql&Q?hM4%d_i{#AMg%|d+^RnHaQs`FEUS|w-&Y2H_Jr&cSr^4g5>fCN-MLAvTJ z3%%-=kDl7!4W54kt>7L^gp!WxR1*b9zqW%{AOZa$G-0eR{+LuLm0Jf1y8*9nIVc_Q z=p_HVSgJr}6`vtD*TsEqR7x7(APf~$&{4jGN%73f23dDPR>ipQN7(pXy_j{_PwNJD=JwGKO{IBg1&_T`QjS`GV>-j`!>P|>;+i;;BHiKiOby57Bv$|#f=&$znI;Sje ztR=J0jKfeP2npEJe9w>DV}(4|p>^;mp&hX0kVfdu%QU5rwGQ2w7b6hR3i{@IGe|66 z`SYH1;_Y_jD6KEAY=-5`xLblMP}C0QnZ)A96^8QQ)ZGeH7(!YklD3KF8O zulcnYeDQC#bXdA5lm!W~Rwu(7-~pjeq;5|K2xB7rrS{yDEL99}RF_CqPvVJz&Uj?$ zOQ~}kSs-q8&6T#l_fl_JmJ0i;`5m0z1y7#$M4}G%!k%iEq_0x%mLcl5=|%tS4L2J+ z5}*0>LOQ@tc!B%DPj%3h+u35d;BvmuBM5c8Gy*3-xhv@fepd#S7|B16&Q(YEC{Q2) zOF~Fa9f^0HxGU9q`JX_*{(&@)4jLmI5t>PwOdKkViQ(JNNd2CSQqK>$^Up{5$9lWI*Gtp<*)(pmm)XM(#-VI+OV z01V58>PC?W_i@BGui8ot>-u1L*Pzvh>kp;&>t?B2|5QbB9MdgvyMQw3ZlE8APXi>x zx7wuEHaO)|x%BYAo)}s|e9utW-T_aPxQixjBqv5 zBh%C9zAa+J3r1B$ViEPhGdv`E> z?2i@*XazN}2p!mEju$S=p{bqP3+F9$fBs4O)bF^I-LnIMc_V%zP_st(!ng}`v`Kpg zRi)4h(g+=HwpXe1>K?Ul5m#PtwY8w)tWoTU9&4_bXGZlIih;PzLJ@tsu?M zXSj7tnX~B_l`e_|RR2PnwzOU?GbeBGe=*6HF^d}O%9CPJUFY_eZ1G5>_1w*|s+G!;%DCOsue2nIO74meH z(zMe6>GP_1fq+(58!eP3yCDSqs1`cC)9Dhj%_65l~_PxoM zJ$iEE8eQeG@W%4N>_ohCM1TrU4*XHxSMK&CNS@c>%|GW3 zzA?6z4pb)M7opKws8)q4R!H-GUdEAl`N^r8`zsz&cm?_s&uCosW-uQ5$X#-%Z^)nz zNQ=bU_jkxVe!j@sUuOtZuI@ExO3uCvlIx#tM?@8?@tJkn8-EhkH`f|WN}}#(`%%-i(d@?IaJ9}(zOatGtCQ|O`dwr#GrFJe)!%0) zJSC8T&lu0}F2hV&Z%HlZ*)wPb)xRP!;&f*=O!Jqk1zHNVf<2nmp(S~It*t!ruM=;F z{KuWyyWQ1T>3J~NKhO%&{G>t`7xvn?NoLk|4TDyY-WFj>=H$zAvlSZdqeEwf4O~-? z&R^0>s1>BecJ!Xri%olRnK~Z&Ord5M#(IAk`;x&uv{H-NmwDuF{DqOsr1oLja7?lg zdxchFBy~mNV0J!~(vYpW6kdS@^oP)^C*jQCRsijPH(T%lHPn#iJ4O;iSn%Efv>E+K z;T?olkmgE`7+-ec$6-2hpst{Gh6Egy_$=dKAJ*Jq2sDXqhV49symkzr|kr9~}`TXig$q@NRMKs%f|FSZ@UL#GX>EUFbL9tSQ`Cm-`zZrAv>97!%ZRq3y6O{cM@4l8NWMdS_jX&-0zBbozey$>`jlEnCKT5&wv zVKahW(l!*LQILQo@snGc1he?;+N@*m9}>(inDsIt3y1V2k%rg)$uy+@p1^jrK1*Mo zuAxFJST0o7@yNu(AofJ5!n)q%u%7tc-bc!)$ybX%Ys&r=KG_pBH3BP z`QZ!&GnO+{;5mu#)=L!R*cSpP+skevoy$<+E)ir1 zLTUcgjF?_Nrhfd!)dE{{j3p zV({piP%b234f%Qc_hZ=gR{xPnBhCqRhXhm}@)^MKIF|o5PpVlmNr6|IJ2fPuhVeZV z?{EC0Ph8eSX1c*mvE_gT3e*Tf0-lk>&o#I>nVsyBFPZ+?E4Cc?fHmYXvNo}7fpklH z6Te&NUFZYWmg`Hby0gzy9}?e`rWk7L68RpGtoi-i6u$t+QIFJm%Nrz#ts_@xtDT=0~%-`N7gvLG$CeAsLcmg0Sf zq)%@r+z&`V?I>5hY;t7rJ0FswWEsONka$;^uFi<;KzdtMjc|$Geb|;YJ?NtZJE2y( z^rN~x(}(mnEL6mH?D;i_*^G9jjdXM|w1RyF?-N2-Zuzq@Huf|mLtnTbkcK_N74bVH z8fp^H;wtMY@YE`Yu>GoI-6%4xZ=MqL<&;`udpH>zS#@sI$0yy$gTsj|u;w-io}~o| zSQ0`(=gw%8{vbFc=Uu>}cO5<`dDQ$LJK_hIZ^IUy>+gB~#RII=|%lFD0mRZ{i zDH1JjE!=ZRh|g8$0EHQNJ(5i7wZ>4h3kfmPaHm9JbMZsz;7lihfaO9Op$CCZZ2qgk zQgah5)N1D1EY;%icy&OJ(tothvqxI6w*8AG-64Gi9Wo?fzj6JphaHPexG#Mt-36^P zB*dNGC-RKhmr=!}eZ&9^%Yp>Fvk0BFZNauqFC{ZJcNXqBB*go%N7}A5y!c$oEs4c2 zf}Is1No~t?)!moMg{uj*R#00j5{LTSqiWrc(t*rKLB9)W_=NJ^H8mf}Ap~Q;?yWG?1w#ccYZX;}HpHle zsw}{r%u4sFGY9%VIOqX2yU+^K{2jG2PIh*yhwJnkB=~^DO51~83;DNeu9GSsm(MSh zT6FQoKi0Oz&cQjyjDvG{886;vc#7U`e9fTNQgd^?sF8{E|bhA z({TM8`qGM*SpDvKS2zPe%w{;O@ushL+2Nf(eDJj&Z?#$cEWPQ&%9W*!pI4tLJKw8Z zX}`!>)in!}@BgaG6}7!_g1pvu$Bw5u2(^L)tRbK8c6vtM{%MTgIr$2G1qtYze~IRX z7k+xZ5v%$4FohaXhj1r(&zv#pvL*E;s8!`B(D-%4DPb)c9z9P`u|fi#K*V)qwi=8- zHfICJOrr4IAZP_?K8G7S2tP-N9sD+n!Yi75JPzxI{_b*PYl+An^-&-V7%0PC?=JKrnN3PuD)B5acf?)AMk z`%EhYwHc(LvJRn$ou3rDS8ZAJFC&arQ^{>|Fmc>fD5woZmEO3S@0EtO?b!HSBkUTz zUz+O_Ow>kv#sX-OsF|RTm$zugO6zg@Vo!mJ4Y880r=g0U3Rdgxu6S$R?s)s zmy8a;UwEy&KbZ^dfCN;~A~dvHAuVbc$tK;m5YFF)s#48Ve*Vk&79`fgK;|R0n-0r) zGMF7*)`mh=DYSw#*Iym`L9cEb%!cw0Ye6eWLzNZZak%gZy`UY-b{IqnwSqKMR&k~5 zo>W?v5W$-0EfEx3&|*d z*I~&&?+e;NXa#A0>gk~X_R6dd>vG41LMx~p6qSGtngy}9lWMX8jVXmzkcPTNghtMf zV4nVGXpN|qN+(S{GGkdE5?G6W`2xge#F)&Chy7U4&DzZ5zXgI7q@fa+E6S)MjZ}<_gr-LIRe=_rx6-%zkvP!~D}N1OnCs(tPK=hSDAz z#;`jnZ54PPa@2()_4(N_GPu(%Mbse&hpwkN6_ePk)FuJ}SB;QHsPFvQbb#j+R_{t{ zVKoY^AdS$#uM^qMkezf&*ZOpD&7s8MS(LPW&ms*}%0e9}Lf;Kz+2npZsiWUQL17A` zj<6*DrKXK>?A+-P+CI2M*pUl0yCM;8H2yQ3S-I6m*3RPrVrGb!>}Zd z3^x56uyL2XS?hrF6k0(Ys;FXZ()K64d&P_0Mc*k@uIfJ4s(pS>R5Q+PRz>3Mt8H}8 z5g+z0(~NDi9HLH>L)5L^ZmI8G{nfzN2h@7ama35E&#vEJYQDmWo&Q#s!7I=o)bjE( z&!VW@|J+nIZoy(fDZA=rg=WB92h#n_xqr@R>>rsZ4~v<=!p5AXP(up|sJZ3mc&!Oi z{&SkdPCs8DoZk)!@$}Y}FXky%Y)7+UCzB|=0txs94*vewK2OQD8_ljgP7=P&0SQ=J zK3nqqOa9CaXRV9wQK)s@XV+NUsuGd3qTQLI?)B#G0@}JqH#W`V2Zg=6#`Ls0#{Gdh zbLUG>@eWS7pGrOd2D6+`=Y_WcB%p7;H+;($OjOu3)*4=Gya7ALz}BzN|@k zs_-O2EAiRAHgUYNaG?ha9^lBJo))TOGd*fbBkZHq4g(8#uBfs>p%O9Ao3$`BVo*;D z&&gzm9!PC}FH}#S@a05CXnnl&r4y4{{GxErAhd!sLObu7V!d29wqaY5uvUYoO+%W` zsqM|M?;l$vqv~t~6hUO5z@Vw#4gvu6lV+ zwj!#Er+dw2H@=>i9ACbX;VA|1D>0BpC~9LO+c20(Np%8*b2*_Eq!BXXC#ozo%#bEW z4U(aENWd>-^0RkL#<4NsXJ}YwE8*=7-^lAam=b@lZln+YqANdNG&YX)%IQL9_AnPz z*U$>m2$k-ez?z;*roAUH3Fd_0I~>ydn=fV)+1SuC^lrD~0s*aHHjBr|`p;z9+gsB> z*K-oIg7*gM#d&;pWCH8$*om(3oFHW4AT1Kp3udqp^ZL+49X|hOZXR+`f#-Gd-DkyptM$|)?a!n@#W*BjNqiT?>93g&x%T{*yG|x7 z3lgGofUVw_`H>T+ybLp;4`>Bxp21q$f;}#6gx{xsP@uAVf(DUUJ%i*`riFrvowvhi z82cG?N5Y@P?=(OHo`k`D1Px;PeM+SEh4q9S3nXCPh<|mO?-4d4ccrm)EQDMU)YC(n zYlJ)YXA63iN{6qT3O-<_NhA)tw87i$o=aKblAzKn+xkRmoZCyyx?22?&eFoN4W2)u zSaO;z5>U$rX}(t`LBU#WndI`!Nl<1&E2z@s(GTwdSU=*Cw0%fJ48MsJ@k~dqeQ&MG zz7;C)`#%W%vhu^lckf8V&Z`BR*zCzPYjYff=OlKE zj?{ekY$8AF5>24ejh}w8(?a>cclH;yN)WVB@N7d!BQ&e-B)Oz;H*vcv>aU;`^v%@? zU79L~E=;DIW=|EgYLI~bcnrY9T={ixCY|dsmBIQ#0;=8+!iBZ)->%2Wk5R!4s-_+W z)|LOvY$UfTl8JamW8X$Pc=YWfAXwRg*yld@hSO`phr(P_M=Z` z#0fs26{PvD_|AHCYm^=Pe5nfy+BRIx-%?X^-J}J%_OqonqCfvCYrZRK(8FGP`qNHr z(kD~iyGL>z<;q95tR5#X1ZmxBl+()%o{91)pkme^9>NlbW zOFfx?`!>Ql3KFm+{;rK&O;`44#_Fx|WY8+l?WFqrzkZpW^Xn4WhX_60pG_~VcV^cP zHf3;y1qrxj;P0q}({zq!JC@?tOjxTy0;N>GPObGwy+*B!H0mlzWBb42&La8(}m*w$c!g%#TJ})hM6(nEX6HK5g zls_d~*2(vm%)%Zip9r*qBZEln&pNH#?$bcNU^-E#73>9A694Ap8m-d#!C>ic8?hbG z3bu(yAd_9B$ytebS5G?v^|X+H{`fA4mL6XGdtSKvl5Lr=?vM~i(eMq%YUiJc*mbv^ z3KfRX>X@4@Ar(P#`R=a#7nNz*lr89QD|hVYD<~SFm8eh5*k{ZZud|i{J^UH8f<7`{ zeN>Zvd&z5QU+&}Mfd;JAjuo0|MFRu^T8TcgL!4;muPSl&PhfDYg=*9JEk>#|(OWL8 zH-`JT`qHiHj9=!kEjt6A186oONqWVVd3xp-{yBGWN0VVyC-$-{riU`1eoxT#0h*u6 zHS>t3D_3$nTy;bXPe6n|;0cEOd)q@gXohkhL-a$nuvX9t(tPdgQ=obHJCQx8+@*zA zpg*WWPbk<~6wPX3FOxO-ci^S?3tK?x_XS05zsl58bKxu|~Bq_dMsOTl|$ z2Nei9OGv|S?eWp)DbhUPEnnbwQ}_xWB;egdNVg;;qw1c^N}@Nkg1Q`$*gti1#s^N+ zoOD7U78`H#ws49fBM098=lkdG4L-?!0}~k1S}L%2VUM0;+oT^y!^yL)Rc8p94)Im) zt)9j{najc%jnE3x{KTxks}&m_4m9*ADkUoPlHp|(;ae$VI1%w27$ z=Hq;UfL3CJW$mkYw(ExuUC-JnFxCSJSVO*x?qED?d&`=BNoyfQabUTS=5x`hGuVr# zy3}T4GX-9O{$R9)&&X;zuzO3r@W9)H6{sqOidN&4hSiafSfd-fuQop(RvrD&Q_R}K zD=_8&_oZ{?K#SYe#Qwdy9iSDYp+6oycW%pq0zI+I!8bB2%P`fP?BF`0ZZ^3Jr1^U= zz88B_#|$sG=p?iQT0xrc5;7aa{;Y|khdR^{G>}mDXkcBFL}qp;6u+XeQ+s5b$qFixuj}WW9IEvY{R{#)S;Oy5YS50+&%dcNYWgG(lo40sZ4I9{;-ybbDJ_{^32+q zkDj}eA$_WJWJfZFkhE2YiEfst@v4CJHTwR%QSBsP}aA z)v|L{?fAxy(+11mN#wyq!3QK@AM(h=iRH@O$|KS`%`^-(iL(oSXueZ>^=9DRfAonV zY`(I(`$2V>|8!xmJ|v)qjqhE0x>VVC?v-?7{4_xW2?^+%$9+ntDQPahq;39jf)WrC zP)*78i9eRoHx>N*fF9k1z8YWBU9I{``pbo7!davIy8ySm>D(MKb3zgrSH{hV%Ov_;bc`vK;0rg!}QV+95CS*eLW?VPAbQ; z^Z2jYR85!!)pVkQ4*SL6w^R=iLW4E z1>oNsYqRUQZ3SH)B%nf&e@p1%I9&5VI{iLH70&*IHGwoj$^E180Y^%gp*q6(qR>ig z`H+Bv%JNNpn3~PchXdNNQrnf%@x~*`_x#)c=*a3`O;Jkg3}T;e-WRSwLQuM)jGO^V z;}+vsLD^mf&J7_U&SZwiE>`*;jA9iRjtO&7NWhZ#H;WBi@z(@b_SbWapb%=2^HU0$ z5JIeqZ~UW;ikDsSMy^}x(s!%^HCK>;N-nNRoz)UgNpxktz4i+PRE0sB#}m^`uP-a3aNb{Y=ZB5z9fL?5Hte>C? z!AF0pfvrMF(U%(vJdch?ohzK#rlR&Nd~TqiPJn8m55^5i)=8q4y&cUJLVx--XV&#K ztm12&kQ0VVCy_AP?!gi+w`Tr9{RD*&w1WQlKH(ES?APfgth?_h3ay|*NYoVhA`f=6 zHP_a**dm<#1!$26`m>*H4Ox%PO|*?J6{^D0ynD3u6Nk_W-;E{HFzCMBa0EApGj8|5twqR3^c?;jDf@gO@ z8ll_gCMcKgw_qNVy#=iSw1PArsgKQ-pYL;E=iXu=>I`GZkmh4ah@(83bzpWs-30<# ziSgYHM`y`K8Qu(!aS>V$tsu=cMbVCO>cnnr+4g1vaqo|9;OyRYyA^z6;BRG}o})wWO{6dNA+$jRgW)LEi|i{?AeN z9X60%OEh87N-y`XCOF`P`smRzRjk$48S!%9%AV|&v5nAjNWhYKoZ-VbW&DS(%qPTB zP>VnUYE8I?J~>9&aic3UKWxdM6?`g1;>~AwR&RtUHa9C&V3yTjPp#@T;C7P&Ma&7W zGjG7UFZ9Bi^H+qtGbCV1{G^coY}v^M9dVnHJA|E+uv|#iN~@pT-_=1qoM; zBC$r>2X{GlUCLQ)g-4wHCS8kKqP~bNQeZ9>=2-c@Z8r%gm&sDyg>5j*>cZS6EQzZo zJ1e;8#Tn_iYiogkR-%vO$sMrg-~tkS%uUGG!i+1Vx!SO!Cw^A<9wFP>3MZaGD@Y?W z$+Ht)IoFeJkGB)F=gq!;mYM{6l5MRE6?i@i|4K}rKXz0-XbXQmfq+VKNF(&gqboM# zUxS|Y#ty?PuvVfEdu>;I-{dK|*w;?DACQLkhQ9}cLvZ8z#nSh@S{R;F1kWZK?pIqz zwI-@M`Bghv7qfnTL2ru5(6k0 z3(-Uo6#>Pw&)xzGSYn|mD0V@OST71TFtO(n8+MHZ6ni6T63?C)Ma6{B6hy>WVi$XC zSh<7ueV+SYe0crD<6ggMyPR2T&HP7A^=Y7ExtN#7k`!gWLy%s3)^)X3lZHB$i&jX> z(}5F5=y7TH*dVLNL}_YVvaPxB_-3QL<2emSTJEUY9;OGBKVXOD{lM`-E2L%gqas-E zwd^)aTwRY4Xoannb2$GYy74egjjLIc^aBYxRwvs<>+k3Pu6EmWi6}}TK~=d@){oII z^eIqNn%*Y5QAnVMl-#?NX|Hei_clJPj+tn^)X%7A&NshPdz!j1)Ur{OU4K^71N~R> zClE6T}+)pU2C5j?OoRA4p?t&o=MWUOlFU;Gio8!ZYIn3qQaOOmG*h|lFy z9>XTg4G^8Oo12}}e9eD$@5GR%`SYE>w~?nVxSE&N^AVVfM;h}~iqd39bKT2*srsfz zfWRwA&^olHCOx*Ou?v@m}+ zrA3uYp{RF#q0W(sLwm6r^_BkFL63!ydfT8ZUt(iUhsqwW8bU;k`|Kd!Md^ zz;cn6N+EqU-RryS#wz=Eq~~abqabGhE&J)VxffXNF^vW4#^5QQNXzvl$$j-D@*wx( zAV+~$aHO#$c?Y+A%fE9t!M8maEl}TOj7tq;LGxXzZABPMnf$ZivS6qBSCYTX#I?UV zgRef9!N1Lo7Ptxl2`ovj{*)2Znp2PQVs#Q(0fE|7NK2LGSL4{~lwCZs$y8DYu7IIL zL-)Hp?2e--n*61}oH^#}jgm4WJbjhvUsBavzUO=G#GFR;MbVI^L=PFQkd_h1_MZuB zR$Z9yG#98~gam3M%iP_uy?oPUTalLUOInrr=(O>+iLTx`+P=zKgYu2Vd}3-n(b?8p zV7W+O{}kntl_C-o-tZe`UIML<#@5PHFwR#K6PI7$ofmW$sLF&iC9=-E;jwk{c)Zjm z$EN`aEJ;Q`rW;-5sc-rx2iz@~&wgIM&+^NaaI=rqq$(?fism=#c3CPLZMh;WXoWQ9 z)unE2m*d^Ct|jOt>)si71y}4~N%AY2UWa{A7N+JDC6K%`S|P0{Q4^c<5StllVE!Co zr3V{HrQzly_Tw26G*g}H;>>^5OzKE^iX>h^0(0Kdn<#H24*WTv+4&C;m@}^yYRf|1 zo0#`!n;FfYUl`d=cs!}Y)$(pcbqEQnA#~ufix@U8js2D7E6@s6hiDz=vTKO6zu&MQ zyMz-0TS!%)@_s4hHPv}+<^3qqk1YQK#*qe^IXib)72T-D|Lo<~{%2Xt;>iNbLIT$W zNX_MvpZJQKd)S%#Q_1L}71DB*MdEzz{_!5-&E(JY!qKhO(TyY7MYjtYM%d_H%Wl^m zo2NDyCeE&WMY4TJQ+=P~2{+6?ZwwYwbXy%WBWQ)Rq6})(T&vk?lvuA_Cv{-#jS@f9 z9jna@4Hp5eUT7G5Lo1{e<+JzGw1|{Y(LVJ9=`9jilGGiv-KpJM?3Bs5Y#Y4SDO7Id)+(QjM4_WR$?qTXg*N2?lJ z?bH{!q0E0~)jIBYg=$G*W|3dgN=GZ~2W{1iYoBYSetu%VoS&c-(s-Zb>~(B4ZSX&x z#n_IW$sI%jb768lwM|RTJ{>CD_LggrTdj@yHj!-Y^b6Xxh@QsoYeU(Zl432^ur~^l zMzNQ+l^Lwem9N?E!?VQEkzZ&BHuW=9sX%h1RiTDiGR&UI`w`m3QXyBQBxi0mVV)dm zT-zYm@7|kj*)56q#px!r!g485oYc`$*EvLt53wWJzOG?wEJvS?V15-<)vVU{i#FwW zM~lqgE@^0m8Fzg0r3b0qVST0y7hRp65dz1Jj+Kki#%N$WQDl9ZMb;W%eYoyGuHG3n zoY_WB5zEW>kvgzkN^E#$CH^YQ;L2-%J*V3dqvx(b)@4+&mNYuq(3=FZia#%Co7$~6 zT-$0)bFa*EzO}X%3*`>!QRV(R<{goyga|3(Ei4)Q#pi?c<*Ctzd4-t;hRb&X%EkKR zzU@=jxYsE^Ud~78sOf?PmZT{0>*|VA^IUl9mLNi4HWX>8Zsy5guomv(lRR(evl=|(;Y3(Y4#P%rNsK`yVlWoF`5yW1psHu;Za~bu0_5XcoD{t5rP4@vgePSuDvxVvg2Y z>trNzU&G@;WxjSv(M+BcbWiOO7)S00<{c^VnWkM>B&={=dtpu@#GG=rR}y{DsyKoi>qe!^(=pWqPi=;{LM@q z?-UYPpNuE^_vY&t=d*;(@uVL}V4tM_QTh{K-LQ}inLL%;kJN8R8p%yMvyD41YP608 z-tx3t=FWfHHA%-62}sZ{Lrzf&AMn^>*l5gNa=xX@8s;AoBLx z2;9%Np~G8~_q3}<>+;tYy0`Dl%Zb|3^Z~;1)`mp3(F$p~iuFlb^CyiXg;RDpdTa)FXAE71DC=(k@%RquL-oHEX7xWz+ju^H`0w{ZOiuQ_Uoq4{dwA8G5;-UrjImy!EXX`kkSO*>5{2`i_ z`$InSaG#~472ZKgZ0+vILR#!#!(Y)aHPRbeAMw7VnvHBDT1cY~FVq&~UFQ1vp@cvyoYhKY zyjQVo&9PE``B{|c(4)7->aw3&J}ZD>`~o8(awhX3$o$2AMfjHuB|9yUro^mHF6J4x z>WEo6VFLGHAc1=}qO2%uL{?{x0S-#tawOTcoQH16pyLPoHr2vBSt%A%P{yNa~eM!6oVR409m5WR5 z9~cMg_GYH+a~f(wDvCDEOGIwZ;8Wyz4R{3!TzM#0Q+0gG7rnI-A5MD`-@s1p67nn_ z$YvCsCw_uld3NjxucKOv@0h2KRyK=%Fm_EH$cDDRP=&~L+s5B!ju)N(PSjBA0~IOj zy9y()#|ZYiP321R(rG8PyLD=b8|!?Cf4;PFvif0ZAS=nbprIEoPr@7~qgL|lf#7I= zay}=XC4;m)19WPkmQZk%x0@5D<986Pkd~Ra**Tis{NvoqHjI1|k-(CqhEU35t==6Q zas4k3(kdiSdrMK;Hh8SLdejo`=RAli7ZRwRr6?QC*4p246;kF%FA_C~OUpD3_87>n z^}480kKX&drVYH=ObjgdmD#fYSz!!-e%q_t=JJL2-9)WEPWq(ni;e7+;jDRxRDy=5 zWW|B)MsPtmv#(#dhJNy$9PT^GP1xOVB2Nj@_{1p6yaG>N_s)1RwS$F33ed03UG=>& z-g^Y|xmbAu*zJr!9#Jtutlf2mL?h4&X?ZSL=t6#eb&yy)p&IE25*XJ|lw6M;{Bzz{ zG?4qYvF+Gfq-D?lIn37$4i$@5UexfZMk}OcZaX@W4>>SF-1Xc_dX5CXD{^J#FN^t` zvT>qsTpFnZt&o{N>Quo6nh6Np(CXRTdkg|}Vk%=8s(v1F@bH(CqR%8zb5*QCslv>t@j27+Z3FGx+ zGsb<;3TZ`IeKyHh^6wxqKB<9R5AeTO555yl$1^N%sztCimU~xVv&L+QzMhHp_fVZIwNj_#dc%N{Q4TmY4%NR;d2{5_Ht~ zL@T6aWP9IUb^D@1W9_8`qA!ZGVx;A_{ox4ny(T-^PhuXKg`*YHavzJ`T=NR&U2J(O zC2-b{wEV5j->cfye5%$tPg{jn*gv^9{AjE>@y>AJ(ElYzzsAk-($q2EidB4LYeH{H z?#O)}YxWs7PW-aHi0BldrVr9m582n%+`4v@IFxpkthqxgq~+cv*GO%{?9sx{HJ{Xh z1l1`z`f#@P)^CV#?tho)$)FX|@*IAkS`L zX0l~tHI+@us2a;q1G{pcuz`F|&#AyCj-|6654YdU8H3U`OkFCtAQuqdNPKb+lN+k4woLaS^F|% zFxL;pkg-B5q@{lQ=v;oG-wSoe@CgEE^vNRC=z1;OyfJMALk~diD^geThd#Ge>yfb} z+lK^ZWu-T1^qKW*?r_+*gdc-%Lj_hEI;Qcp-LX?pCztO2YzuX66Qpj!+6vZS2PZ7`G&aZ~#qTV6AO)m&bx z(b>}D^g7)4R&CMzdWeo?A%S%(igmDyh?RHndoMd3pPxrRWE)5O1TocAq~Y@>_tAH$ zC-zTq6C0P8Yxpc7fhEZ~!zw$G>}@SR*tqCug}xen+7IX3iKJIn;&3e&^0XsOiS>J* z^V*Z$M9KqO;@u&EZ-Jcqq(A2+m)*oKQMNi>K?3`yC^PEc<@eqC2n)+6Js%uWVt9QM z!A2FGt@1|QNp=*S*13t~w#DR)LV}Lft(Nt~z#=!%I{lS~R@f3s96aSAidQ*_H3yE8 zcM#v;c@y3l`&aa5pDwA~$(pmNtymo3B$l?iL|TRARv}8h7HeZ}N|)jf4dXueo>Ri# z#a~!Gu@bE}(^jDs(lWz))3#s15ZFTb}m~W}$ZUeU$sp&otz{4!MY^Mcv7l0pD|^r54hh*8KE$ z4x(2?FCDModq>~eWw%@NMJW!Vy55V_fixwGCNAUW$D2jJ3}+Gp!0*+A_Ct(M9|p1M zbt*l`yJN<2u}l>$TDBr}pcT^cG}=eI`Ooj&ME|c^lRA(<&r<4H?l*XtyN&q$qMwdd z=-pAm-hYMM0a{B8FY7}*NTg}cTNyEY&ux1V*{>(*2U;O5cR@68=WA9A9?@yMj^9C? zO}rrjt91u!+j=v=v2=Qb}Df{Q)S$^_HG$GK6_Pp_j_L|dx9NubA zl#W*Di%_EJ+E27Y9`(fc9>F9!h&1*|uDW-3)(*|I7M*&8kvcH8NWG$j7cScQ28zhb z4kum_)`;*qxQ2Z7_cgp_jNTE+qE6amx6+mXQd zoT9{Cy2WjLBgM{M<49RZPz~(qd#j7@X9bHzi~mJtOK63(%-1@6Ce|F*#k!6^l3xRy z=}}_liBH5QA-b6P_*V_D;O`aw`Y8YRFHx6&B<^suHE)r-yE}e+Li>92E4A|PH){*E znui;G{Qvg$r{t6P-+vzor8vdz_1w0$N%yF)Ul=Sf2|TuW|6jIeWX>1NT3zgM+gPS3fqn}C2*|JYJt7Y$39K> IuRLq(Kf)s8ZvX%Q diff --git a/brax/v2/test_data/ur5e/meshes/forearm_vis.stl b/brax/v2/test_data/ur5e/meshes/forearm_vis.stl deleted file mode 100644 index 5210dbeff4ebba72d4890a78d26c9537b1cf6c5a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 648934 zcmb@PWms0r7xuRWVxXuXsHlVi7@&d(&z_+y3~WqnY_YKwRP4kq?CwJ0*?X{CJa#vr z9*^Dm&dhoL=e_6q>nE;jt)FYpp4n^ewRSw^%m4rVx3`9-#crOVce@!&_Fo#wwwQN< zy~s>GE^-u!+7qu3=yh$@o#4%! z$Qm8OiTISKdgGEqXwsmSrmt_3rA|PC-fw?^k*j)J%lw83`tk2U#;;zhl-77-)+Y!D~zz1u_*hpZqCeS0bddd;7fB8f5sMGVi;#QN&% z1ls&nam&a>rF3yrZR@_ae!bzKi;-up$z=>#=Pv(lRsWV%(}rvar;FW3ksi;T8CsC& zbihFuN9CE2i^tejG@P>QiR5Q%l|Zlk)7>R8IQu4#QE^&l>s@|SH3AL=Z+zk}kE+7H zxz-V_eC3f$d$`5=cuA1_`*(1EPSo5IODot6Ci#=oOtc_T+uv6gGk9sn7EV-~+mjY* zn@EOrIie8g6+1Xc&*a3x?fp36QZCk7_S*#7dsA#Mizf1@%zCwzBZ>15l}GihNp4Q` zt}vK3DnFcrd>v$>1&>gTVU!(2rT|oiwdWZJc~sZ4^XjjcmXb%3`|U@|yLS2I z-@OfcUDIaF38Y;-Mv*RMD^j!|kvX)KJgS+8KJpkdodT&_*#uJUWCevluYW4#mq%sY zA!3A$^(7zqbspt0n{0ocDX;Tye=6xeQnTf?-JpGWeQl$|^6&fis&V4y(_FMtvxTJk z@hFNGB#!pUwuqJ3gOumQqK3I>w~Xav;pe6bfnK?64ol*Cx`^>~%UFHjl(E$IO=;3> zVTwGeklVGCBS{{YWD(c=gs*it@$K#@vZ3}uGWkp(6YAgv&kE8nvc*y9R7li56YIPWW@7KtH`@zm4r1uY!!1b<3Tk&X=%1q9My`RV$Q3z%)@dyjUrc# z78ET=tS^7qD)R2j17gn4#pPo5la`TFE>Q}BUNt6WTg7}N*A_AI4R~)Xn>3d0y75Af zJmX-9qk7A)sKcT84RKWGuG#AymXtEYQGI$Xu8-(3f$VtXaN<3;B1H=lYkl(@V(mU! zC$5j89sOCKQ_D$}2oHrouWMsVN#gcu5o5>Veu@?()EFbZny_ToNu+$|0t$g%VQ0Mzajss5i5NXRx-h#( zvE<~LJ7oV`e;IvG?Zp<1rxKYpIk! zTqW0bioMpgZ7e%kdK8Hsal}Lm62JQg$)n16CidEy1G=$>QuHl{iMN_j65o#WLg|_mq*p6$SteohJ!q+=oFz#syIZj3_6@_J?hNR zf<&zZcexUM{|H@Duzm=$w=O5I!<`fYy=rW8kVM!l5u zvqg-aTZ%E8>#?MbwIM5=l_{eyI_+s-JeAn+%z&L!KnSBbc@#HbC ze~dC7&KOJI$Cnd&S7tb&cT4X|ky&lP5-`&Qd*9`a2lo(ZuGEl5<2 zO_4E93>WIPYyYk0kdg@`<$Q02K(FRDZRN;^Wr`TR^EM)12d*H|3kNdP=ctfTU-O*U zvI&`6W)vy%euzS#7uF1#_Oa047LnVF*jF^sg2Yu%54i^~B7`5+^3XD5rSlM4?(+h( zmfvA%IWXTxsRtH6$&{L5!zpogqsnHJUrm!qxruEVTCh$~mHyy%S>&I@WyF3>v_hcQ z@ynTVcRsm9#Hf1tlD?1U-BCO9nu#m3S|!TlD0oBVrLk9T&$~R1B%kTxfYQ9 z9qTc)Afd+a9PdD@@x1%2N&|&JuQEk{%A;B|NyLbqT7_2OmDk02H5gj3n!s8_(~kI8 zrEBiRk|#rJDg=7r(Q4Y8S#9XR5((t>zAq+LCRiI_Wua+zn?_TI!9z**xvvU=URW*h z{UfoNRIgt3-UXutiI}Hey4Zu4Jj>wOXM4MNI{U9glKrlLj`|!GGU{thTUU7`Jy32G zS@FeAAy(k1o~ya9JgWH-yDb>+>>n?AR8MrVpU)ZEnU<+EoP-^`P0)gb8lzpG z7<#eXBr@dZErmd@gGaq&ZQ#~P#AtcB30>cA6mfZ0fT9Jf39Lml?e(RmRKFHWUiB%c z5a@+Rt7&Dfl%aXKzAE=Qh+<`ewE z-*vPgVdwW#W{e9jO6t7wnxBujZ`?{~U>rq#jtUv|wWdwzRgffd&0uq9xI&;8)(o1K znDfAL<}Y63j4eyhg2b*qcY?(gMYF^CQB`!QV;DT|zW8ecnbhY_`u{Ab`m5M_$o(pE zzNjYt{$*pP7hXq3ju}MJf`qE{=QXx9Yn)zA4s08&5a^X%#=|Ps?!vAj#fnIpDnpUn#Io7y!BB`}Xr&!BiHG#E=rcHnC z%Ffsf-WALnp+RCiu4PRDqy!$Rdzb@LV-ud=n` znWg%0a_z$<11(6XF|L#z!Nl5~lP6Ol(5qhIP+3iUJR)L*Js-rx+RaQHWTFM@53FAJ z{*iYGJAWsZI5`eh2=v0E)wKIDZP*RofgG~si;1-iRufo@@Riu174x)PKuX^Dq!8$Z z^%c+UJ?zbo)5p@M%k~;*L4vCcdDgp1mGSA*Gt*)43i7*aG($y?ni>_krq$o`!gO>U zMUq`QC(?wUUzguN!sNT@NQ^UW~A%SLyY&8-5_Ip?5jyT4`|h_7Ui&gZpWZ8iXX+*| zwpl(O45QU}Ulc7!;7+7zyX#x^fh|JmvdSufUV~P|P82OjV8>I_cAVTzislQa z3)dG?2=vOmCQHU>x2YVD(WCrs@{PwB>s5%N1qsx1nwDDeI%&yc1eNPN1ri+v*%7KPHS#UCjIdKG@>qKi5Ivi}3$ajLcZ zM4Z-z(|QH&6SN?K+E>$>`xT(8BE#sZxHN@8ul8PEavWu*pW-pZWs;U+JvA539NTC?dUI0nxBv3pMuYIB+x5= zXqXbf*iyh11jCSvp#fz-m&{78(_(HE%c#2~4vffnHdPYTA<0mFd-Q zA#`eZhC-m%DSqGRLQaAn_+k@R0Os1tw9y$#WADr{(1L_Vbcj4xwSvX&oISDv&Ea!? zINwYIElA+qEKR#Q)t!3rb2Yk+yNLvPZ8=&;&iVD6zC4DtV|n^4F`NdUC}*Mt3A`(( zX*nld=x}}wR_ZlUA<(PowCZvki}Q+}<;1$BDCOgr)_9bO79_Cez$^Ov4s;wJ$INY; z6#~7IUb@LSUtUY}EWNkc(@A{J?KW&M(Sii_ICwW|e{Q;h&-n%K3krc=*lpqWop0o( z`5uSS8EsVpy?&Q-kRvZxT=Yd}cYaQi&V<_#uv{bQ2ZP^hjV*FD_PnsP@ zC;h0v(1L`jeTKgZHm)a!((p2s6au|EO}m{ga=XuIF^&f3Ya7#eZtwfhlc5C()LQ(G z`(JkE4}P7$sqL;1=rwFav{j6wzl%7#BV!7f)%d!y&v0jGK>}4K?~xttVUi`G)O)b2 zLZDafAFHhkUM1)scZqdnnz82gF(I_Ys9`3EHA=@W@zwa zqe7rp>$%kpv99jsZ00crx)fy%W^uh6yU9cg5?ERCU9DYdmd&q^QyHTb0=@2>tYe5d zkJ>Bt!PCWESU#@dJ}-(l(Sih4#hSLybYqpchTGcMO(D?he!CDkj>w-Hk8vc$!6|QMn%PTUWMdz-VX&`}K6>Eme`PkoAJi`*!$&=OM*LlI|GYqsKf&Bn} zr@>l*ec|(Q{Qdz866m$9MwoGjE15~DLSMBSR+;tL8A78gS}kZn0y`o6PPKO>#+FlR zSMwho3H0KPb2%TrrqGF-R=G2wsj~`Y>1aU$`%L_*ol=&Wd_KlhZ%U9ruW!ZuWsJA` zgx)<<&{fWP+Sf(|El6NrOVj#Yab&d`h0)%g3lsvqy1IGEIoCD{-EPcwWPkJA-tNwP zf)*t3&NyFJw+gT?eC_rao~{t+HSD>Ioby%ZMEwzIXU7WgYjEk&!vrlz;2nP6nHc@q zyv=j_i#PWb0=;_eu$3_?RSUv;}IMGF$xv(~gf4N{Hm zJhx|r@OldVFZ9Abil*6)-)*?giKJSt(iAO7s2%#}(Oaz!aSdtjY*&RqFYLBx+Sg~l zgU0PGLaRS$srF?Mel7`fAFk*tB)&Gj9lWY!qJCzb&{yA2 zCh2RZY#{A#3|0vALjB0^8_bR*r9ZABOMM3_IuVIq)uSyU#^7&tc#NvA8j|P^Yl+?F z0SbX$sA~B$MRA+R^?z29mbN_=y^BQL@zoY_7FKN*`s%OH&E)CK#U%Z`N}w0&d45zK zv&o#j$s{YLt)knJ@VIkHjwAb#&{sKSo{@Xq*OJnc+9(8iVO^qWT5u71;>Ajm;aOLy zKaeQ!;=5%VKUa$%3Vk*Aq&;2KWi4?_j!+2n!fJ`@lKNis!PAu_EwZ{&MV~Z3r#>ZY8Oa+liqCiT;<|WQVx= zfm~_hs1WFd^)Bxx4sJo4mE^1B&wHh=MdEC~>M}-&n?hfO_G&>hV%Cz0A5;Rpux{71 zhfliFVQ*KG>_6~3yl)54Up13S zwR8QIJ_8b+{X^t9y5tf465sTGRIIC%vi%eSy|61HDwZL1G@tXX?awJ479_A&qiOCl zhtVUm7nAT)d33aVcMg-IUDRw(Fm{kMP479FzS)&bZijdiv>-7tMwfJxzC!yf>^+#y zUYtz6HBbrk!cLf`bwAjL+E1BF93C%Gx@t&F{#Z-KI59zJpZ$gU(5X!plD4f?0==*Y z$op_DJJNUF$)sJ4n@X<_i5WIty0``#4;9+yeW&)+v-4VF{Jy3T=!LyZt}>h=Y4=wv z3G>QN(Sn5Yc^5ej%Mzh|T%Sf#+a-(1!DA|cUf7e>v=d_jXu#fiWKs)HiWVeZEwGg_ z_Sg#TGp2d~ZLu<${H>`3dSP!_({6omrG*yFB`FpQMGF#>2WH8e>3E=o@yiI)jeWZvCXKxm&Un@^IL9oLd~?>Z?2df`0_ew{B)A~&*D66)Vq zxi^7?7UeF-5j8_-A7|UCWX-0SKaczf9;)19K?3gxXcY)k7iB3+pRhrPgo6K2_zJp_Cm(3ld-QddV1l28*2dCa4i}nwm@= zI;sSEVV%fpoQEA*Xu@2wH}0lV40(lI%MFDRKdL#ON}{lRjdPa@^UQjak2#(3UEJURZ_m=Y1XyW&^g&BX>`DDfKxL z*c;$)pQR3AS=ZN+oLdEzu7Rz6xFN2=hm$v^VaI~k=jnr4!ToDUFg>qyKad#hA0p|L z+(I+huO7@+aShkiMJ3P+J1?3xVo-mU=gms8AhVy+#X+KQN*x*F-`>)l;C?K5)&_F2 zRbPcbFYGz-+W7K2}~A8g65w?#yjoGGVJz6#~7mccy9n zms_#iJePbkJ}EsiByL}DlQ9ndzQJQutka6EYP^u>?>;I7dSMTczc*1KR9+v&3zcMO zLE>cr2YIf-rKV2pMw#8?wIpzgqe7q;_A)iC^o`2Qb@v)_>TWfL79>Jme76c+GDg-0 zS58-Ev%9S&AsH%xUf5yf@5Jdv*zM;l$smuqO1Bn?Id?9}o#DK!aqi5vXLqM=Aoe99 z6au}lx2$Q6az8V#-d{zI`9&)|WF)ZT&F@ri+F}-6m`u{D#44TbX4Rr)c0T^4j)8r6 zP3zpPv3V?MF6q%}kkW5QqClhD>2i-+E9%-$?;Dxd!WWV^+u{@gz3_el@77+PX7s9* zOmZ(*t8FB(juuZYkE40dT-Ki~yltTJj&TgvE#bV{3f!26T3gfR8RPjI4Oad9xRQzr zM*{V=rY+bzn)$uDtmmp;lVHZFwhNWKTaTJwBrK9bF-$f?s7h@xIRb^z92|sL37!y46&jh9WfkZ3&_UVEi%V}P7 z-o2NW(7(HJdAUlU7wQ#FD|EhynfX_o;hb-rqGXUjjlz4#@$byiTMJnlrgvqy`VL)O zZ52lne8ksKS7K7mPP55!A8Wo7gOrts1g_J`s-0KW=!NLhO6PksN2CoEfV!n zR%&1;gx8!$qsWW2F6NlV!x&nS7}F}+B52=Y;(k=r`6#k-T8!!5UnS5BJ0Y4@IB+|$ zUg%Ry^Nsj2)b`VjQQ%W zN}v~ZDR>{QQCpgNsEc{`#5WTyNYtKOUB+1SQtX2sz1z|Ob7Rb>_f-PDuuH+8pxfAs zigo3FJk3N464-Olw4Hs1(1bBDW{YJ*OtdVH3Xw;0zi=)q<}vCiq@HyXfr65HM5?JGkcihht z_vtaFM~xm7E4r%lrpVRjm=j)36bJ;)(5}%reNP2r2p%XnOjbuY+#F*{As|0$X0_19`%@Ed& zud7eyU=uA!j7_cc?-AY>J+cSGhcI!jie69&^g;!wY40ZVVh6T&G1-K46D>%fp5xt* z{B@bbkQnpz9v6lx%*)k59tp{x%RuGCuk-O$*_nM^Oo#0?8CsAik^hIB&4(XFz8bf% zD!Vf~#iA&U5UI;aC@8Vs9 zA4&R%X_@BRzI~Kk4GG+@H0|=i6!K&DdGpAljtqAh&mt+7&jS&6; zi<04xz}&@i;<3K8`iAr7sp)Sln8{Xl43jYiKim_nDzXxzVyWlYOf%PUcY+oqoKtj3 zzqui5PS&s&wPc((UEMqs0==-~scALl$I!?Lndarg^9WjyIN<9qV~jg3`l5q6bf#@j z@;R?CUm?&7^RamUqXjL-=lrtIB_&fMfjOAp|M(q7FC02=M!(2S(URwfi#(E;Umq-L zrVf8vmDZh{Y0g?(hN1o)o3|^vDg=6AH>7jZd66$K|L5dSN$G)7pIeLiF*Orp?596fH>j3`~_P z-!`xskFh8FE4h2@yxGR5zCxfE?p^$ypUwuEkd$d=`LPe6A|a)r-Q+Og??y44&CWAL{} z@?Xu76En>ngX>YWAkpZrEO~_Nh?sM?=U>f%e9p&u*H;Mi!Y&@~){b;ywRj8?Qj?+u ziOc6~4RIFU`-!!?@Us&u#jo?%&s74wu#2Z@wLVv68Dsdlnp=jV1&N^irDTjAcg1}J z%hYP@3LnRX3a$!)Uf9Lcw919T*up89X6V)26fH>j*7K4vyf%nlUy%c0Osw6UmU$Ec zy|9bN*KSM;CUW}~=Zgd_NSqGzmoYY^hoeNl0~LZBDc)O;l_i)Cy1HTY>&Ii=c00+pJk#eHnY zviF=fx98t&qI#LXzPce+c+(17tg4b3y*ZMdPRKNuxV|*ef<&<9F6o1vMExj;e{DmkhnLQ*XMj3g=>m;^*h^MB7Vz)*pqWA zfnHd7^Qvg;dosC-_`KE+6D>%r>CWqO9;1dX>hn$GK9KJ`MuO*0g+MQ?yfv-BdOKQ< z$Jm_jrim6LI{LcH7-WRFKE{_WK*czITv7@2!pfU>io#3MP#$CKid`mJkoa_s*XMj3 z{=VW}fE_1G(vd5JSRS%lA9h%2eFSyJ|BaslIs5vRzUjz24rJt#V$a5a@+E zm{+NV>(B-nzHG#lr;0*G0<|x{UwgupwwxctGBTGERFm~$c@Kb(iK( zX;ZIYrUjiKXh8xsBY)l}el8idD2TQAryNCvwBiKs0q_`wM%K|$6Y(dVMicTvtXRW9nYSNgNdOKMQGO!(AJ=8~yrB6L+Pt;pGZINl+EUHx2>q6G=31MSnr z5$-uH`l9>G+Z`^o-k&vzt)&p?h5ZNqC7nrUjGM&*SkGEDl+FYae@^i(1&@*NUEFuR z6nVnvmlDL%600f%dZD)CU4u1Y<{^HrCZ|_YR38$k=lHuzn*-RBXkV6X$wP3ryfrJ# z5JxgDVNbBSbJ{r6VLr!wS!(G^7PKIdGl6$PINkD;coS(yx$11^`XF|w$PtA=FYJV9 z8rfBqt>V{56TbxpT98QJ!22aU#<4!4Us9!EHCAL*5SzA4CD02yA(|F2#EWI}825^K znP@>`uU(Lg(X_SrdQ!ha-mLP9AQrh)CD04C1XmepZmdbwVAk*2Bt>N)f%-wys>c;z zWI+(CTlA)hyZ)9AykEjco2yT^ukVvrpy#p66euCSQxvyitGPDCpEc6UDFk|J{FN+MI0hvx{fJKc|>zL86K8 zo#2RBiF%_GA$%NvR~Ty<)@=+M_9{dn&#-Xxi2yj_~sv zp?!|0l+mZIAH&`{1u6u3p{n8UT5e9z%P$?nMjr4p(SpR~>usgK(m z&psa*su1Xfnn=@Pyg%y~hs86xqnC*mB({9Zlrb*f7whWffg`%!eGJ=h(n}%G3zd_m zCAZ4f^YC+3?@Cz{El56gdZ9Yw&p=kP$a8hZW0e&x zNEG#rkmIP|Y75^7`*#i@mHC`IELy7&=!NP`({B6hdqxQu!&mk`r0=-n#SNP{dGJ8xs?%N9FgN#-dcRy@bt~`9=$$F}%&I$Dr;{?uR2$C5DdrKvYNcu1aihv(U-5a@*(RnxZi z8cbU9yt_B)Zyha2^nKTNMJmP|5OVaZZ+zb8gAFs-p#o zaz{$ZbCvIv_{!$!LD8f?U%QzNix4ExOI65=GsEP$3f@|spaqGSb@R({S2_|(}>uB-lIIw%BsVa>o-Vv*~*Bi}!~ z&%_b5AaSyKlFaQsKgBbUcdsneOKlv(zNJ-B2=v0rfPN z^1a&_w&)#I2=v10Mbq+b2@jkdI)*h4j38)1qDh0*y6Jxavrw;?@s7SC7hV?z}J zy|8-W*Lkix*6BQ7Md@J#El8vk^RS8-KemZ;HGJMptJnv>&ZY{1URb^G@8ax@HpKps zf4+sF1&M<-lB{AM-0(xp`Gi+{k#AcmL(KWgVxM?UtlB=rOyJ|Nx>Z#O^ujJTf1Y|#J5%iES0@!C zXhEW4mX{ny`Vf)ZZ&mAL7UbjD@8PHr=!IQw{>_1i!RAXoj@fVi*3p85&mVs|j`@Sd z9#v#TZ#f@NPd-ry^ujK;rakL3$}BV@p0%&BPDcw8edCFYaruDQYiqcTFvVV5;z5c+ zpcmHdnl`ywf?0!~t8=6Lb+jPiR4QDatJ?13>&At4)iFc3_W9}j!GZ*OVRg^n04=PW zqBdyz;ZiVKkjU|kkmDFuSM;ySHV-vR^6UIY^i738FRbqQ*A@zfn?n1{d%Dhw79_s; zh01Z%KP)ud=zw6e7~jwBXRlHS^up?%AJz2Q<^`^Ir_HQwpaqHG)F3&IS}#QpVE5i& z^AoQ>#y1I82=v10p1(@=*xww<$Fb(o0s}2b3>x7pW87y#Q$L+P$PD89$G}umA<#?h z8q{>|E^|rG7I{pxAaQJ+yF6EBc~KiI$<@nz!E@q~xO@tMUf6%&mGIS;<}iMJJaDU^ zbS9Aa^U^_%W795C<2aNlX{L7LHRpqO2L4~@g^ZdV(sRWOtc_D z-`|m%;d+MXWpvY`4Y5akC_hdi&Uv@gNCQZT=0=+usq*!ctj1Qm1oO^Y6 zMK-2KvDaf$O{KzyPQk9gC8?C7te3V~h|Gl-1Q>#o?(8|}VNJj^JT@Xr~7)fQHFKC|-6b9Kq4 zh|Zrp`{_&zU2kp~ff_U5R9jiDJXGq7*A3jEDade^%{sH_|)Xku_gaUm?(|(jX6u zxISLJ74^rymCfpOD^Y^YzS@jpm4a2#bngzqdygjS9>+yhbdkj#P7H8hao1WX1bVGl z-XVPhClZed;_0m0=Fr>DEVFtwipm?6e*EggR&i8s1H{>_XM4{~zv{?3*!U_0dcASa zkwnH$F@xuQADGieMzPD|&Jt9`sKVP!&u@rpaGIT1S60tbERpA{zg-GYB+zU4i&BQj z45eely0TsL)qKY1d}Gokf~pwlYk}S}#{Szvg?om3GizICHsV`0K?@RUjG41+*gBqf zdk&gRP!(f5{Ga%{WG8~y!v~J+L%xFwfnLrH0}OF?lV%9r{;N-Mww8}0*6xCiiWL>| z#4+J=9O=`=bv~_1BrEclGwXG*C_xJn-H#C&qn07A^PBw$dwbWJ<;qxQq6$OR_xeka zBqK(8jP zFU#v=Y{?4zsN$|gnj-JU8ig3D7gS3R%(ikId7Fy;+<)|Ja|)mHu)jhWT9DY4l_F!r zd=dM3*<+Q-CQjsN4H>ElQ~;<~_#21If9iq!>^cS)Q*;Rus5tn0+a->YCp?R8*yhjh z4hCuzRb?#haEwe_pQpqfAh zfO>^rwS6efb=QlXe0bM_79>z{XxfKz_SEF7B=}?y!EQEcl&>^E7uTTARFUC|^DlDk z&U9wY?~Ea6K|<9f6)xFOG9rrgnmCDIHyh*O|HPlg8Er$CrA4s|l~e+~t^|3@qxyA4 z>;|Py_)ynpjx6-;bAlFBFQ{txb4x+r$bG)6B|X`s5a@+R%irzWl$Va?8E)qNvWi+l zC4)MM*PIKF60zgFb<85Duh0vX8oxJ@{X#!o#DU$&YDQ5_paMYsz&l>~v-Icpo!R>6 zMhbymsBrj`%ihn8gTIQieD|AD>};Y&QB}s{rCG)@ex29&(^w(UYw1uAtJu#EN$t~l zaALrjt|OVvhj@a={*?8qQTne$eJH7?Bh8=KI$P27p?HL61|`WXdatce zeXi%am=`-bvnyG3DO!+Fk7~q;t!B|2M`jzND+FpER8CxF*bU+D`$RE!e>;UhFFacQ zgwX0d%=W4?bKmMpQAeTnL8Zp;Z5KUiidyaRvS$P>NTAlzv||10veJA%&j^`FP=BCK zK>fhKhukEb?YrU3zT_FN5a@-PNYm18Ix&1{JKmr+L;Zm|0rdla3MnVUjN?bOGKxP- z28so}P~q^}VEiv5byO7FzOWd>9to-%)GNGJ^YS(go)g#1j8q8py7>K0y2uQhW%r{= z?Q+Ilu9o(PIWp|3U_9)d@OMny1{x>ZII~VQn=rH>p~h&Fzo`CJZ&oMn<^+n_CvIYg z;O*sC=zTKk>X?c6{*l@(Xjhh-<=oE+6fH=!bnaDSZ=n_XLs_fUJX+M)dEU*EbyX$M zOU-cY9}Z6|-ORHNzv3xckoeTML;5Oygem+8d0+I1z4@$3-8xMtj#CKqQZwAJ&h}|PhjmA}75}$NeUv`iEWfv2z~8@zD+GFBcIEHgRhn(Ked3c=sl*5+ zCn6CzI>{==@w>Cg)I+vpnql+&(hmOVqY&tYIhcQ?An>wjb2ReM-G_ZBT9Dusz4a6y z$G=?zmz}n(OXE)gznt1B1bSgU=KIGETb5Cs9NJK_osy}M;Cs7_QBU?V`e%Bwt+Ag1 z+CJmo5{4QXy)bL@zG$*1Gj95%)pm-YXh9-*UVcNIs|xm_cJ7#$u$8O)((cUiQ3&)> z>)I`WgoOlrIP`Z5ABq+vxDGPJoS&6loL{|LvDcna0p8E-6#~7m#^qmJ)LXIeM?PtP zJc>}XAW#Bd+n%sTlV+)LL)mMQD7Ye2o~u@W#O@q5Y9uQ)+b=C+oysqiR@E< zoYsk@?#>P<7pXD)zt9VJYW@xE-JRHkyFO`oY&C`!B$9u)%X4*kvFMkqd)ts%r~0L} zjdoQC^iq2nTQeH6PO-MBPZqi|v>?H2Tsh~LWQYF6azDm;GzvHo6Q~gAg?$ZvXY`66 z+w#jNt-V*E(qTcOxb>&Z?WaD8`utbA6I-^zFRkF;O%(#YP<`;X2gW+F374V{QI{x1 zmmt9_dYO0s?HU|u{niX?^D3Zf!EOqHUZ{h3pE~-j`6@j8P-=m03@u3TE`^No?;VrK zvAay0R^I|X>=>jF=!L3=KOt0jx0(IMCvDV%K}r_~iJ2K~<+)njM)ZB&1h+BMH~OXR zX+2sY&`VV^8FsDBJeDtqa>tHlXhEV>?z`zCxBq*0=}enhhO>9QfPZ>UPzdxw6~+5- z#vDV;`RXnc6n%w6#)#=P#2#h&DDHV_^=Ii9cKfB-$EpN+VV8<$(eyd`4Ml8tqu9iI85 zb>-I}T9Dw^qdZstUgy)7`O!&p{L(zT1S$l2VTX;sEqld}4jY7e;3(kNV_%jPkX%7m7xWRQ@`A0jMWQ7-nC5YL|Z=y zPm5`!DFk|97mx4Gshw!#q2$0@@BWx*LBfBZuP)Zr07ub1=Q|EYd0 zD|beb;QOu|M=iOZ*Pl0@er0}XtvB5=@c%+D>`Y3{K=&I>4&AMF!$1oXTz$y-SYZlX z!j%lweR2Z+iY}@nfnHqIz@5?3oH$Z{Kwy&tMRl|waV{ZTZ&h)Ho-0AzW9jHNmd+|| zS7&->KY|2$aWyE0_lIGnLtCD2Rla&j$74@U(BZsb}NEl6;^D`Pa1 z_ljP38cK7jgrs@gQ3>?IjwgQ#$vKuTpAnR{aKIy__lg8|KlwBCj6fnM0T zmP`?;D0EaVj=66mGgcji?pZJYW2kf>78 zf&{O3WsHCC*Z$m|NuF147Vz_TAB8|Kya&y{u@!ik_`mW=+rn#Sv>?IjXqnrcyNLR{ zFPlxu@%p?c?*SlzUf7N1&+k^AO{Vxo9CG440JI=+c&LX(!~9Z& zTADaf-y)6*fnKNv_%rnGMfhzpXBy;ILR3Zn=?^5HymHXRe0?o=0F;X!+zairZ> z{xOk2FH{5k$^H)R)G?1UjZOMxq6LY^r`=_Y=~aa;aXnp=MqYBHWtL?s1bU$w5Z?;( zqh(9F(7DbRO|&4fFWy(i*q$z){`}z)LTg4j(P*+lAE zh(5!LGEHgehECMad8dH{dZ8NNPjcLEOt)yxw8i=DqWkfm{y?IoUxbXY^pp4{h9}bZBasE>iNX+={FJtr=A$nw~p22j?=MuDP{2|dJ`;R~` z+^P9pfG&Y_d>I!y`p15P79<>gc*z)FcZ*)%lQTXvhF^nWS8@~ry>O@IUuZJD>D=EX z`L_qN30jaizNM6mamh(MAr#ZF6#a0@k&?d)h@RztY687*r`ELaLC*B9lMDUO#*U%| ziK8v^%NTv1i8mCc`sbs$8#&Q|N6RV%df`sZEBdl|Y4?fF^v!H}w@+p`p$(AuyenJg ztCVr#X{wy<_ejkjC1{@+zM_ZxAAw%DQ)^o8tGCH4p4)TBRui3VNeFF##FxOsGDg@R z@y1AKkFBH>zXmV+5QRW5+^KoRa&Z$`>Ec3L4iBbiL84&QBpIWco6rn9`}H8D_?++c zudfj3g*!FB-8~T6?*lPnB;XO<2AwU_uQi`l@M^LZBDkJ>fOZyWvNw^Ss-wS`&&E zBs}wVOc(pdK>0@1vK!gPC!Vir)KLla!n-H@+4I;e<2cW|o+}zrv>?%ItcO*^SSQ~M zo7%6Z+4D~cI^MN`LZBC)S=o*H_pVI70o}|=#a-xupt@oP|Fa;`yGN3Y;aXD6dG+;M z%vgSXI7R6SfnIp`gx{~-zS+FP^X~Ub7NK5czfIh)LE>Bg!!kzScVg{MD{{~LaL$p| zE9k2b=!JJrxB}>N$E?9)Y}vxE54c~0#QX=@GDh?M;u`!qC?7k^_w)G$$|(eT@r+^o zzq>M3U*}=TT$i|%cNLl8KMN9Xd*?UAe0<3ixujL@Qv5gdoM`xt0t$g%c=v?=CM$pX z+Mn+qLuT5Es>l$vzPMk5L{3^M8Dsc!@doIh1yz_gzs^JMKPO0_7v4SLzn13a&GPZ< zd|jt3f)*t1*m=ts7siU*ek&q~z2b8oy#1g;pcmdf;qP-+tHX>RC27C$2MAh_m}M6r zV{D;fk9sqtE^Ewl`_I-B6#~8RehdF48n19xjOX^;-NzBMAW?oikukb77JKbFyT)uL zuMHY+DXI|Yg?CT5ZlBzQy&vsNR|h>3&tJ*;5cg8hs~&$7NzA$D2Jsg5d)p|MtU1z@ zCp$&=<39quxL%Q~=ifUfHujBKCf7c>itZ3f=06J(t(u3*7*%Y=^X)N38?qx@!?j=A zK&Ykv2=u}`8~od!ck46p`)!RD)EAv1L)2vAUJ4SGcL&KB-I|N%MT?9LVM6a_ELfos z=!JJS_?ObZ5cY-V_J#eHi;kC!A?~FhaiOoTj4}DG(4vlKYqB7o+Z!#&R0#CKI~)9& ziH&}2G+(<{n_e{0f<)g`cNrr#Mri6bkshoVKUce|{ZR&Cy_Se^0CbY;{p z6D>&i+;)&Lo*omm!MBv6Y%jkC&s!Zu&r;60xch=$ypoaG_1{kF;?ef(3fEV4Q;Lg< z-_;rc3LUq#Uo{_kE25<5O+ z$`}nEidt>?KgZ0AysjJqO<+qy%Z$+EKZRz9FNF1 zkqXT*kMi99tXQ~0pcmfR;NRJ{%rdL-y!-lc2tx}J!GUdMjOvb}c0L_h#cbZ#iLN^s zsSxOecQ$yx zxMqmj-B{^NAb~vx{{6(u7MjbbLcg87YrzTv>yOUW!}Vq@SLpY)iZ18w0v7&zUf#6c zt2_h=^ujuU>#OBKETTpgx_NYN(HH&i9yaF0s2bshi1G7+_y)v)z!3KPO=a4#=RFG& z=!Ln9KZVq^l_hN97}_$AZeqs+N9*xzO{!;)75azgp?n4>opIF{bHct!9TWQ(NS}AU z9W01ddBvMy=k_P)Zk&iLRKvt>1k#flM_U9@uY-7B?c&G1`b$n2L6uDGEg+ry^=e7z zc4F)+aIGE|B4QlPv)YQYhIEg4A(AM#NYu3@ zI_pH7t8xc3gK^f7)(3=3;^qjkpts~ite*qH zDg~>eAz8MP_(O$S+LCX%-h&f)U8@qT2eAH_!BXXXTniIA@%heX`k+N)Xc7AWg6baW zrZ1*QqCkvzSO2fRS1e^XQFLK2L9LE-_c!j6xSbWk37_bz*7cm2H7S^&Dn=FFXmiI@Z-eb?d1uoNu-C2OqIle z)?(zJmwhnA40bK;Ku{YZUEye!B#JyO$BEp ztG>4>=Db3_O*(2Sq*J22B+>b^SS6o)YnkGx_WbIjWBx#TdwqXNG<>~?6JMVOm=2G5 zjH4%lF=t>d@f{f^iLfw{;d-2_ZQkZFPAsTxpt3+E(|1XTB)WXoto$9t<<%uow8R_( zbpp~0v>-|BxGTO~U$C&J9C^Q?hJm|2($CISmqeD0P%?LBYltf6rd%mOd(yn>3X#tEA)B&!Yw_&){=h@ zNmyi2bBRsQ7W9!rVLHyHghHTK*Rf%`SXc9MiTxw6SY>jc!x*|#KW3aO<}NMGX5X~d zJ5pW#Jv?fQ6?=>PS?A)OGQ$mvG86*6u;a;9c!5Cbw>QCTaj*i7*lep${JBCeS=nCy zi~kbhjsq+74L0`r>2ogfZ;wSHUrq3BMR$_|^y@Z9a<@#F?#4&Hpy{4qJPYbL`i*N% z*On+u!!Q1AQ3x2VnyC*zb0*Fm`TpVj#@HL-AV)heB9}5IbsPojITDe5LJPi72=u~! zvZkGyUYjn6PcrW|s~~>EO@H}|?~wd=wec*dBlrKiF&$oKia9>0fI^_x;JjY)T&2f} z>-?!i z)N*?Iv@FYJe%_|;DyN58QY~w~tk9o1iM)ICNn_eCVT!rIzX0XGYbV#0_+@lF)9Sgp zxIKoJT$x}lyLLzbN8rK61^JGkDkHY2a+C^O1Anfj)(2xUf(cg+MP5V>(v*p>W}j_sF}?WN%{?ku z*#~j&Q+KtzZ zz42^=b$2$KPaDf>k&*|+|E0|gHTJZ+oi6^~zJVGw=D61uBX?LQlzwHK+4{Sl?xt-8 zl|Zi+!*2&awM13`OjDOa~!R@J9im=!Hnl~ey(H)TU|Rv?st#B z`Wko^)NurFUu1UX*T>b8O%(#YaAq}aTY=g73om!Gd!@0=ztQbrG1|v`w0IWOaa`!u zf_U8aFfV(@D+GGE42`ykbCqTUXFC z!<`<_f;x_e0X5ChwF}U>CtD~4df{yHw}cW}kX^l;&A`?1thaBpMXdJee6{0QP{$F| zb0=9h!QK4TZ-7Fe*Q}$fE#itQDSyYk^}Q{m>6=9J+xp%NEl3>rx>{y!%PH~8?@{zQ z(dJGwFYall5a`vTc&5zlWnPQb?)cco6g$qxq=Jfu!z_W<3hxx1j5Bilb3^U;I!!TgHK{jQBd;3~oMrfI_}{zF>rC}oDv>&$xIy<`#V>MXxH z@hqt0Xg0V6eUit=JnYd}A<*lt^AEYMd|rusb>^U(F^7-5(K%hwiMV2LO#Cg={%1%f z$HMgPCJV#dj`cDgEng*#f04ff3(&%?swo6|9e(*;j&Iasaa5CIqb(Wys7}reRk{Xv zrqv!m+g2rLYp24rU#bg3eT6&Dmx2ybuN3?ua(mbXFPc+_ul6l@6#~7^X1Yo3Yvd7M zm2dCvri&eC#}i%YOyH5IeTI?^DpQenH!sY?Q19XqsxfRP2GIn*O3Gy3G?73r?>^P# zuKn-trM5Q5(A5>3%@^0MEBXp|>Ry$?b&*S!$B28|MOKWZFFL!MaYqa5xa*^rU(GP7 zO9JJ$KgX?{tUuxN(Wgy-(qX~T;{5Y^FxEn!-?gDD*=-ZmJ&s8o$ME)zXr79BY2A4R;T#o$NYLj3meXtNF3E^!_G|mlxS`&KAgrjc9)uJ*Yptm;MsQR zV&p4ZM(F$()2%|oXY3SRoX(Y}neP&(nd>SIPzdzmzvFHdbz-8Ii1FfWdD5THx#i9+ zrALOdsdl{L_(wT5^4*|QizXC%NZ5g5t=#1}zRwm%bz{&wy=6!NdZ2Y{g+Q;CciRR3 zZ%lkYFLKuW9?HKJp4pXlAgOZXgZVF`{}?|-{@po$9i9K?y;W3w{v}1lvbtg(_Sdi^ zb9{^D3V~i3tFxpQeY8Y$ifT72N`$_$X67Q;t;5l(ox9AB@AXOiIu95bP0J*9kX_jo z2O@Ml3+g!bRxDZjTIIZSG3}rb=!LV%yUyFc=?VO(_UtI3^b_$&aE2&DYRM_{gPLW{gl2aUR}7x_`AsS;!IJi zW%@R!*lEmc?k=^j{yIciG5q&Z+wU`b@sa;(*gzq0UE!Gc{!z%nD%~z+9&O`AuUvF7 z#K^N&ey{{DvNiq>S6>}h)$+Z46h**5K*|Ee?hZJ6M!-f?3~a9*h+SA$Z0zn1L`>}9 z>>2gi-Q9tC?ap`2zV~~7&%FPg&u5IAgHiXw0|dmDy;W2RpfQR7|@Z z%Vb56zglH8ut!LT8(U63x4ZEL4c{_kAPS$BZvOf^PUb3=UzA-?n(faWX7y;+<|!e? zI~Z0^m+TX6uUj$9R8diU8>Olpo-$lsx*8C15xzVX;XOsQ=h7=`{Rd~ni16U z$qmYCIyKH1U3I>h5H73HIv@5yFw}8T1)lZ3f;EEQMylQ&F}I=o9+ZKnFFrwIAPU=L zvlTluOvcpn6Aw0BG~+5)vs#{)zHJ|vJ>2@8XK>=a=z4VGL+u!3S0^T4OAVBRN97lZ z%LWQ$AZpE$s@7B0+^Bqp7ms_(!@=W3JENDt5@f79P}Q0xm!Bxch>V@(<6+~)#Tlvp_HLn`XDWGZk*rU2aKP`WRbR$3UM7gnS_8!BUOSyfk+p(wh zzM5RIiF`;p+z!$#k%1_j3-qR5$<8u&*f`-7SK%+?9i&a)WKs z+Y#KnnVe02d*8ip?8rbA&KNpHj*FJvQbq~?{dcfpj|YqU^$6873xKZm% z?h94Kh2&xU#me*s&I}xXJU619dQPxVXQ@wfWY8Fh!gDH{ty+#LX3*DSqV=g^67@Dz zqjdG-^}t5-t?is**o46v15v1*P*jv^lLKf^TsESt*1M>x={m#s%D!|{k2B9PIJd?? z6si<-=H(hH)EYdxc%=j963!SrF{8M)tKB-gyLEe+#y}LF(9yRV6Z)8=X$|&%A1$%9 zsAcG?LA@Vkg%8d4BCi^13`C*kK&SfO{VYvp;K^4e&J1iVp0Uwar|sTCodT?8&ol<2 z@Jx^L$Z9>Z)Q`)5=hjZ*uy;`@q`M|7_6Je+8Nz_G9FVR(;adS#@{}%a0#fud+2BO*oco@oOaC27sl}ndh zW-!^P#0JH+6Gdz-Dur|s*RU+xOy@W=zBZH?_hT>U?MQQ;&6ZQk?GhGi>v{) zBc^nIk@<2(i8>7Sg5Hi2C2R4KH-kjAak(@GqOL#9XQ>}MGb6Q-DlP_ROUdQ42H%)6y^ue?&Gt|7(qKj`M+$fwS z`cvi3RFj{jy*+RF6#{iJ)cvsE=)I!!QG9!-8z04@GzOw@Z0H8eN<;X*8x(UMDy`)j zpzfzDce66~;o-T9i;9JoXbeQ*87p~9m-F&sG!vIc8(QW7o(W?X0ljgUBY=1LI$FG& zV{j}%25NHTg*!rdKGJ<^O>;0jqeZPxSJh6GkD#~PvhqQ<$7l>hp(aQ7M!hP}GSLV+ zCCoH&3~+2PZ;ihCmQtFj{Hsqlr)msDVa5figKnWrY2j6FtYGQBTxW&gU&B zv)TRgi>sgdX$(Z2XrJG%EN6ZC>Z;8)eOd_1uzna{T13y%!`ABAB8C3S#{(#K&N|tH zqh^I|(%bRwRTl1EI0OItLt%}9D1BZzT1;m1tK=7hzV_p&A!0A+?RdH(GduM%NGz>6 zRAV3tH97iP!45YzY|SWfpz=VDCCIopqD6?(WV#Ji`EZ}oF6`MhcAV%`V6etORAJFF zRoUSB7US6>SJ7r-f;bU1fMW@&wW!HaZq#yCxy|>{qUf-}8Us=Iv}Dh7K4qB-X5h!1 zgSFfboF%%RxcPNvzKce1&X&5Clk>~E=W~jf8=)ElQK)Te_2ytBj=hMztLurMXTLSoysH0aGmU{LR3qt)_1mW{MK}v-OE~+mM=@)bEa%Q+ zRz7v5t$`W?QJAG`vt6`D8|oC`#>`+Xw-!eb^?kb6`C5ulo@LDoH94wWs3+$|peJ!Fo7fupP zV;~CW7)4&U;$-OCAh9r>tTa3?>bt1>k>1{*k6c?lzgWL_iN-(__95lz=MR)W=&9b8 z%+9fP)TwplZWeQu3+c4)*!Ccefhg<|I-3~3yo&NH7iJA(*o!#BF}IrDV;S9XkID*b zJdm z%!>Q9lLqWjUF!>PQ9*7Z6(@E0dyRo8R0=6SaZgiIok4z{7R+!rz!imAI&iZVx$ z`tjYllEy$3X0g-l-%X0kePa^D?8#*%s+p+up^8U2^iJVYMZ(MIyD!K<6zZjv`w zl~Odd#10dC6g5j!sVD<`cvIPn?8lV)Z|%rH6zYC-CR{wbP<_?+Q96dR4|^1I-RUg# zauzEue0YbS1~L$Zx%KphLa$Pms+JJCUdxEb5k$3~?$RsN*it#e?TrNHr=bp~tF;yC zM##D3^&R4CGzOwjv!-vwFQ_6`(@x#Hd><|A9-jpDQabrLR!6Gy+V1yP3Dj8e3H3S- z+&5$nvTDcpO^tyleass>dCM{N{Y3e-^&~2rs6gu)@|n9YMMPYFkx=&!je#i4zoY!C z-cQB6bK``kM@xw%$Z);B%rWKHVxz$sD=VzPIuXAlL3n-Wsxc6SIhXYHx_SwFKRboV z^G=01daRLwtwP))mKe9n)G*LTrQ0vIkEgv-I>3u?Y7G6I^kOua~KS*`|o$R1_) znT2zN$xRJ(2BOgSrTmgnyG_@`aJeu(h+_#dA|K>;l%qOccT(@sFFsyb}(nXQF!`JE#G@xT=cFc7tc(`u>=|DKa$ti^@^xmK1_zq`OUBd8Fz}fm}`-s;^v#$> zUhP>;7ToI47>L4?ak8ARUFjb0FnOxLK87X8*mfP|&3w>nvS)sDY(ODV; zQRr*Y2qxr^a%8w{`gsP!5@e*xidJ8ZF0I}TcS^}4-(?DyC2B@$3`C(%MEmZ`ymI&J zQ0cz3zUKQN<7ath)lu5|>h!!!1!c~JQ0ag4se%86DD<`HuCT`i0>#NChx2Bj_f{e%Rp;kK*|5PXLRd*DS?hCkF z^|pXO2BPp}obGfo17(#`;WD{YpuiGjc=Zpq>bR1q-kf|Do>NB54wH-X4AB^f>RrB! zsn+>|N$MS}+%Ep|`j&9{y2W6Pfhat+CZ9UNUB1s7F4^WpfhEYm(_W<-WS6yRPpsc* zt-umwjF_0;YR9$b>ZGCi7bp3h)ljLP$<8LmfF7VbE zh(b$EcbM*4B(8l7mlJw>NGw4{@Tlch9b0;wIRBA`(lGarffuUCMrp z@2u|ljeo|BUyHe1-J+<*Koq{MKz{qMVB^i2P?_1Su*4E%9M5zkMAcF7Pt}g1^D?t` zYESG?Kw}{4)z4;j)pCAT^=_-3nOV%YYVyTooq;HPFMx6f*L7o0j&hm4Zh*uRWZ-)N z^qrqW?b)`K;j&DLToOx=QL@W&>#6PzRP*X||0S#~`B!r)cxenoVFo^ZYir^X*05@r zd=~B{u>=|4x1F`>$htzUkMb3cu=UHT$@V>6GzOyZw3ogI8FYj-q`AAcy^F*WWVn6( zZdcMB(`-z0Y|u1HVG;@R)H{J4;1||ETT5QBk;Xi4xkT-R7a6W4%NS}V&ReG}PWLG}`AQn|f@woE2BPq!k=FUL zocsl?^MUDyYA1@w_!S#$)lvMhvV-l{1o4kcs>#2uk**J?(}=?JPkI-*e?C5)e7KX@ z@@Xfd$f%h$)at7pdzC%kKfe&)LVJ6=ZBw++K;C0%H901! zAwveD^b^Ium*nLs!C^AHbA9dYOk|8aT+yn-HHY$dUyk*qJKN~Cli#y72BPp?Qo2`d zx;MW|e#wmMGqtx&kr8&hsMS}`!<64%=eireJUdiY8jz|n5QT5G(ybBAUHMOn)soBZ zWmtlYmw)E8>L|Wc#UIJ_KeJ-^N~9}G*75ivT=+ILebstB@oY6>es z>%3^otQ;AL!gnZawxrW3>@e-^E4ydqSb~h}fqSevGVD}wZIv<;S>p0)vVdz&je#gU zkF?o*i%et(Hdd1hzWH-3LB{?<j zjw*W|{=|cy@EIb;y{s+KT4H@V<4lq#@4Kmsm?GOjdtd@7*pJptt~pFGha%5^>o5s zfmQ~s8t+)fV!T|XEa#`tE_~plIFaqM&Op@qg~1l%WFB=N*^n1~xbPlAZx;6vXl1ZI zozcm;A1||~i}>Q(S7RWmX;`RL$KYbh4)*LlnER9;EoygiHP8Z}#aZEB-D0d=px&Y@ zlpM#4QD5B-&7v_7wIzSJ#dv3Z)B9;m4_k*DGN0QqE5f5XfaYMEBleL(jVN&K17_Jd!6C#jP>b^?FHKL%sab?4QFm>3`E_# zQPipLrf>r+RIG54Z4?rOM;j+&lk&GxxH)J|P!x)U$-Fiuoh(nw<Um$ z{P6FSn4jU(&O!EcSb}nEdM+ zL$wao!Q;P+T8syVnu!Ak{z13z4G?AaUe_3i3ja{iVq`0(*2lYZ{bm1>qeWzy5)8FB zRPN?)rx2D^$FC zR}8fnXCl)(C^qQaS4O|?BYu|bt1%FDe}1sVm=&V-M4$T6veck4;%LNYf$9*dQ8SyA zu^0(=)xNv6cQaY*VIR@q%NLD-sHv;+TZ|^P)ZQLD#6#{#KSr#mP+Ov^f+~)0rXSXr z_x(#*oXbw0vfq~hLKdl`F%b3k+F6Scc}m&AJOh%&`V#a$C%xT|iU%qz%|A5rH@ca4FlFrVfYBp zCHsppYr9JjQ^j0vY{&ntPf3Vl2BsRy)k z+@yBgvF`}M|4n>Q(R#vy+on6v5280;ru4AJW7N*W!QK7BEmY3e^HR~1(Y+;G<-z=l zZJ=oR;f#qT$Ux7DGJXRGai1@v#dhbj8Us=2+mUJz*4Jvsr^=c3;s0LFQr&vO4F@uX z=>FC0-90Tv;g5w=N0$k;P(RORv!l19Ggdqpz>T8)MRc~33`>xK-Wlazee2JQr5h!h z*h**&M4|skhO<}~>#62M&rQYuoj)_!dcwIbL+t2l(M=)Ut6SqSy?uJe`IY%Cl)2%K z9la&pBl}XZK7aJFkEj%!nPUku&^x0q1C@>BhkuP0-CDY73`C*-XtUL5R^FQJ3D*K0 z_`jWxd00<)IW2>s`*7oO<+2!Td1Z%d*&o(!aO#<_fgYL82sswOSC;KB%qfgx2{O2v(PH&%Lt`KceN_5}LS@@OPxW_%1OIo!)3eqSe*In8K!1+DM8g^w<;i}$ ze;B}e)@g2`W>jPbdStrS7cpxB`yMn${2JbaV+k_Q1EhFxKmzMYzR#eCJv9cR&_|_N z5suv~WZB6Q4EO64Jla2>GO!FZ)%h&_20li4lCXh9S4K zN)4F1#JJX`#{b6%ejXoO!)t~(kTj6ve<1_&4(Po0Nb9QGij5Xq<_*>uh(bTu@{mO@ z{{1WG7>HWp7G!wAPS=lT7x0E?8!03!>Xk+QLhY`$TTI8(Wh#z54o1vxG6HT@eEbMAjp-g>I2z(5C5)|1SvzQwLbEaw+@w%W1MWlk#D3=5@f2F4#cV^D0YESP?vI6inG!xCg* zyhDEbpdQkfl%kW_bOxd@BBNN%(bIaWcIOMG;{TG}wVp7~xomblDmpcPu$=KVUR2uR zW?+5+MqxT5zC@h7az9Q4e{edsM zIGBlIwYbc2@)C`Ch~Ie=85ntC+(!30e;g!>=NLlkT1yL79wHRgLdb_l`$ zmCcy-B%)?zhaTPCJ^RN$bGPmGko^xTS}5hWVXUSz$~EgE|5`RkbZM|pUZMIXC^;Acvz6rtqCEwS2!uI6*>Jelfd$=q_-s{=XeiHKw zFk;mi#gfBi``dA%-7}qmD6gv?)|l^Dp)||CwlQnWx2(!x#}Z_a4{tr8^5N;erAj4b zzKvak%i!`785pHvj7#drXGpHaSuY&+)Y2Oshwgc3#DQwjB9m9H~TXY`eA_RF{YKo5@cW;P1(3(o{GYyM+>I~ zEj0$BFw&>k`RkexwYS%5+`*3jOYx`mgev~jqwUUl>Wk#>@uErNeiHKwFuK$ar5M$a+E*N7H>v?vqWW^*!G{ zSYjNGPl*3yvn6hrWRBR}MYx}j(-??CZGf`EoJaE9E7}G38Me*%;_73B_-$n!Q_olJ zebCGLU3b#$kPD`lG4xOh>)rJ(?LOm%4&u4HtugZZj%Cw-X0cFf&NsF@E@U-qhf~<8 z<#p}Iru}NyU>>$D-X8EJo?!_x^g7}@#qh8!z0LDmPclSdeI@#5Fk<$nutq&*kmcOA zyffdm(v$TmbcJCFGQPb|XFctuh6{=D zePe4z2X=5`8IC2$c=lkwRYy#rXT&(OHh^cCRZSc=!ZZe=3LV^FPoX+?Hu*}77PqtU zGGk-e0GE0kOOWx!b)HqnnAuL`Wr&+Lejuc%_@1|g#z2%u$uaiAG#)Qkxe()K_lqn? z#+Iz^whkOikntis(yHU9b(Y#@%^uciURJRof2_tp)WC+N>}RQtCCr2BXcaV?UA2 zCH~|$auq43{;SI-Ut{O+;C(A;X7t?QYap9WEw{%SnRYj2mtPLySb_|_j=^4IML_2@ zM&@CII6b>V)uH;t%d@&et*9D<+>Oz_&ZYieI~vkSMw8?5_WL^rb3|Q;xR9z=YG`}! z|Bo?x>^Af5yY|7OXuacqAp_e-C+_c$nt6Mq8RrX(&{~cRY?IBlXUj=(+`qT^kybB0 z6}G8bfvFC)Y8Ux9Q}326bVcM%^kge(^Qb<&(6^OUbxKKyO0<+B0ysxsBGzOwdEp|86c$B(Rff)C$_mhuq zRAZ|Ojb>Pa46+hd9Rtbw(D$R3#L0ENuba>IwA2`gO52#jRI9dNU6r{^t5voQNHXmY zw<@b<{nuh~zUr^+KBii=b?mG6BAXQYmqX;Uy-nHJ-LVFiAVaTX-q1lZq?2U_#s8lD z=+2rBvR9^BwXeq*j{p7B+D3n+(zl&<4VKL>#M_U2i5G|>+h?j(d!qE=RGmTASzc<~ zK6r85HWN#bf$gLF6AO%xKl-H^4Zj^Vu>={|Ci>Dwml)YUptrer{YmYquuWtiO|@z} z-Ja;6+t`D^yyoP zv8`Br`MPNcYe%aWOOWxunSXV$NsMZ@?b7XwxA?TZh{ixv#scRY$_`Gxdx;p6rj?O( z-?nG@rk9aef{g#I;Nc4%5aU^)d@^WE2~na{n8rX~`{f-!!tGi3yw;hXH zR8L|FGXA#{O{r%y$?F^NQ;b|uLJX?VQez-$Vev86&fru(BQc&lydVxe>%_(t>>#lO z8UNc+^R>)OjCg;aMo1;dTE@{Snn)Qh4JuEm_t@D9BbNqjd*k@bKI{ubDm-t`E zz$c_{d1UT2U(#$RTg(5owGvB^K{MZ~EgWTc`e6wcL2qo|U6jizUdQ6>Qa!b&1*;+Lj9AmqwKk$7uB;15tGg ze6qaPx(Vvsx#Qw`ynLT_>{OV$#1drCPGqR@$n}@n;e5)sNzLxn~QEI|g@6|0W(mW}F7o^g$?*Ub`pTWSnMksUPDs-5RKpRC#z@(<0P zX-4SxqXtHCWMvGs2D@&W8H|yo&Gw6Y=}dk}X5EKdm5s7$S1~6tus+K=^M9?gvV;FU zFEaF}@;LfVY)z`cudco((0;t`aL%!f@*_8Az3do$?u}#q<`mYx$~Nk&G2fi!@>8|= z`}7+%2BKc&&1fn{!-kWH@glgB_)jHrheH&jR7gGP#c6BkyH2k%6e;SF&4-Ri7%^>29hPIr#e-1~+PK z7AK4^4)tG8H>KDw6n3#tS$9paAAj}5(Zx(*&-}L%qv^_wJTleJBUj$k7>Js-^M)fG zF;-u_OpJFa*VvmiHF(2A?*#4)lbc+#tMRy<>$HQ7+-X<6+h^k6)K@O)H?nT0Yw={d zqY)X1>NtO@#fTr~L_6H8IcJQ+A#r@MQzy}5Ty?t|k51#Vn20KIxVl}9N4_X!)wa9# zG5Xz(p2C8er9+}BXqsI;pYjgFZ|@VIzije)2?yOpteJzIHqVsxIknQdunaG&uxc$Ti4 ztdc#SJ~+Pj8f*RTG}hV7cWRR(fqJ+4mXFj|iQ*Qk^)!U%&mvG2JV|jwqV>hU(qNjUtG#I;0!>cYSCKMA7=N7(DF` z&Gy-qyoFjHlg@b9X|`LnH(b9cmDY##yRgf|V4C@c8jpmWi)ckXmA>MW%Lu;qentZs zh@$mjF%I7-N`1AjyO+#Ak(d0{$Kox(Im}$Yrp!#SDh7k zJ8cx_X}_xW%~^|55k)iKVhld0qU{E=&Y07y#PPl%o!FIe)g5X)TAy+^5JfZJp~mCa z{OrU?pQVqvfcC53MmLRtsJIj7Qx#+Q1f`dhz3Aak>m#h4x1jYAqGrSXcRt4AvPJi) z-yb{r8J#O7g{VHKv?aLDQ|7Xdr~G4kZ9cEsP>q2ov=Zczwd^at){Wywo((k{`sJ|f z$480*5v5y$lcRcxKF{L#Zm-1-WFQJHgUxpGlLtFS>w`B>XPzsQ-Le_cS1PHsV*OUO zM7IVpiP>3QvL8P#boVD@f%EQS0Nu08bN9x;3~RwAjA?c^t17+e>303M~VD*`a+p zV-c;7YCfK9Y3(xB>TEwXn~AnWw+3Sib~M!bn6xRo#y}KW2Agf+vTJPf8^(*!xZLr- zV2wwqV^akA zg8Mw(d6wOi=iXhL*RL^DV;~Bx1my^~?92XL9|p~K%YJO3GYCZK)_~Rr8%675%*{o1 zWFQJHgS9?{S|0_Qr!#1MST2SG)EJ0D%Ru=hnKQ`p zWIx7Nz0PR1TYJ>03uhc?OLS{c;6j>M{xF1}%9@@d15va-to^FcV|6~bQfv~RXurx? z!G)ueinavzd3rmu@ldgBM{OSE>8UXgg;s*%sB1&bm9#$cM)hS%1y<|hdaAd9DBT+D zD6rU(=@pGf#aw^$pVvMIm2BOw{n`W)^anmx9t-T-6%Gf~r zRiWU_j1?bi?N^`gXEbr|(s!TdJ(n8y$$sp-mr-LN3U?yP^1k_-^_K>?>_a1dGZ`DL-=TWs>VQ+&x?$vvLCayPp0+ZnUIUGFFBVjENW;cSku34 zu>4);$zSd0`_h*ka%Q%6hD<@N3>raeH)#JT-hsP>zRtUR%P5vn%V(Efsxc6SYo0v) zI}b!P8o@GGelr>mYbI82@ydZ~USH=lg5t-0YXmtm5Je+s?F{Q`WiUxeeQTGZ>l{|- zdoJzVpjh=!RwTSCbc-FmOp4WVw=#>+8a(yZP0*NIYw%SOR|D6hzM@`rUuq_08o`qa zI%^C>;p(C_c-JN`wKsU5IgbR5xz+MV$M)KBb?GbW&OlF@{aFZK+#*$DAj;)sMnidh z3oFf{HCVK5K{|J^^ImL3_p+QL@>i4%_zQzd83j375;+4UlX1n{~ zNrpYa;7ckM*BFSxy^A8@U-3p{juAXh?kpnbplR0J4K*?u=po~dYO|GIzrrXPHi8fK z`e7miQRu_deSLN>nPk-AA8&SLxPRcjf+s2DQ~TwUna+jtTko@I3`FUt8P(PVur)tw z^L73GrC43l+5tNIdKi8BTuN1al_8g})+RdDpVx=2FH)ZuYo1wSV9T+6^v-DMGe&lr zB~v_}nYc&cPL%WYHEWM~lzNFq@a(=c`n>!Ie%)DU3`8x<^x0yBRY@VnA0ugOK~lLF=DEj~yYC+WA`g zOrCD@_R_skNkw_|TgERGooFBfQ5YxCEjN)_*iX+9Tpo`%E-uVy{TC}e*@0>lMlX~* zSmcLMmb@k37b`RdqEIih*-~G26$e7<@S=v7M4N%O1kV%bMu+cNL{*x*JALzM3`FUA zd&XY=eDCktJjM)QpZdF7yZ$2YYL0n>y-hWDXE!;Tify94YLc1%Md!hO3;SpcL}8EE zY(C?k8l%WMhqXFmpj|?HI^b)Xv^=8iS$FRIDIHwS7AL}4UD-}-Iblnrd@&Qoa)V_btMJb|FS zWJIyk9*W$%B2BI*gB4r}oC^Ok-1n+PqGoxL`vLAInIhnNk{HO2H zj%rt$!$0a^;KmhZZW@mzK0gd(APTibde6)5#WUFJ@aPX+1zLUd22fR}_c{IZ@juRn z^TZEXGzOw{ow{CPfGAD-Zm9wOoc3MIZYP|}YS!p;$+A&dFL`NgBA;P;AJHyJ^M7xKu~tt}ya9d`bc=J4M~eGN5tyN{k8 zJf)S7^}9~XMX9u-TDo6y+_YdiyDM2-P9ybv+3-mwG7yFFC+#KmvWPieBY6M#c$0P+ z>%U5vGwrByVpL37-i3acYTrHfY=y=^6zZ^aR`hpQ)|2+#X@(a^zYYC4+EJ~uyMMB& zj{L~Nl+Wsx=c!RJvui*)kOqW)6y2CRci* zF%U)juEiKu{~Iyx{PZ#2)7&*i=HwWSV$7*$I^~`@)?7ZjCZD@2P-7qpBV@`QY~NIP zwRh)L!y`Gy%ZQ>K^}l>Cy50E1da<;aJ2&ctX$(YRq)&ePyv8!>^Gr5n^cOAP2Qzul z>sN2yHMO>72 zzB}7Ka^wm*A z#*wAy?m~I{b6Qc6*}XLeqI6IHSo?FLqg!3Rc)$yW{wMnCdK|TU{Wp=WK~28!{V9!s zDC~9mwo~eGQR|eQ%U6N&)3GvEU!Az)X<)zUEw55GhbX0W;l7gOA@(Tw)7x{4vs_^TAU&MA9}p5BN|-V;~B9oo@NB z7AcQ*ap&F1)5i=d^xM(vx7jvUiIQVWdh#)kmzl^w6h;?LvKU_6M?Bq}qHrElcJ%7XGOE;Ew%RGb#Eb+yUBZksy7jboE}1fLCOi4b&=`oq{4)9$UD>Z9 z-K{w+|AR(a9vNncwT~WSX+ZwHR32IFJ6lD-wV5n*S0^ni3?na{;q}g&&n7+bk#~xC zHp#_`Oke-dvjtGUBu$3yBdgTP;P0Od(-?@tlO~FU^W?Jf8Pbw-I53vQd$TG*o z{kL^^6Uu(U*b>{MKh>sgtHpJLrbb*&1B;rZq;%JQJK&gWnXP(9X{Ck`qw%5 zwT)w~+@YtB^~@vk$Sn^*WhBuW3?IVUzGu9^?z|cUQFv-i@!;1S_S@uT)P0$gq1J=B zQK*F3Z1Q}0M=L&pPmlK27>Lr9BBvn9&XQ+&>djSw`UPq|I8HX(&|i<3Q|pF2|6NmK zAW9#Pl5>W$ai{G(*V{l&9;B5icX7VEiT$RxJZt$J^wx{Pd%r8JF%X5lKzqr9cg+2D z9Uh+Rm>|#MKkWtEq(9Yczbri2qbC2e_iBxSD1FQym2=}`+c)HOmbDUCg0lqGFVbYj zd|_Y5HRd~-&(au(!l$L%LZ25fX3aPFyPd^3wh+f3l|q|s;-wsHuxaOupA^v;h{9f{ zFO%{}9@EX8w`tUrVNNRQKByYe-d;6|?j7{xJZ+hl#fm6Yrs(dS2|;}BzUsW^H$M|K zJw&1UW%=#=U%y=-15v0m+HCd$M@;3nzx29i;C#g~M-`9m9ju*XrqYTkejq(V2BOIK z{ZB8Ycd!=ti8bW6-&*X&apj<%h-xIgg}rc~`D#!dzM^~{je#iCed)}r%w*B$c?iGS zAb{iPBr2ZzQ+c`Xp|5!`9^{o>V;~CKWclr;vT8%Z4;y&Oii)Q`6AxXmi2>BRCvVNy z7>L4oOmEkmtISh=&tm-==I3}?j9D-E)+XKG{)Wz-7tLTZcNrQ3QTPrfy?MX%Gu!fN z2HlU+L_4X*w$l3SB>&p@pYBWNIEmh(8+qYnR+m?P`+p2X>AL8g zuUA=}F^&1x5|W`>hsq}QI^C+`XR_hd$qWDTSYsec@2m4C1Lae)YCWqAXXFoBr^WMp zi#f30^p>ZaSV-=F$++_ePmO^n>;>||d+cUyDVjLBOhcMU@jZ zQu2_qEH^vZJ^9I!Q5pkLsKe4Xa-IF@O@!+F=%^r$3O1rp>!o*>4m-2Lr6PD}7C()F zC{)!+b$)TzRC`JB@FOP9R~&QHtSM*vQ+lC%pQM>I z=cv8Ye#P_1tv&1T(_;sk$UqdHSx}BwWL{bIm7OQnn69pj zvu{(9kMZaRy+OWaQK~)@o0f}}i!!ZaqFV!Pog<^d6c1y{%@lSyp&0qp<-fO+72T3q zx7PHIF6H2?1z(G zMm%21t{lqAu>=`?U2j@-9J;OY+kN&2$~!IBva*>gY79hS4ikC$qw>jikCw7?VU;+R zAfs7>-Bulb39qS+=6%!4W3AS*!k;5F2BOgWpgUzI`R^`eY1RJF&TEmebV*ICj#&k>Q61@i^cRoY ztz|zWVl)P#(ATnV_!FnoC$g6lV>p%|BVy9!5H;rK^QyN}=p`fb#nEM~XK+7_fhgUF z`##t-dwpNa>Rj!kozWxX#G%T;YJEfx_Mti!p6hP7%t>T@PWROqh&owkU8;Jj0U1;t zz}&^%jFS15v#byMX$(Z6*H3S_m55^Hwyb7@Z^d#fK}MzaWgV)He0SB`H7&bDu{#0F z*~R5OH3p*eXkt(Qt!%^Q)ojdx_FAlljBCHfI@G?hzi1|^!%b{uEo&@iQ-^oZ7>L5y zkZxPJ@sY*1U&+=@i_&6eWZ>I8^hM7xA6cEq)$CU7D2^q_xa_vsYI)=EPqaQh($_qy z_gcnUIomY`qA=r|z7o^VpWmiE(Uum>u>=|QvRrhiF`xEGy;qcdTLs=aV={~A7pO51 zg;|_euhPmvZI-dND;Io^oWj@{0YYar8f<*TkP&`^+(V07&YL?tMq4Gs%-)qA&M6YFgP94y+B}AbzMJGSOfABT5Urk?oP|GPo zMy})TRvnMBRiipQ8^rL9R7bO?{WS)nP%oogg9$NwVYWop+-U&A5@hg$-c}vGdL5&; zop{xL^o8SPY}`wSfeb|HI-KjzK73rkwQOf4Q_CYmMtNU9tB#%lYe{b}+M*xly;re- zur6Az8lq5#B@dti-E|+anq@xI*~AiL^xfoR)iH>vw;F%#>B&9%Eo1pMMG0geifr(I zsygi@@v*!kwd28)1_Dcv@nEWlRflX>*-l?@aqYkl3|Pg|g0^Z5M4?(wnSe{%^PQCv z8NFq#on|0o+$9%Nt&erp6R3`%cN_B5>WQr8!jBpQQDjA}*&aDfy=!^-bQBLPvYag* z_(@|R3g5f5*~UhM@U)cG>~j8`5=)TLCeKZ4zgp1XimHQR^KLDcvyWqR2BPrY9J(nz zG9ParyOQ-?Ur}NSGRSIMd8Gfm!Bgu=AfLG}krl~WQDO-)HV)Wr?InesKG9QUG1Bwm z71pt}yCXCPqVN=uW_x-kKJ(2|Hn&Ek#1dqX-L-o6)dQ8^zUuM;=Fw~|i@DQAV;~Ao z7j3p#^A55M50<@`0$RN9G)iEhn-BVQkM}O9;&0036UW~>-6rOp~EcxD_`KL={ zr6#W|QtDW04-kn~?+EwkRF%YGnrM4Pm8lQp^+3+HLwOnUpRCyh; zPwk10eJVTJcSSdI&&g#h{o8&T15spy|I4kWtgs24&C4!{EaF*Ti6zL`l0V7vOTIPq zqdK~bjuPiCFJo)g_tY4O($DBmPj4W){=Jr!=nyNh1R2wO%G%W$TwU6W>R425tH?7u zk(KM;USl9?dcZilYWbWVZp5e}w~FV1%h|5c9W(}_Fh_$j&(41o^^#Y!D*;hj4hu5+ z*4k{u&Oj8tuSnkOB7fPXd=k6&Ay{GwGVl!*+7tKWl+DwY zvP)U*5=)SArNBk&sm3%^??L)_SCsEttYrmu25JmMVFo_wqV+4uFHuX`<@xy}mLP-d zu2o0(dFox6&7`Q5e0_*X2bU z8B;2W?8gBuN<{|7xTFY&{2|M}SjxPX9Mp2vkU@6Wda47(BI&8fo=auV7eDQcVf93Dy{hihb#6F_sNcb*yQz!${iH zoo_CAieb!wF;1hNYg1KUEjys%sNG|ej7t4u`LW$nV<4*R%Q6ndh+3VU7>8}i=8asj z{7i!b0!xr_+9f$v_0{aED!TjN_t;zz*_tQz_@yxr^>BANhrF7?HW(hnSkR!7=y27I zA1T{NVhJ)5W+bPozN$Jz)sgP=DRJu0?mUF;G*BbKc}1gXsCEpDT9`^+_)$}Ax*5xt zJx|mah@!b*G2+@P6{kpt1FWH6Ecbu2&cqUA&}bT}&`07~X8_F@bgo zt(wtuZHQ|5To+}XQ*R`iM{dXRucoOn5Je+vSB%=vl^y(+F4-6m6wCV;JHW6688kK_ zs^yd7RUPlr?J^GM>dqT%I-@ZV^)|koT`>+f&OwZ#u2os@OKyC9)pN(Ml=%Wk&P?2&J;lB_y^hG(lV%(Un>WJ?sS;~MI{;80UKtCGy6|h zEFCEZZR=*$7>MdU&DGjV%Ec(Z{cYDh;s@;|<(|1SEI|gYf4a+J&}R{_vL&D5T|;9a zDmAN{S%qfFyNwg6j%I^vOSL{uUSG_x1R46yP<4kN%aPKZ-&%5zVcd_q1g>VAgXE5IBRD}_d;nIHL_L`K2d-0DrKGtEI|gYW?F-XeMRG^-T1bVM>Gba z@@5(DP_rZ=+=c4c(QLVR{MLoH@UA1V1Q~ZD#yd99nAaMi7@wVlXmT`$KhEOKP}#&? z0#`G=MR#nMcujlaxKE6rxBShNQFYgtR~H|(>l~ zX3gEg=Tz)G{r+{4klL1)JohKV5@g^?qnkX--4}x@wdQy7j@1~5T5{;SW!3!Mwo@J5 z<~3nkMs(+&Ty`?l#c)01x}lp2>(^%aNt0?7xgX7>J^gwPt(OFm;--smCog^g$ag-^B|oL598tU1mREEe^Eev1i9>3`Bjs|J|DH z3)*g^I-Y)QEb>r0-Zb1HFm}Qfj4O@y_5wA9(_rOaC2I^s(a2iu_*~+b-DazN|4&wD zUt9j+{&j{W$iS6GXI^zCu!2Kc^0jVh8Us<=UT?A5ardLDqv?S+Vi>hN=j|x*pnoP) zc`RK+r#kj}d02MIb!?bpdUSTn4z3OkQ@tDUSWJohgYV5*M`Ivr&`xKIvCMTUG5YM- z&Rz_lxjQZo!xChmeo618?vw0Ca(A9wzOcqX6h&c{a`%r~`}=G@S>dw_KN;^Xf9Aey z&A4rUUw35p`)K{XczdU#%-RpuZ2!7=Cq32jv_LuRwi|D@J)6cr)U92YEXLYJ*NG9! zmWe0i`=lQ;MPLas29AGkt=baZR2_3}%obT`)z*j~qcIS5Z`DPMu`pQG(Qw)pP z_diE7EI|fp);8PGR7+a^l>*^9jk54#C+S;c~q@Dk*ws)HhkCKXBq=hDGg^>3>Qz8N#1E*9?^ux{Q36n z3`>xKXH#@V2dkH{1Pa0*>BsGZ!q}Wd56rm z*RJOA+0Abqd+y(|D=*_j@AK5V`F%IC$_-od(P@7PEJ23rmOWNGI`TKf__kyW%Sy39 z6pbV@5H)DjEsJrc)j47`9yD4krV+e7XBNW}WZ>x{skM71iM#8%^I|iXXbeQ*nJ4Mg zJKwM#dt33E3mUNwJu_L}#FtOA?X+sG-)YMm*lBfHp3KPT26j?EZa-rYkK0f#XFZL9 zC|a8qW6J#5#K^p3yZBBc=rb;lz!GGjN@1;9@$ahD7>Gjcgl0+HAvT7*rSJC5Eb;p# zYt)SX_v|Mhe6VKXzNrp7&3vnObA3vsr>g#SAIsbL51!Cvt;RqUts9GxaO^HIGR&AI ziVmXPr~7DuCCEUXiE_`!FA=Y4CeBWtq%jbMT9&#^w~@%arxm|@=nbP4@$slsG$vKSd8Dz)qVO&^YXB?fwAnL24GP;?aC*&oY)7n~`u1R1E- z(+*c+GV>x0AiC&Mje#gUL$KLeKWrvTe{tc-RT~%{)BG(Z#Vc;fUQ}36pCKEyHCk@F z=Ei@rNR5H0RrUNW-N)kSbeqk-U+IsYsbUvFsXxOf95u@1Q~kEhbEq4 zP5Q;~#nS>c2BPrHlhna>Cs=k`QE%ff7+8XS33@f8GklQj$j=_!tK@Nwfhc@h$_n%T zOxYrB_^S5l1ePEJy)&C_e$~_Bz=LjlU7kB8G7yEmPRhi9DA_oZE1&u%T3`t>a4wM7 zSC1m~4c+;W^z9jXUg%|@k3#P>6e0bX>_^0t#u@`r=nazMwPLMUk~5aSnb=g>4{L0d z2Z;>ywJ4*p^)100wdO-wOx75PLa$ccI#fYM(mKz$|EcD^A_IL!y4|>~tL*T>h1YQ} zEHRcrFB82%%7B>dBI%VX{<}aSje#hPuPE1`$#_vNZ+HGGZneM?WS}=l-%GtaROBbM zc6io>8Us-nC(>79$`4}IXx}XqJx|33){d-VOJwMNV*H}fY}DWwUUKI$je#hPcWLgH zX<&Y(zUuV2nXI0(aQ1k!HPnZoJdRXpMoW1*fV6D<953*_S-a zTQ80~{?68&Z-~1numl;XS<{)<+T9Kt_0{odhcyPGuI#KGtgQ2>f~pQ@*Io9FyeK6P)=WUzs~|okbx1B z&DKBL9m9Osjc2WNT4Nvzd!4e1zelqR>0J4|!BG-Rkb!f7)Q_y~MQw@=I<-h)82w;u zfRP1d4%}-j!blx_>6xf85QXs-eJiE-1XeBL4}Sjm1%V~VzzB)Hy3l$mt9`f?Z&~ZM z#y}KCmh>&Ub%pruH!gg^T~~=p3dU_1U(r1i;}cmMvT7c^#|dO03gas}U7B>5<$2VG z4?ML}V;~CkAj&T}IbVz>pW1zKoYMO2Y6nmn5;FA2tA6}SF`tx)q?eXD_#XpNsB2L) z5mS;+%w*?1dKWUi7Wx}Ej;C1PSHshu^wa%~0<>4Ixui~C8@+4EKc~yUk2yCou>=`- zZe+80Wo^mjh>ZNen1&hyQRHh`Cs;o&t2&;QxX4@ryYt+cPdKmy8F>Chbu2y1)cMES z)w?wYqR7{>B1n%D3#bm~I!A>tdvKrJd+k_)3_Qo9b7%T~RKAlj{9(HN8Us;y22E#P z3tRAeWis;dKh_!Ku^7tYBtLibM$tK%Avh%iKr(l@Iq{2{QDS z@7{Jw6zdbiZ;uSn7>L3XezFGbPgrs7;87P$EW!93qkGGH70P=JEO}gGAPS$BbW!im zOnI-5+oxk#f(%qCY_^cfrGlE@2#n}a53syf@vrx)F%X5S2(74hS6KK%vL81mGAuy`>Kb$>N62lK zO}65J^QLMHM4`H4v%Qgt>;;`oT-$q`p^l5%32GdaS@V85yGI)Gh|s?@2BJ_~rg(5$ zoKQML2RUD02{Lr`Bk%pO;zLq*UUK$Ije#iC$Z3Zw*T7gv=LQYZnsGhlKSYlJy-Vbd|LmD3`>wa(l%K4m51VpSbSUH3R9|L(yTdK?j0f($&br>|%4IOtHTc1Hcv z8Us<78%KFE*OSad@~Mx7H{xcLIJZgofSPzyuh#o8Q4BrA1iX3z5HXj*ZV^n15xBfSs5wKv#L63eaLMz zr1*Sg=flcd`mY2Tddou-1C8Z$s-L>(l*T|5<|@*cv-)2!icyqWJ?4bK5>&iUFQd2U z+>RR~$WQ!yQfdrD;nR}B68rz?I_tQqmhbN$TT~PSu)q%N4$j%LE$nNzU|xf;1HCFL zc6)8F9oU$Nz&U%y4ir?du{*Ba$+Kn)`kDLtKL7Dzecos7>6taNXQ$!!?$G)?%SVP5 zBrxtrqguwBM)39i`pu4U3V|xjb=rx`+L52k>ZCglZmpvQ39JirPXB&aCU=qBS$1%Y z>0yivBW$!+biFNeq4oLuvSA8=DvUGIjSjbF@bZuS^hHUR7+R3PNFkj`oiT^^veWKT zn;Qy&DvX!X7xEHfxr_*V#cg7^%ZkxRnw>#BG1t07Jf=l=J$c`1g+LX}?M&k5p75*z z|FB3AsF%DoR_^xwvLKUPQ8AN)n23U{zAmU5xv_2l^6_NjBn8P*(m zwRbnd*}X3-?Kp?)IX9oy<9;~nkw?zRR{M_EZyze-@P9McKnoH$f0TQWmM$03@u2gV{P939rb$yHM66SLZIr>m0h+_N=HPIWb(mzQN3pJ zIqxPbeRQ-SfwK=v$DuU(+9YRdT)3}7pbGb(DH@e8jec}>kC<{UehPsqm->6LNXh%cm$myl zquuY*XoW!4Hn(k7>Eq~Lc~*^$@MV+w$HzD)kJizG1kP&ci&I{~TKR<;?IWj;R|r%k z4XhF)eLOB%fP7qMBean{;$!?~j@QwG1kMKNw7UJ6(P2SG`>r@@z3^M8dQ{?)UHS-2 zl%uoD?&HR^fcTiW*bp5pNT?$YwQq{OX0*%pjs&Xyyw=zu+sEn?xyZ+t3lsTqTSjXl zwRf~2fg>E9rrLRi*BKV?-FS0=LZE8b;`xqF55rkNNhk8L*CmcGd+lM3qV|p!B-9Zn z?241{yZ74r?x;zkAF_oU256HM^)npepg^SJOvJ_kHB! z<_FFr|FGuV^}?kab&P8iha=o|B^SJiDKm*sDiOdJmOARYJ}z zhV0>r&Xi{$pVb;Dc3-|0eV5uhT96ofJ(nTt?u32vbo5e>fnvnsxR_>X9Ski0?@z12hhLCXW_kR(wn-wRf~2fxC+q3$<^N9GuZUuuQOq z1gid;l+W}LT4*8d_EGybeLU7jXlOwKcjahrA}&M}4$El27IVx%0#!5yn?9;%kvnmY z*bq@`V0=v4J;x2SAVK5quU$a7ml`ZQI-27-M*>xqisv+BJKxbKvj6b_YPUFvhI_ocUEQZB%(XT#f)s_B+!C{x>xjNr7p(K%V>{3 z;-wI%np*mz>7(3HdEO_@bX}xt86UGO$xEOG3EXL-9h1F9WzS%*v&lyxP&MVR z-tx>{_kW6t+JoX_%EkByv><^yTNINyo<@Yz81-$NuR@@zUEoyH$D01~G;WWbX@uvU zi!r&QeFa*Oz*B*=|IyM;?^x!tZ>}>?Ay9RqL{-y=b8Y$d!0-#vJfPcf?@u!a3bY`B z=Y{F5#^=5~>Oyes@zmatK$URYW{!h#f%1KWHc{R9v;lE3X>L(_2MZFon?>(tlnFLY z%{0ezg+NvGz$)IdUcJ60cb5jU5k^QKb37MlK>~MyX!aU&Oxs5D64~C7K-ISrm&|#| z@Y!D13lUs^Oe%KfFT5U9GyZkTgw9w$$p%?|L=U*t?Ukmp}rpalus3$a+Lob=HT^oy%G zaz;^wKoyPJW?uD}ADeY;it0^ao!6Hq)`Y|(HN}B z(Ye|;d0Jg0`sz)J1xJ^jm`0!l2|P7RQG;3o^KU4{FayBf|iaBR=n#@!fxE}UICA?FOwYJcV<9{CB+!8a5FRcAbNnM9Up zawb#c;bE4=5h!w4!ws|`v0$^C5pa#xgZJcH@e}nkY~eP4(cL{lAy7qcL>L)~XfRm5 z6W4gdUBmZeSMk1_Md=fn_Z@3!%o?kTpE7ESo6S78@gU zDT}h^#MR{2?c1y}9sj(SYqbmmZ)ihr`-yjLJ~Ld`qKfj+CJAr)RyLipv~4rC3=R?( zRtb(4BpmfOTjf(dZz+8Y&#>EQb}B%O^*Nyss9ILNvhDCyDv$Eg$E(11JpHMzBKulh znYUwG!8%P}osN6V-?r;6>K>@65U3j7#>s3S2|eW6pvL@27BM79{FUE@qXh}Iw!b`n znC+qRXth1BLZE8(rOamg$W?Ya`FK#e4iBOB@uzNu#Wc(zl%QFSFG7dsOfC?;HtW@te|t?gYd)4jXY&P(h&tPqOo{$z(NLD#C% z$M>Lj>@KC_`U+j?cd@Ntou;?q&pc+G+IJV#BWo%IswmgZ_Hi#kjs`;(MDl#pKE9@R zVQ4`@t?g7E^i_7VJQM;|R32vgD85R*gB8qm7IC7hSQqe)VcZ96KGta(2RGMb-krLO z&PN|B1ga?4&GvEky?nc$Ey>QqZv~0PYa%#WkWg!T6`LFH8xSBu{}u{?s%fRB*kpMW z43Iun=Bq9KxSmC=SK(=Ug!>sR+uO`O^B4lm!uSA2LG%Fu!YwgP&$ zuirL4wNQ|l`>?S>pz3_%FJ?{rv|IWZlx-`MJ>0(TjW}A6P}>K6OPTFB9w2rfYoQRR zYB%h&L-q`%LgYJG#|lm1YpJKUbiB)P{ei6uTY<&WZtFy@-}e)BJ|`&zsvh|7GSl%s z%}1&wRk}1{vWK(g-OA8{ghgo|ix0GAxx)O#xsV8jKvm(KJIxlgO_R4Lh2*WRx4)T1 zR7jh|aejrf8Y(B7EWsDeNp;I~I1P<67!PBZ13*L@=&$L2d&gMk4e_n*aC zeBF%ZI1zPZmg9bG4)gEs$-1M%Bo}jxs#Q(i>E!%%zkIh(M5Oan2vog&pV1^%xX&h{ zt;*Q@7{*iL`YD61_Ko!O~=$?r_KH|}>ETaFu zZA!cZ35?ZH4B*5vmhEJKIM#L=%kVzIYzNI>U3R=W_Qq`6XA4I-BFnxr+n{xhyp`)l z=7ntKpdhizB~&3$RkQyElPEv-3K3Zw(3=T&>0XX%TRB>gz{m^5Eax8OOOJOIZh3Yp z1gbEOVzErASDX(W6d*b+aIn}c^USjKS)JaPG^e5Ych^$cjn{`7n)T|@E%~-n9}iFd z>T96Lnl(}(P&KOCJd>ExMy@%1Pn2f=Wb_waqYiVlD8%r_w$E3>nch~;UPmk}%^G?5 zi{ox8F$z=%|Cn#KsQ#OzkAxxN>{L1@VdQNn(1HZ6km;=y9?4!iKj7(Jv{qKaNZ`7i z&VG!Y&XZ_VOZu{$xff0{XLp~^opsb}a=^^1Pjg;6+%Ft3Q@-)Od_O8`=LG)gMvyoa z{I^1&s?L!lljvH+TnW<|mL46u3zyM$jus?v6;8K!Or1sdA@>(|N~~1~RGqnY-OQ_J zBc+ead0lkB&zZ%AhD8Kgkib?>N0#(%hO`>2H`PNp0DvMY|!630L?<9s6By#L{XO{WPB69cjLCM3chOe*Kl@P8F zsKOmnddIzc5xp$U$cmNAuI!8=f%~`=Z?99F9j5W)${Gi!zQiGCceV$aH0n#tzuOGV zsZpDE$d%Wvl5)>$NfA#rl}3ZiLuM&dtNh^d+=k0gT$o%`x#o0sL{8v zL)MbNW=kJk-d5!vMf-}XgW?ndRe3tkcgWtS#~10N<=@+R$n7j5zM7AY79?;LLX_T4r3t8me-OQr1 zxobXuF>Rwtpo)47Gk340mpgI&B6Ely>77K4#YJ?qAb~q{lvgL+#q?Yc`1pQhm7O>w zaNm$-;VaIw;S~K?{_>dPOa8ouoIB?|v^*N4I2Z$>nB|j0Y~8gWF*u`82vp5Io!6XY zxYQg;yGt_?esC9xek7LvYC{VW7|Ws-l|7yKL=mqHOTQ}ws&LmyzNvSbrKhz)#A6R_ z9^FDI+sDl2C+(Of^lh z*;hjq&gW=Xzv4%Eno2CulNAD0m_CYympjNBdK)h6f&|V>Etc$ovFzDae{q&4 zCpNi(~Sv`&Z@d08) zhSO|Z|BGg%u}Gmj2BryriALu$t_Gb?6qni_Q3zByA5AplqJQ3$ZzA2jQ%`IA%uf_^ z{lU?Kg!)vMuQb&xRNIa1)I-2;p$d1;Xb0<%KO2AJ4u4&-o3cxW^Jtu3(=5E-buI0i zKB8CUYaD+IRk%)|Sw`Jt?JMyn>u2`%VG`@wgwakL{-iP``-(j2O+{Y$swOo@SfvXp~&$6W-A22vT^zRz35U9d+ zBE1 zRWs8e<2y*;N`~IGY`K9|D&sHWt4>u2RAJ1C>eb$o`i}@FF+IDJKnoJ;swk~XDSgoU z%p%HHB~XQtFnY)2@;NU1-Mh1o+wQsLHP=oqxBm9V{SVwdp>^%8BRqln-ArE(Dg>%% zY&X|Y=fkGaN!gQUf3R!RKJJ(K>Ocz;xC=z%;QK^2_i2Bzqv^j2fhzenkDOEc)Rd=j zUsvd;H!Pk-jA>HeKnoIhf|gFCR5{08sYS&$D`wD`YNouw zoS+b>!jh%$TGE#(Mt;pKdS863oGQjuBd#VXW?8MeKIUo`Q840>LZAxQ?eyL7GJ9DL zYUg5nZ)HV~1g?rH0(ox}3!%Pb-IR6;fhvqYSS$^DZsrj*e)v9Y$I*fWuI}lLsz1Yd zG{s9~Z;Aw}Fe*ZKH0}@OTZpKfxSC^(6(b0^(x>R)%CY?Lmq78%bGbsG3gf=iGt^(h zF1770dgVUE(1L^-1Mq1womGF?U(|oNULjD0@oQSKoX)HlpgwWTv{E`+kib*Z6jhsY z++n3Yu?zK$-EMow$W>8mx=$PTQgJ7i_M==6IT{5Ai4n0!6arO^29}JGqusZG@=c@+ zCBE8v&}t%2{11i}BydNW-hHJQx}51x%#uz=0#$f+nBE3#QA~5&=r2k)ILG_-zi95> z=k4pRVVdxl=v#EBT{U`(P7G~(L?KXB>u92R)}U(%x$o?Kt)B6T=AuJt|6piALVc=8ajIj)igwU5_uf!RpDFS&rPdvll zLKVhmXvf`}Y?P$P_G+Jx3V|woT6$x|@3{8vN1#~aagd`035+vQEm@GD{Y4|rq2aL# zfhsH~TH|#7$V;5Q!ym7&tDJSlXaq(ns9w$fz`taDz(>8TqY$XV2ok+_w`mC9K+&jv zFC7dmNML-1q8}l>dHKx#qF2&hg+LWX!f3Va+@4kZ9w>Sa+|1E}1jd;tQq;_kZgQtN z+{kc+Kov&x=mf+IdY{NYK=h6Z<+z`TktvKh(GKs;#cXy`5Z%ZgtPrT8xt+OA{JHzO zWZ^PyrJg$Z@ff)i_sfC=#+)ej+Ij=Gd-)5`>Qfa0RW!FV<5xd-k(cEwDb`YyGwtfk zI$DrW<5wO4tbwd&io&=c;)<_z5W3OfLpOc*vDN%VsUr%3s`)oE7}js$Ostb%i5YW{-p0KW z&Ube_$k2j>8kxA`)m-oXZx|oslU7FpRnx3*95Nl#kM1KMZ8JC5=Y8G8yB15UqXh|! zgwTC51*+>k7KU-pJB1YjRTv4Ob7Yxn=qF3Bp)Z*iQ3zC>Se0NtRlok1>8YZ>-PM|bY7D;78{Q13$}hpf|% zg>vt%U*aB?;nqq%hf*%dv#cYt8-{vzF)XJwPE)mGR#UZQh#R z4d1|DUk=jIf&|8ZES7r#{+gJ$j-MGfR3T8+qG(lb=_5~ZdDgk2EkG->CyZN84$;wq z1jc9RoPJxop>5I4Fa>K6WbEVN1i_v0#*NHJz)CS_U;AM_B%b( z(!N0$e^9@^N@1~V8ec@zh+oaG=d7U+sH(r~n(1SI2l*Ag7UK$w3O_dSyM?Rk zXh8zw8nmvhY8A$uFdlm_pF*Gtqark(Ct8I^hqXLM_xuWhs=D_+*~)$jXNwle^CdgB zwi3rHZKf~Sy=Q1aV#KJ-+U2j|?B5acR9N|Xtwotm8+g=$_Y5sa;CUA6OWySmAHBEo z`7P;70q|R>!c%A#%hz;0#e(=P{Cj_uK-Jv}IW;*N%>Lv>PxT-%P>j#Lji(>plA#3& zjD*m48kP(YH43fa*PYua1gZ{qb<<>RU)5ZmFilr{sCc%01Mgb6uZ9*RFy=t9@I6CC z$#!db;LUytfvT8e?wTx*&zt8`^g~M+EIyap%3IJWa5C&5Yc_y zMs8RQ11(6XYv=HF{l)1fyLqqmLpTzss?aBg>0^zX+;Mkz_7^`A!+CJCp&Ttp;98x| z6utEmNl9UR==DPifhrnBa=y$It)Tc#CA zpo&IqvkeBfkzczR*0i~(mOYF=?37lZ1qob7TP%UuY6x+DHGe<3h(e(1%%%jhMcE$R zARk9rbuqi_CSGGfVSyGTaOF*H&`Mwa8oHi09BWkwR26*~Zj*icr;_q))IY`Zi-)H- z@PZAj0xd}3N}s;;SLOqEp0$qW9MoDNQ1#*HWLp^d*ilBFnv8t=o@ese!25Z%7HB~N zPj6B5{Oxqp7^%WmE&c{9(Qy{AA65_qmt?rpc? z1v78ur9K8K1gdD%Hpjse-{jZ1K7DJ&={uvm?9~AREl6NIfX=-R?rT(Sxrxu*Jwzc; zW%sKPBYU6jUF5sS@p=0iK9e@`(BvTkElA*6-C}XA6|HS%yLp6dkV2p;ZJ$KD^iej4 zOvm8{hIVoEMqaMr5P=pX)HUa_bwk+M=DT?*YVSy(>f!W;4(a1SW_dU6mBvF^y0hWD zcy@n*79?GVYhSM%eankxjV^vnm$bd>A*p4vzJoHV-ZOR{`yD$s%iu0JUPd91Kr^JW(Sihx)aHqTpXc;5Kj^8;)$sc~dlUjyH13+j&o^J%b?l{E+pgt_(^UdhWoqU$ zWXj79l<&%PYZRn!E4!KQJa0v7klz+0SZ6nLT^s*OzK_Gw4b+QWU&*VLZ_Uwy1fJ@k zH;cCp)$`EWz}Vq$Ab~1eQ&aAq7@~LF9maFc4=~Vz1g@hkmXZZy+2lTrMX_cd7_OrR ze%Rq~$+(kMeD=<9am#u0@A~nV926alIBx9QA0pCsKI^#qFpnWsZo5v{`##QY{v8>i zJ8sv^V??Ce$+~?$M)Q(UW6l{3wuFd1TZ$@%B@j2%jb=m6>JTDxYjJO}@LfI@-HlGygu)GoDg@eb`ys^#dWIWcuSq z=F=CfQZ>AOZevirhNkL9182jYbgT}$G#cXRgg@}K@o>K@^b$D6RTN2mq z$kpA8jHh^uY)!-*=l2{(PAU)6hsWrzHYyME@5;V%JfHdK9V*;jDveNa_;Y@~x_qI4X5OFpcdJ2RIjfg;PFn-l_7$*bp@ z{w8IXy=ZjaS?tn}GNNm@%KF&ZlWa1D&$Bye=RGT$e>dKlQ_IdK+4@m=d{`z=o^{-H zm}So3B^sVBp%AEg@$Yt+atrU$DU5jm4!M zi3)+Lm6czc>DbanMp%}8-pPVkfVe)ns}jXQ)#gkY44GHYJIc6E)B5|_$>)J0IsERvsMN zOTEvZor{YXPs)f2=?}3g&p$Y1dDK|A%Z9B>Z68yI=Mj%Llon|M-zx;FJeOZ_cqY>t z=XxTQN2M!XqQi=G;&j=pI$DsJpZAKR%WsD>awa+E?98M#?moK3ZsIuOQa6EY# z&e|S&L&UNhUwLdhA8}VZ%W##7YeTiw7Tuo6uTT$nx6liPKvm50NV7$Wr_ZSbZ+RB6 z6{mLo`pIdd*ed32_6AE`l@_eF+JnVT*n5VBh(15gD+H=Mw}shduCM${Mn8(qIm~;% zE+x89zg4{X4zoPkZ+hz(G-Q&QqYIW~F#bKZ-63oHjQBV7R1F4x;Ookl5?;Gp6#`X- z>YQ`P7Uh2J0uhm`;(5B=KH^2auN-Ik*srL4$)D-}<+GFsDL zkBwh;axZEhWv6ykRz;|yny<lzs%9&~79quQ>?@|a^R^v*RmkEwcc=B|}$z9w_GZa=w4A64|cR+ZYvt5+xNNT7<^ zhe?#1yNt%DY99}>o3(?)u9F*$k>gy<-z}TIv;$SKmvd;cwh#NOH0AEz_eWWULUb~7 zVt$SUs?vC8G6}yj6NuRGsH8shb1C6F?=Yt}X!e7RGHi9AYK7|uvxjpTzlDg^`&{)a zdS&`(Q0m6M} z1nqTG2vmKqb{^)iZucWWvHs_dEOIHpq@oL(R!jcN%) zukiZRUxlW5$1rk-eV5w%Y>wZ~U-xS)8Wzu}MD9>UW3AZ+XB|pMeTiK^WY0wHqtE>_ zM$4EmyUeRJRq`l3k=px|%NXHs+Z-bHzE=rU(MWCfO0E0JH(zQVapvjj28sHXC>|X! z$($S9WEl-qVc$h_pD~?yR@%Fp%OVv5RoL^<+-LD?e!EFoaqC+yaWC$i={@t7iw;z& zJzT(7D#5d*#r%HB90^phSuY)edeIw`_fY$Aqx0N^ggKm&FnqdPcjoAwA7>|K?E_oFYOn|*k( z2@X_Y-$lExJtJ9(wEkk~fKCd5D(v~_J>j&u^~H_KiluX3^ZPHZnEh3zX>krzsXg4i zz+^V%d}-17V-A4?s$QOsu*ov-+2CI)^Xtw<*=TEk@NwoGqXsx*b&PFj*6A%3U8vO- z^EsszbP5uc?DSSVL?BR=ZO=u!%#jt9-H4d*(%UMi4N{BJYR80`YlBat+_gJpo5V;} z&RKaiYF*~qVALT=>`xzIi>ChS?_nx|Dx59TI4It;IZH~3*%MuLs_kZ8eW-rjhW)(S zQ*R%gOaI|jLfF=%Q3zBmoA}bUjryzHWumBWf4jMjR_1PI@#T3(J>W|9m~zzXkKE?2 zC3%Hei<4@-kDKRS8rToSa~nMDOjcL~JYPsrSv* zL@Y1nYT)b@du?33(k&hzi|Q=4f$)nOry+qV-1((7&g#zG<4I;w#J9CxEn5S}1xiQt z(K)p81rzN1=r%`yr3Q=&iZ#y z2vh}^ePz~?@@@8#kAHWpH2P4_kRy4ffukC#aK%ce^csd5X=%k$wAN~cKov%G=&mp~ zS8WV^QTIk^PaW4dIDf#MUz&@q8E4Er-9TjjP*fpMh5NX4D(qun-e+)Qu`*A79b+UI zJHcI9`X=k(Q0~;iM+AquDg>%`M_UB_=bbDPvtOmlKx{)=&skz5mw0>|N)q&1O(+Dc3AbmPf+sl?IMD`ttpzs?Ys@ z{}U<7ba1tHH6lbb+8e46xaPz-6SecxQG8}PdRx}w!f<@WaToUu=?2Vpr}$l3d37Ha zr6GYT9Oo^Tyvsa9>uF7d_treJ7d7V$avg;#+|#4z;E5vQ#04L5;pt%6Q~x4Rg*$W> z%h<91u&t9ziFFey3+fYXvJC#dkk)W|Ti;fR%H~YXY{t>pIcC14FDGZOkoww|w|C1#H z-51u*_~hiKwR}u(DA>=}-m@x?=1#Mh0)fly6*IeOawb#st=xTm*txat@^5D0wy`tY zanV_mquTT3O>J{|ZcUDA?`qz%VlL2kMz1!|V_IYtPvf^M1gbC(>5h~v?X*{v@|q_i zG+Z^}EF&t-WOMy-(jjkF{j|+hlk1Peg@r<(Dsp*6lgQCq?yL39vE0#be2|!W{0PSx z63(1(pNe8K^NVQx|MnM6R-ID_RN)>H-EQ10n;x{!M>LLBSsm{m(W|UyFrRk)Kvx5?}%&IV2J61#d> zb)0`<*l z6EQYhUyd^`EE_dd+IMwHV{YSi!0|vab}70QmR*bJ#@{bi72pblyW8&Rk$vp`$fy|G30u1Rks!z zwh1f`oL|$;6yH}H8EAbTpEy$?P=&3TX8L_5@xRLjiP-528P*0=;W~kOxaGZgw)FH4 zR_%QXfhx>ti^VwFiEHDthy~qR>6nX{ySS#J6=a4j#@D~Ih@c2_iT&4DUhSQm*sZxbVf%8Ag}su~ig!ks2M-Bzj(JCmh}sCg{6P9wcTMj-PR%wb^eD(7J>p9(lL z_rwOm=V~E^Koy>DrKodT5ADv`2BKm>G3691u0L?~LTArUWMZ*yO~f>jUm;M1PfLA? zJujQQDnN{P=M0}6RT%xG+J3qY%Xc?ObeOb7AyB2(tLx*Jn|*tw{L>ALf8$yWS9!=y?A`Hw$ z%w3EQ)3>(1Og5~v<6d*~dWAq0uBqvclq+RL_e>Q;_c}{CwkTZPw@-35_IQuE3$dkG^y&1AmK zKdeOlu{^M?&>dT8b~~0+WFqwY357rvwnw@%GjKeQ>EBp1eV$XteGZHus4;-$D|Yct zbPm9Iu#-Zd3in><&a;DA#fsBDg8dt&oc+Mlzcgx_D;9YU#A2y%?tcEr|&@}o`7jmpQ_+cKP@+wz48=tsAK_np{l(X``IWQhsKOYK#j+Vbc=~``?GxTV-+`trKG)j%IjN8>yleOf>!QBdhDvZxqEDHh`X>YZ%VpMNU zU`b)U!k8SL3UjGpwEj>=BvoSyfhw#$RNHe$Xget#8!9(2Fol?Mnk}1A9J#Mf(ZTEM zw00D^3w$$KAy9=8JvxQdWeK~Prh@3@Qchr9!qj3ci~9Cbb=dQ9WyI>lS_*+Gj7-uw z%Yx_G!2TsgR_EdZcd9U6hB+<2U;CY{9#uwkyOLcYP^G4$+=auuz{DWYBGAjgokNV` zsgb)QZ)5qRS^i?el>~)A74A*a%`JDsc;&bTVr*$=<;*sotyUvm1HRkzY!A~3_UI<> z{U)O(d(qJN@eVxCfO~}&%g4ga^$ms7ixGv^Dg>(V#07mrVbnVQtUzN?OE?LPK;e_% z6Ve@Y=LYc;YZ{0ii}ET2s_^tT%>Y(5)5~YhApDN5;Th-T(0))l`sOd~z%;2(RXU-) z{`N$AaVf1;Ay9>JKYHh91=F{Mln|YUl;LPW0{3nyUNXO~{-RhZQLlL|0|`{&2^czs zR4|c!r1hXnTtV$xraWf!^z)iq%4r&`J+xxUBiQWn^#1v%lL~<<+_k2!P8aOUljxk~ z-E_GHjrzZ$fLJza?#hUS{|+Bg(XivG#X&I*Alj6l=< zSN9@VY0BN&=RFzbBC0TQOXJ|qGpuDXiiM}Ws3CzWELqy;e5mX3?MjOecm9zx{a>f1 zP^CtlJA~BLo7FBQPFJWUR};SoRAGIi_XoGM)aFo6J@n3Q4ci@-ImVOedlO$LniRVLlKG~X#|nWeb;mu8Yf(KzOE0nJ zNgP895?EvC)cqM}z08lYqQZ%<3V|v$uX;VrtCghje8!(gloPlZ-&bq8B*0w&(F&2t(YBfUd`XE7zoa8T}^PW}+RN?6hdL#Gp1eV9Q zvAEB33Yxc>aRzzb7*+VrhQ(66+fMe8;y!$%lR}^h&p^|Ot#9M>_sgSxI$2GKIy~1Cj-6cI(k6ats#cDjLr=tZ4d|JApsfb1Yb4N)Q zx1+O;&yFfQ>rOpfa3bsVsS|G?dMN~|p4I=uO!@b<*=b*G|Ar`5EXJL^HUf22;f{$) zZ1M_ay;thI#*?8sT961Ic+M)z<7OXudZ+fqF6>$U=h~$Vr=tal z(1GXdztgdn-eYN=KdpTSeb*9I7^hZ=rXiF0my^Mcr__4Uf&}HEL;Cnx9{#l_aP~IJ zcy)NBjw(Enq!LeeUFM$_t=HaD>qQF^RC`PxvgXsbYaXBEorCK0XC3>>TJp<+1fEW! zJz;kz(X;Ap?P~W93V|v-(Mr2AOqQ@^qE=w~{@DgktZ6RV-RaJQo958* z{6F;)hRmNVt(V)VrJ2OfcFvqii$?bgu+^nC9e)?!1yFrFORglwZ0NvqCRdVe@RxTa zR39If)DunXlw*f>6qfD$mvaUPM)#sJ#sts_YwJ3|S?-o{tl89}Q!8)E;>P;@58>k?MO>n&@@lQ~+DNX^3&`x`pwrm4Lng>#@Vb`_bRoC?EetxDAG z)lb+nEY)tFy3EmnL~31lcz6N%i0IT;tdFYCwS2zjF zy^_C$DvbNms;FcOacqv2@9OHJ5U5IR1(#PC@~OJ!5Ig(&@yEQm!1qd0+x+b279+K# zja_z^h?N^Pk>h!8w(V6Zfxm@5@H95{SJY~^C3N8beyyYss8W5bYgU-Ppj3jLtY2TC z1qpmw+G+UogZtjEz!q-mtfT`Yx0ojS_SxV>{&&63e3+w`KnoJ7{c({lSq*v4GMeA7 z;>rH*9;n23G0LeD(X~SP)pHt;rq+uVBxrmz$Eb(&^*6d}vb-;MO?O|rLai4qNT@Ng zl|y$M{ixOYQ|mx#$E& zFDMv!sr8}?<8Uf*D)urP?zB|<(X*c{k6*uq1l0xe|FY&&bnwhcmg7Kuo~KJ+SxbId zkiaM&-M3_M(zCw0p?wVQpb)6S7$Du)+tEdTHOtBmc(qUnR8cE1egABOI~~pS0)6~= z*6uk3M$)L4&}9A;II_S&ZPO&2SITz2>3JzV@MUgRX0s++)UV${A9$YLY_)o7t3?7; zs*ggA>g%p^O0d0+3d?r>%R3VIv~=gdmd?6twST<*!O?bqcW8kr5_vlQ8Q|ERbqV@BT7g+qH!sxn%l(j6XaP>+f&WFOa`pq76fZlGV!a>we{-6? z476m`F9K7GL}Zz5F{}UiNmJ}Kut2r7&%({}Kmt`V9cxZRr99R7$Nj1oE_Zyd%w4n~ zA#&qqis87q(4J2j*3g z3OTf1u4m2P-90^{_B{0b|M0Ppb^PT6e@i~e+nv`;0)5kdRM`8fsU?UOB*K?Wag^Ad zS`+_&1gg|pvOzB#nd$@Ehf>FsTC(Ad-!B5|4-z3mo;co|yXTnSkU*7eE9(;u zrVwTFdj3jBR;tzartddXIE~r|U;FcKSr-arKb(^Cx%rk=OKnkD6S4NGMC(;2ei3Lv zqBqsV9Or)eqmfg8{ww83pbDRs&hK`+^NYY|&r3b(Ww)Pw6#f#$;g^bGLA3u`#Kc4E$@Laz}9Q1*N%!N~DE~Y%y|BpbGI@Ug)uz|li z+FT^Lrq$60j#qmwOm>{Fdw})H!DyCI)%_D|_Nb;lVtz6kzGS(rVwJ0`qX#qacb9n> zHW~Sr|EG6l2ItXvQgJ=abTP$8x))x0 z-lf^~xOvG8El6O#(RqfKUV8Oq+4ZW^RRUEDLqFLfuO+7Vm|MiAZz=eUt(%(6(1HY( zEWKAmL@SSHtl|`vKvl_=nY0p%&!_l^-qKWmI3j}0qP#*25^BBbxwWajdt?MVJy|7C zRjsXycChH#6dy-fd;P_d;p{M_11(5kJ*K&6o%VX=rNh~%2`YiAI#xGr>AW*3K3sNn z)%XASppBW7%+P`awl4aH^vPvhR3H;e7wHor$6m&uObPwAR+%w zod?ldjhFrOJw$w_^7t>2s*?Ywk8Gy+7)?)QP5xjEOCeG%_`5Pqsq-$nL750j`TrqM zCDZqG;Ze0bxGWE}AR)`<_gFhsYp+jTIGo>_r4p!WNo9W5ayZ3D_~xeiwc!!`FCx%_ zgsfw~``s;66OWCE;IF7AB7rJdLqq0@6dxN3+jQ3g&-lEVKWk#D1qsc?82RDA`+;QJw}w(p5jCHSK&#; zg=~XpK>|lO`ohuBJo>$x#YN%KDuJpv>Ju*wj!yA$*W(kL^1P;a6P(P@f&`Ai7EA88 z7QITikEqk?xI&<+P=lqmv=RguN|+8!N^Oi4${d^4E!Ya0J6e#yxgEWI=ANHD9oa<;q?(8Xs!9%;V0|+>)kntIkG6CXexflw6qDW@M)*2kac+DvLjqM#d@tI! zY)$piX?q7ArFS&bffgiioTu-S?drhC)$J%`-;M;Tv@4opSAAVTIQGqfO~&Ua4{@gly0s53()P}OT@oMYf0F)2RIFDN87=FczM)0_b)Wi5~%v7 zr!~6t;VC}y->NTc&CYSz2GN29u1n~CQ6kn6k)3=XfvO5Evl+(^A58J#HL{gh#+Py% z_0(uV0@t7Pjloy5LzwfG|{L9`%&t73WIYbP<7 zh>iatP<6pQuaWTJSc;GEW_}`UNw6`A2(%!9>vpcNx6#_B&5I8h#u|HlMk1r!J6#hkU*96eO^1A;$tDD{KlACtTW93(1L`_ zx8E_c_#d5w&-hyGeW*&HO6K~c>v1VQ4pNyr9bC%ftPd?n$nsB($$LJa5()m71P2(}5NwWIaxe0n%!M2&(N&&h(K$m8{cEctVN~*#_7B znO2uG0JI<>+vD%J&p#$_;qfCQ>!uT<^Qr4%0z7iAW!&sWgrkPoyVp~lFL5;6R21zpxeBv2)L*WP2V zr1%h*Zg5x6I_ByQEl6NYj&}8{KH#6K{Gr!=6>Ij?slSD)*RLiyO259!o>q`EnOE1h z^QSx8>0vGJF|;5dN1xQ#BE82Fww*Vw*+u_rXq-Zz>TnIs5tej~l`bWH^!DmS-yZ0t zH+y)Fp#=$yWzrp6rTXzq^*ZRAXtsm|s^p0Jb#zinUQL}GZ8TZstIL@TT9Cl#4&D4k zGl1FaeD!K6qjTzSp-RpS!VjhT$Z`Fn?c^Rm(+66Rz?e0iz+U&!7Da?VJrxqD`Z{oe zb@$d(AETV}vra?0=(krVGqfOqQFS^wP#{0sHL8owsb@d}Rqtum7hWaRN5ZU`Z0dh884n&%$EyI5U+w4Qi=3&74t30#$OxY43I|wSDY~XYC%e)j!QSVa`QUEl6OD zgKi_~dx|w(=A(xRHTr=B?u$^|C+{cbOYxF>G%vy5LKQ|r=$8LJaqR89Hu}7vc!ii= zYqrf2ewEE^rW_0HyKUe~&wwiIW2i-Se#&Fkx6)s~-_N2OY;>eMcZKC?tQ*(| zV_!+R`znuc+rab@BQ~>sO^-X8*1gPD5P|(X(sUci7B_Kk^B;Qp&=>;=RH1K+Wv9Ko zICZ>~-uK{EmazA}W1#;5C!d>WhbQ&iaMM z+gbF3ACBg&5}A#15l3qDN4F>a)mRvV3hLHs{~Ac33VqW(+oRiw4sY(TaXWW1zc)@s z)@!M#6NT3RR(}?5KQy5?Ao7-#lv3;>QjhtKJ*>bAY zI1fU96o;cPHNOfw$l7Gj$dN!5`ldPcAb*i=bVGKoQ8erS#m&&lonZ&5b>a*V{n6LC zZVngy_l{?GdgkXypbC9kEcyKhh$CkwYwtEXSoPq1MzwXP*i!0ca9)c3=mxXkzr;AF z_u6bnIgSLX&^LV>aN!tncM5Z?tfq1NU8Gmk%4d|zeu547L$=x@mx6@T*vZEHP@&9_ zQ6>HTp6Mrd9W6%hap9KaY8+KFg4~S1!j7`mEq~_QZ$f^rcz3%-JoMo{Wu}jW%(vew zmIf_`iEr(u@LG8)aI_#XyKYWnU6aFR9?}i`i@J-Ihu8B;uN=x62UWPfqH|;({Y1NC zN9lXNJ{&Db>|T-Cm?Q*qF-L>F@BR=a(kJmk9}<+c8meSH{=I@6Jf?%F)g+zR8>Mqp z`8WFNs1|Clwp0cx5s;^z_-Jt!>x{k13K9v~9)GWd7dN4|Kl^!#LCys^T9A0x?ye)- z-*(pC9C4B#mlYS&l@#-49#B@osFMB4?-hOf59LML2{px;tc5vRkm$7Jh+{90W`#^2 z<16PDS^C<7yznd|M)vH@%()*u{!<_aVcLRjus?Nd|d1BtQy54%u&r!a|17D z4-#EIwBvnec{+?Vk*pz=6h>RtWgh0({4_F!c(-XU&(W}lxYBzYLkkiZN3mEgdA{Jb zTrEXRl_(`%f={9nAswFZ7j=7xu=PbaT982B^qrs3zPwSPexmwPyAmnFG|BPtcl@eq z-unFIsbJAPq7+995&?Br+6ztEpHk)xlTH{X_Y4pVhaOPkSE!QX{O>qizs)0z!SQ2- zr@N~;14y+XapbQlwbOOo&w88X5&A5Hc6Hw%;rDf~IRi)~P=!%Ly7R3660N^&oLGA! z4Mz(SW7daT%go--hL|&ezW$|H5#Ir#$NU{i1QJyk%cL_!?c1^`nMMoOPi&>YleMO!po0PaHs^o0&cYJsG_#Ldkhe4vkw*BUuGu47byX?WX zT%RM^mqyZu``U}F+Q&{}?}yC{Ef|4TebgwNR&SWCk;pdBU}eMi+nVHzW)tXnF)og@ z#gaL;h#s}8s&J~dg;~;GvZX&B!_E00{U z>;=TIzS~)3z#E&}L5k2(N#Py>`lD59OHH5QR7P~nblX5eQNh#F%rc<49u%C+(+}Ih zo@LCSoqNU2+JO5i=#TC&z1>ECR6V=s*6gu?1gg-t#gb`ES3T#DBm89CZg$8!yH>67 zAy%GhHSX!4Kbl$o=&e6+zRfpo`febBD)eo!bpGB)kGnsKSMZEvb0YF+tM?xz=|`ldcH?HK)D#8hL=d`%&~h9 zb4M|CfAGJ=?_J9#{4d?*$x_Yjqs~z^rE*@aPtDWJtA%bzAHNCDpaJ^BQ5o3LzxFA+ zmPp7n{oZ|@Q!-fpcBc^wxm$^&1&N96@@QZFq4RL&Y-!8jKlLqXCb8k&xw89;D%>lk zH(4zM^v1U~vQwvDSBsU>V%^Df{P0$hPu(SO41fminfiF8VVM zPmUHOE_%GO4fy+lx$>r)Rj<|1`wYsbN2HHd&H3JUz>ds$}V zRi0>H`E|T#vGSws&5}g(w=j3HG-;Ju=aaolg{AhP+{my`Pfx^>^GG_ z6_%#Ovade3U%7bS-ip$({*m2weZpn)x3D}^qVkSe_QX5O?YrrzkU$msqn(E1zV_>b zHrNx0z~4ps&bOtugBh--q{HFA*OY=`f;88(Mn<~V=A8beXI%3{Kdf@devi_Dv7W=F`dNd_uQN8?5*bn`OKov%G=+6F0aWRt|_r2+~w`g(nyzSiin4jKZ)K}~7zia-C z={|jy_sjnfsKT00w^N>4ZU1=3$1(8Bc!9Au%+av7J#Eg_ucf5CPr3T`&Yw@(Lns|+ zK?2K$W*O_f?c4I_a(Gg`LLW%CZkNN>F>`8OOs9AeI4^jY^yj+8Wbm$XB})r2`36p+CC6WLbFalMVXWS5WSbFZj?NQv2ugVpL5n z!SNk3dH23O+wT7#0##W46y^L@#e3(el6Hd#{9W`%HPP|!l+^MlQ0bmylWma}=dBW` zQe%BDJ*wNr{Tc54igFht6OME4hX1w`W;yTITn)d^M^nl?u)$&*o7vEM`khLk3S$m* z_G3+XTkBV~tvRXHUS6BSaOry3{4IHqarT^^l?nzIDv(#J>xwWG1N6m2wXSH-%q+AilfD2c=qp1YZg2bxk z?%K#B$5Tc%?=N1Co*gIF9!Uh%CCz*EDbs@UQk8Hz_^0DiNEh!#G!CK#3G_{0aH%rU zF~0k*n6n?p3sj*$oR`u%s*R6h>2^nq>`Tys1l2dqr_7m@bQCC--m%lKmpzp73P)!g zk*UUL3s1(U5HbD>tR*V0v(BUPppndReErY)2gYR7@`!$U%<6U~-YR{d1qm!!c`jp( zHO!i1Jx&Cw&>zN5=>GN+ZLBNZR#@NBQ=tV3>Zu)vJfc!+$;dm!ty_()RynFs@9bFo z;^%mWaTt}D<=n@*@%?pcJ|fV91p216b3ZS8>#I4evKQU?>V>UqqQmsAtTSPRqi+e( zz&B}$@V@yU0$1M3_-L72u%W%}@eI~kM0{V9PWznmXM4tVCi=EmCU2>1@3G*S_m2Pg zKtg#!%M`1heay`h-qHuE(4V^Ith05Y-PR(9wImT}L1Jq@XRTQ1p_KBlUbh98X9~bL_@<(FiSc`ip_g}qmG;Vdl{9P&=V{6WY)beO> z&KfqnkU96kT?^bLp zd+?)wv>a=ZKoveM-Ro>AY)x2@#wL4F?4z)^!{{*G++wM2bw1~8ldB@MATh074oC9> zscVA|j-}out-$W|F}BqxGI+a|E~oGrJ#V_s9+bd zJ7Zvrt%#^th=C#{0*c+;b&YFxT|r`ZMiCTS1Ou$=wL7kzzd6VAx!>3P{yUGy`}I6C zvoSmSn)l8mG*kiXk2@tT>y2oiWLi1itw>ZNwv8s}qbIYyAc^Djo0%k233rk!|~HgRNn!!7#K zNy+4Oy_N=!AD%Q$(~qr6At$T16te(zyzKdnKYH>C(_%G@Ac3PA@gCLKjqgZo$xqFF zts;S{<&IbMv@5BkSXEOQ?N7Mz#dAIRgSsIEBS_##O?FIyOa3rZN{e;p?iZB4a0wr;R^*F3?l>|IKbKo! z^MTXuvr2lBju9lVhT{2=6w33b8&l&?k+SA<9{MtxyLvr0=EhAW7(oJ`T_R&E zXfpreHp$qy(8@ppRkzE#7#>5i$b0iSe|LB?-=CST+RUh@;qz`uh_eya<}4Z5uBCyr zcn-!#aL*U*wR+kXf)OO-GM>zy#NVgel70=3rJ)MT8`HJ4@u>QFGPt$r=zh$J;dzH@ zl0NRO2}Y2>k%4%sEuX;Ob{a$4bQ@ zN!i4{O-lo5v9juz5bn`2j|}eBL&FFX|6dsk{dnGL7kW47us1#pN<4X{ciWju#)|V? zw8;oN{~Pm)>arS0sZ-kEd{9k&t3*bkiZ&>_4L!c_@sWd z>B{!42-XD4GqzpROO&`qvP8e4C`+b$^Aihd(+2Mv5sV;#&n`u2vAqr#3n|b}6FcZg zpsIDb<9fW`brRm#R7S%~75UbFEofYi;RGW{;B#KAv&{?hw?Td`-c zLmr!(9YUvx9k}p+v4#P8Lv_|TpR6|bqVvk9vjIDX(BQ1O1S3e`m`G7V$qwf7G@QOT zQQbfSRerbX=;U%fscwD_mM%Dz*{6=9*_B2Sj39w-q}W$xUMRb48%vL*_EwQV)tfGN z)2_C-N%}T49jm5;-Wc}n;^~$7WeNJ9O})FMg-p3gX7+ArAgw6-#};9s{U*`DG0_@E zkdVvh^tPRGFn%(v9$eEv6_yu%xP0ZAUnQ9$Je{QfP=q@kmTR^>C?Iu~`Q+YzOEyIs%7sEsNv;r zkUU>g87Gg~(5}zA(h5H92}Y0*Bf$UW!o>Fm-(9K4)Xwz%`_wcfP^G;-r7r$jWc6h!G&n-A&wBpC(;zMo(0-lL%C0|9e|~x%MjQe8*IV%Xvzt z_OD2X21O8@6MQ}Yjk>>SE_o-OmPjkgssSx%?2$^eb*26UBS^?)TshU9_MTaamOTAc zM-`S=E??1-_gx{^%}-}_b`P4^Lwsp^elEcX66mmsubo1N()*d|#B1Gj(>U=zZ)r%l zo%Yf95-IYgg&`B3jY8tpG1ECOS1O2h$^A-V=H#nL zpbBd%cChamMb~azL!4jFR`Gwah9$n2(w;|Nu#Ea%+Jw>Ia#Kk0PtyoSkU%F)QOcZ* zrAvK-$oH%)6$w;D4KJg8o_mfoG@tV^`zBKFgKxEZ{&zJohWpQNUwh73o0)iqybxnt zq{Uv?^+wP=o&MBPZ|4woiBLpIrv2S=n@zOy4?=dgiZXp|LfH7{mSN%AW6M zuSK;;C!@tP3Tg4}YfKcap1XoQI1r^_1PQr}?$Za;b*;S@2XbZOYW+mJz^!<^$-s=hxWPnu|Zul=92~<_Gxu>o^n?~B1N8yWnyr}cgTKs0$J_I93h!Nob&5|o~-I~)g=Nj?g z&8t-;Q03urPHiwW)pAGKU#>#$e{Ie8=YP<|81BD35{IHk)q^yJ920jG(qhfbrVjM! zIxk-GpR)ubNXTWB%Y8+{wIRGabv4Ae?Z3R0BF84H>D`VIU;CCu@`-h7u_nouJ1V4j z0s1bnCol3jfgph@tgZOoxtlG$7S)^kbV^La2ohLB)7Sk+NYSN(c(;4I2-XD4`?|E9 zYA-S%oNAcvk77N}kjwU?x$mCV8b**n2TA03{a8$f#z%AO$1W0qs(SNn)dHoHES}6J zpQ+@bTR0yaHiTdV3G|%AidiK)k|!VI_~0$W(~v;b^=aFDq7oCy>szKLr|p=(wV7*2 z@c>uRyFdg9^fE<8WAJM2pNfdljl{Ab~2ctrhiIo=GIqTt+Lab!_X? zalEWwMbpUg|7H}U`s$C)ClS|DEe)i_?h*g4Wm{~*d9+VCf{{Og`2Wgq`_7KV4dgjT zsDXZB?y1@OwS~vXpJFV6PN6uuU9YpX8wT*Y#pe@@Ac3_N`-}!&V7Ze=@v!)ZJ{Um) zYp5uL!^_h)e;zVwj_p9vFF{`;<4$SgV$<_P%QRhymutAt1I`DGA!2oWBv3Wm%EegS z^Q?s!UiuCr}W^P5yBc`*&rKbEEztBt{}o)z-?@crZMR^t)_o`I1c! zS^TLybzR~xiV-B_S{*9ilpkMlLhYE-L?TdSzr@w(IUtMV`kUHuXX$QI;gi3{J;EtE zcj&O)tKnjFpLU9TIAf~SzQ2!=*j*9Ym~mqz0#$hjN*km8%piYdnTU1q9{kF^dqxji zABqtqJ~t|3=*u$5z+R?WZ5{8!(-L18c^BO&Mv(B{;bJ&NpCS!Ln1}$ER8sT)4bAFZ zC`FeO+k|B+%C7P`##)bhM)!~ei9pq_%B75ZN7BhsM^hP_3e@05NliAiHKiCq0&6SY z?z;FA&x6Oc_AYT0U0HN%w|6dX6zYGR^uK62R+Fl>=2b(FGlyqZbg!e8G0juV(Ti&o zRqL+X7_XimC$)rsDb8SpxukO9R5IZ67>a&7`r_p>pXd{7>tv1jd@L<7i`+jskqplt zD-o#5ypf}i6dC`&iD2b3TX5|u%PT@DMkJ#4Q~hlz-Ev*Uc>H8bHosv#u6s*FCa8|q z&DEbBO0}Hx&OiOhUq>>@uKU3h`w#59-WT1bUmuf9TAnZ+E8*#r#lOWWs4~vt(JD(>}ZX+o6$k-1KX;DQ#Q7OsNjzLu`lJDgnkQc`XQ;Z;iwN;e7 zOQqO`ju+X25g`(RDxYG3`X5`9$aZsoHLKrH_PkjW-Zs%gB2cw>SY>@(u_RL3d`FdR zz_sa4x5%u=!4#ir*e0y4$VhpgBG%Z>A=8~kNCc`j2A)ehnUrX0d7I%CS^Lb!JpXY+ ziV-B_S}9>wSj8P){NqY@iI7yAveK&knP@p5Zu!aHlWH6yp+Vy)jty`$kn*8w+R*q! z;vll9#cl;&DLz&Air#uij6|U7pI6(wcSR;zdTMP)b^Y_0COm6ZQ;HEJesR+O{} zg|!KF1~H$Vu@Zr*>05XB>{^sa?wDs3^9%TE1M4^BZ`*rPj39xv73&(*7R#h(<&w`C z!zqq;ag_ORnvFX9Ya-cZettZhaMqZ-`6wyUc&tRA>fW?=>WqXW%k!gqNo(?t@{TRF z3!)f7BKMe`I_t5Ri8tT1=R!IVBh!uVJmf7As0s}2pwd zhPXZqr}(~%Z^ctHt+dkfIVov=;*1(uhHhXR*(ZPD$3evIR3Ys^DkJ@cw{YWKacxw{ zam)DX?${dq?(F%*Wnwdm@8MXUV}`v}rg1uXaNcy+?wL}bU%b15q~EGTF@gk^ttgH9 zZ6bS%Sg|>u#z_RKX6HF*3C+?i_eZ5+f0FgTtl6DK<0wXuklQh{W(w|o>h*K2vlWGch*M!lW95UhkCiu!=DPWzuyyz5hO}@ zmew4KW?0VP$72ns^K@$#=h93fP&LNUMH}`w)ADrg*5@c&w8B;!JT{D?8-;Bekn5rq z$vkbjYq#H?!{VF|stY!arx-y3Yb$1DwEFZ)?>HmoqnAXWO0LzsTJmJT2Hfni*(I7RHqNj&0Ud2QgP1%3O(hkReJCh?bN^66+$4LaL{+;iv zE!cR@d>xBD+#KrAoNBYh`!k^!LBgw}t5z%RjHOl`{`}5dto_*vr@j>ZR&-TUKNitm z=44w&G96NrH0|+cRlSu!F@i+PL-yM4E|)B`;ajpZNr~c($&G6v5`n7AM~i9a?6NKQ zVC0O6td!$0_AN4;qUVh+_wxGh)JbZtrB=-%Dv*?W&g65MIEg^jzNp>WjAmN3@ z><|1Gs^nVT`Z$JcjMK@ZnBfwEs)S{$)W19BS&nY&dpXAGcJo=G=TQ`At#EFt^0o15 zyhFa_jpfd%n#{V!3s$Q`kVK%$uVZDkywy$8%RB;b_bX0X?>|8bhmWEdL83uzqAtwN zw~StyUl6&n&w+kB*ok5Ui5YjstFNl(6A!a5y6wRsqpkG{=2tR`;*1-%3CkAlpmTdt zV?cA-ak4vIf3|=cA=c*qB1T21!m`E6+z-O^S$ns#7jq*h&U50t=8_^10jVMNts9W-%wE5$1THYqKAx&7y36(zhx28m(ibQ+p+kWPg@#gVv*^wPc zR`DII%w5q_Lrwwa1YBkd^%c|eEN`iWdwtgWz3at zk=fIY2DQG-8VspMF@glvR%Bq;*u`2^=|Lu02{#J<3sw41Yr|@DHmTppbk~;J6whAu zcOhq+MpKL+@x9G2J>MsXTol<5it@s6qSHssXN>}UDP~P!o3Ly}IXvn!X|pMeRXOM{ z5vV$=+8V=eUm~8m>3pnwcZB^L*_brhG@4=r39PNjjVjHb9hpP^nCB}IsA^i>&S-G$ z5=k+ik4G(>`QKB@kQSvo(JG#f#y6{rW@ih%*E#b`8D2IQEl(V`&G*>SNtdk_c2`)|;X{8C{AN8d!#{T-$|W1PS!46{T`QE1or`jp3uz zr0B7tPl^7ec%COT;L)=y8Ux+EBm!0F`zngZo0()m$+>DDn;6O8MFQPuMR{3k2l>0u zO!Y*$2+4Iu0=-`0s>+gD>EneFUE3j0S>%(_Lw*Y}%#cDG@NkbZjpXs5)f-U0=D*AbMNVJ(ymm1fO*223xk=i(&*_PV`6> zrHB0&cB~j@>!{dK0R9VASX=RRX5vV)!e~l7OGZnzLIRJ7$URNyiEa zbcYpXL)$pAYQhrICt(c5ym0iB(f3u9P66Y|hpP)oRE@C`fvQ&7x*jtt#qu7NJNXXV zS5e_p4^fH{BxFap&$Ne(TrAA%_V<scZv~oXwkV9 zd6uJ3G2b~Q_^zHrB2b036*E5@29o4+8_Bqrp;E1oz+))Bqj!oS&T0FJ?}-u8u|fh} zZn5UT>pkplmsXrP+e@x9y0z%_DvFK{{k=JEs0(HnIIB&l0^pbPC0u_fPJr{e1?rIe*4V1gg+s6?fF* z=|;ShSjRMDq~xR`fsUux;b>23R_pajzBV&(vPr08Xe9jAi1c%qLdeS6191gdIx zpQ73}Ic6F4EiRWp_PnjgyY?p(BS^?Dr*oazoRDp$+9;xU>`=+s0>21;u5}801I>lm@8>h4EwdYFXcI+sM zxt!?3p1k0w)epU3nPaK6sT+Cw@V&ZZdK|?F66mmsuQ?w6B6qaEwQBDJB?46;W1X}K zaThGkUHiN5$;-8+iGFAR#R&SMShg68Iv>`qJgKeS%TAEWKmvVO;mK6aW;vmDtk2uw z5`ii_B4XS=m@-p_!dcfiidnbl4x{fYa@CG{QRCD_;&8}L@^{hGz1*sx##66YTE1<} z1LE0e6N$OjhhhZ%R`g*NWqAF)#4&Lg`7mshM4(EpRlm(MSK^|Bo)*ohIn?+A%N6}q^Jvie7HcDCFUmNX(-a%+)5 zA6DcPrEFlce^%k@?WPieDs-b2C1%tCZNlrK)UA=f2#DAEo1zy}wKGAU! zfhzP!#SUPpVa5~o2Wv$mC4UzQ^hgyYcYOioP_QlU_OqfypbA}Fu_oa6$!cJHWBze+ zGm77_pbs1Ruw&ZQr8h0}gGbYE`n>12v;$KTC`OP#hgH1SmYAsTDY1)mJrgbws0wxW zORIfYe8FQL0i4}>)9Bv7hhB~Lpcp}y6FpKzx!e4wKJwuXQu4ki1O5wDSX=S^;8;y7 z-{u-S)p&$dD_GQ|5)^>G_M4&2u<1js0DEa5hNBk57+(0P6EFj$R+j0 zvtIu+qJeQ96eCE;9`c|iE0|-Q%G6`9mqegji(aqT)vC)%w((MBI%uz_M4$?5D|Xc_ z@j!d}M;PnaGDd0#5_k+1<;Q4Oav^#TtJiC+bgYm-ms_mBf8ipV&h2Uad!7{CWpr!N zzZTEhPovqj-RsEC3!#!vjRd;GVm>Ok3fX7`}2@%1!?zR zS5q;9M1^_Q#$2%z=G_7T!oO-VBZT*vnolwpwM;<*RdTI%ybR%G`3(~IuwEKQkf`Bi zXB@nIiF|f2l~HCR_Ki(jU8^g1P0f4bnVeL9%z zDG{JTZ&^_mXH4cE!|G{M&lK{;2>$PZWv)ihz$_AOK34bVP3HNTQAT*S?u7)Z(6bg_ z3C)&4reCvkteIL3O##E1gdPk${4f6{+z10 zRyGGG@>?CMvxxJ(y|D~b$+@)??PIuWrCDrwp$6UN&XoexHk z(0xi7qm$Cf!B?hslzSJ(H$@#`y}MSI2vngjE@lDNgz--zS7uqI0?8ObVwfd=cZIo( zIL~0-{M&uDbCrG!2~?p2E#9^V1@lFHMaGsf*c&5A9CNZVdNt>SzB9F>$FKnYrGYiC zGhm8DpbFh+v4X&$03OiHnop$>-WWlmL9y@p#x;iJ3^v)^kw1?r&sVftClRPZUtCef z#dPGx?(%%FQz37RAhFx?slLWhw_Ll|4KKdKy#fE)!tg=@Rp<^YN`rb{e0q8V{&4f% zRE!{zS0-2QdN|c`pU7lhj@D;?p3K)yTOkpslJj@JZkwxziM2(2Y}=$^1c})F+r4iLOCMjM(9GXJH~aWxjTo`KYAv z)_DGHS|5o(75e*%@_L^}9xaXMzIj)ZF@nU1`*vcz@I;bsex7?*T11w7aHKS;fDf=dowRzBppPi8W2qB?49Gs*70wvD0(+ z5o7o{vN{zbNF04VMO`Qy`rl>Pw|hZ+I$rwT6 z_RK%j%OZ>U_xW%fYD>$E>dm_rZg~s|RLRcvgORp$NpNreWdKXY2oj%i64i?3QY`oR z&ka@R#W`*H;XTX~2~?r0E_Rn6R)tnS)|NY5?BR_OBsKh)|5vM z*dP(8LhoHsdhcjLLxwixnb|4H7(t@_xcll+kx%{mK7YNi4JDVW@}D-DNl2gyU3Kw& z)Y>+*OnOy*|5E1^j3BWn`J-C5nqe7VohZX^Am(Q=b#8(B+e!NT3Syti>ufFGA_tw$|5bEE}rYl?iXo)3?oQv5gFLCL7*C+)X)?=i9i+3Iw*=*|B+^_bW+bh zDCmt5BwG8qYUNVTSl-oMbWNbKXP;?3B9D0?fhu(C6=j%z0!`;nw3&OSd1C~L8cUqD z?VHb8`jQ$;;%Mx%?&QVSEfRq$^!ml>V#DHSarfTDvAcgNMv!zD@98jvGPFzRXFP)c8OWsokkUNqyt^k zk}-lr(F5<)zP)lS_j%^8mQ*k5MjKshaSRDmVTQT*zV>}fnpoP6j<)hi!3YuwYwxJ^ z%vIu8(sZBu&Tc^2a4%|m%+?bLRN<_Hm}!V_A(TnsgK+NO(Osq}sQ-VL9hPe-))Z5q;^N zf0I3tKo!n%2w!wtQF=bUFU`A}l8g}~4$fVrR_mN+IUoHWJs|gg45J-~w>^pks&Mv1 z%#!uLPxb}{(~_HQ(=dWW&2{6|8Y0X4_pu^XlZZ>pFxt}bXDSk?l5@5@t~*A0e;G?p zx2=5?BS`3Fh*&Qpp9Gr6qQMVlkbk{m=Z=8KDOcn{4WA6vxGImKG1U<8Tk zs|%=!BE$T*N4DaQ8yPTnB3&Di)*J~`;S7z)zuNCcg6B`9+7lxgBS`oZnC~-Nd zlWF6n3nT(nILjfPoRi8L{@Ii1mT7j$7(v4AmtUH-m}6OJ?ssh-7G@`}PNIG1czPj$ zDx7Z-tC4)SVkI*s(YrU-reXw%eVbP=&J`VxDu^4z|5i1U14} zq+$e#DH*}$8I#}RtL*;QS;rP5=*28UB2a~MA>w)7>l&LfX#}0K_flhwAW@~rVm)7E z#6P)h@&K;(FUY&~=ua0fwM#|1kF6_Q zc(=uVH2wXrR3uP^GaHI><)`pjy6oe2olM!PUy43uUpQ?)1h_vLpw^nyHxi= z0#%q*E>^}HQ->2PN@H$0cwz*J+)X$1?IMSMwfR^rAE)u+4Ls;yy@xhI0#!IKBIXUk zH9n-i2c6YtXcLSe;S}^zj~Dsv7tCdRuh@3mak00&hLXCy{ zyfK0V&LM~ylTc?m*VwBM>f41UjdC`!e9w~V!fA{t=3;aV$|AeP?7}fu7bC^}+lGW+jX@zrNnY0oqJN(8EK z4ncgePIF1kh!e)b#1Q_(C}|A%et~=y=K$L@j+Zh{eY#*dx^5wx$&bz(jL+R7I7X1b z+KRjkw*ow9{Y&kWRiH$mO0HFV=jMEK7cVm9iibp?3g=tI++9FZ&PKcz>(w;k=*r`~ z!S>jm2{_EbX|R=)vp!eaCqH(3@if32Vj02pDsUq=cHzN2|++N$0=9v=VE>Bm!0c zdN>>FCY>VP%b19F7w(Y3U*>25`N0x_so{8Tr>mX!41==@`w zu(pcQEV3})5?@BFeCjU|s52$LBwW)6pwYKNp{T8y3P1|#H{&A+^p_h$u@YQi*A7ZLiiOa?L z?Mri6#gm;R0#&6WiWrByxP|y=GnHgkR>@X8^pof`waZoicd`0 z$r_}#=NLf(YpW<#GolImQi1#mjFt#g$+cR!DTBB*9!s9EVG@C=oX{`&$0LSi3|FC$ z9bYw|Fi-fqCdb(YoXx0I=!w4DStsL1n~qi3_4#bX_=W8L{m~MED%<55djE$hmh<6c zeTXb9c9_^t4dNIi=96m?475fdB3^tW9(JqvI%s#`^<3l+{kSNw=i@qQv+498M z*|#()JAXWRu{~NMP<8e1*?P^zVvp&{rZQX)t|k3G+R)R(+i{E_p*sJmmyS#(o&Po6 zA2o_JW1o9>=U?4yIJ)vU3$T91ue2sN6D?1J%B@swZ+0MQaUzyuSTN^It?e;dU`;qBuYWd^QfyUJVwP{u?l?YUAthdd_BQ??dd=zuSoxMoG4u_a_eKbc; z7d_;EGDFnsHV8k*~>`h5;NS zNF+3zqJHmjjHH^6)myvGO#R|Q{jD^KK$ZQOb?RZUir^$iQyEG1cM-eyyV&wYLpb`a z=5qnYtQSY05`EF|70#O1 z+e}N(!0Kjc)$3asIa?AqMv(aXhNBi@pFxV5ulB|nQ`y{*Agw2jkqA^(-CM?7hG{gY zC~;n%^z@ES>cG^799>6rIc@j4Xf*STWn7XUU5O@MTI&<&+n!?tiPCv4+UOgnNu3Z= z86Rt{V1s90OIuewS|U)j=clvgk#m}yHTSz!vhJ~*1{=KpI2kMvsJeH`S<6j2V;PtH zdg4RdCa%%GUGU)OY+{?RwqgyT7kA0y58>L{|CD+43~RS6T`s_va~El-2$?`E=#;e^$Rir^SQ zqQ~qv>gr~>WPo`LSM%OH<{sLX#Ri6Rj36=bfuaTPx75a;?W8QsYs$M4;+T>D%h7qgTmn^V{9t7Gu~-yFM(qc_haO5?EW21Nq68 z7ufcSWyCg+2vmLFb4wkt^{T~T(PQV4v15;tP5pv7`mN~icKdlijT5_0|9;zkJKUCd zSFA>A&^VD({eKZ8p6I94Za=S)!RD*|VTrTs!R)Ur*rlCBpz2+TgKFz?H%LwMJvguM zzgpm5+sW9j<2X8-=%lWkQ9+$q>?VmCW4fd6H~yk+TD*sR3Jl{ILE?0+D0PBuzGVck zZbA%e6(_#?InYERP?cF!RsYzTZ+UWVd>+W|4J^%XOli+Cg2b$rQEK|BF^JnzNPbwro*W}xEjf8?g+iPNc|lh)~J9?>U+aEu^v!F!I+*jYEpIC=Xz z>xgZ_+A2!RVFg$vM=!GMNUTJls^5fNdbcq*NL%w=8@7Haxs~#S`8xOI7(qg=mDBRq zYOHZ zOXfRwVzJd?Bmz~pu3gg)7QIG(n)`OoQpsdou`KrTbzhDVB(5vR^|M{X-s|Sf<#zq7 z_`w$MN%`0&9DPb`6V_HyYR)gj4-9@qijMb`2voH?`9d#rIM>qh=C@msH}e{>6{BJ} zMv%bTiZ@}u&SX-6H)}d+yhNbtQnBZHv2D3zt9a8BbHayQd66Ee<-eXY)l4&mrcqt}OivRFshTPH2}LA6)Gu^b~vptmU2VHs1KUn$+t zn7f)v1gg-#6xnT$RsP2}ccb+<4~amP>}-$8vZf1<_9A{GdT{jgB#%t;6Aw3DNM`@c z(EP4NN$wyL=*fv)r^`9=9>IN?`-qMbfhu$g#TRai3-ipS?N}$;TOv?}eyQ01qw9Wh zV&5Ec_Q!CJUKsjG=+7z2k+*Y*+Nm3usujE~e&Kx7?S)rGuC1Qk_c4EwW_mdI-9gC zjqP|dOsW+U=*cNcV9oVx{Xco^NQZtBfhs&AigNtPGWIK`Iv@0;DaV{$bPCY}6kj5h z>c)P|_?Jz04U`B}*+lo%YajWa8`W!HzLwl~DtS|bugOOaHrjU=djnhz;g}bUjwhBU zT%7kS(>ksa>#J{0kO)-qMCY_cD-ua!JJWN}=24h7zL^bOy}LKZ2oka{TJg%al(J27 z4bM#p5`ms2I-Vj==8a>j-Hoe8)>@fB6}C@NoRiCw4S#myoy*vB%<@Ga79Cba`O>_a z@u2l$5pj73L|;h5!%<;i7y^S-QA8+nlVx`ast zs?Z}9UnqQbWkE*@(C57ZI7X0=YvpvPE~{+gL>)VHmIzeYMzuHZk-N-1hTAiJr&0F1 zAF1~smSY6HS1j8!ueOn#zUs?6RkfGOKo1#@h?vb7dY^2#;m(&mbLN=kiw-n8#fp-e z>c~>=`I8w_V|dKKqw3CIj4!wx&wVO4jGll)!uLeV1?U;VyIVK3%&U|p+?ln7Lz zw=8z<-SLEs8#s}ysvao0&Pbr+ttbaPYS83|X{5+}cgfjC0{vvM3RLS4EV*!DUM{aG zNBHJN_9wY#_f$%b7}c)@pm93x1` z-s{A7yV#D)MR-x44ibT?DgGc93fhs)e;``bf z&B(inHR%p}jbpAWI?(7Xi+3Dkk zA9KmbrX!@5BY_UIa9DmU(YFk3Olz!Y#xd6wU2gQRMIKrCB>lk9{$%CrScyPY?&Lmc z0kdzCO8%x}^>XQIHTLin7N*B=j39yjwaD?xUaa*S;>y3eca{irgwanH^XGvvdWAg! zh>i9i*$Ph?ZI++O9)SWha=1CF;heIxasAC3MlI^M#$ zYjsIm({-aMBUf^^kw8CLWNa1u$LR8Uk62YcOd?Q)&b^|v`Ph(EbNWWurhbw`j|95h z!b5(&o#YpM#@-e0B{{prcZNz<>VBG1ggZI%=)b8G;}4+L$Ml3|6u;0;e8gFzR9#( z^M5tL8lrP6uB$bJ`Ke|nSSfueLjqM|A7;ZLFx}E}{~FOeC}%qR)xR*q98s}HvN7OG zhFOK)vdEwdkLG9ROlP6Rz8DxmLN3Fh=s50l{t&aSQbR)mRaise6s?%ZZ@j6_{&hL4 z;{Rd|(R&x`c)g3}!{~D(q0dx?1ggY7%*Mz)r!6fX+f;l*pEJsssPxxRg);&&v3Ywu zpL&$53#WK9j36QQVK(XxI%_E-dC6oRdCg7RzP`4B5hO4VU98=vPvQgaT9YnslT(pE z73Ob?-MQi;dGpL_B&OqRh6Jj_p3KG+n+uj&&8jqk_gE7{wmnfbRAG-J6XpAl<#UFu zAsznpXBa^u!`0Rp_~#`{8MkVO^K#SnlD%ZLiYn}*WTKcN{=di!p7YW5 z7)Fqg%P2O+k8ewNq3_4m@<9SsSVKj57+;6q`amds;;7>PVhwTppeO@xROD}mwxH9- z1v4a2C3aZW=a#r(Y5ACn<@kM-1q_TJfiu`*6~VV(Sy-MQE!busLjqM|S7kjj zD$i2J$O=E$O5cIB;8+jS9@76skdSA?AA9Aqo-0D=r8UcakU$lV=Y$XEmd-+t451H3 z&SglTO65tMB5*HDF{Oqpm~eHzQT5kdd_(1KwEiOJ<3r6qXX zw3KoG-Wb;Y=LDKwcdHMoa7--|N$=hnNuMUs#IJP>j37}w^l|E&`8UZ<^YdJFF2kB^ z3Z%7bd}kOz0^a~ce|7MI{v>}g&3U|5M*>yyyTLz&($r!LqG=ufB~lrvN_PwK361|> ztpaK-*Lqf;Ouu_n)=-7-YBEuvOl?vkD4rf&Wf&MiqI)lU)xYTf%4pQPG8wikmQJl% zm0<)4d`}cB{B@a5D)x+_e=dKmB7rJ=|5TJSkt@mj^pP~}=u)W+REZsyRokZdmh*93 zeLw=&4x?+5CmZ-iFGgYN@gjL<6~5hzy*aAfCLPxerOAEDF^nJ~m(lHeGT9k8md=al zq9K7Qtf6SfnWEIDUthY_f4YkQi#0@-LabDH+=hPm*p)7`YR8a3l^6}Gx29aTw7kVF zR~oXeGp*Iy(+49+pa&#+xLTFzt$GdU-7>mqWclBJp-PMf)d~l$TFS5~(1a4}X0%G7 zTi&QbXG$iL`?sMT-dCl?)*RF^f<%`;KB@0kE-xHQrhRNR0gWnX{EI`b;JsP<{oZQ)p*+afiJoCFjqwt zdc87na#}nsyzGtEpsf?b2ohp6sCmbpv6OMF%_O>Y`9rNsD`sE>3Cw#IZtdsEG{k?j z>Rm6@2MJW6lPHsP;`JzDH&lK|+ilwJ#;PrRA~9yVCJ4*8JJBsSM}54!n7#2IlBy z70#FmKT&*}Q0j{vfBIy$JgF*2eG z{V}ICuX4b~G_w5fzmO24LG`n1gfxxit;{eF$riA&3mLz zGw^@0hB(77zC{k5Le%YH+z6X!8jJo{cU1K%|H!AtvqVeFE5~c3MpQiabNi`c1PRQB z5IYH+da6|)6vbyfSiq1#Rk6{Jz5iOBXc_OWxt^p&geCBZ&oc~EVeW@ayzj^LIu+x^ z8j5WgMvz!rDWUH6Hd3FSkKD(7+1PRPkQ50k6JL7MbbUypn7$sy|K&X5_xlAx0(}N#<2Zyj=1eLRB;dfdtSTs_96>7K7zAjc81lB zOXTCCZR)DYMU8Kf*`!4LI*Ar*);#V+m3^~ViJ%hP`kAW{zUVBek)NSvhnF>mzR4mZ zE?iPOT`p_16u+;hToh%rj<%*>Htl1}{o8R=-MMA{6AccbN^-$e*Ll_ zyxfnFAs?#o9j)v0Q^p;A@UN>RH*13GdOAzL)$uwh-hYHztK)6G(xAc=Y5YNwKkc`;!e0_z6Jh9t5{o$Hi(zDA< z6>C`a*c-jo@?4@$pCt(2E}V26RE~!)ZOk!(gxroNUEYuieGBqO2R%8S13WiHJ&PGF zCuftZCDw{ECZ=2>pCTOj^ZiZvg>warn=Uyd`^h2|Ylvlw(QEx=l6>BlCpK>(5vV%) zk7AUw%OTF~mxwav#&;$*T1Ea~ODpaeeMm1m_6G4^-%mQCc%;Q13ZElLuYF~Ck}ead z!m}yz0DNbYDOaoTMH9R^Mv&ON{f=@%TZl&*kD>5)#XINv=;FNL+Gf(RLP9oF7@ob8m zyVN1X>gq-Iz4=gn*xlI(8FZfPh)PwFknh@Qy`#yDx98Zl-9tD=@Q%WN68Fc6nsn8( zEv#4X2GTu%n9;;6xpX zK$TppbXJ4b>6yz09jYd^9EoR@os3D9E|R(l2SgdJ0cGjs#OrK!SY3%g6&^Q{tM=5H z_C52OWvq0Q`a~qY^|Uwa#r~X=S8o+%-1x(hF7$oFE;e(Q2vo^u(6g=s?ew@1?=rZC z)QckV|6L{F?e&+*rTOX2UHQP$vwh+X@jan$L-n7+?%r1~-6UftG*AQI)%6a(cayBV zS5MUH>#75WUw$sTeIbM|nq5@8F+H2m4{OzFQE%0|^K!}iL$lQ8Jqu{(oN@@gxL6Ql zBC?pDUn<)CQ5+*kxSwYqBm za$i^FU+f$5Rw-HPkWSahqV6Nq6)AVrZK9uVo-skS*?U`k`uZxd+Gy(Ahn?ljvu=6r z#u{;qAc03gQBJQ9WsMel^6Zvw{KwO1^;?U4vaY?ay8rfgbx7xY;x)LH8r*P|`c!qdf1mz(*obg-2aniR-gi>2ar7 zm%+WItK`tuGTOI*^W;+XG!?Hm;YPi-;!{_hVzFYCaf~1la@tv|JMcUi5o+p7b}0&1 zDi+}XdN$#^=GkfUE?go#N^Dp0$}bsZuT8mriR||?-L0iZ%S&)=Bdg{S{YD{Z-JBBD~1wSL|IJTA#-I8`0< zBjf)ucGhuSC12dXw%B5U*sb7NAPD!)wYzmK?81PxTfy#JJJ(!uO`3aWKv1!}ySsIb zUG+J`*9)KP`g@+|4_>c%anAdm$uo23o|!N2Q8Um+pjMe<2aKXcF0xBA1F7U+M&?&z z&K%+XnfnV|rM~K4!kF~;6{h^=rzXs6V+Xj<1$Hq>J zSrAIM2wty5M4rI`G5bFy@Q_A<0wqY)cs?rT7y6B=hI^|JG5U~`Hu#?d{PFYtw)HNq z&g{8z@u;X??H|X7r3n%!K>}-?eouHutePR?L4M+!y*-e?wxQq1ywYHdk4z^;L-&w`rB%;f2R| z>mB_BN|3PE)wh<7{Zdbi<3{Qrff6LJ)~S~$1^lKgjpd6f2Z=&+CK-VzF0%;VhAO@x z*?WoqV=r}5?qhs#LO+2LB(N4Jo_bwP?Nn!z-<=+0TlL|(wAy|b^V`Ej7LZ4-8CvXp zV4QK|{PGezfm-(W!S`qW@f+JSp1)olD7JpyZ4@njiEUrn#rAeL?cxe!LdMH1xJYZV zp$Df`%O%9~cQXP7O7Lw0%S7*kcZ&NtdB*bixyjd3=*bRI&sK^p$j?YUeeh ze$#{e!4o@yS~Qxq6z?=`B`I) z#)_Ziad53oON=2=M)%%=2X+n;SNrAFvOGD;mZaylRk3}&+iBZ#V>*q&U+xErev7kf zwP{zS-1e<1&V%?QMJeQz$t-grj&D93XtRL?&cW0ZAF1YW@^BptJAqobI;S1_%V+&| zzTL(11q>8brWEp(-nnnZ%C@%&tOX)cw~8s1{}``zsGsyD$@PQpANIO>AO@Meb06jR zCiSzeuW%)U?W8EJCfzhvZj0w7M+Dkh5DBb9in$A{X`U~BfR}mO-&VUwV7)0ygN^l# zT1SuYbD8?vR{AtjRm8Xdnd*lO)BQpQvpvol8pN zlN4o?S|gc22@?IArPcI8XTICOk|TjySR?dXjjdxntz2DJ&-wN5dNt`;3)8ev&rZ7X zlU%KNb@-}kmaJ=d^uLVF)ERt~Xzin6=! z!=JVw5=C+>+G~<`m3@JJjrPj4?)WYKF`1w?FsxRiyciPI2`bD<0rG9Dy3FY!hsxG!ifFs(3sDvTz@XDc7WYE z|4989omcZKaDc_leoU*37h5M7Z>k$SFyn7YYf($!&K8;RP`)7$}%irbGCS*FKHpm^#hGccoc6_<6ZX6uLdhgDq(Jx3&@f73V1n`Ba zJjC|w`L)~C6WPo&=hQUk^J}}+M0T{{S@j70KF-=d4>Ff*Hz`-W7F{GTe`Z`Xu z8#l!-+q5v2u5T&rqDNWZ5ADO)$~C35E(c5c%DHxcS&D3A+c?YoTBo(XIej)!@60s6 zBhACu{2V25ruPVA&M!+6@xW!Yxwy2dpSj~9P+}v_m-o%q zAdCgJk~X}D%`hkJZ>Kjen89YF!dxHUyq;mKOMkmPMaiB;75!JVWL7RPsD)+nQfJIA7ExNL<h2;g|XN^vUUqT&cf^D z%E)%s8Dk?*E4q@V?rjjxhQ|~q;_qdT%pHH!SCv-50wqYK`&?UnGd7&fne!Cc8;cB+B-Ft6rFGF!g;_BBJuv5rt-D)sI%#z>z>L zpAI+Ge~QGg7K5CLXzx;&&NQ@VgZH}#kHAzKead3zHfGYc9LuORFzL?rCF!(k-BN4U za&p$Bj1v(p&(#rsE^WtNRdf|7LE`RVC++1n!Opp*C8DdZCY}~p$a=hX6PS0jhdES~ z-}?n%8+$U#k-*mv@R?13ip4qY3sYivF-I0t;dI4T2TIY=3D8# zI(kxW?Ns)7HgEk~@^E9vv=R;fSfK52%Wfl3t8ILCtyb$ewxiTnB5ur8MX9QVdA7-A z#K^(YIwQ0n+sSf&G+uXfQR=@VZR=f}ad7QR@g?o<^V?ko z`|zZhZC1m1>Cl9wYOe4IR;PlTo$r2$iJXQY+Xx9eDib-sBT%bqw>@f&Xj+&2ou7y&Pl|~$Yuxov|9FlPBu0)ntme8H#lj}#B_be8anY(o z1-<3lFdKnddt2;L|JV`9_EeTO2LHK-_j$U7r;F`sTessX9N%OpvS!$8eyY|KzIj($ z+bSGa{Wvoyit~p{ywj#JyumMBYy@hd*QS%+X&&*pI|6y>fQ|wsNZ{P1DDe@u_^WBn z_(zwnHUhQKYttAU`}9Sa~p;2=mnS0^Oo)1ZvrPVo<@^d{w`{_}rnr1->s~3u4<)T*ki1 zyl47te0KRhHUhQq{Yp{7g4Xi!*;DI1dN#4Wf8a>Pkw!OQ&RNG-bV;Wh85`LM)IyI+ zw{p1)UL&%p9(dw0M+p)*HWj6J6~Q~XHrFH0-?I^@g+7?Rm}GYo3qtp?Zzrk>97Q;4 zaaBw2oWFXBGyfLU(*^G0NT8N|y*s_J7scWf)OY{7*G8Zgjz_w6XzdbhL39|;u=F?4 zdRJaE*69FC`}wh2wp?EG-IoORtj!ZO>vLzbW{G%KqUhhWhCAo_i#D}T6uend0=~X9@6>NM?eS6N; z)M_QNlBv%Vac6sZJ^Ps<{M7SGHe%GS{N_ZzL>6hDwbAtJ5j!SnT~4;uJt8uS3vtt8 z`nC#VX_}VQerZ!Sdb1YB*1apKt?N=Px|dfN+cQPJJAckILHlyG4llN4v~XYNXJl<1 z&iKvZT3^w@C~SnYN-@Q>y=mNxPWi%FnuBuYoG@symSbp7{^in0@hsyYW8>#==5@WO zHraKCv3z_4TT`!yjixcUyp!hYwvA8DK1dvE`@%Rd-(U@M=g?YS`P-P*++_VCvub%- z-ZiG2h+zX}$TybqPFeJ*7DIXIeobucf%TU8FBkL9&Vy{#0J)1($|tK{bKz{Ber*GR z5+v-k8~8kkom>~nn>zKe5vY|)cQxPFOJuQ?WP3~x*u#33Xv7o3`U&Qg!JPhwU_JfQ zYyTY0X1?zg$3C`KG_R|f&8N3wSwzRwR1PEEV3vOMd_FVPZvrJqVDF)IV)yL&v;IN+ z@8yle8@eMwoK0X=IzLvi9N0q@Wp?iLdfw?x`Kw#4Yy@iAZQQCig?)NAo`+{=zWSA?8{m1giKWO98aBe_oj9D-ghlT z+h22=ZF0o3QNO=au}t<{wK<d)RQg9mRT=%%}M@xoLR1$FOeBxitU$hm996qS)x=a_tlSHn-mNxuU-r zQ(U0r&E?%jvlEfbobIOKljytiZD;+`FYj3~*K#%+NW7@9$EXw@$wpn0E19viQt9bt zIO#jzc!;{)mKr(sMbK}Q71B^EyyQ|N^k@VNin8%M*oZuDMi?HK!r6@jMKrWcvs$kh_BEusKG#2ss1vUlMu~8?GrX8~_wxv2 z{kd>&u_G5sX!s;L^&2vS z72VT7?{qt>%?1)BgHsu$H-@pGCen-cP4Hlz6PxI(`a0V@HEJPE-|l7>Vr|n^)oU)R zU?Wh=t=#mOut3UT0ohCHyqTg+uj{T~ovI0(0dQuRU(hqU6wMNuXG>zO^H}Xk32%M5 zM`eK$BUG;v~cd|UKs@eK3_IAJE8%ClY!>)wMUA1xPLz$L$f|;|x0Nad%BiNkz+Q{>l z!7ld8PDJ486|BpsB0R_5fi?oQ(8thC>7qST23Kak7aMMy2XPj~^3m_5{t>`>6`85k zPq5FNsAVsCg{qbGNn2BCxqtT&==;#ajqB-T#_kd9!K<`X4!2bmb>;GVwPLwyHUhQ$ z#%DKk`^T}Ht-sN%*5^_oZPUf!Y)$o%qR8U1(Y^bJvG??y1KS_H677ptVA_ILtJxpb zhS>ZS682uwHtkJqXohucYrr5Ifm-M*>AOLJ(`?DYB5YQb-U9tAdIR)hin43>Ew&)r zOU*q`PaA<+_Fkfe=g~hDsm(SVZXmFoQS0==H^!5N2K(cp+{L*%tBAh6NM6?cLM%FO)4`aA7_7NkTa)OX1eLgJ(sqpS)gsk!CJr^lHV^8_vJ> z2oktsgp!{Ti_&bji&Vyd4?7J^(C!hsvu#drItO`83~m9jQFrpCmVHsdBvmibk|q3e_sU& zD=oD!GMi@S??hH2Q1UZEOlUwh9u&q! z9$TWINaxZ}4-T*7iG-DwS{TbstJe#`V$`$!+LorreJK{hlu^Z1BO=xO2VKk(rxMw} z^Ny)q<6X>zGl}f)>hipE--E&8TGjx~x%&na3Dm;q_@CzrbK?4}+Yh-K_{>W*FW{*= zt!nFu`L)uLu+B|%za2uk`dFeJ*HU*EiQdZeR9=B>q330$%JD`u|HOSqi;q1g7rL0 zM``2yxQTJ>B@er3sV^n6s-sUO6Eg0dzJFW|)~}b0)4V$G4n-T3tEA|7*^jNXgD9(w z(2pOR%uZ(96^aroEeR{i+tD79ctOOQ9|+X4(lWN5?*1&@M}K(eA~UCL2}OyOmd~Y@ zO*$js=vPCD$Vk29M*_91w4+}QDA!WYMSXWhn$@8wvC{In)Q^+SIylA$^>$tM_8$q< zveJ(6(dC(!F2~@gv$I1{Vx{GCt;l6L9%-MNh;Bb2P|Hd?PKNC&QBcoI{_5$eF`+22 z((<|FyON`(Y28kQ^l*|u$%YI5{9#S{@C{jMZX@-w1U zl5eNkfXVqI{lX?T!lI=X%?!y=>`C)SGJ%qx5so=(0PQTwc`!PwXDAX@T58b@2d4*+ zK*`St$6R~-{u(Cd^SH{cp-5P1sYNq&a>T!)Ac2yf5$8|HH3O|nm|T~98t{1^5>{Gj z(TXGae1D@1>na7M2Bk_ptp^DVUA)%$3d7lr6wyFH6R2@ve#@;ZAa+OBJHows+JVxpAvQdyzU77=%mTK--Gp`82xFbC1J&N$~4`x{XYoQ zveJ%7@|S7)=-0bm;w#oJ3Pp*Pmd~{!K^->!7lB$<+F`>hw1qzBytDZ8!h}$iSZVoO zE5g)aW7obGdYuc-Li>S0Ei3J?;gZfvKcCi%e%qi;C`zofe6AIl>#&iLhzCSu{((R( zEA6mx>%6o6W?u`D!LL{-O02Yet`*Dcupx-Z9oj;~`PvE8veFJ4kGoxBJJR$KGygoZ z4<%MwKG%v7cGyTH!XTo>4+LsiX@`x9lP9xlbw`MRo1XhnVx{GCtw?Eyjci0v-x61E z*a_6K(heK$9m5w#s|`tV{pk2 z1Zr7nhYj*2Li&>G`&W6Q#7fKOl3z(W)8Mc{zC=i0@;S~OB$cCHXHvTWUeeN zD{ZSQM)P1&IjmDcmX>_3tT%aL$kFDquB6ZUkw7gg?P&8G)Pi5TUSc#4Cbfrk`pD9< z`jtFam31tnHmK3DojdAiPF zLykYWHhI0#PN0^RcGwUTC-dy}M(Ay>dR9S+m6p$yeqNs1bJ+M_1Zr7nhmE$=_L>uD zt3Dm-KesoXKw7r_NF?-^QeJDXf&ZsHnAR9@< zz#joOMWHk ze67O<`4YYVPYBeq(heK4W@Qc2td_KH2Y;3PT=I`eXOtZ_Wa~*U`Xhl_R{DE;&>M^N z)HDw!y?=n;O@6L?-;qAwVS{}8PYKkr(heImtLc}5_i8i`CcQ_=XGw{rC7(<4RniG* zhYd-*{Rx3uRyxH-*`K`KSz7YB(w=;~bJ(abdZdUafD?Lfg0;Z%OUiA3P0w7v~Tg|Rh$r8J}!W$~;4t;Et0QQ-QbXp|takKR%x zy}F!?PEZ5<*p}^M#2=Zj$KaVLJi)b;-sqbh4`Uw+%9x1`&N*1$w&TR50mpp-WWAv*FqFTGJbH8y>&k7{hT z$xgUcc+VnMb`?+ByfaXO1lp!w06Sbzcb(rrR4&ynX@{<>62;cva?#{IT-P+w?81gT zwsV*C+hKeJ$2{aq7Z;Vx8; z5+tw|=#0t!3i=RdUom&UA_FBzV2#iYy;noM%qmyWG&~iTd%AnG3AQREordR4&^FzM zc(A@cB2RuXIa_{Pu8_d;Dax^0z4eZFuku-yHkop7x=ZH$?B(qbDxS^3Gc*)!ILcpd z`TQ+kxTPLP0<|!=G(J4K>E*k=;jjO$uAu}8%pvvN*TH(ahP%u{X+Nq%FS%&bS|qZT z(I?d`?!RD@NwlY%#PLTT7$aNbw8-&x6H;Y5UBg?as?u zG-b&nsTT3eIDSi4ls(si#lg^6Qy)3i4++%Te%QsVO>NWh+gYkB`n3Y#ek8{1=~BaH z1GUDUaxu3MQ8!B7j5n%SpxDr63$H$-ne11l)e|uy8f{aQ_km&J!o3;%;+U05-yp4W zAc19~7~$ic#HZM=ywDMM?OI;N%wCo=|FoGkj7g;)YFZB2t5=KO0q_m>6Ygb$Jf9obT3~4GoLo$I{SR-`KD@$Eb^MsqIxzN)_pca;o z&O!d=A=+}LV?#INNq9uF4J7N@bD5R~p{A0nc%E;3Pq1m=)pamL*CO7G6u z>0LoO#wyA5yo9T4=!8-##urhP#jawJPJx_VTpy&P1c};R&&ROQSJ@r!a^%}@4QS#c zPmqt=9i*q7<7)0ao5+fcI%bP6k-j}84j}zp9xlbi@HHZjjC=vM6VfWCID*>fh?A z!5*Juqsq#5+tNgl(!)|;sAQwTNL@MB4-iTat~y zZMmwob>eJ6c#w_~Bb0M!?CN zKDG03<`o*GqXY@*E9L!gj)*OhuMigslj9hhvxi^)!x1RCb7Fp7dvz256 zB}hn*DsK^VeAREgb3tTCWzL#WnRrRV!)_lpsMqMoaNmH#4M*>`{`lZaSq#1wfB z-R@&~YLp;BUP(>ySDouk^2z1F*&kF_NT3#tc{RnCjBu$RIqbVnLB9O*3=mK{il=1bHPj#b2#!H8E1l&RLHzJAqm>=96NI9KNLU?`wSe*X67m zjSrL{L0(Bs@mKBrr~B;vg)>)b4dNG)kw7hsjaQk}D)?MbVpcmt-~iI;b@$-QDZ2J3(1Non~NK;BK~*h5AFB}mB8 z)zjA^#&<}Jue{G-y2|2nJUHg9#s3pI*$eZyTHd>yF^R{w@5+pt? zTdr1o9r2wgmA6rHIgmgt`Q0cbE(7ZdB}m9#=`_uXd?oZAY5~xLE6XIkrS13V(#-ouk&i@yITI6FOP6i2-AVGe`v^U}IOU6&gn<^c#AIBcoiM09^N|2E6aB+%76su9&=Vt_JVLXdHS13V(R%d2A^}zRB zVabs|E&7g^bbG6#97v!9iCoKF|33)S!gv-%dH*`4u26!6v?tfQj$B=>lTuempq8xl zlsFkI2TG8T^(I%Z|MhT4pqA_lRs@W_Jy3##Y=60W{jUuqP)m-2lsFlmyD9SrN|2EK zSgv0GYXb??l4CO^P6q1=B}mBeC|9rlwSfd`VLXeyu26!6^eb{@`Cl7IpqBKmDRDAb z4wN7v{i9s<{nrK(sD<$?_Hv*E3F+tMitN8OkU%Xt`=rFlU^!6oBk_)Qzy50j37NLE zFrGzG&Tjn0%kn-bK|;>EDdqTI1Zq(oLYxc|C_#eiEol}0-*O;+l%fm*T-<+{W{{ID)bmLMVPO|DBE#1HF|WCFEhU$EZc>~)0_BxL)itgni=c8~Z# z@3qJ2y%wW%WIA`3FcuPA;)i&j_1epDmF^RWrCm;pxj~}C_xF!`^JVnT=uOwWr_fH( z3wonpv)EPhZl1^*?LTSzUT@Vx^I`HE{gIVs#NV`EvYiN&;1_Nwr@IrDo8AE_ zUjdR@N+7_O^4+D zq8F5S@6+GE0-yv5xp$WG6*=e9z1}Tp_oF1)KmxU-myoOQv2$d6$*hrwlXoCdf`r^V zv%dV&u5;rp5sfHUJ!!8G3DlC_PQFdJ6p{B11{PT96XeF(-;O=~WC;>-@67rtt0>#u ze(`ao9pUmn5U4eP_Q+OWjAT7}xRQ;=8TKdV3MEL$y))~pER~%0MPJjt&u-d{LISmD z#gX(zzp-`!ve7wkz2|1?i8rWrQG$frZABV~Mn-0?yJwd5=+-%?v1l~G|zI#-{;c{zJSV-O`s z$i1_audjDPw|W;W#@P~Tb0knpqdl@p>BHH}0D0?a*Doi$1l5)N5{Cq8p;uCr%J-Wk z@4aFi33(#Z@upnTPd}MJEgJL5F(HalfBymROW)fAB}kBuN!m4V_$xUEn|`k=Bv6aSJj8zZWv?Gu zioUC@q27)XB*aSqL~JQn6TEdV&JNSAGfI#k zA7iBWtJPk+y)P8vjL|qp0<~z&LnMb8$Rak7FPTEV1SLq2k4f59bNH)i{^ujo7IF9z zBv6aSJVbKjF18+{*4L%swF z)WSFr`n7_@ar(?5X?f3+QR?cv`7~GW1MH>y6gBLn8)6NeiNaP*ks(GD0#Ky!lBqHs%5PhWI)|ex?FGQonZ-uM2@L(d# z^ziTxHp)H?(TC+fVffsR@fBtCp@(BD=D%uyK!wrM`)5VymJ=TzeIw-?_ZY*f3&z6l!&w>#DUboXF07 zD`Vgqp2pxmbWUavGwQn=)zCUV7jsK9&esrq#H#q1M@`E|q67)FO{=~yA-Z`Wozb>h zDjy_J%bu&q$07RP^UoWtbDB|@_ihthwT*WUvA~PIKa}J6ybwJkWPtfL;X*V%3yJs~ zuG*`Q2ic4hZ4LSbuma=sTYsnJD;_N~X_iS|A!GYTQZM1OI$vD%DAlg>jbMFIP^?*h z^W`WcPz%@dit^8*5&Fn76Z!DDU;Lgob=F?-{jA;5)oSOudA0k?64;xrvy#4_%l9b9 zT)XyKuzu`ttl51)=NRm}m@BkRu{e{*>OJa~;Z^j;HUhQK9^Kx0FF+sl@DiUB+c*Z- z)NKZ*)!tv??BB>kDz4}0_k^?c(zDmS$NMfjWs@Kg_cgP&aC8Z~~ z@%yL#jKOE2*7D*xw1R2&vqif%5fM>hg#PE4DSS>&JAqnQCPi7@Yp}lZ{(7FPMhy)m zNMNng&0qN%>(_GT7yoRot_^MSL497{WRJFAQLz?&cYmr*rMHQF^Y2rx$}er9Z~W>k zzHM@eK?xH0&Y&n)?*6I|-aF+ zQLIay^cud`WIwfDb!!yK-g#xF_3pPPuKLV&jm0|mye3MJz_$z;tzJ-Xle3|y)wr@B zYN5S_-m$9J@+g+!lDx6EMh;J%euY?+eX-C$0=2Mw^!tNFJ@mzORnf+OyM_`Zuny^W z!zaeG1_ua*^LR-O#|di6sx? z{Q9q=U+VHGq@`(@RV{TVjP*Jw?*wezx(mCKI7&?WW21@$YT>&N-K`$Kls#-YRCxAw zHc^5EwiDeaqy5gVv>qm$rtk1W2@=>&^b0OS4XqyCvsLRv15%U`IrFkR3;KxgUWp8!i}bwB#r@{@4P(2b3sYUC zS`ex>EEOXDUa;0g$=0p0)!tAJD@-n`4PF)#^=H>GHav}dugzl&HYYWs8y!AXG*Qwv z8^K~&%_+XP3Paj%LoFX_UWytg($qet zp#%xcE&UdR(@LJ<+d(Z?2TfTs;AcQEsTB{=t*8t>R?lG ze0fF#wVEm0j0uY(*`~oxKM>k^C$S*1tien1!Y}?}q|cY|Q#Mn#Iq0NAvlBU#f}e8Z4 z=C@~pExsH>CEql;kI3nIf%nPXQLf>VcT!Lbb4xdD6zn8Me^tbwL7(J0F}Z|DV2vos z#i^skv1KiJ(5fOa>kH>IPxu^Qud_~3_pHusx<$sb-6a;u^)Bo!HKLngv$mMSkKNm3 zlOXYXBWII-HJTk*yPC>jj2t1BewoNu_@}lJsD)rv z=V!N(Z$-&ZqO)X+LqyM6eNDI9_oGp(V}Dn(_U1!OUsUS{!hKVSI5oAcnfu~F`5u-0 zEF@+Pay93K9bz3KYEZ87UkwovU7gIV{he(DYSo$VYVIqM$Oh#sOT@B&LqwmOrHn$C znuVhc)M~N7)x20CktI&cNws_MQ;2x%*4;RkE_G5}ImackJSPsTsO3wiBkJc(Wbr4e z5%Kp>D(tgs0MW7bz>S~5hImFr}{7S^i!(+si!xg!jzJrFXhOODq$Y*X| za)5mrK7oi-n?uCItToKbJHrf=AaQ!Bt9iXEec!n%_Zgz-7cf_De_-CKv{Xd`wRXL5 zF>4(<$SN)AOGNoDA!5Xvxu)qK>W2~}?B%HUEkwMnbu(H$I41^6j#@}7%DllNMex1h zd`QKB=(ca2%|Ft_vzqG`+4>mvN<|5|+)Jo^Z}U!zPRER1l+JwSB3PN>+$b2F)y(e_ z$Ku9rF|3j1h?DtzdZ6gK^Do|Pb9z6NAkn*M4pY7lwkq+5LBFozxKD=nEo>)H%bu&L z>R>T%&?Y`(-cSQ2NMQNsE7_EW!h2=`k^kRu#*)dO4cB2N8};y%VR<|G-7({|N5+f{ z23xo6Hsy+Lr4T#mE~JY?Tl%2{3Cn{!zFOYQ-%ixhoW%RkUwu&v?O`3#`ykzN(}V6} z>3ybiG!m$Vb*LzVs`-kCTPlc44W~rst9IV#9vjVmO?+iw4(++Bx3I9dRldHMd!u-i zHTTK<*7-Nec-$b0UD=(=yd4~C+*%dIUVZyYB`-*K*2q8aCJD)F21%fn-Nq3YcM;l- z3ASvPff6JzxAZ=;eTCqqQy^NLypB~QEpUq{`yAQ+;W-fV@e|_Fh_z(Wm zKnW6qYfm?FbckR_L$aE5M)AleUUozm(XnblKO|7A;ke~Sn|BecXRk~|$o9TX?f%!q z0yYA*unraFmm+D!d}n`AesFFRB}f#WyV3CVk7N}_W*{4}X>y5@3tET;i_^y7SQZkvYhxHC>)Hj(-FME#6r*_BEg|_$v`ybq*LC48XGRG( zkIvOlf&|*5-yWd5ujCzA6-HK!kTE`1bZm!K%OXe87xxRB3;bZ?=Ekz7>)a5Lox9s4 zNMLShoaf2U{>n2>yw8)y4`X(7w0h)sKRk?W&6(c}SvkuuAUKR=pXNsG@nZO1t^U8k z;zh&)0}0f^+|nE5w@%vfKSBgA!F^GJ1m;jtzGmOTzMUK-0-aX)#bpgprydGtGrBvQ zX3z5Kz*6CC*VVkT1(UvW%D$i|bIuQE6?=^o%*W=Ei0g(rP2mPG#I)71q32zI-4HWPD8_ltI>lOReTv3)^N z8-ZGwTRI_KJ-1$9UvrW0GCn4#Jy*Sgqgcl--;C??HmPdwNS3E#203n$zN~A2U;Enr!ycsXA|~Y6DBsnR%Yg)ziB1-8Eu)`aTwT0M$n1x? z!kl!jdqFLEBAVs#eMvr1N#mu{pUAy1uYnRI$nzydCpgMcrC(D$J|Le^SS};v<~Mcq zYm-GCKVmcte6G&vVXy)XuFKg6B2wrEL0#SpSiHztIU6KPkg(g>a!=KJ9`+E2hyLk{ z5+pFUiZVV+Up?&3S-$C*Tg=Ad*|m|&<5-J|8;qw_GisOj(eG7_jga#>#Nt$?TZiQB zU#FuMB=7Ia5+Z@QrPb?*j{25@sYO_xBJ!;$xg1DfjZk~E3D(Che`M}ldp`#E?jJmJ z(ORCR-&I%qfA|W(n+EIEH$5?Tlz%MOKFKjqNTk`DPun=_06V*KIJKa>!>a|Y!sA-r zmaErf0<~x@k@R)I(SoZljMB3;>%za44Dm&+IahLN`Fh5)BYeJGeE{J_wNfl*HU-f9g12=TWNXX*l{;5l|1RDDU?`gX-}3f zCDvKqG)4FC{)a#}PiVnz!TK)GnI`5vW7j?(`5TZBBu|5(dNIXxBkCH80KW#zsS6@nP^`Y5N zzdJ004-%+_HKHiD_X@M|tzdokUtRW6$&>SoHZZsJ-C%Vme(c34y@AKe%1EFVwGEsm zr(VLhkzcG;taJ=Ykf0WX(*xAzT()`FQ-i7?fm&AD(dH4)EG$5 zbWfC6X_+gU8Q^ry{~}P!N;_s8nxllA)w;^Nalb!HO02ZBK{Fhj5|TuTpAe{Jr5!V; ze4D4YdN!Qy_(h48mNsanhSN^}i$E;O{o^dJ;Zayr>sC&$c`&1S?yO*`7KItTo@O`0uEnv) zv}vb(-c*zYgp-*kH$8uNq^6o*9dp8f?L#>&A%aOUCCn(d_i= zXS70oIM7KH&pDdsifJm`@@+M$PmW}*`=m2*zoc2n0%O^R2&NawOtbUmj9d7x{qm?@ z^ls}_w51U;DV)9j$JvZoKfqA0g|m#ka@c5Eg%>@_hcxIb=DfbG;d7C&+o*Xxk9vLF zXt610D~4~?TMFdzTOJU`^4xMW-!Iq{U6a10?rB+oaurm3t2R0)SmeL9M8kLT1tA;# z_S0{@PTG{;M4F;^j^xv4eI3XHOVkmKTkKTx8Ii1UsedSM3!Aqd z8+M}X6Hk5a!wvk_l%fI&)Ixg{`*E|Bp3%ReXxC;v!#AupJ5y<%bRWpJb_b2O(_X0| zYYmzQFH^3%rt7IUXmsDqvhX&?_Y8A~leRu2yx8@E;f@gYm?sk##T|_f=8jqWI zhNA=t+_R?Ne2Hr=_SK7vspy_tpbgy9N19ekuR4guF)w-WfF2yZ58AVrV{nruqK-#u z@qWr$j(!QX?0d-dj+7JrdzmQGXr&(#sMV!gobhOV6cg|NrQRNM-dh}<>@KRrRp%%{ z!rmUE=N1tcX4e+`E2p;+sD-UbJKGcTvC8**h@nfbvG)m0)i)G*^jq=VCiYj{#Z#1= z39fqYT#d!C$)jz%TsXGT6VX2P+s*7^g+SqUu^FE?C9B&0Mi`rTD4)&7$iJGZu_MD- zP_JCncl*wJ$j;a3AhPA1%1<_(s-|xc!CcQ}HBsw_J+gZN?4j3i;l971&Blb$=lnMQ z8OHA3bG6x{ebIz8`oq>8#O=fTY@QSC*=gr>dw;v)!!g)aK=jHqwiP=MW{H_vT2T7J65Tt*!rw{l2p-e>KZr zVBRr@MOH6XcN~ad6;%1%Abk$HG4k1HR`1AQ+dSAfU3Q}{MNs9sl+WDd>1*WH!&&hY zxrr#T;u7~C*<0MKa>m4&6X!&`FY)cTfsg7qQ2gtbmLq{$=!xhCd*7}9+ZXg43P0u*Ds>lu zU2bXkT>N5Uw{d<~YnHX@V!kW$0D-*^X9gTyiqfFrdiG`V0Flb2FvlK_S~xZprOJV8 ztXaCA;&Jy}HUhPdu^&&zA+W9=7~H#)74Dv zeu&kpe>w)=aq!DH?K&&t#j<*}d;+Rnib7%*eE}b{KAxqWUo9GG+Gki@PS0!liJBvv zc$Mf!YF+xCDHc6Y-+NzF+h31nUP15GLLbhkm6k`dsCge$`bE!6s@`NnIuYt~mFYE1 zbx%+fyOAx8hFVC|nYdA@#DFaAMYcS}`S0a684;amwe;Re`+cQh`1FoqFG{D;J~dox zG%Op*o;;Voy_{I0tw?=w4!=I`1$+MdA7jbyCJQ}%&K6l?k1x64t%?gni;DbBf;mcr zS9&vJHNj$+?pN_iiZZ`ZV}4Zp&93(ApyOBlfM4^NzvfC{?ltDC7zaY}47XDAwx=@- z|Cg-D~nk|kEsJM%2_s?Ko`r@9Rt#$R# zQG$d$A|+vYGm+}ZVzbSf+&bF8T0olKgbPpM*`~G;vo^lxCx4r6tmqWMvbV~qVFVRM zR#CSaeAh4FOdHY8H?zQK!z1fv`E8>wW9o@Q8pbVBl-`mcw%ntQIGN)s$EZ?_Vzoz@ zuBq+GvJ|f&qHY!y7#ECM7|W|D{ktAz@3S`*^>*LpD8b%mkN8d9t(MM)NAOOYopjW~ z=x94psbp_GYIY<3;pidO+>KMOGw;MNK$n>g@+0^ZAb@$HnTM3&J z*vSbqZRbtwvFAsIHPt4EIEjyQ+6t6ll)pVrf9+>FV^X1|adTS!DA$FqT3MB%UQQKM z`8-!GYk@?zJCB=+=R#-FkloOMhmlGX4{$fa=7J4PRXX5Bib6Bw&qT0$* z`lXErBg;}eZQo;sZLzLl%?HN(J~WKo+AU*iOFTcQp5Aqz9SiHJHxGO6d%JrWD>$jJ zhT{XDL^1S7RGuk~MH>~FJxO$+VONegYOwts&w@$Kk!j&aDSg^|+q(Q;`9y{La1k&O0MY8lyov7qBJWA{3QhAA%Z)VyE)WTUpQ6f7w(IZ}^ z6+D=-a)|L}r)om9|cDmmg>Q~D?HPxdxd^2bI z=F(AuBNb0zC`!@!N7>mcK4Q*`90Cc{!rW3sJ;OFYE%c}qL;vq7)^uMVk@4#e4GGk;N5oHb zTgO`u7$mBuxu&__&t}L7{@aD!G_z3-WAno>)_>^YoJ;J4@EFJI?eY~l+Z5L^nj3ALoOU;+=AJP2*GhT*yuiquqC(!5qR6-5 zCeB5+5e#byif><*S4?!PCff9fVo2cF#vIaj=T4bL-V@xO`clXr8BX~*`!C(&MC ztD9_iLN`&exHm%|g?-H4uZnqZV7W>U5Vx|Guo0+b_dd@yX5$<0bP*ZOJm*+fm@AB2 zPg?Kl1!hG1yh$4wg>&DN*{)idI`Pb9hj+An9_;h1xF~ghF27gFTlBBC&QSYDu=RCQ znSK8;jG7rES-qLhjmecv<7!+ayM6Pup(v~VZ7HU7SY}R-&MtPuDdyJPG3?FY^@eY> zV%`pmVIDU&7(=H#H^R$Dv%9xVB7zTU;$oHD;_UUm_~<1UjM`J8S;2nyjqd9%7z?7J zS&REmjPrR+qpg1wd)@8}5h0uW#fIXq`R(T|Y`Y)N*FHD86)>0?d)gRtL@{qXFxj|# zu|#}n*hBv#aFN-2{Ug3CQPHAv#ISAt>y6vj6>Zta7&f5YMx(K}llEbk!Aj1WLd1Y~ zzltm)e$oCekwvJkAB+XZ2}InW)3X2GE)qQ@P1+c2DQrPD!&Qr@eUR<%G%JSQSYC(dJY->1Dz?*zqcgRB zE_>L5^+&5!80V_hh}p+_51tf_G|lI9T6X4>=|+co9enU)Dcajh=U}J0C9-+0WhhsZ zCkE?dHe4_x&mJ={l0$pvqTN4tklnb^!6;tZMGL-tfE_>AfrycvgY}nfE|~qQxofC} zr%dgHd1sW~b4e%OTcFlW69IIbpo`E&3C@W46&}-e@%0ITcYoY{+9)3<*!5xA%s5w zaW5R{wRuc6WqjPceUQy-)kg%)Eza8A7_F8|6Tzl-%5LJW%F@7fYD~FE)@5UQlYTM1 zL=nAu6&G={RVslJB+xe1)#C@O!iM_d-uf#XZA|*No~kyXTSz-Qn-z=HQ4{^c+4%Qz zkL=YgH__LtfY=z9%J%&OzejuySZBywW!fThH8wi6=r*^52p@mSLcqi+34C+hZTlrYGH2a++9g;G2y7YSnN^Hw#$h$dOIq4X(m!<%`fKEU1B3p3-@H{o{7)_qR^eq zeBNM%;~p~gSJUdR_%(~(=r0Zs>9%a;>!#c@QGx{Zd{d4UXBsPt()Pq?;pg3i-w%A{ zi=GTUT(+3p=G24nY|h%*M0}hXEYgG?H{VUqsUd;;FcZhPnvc&MV!oGEA{KfDi(F@} zn6J(?u@R_+-c?cjpZ*`ZzB(?e?c{#N6 zlf_YqUsigVmbN?l{DL3e=qO4=$xu;RU6uHJ*bcLue}2>JVG_&za=HCv3PsxZMpGox z4!g_znW)-u6nZ*ZZOH#!z1yvK3-9CIp12{Q6d@o==yJKSQFSv^_g@=CV%f2>Lj zL}B~L!}PZx(Q|R6EqVAVD;%^(hibr7LjXWDZ= zk7jNzo@S+knT;+LV%Y9-9%jz7C+#Im(Os`+K4M%byM#Y$Fv@UWH%=hxdSpXmX1N#^ z(m$7($kxM)+VyeBXw-f6+uQjgck7ifbWi*XP2}a{a(QL2n zjUl&u(D~VBO>gYF(y|68a-<6B4ab?l^4ra(YESx3mgd!ufHU^rQ=-`9w|Pw*S<<2vmhc;`7u$z#87I{W=a|g+ zF?c;6%(L^l??*}uL}7WfpXZv%D?EH|yVNd7Vjv28K~Y*nAK?pHWrq}RaiBWE|?N)V9iGe6QWzbD3-G!fY zTVNYwWS7n#$Y}l{lS9KT*cU;w-TtzxIG*E!tz#x9iGiq})tt@PM1$Glt+T=IYlB6F zz!-b}3uib=kbyd%&NzP$6hkh@`JW5AA?*gpKs`_2M9N5Cwq8-fu>Qr~#Z_zh7g=ki zW#z+EtMa#O;=`>-HREM@JP?IFO{b!Kjm4R9{nWF43Ja7V1J6EmQxsN#*QvIU53E00 zphkWYJ=@-{a17fY>S@Y4G2+^3b7sG}{CM+m(i%iY@_=mi$Pdx%h+_Ex}&t%ENa>=LGd4M&Gj?+9@<9vdkH*qzM&Ic z>%nns`IpVc#qKHU#vKW)*AD9>ymDK2{qm}0+u(ePKoqiFb!RP;Q#|Xs&6~Q_9 z#0Y*pNvE_4v13;?;mAN#+aDfU*IAUE!@7$OH(qiaDI6PVWEG!} zl!3o>#*HP?yW{c5j=Ry{~wun0Z?{Na>l568O>hQeL%xk~v zf993*+?lM`wlI?&8ZU6hAww=Bc4`&Yw|sqLpii(g+p(|km&nuUWF^Yya>-~>MrI%i z+eDF*XG*gtQ}$|}>n978AVaQI^&Rh6>!EeDqE~x~S)M~n(*nMFDT+nlXNy8yS?yKGJ#_H!Yb+a8ju}o??Y6h$!&0WvN z`s5eg)LeZFTQ_&<{Z6}@<8N!X#BrkN%s>70iF?A;F}Es7BZ!PzH?OHzgJRjc+?Itr zo8r|P?)YrqK{0D6K?d#Aj`A$~n!YDoxqz10bBGv_bc*-9pVfA=9p$N+TEINHrcmt# zO=DPv^##m&_wUtONUvuZGgy9>ipZhe&ofLs8lEIE5GDIrzQ{e5dDQDD-XH$Jzr-G~ zy`fA)LmqjV=)?3W&ly{W+4RM9A8Y4)QrDk-{@Ozn-}XXcAgVyoFk4G{yW2Ra05RI` zufu-6=pw$i`NR{w=G*R2WNG#?UM8ZDttcxNhOnAl+Y9%Y?|gF2EUNWtVYbO27J?$r^@<(p7qH0tDeq2XZcjKx`3`Eu6ef0B>m|m-E!j1Cia4&d^)A+W&sc3QF1>=2{I14nQHgy3GBYN|u$4DC{?S+kRa^U;Hzl*fr}RM+q`M`D{|7F2}K=eooZ#wXNOt*4uo< zqKQ8w2BNU1Nhf-LV&TW~lJ)8%utyO!esO@>;A;%?ERvhbST;9`eW_AJY_DHkVjv28 zn!YADVV3s$>DNZnPNAaO{-DTWSz}o1MKnjjdX?;8C9aK^He^~1Usq{_#6T3?mI4b)aO)gb5$D@cbB-znPnf zZKCt}v!7IK#nT?9&yk<@;MYDC2{ ze%YBAUjkk9JCQSZqwh^52BPR?OHF&(qD*YNnZCdjyiGqBk#+qwwM0?^>wNg8Va<|@ z%i~#-#jlMGtFEbC$|q8`(0eo!FAr{||BTws>sI{8k%6c_Q7P&%)y_7aI6{osBfqmR z&!UX(+j~e>2DVAIWPWUJrr&tipPPGJ#rC-=YO|4vtiv5{d}?t`O>s@2#G}^@xs2qK zUG-A7hWt>0cN`gr!uK!odYQD7CD%H~|9sd*I^*C7qK!&l8lUA%ulgPNqx`J}))(6$ z_wM$P;=1ScMto-nU1A{0Og0?fS6geOP+$43DWOmLF^zA#?I%!z3>+KMC7u2BfTF{B zQ|2Qv5QXDJrwP5A9#QW#KXu)xG;n&S49LM=7 z&62S%<5`J@?+vRAH7t>}n)8Ghf0t{e-|2meS3Pt}VjxN`<91w#?(^c9nfc)b6<@~i z^-X^B+8qQIuaTohF3pH(&NO|2{x@$KkCJ zQP?IL^YNkj#4bCwTs-*EWfPC~Om*)&36F zoA=MaFMT>-q68Vg(@FTigu`sl7aNsPdDsL!YxZV*>be{p8Hkems@&f}`hpD|_*Lg@ z93{voqvY1ge@S9xK3LhT{)iZ*XUx5rFKy^6F%TucTJ9jmjXaC_%wK#tN|14$-Z|@3 z5v*&LwN%D<-$8nY!IAvO*WMBXQMgtpKeZmHcfFLrw;t=vQG$%8rxmSah{@iXF;qtJ z1zq*z!jJf#I|WC zg%9waQPCVF$iO|3ZU;}d&^?x2;$F3HNDM@w&ZnEAt4t4mo5b%WuCCI=%k$mpl^XMn8T;Zn;Dup zb7UY2=P_j%UmC18-W+GnTA72R1R2wv@@l@XlGw=~1!$I3_nx31YL$cg^=>0E5QX!Y z?p~Wu&^Nm%eBzn*93{y3Nh_*)8^KzX?o4HTnlMWLcBT?vb$zMCKorhnx+$7CN*@_h ziMPJJl%oV0+h|2KIBK%k(0Now*Pem;z$@K&Xw!I!fhe5E)VrGl_1F!~c$W2cjuK>` zj-frON>}}p8p<0+K9d-T!kKTK)f(&FGPmaI)_F)4GNRBfrFhO?YwMrA2J&?Yr6mTU zP-{~ee#hCMyVs1luZKuy1GFsheoj%w`W$0=j{I6=o}m&0QS!}Qsmx4g_s`f0%quUj zCa7!ic89WsUI^C9%)i>Y)+?Q8@neu2wBbf8ZWr z78wxAQGyJ-HKg--+Yx$>O{dKL8}>^KMB(@=%HZZB^gkb;HG^jD=O{r2-Wn>(y#oDo zTg{9-Pr-W<15r3*NMD7t(S0t=HoL^Eu-%jizWzWU$EQdeRi3ioy^0<^T=Enj}kEL2yZ1R3&OU(w5j z^==){sv&zDNeo1xH9$ApV?u>7c8vM>!+irIkD{b$It|5UIz@TVv58ra>kMNP}1 zk7+zu@-UQraZZ*A5_PtQ@(D{;N(@BF{yufSju##)L-_p#2Thb913ibxbKq)!QSH%H z9(&8xvJ=xikI*AY_RMU0y1#HLd5OPUT9Ts#85q?{cT2II#X;K-etJPIiGe7LiKV=? z8JO_x?Jmsi2RQA{2CtI9x;*)8ptmXdV$n=IQ&xSh6aRi4ozar!KYfhepYy|J`w&ZpZ>^5yUPi^R&O?cFJBz{?)_O)LY;rkDn|2=hVR zmSU@CUV#$4J#T({mc2%^7}k82)w>y+GG1m?GtsJ@o5VmAmaQnRfz!Bu*Ji@?Lk@v8 zK~(1d_azq{S1San3cvyHX;n zAAV1oi8w>4jK3;~WV*9#KldT8SbCF@b~lMAc_wZt>LF4Be8kTA-z5g3R&`HznN~WPQU!&1zz%TXumAM4kH(V3cVX%kq_SqcYqN zH8&TMznc4HUxuTGqlv44ZqLK=m{Gk4iJ(mvIWiC>&-Std3!C1JhKN7Tos<}e!q*e} zf{R}_(O~{y&CbPN`cL1d1wGqq%>{{t^RGsQJhTk{IQ0M)T&m z`%5|z8Faof|2=L09x+NRqw_g^6iQ+s3iUC)a}r|!-2s$rNXH>)fn6@Gwic=||S~T!maS$c%Y8&qci@`OV`2>B7i4tVcnZexYf0PYe^Od|4MlB2#Uk|k~2Yvn~ z?R|*C9hH0s0|tqqS~%bS)^X-cA6v4P&9rkS_M79BD*Wq)^O1i#&b8?yg$%Oi&3{kA zQ*(_LJ>vqo$4bZ9Ih}zh>^H}0TdXqz`OCSE^LhG6A%o5(=D#QUaz7@BGUa{vZ>e_x z=?p|+zflBF=TI?WdTq0Jp*|+|3HB9wwQo7U;>k&kt{yB|10v?%+VIqAqgiGe8e z{i56Rpb^5e(FOBq_kEIA7&6G#Hvc`{{Z*&GaDI@1*BEe9VjyZW&B{Ja4EE-a?NlqD zf?dSst7UnPx_@)@pu^~zx${$ut~u*Fw?8fyTqC;#Pe=&8hbzhP`+fl>KTAX6k z9BXIKjpM}73-=JWjvnHVtHnwTMB&V*d}Oz}iDmAyc;$R|trKMWDnf?bjwU?=MZJsN z_`YWG5(7~<^BpH)(e-jSemM0coW9nOf&E6lgA+!HCkd5!^?Q!fb~*!5xK`*cW8x@5 zo-h2y6~~D_eXSt_N7M2h3=*FTIdSD_7wZlnoq;HM4KD0CK|H>lg^%-ZW!+4qYfEI{ ztW=aDpMyov>0M^|Dw!k(qHqN(%J472!uk0w)9Y_1>xLy=TOtEj7oBT^L&fep+094G zyI8k4=?p~Snx`nb-xH-FzK?uWxuRl7oQD{Xe%F#pR#NIYp zggrQ`bvu}TVnGW~j=M{_Qb)At@sh{feI=b(5GA+6c~?u(t=LvRf7n;4<%p7X$wulOAJKedZZ}Rs5ZQP^LXC;L!dy*`!QMlwOYooksI@w*l%=`dTS~_kS)8& zCz=bCAmjGVEQZU_7`9-wmFuPbmN@R(G^;rOwSmMy6!x^D?Knulo3sFMDzfcW3O`_nhKU+g1W4$UwgXMR~Mlg?21sE-`d+ zTZw@v*-vIt+t&Q)fo|g2$P|v_j6T9Rnu>CG`+hz)>+d|^bO)*Bh{9f=JS@IzOy?cW z;_;^rk_QBqA=hduTcfR=>nt{B?*f12 z^Gd`9W7vpLsSNZlLJxdJIW^bY98^1o=eCUyC_#qoQ&eP@yI$^SE{e6xFEC~leX-Do zO;J>Bxp}gTyEweHm9&508pIw^l!tpX{#SS|kuAK1#6T4481ffg@RD~cbe|`U`$f9T zK#htsMp5eZbQ6A;F7j1xeI*8>aHm$3KG)qu@%dNy*5fqyBf~o&E7v{eCwmPO-|D!SC_#pNM|OYO8~*;VCjJh6&yj&Bw7bYJ<9IBm zJjbHJcPGidLI&CriZXz=<)Y6Zao|!6M+x5bp}j&;jr&}*m?NQL$bsKf^cTgODfH`A zl&yQNXqnsg7PE?{NS?BYLa$kh47+e&YwJ8hj2pe5qXZeU$M1!;nOU7GJ;g8aZzKky zFfXg3OdmT<^SoY0^bxugD}yl&@~ZVL5oB(tnad3B6)MH!APS=+=b&S3&=Dnwy?3*Ai2F2kQ&yLb= zfIEP^M?Ka0h^2Xh%s(PhIL6oEjR)R2kUzuce&YQ10`x_w>l`J>kk1CQ3w06pOVxPN z!S50SQD~i$4L7r;X!B+VXFWc1)NqKx(+l~k9n(b9$A|gjnJ$usLlo*qt9Q4Uk+bMd z|3qOI~!&`^h zo)(E=+bcPE@^t9K>Ni)8vX|I+6tN&8>IGd_Z2l72HX5}Mma5z-_ zUd+iX_xCI53xar4kZmsIF|B``4IjA8Mv;1yza)oC1N)5FlD5v2r=(-#5w@$!b$|30 z!OS%D?Mz~0uqh*D8E&8jPEm?**TDB>?s#9_c>`od~`UgA|z|HKS=&GC1Ruo~4DshDw5 z&Yu|Fb+UNjT92#wR@sn&D7g$X!zAIJqZ=>zvABT}WVmrpbK>%2tlZ>jH0Iv@Ld59i z56tDwcB#lf6!rprqiRZT@!-uJzVTZM0`(=zvK@ zYv&j?t5jhPe~G?++B$=tQ_nB%6)37x%hj}&Kef+OUk-|6u8(tR?hU%DeKt`Z;l_EX z9ZApg={ce*2;bmGEYY?}eNrZ#EzF!<%Ib$~ivMWUNFOjeuXy~W8}rDq0T8Uyas)gw17u+>>WvgrEo9w0e zUOd4j+ncB{AM$GrFCAy1QJ}iCT7IqAGtwnzhEN&pw}$A)e!XPw3+ZS>RQ#;`+T-TO znUa6(e;7k2PSLaN&cbV!YpCIGAp>hmnT`Gm)&KN0jpcdVZMd`4|J6&o=X;X%X;EIq z+EUiKVt~%|K&Gkb(OOo#-iN&4p3N6PjE+ts((q;t}4hQUz^e%b_~{Zl|l$j8_GsS@CN* zHS#tztQhXxAM#KcWhYhP{OvEsT%TaQZJ99p$(PZrOsfJ~-fo!-fB}Z`@22kbgap<%?*tll-J~W%OX?Nn(6@{Doazp2x^Atf$^{#U}O4+ZeWV zzmqn$=_W^fZKgX;TG2i;)VIB3S?+QQ}cUl2=LZ$=3h{JYEGss9=Te3QyIBF2Wg?nuI7`Tp?ZnS zGi|*nPBnkq0ve*CzPi*dNckDI*031suDa=aqMd9`(XFL+;3-9})uBOU^}$1({qMZg zb(A1u_HBy!%a_1Pwzxy>Xg~Nk+h5MEHheimZ!*29`ZhY6weIJpjo&{*9kV%xwfrN8 zHt|S+x_osEo4Ce}YIWdJcRfmp@ISfiC$pmBlQUAxVD$y6ZyjgNFWz7coEE57RJ<$Y z0GScCi^_P|v6~*d+HULKI-`yfWSpa{m9IVRtlleYUg;g1vYWL6#NgTa^``&iDJfFP zOT&JP+?!do^h(W9mM&>hb_vrajI-z4I!?#>B8qYrIr5jJYM<9)A?x)0rET$dnSm(m z5z0qaei1vf?||8XV)jvjj7tG$Y*y~2(SDZpS=W7zHfQflbHn2hY0Pn)Wk$~Wp;|Bh zQliqMdiu$1S#3d->*Z`c%DD@3344LEX7{he25j{++Ol9BTZ<^HA>AkLsmX4~g=#PQ zOx96?47pZ^W|!9Yjmxh+>`+(7xrAqDpO9;6^2k`0|8Hwv{nf)&XIZbQGj_Dov;MkP zZBMySR$R@j;mI3U7g-rC^61C+71Hj_X(P>cWMJ78~gPwpv|fxjzj zxXb2o7MS#hk@dn5iGe6wH;S@rQ)B&zPe-*$R3RPvBrOMorE5!!|D>*{8O_>PN}#@4 zU81EP_QzWFV&{Ak15vpC73E{S?)tit)77?-ioUc^X02#c6g#(kt%`a@W=!_zpdW4c z(SKmCoDu_3SRQ%f9tqKvytV9e@}AN>YUS6ud`MtZE*!A+5BJn=4T@rY;;!4|`;yXC z1N9b{v)LloU1wOU?NxJYR>s9XCAz5VopNiVx5TpnA(kz=g7Rf_8kxc9NB`oQz`nvc zM!p}H{q?&y!&SErm2{LK1ACqB+dZ4>_IuZD&8mAzy^AfEwP<(WQ1-A>TajZ;2ED?K zOzIz$nPTEqFKtSmOzNKBV%X);-rC(pYB&7#PN1tR0~{HO;D?8y~|uDs)bRyG7KlisE&04J+_%9$%R>NMa!B-PzRK z0JoP|r|pxg_On8#d-0~qV7+isfa>>#yt(VSYlxE9;P+M!Sli_txZ9kb5(81Vf+>IE z(w2I2afJ`bb(!THn4)S8?CfX3Gb*0d4@O*5Gf)=lKVJVyb9eDwRo8dq7Y~)q?3aID z+1u{@sUiw9+Ed)!pfYr#FDpu}jg=UP!rKJ8Rl8DLS7+*bS*hz+OC zOAJI|kI*-&?&*4~`CoXwUQZbI52EPqUA;lszfTu@O1U3p!KfNX$>YJ8t}R=6o-QXsI{TQGyIyX%y8M z7R_EJWfmXIh7to&a;;{Td&1miZs94H+v`6mZ;UVH#qciXq9ICNAATt=x>L|f-s4?E ziGe6Q0Z?Yud>gd^a{@%6VnuY+gQ%}ikCDf3OcZ>}c>)YUs(3s5K3ie*=Ccxc!Q zSVQting5O*5}CxY$Fz?C15wyM@-KPfqF=2O&c(X`{W|6MsJ0}I1$$=J5GA*x%wHw- z`g>>d`!{My3`F6Xf$m;C+R^tQSC~JGy6U*Iv`inf=4_)6QoSQ*jn3D$2;$m zbIPirp0Vt`Mj3;?2B=n6;L6RdFKrzC^AgkB4ipjHUrV>-h{C;uzI@X74O^I96^B26 zv0fU|@7s|f&#Q6E-1YXeDha;jsl-4O&PwY$)#deSiz(`UiK62!gD7uWA90kErPlSY zH0Bd7>iXcyUwP!s$I|YMDBOK0U--nX%zQ9f=%%=n^Nnd>(%rnOMNDM^b$&%Lj{9%mgVucUbuL5v;Y-S}?9OHfQ4Wxx1Y zBAkD9Ucpd;42ne35-SVlH*YhQQIZ&iiiLCcoH7GZs8=YL-Xh8ekQmMLcI?kkf(-es z=-t9Vy>DzZ*GI|>M4@)2tf#B`>D~!<`Sjy2HIyI&-=^sP>S!PRK&dNy#;5j@rbZOL zg(=F;3`2Bt+ce|ds)E+=(%osdEuXW^{ zS>dY8Korg~M;wdZab+bw^uR*bBy9^ zDLag^IFNfM3}+}o2F0ss+de0;PC;Q*MkFyp=LT{$S!N&#=NRoDD@N&$KP}?gR9}V? zWZ-)wc}Ek&?aLyb5+XAYg>#HNIWmpVPx>6k`m*_>O%Zcf0=!xjpq(tse54wl*C73(zA0$BEWC3(`N97{{xL68Mz`lZCg>59Tl2ZuGc%MR19g}6#xg>hxx3=5sKlpEP%D1LE^82 z6S*3cEHMy;esZ)suL}}i14i%~A5Lp1K?cprbWbR~|^4!8uf(-P|_en)n|zm)0LEW>#9t&Rm~uKTVlbvXw5RAxic<>lp4X zyk}GrmG(VlbdEC8ZaL`0n(iA%r%B`Ad(r*V@(A0H8p313b%qjT$R2cU%9I!V+Z7WN zi=LJkh(hl@@~xe%iRbi1wg#IwN}hm-lKuLI->)gwRVXWFMGcV{h(bR?iVUN#&@L)6 zKn(XfD4lEZ{R8I$-Gz_xWtVn!6#=I|N(@BdTu_w9o4e@E8#U!iDt%;ly2Cp%`G#fE z$flH!Yy_`9ESGe;LloYAQ8p{i^au5mdFG=zr86g@@N`eFOSvoP%F=5*xIsaQfhfE| zAl<&Jy%xGDR5XccY2e(&*5W)SkH%#aOfSmo)NcK06=x!{aaPip4{gQYJ$l7&xp&i% zk#9p5qtK}sw&q9yty=lhc9rtwxNNgte!^PlyxiVqVs%M3iGiq%H)q)!)7um~W-(eF zo5EwZj}tZ4XXE(0ea@CPKDfuSzq@&9^IA4Fst=B3S;txTKF&v`@>@GviVNkk>iE0J zuC<+NxIc!uj<;U*^X@4pUO5#MBcpFIx-~HJ8SSi1mBA7_cFS*y)`l zF%X6Ruymieu$kz`a*9_crm&$?Q;aNb1`FG9T}71aDLW;osc8SjRdgsnMq(fezr9V- z13z@0dFViq#h`DJ!}oS^Y-C+B|7WOpo4vRF<@mk+=-J(Rsh4@Z@=10wERQOC#J3+5 zD(>wcZGJnNMMDOna4(_!q7y>JrwP4HV?Z_yCCDIq&K%eID2w)Sr5&ej_h2#qVP1Y> z=}w7(DBMfvJJpF3#IYsK=nJbE815N}A{)iD&V!2_=fSHhMhTTpcUh+SN(@BNi3~ij ze%Ei3_2&8*6qg#$L&W@qd_7fdnDieS8s`dDD<-?TeMc72yU3jRnbZE zz(xk$ikSc22>;m9Rn&Ok<7!MTfWo=7yVtHIyL3%GL14^22P>;tO=<%x(`B#V$mf#ZG5s z$Uv0$PY<)qlw=k^stV1LS&=uVE_gu4fdDa*U@sRQCA2BPTh z%uKtjr<~+NM~DHrPMKLt>}TkuhYa+GrLQOr7$K&k*wb`dyEIsM*Y05c*|xC6Kot7i($^iMf`#Ji zZ(hn@Nb;>k297^v=BAn0=Hg0Y-_TGEe+yAK#}p;Qt>)tTUs=_Za``21!@`txCgNEF zTdp5f(bJJ`yzX`t4Zby^)7@J}xgSk8`rh)^N4wP$n^KHY;}V%$mt*S4xhY0J-OjQe zIZb0ex91cwInOzBlenOv#=&`vZKC(uaudY*kShH8*J=zIh(f(WH}tj%qW7*!{MWA4 z8A_0WtAH$I%IQ=sJdp2C94;{sg?fczMstrA*=GfEvaqZhuXIg@3|s}237T?K?I;<} zo8(v_F%X5CC^a|Lxgz2G05MR440(O%Cwqu)84q&X&S;5&C|ScTI@d$=KYW0fIvXwR zoXC(zFyU?s@geXMU*>T`VjzlgBO9Ld4F%^rPiV|9Eov%~G;eXHxtk4lVT_fLb^9`p zZlZWQ`fg2&p$r*_lG_oHwYA7p>;ZS4CnO6QQF5&!P5KURf+9+FyDu>ig%*yYBc`VeNd6{e8M%OJWpUV##pb&pNDng-j#=mZ38wivl+{2Q?AuF7hwgSx=Y)qRdcHQeq&AqKu99ZV9Z}EX$X3*_YDde1-S6 zHx=sYzvnTH-LKMft1Cq)AOoWnC`&>3cV4#nGNbwC9+F3V;>;OF_I|NUFOl;&yrOGgQ61{^0V6GBrl?{QOp z=)Rl8Kor)N&YaEM#g@?8^hNa6I>vcmJLKM-cfd`wA9T*f&dCf!VI+y7ys3~!ti4)V z^Ih3Sicmm?95J$}v#WUV;-uPke>;hRD2!I1nfR>=pSf?ge^RMn9cMMJqSq%*+8>mR zVGnv(C$Icj2l96(NAMZ>M(Zd+21a>M^g#6kJYks5HyJ~8jNo_}<70$TPL`4j@@jV? z8X29Q#;~Fp-K`Zh>>hu$s~KOkyqCm46wWb4se0lpMNX~cRX&cEB402r0NY1z^lvJ0 zXYUAp@wbr@15wyM%2ZgvLrhrvl`r`0c-v0D#X-$UuhvGIACRJ$g+0Xn8Q=J?`(y^9 z=w;kU^O;hV58Op8ja|y=>+hCtRJsHiIN#{I8A&dp`;V18?2E&jDxHBSxv$*c?&81Y zKF_NR?4sis;Mm9`$o>vCi&iOTmS_?x%@XWg^y{>G*UP;1If{3jKp%JpPf}!{KRv}F zm-Est=Z)sW#>fmr$^N2GCsq)Bra$ES!aQ{x1Kgvq-xS3*yo}g5GM0C^Ue@wLN*^g? zpkARD#o|7~wZ%++yo$qPE1iKT>}iVTs8m$UYhRzA*;H5349Gw&p(wrdqr7^J^~TKp zL#5e{qlWQc^bLieV|-I7Ycc)6mE5@g6RSO$I9 z*(&D60X-X zN(@AyJx4dG+v53$DrH39$z^qvAOo#4%0X8>l5d}yO>FP?i;fax;Lbp~?thKqXFR^~ ziCQy>fhgRsX#ZH=f}eWcP1tIsFtqd#g*yYin3P^+HuuOQO1iX>bRudyv}!4jaHdtJ z`>WidOzAd~Wr+;*Xr-1L8_hN@0b)da5s85)yz8YefVrL5VMY@4( zU6+49HbJa-oSWg8BMNUO>2+ynTYmp-CQ+nGJIMmTJp*I#6lK$qAhT|(l4jwyp*rew zL}5G@ot@hZGym@Uz{ou@R7VLiF#1$cx@;J2W{l(J))65R15t7WXysWO*!*EFM8`vl zjy{R_HX-}m4d=VqxuSzbf8SvY8Hhs9L%L_^9>p4L`$d#a&Z?sX88na6{gV|XvC3N( z-_c)u^!b}115xOUN72iI=~aGD;v1gk)X}RDUmEb8f})qFwbI9b+0J(y`oNHZD0~Z{ zw{7pPdWRWR_{(iyBwsZ|;d=|^%lJSZTV?_Nbm?`8fhgJ2=}p5v`t93e%>C0+Bu^(~ z$Ub);a_wdpavkC09(R)>4$+4mV?nIUEdyE6hAa7>BZf&)jfg_;csgy*s>t@;UBx{c zjg%OOLN9;H0`($--@7x34;nL6vafJ`pk7gwAubE~`a8L`CMn}22BJ`RQD(d`Px#sB zZM=TF_Bz@Lh(i03R_(m^Jdg67H)`Q8S(eB^jY@ZxZ{PB@`Lcz`cZG8aBM;`I2qo^t>+;15x|MeHb&&C1D<{^E)#%uGU=bY4x;0~@chjo;dY<5!^Z>+;d{ zJspu~zcx5&Cm&(*NGa8@oN;|;q=v8eX{e8rAKEZe5WQfX4xv}eJw}mV4oeI~$v!N7 zDg^1it18=q17=BnWav47<&h7|^MU%HV(SxEUAe&!g%Neucd@OAx*Ye{(yhTKrw;nd zd^6*sCg#*pf($vf&)nHsul#zuQS-c;#6Xl4QD>Er8XMl=R%3n8;{3km=t4S5{>!kk zBBU~|X7bf1C1t9fKeMmIu-IuRjQOJ|I%0GuMz;Sj5M}*s)ym2`MfX?t3hOmz#nfnQ zG?L;6G4@qvn4?_viHlN;}ti(V3(hVT>xB=v&`lwdb8Myx$JgQGyJ)3@tW=RULQB zzqQdvVjv1@NSa~hOg8s~zyIxqK{`e$cjLW#uYI8P;5qGgHtk z;mwB~HLlPsK?yRjhGburJ<12Z_)@E5&`^niC~O}^7GH4_yGy6ecKlteA+0D!4O7?P z+$pXi<>5=?*#gIkN|zu5@3JUc=)JO`Aq)WuMF~JT7O8(2R*s1L(wzr8GusOHy%>D%}eU<*Z*57IR3Ug=B ztwF2SqSK(YhHFcAiGe80=t22?ehCuSKgJprTxLqSMKDhd=C7eU{Y;@^<)sG3%1wtf zlwc+uIU~^9*bs5(MVxV-pU_Z(49qP=U$FlYB7SMQ-stx}(nLlY>dE4$#M!pgY(4c? zpQfEtO%4%@9#t^hSDZ9af{c?Xp5}SqXjXah4r0U?4H9QM2Kq1AFpKA%;$d1n>K*Q4 z{9Z3L)6l6=%?$h{%4l7%u_(~)d1T^%!Xj-9T2?-6q>YXBx9auDF*d)eQ3mA}>Csv| z+j!8}=<6;}f{f-LQnR=$*cU;Jf<7HY#ckjH2mX>%VjybT8pToD`=_^9Wn}gh%QhdY zag63y+Pqqv^SNRDZOx_5^f#7ARt6zEwkr_hPfTPW%KF>t5vxs9t5QWo?vN|dp|ct} z7-_$YY~xXa(YR4G+iQD6jA<9#M8hVFjDvexOAJKO&fuuws@~QY?C-D7C0frn?K@ty z6(~W*y3T7I8ECGq%Sep&caHF3<0?kEg$)&!{;=xXB+}JP(Mv5|}%7#6LSg z!;N{R8d4mjZ6;se!#8SUE17|{!uHWi1MN6(?m2cffxnC8HMuiB(i*32A3Uk$587`u zXaCUs&(|c+FpwYvv)j% zSvj4)XU<0qrwKDyYwrrS(Sw5|2BK&;P_5ilUi)2%ac0I$w)F5UKdX1K3`9|1saDIa z@uzs4isNKeL3LNlu+5_*;mgoyhS^sdIOEzZ4cW zFUntB>?2Tu4BCBED4dok=6;o|m3Rlf7%{D@P0e(N_+(^|$p)R?Ab%uu{XoysV1zp#G9*FCDX;m{1}zlAlFYgMjbWBp;<=aH*YwP^b9Vhu@GI({iti*DWC zTF>(?));lx&C!munjk9elu@var9~_HbkHAmtq?Ppv?xmc%dpy)%J`WzP%rX0_TgXUgpPFZcFKbG*vTV2d zyXPlb8UiKANIQ9@o(Jn)9IMaER>0gkY#B!hGVo+c-U&ZLbpNh(jGj%SOl07h6VHYe zJ+L}N|0{QlvE}py6D7!~JjF|^YmZ_BM$bs39jEP&isWS?j&@C`y>IvhhX%OYKd*-N z^-b?G4x4(wKI{FwNRz2$EJjt+RvjhCKwJBNGY~a>_yvbeZIw;=8OD8b)arEZnA#{o z25AfjBh?D`b^o!?s@07O3Gv84)Sva68`jAF*l0zQV>?iSj5MunS>*qlfvC^U5soru zgj*T@viW{;)at^6=y;SM1MOhB3}hhc#ls{=8H0UZs4@d3$eBS3}hhc>#Gc=#kkUH6_qh}*p&op1nswb z`=JCGvb}qD!c7Nb+%rFkfvCklayrV$b*U|tVf&yt7$s+A@8-LI$EfJLhwhv4&st zqll`z2^FlqYO2hsg%V_-tu40j~ z92)M!XusOXKvde=3|{$=%D@$c5@e*UN9&yMzZr;1(=iL0DW=Ro2{O=H{@)BlrD@k< zU2^=V43r=vO+ThueXrW2?op#?k3t5bXzz0D1{sQ2-+jftLJ2ZPfDKPO4(uPuKvWH7 zw8ILYH^BM=7?yz&WT18azZr-c_Rk)i+Kzhv>>qb&|3C>cXuopo+NpLs?&rur6!oEM z>D~WxJklk|pnh}wf9iN(%kj4mMYG4jNS#->pQ8jBXlu)Tg$zW|DsYtXf3A;o2{LFN zJN`e_-o`KumjyXi?f25Y)~ZdHAcNMUaUh$3C-D5F}erFW5m5@euN`@b28qFusK#zn{eA+CJfYwaIx*F;;p zLAnGPXw}MPAOlgfcR7Bkuie@QJ&T3cw#vw~)m9xP$Ur-oUiE9HGS1!ivv$t(-$GPc z*-ae#`LbL$eXTNj{!u->9hL+cY2{gOI&`xg^~13{KiHH|3x5kyX}!MnqGgwS*td&X zGgV9(l-06J9JR6}$iUYcy3=30)4?d&Fq_0cRNCxGwXYhDI#J6iW6y|M@hCw?+W4ow z{NOA>2BOmDbgG?*43r=vZ62rEiHqL`J1omWdHpQQ(y?-^-$GQ{+WbH5T?fOGAR}!( z{-5@4Is;K@IwsX_$9A9u8MI$Hekt`;=6^8|m8L8I&+~`lw=4-V()45M`Qxx#>KSJi zopJED5Jh{JV>d`WM`5i{f($&j(e2=?{WY{{x$Jj;hUy2hwv4}WHHO`r>t{B6_c`%P zzZiCBzMuK(&1e6#zbTr-`&DAAfZzPj6C*G#kNNETS=NzIRmSEsYmGR|4%6R_Lp{u! z4^Oj4v;1i1wCl(04@ZsiPyRbpM+q|8{g&U%KI625O_^J^tg?5^IX$r!m4OUIMLx=F zcFK6lQ8qEC9S6^4vRUm2{BwO`m8Q4YS89h<_g(y!qa9Y8(*7oIs_+bn@!Dl;jp**H zyv*dKCmp5Z?_zlr6X(1(v5#9Nqrvx3{a85<hfl(hvt@182qFVf z*oPD`lA*Z2==zIoFEQ|Uk-hz=mzt|gYI&4Z#oo|={q6;7^N*oAN|3Rs|KGNcS8g)5 zP^;y7*NwJ-jBnztF^_Z0tJVACILjX9XVMsGt2Q2Y^iQ{>yxO8OC)s%Vf6E7Qq0Oaf z|AaxrKnXHvWHoDKQ+tB{{g`x6RmlL5@ck$ z<)-~|43-`_V%^qT$A?nPsc3`F7hQ`FX>R*A8$ zQPz&2<9x+=Tw=vz$jg7;MckSXFs1v>0 z8*7%t{=+B`UOBON$|pz5k%6d8w>8823VCwOwrbV9&}#dvcaQvis2z86x|q8!op;QV zu|6(l#gZ2szbn@&{Bali=u3Cv*ZhZpC~OnmQ@2>25Vf+F%_N4^?`f-O$BwJ^azAgf z6k1)fF3ED`XWW^}P7dvZ5@g6L>fS?Do!)PM;@WSaI&0{m9r|>dRi5H!{&yZ*Ix%(b z+Da|66?)Y1MRju_4V|?o?jM0pnl<$Qx#{H@TwWRoW zBWCX9e;5^)eDP~s{a1hMn_D!_YPRs34oR&Wb<_+kZ?ZMCifG2D2hRWg52O8gn}6g5 z4~M=&2{QDC^VDsXes`2j-$U&OJf9P+b&e8b;QA*|rxKNIf8LtrC<9SgUS}T{ZT!!qf7&tSOC`0({ss z83FloX!8yxvt=`^*}kM%Z+UTlkMq z9d$U42kLx9xzXaX-+;cm9Al0WWZ+z&@9?@+v#r{?-_ee>53*?4MkhJe9nC&%S{>o2 z3HBlR0eoI*+niyx-Lf)}fhf6m*EfsztJkWJ-J0#zu9mh7zdgFq*3v1~u|8zmr(?ri z@pX24+bu1M5@evBCo5y`e0y=Xx%Nv`E7WjFr)L{4@5HfJq}5PAlDE2dLHnv53+&cD zh!SL23y(gd3%S;SN=Z+qU3cx;f1q(QuT8VI}s(wK#f896R$r@TzlxYgMs=3 zQNL}wY%~eJ_D{=0yeitogFH5C1W|&F^FOW{#dlq0mr2`^ZeR7re)@GI+cTQG$Usz^ zTE~rvORlk%4!h)G^^8WfWKT68X=-F3YJjt0EJ?clPpw{+d6xL(nT}np_qvg`>gD|HmR&-9rKYSh9lv#;*k^T^H+SrC4f4HEtK>HRDI+;(h0Qmi zgZ({~ff8ga?fa)1?_{#8jxv^LE$j_kMmpwI?jojfCQl+OMRt_crp`MO9lvYU*1d1S zKV>YwU)Ek_%ut6VgA7E;^Qu<;+xC<(4b_-0sWUN6XP_>z=Jdb1y})t7+xYFTwsXC&bm(_IoMEI;YmEX{}w$ zkEm1oC9-&$8Pq2A-;=sw>}iU8xtzr~S6wm8f96$Mt&m}@f`7++?CVDM7lF6@tW}E=WH`NgY+Rjau&EAP z^zTLs6X&~iO|EigRhzTeNb{B$T3p~&6jCI8X zMX!Y|76x{qfZg5gHLzQ$XLm;tu&_}C?Cw@fcxRUf?{Dw>A0N)|`^?7d%$eD9&aR#F z#$v~x540e0xLk((qM%CpPSfY8kct+Y;9t6DS&@I~D|+Q=-RDFvQ}Z7+pE}#K44L3* z={E2|j5a%?1qq`s`lIdJ*p6T4o3-ZydkOXoZMMttB?RHroG+FOarH#;NA^6JV?iP^ zqLZ9dAdQ6T&-uyb7h*h5d5ZV={lMQsRSDaga_O)%eSYMzwTFkq?yG9jdm<93(q`>{ zS4FFw>&Bk0>Y)>;LO;<+%#*6d9_-ahUwNSg3G`8U_CTLi=H??kMNfVnX#wn91aX@MPsL+B0K3)84n~x5eEWJwoAAu@-HU;5y z!CU5@(Lc?*xhI3ai$|zEoBxjOGpAKG=e0Ye_bap@@lVVp@oA--*4|!bUpaBtz^Ae4 zybbgKPVdhp&1%L-C%zUz|AEJ4lrAgoymQ)eiC-(UAc5mC|BlJ;G*j1%$@*N279?;C z<~cvBDqC*9++o&6YHiHOxd%7{px?!JIfHM-?t77aKafC`HY(+ec7pJG;T)64KjFH6 zh1UwNk2Xu>c;vhmQcyp0&nX+skF4`Yjs=OT2Oo(QB9pDVIH$7KniDgJ{U3oU?TY;C z=^v?h)!bt()BQxWAaVKkdvQU;by8a2%Q*63ulYos6W-4JIWHYoSc-mniOlAfcK3@) zrS2x_J5!pk{qLG{46!xety{>X`R!;yLi5|TovCc^_4B3ArnD=oymMCcITj>veZ}7* z$1N7S+a{QAe>SeXkia#ZAcRbIHn&?sE!xTpXW{8QlfS?Zi|$k5tj%*R7uuWYhI|%n zY)1# z>+^~ow~I^eM(cAT-^*B^VCt~Fu6|T#K|*`x|Mf+~#&}zLH)*V|inJBBw%XE0eViF^ z%-3eldZyTv+xlL{ao60^+Iv??MNVkXW6nAgdk^2!pDUZAKIhW?D{m9wpPRfU zmDH~lUK2bgZB+W#14!BX!$dOb>plZokT9P4(+_rX@fU|JYx(HIkL1ssinFEWYv-(9 z_>PxN9eG#3UzWCfjzS9(5q`P!iY~SEIqK8Y>T<`j4J@5_8%UsPfAs=NeEw6`mGJnc z2k@SC2AleHSfA43hIp4%M%aL;BjeNWccu zh9i$HJ+nQ_KRq0K{>EE_#T^}^tn+!kjI}1yp}Kl6!5)q(ZI$}({?&l=;pR%c*XgS} zRK4%pSj>Bee+7k)T}EP7mr~}nbH7-=@$*0n650wmXFr+OPaL}`ddqJ0M)2~^>Fmw#2aWtO*3a<~6m2_u1PT)wC8_Qo8!#LcYj zY~%gFdw?T2f8$=^nK`zIyYA^Dfhruqd3|cH>FV@nEojSOPV{AMH@WS|H1eW?w}Kwo zxSk#5QCX?P=cN~4-F=JSrVhR=(Ioy2X|y0wdEI1LnvzPI6xO`gkyS0~{PJeHV2%?- z3$8yLwk?&H+`UDVH(DEavr^Rk8cQwfR|B%VG@mMi~rgOsbN zwUPL(E!)s(CJE?Skfx3(tdvi=M8@+oJW{BLvS8jtawEuFd2rN0p$jgOIr@35=;Fu5 zJ#S8`r|}$jupqH>o};qCJ4vVcHxBFYc-J5w;`X?pfk0K8298RgDT(~`T04(NE*;s| zC>xUgvoJ*q5~bfcDL-GHCx!I;@pN@}7GHj^^lV-pikAIe&dQ^O=kz0)_n??Ed-*xi z?19!sN?cbK@}s{vwt9Yw79^YoIxC6w&sxu8gH3NXD`SC3#GeESR0S4uR{ZE$QuL12 z#^WK~*yWf$O58A8iWVg9&nu=3d6GzKm(%XY+g0AoA@Zc!=%5od*SA&j2A|aL6pqyH zcDa;SW@f!s4cB_GQYX{Zd25POv><^q1J7!#QI`$yx~rbr<76OERXgUZ{Hvr&BJ_TB z@>OMK_vov-%Aq7h3ldjry_Cz^C`7KOoku{A;w-35Q8xUbBTdgrm)mmhWv7R?f@{vm z4@vUXd=_$1e+GYEuw|wmrCDwMWgxU5VRH>Q;X3_>rK&lSY{}GmNs|_J zXs?ue>W(tSXpz*f;$gmXQ;qj(TpMim$R%YpW!7hKaDhg&W{O2UwXh^b3ld-bY^ACb zPLe>qUlnxlr;YZ{lS2H8(uC*5q~s0fbPM{Q`JXvSmN(}~P5plS=+}nEwJAc}yo*z` zAW>3sl=3xBvflIb_$E|7A4`f|Do76w$}b%$d4&w&XDB_-BaQ2NRey?d@5v)I?r@b9 z(Vz1R_TJQ$@5;ID=lSPgL1J;{FR|u~Yu0D*`$-dBYxZF#z>Rx$cwP=eT!kFO_r-^CTA zL%ZW*hoBVe*lzpj85uXxl^*<3lA;BP>-)Ef@9W+ol0LToJu`r0EN)G6Z*-(o~{g?DzQSZYZwSrt^62fu@s2p z73H+^u=KC1T-s}9%IR7ZEl4C-4#mEBcY>^`tv%;2T685z>%CaRSFW^sb!Sm~*0fO{ zJ(;;1i8$*^FO z==G%wb41rqQAGWrwXt(s44GH03Ja`Mi=qXI>Ic_}v(`kDF#Q?q*`N+}Te4YY@2XRL zI`Qp7z4-;DW(Q7@?Rpz_rRvdp55}wYylYXkAmLiiUK;cv&N`nf^Sx+;j6UQ=3s>4? zyMwfO?L~5wUq$qmehn-v1>d@4ZDZAS4;tNVA?ezs7DWpZ=$-M(L5J$npSFuhvV)s} zKovg!ylRAxD-AmpK^}jrWf&Qdz-OL+CFW^4y1B{?a^BgE;**N&GviZR>x3gE&tDK% zcMpmdBsQ0LB-Y-QtdFkz%O~e%sMiNppga6Uik=?s^`URat7eC9P`kw!qi`aeMq1`$0KX@R$v{z%pgmP ziuB`{{K^PE?xypx1mE4FZ^z&2pRU06t!PGyRN^%O!Rtc;{W+fh@ywltTV6}veasZS zFpRQ6ze^BiFL7l9&3{Yd8#Oj~WJsW|#dG^+JyqTB3hdzpk>ae6`{G6q;N``$YMX<$ zjPcrrXh8z^-uc&{uD(<=7MEhjUR5&?sIoh`Pp;G{hD17OqmNW%nHtddq1yO%H;NV{ zYDcV-oAT#;k3P1So3lqv=vbWHZz)o=;CnUnSNM1(&Qm*Om0$@Yn;1M9B+y6UKEt#L z>Z1M?__rD)!}~#eM~nUnuhE#~tp?J{?8r`P@MMrcA4L%6ceW?NTWYX9gGD-Y>n-y; z?y)5Cu>@ClI3DwFlY2KOSw||c3e}nzMrtH*3>JjS19Q_BNAIbVE=Y7tA@W_xrUyEm-@h|-?St=`UZ6%sk%Srk*KSB2yFy3d|%wp2<`rAb%VFJs*`e+#v7HiWEmq+tFF>)cJ!em-40l`+V0Ca&DfveY}1Mk?t~li^+LF|Y{U}2%4pAE;@wwjw>f##;@A67v>;K_!(L(U~M01Yt$T>g@aW!|JzDH4XbENZ?M0AS`Mcrxr=7!A4fDYM7mIKF4{N*X(EB zQs=_%Y+PAF8`>N+_2(;-fA}s3?re6NIo({LV5Ifxu1Hxa9rzsx~aY;BQ#5Ac1QU{&o4hhe^RNW!PHhMuzt| z_-?|uVkuZ>Cb3(5Uah^YFGULyjqmOgUvY2Ak?*kbFCrYjONx{)#%gY@L(zf+zDeP+ zYCDq1V(!B^xivNrsG9Oj76(^}weFEExU-pTymwe_72A)Z1qqK&N#eM?7SdNAeZCC2 zL2C6%R$c63#=uR1TkA4LmB0vgwYk&OJs5NYaf4# z79_q^Eh_o9h__y=U^iddy!$+9%Q+XrP7(SHxCg}j#L-7c-4kC(T1Q_)q8GUu0YX(gz94Fb+1u6na|g7f>7l6Cgt-f6HWh7i`p-!X}Q9G zYc@Y#^nK8a;y%OoqU@3XMdG=@(=gZK?2NN0ukSO;jj^eVNYCMQ4J$~T2XSr4_jlh4 zY`B-KrdDoAaTgfleT+N8C!eoZkG7nluHoOxMGF$RQ_QQ8cz#v8rF~F)mXHj4$VlKi zkzcFTb(PAq%CYq)T2kCk!q`gV-qNp5-@W&4I;VbaKa^*i{JZOf1djRq>nf!@CCiP< zEKi^2hS>lK9OrqRqT0^n=YfXIZ(JS2Ito>|LgIPC8-9@R+Z2ZyK@=@W;7W$qVx3ii4p;XnPR&{v);LJu z>V?-e2q;5qZayzqt&R~8#y#ZRpHGu4Pm9#3+fa%YByeAgXQk|}M3)_JL;P}6 z!;TjcxR=Fq^&j0K4r}$Ym(1OAwO5< zq0Kr|!yYmcxI4@}0J@&s^r}xoELACvOE_L(+%s>(Cz!ndW=}&-wJ{K=iVf=^zO+pv z?)p49q?8RIJ2=#Zk5qpoS^z~h0*r!GUcc2Bq+df5ohNr~hElX3fg>NUs`6`5WDO-@!@uF{_8|+a-Slg_6#IoPe$Lu+GVi(t0Je1shTwNpY15}2(b2zL91 zv06P(k#m+#3KFPl@UoC{cFGlUu8-D6zH0+m#R(5coffS`ElwsU^5|QptuiI*Dyg;H z{GVKzHRS`@BD*JK$mIG|RuKF<7@N#r+`@<4ikT9Od=wJ_aHq=B9r7l(QsPLSYT{tI? zEs$b;qMG)u#-?3tLLc9pC!qxijKAdZ3~?2isd5V%KFn1?0#&D8M9Y3pQ^?l7+SMIg z!I6c(??AtAswQE~=IMtKGC6dM4CJE$(){Zxl^xlOcOB^HZ&f6;AYp7{b?>KYiK>CL z9vP9hl;P1JNHi|zxdz$i|L~?)K^b?(fT7|C5)#=6&@|m zWAVwuR`m;@(&i3^Y#=1?2nC_-?|3z^`%pS=)@})p1lzl_qpN&iZW?*Sdko(LczjMR zEDoi8m->llK?38f`J0A$+tj=JM$^R$F3U)uYGAqA^7V`~@=_m53fx_*HrO(TCTFdX z(1HZUSqs9NbKzjK5?YWj#%$M} z=%rqX38B{}T$7PN)#66SEXtB}lBnOo9r2fxT~$Nri2Ac6%s8CC&eJk4GM&VZG%H98 z!j4W470=}stqR3sSkAp)|PD9T9&Q31b^$?dp@hcSGpt>7^A^VS5i22x6Zu z>7@8*t?%aT*pAeXpG3>rMM`Kv0yE!u{lr%jNw4$?H1qyw1qoEmFuRNAa;1}P{G51I zuD{lih^Av{g!sr?iyh3#X$UynNt9Qn5w~7u1!-QJ>i%X@Z}}*ib@7#i79@;qbgp%d z6dW8x3&tMMMutCGG>e)|6_+`rkxBeMW7|CbYG@oux-yKu>uW2c1qp1A@9CG%MK7oK zpaZiUwXx(+)=ur(JH_#*Zjo7hkD)H3k!N1;Tkc=+9Y1y)Kv{CdAfhs&& z-rI*iCO0MpQXB8N+L-v~JdnU66ojcZjyqZ(^$`ap1QUa*WT3g0Q-( zBR%r110DRsMM4V_d<6fKb-};pal1Tq>*Pb{<(i@(fvR3z&x+H_rjWV%nIG;}jb1v_ zgeFW_B%uWfE&BW4ya^?p@(sp(ZcodnJJMdJg$fd=N_Z_uU4`qUVD>0F zw=0dyu%(}eMM!8t0^^efq3Xc^`l@pty6*H283|M^m}M(PC1179wMT{wpteOH5c{_i zC5(G*IHiy@W!)8Wm3sh43xfUjKw7T-Q__6-ideKDVQj-QVJMBb8c$~an5Li#+p~>y zkh=bz<5Ner38us75n_LRyMz`b{u7PcaWp*|vyD_A(n~=CRo7gdq=3Sg$QOO2p1~&2 zm94``#nYEXz8?6KgW>(bS^9W9iA?2B5z;))Ql3EfhJ}-f!(v6WAYp8yWA_lc=evmv zZSAk13fl|+SX@f#d4U|#dwbT!5E|lJiZqwcOK3sDnE&xKUnspC^gwEDH&{UeRli(I zO5d{1kx%+a-E3bN^%?6YMUGi+#+>}|D@#fPx1A+Txxa%nkAdtHO4n8{C4K*VOF|10 z#x|0#htaOHb1Q88Dj8MS-biyvsb=#tq?dl?kM@Sr{5yP=S^g&^v>>6iol|Xte~0($ zRJw6Qs?z%Ed<6+q@ztc1@HoNx)bdp-<*QV6_xb4(T9DAL^uHB0(jJ7+N4x5(Cp&v+ zE8ai9g{t=ME>gnMcx%6k*f@c1^$J&Qe!EFpFUiU1&l=|>?e;uPF7bYaG>?>RHIe#F zoTR$D<+Gp#31b_FCJ&>uQ=Gc2uD_x^6FFJ@6JHdP=&qAwDIY(6I2Dm*y*g#>S98-w z(0#8Ct6e5v)SvSl0#$gl{5zu$M$^Z)H>=ZkjWbjxKmw1DfAgh$AicQkiAo)oNO&Z8 zgw4O&N?}78IlNt)qo&*(K;KS&p!WKn&x{r%v@z!2tdy>$dQg*1ZnkVhcLfPlJsz1$ z3T&*Bs=Kr{W`6EUi{opH~f*2-Iio!+=qK>}52lfH_l$|%HB zf1+vzwx&b&lwt4Q9FesdASWxflkE#pEGm=R%ghSWg0LmmU(~-}c~*R2QGymEjBPA( z^`?b+?BKwiGxb?5C)a9EdaC%cfJGmvv2DI~b`xpxlv?bo&oRSMA%X4jU7T7K>4Oa| zm{_!io{^c8aa81Ww3sTz=yOKiyXVALC1Xij-LniUUX7ld)Pzk+J)x}!|Kup4N}Gdo zj#dyx$9vGPJxt8G-83!d_>Tn%JVO2rx8sh~Wk(10vuY^U59)mv>>7RLH}lkE!x00cFzQ`TTvYqBv3`|SBjpeqO4=dl1CZjK)c@T z)3JpTT9D9uyPVuL{!VzxQxfp5KihO{zJdg*W~>VrFIJ8sJM}qg$Nh6;W22$0PwCa7 z=Dp_R_H>BqDt=ubNiJ}o8figzHt!4>88DRj;*35-R6nBB_%?SgCMN1qoEy_#8J^Sbaj@zZ8VY`z}d= zwxLX!wnD;s8QVU2np1|I(9afWL1=#JwG`tO%I@wsCZh!jV;j#s^C>${hj*SIj6Q6s_Z^%2?w6HZJ3?9e z5gSCbAfc_>bE=f^nAG?6)bf@P_N}k4f&{ACe-h+l7f)FC`tEOSr*5u1nT>eAL_!M^ z+AhVv{VM9Ob?U69V_8anea-o&c2mhio#bbW`In~o^I;_3`)*b@78}j_AHO7_1&K=m z?((!sk=7Be@2H7t#q$%``8l%;zlAD1S}haloO*q55Zl;PU#0%3F16EZs{Gz5iagF< zrT#fu{`H6{ykStc2}C1qoF72Aq|L*~D7U{MwVMtORYsIu^Jrp#=#(6UqP1Q7NZtv2P>H zti_Tc3KFRL8FWkj@iCSxny#J4i3Q%QM(;YzYt4Q$*0NYT`I)@;vxQ`E4*+TYo^zXt zbsO%%W=@%<&F6nwN5a^~?)j})j8KLZIe1dvbypJka+5S}q7- zD^7c=8+SI>-^AtADG7X1Tp1CUK=SbUgMDyTf`-OhiQr8W*eSmLD4tY9e-oEOpbFdL zJ&_Y-nuV*47FL(gf`sv?>h_(&Ty~XG567;Q@JRUThP*=(C^+GOKCz<>;ysm;#x`MDI(}x{wU5+;1aY#Z7 zM&{x#@lo`5M^=1>1HDt|W(-=8z$jfonDwqF+g31(T)VeGL;_V9bIYUK7W86;qJELV zwXYfoRN@V6o&m0*ER2lP%DqeD8-_`ciuCKd<79@<(>i!!_F!#Cb zsGI9g3tEtf92X&1$bZXv=EJV!VH1V~(7g7GL?lr4uJ%ItxDBt-$a^S{Roj+_xy%Wm ztIvlU2vlLVJpabwvgc}Uhd{dE(sK)1kT7O#B^*do>+~B$k8i#rq6LZeF8$>iymDkN zK4S1U)jbnbue!r%=EP1i5~#vD0fG=yd^eBG7)4XI9um=lL{hVAGMSJ@n*Ob|Q6YJ! zI=AL%>a~5Kfk2fpL!tTaP__8G3G`y-bTe9z!0ZVA2J7S`^<(3Sw8hpsB3h7WIPsAs z?*smZrk{3HlAV|8aVCTow%;iufhuFP;Xl7hNM|pF(GE9i$ld%Fni@AqCr|jegwc%{ z)x&GnB;J;_UC2egNg@)cvMc3fnNls?+OO8vFRYZ|M7@dwMYJGcj1ZdKIk%eUdV1RN*=CEV_e_m47@!>3LgM2`xw%&pf{ABGTm3I2uyro%nP|6Vau98VTWJEk;n? z`j}soPVhVyy|<665JDE#n#kX?s1~#!fw{ecaCwd=>3uzf<|#f-L;_XrRo9vQ?Q*Ww z9Dg%8xH*Kb4mvEN1qqB*|&5dhx^bTckPfB^dpilpH4R8I(%$>DOvoXdYhqqB}h<9+FUn zvAjm&qU{T^>|=j=Ey`?03lbPv$D=Hg;>qn+!|48-d_ITPQH61G{C;dbMHbu}Mhp48 zG7zZ3zQDg$@O2y6a&R=Q7Z+zi3lez$1;O)vZMr{~i5A}BB+m@JAhxW3Q=jQE^3WJ< zxKOD~w;yUwlUJS-kwDd)??=V6lTxftRQnDU>EVPH^!IVuj20v?lbzQ{Iqg8ZT<``!~8q$ z=`HE|UFB%!`71;uP=%Rq{A-&}TGH+t%F)W=QUigiE8p*lLHCla_v1iYFWNq$E}ho1 zos1SF@Xqqx+Fv3)QNJb~xz1NY3li94cy0A{6X=q=;lypBomBfnQE6KDB$Cv;u7dHj z4d2>H{aRikO9yJ#s`k7wwEoFWq)Vc`j20v?XPnO;EraR(ttW_ImnablR2|%uU+SO# ziuKHsZw#j=hAG6UKp7KS3immSm#5H{3FV38wMRq?5*Pz52u{fpX^q|^Np-2Ifj|}B8}3=|3ZrgGtHs0KJ`x@~ z9wBCi@`%P0VYK*ev$XnsZy7B}>?rQ6SI5s@;}l&QMibZcR$N!^6_G#{X1MaHc6}JF zkh`lgapyYdhqda z5iLlpiq9`Kta^$h^U;-8DvSuC{kA8mW9n9xkwBF(cd7KWK)S;FsX8s%RzeFBH`ep| znY|cc&9pXpeeO#=_qu;rY%_hu{^oxSop=}N!w*SFr6KgS?}D$K4Egot&e=#*-0Sfhls5?YYJoIf6y zF}*U)cfC3L)asjz79_TfIx1?lRPGMedWo$EBaJ%@XJZ@LOZn`ZiQm6Ql4DKlDwz9N z>ep0p+wmw;Lhrlz3Z#<1QwOn~OUuY;K?3Wua8G9B2cD19m!-3#A`+-d4__xn@QM!E zBV1g>S8}OS9~Pc=%8V8yFh5cdmiMrywc|Q5mjAYaK$Y>#zicW*FAeX^R_)Cc(SihC zLqX{Cqz|#ROk$b#f5~^Qd5e-Z(a5*vzr$pq6G=8u)|k(E02*|mS7ezZJikjRN=khb%dMj*D6i3VGjqHI@WAsI?MBG zwLXgZSjHTzlQ%a=jq8N5Pjyrg2~=U;Ezb-qzFykMzk|Ij@u-17mA3Zj;1X%Q=iR5( zC5a`bu;LEE`l=|$f&^Ye{tSL}Bzb>KVIRu(kkEnz-Wx#}8n8~OJ|vWtn7LcRv&D15 z8cDoD?7qix2MgbST~bUz3lg(@iKXT56XfL(?TLDRYmZ8Lj$&%-Jks~>edSL6QDj$> zx(ZeY>ASqD{A^DonGmQEiS?JNt2U2g?QG04T9ClXW`fYM>kxH7!%1vqy*(llsQPl_ zhNYgEb9J+9d#V?{PGYaxwl<>$39LXR2#dSfs)f%@WufJV8wgYx&%C_7qx!h@RA%dc zOGFD2cn!I~JO8G7bl_m-e0RRsf6^3rVq}!w8?agnUPB%!Tl%J&JZ&(0_HwF-79_CJ z9M9VFIH48@4rY74^)VxXD!kI%XINc>xy4iVWY2m_xe_PkmF=^~aIBJt)o*z8PT^9l z<|Z^a_zczh)sn@9*mbAQtYpD45?YYJdO3W?DU_=At22n5`m|6%0#$gQ z`1+$@hU#;$H*0obvV;~S@ZJbQwb2b(#^L(xj`LCpuP9zOV_xd9CLWAOR53so6<#qA`C4p^{^C=b!THc`CWd>!S&R zb9`5Jwt9Z{&$_lUT9Cj>jC`LupgpS;QH(w893disDyJ(iV5TmAi+rImn=sgw?XVOx5U4Vq`J;%Yth7@#cAl^1(SihCLqRCFdko9fZlgLb zCPOU#)j`ROJxz}D_cd5^V)@iU%I9CF$lpD+YgON66dO@zzdC2{7zhy<$eKJnT+?Sh!?my7DCdKSYSM8a5^f^F=}d|!W1m&|=-#@|8}-c8=d zuc@qum!LXlt&P3Iv*oUyPas|RJcw08upWpYbiOv3wR3K(9{s*TL;_V<-Gj$p4V=tc zrTVL*2E8{BsKTqxpW40S*`A+s)p_y$7PKINS6UF}-kQjK-wY!YH&&AHtrb>sHNL@` zJbwx+l~IlOJfAM31qowyvB#T3*{|W3r506Bh)AHS)RSULzd7fuz5OuHwX8VhnA9ew zh#4(NU{zP1f6jf9*CDygxpF;=MFLd|Ul&t;?K?vn>Q8NjYhf(w#yT;1vW zG_}VC;``04{Bym8v^~#x^7Ecq`B}Q8^m+GLQawY9nRxkNoRp{c1NBC?0c_UZhvpuM zC&FH8@;+@VfA|CfP46~@Ez%=`T^nKNR%+Wg)C1A(du<-dvr-4!xlzk|tcY0{Nm zPt-1?zu}&D2&yf5ct`3gjNap4oC@|(b}YQFp1d@GRqfd#HYWIlenl~E?R(d&mPJt~ z$fDocnSYM&sI++ZOI2(4VrW4EBhq*+@6tBv*{UWcZLDb^P-TqqA`{;zKQ9W*PYf^+ zsIos(LvH60Y3=Rn4=1S0U+h+!&*{O+r{0!Fw72Nzf$=xp`+SxUZb$2 z>gc+?So2Z2m73cyXY+m-=kE+Aahw9M2tEHMZ0~z*k%-hC( ztnC`uvS&Z%k)`c?S<`j4O10fr^`jcTy^xYN@{0bum~qT=u)MO=i1q8p?ndnm1gci= z_$9aQbKQE+=dZk}noCY03uAgSv><^Q$2`YYr=(b)!7?;id>P1s(M>@n_tgOxAv0PF7wFY(zcZ5_G5Ji zEfZe^rs{1}dJrNu<&`3|HZa>w5V}-eK~7f5#j~*68wga*8?-~rjJah!kB#gfQvL1& z5>=oJYc*9B2P{m{uP&-EkBw*K#=Rnwh!fr8Qp-S~3bWGq_YF22CIjNmkVgx;8wgZw zUz;L!u6fgX=Cd!nB!$OY2y5lXY83x0noB0@XI^&lQ!&{+yJukj7q9A=?;|;u{{UIs zrM-bbRh}S0Dwlc9+OJY+7O8rq3b|jTJG?(y0GM{F4Eez3Htqb zaI=Keugw|#u3-+1AUHaDQ}s@6_0Y9i1_D*a%$Lu8-$EXGR;rx zSA@SG^mC$Tdlpk~+50lIAYsg7d3(s7241MA`d#od5UBD=bddIZJ!S2QJ-;WBrM32{ z9fo#iK1=S2FQ!^ZB%k$h?lWdubUinRJZb()jmhfDzTKNI4xAaKU#p+-5#r97(fWPH z90Bh8yxBu6f4Q>V4S8%T5U4Un_N%_#iC54U)#+n51A(eGsa-|N_hhpDtBy-&ia#PN zuy0;IhPf8KFpMe}g#J#)tY824KL zfdjjHE+5M`v66v670#mky_#)ab};3*+TYd3FmvMkiqWRLs^el;_R9Z)nz6sKfj||` zyF7zuOnEk@Y^r#$MoYsy$bDSp(By0%7vmOrMWBS{Z1efe(xru!3_mq1T5MMEINPV&#~Z5<;9>Py>M~oOk(Gz(1Z=AC9%B!!C&I>%j=Q)uvmz z2Z?irv33G|a#(Hiq9`q%N(=<5^5$JBfA&wcj#md}byG9@R;KqxNbFI=+H&C2G`)=# z(Ou;>6M3u>pDnR^0pE>UURQlKstT=Fh8hS|jdRH>&uf}9<4lh)MX$VGPbzN{S@)L( zrLSwR=xuz@Z!49!c2z$T%sv%__a)2H$LTXjFGp%1P=zBZudoqZo|?MPA~WuY3@u1t z_NgHBi>*L^cI!bhr%(fdDjb7(MB`;gT3}v&R=!So!&rjjhcQ2@-@p_ibiStsZEIs7 zP^G;Ksp%GLy&vuBHC0cQHnHL#T^V|*=)Gco1CN1RG+3S7sR0{Xy{3Uc)ye)H<;=C& zK0Lp=x7}HwxL3_uR)e9Jicb;7S@Q~Jw>?-&+)B0O0(S#}svFH56q^gDtmD=78I4#> zi7{%e7q0C52}dQU_-Qiqt64#nF>*Qjg(o}EfvUgfx*G^op|>mu1yj9Q^Er*xjBYg; zT9Cjwjve?%phX+F83DFWn6Us&dZ`m!G`P8B1J(4yZ*gxzq4#t_&?m zU_}@HwatX(YWGwRs{C{}5U9eDi03YCTB%yD*QH5et_&?mU_}=}_|Rg!x}S)`6OYh6N-v8@OYU(u5U3)pO_ut* zvVVzZ93I=FgtVh{*Ev^)79_BCiXglSE})QngzoltHxQ`8bpns;8aB z6Pf0}T4%S8>jA7F!ZUJH8;~&ryy*QHcLRYc*PAuO5xg!=wqNqV&W@)0=b<&Hxia*l z(euW-4gCG!Q%8DY&qosW)6GDj3jJ#ydp@)rU3n^%s6}fq^tMrjH5CM5!0Ae~63^)7 zz0p9R3TGc)eeXaOx~*^mnLe@xLkkkd+6jXz)}TSN_Y=mi6%wez*@s8(#5AU>eq0oz z`s!oJpZzx6hr^6~zH6|G(7bQ9$gNWKvE&bdD)iy`ShBb-T~T(0I=oB`hWknA6=McC z{|bi5otFB(L%kjDZXi&F`%FAqmDu0)-3w0l%yXeR4vYZVBs?dk$UktfEQ+-|AhuMsCWH{5K2Vl%moY#1h8abpf+a2s|AW#)` zKSb_RHhbrme}gAwv$`j}3G=Jy#L$9-G2`&Y;V5f}ppNoM&748o6+>AZ3swua~ zUM+HDXh8xq9Ql_pvo5M}Z5pt@WM>0`D%?-zZ+N%curr@3vV1+97_Mt@4;k|cc^$|r zg;->}l5FQEX9IyMT<-|N&4-m)m43O|8nY9_bq%hBj9GL(VhvU{?|U`Z4`%~`DqQdI zoVW{3*sCR5)gaZ8p-+V#8RoGG!k0G9*^;`!itx#i-ABD`awMAPHMt-&;N&nHgsaR`oUF^u^Qs@**rV*R5^rVM#)+Zlg(|Fg#NTuN<4ao(-Kb1RcQz2H!d+GVR;FbO>YO-R9VI$3 z^xJTk6KnnOJeJp#MnBo0o)2;{5U4_bj(}ith)FRc?x`tW0}4 z3yiFAHV~-7oe-W)?zWi(yESDQ&m0-<=sTQnHk|RS45?I$ns|V1H-0$C#bx@oQ1gebhIrF^=p!@qal!`9OO}#5PNy8?d zC&PSeE3ckA=~3|7-xwXwy^Og%=Ig6kQ^E8aa`ynmp$I6MMQd|~63lhd` zjtym6(EbGz$+CU|MFLf=Vg#w;fveV|if|Wc{kv~TpQk?wT9ClZ9{z>AF=iS#{w+xl z$}|wDDqre@*skSuYa5@p*P;hT6rg2C2^1|zRIh(m?32&ypu z$4EH1mZ#5xOVLyDSp+ReU@j0}d7Uduf5n%i`FiH2NT90mhqL0a$~Ub?wf==Y?d#`G zkC**P(1JvX#YeAAiR(Wj6_^1 z0NVFAAFm)2I<<1L#WODZY>mXRb&r&hZv*J0b4CJHm~ku!Yv1jam#RH!+&@_aElA)I z@^=%(nyX7?KRT`}|8giC2_9jdfU0s!|D5|VCCs9>YTJ_D|0qz*IUkU3fxQ1w_H)<% zQ7@Ss&3+yQfmG_OUb)IWruhQJ-^JW;V;kqrE>q{s??C&E&mw3+!q~>p(3NV1ejVxh z$P9uO17XO#51K8j?b2G(=0mawo(JaY<2mulc*uSybH^g$Lu3lg_0@02e-&l#_3 z^?Ix>uSDpu6IldRSf#*7tT|`H%5SbqfANtTEl6Nh1Ret!Xv18F)}deD=b}iUsz$gX z)A*cm*KtU3HgZ*EDqs9b(1L`N#_Q})x?%0@XZpLao0*kpteQnog;gVrME}Xvm^jCg z+A3KDEl6Np3m%{7S&e;h;;-SN1&Rc!rha)QH!PGht0jzW$PR|*rVal2PSAn`)&t>w zd)4}EZ&W6kZO%iHKvm5R-(`o;Yt}R0JFO{O<$jw82fh=uAfdFgQC{4*VkKTQYsT75 zyG=@FeI=;Esv$-qd~<8ovg8pmJ4>KwK_X*dL1l9B9N(wQ#5Sy9<0E8ay)1$%tQ2D; zd`ul#soS%8+-nv=3ldZEJ1E{h7p>N}57W@iT|tLc$pLnyY};A^eM=_6Xm#!BTDeozRoYV!+wP>5XcgFzZH$ykB zZQjkG(Yi{T?QWW4ti-}1wygf{dhAg3ANs78^IO=4k*E-sm$fZZm*wNLGg^?qd?Wtt zvu%&nR9nJkpEnYy!lUKiPwuotmhHc^rV{R z)tb3C_@;XqIcJN%i}^-8CoaFNCf;eu#_$miuMet>ZKM)e9nLFrIaJF<(SigXE#HC6 zyIgIt(x3Hs^wV%2n1hGsBnZBtZPbU$yRef10!0fFn#Vb`Z}wH-*J{WKInAaA`^|Dw z%reyc>)Je9QuDkuPd1*{<;?cNp9i&3zuoP^w6Pt37xVm#ZInxVta!f*U=1U)2wIT9 z3_kAl4SA&WSRTM`&k%GEIp=s$Wjv|^<5T2uxq35q$Daf(NZ|PiLY!Kb*beH--h9bT zF`rXgA(>WWpRKko(Y}kGedex@=1C!2da<6ve(39uoa=+XYb3(D&XN|n_GF#;nWF^> zV_xe1J!QzE7TwwM3q}G}c(gnouInz+OKr!ltjQ#pk*lp(3ti9d;o8b_Na>v3o)kTa z+z9E+E*JZuuLpC^7Jt`BtgJtlZ0_BOB{j~auTpb<3khtSAC;6rjz)X3W@814`Oex} zz2e#I@k3jM+eT-PD}vBy%x)4l%$JSc@zZb~n8S_d#NW8PT_Im5H)HNW0!0fF#x}hE zxkCJ-nz1^(@1g|>%mC*ZpurjBf`b=xjUZmDKZpImq6HGY5UyB|5%2Y(lTNf6c-w4)6iYp^N2dNy8BR2kd2 zKBpZ0T&gsSKb~bcDy$%YN62fQ&96nrI2L4U_^JpkNMMZvL1?<7CT+MnKl6;s-i^vR zTU2R#XMd;X%+B|=x1brz6V-WqeSlRQw0*xZZm0BXrS0fV8_%rI;P)XSoqgz?S~~Ct z!Qa9*j70B3Ci=1OTh*&gE{YZ;ux&nb?(v}wKP0Mx=TE~?Vf70iUbw7-4Y#-B^_9@p-J;2w4Xh8xilkl%Kx9d!spX{sVtC77UoO8UWG9J}H zHGqDk1=Pd^zx93Uoa03T&rcAVc=V(ZL5-BmKtbPc&pBIE;oducmpY{<`CWq1YU@Se1FM-M5K@1!DuD%^+XM-}7{Tk=(z7+$-Vg?=imj7fTU03c0IDxZ1W!iRrn^J*GOp}YI)Pi zEoSn62vkjTFQ$aMo+Ul?qdFZlC$?X^2eEfKffgh%(vjE8C^R5;P@ZVLMgtP4!h9pX zd!9AW{I1;2m<|7F1641#IVuDEFIcZttn)wSA(aotR^tR(kiZN+UOnm4cJsJZ`~QzX zRqsoMmBKwPkyxJX$oJu{bTHTXUCetwZv!nzU@SL}vS@wJ)J4s!Un?X~)!oTf>Ce9h znSIaef0}36Fk*{&Yi1Zj3lbRTz~7%=7-Km%sFI~CKPn_pg}GTg4!Qp*b7ZT0mLdNk zP*t|pZ`q~(HPS#o^M&Upo8@06^fu6f1mm3Y>Drip4nrrdl z1X_^5Y%czsPb_5iulxT5sw_*_%BPmyB0G78E1#V=<(6OU?j~0L7RJzm1V)kYyooP$ zEpwkl#A^Kt2~=VHKF=PwU&2yt@?~=fibuIsLI=0On+4}YiEuG zsxYg7*V214-Q*yZGimi+kU-Vm^||D7JJYOv_x0zerahip%-YC+79=p@hG$Nm4)=B) zw%V-q5+qQyHBhw7w@J75-J*FL#9oiKSRV1Kixwm>5?By!wA$kxQ@XK5%MnKcRT$SR z2$N<{h?U<<7VTOgfvQ9gC$m!#{uM_3sjaf)kf~bhS7tLmDzqShQ9XiSahhd0xw)l9 zyH-e`3iAqi+|t;6u{+NOTeSHD2~@q9W+w)3P9q+f+Vjyi>yY`U-8}L6moSDFBrtM_ zuOPRNh^>BQxmg<-kU$m2Ckw*ZuY=4Lm5=7c-$nve7&*khwwXTAyk+^P|07TpmM~77 zJR;S4b!%NbVA}R}hThxJf&|6^@xD9fgK6FLm1a+V9!Q`HIjJv8=pen+`4zk^7i$~hxm zl)K*~lb&n+>Tt8A^5^>_#7+R1x)9sL5A$n<79=p9iN9BS z5Mo*NCL&hr?MR>squ6=eQsNEM;jKL_3;si(sz#&tVzF-5$x8h^GCq!ygYy-Y+kFjV zXfY6weJTj^eYeUf_U|lFAB+U%s4gy%TROGyDmkgoQJcnB6sx~pWUl`!jG+YyjFcCI zR*$ck+mD@WcH`Fy2~=SeJKuq9B#4`eY%+c0cMu6w&AL=bO51(OI*Q7hzF3ZjxmYv; zEl6Pg7mrVDxW{tnh@Is#KPn_ph1p!Zo_^Ommg&LkV~78TKvj|9MWyJJh-)V*WwUSh0C^GlFG)!F-wlD(2(?TJl2W}92SanR>lv><^Q$Gn2T%V5*O z@{#(rLIPEPT`DZCITL5Sy5k=Yu)N9bZ_-v1Xh8z=hj}#BgWR#H52xv~8WO0&$U45K zKYNimIOTb)<})CHsx?b&BwwGC)}z{5#Ws4W>Agu?N1+7?%)sD#0JrzX_&2U#*8E*0 zP=)nCcy#pi+LpQ-tD3dZ2MJW=FY{L1F;cP4&TU3NF~9oDMsChW2DBi7@qRp4rl+IW zrRhpb%zp?}VdgadruWNJv3cGe^wxYhYy(wK#Is_?vsml3suemkc0z|``aMSr5|~NC zzaW`LOt+qo{67L!?OQ~M+M9;#S&gjmGP&I!pc8080y9Z?t(WU&)4(CK|BpZw=7k7C zd%w>yL6sN%AAzb*4?@L(ccQFgVxRMkE$yG#iqrX+h!!L;XIu~_MpU<)KkFox`wxMt z&9-&L^A97fXI?1pIMax|1>_x^KnoHWf63okhFvw?JG$PYc``_#3Zp4`yy$^^=G*r- z=<7iwP&ISOHS>GlNb721vU6Fp&&^S~*M}A)u(lA7*gEPTTlb;8{(K;Ts=F62d!H?s zbLO|t-;aHD%|+I(6XscJ=WgAR0YhDIgknnz(X6e#B(mJ-^bSr6c|FT5CR%k&2>*a8tVYItB zV(}Jzrw9pDjY(=IPt6x)9jRyM>1}yY!d4#5uN7L5z>H(Q9$d4@@}88@#}Xt^h51H& zhd%JQ}+|FeYEvjWfr~`+i82azL$X(BrxBI?|!7^w{%%OPu~MT z0#%qjBnTr5xmXgHPS8DMBv5s3!%|zgD%?M*?K>{Ra)l`s;hN*-ma1f4u-eLk*7%!w_!)6b;MIA}ovGuioA^1DWC z?(Kp4ZY>h1!d+ZJIKJ$>W&ia~v8TB2gBBz((~5sx#r8pL{LQ=;Z7e|oRhawEvomLJ z{U2Rt9hTMdz5Q*0q9`h13j!*j*a6R;83V9IP*Ft0K*bL1LQFj7v0Lmqb^?3PsECTK zScu)-E$TZDhez*y-tX@ZuZzodf1Wk7b4{$Z6H~93&3b7y>W#>Q=2B@ZTr0`+==O0#)Y!|FdK3zl6ntgxR0(kbkoy)xB+Tdkrkf>i0$YUys?68hpPhh6palu@ zb$sTDW$&OX^aBZ0neUrFI~vgkT97c`kM9;+vLOC10#)WG#-DwbNT3A?^K)f>RZA9x z)xbaJ90^q69#7-7ixwoVG`^vA-fG!limgHdRk&AGR?;b{{yyhuLE>4h<62anE7|vl zgIhvv%kv5eRN*dNS+N%gv>?&mYNvL~*Rl`!O%cD}K9E2a?j1HBE3_a{|La0+MPo~* zMTYYqA4s4IcTFn_5@SRAF|6tVv$U|93yog2ZCGTBefMEIwA3zwtXRK>}5n zvmvYXG#UT9RcJw?ec(k?sihVlO^5#J2NI~lY!z86?6X$Y(sQ&RvGb*?>f~kd@qZDh z!mJzPb%horf(`tX7S4g1B zygTVn4i*wO zGx|Ua67#Gss_7jqc@tJ`{=~aTpvt_T=})HKifToEzdz7|MEJ*t>eIKD_u!p=f8K*g zpbE1Djn@@gkf=B0v+5OP$%44tvDojp9SKxnwxe9N1OhEc?Dxn;b|qW#CK`756SpIQ zD$Kz&_5&?Q^s1GQ6v}6beej$kfhx>CHTpmc5=Tzkl3S`JZvuTFfhx?RHTpmc5?$}x zk&oRhSrF(02~=S=u&l7r>d#mWEl8XiScC*kwB${o4i!s*L-{mS5q&??JR6VLmr0k1Tl;IF3RB zRk-WT*blTIVZO%Zm9UICu^&jF%D5BI&=1(fh=loGc^6^Hv_K#DTd2Z)i^hJS1qt&# zKRoH*%#Qz?Ko#z$G!ke*!u;$?tY+EEiTywVRpw`C<#NTA9bOIY@Z8-1V!33H4Ru*kBL9DN{xD&tOaB+!C{Io6u-#j@*uLByZ;AQGs; zJ^M0Fw4U*M>)Y78~W#JM+*|>v3i~i%Y8nj{GT}h zBv6H!Rx&p{_n#{fEl8N>9M(RuJg=s&nfbftNTAA?ONOHfv>;)gJMk@JdA560`7_2r z0#%sjC#%Dh{L^!^AYq;ZavAY&)`iTd>i*~GB7rJnZXZ6c(1HZcy~~};IlriR;(D_$ z8&7kkQI#a8ru%7RQ(yg5g8RWm@(XFr2Knk#^-eP970qmJ-WxR_;X2uTs26h|m1-bR z)&9)3B-23oeKkodO7&9Ktm}2grffgW3ncE*YB(g)3yppCuil4SKmIXh+gZ3tjr(bs~qUj^9u5 zrq*-RBh#+Ye6k-2IYy~DJ<@6MwZ6Ja=cvyXUCTZnPaK9apZ9WfvRDSYt(wz zuhI)k%|6BzYb(-sB-6klXLw4>oFuv8A-cP{uU_Px6?xd@FnuLk5~JrNVeW_Mzx|l> zv6(ocKcmh|Qw;>Ftg3udTm9s;dzO!ok4yRcli_Scms7l!Yeh9P#{s%U_G4&9pjx5t zfvlDccvew${kT8d$B+V1oK%WrKZ>Lp2vqHSlxf~lFQffgi6iY=zK?f<3a7r-{` z&_`|Tz#0eTG7tt;4R_O$M*C^rC1xMLHWkqwTXteDt2!bq8n3ax#JC3^SC^s-<8?gWcel4W%9Xnw#O(fBl7?al}B_}?>X9y zi3hXJrmQOO{Frj#!;(MdHV~-7XP2Vnn`*^Au5xGZ2iXdIF5&aa`NC1H;MGJrX1n=( zJm3zj>7o*>>IxeJfvT31)3n5!@{KdXe9p(bF3M&PD$dTFvK44S;&$sxT0=Sh_}9m# z%(84mfkJHA?ED4-RT+=&YK8A;bVb(rC|;&4v!7p(HL$i7XhGuC-g{c!<65?l&pEx= zg0Z<+(6xL90#$`lUTF)T>sd!#QNn{gSkR%IEO=P~ffgh}Zp#(rt9rJNef?^(F#*qL z=@T{v0#!ltziF@X3OXU{2{-0lJr-*vzfm*f7ymIkr#`sOA=*~nm)Gnc-*(xL48 zqheAMcCNBUcizi15U85sY^@g`blCD%d(xN{KaoUlrT*e*L8AH`YkgAWVH%Kif7IFB ziuD?_l6Fu2W*|^C?1_y&a>o&xob^_7Z5hal=Uz#tjrqmVg2aT#{Cd6(N3wlHhJ~;z zC5O_p8@?L|RBg{!P>=bPOz)jBUsoOUcJhl0O=y$GiWuC?POsV|C98+1x_`z_U&&Ik zeMCL($RepizyI@%BY~=sXAA3n7agPd2AX~RZQGG89HG!TtA22_ATc4Vi0(fBShkPk zLY>(6S`*2P)tLqYRrqd~Uov|V!HRz#Pgsdh94$zAx)jy>K0luAZe!kiN~VE8)uwb?J&>NsinZjo z&!)9uwRVr@r-~{fe5Z}RJnn4P-HEEXjq>Ytge_g5^P#gd4FswJ+vn81zMjv1 z*FNjmgmvqz^ZK4YI9iZclAcq4tX#6VoGE0en2H6-0RAH1UYfb)T z!$?dyajd$nz?c(Nmm2QXo>sh??L+VXm2S4LB0gK?HxQ^w8M0cNe&;Hkn)QzA@#QmJ z(zvpy{@7Na1&PLDjTZgZ;$w{GZ5q&^j@Y$6zkxtiwU4v3^m*5^?~lF5@6%Ujsj&6= z$D3XHAlw+J-Gibv>f#U7B+y(+wm$p>Ud|j_+--CZe&!G2%T8iV%tpr+- zxD(-_U8!}QreygD%+rw;8q-EBK9b8opz7DO%q07f*Xiu6XZy8)jx@PoJF)KDPmUHO zF7N-DG;G@S?DNsq$(hn&VPd?;cLRZ{F5fpLHQj$b`yO1s%9&1?6(lRf*b1~Daq-;N zq^)PKXJ5OwkF6&6d_zRP?KTDiRc^Jtl3LujPRFH~-%%T0J8-Wr(#Jeo5xrbda}Sl< zzU7(1u~R)J|6kTN{5QVZ=oZFyQ1s)Bty)w89pXU7b^4|V2){|Xi^VO9by zNVxfUs<&%g&yGd!@L7DFUNor60FIDogOtMaRqfSGMAQrY%YGIFt1RKvjip1<5SanQR|R#|8`EyuQ4MgN=beRp|zH#Np1V z?03|Z&7q>!>$<#Dg&!O(NXR&dQm20SHr@ZtK%gphmRwQ1bv*n2&`Wg^-u_d4edqn; zXhGs!>7u0k9ZL+ClB<&_^I#6C(m8+mx6{G`t^OHBPXGEmeGSwka7WEu!m;k<^TY;$xLDXF&X zMnVA*ck`kek(NYV<}<7j)6FB{HX1zM2!RJgS^s_lX7ek^#ioUW@|hYj1{ zY#>mDa}J79XVZFGw;5qKU7ZA4kZ^i7O1tvt0Npv&e5~4XA>Yn*+1zjk1A!`>rBD>R z7boa&y&_Y5-2~2Apeo|)8m+!}LUunQ#=f8)3(K;q6I~1hs&HOIRyMM!%+AhwK+l$P z6ga}h@p;Jc@7g{&bGI<79}x{|u|xGv&~20cLx2j$(J~Y1Qe74hw9oXtOf`X{EgY-u zn^shB#!|EUF}j8?D?4SFULa=`1A!_WQ^}PEw|eaPeHUJI&v}_-aeM+&wcAt5| zbY}yBDjX*$$`dbJ!khbv#*S_R=Z!Jmaar}cHkC6Z|K5XBQ!46P8U~01OI-{Es&GD8 zuD+g`uRq`7C!EvD35<|XHKO`IrVVM=vwiem@5HOVYarfCb1@L8!Z=9QjS5@M(_&PU%@dw5w8$=OOe#WtEAD91bC_W@VA#Tgd83y4pB#KuqO!n1I$@US}rGc>h=}qf#X9IyMjDr+q;qH3kL53GC@wuEp z3le^p3X<;+le2wv8%>07xdGJF&&5EX3L_`^J;=^JqD4|1ZFQoYz&Hw3J8oJN8hSX} zhpBCK(R9pqIiKiaAW((tMvCIDRu)0)lIX7yt^!w)P&GAIrh4f;&-PK8x(n}n8FcGT z7XyK+qrtD#ske3dD=V_Byu?-9nstp{9_l90f`oi(tIn77Y#$^1ONn}p?Lp~)col#2g6|CR zUXj!iqqi?po3 za;$X~1`GVI@31*)?8t-J=i{Z{E*{aP8ha<>L?lqvwbBT6{JR748w+M1O55f9=&Ks6 z!3Z~j>yD_(oe`{FY-5QNTi6Zdd3`Cf{ZZCHpbFz_`9jg)zijt7a=HCM4Uw4uOQ1m|S18j)O^*gZte_^Yz0=%E@2RAG!F zGjgB5<(D7d1kw;EPkk$I4%ZA8g{3w**L)j*&M-#&`6z}89h-Y}Y5A& zv6<@dLb+MDhG7CNNMK&NTpzR@MUL=1Y*Ip~fj||m4=c*)Gc(94--4`wa~pvcBrq>s zQEsk?C+_ka3}N+q8VFS3Xi{dpPAEw~HZ9C1R}2(rK?28ya({_kCwh8xZnpPPYk}i4 z9Nl3CxT0*jSDSY0Cf{m-eGLSvaNH*2#HMky^pjL-6(~45K)Kxm;b2N4G=xiWRWgOjYoAi_B zD>0qMeGC?8K>{0WV0>|pI%FXs0)b89nz2UD&fg^WRVf9T#xo7o(Ru3`p z(#{=?1l0S)^+;L8Bh5eKISz({J6b^9j)*1_D*Mx+rV6rR=A%c^!mrE|b9ZI$R%| z+;xh!zfd}zdDk2*9of*5E}xdl*DUI8AW(&?a*FcLkFm68cPCM%pPxVr61c`DpCyxW z(ND|X@~h1v1g=`)DjC*?RFonMqney#Ea`81_D($KcgsGYFVB!`55>7J4)cJ3eL%3JwjR4@m_mg;cGg- zU9*dUKo!om$n()`AW!S_oDZxQCeVTeRu`1d65GFd)QkIk=g?3CfhwGzQIw_!$MX~Z zI!u1w8RttK_x5duYt5t{;Xh8yNy2&%xA*aac;iz|(&nqNQ zh3hRcrrvKS*0e07@A@8WezyN!;Xwi`v?UAT;K%ffe1LQZocdy~iH^s{h4c!HfuyIU| zb;=ZF`SguEy3sk>Zhi*?fhxl|T3Nb3hTk8sl|Eb?C2+iqDy&l`-{&PJ@CofQ=)Uja z1_D($7Lk>1k9+X7>;Vm}+C|{_6~`P{txHiVzRS%Ir&zJU#oHMOR2f&7n!AqCTX)(@ zD;^vu(1I1ZaHUC6HhcN$KdQc?k85`kXh8zUJBs3TX%SuG|AqR;2Mdh0F$%}{T2cC) z2&8x3z2W22!UR6`@oj+hh2^f|Zg)uWN9%aWq=5zkRrqF*tCl~WlN;l2^I?4>1zM26 zH-n;(^eBDB=NN7~c(}l|f1G2-3b=AU(d5AkroZ8zpM)C-RN-@8?n%Aq!;>dk(U5T? z1TM>>EGFK&haY2vnlJ)7oVap z@T|{LYLzhod{JA*Lup&&(#%76;SAcSjlsRN+WZ zz6WPn(}F9P(8Q~Q1X_^5YfMqvl=w=m3k0yX>E}3Jaj3#^o1E+0vx&M!9i@3PItm=c zp$e;Y$)3knp^1fSQ>P*$3%<=*q?o%Aw`qa%m* z5NJUHN0W+DAha-RclsJ_da$NI3lcasly~jxcC3ScAa~?HImYUSNFO3{nM;=Ol|HNg zlm{%P0^?7dzrrX}Rh;IY;|0zRF%YQ2=Zd1Nebk31`C8}Px0)Aew9*CK%`T+@+f@aAM5 z(MjR+l%WDGNZ`DmqI7=hCTto{=dYDI0xd}3yq}^t+g1=c5Bl)G+N%O%LsVfrC%^Mk zDwTKey{z_Y(??)bgwYb#vz1TLx*hrSlR3zhJ0lGQs&Gw5&RITR#Jh}{O1$qy3$!4C zv6hSgrk52{lgm*LD?fo2Brw*J)%VtUi0zlA)3=l4SJ_~N2UQq3$?;(GvcjAhaA|B+ zfpHY7@LeJ!FHIBkh$JLXg|*`3h$TFRW#*bgDon4CfEFaMwTd!x$q@GS z%Y1r!Q4<||j{Y#8LQ(QeAI@HOoj_OhsH368->d4GC0XeLzKN z(R={Aba@{wIOLRnA3u^te#{S#cFw8qr2jZn$UuT z(MRL{ZP~((w(R2l&xx2Din%$st4EG&KeuBmJLO|7VwP!0pbA^7D18@3uruwx()TUa zn9za*wor~gZZ~BCt&d$|1=FPNMIhU%!0V>%1UhV zV{tpZG|cqH997KblHX(b?#jI9__0!FY8wbtVLq?inc1ceyT7G6>oF{^iWVe{t?HIg zjm7`0!#r-3($Rth=6@;5`l^gYs-8@%d%%PQs_=-&oiz;|S*`|wtm>R-6ZRI50`3-5 zl&iN3vfAz2uz>Xa8d{LRjBNSlw9CW7c6VURYncfNRAD}z+%2;54Q>0W3!8N`EeSKo zac>;bigIbmbK1|R8*8?HvxXKVj6T%shv-$Oeys3>RT^f(;~qcsZT>n}Dt#W(hw=LU zhE^eg{^Z_4yV*2aAIgq&+O1(mKkijT-*O(#W+R>cV-Ra_ZM+FBNT5Ht1{t7I|7D|@ zTp`u4G6CjYVn(H+9Gp{!&bd08RqF0!AW(&^m6hwSkEe@|3}>H{ulk||32dQ!bN0PU z90rYJZ8|JAVN0-um}Ragfi^42#HulD)`{gBT9Cl%4ssV#k~0~5I)+Vql4wE#Rfe2c zW%8?w+J(n4?0%4~hSex=e=5>4W@u-v-@FjRF82IeGgu(M*63r(&|JLFm9ebhj71t& z)xZp3gKx$4<1Ky2sBz5k*4jk0Ab~BE6_h>t@ZzOLvL!iIXjmNt^ODiGyhVSc7!eHYxLfxhK4@z68A(779XdNoK#3lg}l zt|*O<9_4x4^<{hRT<}E#Rd^I+#gpT?MC#BmR-j@P4J+H=eim%4TvxNoBL;2j!0zRo zqoD-}Y@wp;Us+rv?Q70fwBMy+g&y3OgTCb`)!$x>-PwvYZL%N%El8k0Ip2P>vZ!#) z#HQ?@t6`-f+*OXgW&M&uZldcXKlb+9S`%83K!0-Y$8{o($9u79d4+~iDn_rkQ(nFe zn$#2Be^q3g+H6Qd0#(>rdCsG%376S**^n3Y6VZYMwop+%1+*0-JZ#yEi&b@usnH+q z7L%FY>spC?t?ilLfPyMokie=_G6G2HB<^}-(q5msYe=99YhlU#)LkOPr9$6mRBRyw zfhw%2pePfk^cIgZPt$9sw`y3s3wP5YEi+$2dJ4A&a{uqrW+t>CVf69!O02q1?#(8&Ac6j5M7DOMNNU`L&J9b`aP)(tBHV>2*CrN> z63d1K(U*G{Cn13AZlBS~oQbEl6MsCE*eyKKslgof2Gi9IK%}%vO6Bi;*h(dE}j5z)0H zPtn(89bHS+L(HziEGW4zW7a6)Hz0`XArrE$D+_@t^KrAZR(?6F#aMBof<5=Wyw8Ld zB(Q~YrCJ{{%D*=bt%C%ra7VNJ zc4l%XQK@1kH-9(L{KTs?hfM;3}3Z^%J%mx0;YZ73N?nN`?61!at$8 zcz%1Sh885uk-?v8x`*aFh_pk2B0xQqhy*Kg}mfzE_D@-r&gOVHv{VlnIoJ( zRe{SL%O%D|hlxJhBeP;3iv$Aom77G&CLOFIGS6TmXG)6dvS4zT`U<)y)Nq!Of(M4_ktr#); z$Wje0NSMcPe`=CXc$k-{88KqcKzkDssKWip^4po#E6HIW`31g5PO5oiYN?@)x#&pC z85XODq=CygF?Fbqjus@0KK_|qht|I`THMKBH*40wQVSk)>e08H3y-Z!Tb~>)IwdyA z8qr%?g#`MOGj~O2({o~|c(`YE*6fF+rajgMFpu;tzU6oOI*+H-E)EynzBVkmzbBk;vn*F+DGS7=yzOI1PHjG5-c6*kidMjalvWSirB(RQ#eD3~}o29?% zAUbK92?~PTtlsT1qGq$4S@YDEFN0v-CDMwr>{wIgUapjQ6L3mH3lc^jIli=GhdSjG zXDctsnxVIR;RGwim}j6ZzU33HTU!?2&sOx<|2AvB-C{uk{VB@%cabdo)&uUoXsKqN z)3gTrta~$k<(g1^VP1P;%4gE`H=%mW{qiJko2TBpcc|Q3Ydty_?^{VD-&4nk7j=Dz zeSRnX*%za#d2LE|1QycoJ~9*ES2rYQ@+}~~)-mGc>TBA=%Wd?>X`yAhr3+33r zK0=(||BX6qexu>PL8;P%(0ZH9PLqJqayHw7=n}*HUv57uR~S zRWA#+5W8v=XN7HJH6&0q&bpRfNH~&#FUYY=sa=xs-t8it~#S@Cddv zEI>ocwq|FvOAXr)+igt@t(AE!$D8th?v7%4`>oXaJeZ`Nn%>1KMbaRtcF3@KO2__k)e3k1kEN zlb>(LvHU)Yfj|}ZN$#d98A}|GjANa*Ej9EO31iQjK98Ui?v7w21)P`l6UKw|BiNutc{TQtjM^t;vGsk7HLA!FseVkDBAJfmbqH#YKWAr&o1tPeY+-WVQ2n%*$C>Z#pMx-NMD zi|qcF;kQtQt(9x^dkVAX{ezipop=p@7hBlicu^ue>k##(*#|3InAx2QX1Ttd)zE^( z`vwl=TS{#*s*k@UqKhT8iAOuV*sL9dEx%kqM<4jRNXyz(X_Z<0Z4;Z{ZiR*xB#b_qoe5zVa^+{W&synK z2L_Uz{qL$3PlW2&LiIxfa`KE7*|5q>wyJ4_AG59D#^QXt8v231Yb2~2jA9Oh+tP(O zKdUROyO32Q=c{^11HEb0K4iko*QV34q56Z8-O03kz14kLXVAH8Pgd{Hb$VTBDz*yg z?=!uo)Hk_}ZNwEKaf`iBC2h@HpJq$gSTdGO*GHYZniQvbu@qOb#Q-fd{lq-Q4rrH$zx8xBQ}Dy_y=TkTymcZHu*I-m;$Tl3xKDT9CkaMMlUa zO|08NZ}IrrKZzLQpvv>4GdW+cnqG9R`P$w7u`~<0)>wRbIm!?fp$elY`TcWQ2eRI0 ze_=K4ppF(KUd87nxA*wzPC5LfkHH%ib~rd(W?ruU@1+R1tSR)w`&R&&%jqa*eFKJ|3ZQw3D3 zlKu1=C46+P#ANmQr0)9AtoLBA2i@rNStG?A5~QI8iC@c$s~K8B%xBWA+2Z+-5~|B(R0(TkdGwTZ%vVFh*2=IzU_O)>*q)u$P`3 zo?QSJIFsxS)2=v+~@bUVX0MD!N@0zYaa zGfeu_3q)UIQ&um$vzor;$$umLu50Z?KcCj3zWV}S90Oph@R#IX=O;D9q$0J&t9fxM z{w}IoG^nAs*;86yI@a9tlJO%1Jr&Kn+;vmG*6ga+KQ>o8IxSQm>Cs1TyQ-iz(T(UA z`bOy=-Ro&lkIh8X&aq-*U@m^t+u2}2BH=(^ef#$4yc^q`D}y~SgtKYgps<}LL#KYV1X8a`+xnyk*t=Je^tGw!(Qqi+(TmbBF$ zm8q`p-&mF?w(feWjPS=i%ujAwIY?r%YkTqJ+bz9ZYb$m%hmYQ>)oaye{2+C67dO4r zfD7uz7kkxw&ph>=-Kvv=Be^Q)WGeP#3v$@0>%V4D2Tu>O<$V-6-TAb7)YFHIjA=(a z<(ffW`M0NMESc;WCyBK?uF;coFKMpnJ=un{C)Drp;biku4>JAgG1aYHEcwvKi##7U zNxgM9j9kCyLlotqM;e{haiDh1yB9;%s&Z4+q)M@5(lwv|5dPyQ(jo(jYYTb|XZTx4 ztiAlpw4rz`8PTGl^s#iwNZQSAkLF!xgn>X+o62uZ9d~sgoo+XfgsJ2x)+BYcR&z=y z#ok_WsZa7b#F8dI?9{w|luWK1OGcgamp$Kidl*|lcAKsbooXOZWjt0N6MC@6JNl?$ zdu~#+RQc>dDp!al4jqrHc+?f;!l*WE#Q7KM(!~WCT98<4=S0d>ls(_^OSWpx+G?zA z>56JXP+gY$Q%-WVV+094P?Vr*YQ%SSv_~u{-lCZ7dB9>vc5lv1?ah~#1_D)hW@Svh z!I9nXGg}LO)zWazk-#%6?~ew9*xwD7(4#ww@mgELNwtWTYPDxV+cT&oSzomfS-E4J z=Cz}loX~V70Zm*qS)bwaXcpANm8Q5ZCrF^GN6FUYT=D}ou2CLITzNE{O&ZZ(``{Qu zZ(@j|H<ooBUXRig&+Lr-!t zBv7?+MOE_i%E_dIon}g+WTp13$Cm;8M*Js=79`3=dXi~zh14_Sd&zS?&Ak@8m1Dci zl_|rJK$X>@qGah;KW&h^qaa+Zt=WKPF3J{Y`teVCiLE;B<9{79=wLud7XBYx?d=NtMI_w{x`l z@elm{+(rfhRh~=cstu}+)mq(IAqmapCT;uq3%48AfT0D6Xc?EBuXiD7-u<(ZsQRWa zos?uPPM>dOAW-%5)Gt%U-M_T{ZMR8+)#^`2ys;M57PVq%L1IRg?aeIxQiVI!HLv6}Nvw1m%K|qZCf>)N6TDvW3U}QSK-TQ7BfnZW zNnRxZX(QQEuQJ435gG|p;Z zoR8^Qg{&>tlK8B;tld85LM{yoAQkTA)f>(&P9FZ~MZ61t)#MI{yLDK;jXKSj;waF9 zM2`Kor0K(8(rZd_T~R`B*JL&N9;MCHVg>?LYo_EO<0|zhS2q-vgkIW#c}+>C&34rh zXh9+|{kH0RZ~!S6UQQC{u4dAMJJM*+%MAru3Yw0n`;~6Qbh^5JW&Kt)c4su{Q@ENW zyobh7N%S&>arg}N zy49Fh_!em(P}RiajCyH;#v)17zr=sVsql4n6rawFDPtKKevIR4d! zc-gun;S*K9an259ua6a>+olcWNT4d^9U%uMHYXS3XUg+onl*q8+Y~@Ek{22XR5`Y< zP2Mi(O^Ssql*Fx*Emqt4$?kCoePf9EtjYluVcJ_Ls-&$W|2~J=owcljt?O zy&R7ZUisKRc?KPtv5mWC(x+)QhMpsVcY*YAsvCPnE6`dMuN(S-{)|4xJ*-b(<*{NH zH?4m(=0J16p){LxJ}M-Yc~pzE;Ps4k6FVSCT~Qvt^mfNWfhAd@cSjmm9avGv`(M_-x-pCvyI7wd-WDyehcRr5dbLa}shm{XaNOjV<*bV`Uu{3x z6x2%K?>6+!RCjpE2zjghe|(g#YhqR3SEl!ic^IBPNEpv~^;y+e;r?^U)7UzOI|>QB zSLBHP#7Ne#?_|==J;ZSB;`Lg{zdji%@0G^G6nU%~ri@{qT?&zRu4N1as_?lY<3#Uq zY^V1KGGgxq9W6-gE9*yE$*6zp$XPPJnzks0rIo5lj()QvNT3SeaI!wb3pszB7(twS zEhK0`0%Lf2UAfwecJ(gMW!q|sPsK8{Dklb!W(!>P(j&iXbLDuT-ga00)X?1e);a^o z)W${RmAIo*TM;$?tgJA0m{$p}qigrPiO+*KT2$?tdc136vOVpDR%M;5{-R|&;yllM zU7h>VQarUgK_>>?HH_7eHWKw^<;b$l3b8&b3J7dT?#5;G?(ucVhYjWQ8xITX$MUu# z3)a}lR`okpS)6LLg|=K$!9bwO*pJu0MvKh*H_0!zg7n{nRtvk?SznyD zyP7m_ul99ZpgyzcWwrgf-ST{_Up-7L=)6mtbt;acB_gL>eT$!BYWJ|G7X7P{K4FYd zyRBU(HZ=&iVuxiq_J8ymE9KgJGBDGW)-!bxsfk4%I zM^AlCn!Oq~sGsz4`-i_6UuiiHYfzA(1&K#{%jrKB*=zmgJd!c?ZZ+Fpa z(`tU-sk(tc)%qu&HQR8e?k{WhQKnT9QMm0PJ}lXnp#=%-pUkva{FD2wImEY?YiuA; zW$Z^nonLyT%QyMQ*hu!eR8wu?=sTvNA^FK(mr_Y9FRx0hlBtl8b;Xm$I22aL=5ru& z1lfOpXghU2-LO26+ia|-?`x@&Jc|}wH<6G*P6$+9b&0{6uU%0oZRcnX-*VdC1XL0cE!n;6Ll`j`6KAY_M z>p?jU^FFA;IUu=qx1p_YO84e39IO~xkifGkzl*%Bo~YEg8~?tujDbKEUJG((P2?M1 zGh-h=A!iMC+U(FAmM&F8w-zRNHjU?F)UAE|&bve0^s9q`Koy=R8TB>U#m5Eb5WhM# zWvIey*FGdr>s#PU((8R!9 z>2lMw?$6e#NXvYjS1x>6+t+;4rcQ?I3RQTG$=W*sQ}mm`bxCy97}25fwz|3O29rrw z8tMoAKP0WtqlubSUk}Q+Blv`=B3iS6+khBqhvlFZxw>kDuGa|dm9vW3A1zTxm4EWbtPH-L|- zS&MEyJVaninih%B_D2pSPp+E&)2iaW0sMOFEwoajz5-QvtRD4q&|aqZCw6w`o_8;y z>#1Ak(U@=1;`AX$ZHkQB=D(&o9v?hzGAAzPI(NEOgElPH&CrilPG_~XrD92^PVR=| zCU@zbuov~37oqi>S_o8OZ&U4mXkKy+e0FPl*{ZWaB}K@J%Opr|BG7__@qK=EZ5$^T z9H>o3w4vvC+^~O&@?&+VSX^Tn8Qmj~=ylaw_ttxoRTB?t`zw3uQ8Ln7@?B`ST;oDqK4=ei`$xQrAjExVF-`y%nYT#=Lh9>vWp5vs|CN;5+N5785J7s*m@J8hONNU#5lhLJej_+qkdWgKeT?~7 zO&BDJ96?cHN5pyZ&$ja%`++KKp{#Nf)IZ~}-Re=^H@I1-6*Dqs5=JRD_X3}Yb(bDMlRCg{POHR7`seQ);=mG6ol2z@# z$+*O6#TXH>v>bhXD36X28LHODHPXkOqa@SKPWHU`gJB}NP!l?%SuMjG2UU0s<=)!Y zJ;j)J1L=thdkqAtFp^c2{u_FTik(N(Jst-*T9Ck_F3;eFw&GHwI65}`lYu~$G2UG= z!C#zooJP4`RGB~(7<>oI8OSYPd8fpYJb7sd>)h&%_SeL3YQ~xZ1mB`nhHlk**R~-; z`nwVHu9tBvnuL)x$5s-Y1;i1Oaa7c`@Nm{<=W4a@j>!}WRN=aa91kAv&;DXIdh4?X z37Kl&EvjG!t09<+q%P> z_wVazAW(IAdp@$#XP^3W|6@ta`&F5pu<_>ix7226K_cmErdnm38yQx{PM(h?EA3gw zlc9WD%_ar{RW9il)eV8G)%YKIC2_NyJu5dLl)qZggrNlq<6K|PfXDRf$7p`Bct-<) zDx4{l^>^ol(p8Tp@o%1kSnQKmrgpmys`3l!hFM3Pjg((&?!ifqNmF^q!J!5MRr_j0 zB=pHRr9Lg|E&Fl0C?!kchw&obBN$qci0tp4RN+B!vVDw65?A^k(r;Fn%8&4&3@v?I zuT5HZd6!ysO+|vgBx?@zh#{f9cJt7_0}QP~mAdqusmRmDYOxdMHLrfZCeYE>w(@v~ zC$FkU$c64~_&W0HlRN>66tZcM>7|YXoEN!&qoQ@V;iNl#fS?gZT z2!FjYoPNzHXqdA^BG#!9iF;LsOnJ3UlbQE3yRqh(OQ`GnodyC`I47ql!}|ANugh+w zy9>@XtO6imoVnX`v7Q**D^_2yyaIFFSyWGzFX27Sm9Nh)0B5G7|dbpH_A@;%e!T83xy0)`O=rm16svl{KQ6Boo z60yWBI8{ykP)onmB#3OOx!CZWml4^wp~5=2i0XHArkqjr(JRWiF#oO-RE!xgmQa-8 z8AC;jM!D7Ft7jPqRN>fAp20K0B5$!HT8i8eh{pOGPd|~XZC?GK*1BPoiiELMt&T*Ay3t$}X_qLfuvK_O%=LGp#jyytq+(0r4KqbZ z8;P|~+{Nu0MR;^9GrVi@Eo!_&{E;ju#18&jrclJT6&dnb8aO|70= zEdHx9Pq}Y=JL4O9u)`^>uS*MZzIH`P_*ZfitVBURq(O7T`y5s1Pky&A%t_4evOzo7 zy(L2n5|=)m)Q-s2ljn;)q>rsu4#NHLL#>u|D+7TlV?XM~@8iyS<9Pb$?uOC!wQzs! zVg1|cj`yAf`y}(|CvW1hHa&S_Qg4P9B2F->$X0bG+sWeK!J1G07>4^V za9>7epWME^ykp6a41Y;{DE3gDrFGS>#Tp4zwXRVsX>o!-shG8sLp{1w>mO5BKeQ`` zp#_QHjJrvl{kxFnr_8%Wep!#;Q6*CKFU5y5Y%i)zy{2fcufoXi5#~KCu8y%h+U|pP zZTWD979{4zf4FG2FFyKjKqK z&gqAZ63s6LXorTjrTFCR7G7V^l-ClQkpDj;FIn41%xlGnM^q}o@crFXUE7U1Qd~H&TWR8x|kw6uW6XbjF#5j?*Yzm33I8R3l5;(?HlwIpOvxo{eX?Sp3 zj_*DGmiaSIwhR4WRR=N#t%gTN#$9u$CGbB)j z`#@yeUO$xGeC5tNKU5f6kib}5&N`>pW+^_S_|S(H3KEF_1H%Rw2vlKwZQcQqgKuv%J6xPDvZwM96(wPelu?v z_f&=%2vp&dNKx``b>UVMy?94wD1)4RADJB#LN*c=?Sz zSl{qOj(ImJYiejp)hLp7&X_SISALFe;&!f4%y-@t1A!Slg@*0W=9lS6IyR}VoA*H4 zi`9`0SWNYW95b9=EOpSY`P3o7`^bMXs+yh~BPJGgp)1F)G2u?qulIWBpBy@(H+a=~&79<&$czc|~F0nygulw;bo&QDy%!kN)9&W3u2x6-lf+R#g;W zrP#T|V;uJd<9=aWQI;dGZyw^yfD)`qt$Q4Q3spEspgKKP3qKZkTQTvvFkaIRlbYHS`Pg5Sz7FGLh2xG%i3 zUx?l-*D3AV8fU#E@2-1ni_{uKmy^dTJb8rhTR(``JMqMX1geZz`_x+fM5VEZc-RbE z6@M4$KJVJ-mn&S-Qa+k@k{>VHLd^3mE^byVY{Gr)4L&ek#1+tA<@D5#*!$}h)AQ(c zwwj4ATf)VoXTP{j`LzZg_`68UJ*l*lIB@<4Z_zthMGF!}AMJLH5gY60;5mw)B=}^& z-r`YElLsL)(A&kt zhd@73?AbsAfvVFl-fL+O8_KK{^Nz+F2d?wE&wGgVrg7xSy_4GFm2LHROPzG=Egl8) zw>kddy{8Tk5djwr1giFJjMTP|>aJT4b(F1Ie{D5C@qCcTc(9eA1qsYhk$dh2ZRMr^ z?j!E)3Z&R7q`wb}*P=Yy=q;aT$bR@VO5vvdQLN7GiX5L&_~gX5sGQ5-)}r3%$w-hU+kKmNQWlICnABW3zItP=dS*vR7u8kH0u32$+x$mx*Lsxz2xnf$oyZ?``vyQ8(>H7XL zK`^n!K*d%JP~n`tXY5W4R4@^{TPgdpJ1+~d6%z~2-h05f6tNH~ySux6*EwE1&u{Mc ze&7H2uzue&vnSTXp1tOdhCNvBwdKeswynL`xY)&$A%UtQb^P=Tj`OtO=NBkfvtEr5 z)BUy>N48XEXh8z|yvg+Xet*&S#yR8e#V!(os!#b_=!L>gX@i?vDUA)`{-T=k*66ai zBSQ-k=wqlyZEGW5?{XB5ZVM#>RTU=G&}S;S_2KJJQW`%>sA6nnSutYEXoeOfus54b zUpG`0-a#J1yWl#BKvj+j&bs+q6+Pr+G^NpSQ!(+dY9o2(1PC;VcIn7OUWW?&1@j%El!itK%(x0Tza*qMfJ(| zdeL*-m{D0+b5|0rJVk8FXv6YbgH*I2fxX#edbK=I)JS^D zhPdUH{0Dl81nH|SbJhS6P%;az>k=gqsFMAawRtOXa99amh27H8f&|V6O{TZ~8j6ia z%JWy;MOv|-3TJ?{7k;CZNN(GRm;T*FB2b0v9s1tcx4c-@Q02X&0}Qkvf%8Z5y8{Y| zupJF~*Zk`w0#)*8;PNSt(Rx`Y(fXh>$M5cHH)&WU8*r6|- zHas#ci&)aK9nW%mgEZg8@k&nP!qQNYeMt}1Ma#mlp0QoYE%DvTN#c8tG=0kfe}?BC zQHAZ8&L@@~FN$3$VASh2RU%N;ZD$w#eaDN|E(ZrvYcD@wgjg~l#5g)=E<+0v=)p~< z(PIaQYYA(N^Zg<`Li{d)hC?e`Tb`!38R*nR!wDf1%sL0%U zpY6*ijYm05BI8*-F*W=fLkkkF9BJU(J#WNwZGsvgwYr!`5+qo4cGo^~L| z#zq-CE;SMnH}i0`AkpRTUfRh5?)v#PSt(Z@KHH6~)tiagZq6JnNMOIAT{XYbM)$wk ziweu~bL=@t%Y=t=n^qD%#MwzJ7|sWAyuuNV>O5BoqlH%uk!!9upIXdMo6+==woCO_ zhQ05ht=Qq9-yL(xQd=!=t^C(B&1*~>g<_2Bg^K?2V`(Owz`48g(@7W zO{U0~eZ>83_gPrSItE&hz%L0*rV3?SiqVM2xiPgj!?_6d2|PtW(Jvkb zi=zhvRiDrr5`ik56VaEftAfS+ojz>;^nxl{@N^^gG`hX6(*}l9dvH zDm)uW^WB&*W&NFS5m)TGGP_P`>+fU3^?5H!OEF6@YKY0y;bm5JuO2RLY|f%0fhvql zVlqAcF^-*gm?Sy`O;^!^QB^)|9-usK5TsAaX505E`fjRG_w5LA_SJN0bjDqrMi&Zc zaTWUL`kq_VcHKf8jk{AOiLjDw8CsCQ+)^Jr+F74*a-z8A8X^&>dSddiWPQ?CKl?I` z(s*lqs-0auNpw3mhM@(C?n4Jy9yK1SKYefeBKJeRB5ZoS@gnVJh(w@jP_-asRU?0W ze6;Nx^42_~)%yR0iBHL!q|(K5Lz>QVZYavmR|*zsvwo{+K|(Iwu;1HR&XghIO!{Dk zsiUe%>ywJp{&sr%H%};6{W79hojoJPywYb?RACMBXXli5+q&rH^F`^&hrW(xMRtr3 zS3?h|Xh8zw@X(G|ayI@gu!lI*BuXMs^`KiG)iXuabA{UWEW4G+&aI7mh~-yzs%Sw1 z;}_9-P?X{0LYs?R`-V#dsxU4Ry%RoBmA7rB3eN&oi9nV0bPaXVsIvNp;bkdT{n{`- zzi2g4bjCdeRT$GrCfcuP!>{-j7pD_Dsc1n0V@Oet>N1*JTz4A-%6(FrSMR1CFD10+ z|F}vKtTyZ#pq{zfQ*-=Lk#aR?(pY}Lb))emjy7j{_-5U%e@?wtK8+<@n3<3MY9WODiWx|5}~gk?yG$0 z)$$^BTw8_~B>ovxP91uzlwP=)t=Co$tj6!1tuOZUY%CF|!q!7`>S7+e&E2}9dMh7> z79{c?%B^;=)Yq@7KPZh=IkIt=*Mq4z;Afj)+Iyt>_IGfs9AE^QAmd_v{XJLOc%y85&BQQDAKhm>Yc zP4%x+w^8X9Kf8xb-akZKb?(LRcd^{?m*`uvOS_rnL|?IN)ov*b{9QR$J^F89+k^Xx z8O5U|pNPc0Wr51<{q6P2`!`Xpsx&CXu5TGCR@{h@2vlL+&=}Rrfwd19BA%^0A&mw| zJidR;d}d-N{ZQe>l!oV#BKkFDgeW{ZRw7U(*ZH2rEPRoF2YSb{lwpZt`CuC}naZ6g z%Ke;Fab&?(i9i*$e0r}|wh_0kEiQaQrn6Ztids43la@8?f`%&WkMtgAW&lqNdTy*J z*i#}jk|e0#$7q z^iw3i9l6a(YZ?Ws6h2|Z`)UN4!IApqoOyzB=$-Ks<5PKUE3^) zjUU^WCt1cDSnjB*5udIcWS!Nq>uo0>CRK9aZ+>*;*;Zze2vn_|VNz!wX|Co;G!xNr zNIn*nB~YwhbeDA-=%{F$x!%yNtcLv@y#&2!I69VmD|k9dvp_@pWI|r8pwzh^s2!}aPd zb|gW+>?(luDhJFQmFnwK9A(Ve9 zbJV)CV=;+9Rj#J})$fxWl%QueFWR+V7{3*=Naghpt7t(2`#*i7_@$t5Yt});B~DXu z{{r9Y*k(Sii-QkYB)Zci}gT^J}PT{$GJ?r_~Ek7_?Zjn^N% zA0R?vuQA+DKo#~Clc~*bU$ynK08w_Ii6en3>}j-TS;t+ScD}0^{Lo1vP=$NTCeye^ zj>eahjm6r5Wu;x`a?LAPH@_=Hv5PiXa0i;^43BAk70N{HtcudkHWIi~Osk0ji;cpI zyv3a6_?YU_&kwd$vYfNcM*Wy1kjFKm}#Fwoma<}B^ z%zUk!iUg`K3Nh_+)|$*~1rA}^H`UV6f`mM(eVXE66dGSgSo?c%+`Grp#k!%lGUuzP zFNQW1FCxluv|!8VFej&SGo_-g-26M1i0ZS{zV$kaVXIv@T9CjtMl1TGquJ|mt;IO^ zd=i1GJ0I671Jg?BN&70$8#5aZ|JIzw{iHO0AF9uz zd>ko)>m!C1B+%zmOpes9e9G{P#>>=Li9pqgnl04E>wLB9Eo@`$%Ajt1(5mys-8DxT zT9A-Og8>5u@VXV&8n?PeN(8EUKJKDUcz4y>{QNLVBf@t$f6rzc6_mLQElA*4Vlw%r zjN>IDOBl|%rb+~=a3rD`^4<_$qFpZby}%JEo+XYY*t@6?z7OV(AG0d`)w+-c8Cb6#UPpJm6ZURdbRe!K!+Unx|F*Hx?;lhV|7ZW#P@^fq@sKWL}-+8Sp z#J9&56SMEY|}@DTzQ8_B4u3weKp+GrxwI z)XY`tQAl`iTdkZwm_`3ybw8ycj*ei?;byVtvAaZ|3OxqRmb8&<^G35M9_2211|-5q zXIHi*<=3jG+x?C()UE%D7$Q2Z9j!y$pb(qsy(>|iXe zUsf2O8%ulosCxR#&H8!gR_(Req?k-DH;Nfv`AUjPZJSF3s_=~fom*;NhWBdYNjq!N z2F6uu_%WY4_#;d<|8xp`7M-v4<^11(5k{C<-u+S!Faz1o(qT2W0RP=zZ1 znp02w$o#p=H-4~A02tAtPZiA_zmhD z#9cbvWxQ^fjus?vw4--kBL?$ZZdX{{<25uSP=(R|X_dO5C*L0Pk+tem(7;?FE#Ez` zC!iazxXY1$%zHya3lee~?zvm@)%8EKo4E`r&K-UkglXG$Cff5i+e|#)x-}Aks)QGn zREGmKRP~;1?=}C@X1v<B@P;iumRdO0nhOS~w3)JBc(kcmDFXL*v_0zRVspx^~>LRw4aD&K$%-H76pTv1cF&a^Y zG@TYPy;RqqPG$pY4ivcV|2FHArT3a)>e;+ab&T*xaVBy*=&P0m@n)~?8hA4d?o8kw z2c26o?Nj`x4B~BEQw`jWKozD(b82%9miDO!pLzV7M4$?DOSzhTn{C_7xS3}WSRYt! zSPHbC*z`2Z^RpV?{jRJ;pbBe~zGR)?$$wri%sbAxY+y}c+)24zt%zvO8@?&YbH2?Y zwJU7-*rurlKU7iAgclI4dirvV7mU6gBNEdbZpvi!YW`B9+rpMoHbJfcenABjK}Mu;|<5@%jwch}|?C5l+2NYF@}UU6R;m%Lq@v%fs0VeOlbH*)`O zEP7H^B2a~quIbB{q9MHOr!DH`W~CX%OTal5&b4UobyEm$`N@&>(qmO5P=%{7idy$| zEFU&y3Hv(Zu8#2%P=()l(&@G)6Zqd3=Ccz^GpuMq0>>*lO?4=cr?_8ceICf)jG_v? zHqGHKh3KW1>(;OIO*MXhj9=m6*U}W}c2NtX^|{e}QtrQ`evV&MVV|e3YwET){Q4YV zq5FmiwBYwq@}A`Zze#G^;H~Q1=rAb?E$)S4gf2Q&Tx6Nr;p|1VSuUAC6~<_#S;hn> zqvy8oYRZ9NDe@Oa`;v2&KS!WZHD#0Pesz*Wpi2G{Dc8bDJojJu^ct(iNb`1VR~Vn3 z?yk8rmUo=>TaWs=P$E!;>k>L`_;45xbS`T&Uv@xRe;|R;=II@Mu0ZZnZMuAKO# z--+XIp$b@gP35;k@yXQy7tDfyY8KIARO9ZO$3viQZ zN5krB?3UYxd#(Wzfhzo-o9;20>|*SBrW^eq4UxXv#;>sD=#ygrQXfTz&ttN2|ZqU4cs%&3y2p#@{b;VCpa>6CYzsJC=7%jjLtKnoIhrpy*2 z<(F}N?iKaJ)&MEiB*uZnh>{equjV2(bw_Ta{QD3ostSGyk6*~s7gb-Tu$5PH8IKai zO9ZO$rX5;s&ktg;;+yU|UA}_{3B1LC+HQjk=Kb-mKC$os=?)%L$@j`6r{7{_Vm|5< zhV_^3l|cfdnb8VzYZ0Dz`bFKpS6hid72bkF(JvMi9!jrFvc3)CU~p{ z&xkptM+`Gd1gbFd8=b9P#(1uRt@W}Ks!9Z^@ID^efxOd-=lR`I583F-(Siit|3k6D zD0WbEi2yw=y0~;R5E2;kj=tt_3F6g1zR)biA4qpQp$a4MQB>gj1Nh4Yr`0-3(xf{S z@OA|GCV`NRzu6S8Sas5>04dfz#_X4G!Kn2fQBhOp z`;HK3K?3h2p!g++hp1hfzhY4(`$_N5aV?9hNt&134B;mO%NPzzYb$sn0ndBk8-ALX zOb_KPTfWw(o%mZt0#*2~kiPU@c!qVaeV8@4(OtR`0oR=POmt=MziO zg#6^ALR+wjH?s3hS-MCodQ{1K0KW$=S2yj;$=Sz0~NsGdQseu?0*}+O-6`2pE|Sp3(Dz>YEM(0-rrJCg|U8URdnx^ zUNAgF?5>iFg%5h9jU^&-n45|btmLRe_3Mt(C&Yz{6Z$(f;np#&4&|!x8xIv#7_*4> zcjI>JxkgeX-IZ&Ydz7D6n9{i0h^eT;)e@cF`CLL9J7cnlP4CazdG)kTtusws+Nmks zK=QwlpX7Dao;4Me;b((I^hIx0xK~e0Yf8hJ2vlJ#EIP5}JXu}4VY2A7vxsN$E$66I;!6&N>cAVc~w-QAEVLvU^eF4KU5Up zQ`CulYpF?;#wUmU3f^RfccRf7te);{cvOgZ>GIB!vs7bsH>Gj6_5lk zrfhM&DPruE1A4%}X*f*Tso{Mrcoz)4cXk}c`bCZx*UnZjs3oelmRK?H&KWR7j4pXrB2YE5OgZ~=>}zYg*Yd00R?FreDpFJz zquBQVtF5&^&)rChvn|J`z8!x_FVbL0qTf4F>GDaF~wNMkf3 ztI=6+vLjsl@Ly-yLBwS$tNG3cEU3cxXEYam=D^;J2;~zg79aM`)l1E)t-Y>voUClS zSJ(b`!3L{nA1-Qln)-TaC_nn_uyp6(q?IM?s%FU_l%TT(?0=^`%C6G8iPLlsZTw_T z(L5OXr9DqB+1t|B-<4Fa+Q;pG?@71Cl>NM_tv<(PJTD@aGW4}G24AzXAtR+HV! zYqq}^%xXwP*0)YZM2BELGsTf9eH`_j^yDLV=2csrEuh;}Zzw17h@5WS4G86bsV8;x zWQ*zM1>0?#l}=>QGwY^X(^u54Hn+T|7Z?#Hx^7u&#g_3VcN6>57T>x(7TX)`0cacb zDyzf9kP%CC><3iOcH-8iEK+2KMcf|JGphO$~JP7<6i~QWJei z?*!#*L^v<8fTIP8{^@C&mK3Ru-f4?T?sK$+ z=-%eM+Q`2ZM@!!5Ec$^Hk?Q9hg;e|{lW88SCL&&}SFV_ulm-&*+vL%6({n6)n2XY= zu%e!DKANNk&#KMQf<)~Wh4t22qu z9_8CkluA3IPE9Gy(Sk(Fy%qIwXCu{F>6eID)255?tld*NF*A=uplapF>bkAY=M@%2 zbZydC6ncL`?bY@bLkkk2)jjp^Gb7c#7bA%9oHtlRHeG3X7o8*#sB#_Wt=sB+^n$rW zO!+ZJyzA}8Y;POTf<(KIT({ME>l0OpaLh4UG+og^ryz<9El6xwXVGsbM5!_c%PfxMZw;4Ud#M=9bgWNSL--^dB*iYX6k3M67HvK^&jaT*+85Um{Rd_uo1%`fEB7 zeZ2>YUe!x7?;#fq$pXJSZ)6R7ox7|%OvK+8MvAHfyX);L9W*2h5TisLy>&#SI$+v5 z+jC6qE4Gb3sm=*}X`tm@NJahO$w)QqYl0^Il4)DP?xI1NbxOna4gxJmOpYmSuXFQZ z+xG^S3bhk2^IcV!&o3;{f<&39!n*$+@`-_8h?sJ+srWssklLknIf+2k?L7tTb^b0s zi%wr>{;DS;jNNJ_QCpw|iJmlij-}Rqa<}b9wmE}p3D+Opl#(Gn5`n6;)DL#wUgEax z+rF$bUB%S}iE5L$mI5tEyxN$ib+<;U%h%e@Wi-2xPrMk{RXKR7lSH5@`d{CEwufys z;n@9_QO>m(n=m?1paqE;vzFU^d-rBuM2z42)A(9=fZq6VPk|OB93rALTb=8-YY~~033;diQ(1Juzo?O~ZJyM;JW?O$8-me(e7A4gt!-FLPRlS`vthRn`+k>T@ z)KkfNt*ZysD?363T9DZLdUtFjwc{+;8xnDTQJB8(SX;$CBwQj;wS3!pv(2}!-e&u5 zB6q7wYQ={M>bbcTffy`Eyl-^K(rsU)YG}4x-8_0pP0H|A9`BL~RQ2?8RBXQe&>l`i zT-kAKjz=-}_kr;OEl5;6I$E*$_GmXBB5I}9XN`Qj=;!I1HMAhn*sLfUHbtr)UbgzE zIDZ}UjXI^S-Z@5~rFrTSWdOC;m^@xO{t}I~#{|n0yhPz$hD&K6@w!`-Vyp9gu^yDh z#bLME{n1HkVtRjp79`x_Pbv+fBGu?Em56Ba^&9igD5PF~(@P>y6pbef*hsbM&+J6(SW%xh?A$?7``3{OR8{L$M77m<4i{S# z!4+HD@FjDKv&-)a3A7+F$giSmt8gNePW)Sgp$VgtXpR4L&IV=&VN{le8wmMH}I+Tb# z6~^;h`wyvwmMt*Qf<)>Hu6`ijwRcAmB1GCGUc;k_vZ(ARi9nTKrAGETZ)3ViGl0j# zL-@#Ur&PV8vw;>Q%6FuFEtUgtfvhcsLIbeQ7kCp6HXw9BKQ1&NVm8mlYF!;N_oNQC$F z5FYgHocd#7PKK5&v$)!y>fl3kZRIXhTfvN?`b?tQ?x7@ZfQ9X4K zKOD73Juv1nLkkjl^H*0FQa@i(K_}ulMa3`2J1VPQ-j@hejhI)-Ugz7KE)$WblOJ~* zl&E^TI&!oiv7}uo)syCMH}5?sV&3BRe0!}{%CXCZB?48={_W@QH`!J!F16}$x8o() zkjR=GEl8aHE1SK}BYG56=?$+}V?HBxxL)RiJ4Xu=tIT=S)8yMvzqj2!lyaaNFLvOR zdbE|o(Xy}ndu1i9)gseMsQ62C^Wxv7cw2S7l7DYYDGekpwY_fl?V)3B>%s0n^72y^ zuc&p8`Es-%F>duKR@QGBA`z$>Rc?x6^X*qp*jDuMW7e}< z)laEeW{u%!LBb{6N7+FAd_@8$;&Rdyc5B!g%ZoabBmz~XyZsr@cfGUiOw8PPL7g$A zJlilRn4<*=dG{l0|H>?=@*MsAn_!L>B<}e|SziA`Bj*>}4HXM})lgaYBz5obFo{4F z?wyc_`)4al3Gm=rKxK|4TH*L#%KR_Us&`sp6-$A>P*D9?Gxs{Y(1n^DEl6Mvshi&& zz<%|q#rt%v#nFPq!{I+HzpEWpyZV-*=NPiZiG2$8;Jb=?O9ZO$nM@|Xz$PlcUx(j4 zT92aziLxyUS=YK8R@YUxwcS&v-YC~j*W%x^`A7t+aQgb@^exnj9@i;2sdoed1r}4YpO~;jZY9qhzRmoYM9Kmc&W!SfhybsGMUOgK4Ns=;=w77eNjpVI;EoRU!Ex7&ZsR<;(aMFR1&N~{o@k#xN2}9&JSHzX zcTG-Fz|VvC4$dkOsKOl`iln2RWd*Bxanqin99ti@{K0|OmETvR)yrROJ4MZ2pJ&yp zc=9V_i*mFef!>a0`WJt(Vv{_0`|i0U0#)dbX}@HgGhbA`4nG?4nqeBKLO)NtACsJT zyT-Nn<4&I?0#zx~vZ`~bm8@NPjGiOfu>$Y;vpP@6ag3n_i9His)Phb&)%kCx5)tZH zf~U3il^s&J&HIm6T1{M`0hT>0EiB2b0%2l`fj zRW1ImffuJpY!ZR0Uxmu4E*^)~I!C73=04MWxXV@#zWK!&6)i~MoQS^jy5Px+#(MDO zzuFl%oBSMBTHSv)TK)C1poTL*8l!5~=S7Cp=Avt~jus?vc1C9nBE5LlbT3}&N}!Gw zBrv!1omZ3(&uyv2ul;4xkU*9E9PZ!N={qLX;#-G%3tau??DgK7{YbQ$cBZtBD^`k0 zUaX0pG|`hcs!&g$1&LZ(w7J>QXthus+g?Tq&s=J9;adE`CLf7F6|O%i3e?_M)wzi$ zk8D&|paqFXfi;vzJr1is^Se?SBiH3(x@%2dZb5yCKozcD$+utW!zOO7!AH#Yl-9LK z;QG^K8rW|hyYsO&Z!o*6KnoHf?N%zosvcD%Ryt9x+{3o9mNPwgQqIZ}fht^o(#}M| zi_A*nV3Bqu1X_^5)g*bJODSwuAx~a!Y$1UbBsvwisq9XVR*&p^q0#-$SAMX@_dNIw zEtf=~3fG1danYKGf9d4O7cG8dpaltB8`6s2I}eZOSciK&erBKriLg(O>acuA)xY-H zqQkuXSc>cSYVZSN6D0ywxHhD>`T>_#tH3G)El6PBq%Vu7RN)O%J-M*Xl6ox? z*t;kyS8fk+_F_#Qx_5?w`>oA_OX_~$SZ2{->bYN z0#)c;Z86CUijVK>@B%;57|uU%UV?9d=uE1qplC3+CjUFSYWqNNjFVQlFI)t?qx?kkas>{xSYs zO}_s042eJ$z7e9g+Cg>2bDCK`J6u|NSA+z<(=nM&dwYtz$zJ?u?Nk*lNaP<>MsM@= zi2Am~Zu;K&Z8fv_)vPA}ysVi-pbFpVP@Pw{h`+f9cYWMIM+*|t8yb_#98Yokjwc`Q z_enXdC03@J#(US|t0uV_NJuJp-$MKCah6Ne-!RZGp}(@y+W%#cRUPuZ z4)wv0A3O82xhO*H;#v}cs=aecsAJk#Rj)DCXg6x$+TmQCZe<^D&e3t6fa%G^zKHI; z^N5`M`unodYzb8p2NhETwp-P4{Vve_D*AanKCwjxtJ^QHG~dNp=8c=T=na=uEqg7U z+Esg2fd8JDo!9(XULsJn<4&yNIKZm*=xDp&`Qy9qtjXL|wy0ZUi9nT-dzn3r*_&){ zk-vovVJE9cu{+lWOEdka{i`TD##_~I$BU@AHnc?xU1B(C*ID1K{v6jFxQ?nnAWTy! zSA1ArN+W(plF?y!b{^QbfkdF{;hO){|colVWF zucF`mIy_K2R4BW?WwKTMc;2>K>;JTh82)$>ZUq;$y@d~OCO0qRsM+|Ew)@;I$oFv zw?mg%{wGl^wsv=cv#{jwU?t?eRjt3pS;u)0ttNW+;Y-7Bu||zLGi>eHUJsWlsLonr zRn5h19`1cb4_>TIhWw5i5oa+ujCmTK2ke1 zILE>V59YXkf&C71NcWg*YoTxQ%g%$&^pFTtz1;uA`tqAqJv*s5^{Bl=R~YUUQdr1_ z0FFHsRoEYCH)?Nyk?qw%=DShO6%zPMbjE9S8&R(KHCCrw2E%?ASTmoVLT#zr{bB#n z2ZO)4i4R}1^U0+PaqQ1XV{fJtbnTr)KXIEiDCR8@s6wwq>+@$F#EDtwnUayhaI{2k zfH|bK^U98*(bH^v%9&{rfhw=HMf7hp;?(JupIUo?QT@fj)wkH7WnWZG1657568d$k zRjpU4Bi(Osa^OG_nk|`?bbhAeXo;%Y-HPbt$cwg_;XvN!hwn&{rQiXUf5~ALRY+sn zGz*{JfzCPS;Nwj*4g4+iaOnAH*SS=X@XK|JomNwIBv6Iv(Q~vaAo8C}W@c7TpdUOQ zvqP&yb^b8*_kZ}rhwEzz{a1EgCSMkT{uOETu5>rxiuU4S`wZsU{-S{dsxXJNR!hq- z`qn?s(hP5bGgkD@m_vI1QT&SW+BF-0SGl1?psMcW5Us>&tJf;y&6qgJp@We*!@!I>rg60K7GeE2B$boO|sv%ud)6+WMQ?)BU2 zEb80!d)d_vYn6gMe<-%U%lRBu7XQp`|NqvcPUL+$cPZ#?J1HBRV2MTx64$2|vJ(Mi zs}s@PC#Rj*d{!nBMtxrM#$emv%BMhg;qM_jhrh)1Peh`4al(Vj+dgyE?V6mOZbxbBuMQy=O{AnfSOMBRBP~B&*^Z&6RLG8`{|Fm_smRRaZS9=X!ZNp>G zf`nz}cRO*a-+m%8oJ-k>0_ATXM*>yUZ*+TooUrw%TMv%ei52g<$DjoX>YH}r%n@6! z&7UK|{v7WgEjx}DB&K*3wiBKs3ledD%&9%LcJ)MGd=v>(k&m$x+wI<`xSromjJcG2 z1T9FA=d%;*44Y3ZE=KRNrIC8C)?u_DVJt3gC(d;nOn!H$)5EGZ;@zqpb&)_-X4>QJ zo*LT~T9C+mzI<0~<44c-MeJqO`G9>iu%~XbAn_pVPw7_knNCErUFYoWYVVi@$MCmM zm03?t9oG`!_PvumjXZlS(P%*;vn~8n%pCw}P zyAZplE>+ma8wpfpw)Ph|YInmErJEl8C4*ZVXdTI4^{Kmt|M|MhUUf(jF{_;T}o zw&w`*a*06;64zWS*Rl1%0eg!QktHhfh>hsmF1rs}kQlv)**$fG7PdF;Ip3A4XCpq* z7d%LyihQM=_>-$;)o0hY5&edhi$Mz#^n8jfSAWXt>eUtYT*YgaXtW?Pv7E1+X!OZb8NTG8R^8<*Z4ne_tipMkNThALKXE*)mFN;87IAI7;NbA-?CV~ z+z;beqFCp+A3(FE<*7#Ohzgd2bT=}#SKM1bnoe01U2iz%{18*+@>q@*B;+*y`u@dm zNX=n(IM`cC19u0JHkq6@Sj6|aQ9i$_dvNUAxX*z54<=KMiXNhN@vml2iUN)Vs_^;f z-S(i8V#B0`-qUxrl5&Lv=8#U^FX|`823TXV4tT}TtKmKa=8*1>y3{0tRCk57yF zmCq!lfx8q))7$VO;o?EBT$YHWC>_T>BrvzMhincLeRH%j4<7kOB5;=iX*z#3VURd@ zwvwg!=u`toecb)Pv?cypFu#a@Ams`POpi`%O(`pSPM#2xwyl}KITh|q;NA&6c{finC}&Me?!EOS z0#$OZs)wfCqV$ z%OU0-Jjc@c$4Ds+{9Rlj(f2sh<5{)a%Ph^a4wnd2;VR5z`Z*{&k9fH>=6W%Ii9i+Z zozQn)ORDhWF0SUr=eR%%61c}fx%yO>k1FVHNv4|@(1Jwf=<{bK+@f|H9@INmtka+p z5`ik*d9g)k8pvB(?wOC)x-RX>Ab~4yid$P<+wC zVZ6|zd*(aUzNtu{N=~Eyl5pOxLQkKYPe-U|K>~AYJEuR8j|ni$-MU|AnI&qQP2f8g zOq;&ciyg#sEE!>*+A2kQ4kYlIXs!0VE*~-3)iS)47suHPzPrJ+X?^~rHFpm3h)pI; z<6bBan|FA7Ozy>@(i{#6d?p&Bj`U|cD;BdTb0$j!s<1@pCgrgShQIqZ!@Yk1zp{0W z+5Jd5TRFLddi8J%bNYvL)|>t=vukvnAEVOQhp@IZ3wK|$j8&?6&p5cFJO4Cmr1y;) z87v{VgWCS!{Mcu7gY1<5i)eGJf#F{DgK;atpX2YgO5W++vslJ|A^3wF;;zFE?OE9# zyu8?C-sqHWPXlu(=W6}nWyYEj_n6bB?!0Byg~wM#r?GvMtImh#oAVa>Z@Fp|Jf1br z_J$4F+MO?Iwk`ubORn}W&&iLa-Bfn&>ds##ZHj%G@+Vi%I<1PG zTk*f8@nG;SJ?jh#K18K^!gN!6^Ih0Wx7yP*ZB(u}d+C0Eab0T?8TW7L{+^IvI8@$d z9O@UqGx}Z8hChn46B~AIi6*m)ImiR()0!B#zJgH-CxFuBK|z2(RFQG z%XoXPkd}!(&l?z{-98v8iT)h(jn>Xi z)7T5_)dhc!CEBP?vevcfpK_ClagG73b6FoWRP$v4iK4t$D8N+_kljH9qEfe)FB(uj0))=2H0leSF6y+|pD{pF7 zm_wPU>s*M%c>OYXDSv)&^$o?=u5A7TX_+{CsDUx4+y|roMSqUvgS1RExt>GhP1~(S zSM0$XJbS44sd08+f+Z~z#~N=jBIaCRt&thMJm9w@eUWy_&VG#K`jeM}1!1-tFC zpid0)SnEBqW(GU&*Ga8ZDI#`Gdb<5@I!#sS6-%DcPaNEOLqP&n=ws-uWmGtKqnI+y zXKb^e1&KB>yJ8EM$V|iGNH`zUp`8Ax?sWwTRH5giljI+qc!O_UxizSta=P@|<8u$C z+4GAY6=~Y1uIt1dg8g`0;3W%MkdV_jI_Rl!_DUf49@S5&QaQprD%YR?8tycNOoVqGHj41hyF3fz)#d#iNI~ z{Lgg-2~=V0q4Omlop?8sANPG%*@8Z5<(x;F(mL)>ZO*=-UAa%0rySDx-5$<-jp)id zUI{vkL`zU*{{834&-8z4Tulz-4F>g7CO=Ekj+_7dE_&@u;!hff2EI0$HSWiUF6yVD z1qt-v^i}-$C&r7S{rTK{&peSpRp$Tyl6QPee^!!Bz&s+iILfZypTXu=HGvwJiNnW*1KPSksSCW8ZDWB+j4~? zHSOXIe#&~J2a3@m(TWx%!d^er_UFpXRm;{d*~s*M!s7Ne8VOWo{{1IctM-TU2G?8Z z*B0$PjF!y5ZMnkv3Z19k8_o}|ZL2pKKHiKLB(NRR_qD;_m0o?WSt4$S^S?*jHDAk~ z&PI3cr0(&2Y%cjSogHr0NyXkpYn-qXN{_iKl+*WUJp~>BiM4qG%=>#}u#~p8{$c)= zWW5xz-6!H#I7eUGe4B{foIitQ>)J`JIrw#Km9pu~jnbx*96gh*TSIsKUj(YqSJDp4 zr1O@elefmiQyS=_@R=~T6zgTgdyAK!qNIJ5p9AxbUV=_w&pT+jHEO03{wbWJ1&Q#z z3y(Dzmchn%usug&pZ8Yhrmk9x@8KNti_g)s_L7)s6v(1wCmZ2YXsxx#su$MfKVKUCH#QiE|z)nSP9RCTNW6GtW=0G=8ra5~#}j`%fR-_0dV_GKvm}d|Kw^m`Kx>nTIt8gU*Yd&{%y+%&V&E>tII3@_$#y^L4IDz z9CuBoSQ?$9XmtMCIp{F?S|xK7&ivb!6PZXD@5~dTx^j<+5f-!{firSC{kiLjA)Na2 zYdxNMqAK%m+jHQomLgNojBHm1jpv2V$D#!ZY?Tz-(DK1}eZLQPs&G?50#(>T>1)}1 zQ?(5v-o@Ve9L}G&d9I!98^@Nm>!j|Cd8D0uCs=9nqOx!Q<+DmTK08p0d=<`dbPgJJ zLCaA)j#X~cN&VXHyw-O?96Mt7K6Rg+u+;izeasc|qDY_$$6dN@v)+1TX?8~~{cSjZ zuGEJ~szMWJ|1Aoc(4zjs+)XmLYk@`7~O9>|*YfowQvYkc%`8l@D7_XgQb-`yB zc{sEnfhA%xt)3iY^{L%M`GqU+_4%c z)Z$j1JT*Q!);azXouOaQ*qW!*DlPVdoCXq@Tk3=Prz)vow=K3&4eL4F;j)(AG@gB@ zkxZ`h<6r)T4KDbD)os#3ldmQbT9Ioje6a)hCZca7{}4l zYejr)&7Em1PmxY4jy@DAJvmkFar}iIdn1J7nDKGXMf1zpG}f<7CsiIl4*qm!CkFMf z^xY96&3(ET%xAeYJ)K?7BF}>;#;@q2A0PKZO@AFKjkP1%2gI%!n!#3-@1)`=N~i7< zf9o%NPAM7BLpfSDcXG6fwdwX7iSYU8?W^-XV*)RwXJ0f*dU7N_WGi6Z;+D?l#{1G( z+jZJ*!;xM;{@6WCB2b0pMrGB!s+wik8pEf`IF6<0*8i=~i*{*DT=Z433@MG|x-9ZsOaG<0tgla}v2!(iRrxt$3!Pz4lR6s_{{(SNed`4$%laE>?94!4 z)whR(#i)^PZ!vT$qV+zz$I7v-yD44~)N``0i=`Ir*3NI@E6sf=fZ~ZtCY^FZ0krwLdd7sBAXkC zsp~8Ap)PY{0vo5X%lm!R0^jn*bSad^BC~c>e^$$7R%)cNrcZ1{ky$Op5!WMzsdE;N z<%7Q)_NBBn;anOU>1o@utn_s=%NN*ARQ&#%;cxx8(9N9PHI2<9)zamM%s+h7SeZy$ z8oH}LZ=Pe3(xfAOht<^0a;SMaYgfxxomkh?^0aTd{dWi6Eo+%eJKk08xmvqH<(FFD zWC44#aV(#o0c9*p3T3bzjCQ#X{5CiD%V3F}Y-!|~UqC!t?oBsapJTml--~S(p2ohv z>ZsN^^*(k3wX1{o+N;5HHpfXv9Czt-^p5}aq6Jf~D>mPL&*s}<#(`GH1^ z&77;*<_W*%Jy$Zv*S%LBE14dZY7l%o<{f_*)1$bx`}$k;ihoA6%$b)h4Y#c~mCO+? zu*hX4a}=dFKcg#IQz!iy)sR3H&b8XPS`-p?1D7_QbkI~hh)mg{(!5l5)rYC8C&5LI~**mH7Sevr5y1vilg&LecOqwmZ ze7~*Li;B0Wf%7052ZJjZ&r26lD@+XGXt}xNrFQFl9CP~RtI9L7Zj(dx+>?qMC)bB? zoQr&Jc3&G3DOkb$)Vp97MbWJE@#@MqO^gO{Q>2+Cs?x&JG>?VC{!Av*iQzrjy@v;l z&=rFv0#yZlE^6(P1?v}Q^VH|+jbdN!bTRy$$4UgMa3)JTUIo^$`zLC%xZa~U)sjurFVt&0wNTrMQQMg)$HYzEeT_&1ZVx=O+wjsgnL%Z=D z>)^QKV=ZY(Jlmh-E6rr-?UKh?qi3#mx=+)g(%Jy)qqg=^8&CU;vBzy`>?$yYZ3(zy zthm{SWBDw0NYWga2{xGgJeC5TZ_k~<78w1ErC0iJEE7~&%n6!*gE$st9|y}O{%!ou zrWw6z_T@cOpK2B9-JNr>PO9(nXPU!hVIM8A{^_oiule{)>mFm!l!hFCcm9M#t>h)a zrtj{k%4rNbSC!vBTU_(o&-jtU3EIhNacs>sUp2D887;@iI6Luo$+OzE2XX8n%^xTR zMACIue#Q;<#k)1fI=~tXs&Pq6+aCY#sCK%`N2BTH6lUJvlw%!){`Eq;H6fl=T-Z^q z*8HWmwO;(cGk``FN{fTVmRS$BYss+=b_{%@&Al9F|J_-UueDJBIQxi7!^ab6)!|@4 zk?na0-e%=3Enlc$A88gg(dW80+E1`PG`?az(Op8{J$P!?Xyb1OcaCL(Wi>tJj8?cr z92-XS2RgTO#e=)&ThG$pSK;`ySUyqv6SURc;+e@l)}EVJfVcekk&XCVRU%L&msRSl z!CFRHGgpt-p_%3U-27N@72&l$E64Wwn_twv&Jy;S43+}TGIsam%Y1vN2~ocowhZ6dr?ooE zv_((j*+ctmDJ1ZKv9(@X z-mnTqKfU!x%V~|bw~R3Z?r1mO#j^vX>eBv^R{Bmn>ujI(ZM&9FTpB5M5o zR2yOZ$qBY&+L?ItS>NyAFK(opcw(1(+LD37KIhboJKEF6e@YiyC5`7V!bKx%jCSu` zUS;9ZI@VK-)7kLXE!FBfYFo_>==(3P4(h7rp4OF9(^>9>medl*X9*XkQ-}1S(^L&D zNThjtS$)FNS@B`*iTK(kOcc2HOnqIkfJC6G_nMm4H*x9g(05xJbW55TKd}_s@HL5U zHU3`<5@FSATJsv||JFyR@ngkc%SQGyud{*#s#d?PU>)-#ozq|3Pim(yVPwd_Cip|^RVnG5`qC-*Zea{T$Yft0t#z0Y|-7}V@ z;P_ayAW@-VQR{rK3^sPYEsa*MT8Om~rMY>|KN5keZeH1~f&DYsyA<1VxV~>Dc9tm1 z6BDah(1Jw&5!tP8`(M4)QQ;Nvm9QZiWOR9jhX{npONdTa_GT;9Wi79@fu z=rJp=X4c2VCRO$4|Ah0odB#fws#f30=JP8pgEh9N!RGqt;aWJKwb9*z79>UvdFwrq zh(Eq$`cDT|)Edgy#f*^%RNeeq(yU#|e2!riva`hNp?u0}cMDpOI9IiV*^da@`jggb z)z`6FAII{eH^xZ>s?yB!%&}SqOS8A#OB>r)#~%O(6~Ekec`RCxsD1Li*`J6%X{>6b@Vk?1@b^0ni9prz zlKCvjtumN?&(^L=yHWtmh1Ge1EDkn{533){p;(tuv3yA?hCZL{yTHJxjKtED@<@=3Yxgi!HK* zl|6FdB75~b^UQ5glu&3z6lIIZzW?qsPtV(#_x=6Xr*pn@?l$w>bMCqK zvQPpE7!RQPahL_TWu$=Iub0ROs1?@Eiv9Drp!!?#{?%MOfgcN{k+ia{6G@A zq7w@xkhnQpWk1bG6ThUwmT6jK^&`@Bm7Gm$F_AshDhV4(yO zS$$h6j#5JNTPqgs7S|tJLcBd@$Ox!qF*8EZx^sT@Ty$I24v}lJlq?C!b$}8`^gg~x zfhdtBjopU3-iwNrKBVKS#w;YDR@9Fa#khg_(tMuY3-IurIPIzrSu(WT0ZJf|Z>}LU zHhQ0LLwwHlYW{6@JfYfUT;JCpNkUf4qkBw%!yVR$Ewe53a? zn(yl4012qoWRbs7aYVn&fc!dLqJR=e$RqT|)8>%-Wp~6I$xW1yfLa?>R?4n* z3#xx>Hx8Uj(w}9B>)O7thZ0CcoU~Gw*D0v}UVZ%#NX{*rEM5%j>i`L;1${V%X?-b( z4EGu#ruMzWLJ1_G-zJYh!YFXUxN=S%fow}7kn)JI-uZ5-J|>z7B$Tvzn$Kw?kTBLu zW8s&(i38U6;ZHUR!BEn+cP3|$nuPAwcIV|0;fm#cii(zJK~UGCKvf~q>Gvv*)571kWqw&l^+TumgE-tvdW zJRd8zppjHaKrI-1XPBN%_M&AsdLd65NrmlW)vtn6M^e?XPyh}GV@Q9X6!4jCdxU7 zM(%#m$Xz!Yxq~BawC)p58@b#5ET5~1-1R8yL@Gw6iSg!LF&s;6!)|ff$Q|q(dggLx z8yuOt8=c#0j@$RBL*KaN?1agixb%hTYdk6hJCigP_n}EJ*SOR z%IlF_T7b6GsM;07zX(cJetgMkqiXlJmvZW;nr00$q|%VyEN6+m$4K!+&g*irG%{#B zk<-S=U~e*vZ`n81(mI|b^D%?H*W_~A7@5CmKBtY5!M>p@4QD#wta>f+Cj3)YDp-#; z_1|(e(ZLaK!pUd*`@+)JI!ZTOCDdPf8e46YKs(Z5|cjo@2Os1oe2CZT1w)idOS zpTfyJo5R9^iH+@{1QJKbRS8uWlce}1z1=Y7I=CNVC^F+xLbcOpGl zGkm+>HV#T4@jakIIJ`5t+E**f9!|WUHp7*}2P&Zi5+4>-2=8|#SI4h5)#^@;$8mU$ zZG?<~TEBL^7U)jQYX75Eb4!BVG4{XK$pK0r0VB8cxAsjvlA|{ce~efpBcRsf{^f#e zNpkhPz4IFbvMP2Q9_KiZhIw<0W>>4aAK7co_k;PFyb*w(F%J|DS3bmo|(N@9Yf1 z6AVLHDEV*VR^1d7^F}>N(%KdYPOrkT-GEgxLQQM6GRumDP6jE}zFM7E-Br(TgyVmU zwkn_m5+2Wrgjtl(#Fh>$_`x@85r&^_za%4|)~a`T!fc%sbVnNP;Y1&0T@A&%JN00p z1QNsk<_Wr#(Aa3u^9VZDc`-h)ZjOwQX`u%~?f1#mbL!k>(J13JjjDy;v1fn;5(fur z2#t*seLkV#L%i{z7abiS0kztd+!NkpYu}G-%TMUL8K zV`S@tWCYZTkIfQRoz>REd!s2nGKTJd4PV7V2_)oeEcQd_p4Wj^xX%}bjDT8ZO)>?q z5bgW%Gspo?nbs3$kKC?+5=hK3$`p7?XxeI`(FE)!=;5~`GGqkQTCRUfFd3?S&!fkT z$2&^&apltCN+^Luw=Xw^VNS_th~%q%eK-eO7u`mum#vc#Q0w)abYUA^o6@X*EIT+OQ}8aL#i z1QH#trwO@qCZo>%p{uVI;kZ*-0_SCGC?lYj{hBo4-1emE?{4&kaQw)~mCyZb>;NT@ zfa`d4$9nxRTrumn%DQB+jDT73MhfZ#(!zTWcQ@% zJ}(Phf^9SQh{Cx$_K<*De|y{%ESe>u2~yPH=a41z7QTJr)vnECHlS9lwT5UXec@sn zd*g$jK8b&>>MEcF5{qkP2ocZdeT!0T>Hc#sTsG*t7{03q3nh?P{qL6WyM&+MXYfI81X{L;TT6^Z)7UrKP)#FO2>xMtNv?Ps+JFrj!3B8%Og}@7>x~(D)bj3ad zlXnef$_S`+$@#7jvqoEwja@A9l0lR>(1nE(NWj%zhIz5d5Gzi2kbP_B$Ox$Acrr)G z^48Yl+&v@wul-nJe2uk(5=bnI$q~*{VvE#+=6z@uapU>q-1Z%bkbqk0biKB*GpU}* z4DFSG9v3Yj9`mzVD1pR$^IXBrl~iZx{od??R3$;gOs75v38Yo0gF3S$MB~rDBh$>Ey@si%KYgL}KSxLh@o#9Zy{U zqb+%KYAC67`jU)*T8>@c3T6AWZ&bUUtx2ssXVOWGQa~F}tIYPTkiDOv1<%y8rCpzR z5^&gxJZT+i2PKfmOM55i-_rK^slhz4_v%KH^mi+u1QMIG-U*hJ(EJW&{_r6E7(-GT zl`12k*3B=Kg5_U=n2+jGbs0R4L`Bpi6^HE=Py&fVUn>OzN@%{Kt6Tb!0ae$I+SZD)kU77n)m}EzZcj{7g??7YJ1rx-9{YqIV zfdur(=n2Yc%Sr1!3&lG#+!as)3Fw8<4D5H|q}Q1wVY_85CH>{(wBB9j&NNQz-PK-` z#yz=|RJ{f{GdG+xIN3-w#6DRGC6JiDD2-Elcc<>F=ev_fg^_KRZ+Q2LQ!)Z-HT{{! z*;gf1f7@M0hmvN;Oi<^@t1Og2Vqv{>uErls^Iu4QbU%w0+!tg7)LPW%CTDD!Tpbs+ zJ?~F~=3hX;ZS*-PfrL?qTU?EIHz|EOu{vIY5|{}BB%oGttqkt}_wJws5`BDbb87ES zv+^@9XE=!|)_8YN0tx!-!_|0qG0m-sakf2vGHsm_5>U(8{VvxtS?kxHaY!&yEIi?L%q9{#}{P;)Y>^Rm#gvaGX4qTlDG5mO=72j5=iKAxm*PG2Q{l*a^ zgm~)R^`hP#B%oGs4pZ;${?Pce&YfaIbf_(b?MTWdzi^ zweTa?je2*Vf2&9PN`rYgs9!p|vEW9$Oah)ei4wypacjN$*8rah#xrwRWU=qO9VDRE z!Y)?0}Rs8iY+hhdPGMHV( zE!?X8#o1XLj=e@PeCqiu7D^ydO~wnpyZo>fLp5)OdGL z0*RQS_gsy4_lG!RQs26_(1CE?A9j`3n8ltAL-yb8{e z`hy)MugLdpS1c^YWZRMJG6HJlAAZfHtRvMk0E=T5*f@9qsn_s@0!ko3qmi68mzUk`(FGR=0o=Ty^|48>%8h7cbNKW8ZY(dpy%S3>C;K8JEIg(0*S>} z?{RAHP7^iAKh=h;ygZClx+lm8sI@XSi&J}dnhd%VH`pAvN0*O}_HH7B( zJZ>N&H=V2qdE46o5>U&dNhVk0-L0YCoip|B9zEWxfD%a1Xe3wT-HoN*9iwAN-sL38 z2&gsv#Z9ipyDJsxR(|AI<#lmtw7HCcT8WYAoZP!B zqTbz=f!kPU18O-gNat$2yOq?ttMx$R-9ZT?phre;=%(J?t^Er{2R9D|lt2P{Vf3Dc zqha`}TeNE0{U}ufH#_CoVFl=orJu6*{~BkDZ3d21K>})3-fpR^!iCi~t~GbVjY>?& zINl6F2_%-)Wt0vy!&GDAZcBSSca=4{U)5a&38*z?U9qC?KCO*y76#b1!x%E5i8+E2 zNE9wup}2EfYh&U^BlOEGm^`vEL}jK%iKs3$qEEKUD2%<&6L$-t?RVxgCZ z3KCGu*UgR{PDhMplz4RsA!5Ql)$p$E5R^b7t4lxjHXU!89^CJ>l&q*9iMB@W=0%sT ztlOh}=~+;#-kZ+slN_2yCXFk-`)*|Q{3iGu(nE*C8?%kJY6z%xLZ4yP8Cx3SwYdQ) zHyeWo*_oqh8;c!MhiV9@)qUa$2didUV%c*e@qNc&99dBhU3^gAp8c06l?rNQO|74} zM@LJTWjDfojYi`^g@aY--eE=Oj`=8wzB%X3o+$P?YqVb8WR&%0YHTyi{_Fklajqv; z_c@e6A~e5+a_n)fjj>He;F;5!;-#${`5^gcHe-}1bR;w~8?D1pS~ zF1?jM==qtNwkp4{6lYI~L@kDSs2~Bg;1`3g(RT^KSFZ2l2XtwNpac@|E1~)NSi3~4 zHsnUB_I|Ara)Oei)@^maN?4kaB$Y%QTP586tNA}Yg*I+6j_bTxY+z!9pac@vV?GGe z1|(P4BVgTZ9M~aSY??bt1qrBisZ*seX<>5p*v-{-!w=t^klUZk5R^c|t>Ue;+NWu& zX@>Ur+fHkuKh{zO38+<-`&yWk`5zk{4e*sNV~9aYXF7kXkwBu;r{}`w-`aZQ|1v^1 z+6R+*jSbO2y(hv8lN71EP^&iGyB52 zeecde$+s8j!b`f#P~ARW&u<8OXm4u14Aw`($T$?uo6uhzD0!@xCNv4w^tpQVY#%*+ zQA5nSu$Wx8+Kf``8zCrx#4@@brAznVY5L>Qy4l2~eKxX=9i@T<)EaW_mS9Nt*J^C& z*Ks3{KbYX3Ic5k-Ao29wZ7H&?v2oDIp4{1Pjg4(ARgi#M`)c1641=^b$~zj6hh4_t zd{t)zC6GWb9|((`NOe8>zBdvB>DR}$?()0m*n;B-S@PKE3%@G|_0n+ILq zT7JbvM!+Z0yM(-ch#iux@wvD9GJ>@!7tH^tQ204HP0toNx?0$f7A!42&(dtJSkgIL?BwHvmxkOE zPC2Ha*n2Mg-aGe&yuK+YkB+Nu&KZLDuVhqrt9q1tD*7t8+&&^!RV>15Jo1DP?-bOY zO5n(VG`-Q&&RHz=JR)|zJ0C;Itv(q7pPP&Vtp3}^`t{c0raL!PGv5Vas8#qXSrBR^ zqt;zqcn>i}IOa_=Ii{*hwe;c?@zxYS&Zi&{+hh@;%c>+4zu1L`MChv&p`%F>8ck_> z)5gDVY^_626I#*}noRHYQZi!-QCm98s1@t2d^4#K{iUT^b=*a1Q<9H9NyM+r)r#QW z|LiO22~CiIGjcd{rsuccN>*lU>CLIn(u1?Uh|ykbiz$U@XlECGxPE(fU;w=@yN$Y3 zC1=h$-06BovH=M=1B5igaAt>;X}4@xU0NzAfy9wcqu6gc1?V}g6P=5Cv=S~2=*>R* z9gd--KeA;z*lOmVaQ;YlVb4Cm^}pfG71CW#u)J`#3`&;XJ<}fpVQWbtqA8@ za85*TYu>eyoqqJB!+2T`NI)$(Ct{dKOLe(R^Vq~ZdVV4#pcd>WnyWT)L*lrZF|0YY zajkf(!?NN$)S#;iKcme`hplwRs{SRbx~8vKcQX$K(*HBe*W&sPPou6&XN5rtoXHj) zw@@shJ4Q6WgEwLhu#1`+aq8BE1f0q4Uh`97K<(9+`lJ8ajtBWeQ@Pz$yQ!-N;~Q#A~+;G3V7mo2Ic6VpE)PX)SNF%d z39tCp%h#eaeoL{N@fO9!v^?ZO-*ZU77NPqfAGbl3W?%Vi<6v1^!S<1rmU(<63^%27 z>Q1kv`4y*~Q#aO6XKe}F9Gp9aBJ^F23HjB3#8zCrx1f2(QHFN5o>u2MHcG+V4 z^P^OdfLe4;#MR8HOY6Ad8+1;6;Hfmf`i}$>bRHzFv1sO3z37~}oX)9NoRQ{N{~@53 zuI^o~W=`$g$pC-sI))5B+8IF!B=lcC;M8;K){+N#%p8K<37)A<#rol|$#tKiggYRwzvNaxi0JDZZO!?6lVy0?4J zY3I~*7RA-fsc+Fab!$4OuG}Tful^&k{qlj+&Z+Nwt>WZ!>ahON{LQQ=6_ik~h||ug zsqe$p%&9lgIdvz?&8VHRQ=wsI?>SHK(3aZ<3->Ih_m$Pv_KA z#LfsxAW>QMoKw%KHGSUwm63Rk&Z+G`)KmLWoOVtPweCDD;%es9FX^26A~PCKzBfeW zQklbP=hW1X;Me2VIkl$G57Iex>CUFO|7gjN`i~@2k-=%_)YMzz)N^VLLH#I9 z{V3${R`R3%Bk5L_&S~e=)LZ(`+e~vm-aHV~bbQH9u|LAz|F{tsl2QL)r>X+fG{j2m1KCpNMt zDZg7{NPJnZ!>JAY< zG%FB~s(fU`>iar;)aYb1m(JlJO?MBhos6E&KZGKU=3|GBncNjGf=cMjvT#&6ms^|W zk*;;&+YWxoMVX4#zk_W@Zb4S-(^1~ILY1bh;`715;(rOxPtjR&=p z5jW?(;9mYyp-b+rGJEu7qb`3@cTXGK!>%^Y>YvFiY)??5Zy=%rlig$ld=kT)e;m!fKW4_=L3B5Ec@jsaB_V6xg&&fW%Aq#|^`&iP zdpDUI^Djx7Wzt_9(@wZ(rWa?v*bzS%DRLoUNob{;D-UUpcZ7TNEvb5jJZ0Kpl$B#5 zW-Rf??gR zOp*Ri&u+^YV!yJ;LLz+GHC(AjvLx}_?9L6}drSH|u(m7rz4a{=OKFC=VsOj8e&{}l zcrn$Jo12$`CQ(VBiQ~CCEi)tnK8fZ=g~j*nrDAQ=rGgSjKzr22(s}VYdhXIKtB`;t zx$k4j1t;C8u18VJ;)JOV7o;0E|IVAFWH?Q2BZ|8!+xls03tNO?Mqm4E7r9}BbT4b< za97TK=WVHm_s4s1hdyLVZB^Q84L830PPL7^!5(n9nxbR8Y|bI(e|L4 zXff@T)K=X@Q*PK7O{w0^YR2t6ruiF|mY%Q?b<)m1+E=PM91Wx}S^1rMcj{3B?a{mF z7OYFCrTrpOvlcQANV>)x&akF7nX2%v-1jlUNd@V&Rv3&UCXjVQtR$as?E8)$&$*RZZd-l z`mL!)?U^a8d+j^br5g0yAYr(vi}Xf8El9_DwdW3;%&I1$)~@R3w6%xSR&b2MvFo>W zBKK{)W*pNwwbLN`sJtsusUQKh;M+`3Te$0<5b(0RullWp1k{4>CXGg=CfJ29H+a1Z>@KX0yZlU`@LPnVPb{)3h}&ovCX#G?gX-(o(@u4c|NX z7SL4yzo~X(l}qf@1njBXCs+laa9ipdXq#c0%)Otus@HRAlt2k2QZFCu_o+=5vhh`a z6DuRP*q0CgN$DpuJo0l zF|rp!?a~j&OCu3#L0Ud`tG--~4;h4|n#1od9AofFbXUrYYw?ZTY$O6o%sZ~)!re4) zl)^ul^NrP(D&oV6enWnY_@7ci8<1w02mLxEWIA<~Y`|y1+Ctm(#&oB)iM=DvOT@r+ zj>7g!ns48{qF|xVXCk%FgH@c+g`RVFQ~Kq!-n7Q<_8CoEZDmfdt6FP*y>_qJ!)|)1 z`P;LVIZG7Uuc-Rlo)qFQ-2@J`ZoS#dn(fwn+vP-)fcV5KKNF-rhc=+rWd9rN$q}h& zwKR5Dop7})ZCpzt*6o_k)p;UH??lg~`fN1K5mo;hT+&N(P)^g#4CwiApO4v1cOEJc zP_l5xAP0-DNm6a$lNiSO#-x7pt-Yl$_2xU{g)@INy+*Z!{!cZ3!#1R&{rswUo%1WD zn!^@_TK_Hw3wrOgU(uMy+w2O4VadjvC&yW?ce+&e!+zH6ua}y632X&=9&+4C#h1SC zrKlPtpcZUtnqj&)GGWo}jtT1b96lG)3uoM9JJ{Tm+K}#d9@8%I#J^REY67-9w6Vy0 z0jKvt)9Un0*^V9w8_wL2Y(NPlVBau|t=Z-T^M0Eo8?fDBf0%4BOtd+bR9&i1hIJAO z=be+rZl88Xl}>HxiJi2c<8{iEi~4D_4hs_mn__LLS~Z-Xu>8fWKI(eFKIn09u%P{Q zhCNHa45200yZ_Xe`U4WM&rQ}Y6wLN%`k#6%{T22r0==YC!M=s`PUq!9A1Cd5&YWIm zUu)~{e(F-ecDJ2S#U314x*KE+VB~qy%0kwK^n^?n#TH^G? zBmMT&*(j9?wxa!UmBMzsrrltx({Z&d$2PUqC+U|HwrFXSy21c@XM=i1E^pn!!Jf7) z6U!t5N+2=Iq?r)1PuqjpyRO(-+4q+E{QKPA!kICew-bJG;CsU`^M@RZKhUw^b#+_8 zc84trTbkbQ+)g)PLz`OCR}{8M_K4}M)hSK;z&51bM3c7h(I-Ns{(xf?whtWB^w!Xu zk@45rZ_;RoEvjSgpd8gxQ(H*Wab-Nf_So(IQd>a@Bw!n=&m0(T_h7S+)K*Xe3HTP! zY{QWw?4vBsN~MA|Ogv%4{WH_lJo|z^w`_rCj!M_n7LG~uXuDDxB~X&nmEp>+X=Y*Y zNpy|gt33X+;ZCV_p(L%MKbz>Tu>rp^4AV6yC9&bi7m^J~KpSx6Gt6}oWmo@d8%?P| zuR$B|Nepv0p6M5Je}gniU|Ye_*L`MZVZ>|gNE~6_BEclTn`8ryxcct`q?4J|qaBV= zhI!a}bYi`=rTanD!S{9V+1=VjlkN$U^w z8+Em5f_k(=>#$Vc+WQG}Xv9KYLb?j@V3Gah(dN=9fhB?M(|b(=Zd`xOZwb9!-c_{c z@pOrR5=cC_`;*<#@0L_T8uy8s8{dD7LYl+DQo)kI`qA^-cPRT6O$wCi0VR-ttw3*A z8MqiPd6v$%+_DFK?6+8X^h*JnHqe#NN?fdLxT665nBvA)ky*-6dM;~UU3L6wD&GO4 z=j-`-G<^q~7 zlMzr0#@^{FnZpK^EjtCbE&hcH^fxJMITazt^1*z)CHs|uW<}_$=)@0O{YYWvScFzC zbE5SKULTBE?ysB6KHrFeQ#I?yfQkQJ*;Nj_36^^@UzEWc6d)ujLhI*j9|H(R5$%P&Aj`>6Lc7_WSQ%Hn5{9}ZqTT-CpE zwp^$knu09Xd-A5EzDTD#gw7gE34^&0&<52kvU%?m83DCmP5?ax%EDC`er38coUYfx zybPF&5p4BMQ0EBmZK00bO_(_`F>cmNMKN7bgzNG!*W^;y5+T|_8`=KT_@lylce