Skip to content

Commit

Permalink
Use allclose for comparing sum of small numbers (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
undfined authored Nov 13, 2024
1 parent 3284742 commit fdbb76e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Old ephemeral checkpoints won't be removed until after the latest ephemeral checkpoint is saved successfully.
- Made GCS uploads more robust.
- numpy.random.dirichlet() does not always sum to 1.0, so allow for a small tolerance in validating domain weights.

## [v1.6.2](https://github.com/allenai/OLMo-core/releases/tag/v1.6.2) - 2024-11-08

Expand Down
17 changes: 9 additions & 8 deletions src/olmo_core/data/source_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import chain
from typing import Dict, List, Optional, Tuple

import numpy as np
from rich.progress import Progress
from rich.table import Table

Expand Down Expand Up @@ -54,12 +55,10 @@ class SourceMixtureConfig(Config):

def validate(self):
if self.target_ratio:
if not 0 <= self.target_ratio <= 1:
raise OLMoConfigurationError("target_ratio must be in the range [0, 1]")
if not 0 <= self.max_source_fraction <= 1:
raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]")
if self.max_source_fraction < self.target_ratio:
raise OLMoConfigurationError("max_source_fraction must be >= target_ratio")
if not 0 < self.target_ratio <= 1:
raise OLMoConfigurationError("target_ratio must be > 0 and <= 1")
if not 0 < self.max_source_fraction <= 1:
raise OLMoConfigurationError("max_source_fraction must > 0 and <= 1")

if self.max_repetition_ratio < 1:
raise OLMoConfigurationError("max_repetition_ratio must be >= 1")
Expand Down Expand Up @@ -195,8 +194,10 @@ def validate(self):
if not self.source_configs:
raise OLMoConfigurationError("source_configs must not be empty")

if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0:
raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}")
summed_weights = np.sum([source.target_ratio for source in self.source_configs])

if not np.allclose(summed_weights, 1.0):
raise OLMoConfigurationError(f"target_ratios must sum to 1.0, got {summed_weights}")

def build(self) -> SourceMixtureDataset:
self.validate()
Expand Down
14 changes: 3 additions & 11 deletions src/test/data/source_mixture_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def test_source_mixture_config(tmp_path: Path, caplog, capsys):
source_configs = [
SourceMixtureConfig(
source_name="1",
target_ratio=0.33,
target_ratio=0.33333,
paths=[i[0] for i in source_paths["1"]],
),
SourceMixtureConfig(
source_name="2", target_ratio=0.33, paths=[i[0] for i in source_paths["2"]]
source_name="2", target_ratio=0.33333, paths=[i[0] for i in source_paths["2"]]
),
SourceMixtureConfig(
source_name="3",
target_ratio=0.34,
target_ratio=0.33333,
paths=[i[0] for i in source_paths["3"]],
),
]
Expand Down Expand Up @@ -62,14 +62,6 @@ def test_source_mixture_config_validation():
source_name="source1", target_ratio=1.2, paths=["/path/to/source1"]
).validate()

with pytest.raises(OLMoConfigurationError):
SourceMixtureConfig(
source_name="source1",
target_ratio=0.5,
max_source_fraction=0.4,
paths=["/path/to/source1"],
).validate()

with pytest.raises(OLMoConfigurationError):
SourceMixtureConfig(source_name="source1", target_ratio=0.5, paths=[]).validate()

Expand Down

0 comments on commit fdbb76e

Please sign in to comment.