Skip to content

Commit

Permalink
Replace error raise with assert statement
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Apr 22, 2024
1 parent c6214b2 commit 3b13b38
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ def test_tied_weights_keys(self):
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")

# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,8 +2025,8 @@ def test_tied_weights_keys(self):
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")

# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
Expand Down

0 comments on commit 3b13b38

Please sign in to comment.