Skip to content

Commit

Permalink
Fix minor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmpablo157321 committed Jan 8, 2025
1 parent a728e96 commit 1bd53bf
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mlperf_logging/package_checker/power_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def check_range(self, l, n):
seen = set({})
errors = []
for e in l:
if e < 0 or e >(n-1) or e in seen:
if e < 0 or e > (n-1) or e in seen:
return False

return True
Expand All @@ -34,7 +34,7 @@ def check_equals(self, l):
counter[e] += 1
else:
counter[e] = 1
max_equals = max(counter, counter.get)
max_equals = max(counter)
for i, e in enumerate(l):
if e != max_equals:
errors.append(i)
Expand All @@ -55,8 +55,8 @@ def check_power(self, power_folder, result_files):
power_files = os.listdir(power_result_folder)
node_results = [file for file in power_files if file.startswith("node")]
sw_results = [file for file in power_files if file.startswith("sw")]
node_idx = [os.path.splitext(os.path.basename(file))[1] for file in node_results]
sw_idx = [os.path.splitext(os.path.basename(file))[1] for file in node_results]
node_idx = [int(os.path.splitext(os.path.basename(file))[0].split('_')[-1]) for file in node_results]
sw_idx = [int(os.path.splitext(os.path.basename(file))[0].split('_')[-1]) for file in sw_results]

if len(power_files) > len(node_results) + len(sw_results):
logging.warning("Detected %d total files in directory %s, but some do not conform", len(power_files), power_result_folder)
Expand All @@ -80,10 +80,10 @@ def check_power(self, power_folder, result_files):

result_names = [os.path.splitext(os.path.basename(result_file))[0] for result_file in result_files]

valid, errors = self.check_mode(node_results)
valid, errors = self.check_equals(node_lens)
node_errors = set([result_names[error] for error in errors])

valid, errors = self.check_mode(sw_results)
valid, errors = self.check_equals(sw_lens)
sw_errors = set([result_names[error] for error in errors])

errors_set = errors_set | node_errors | sw_errors
Expand Down

0 comments on commit 1bd53bf

Please sign in to comment.