Skip to content

Commit

Permalink
Attempt at the bucket fix
Browse files Browse the repository at this point in the history
  • Loading branch information
onmyraedar committed Feb 7, 2025
1 parent 42fc631 commit c24b1db
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
12 changes: 10 additions & 2 deletions edsl/jobs/Jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,22 @@ def _run(self, config: RunConfig) -> Union[None, "Results"]:
# first try to run the job remotely
if results := self._remote_results():
return results

self._check_if_local_keys_ok()

if config.environment.bucket_collection is None:
self.run_config.environment.bucket_collection = (
self.create_bucket_collection()
)

if (
self.run_config.environment.key_lookup is not None
and self.run_config.environment.bucket_collection is not None
):
self.run_config.environment.bucket_collection.update_from_key_lookup(
self.run_config.environment.key_lookup
)

return None

@with_config
Expand All @@ -613,7 +621,7 @@ def run(self, *, config: RunConfig) -> "Results":
:param key_lookup: A KeyLookup object to manage API keys
"""
potentially_completed_results = self._run(config)

if potentially_completed_results is not None:
return potentially_completed_results

Expand Down
30 changes: 30 additions & 0 deletions edsl/jobs/buckets/BucketCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,36 @@ def add_model(self, model: "LanguageModel") -> None:
else:
self[model] = self.services_to_buckets[self.models_to_services[model.model]]

def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
"""Updates the bucket collection rates based on model RPM/TPM from KeyLookup"""

for model_name, service in self.models_to_services.items():
if service in key_lookup and not self.infinity_buckets:

if key_lookup[service].rpm is not None:
new_rps = key_lookup[service].rpm / 60.0
new_requests_bucket = TokenBucket(
bucket_name=service,
bucket_type="requests",
capacity=new_rps,
refill_rate=new_rps,
remote_url=self.remote_url,
)
self.services_to_buckets[service].requests_bucket = (
new_requests_bucket
)

if key_lookup[service].tpm is not None:
new_tps = key_lookup[service].tpm / 60.0
new_tokens_bucket = TokenBucket(
bucket_name=service,
bucket_type="tokens",
capacity=new_tps,
refill_rate=new_tps,
remote_url=self.remote_url,
)
self.services_to_buckets[service].tokens_bucket = new_tokens_bucket

def visualize(self) -> dict:
"""Visualize the token and request buckets for each model."""
plots = {}
Expand Down

0 comments on commit c24b1db

Please sign in to comment.