From 2bd7c3ed35ce27d5ffeb9010e23da8d9ebb3ffa7 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sun, 22 Dec 2024 18:38:13 -0800 Subject: [PATCH] Replace `len()` Zero Checks with Pythonic Empty Sequence Checks (#4298) * style: mainly replace len() comparisons with 0/1 with pythonic empty sequence checks * chore: more typings * use `df.empty` for dataframe * fix: more `df.empty` * format * revert partially * style: add back comments * style: format * refactor: `dict[str, str]` Co-authored-by: Tian Xia --------- Co-authored-by: Tian Xia --- examples/spot/lightning_cifar10/train.py | 2 +- sky/backends/cloud_vm_ray_backend.py | 2 +- sky/check.py | 2 +- sky/cli.py | 50 +++++++++---------- sky/cloud_stores.py | 2 +- sky/clouds/gcp.py | 2 +- sky/clouds/kubernetes.py | 2 +- sky/clouds/service_catalog/common.py | 21 ++++---- .../data_fetchers/fetch_azure.py | 2 +- .../data_fetchers/fetch_vsphere.py | 2 +- sky/clouds/utils/scp_utils.py | 6 +-- sky/core.py | 6 +-- sky/data/storage.py | 2 +- sky/jobs/core.py | 4 +- sky/jobs/state.py | 4 +- sky/jobs/utils.py | 12 ++--- sky/optimizer.py | 6 +-- sky/provision/aws/config.py | 4 +- sky/provision/gcp/config.py | 6 +-- sky/provision/kubernetes/config.py | 14 +++--- sky/provision/kubernetes/network_utils.py | 2 +- sky/provision/kubernetes/utils.py | 4 +- sky/provision/lambda_cloud/lambda_utils.py | 6 +-- sky/provision/oci/query_utils.py | 6 +-- sky/provision/vsphere/common/vim_utils.py | 2 +- sky/provision/vsphere/instance.py | 13 +++-- sky/provision/vsphere/vsphere_utils.py | 2 +- sky/resources.py | 8 +-- sky/serve/autoscalers.py | 4 +- sky/serve/core.py | 8 +-- sky/serve/replica_managers.py | 2 +- sky/serve/serve_state.py | 2 +- sky/serve/serve_utils.py | 21 ++++---- sky/serve/service_spec.py | 13 +++-- sky/skylet/job_lib.py | 2 +- sky/skylet/providers/ibm/node_provider.py | 4 +- sky/skylet/providers/scp/config.py | 2 +- sky/skylet/providers/scp/node_provider.py | 14 +++--- sky/task.py | 2 +- sky/utils/accelerator_registry.py | 2 +- sky/utils/common_utils.py | 2 +- sky/utils/dag_utils.py | 2 +- sky/utils/kubernetes/gpu_labeler.py | 2 +- .../kubernetes/ssh_jump_lifecycle_manager.py | 2 +- tests/test_yaml_parser.py | 4 +- tests/unit_tests/test_storage_utils.py | 2 +- 46 files changed, 143 insertions(+), 141 deletions(-) diff --git a/examples/spot/lightning_cifar10/train.py b/examples/spot/lightning_cifar10/train.py index 0df6f18484b..14901e635ef 100644 --- a/examples/spot/lightning_cifar10/train.py +++ b/examples/spot/lightning_cifar10/train.py @@ -163,7 +163,7 @@ def main(): ) model_ckpts = glob.glob(argv.root_dir + "/*.ckpt") - if argv.resume and len(model_ckpts) > 0: + if argv.resume and model_ckpts: latest_ckpt = max(model_ckpts, key=os.path.getctime) trainer.fit(model, cifar10_dm, ckpt_path=latest_ckpt) else: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 8974a0129bd..26118ad2de2 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2626,7 +2626,7 @@ def register_info(self, **kwargs) -> None: self._optimize_target) or optimizer.OptimizeTarget.COST self._requested_features = kwargs.pop('requested_features', self._requested_features) - assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}' + assert not kwargs, f'Unexpected kwargs: {kwargs}' def check_resources_fit_cluster( self, diff --git a/sky/check.py b/sky/check.py index ee5ea77234b..1ab92cb1af6 100644 --- a/sky/check.py +++ b/sky/check.py @@ -127,7 +127,7 @@ def get_all_clouds(): '\nNote: The following clouds were disabled because they were not ' 'included in allowed_clouds in ~/.sky/config.yaml: ' f'{", ".join([c for c in disallowed_cloud_names])}') - if len(all_enabled_clouds) == 0: + if not all_enabled_clouds: echo( click.style( 'No cloud is enabled. SkyPilot will not be able to run any ' diff --git a/sky/cli.py b/sky/cli.py index 12f77e9f6c9..dca45e81164 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -114,7 +114,7 @@ def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]: glob_clusters = [] for cluster in clusters: glob_cluster = global_user_state.get_glob_cluster_names(cluster) - if len(glob_cluster) == 0 and not silent: + if not glob_cluster and not silent: click.echo(f'Cluster {cluster} not found.') glob_clusters.extend(glob_cluster) return list(set(glob_clusters)) @@ -125,7 +125,7 @@ def _get_glob_storages(storages: List[str]) -> List[str]: glob_storages = [] for storage_object in storages: glob_storage = global_user_state.get_glob_storage_name(storage_object) - if len(glob_storage) == 0: + if not glob_storage: click.echo(f'Storage {storage_object} not found.') glob_storages.extend(glob_storage) return list(set(glob_storages)) @@ -1473,7 +1473,7 @@ def _get_services(service_names: Optional[List[str]], if len(service_records) != 1: plural = 's' if len(service_records) > 1 else '' service_num = (str(len(service_records)) - if len(service_records) > 0 else 'No') + if service_records else 'No') raise click.UsageError( f'{service_num} service{plural} found. Please specify ' 'an existing service to show its endpoint. Usage: ' @@ -1696,8 +1696,7 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, if len(clusters) != 1: with ux_utils.print_exception_no_traceback(): plural = 's' if len(clusters) > 1 else '' - cluster_num = (str(len(clusters)) - if len(clusters) > 0 else 'No') + cluster_num = (str(len(clusters)) if clusters else 'No') cause = 'a single' if len(clusters) > 1 else 'an existing' raise ValueError( _STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format( @@ -1722,9 +1721,8 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, with ux_utils.print_exception_no_traceback(): plural = 's' if len(cluster_records) > 1 else '' cluster_num = (str(len(cluster_records)) - if len(cluster_records) > 0 else - f'{clusters[0]!r}') - verb = 'found' if len(cluster_records) > 0 else 'not found' + if cluster_records else f'{clusters[0]!r}') + verb = 'found' if cluster_records else 'not found' cause = 'a single' if len(clusters) > 1 else 'an existing' raise ValueError( _STATUS_PROPERTY_CLUSTER_NUM_ERROR_MESSAGE.format( @@ -2470,7 +2468,7 @@ def start( '(see `sky status`), or the -a/--all flag.') if all: - if len(clusters) > 0: + if clusters: click.echo('Both --all and cluster(s) specified for sky start. ' 'Letting --all take effect.') @@ -2800,7 +2798,7 @@ def _down_or_stop_clusters( option_str = '{stop,down}' operation = f'{verb} auto{option_str} on' - if len(names) > 0: + if names: controllers = [ name for name in names if controller_utils.Controllers.from_name(name) is not None @@ -2814,7 +2812,7 @@ def _down_or_stop_clusters( # Make sure the controllers are explicitly specified without other # normal clusters. if controllers: - if len(names) != 0: + if names: names_str = ', '.join(map(repr, names)) raise click.UsageError( f'{operation} controller(s) ' @@ -2867,7 +2865,7 @@ def _down_or_stop_clusters( if apply_to_all: all_clusters = global_user_state.get_clusters() - if len(names) > 0: + if names: click.echo( f'Both --all and cluster(s) specified for `sky {command}`. ' 'Letting --all take effect.') @@ -2894,7 +2892,7 @@ def _down_or_stop_clusters( click.echo('Cluster(s) not found (tip: see `sky status`).') return - if not no_confirm and len(clusters) > 0: + if not no_confirm and clusters: cluster_str = 'clusters' if len(clusters) > 1 else 'cluster' cluster_list = ', '.join(clusters) click.confirm( @@ -3003,7 +3001,7 @@ def check(clouds: Tuple[str], verbose: bool): # Check only specific clouds - AWS and GCP. sky check aws gcp """ - clouds_arg = clouds if len(clouds) > 0 else None + clouds_arg = clouds if clouds else None sky_check.check(verbose=verbose, clouds=clouds_arg) @@ -3138,7 +3136,7 @@ def _get_kubernetes_realtime_gpu_table( f'capacity ({list(capacity.keys())}), ' f'and available ({list(available.keys())}) ' 'must be same.') - if len(counts) == 0: + if not counts: err_msg = 'No GPUs found in Kubernetes cluster. ' debug_msg = 'To further debug, run: sky check ' if name_filter is not None: @@ -3282,7 +3280,7 @@ def _output(): for tpu in service_catalog.get_tpus(): if tpu in result: tpu_table.add_row([tpu, _list_to_str(result.pop(tpu))]) - if len(tpu_table.get_string()) > 0: + if tpu_table.get_string(): yield '\n\n' yield from tpu_table.get_string() @@ -3393,7 +3391,7 @@ def _output(): yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' f'Cloud GPUs{colorama.Style.RESET_ALL}\n') - if len(result) == 0: + if not result: quantity_str = (f' with requested quantity {quantity}' if quantity else '') cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.' @@ -3522,7 +3520,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r # Delete all storage objects. sky storage delete -a """ - if sum([len(names) > 0, all]) != 1: + if sum([bool(names), all]) != 1: raise click.UsageError('Either --all or a name must be specified.') if all: storages = sky.storage_ls() @@ -3881,8 +3879,8 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): exit_if_not_accessible=True) job_id_str = ','.join(map(str, job_ids)) - if sum([len(job_ids) > 0, name is not None, all]) != 1: - argument_str = f'--job-ids {job_id_str}' if len(job_ids) > 0 else '' + if sum([bool(job_ids), name is not None, all]) != 1: + argument_str = f'--job-ids {job_id_str}' if job_ids else '' argument_str += f' --name {name}' if name is not None else '' argument_str += ' --all' if all else '' raise click.UsageError( @@ -4523,9 +4521,9 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool, # Forcefully tear down a specific replica, even in failed status. sky serve down my-service --replica-id 1 --purge """ - if sum([len(service_names) > 0, all]) != 1: - argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len( - service_names) > 0 else '' + if sum([bool(service_names), all]) != 1: + argument_str = (f'SERVICE_NAMES={",".join(service_names)}' + if service_names else '') argument_str += ' --all' if all else '' raise click.UsageError( 'Can only specify one of SERVICE_NAMES or --all. ' @@ -4898,7 +4896,7 @@ def benchmark_launch( if idle_minutes_to_autostop is None: idle_minutes_to_autostop = 5 commandline_args['idle-minutes-to-autostop'] = idle_minutes_to_autostop - if len(env) > 0: + if env: commandline_args['env'] = [f'{k}={v}' for k, v in env] # Launch the benchmarking clusters in detach mode in parallel. @@ -5177,7 +5175,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool], raise click.BadParameter( 'Either specify benchmarks or use --all to delete all benchmarks.') to_delete = [] - if len(benchmarks) > 0: + if benchmarks: for benchmark in benchmarks: record = benchmark_state.get_benchmark_from_name(benchmark) if record is None: @@ -5186,7 +5184,7 @@ def benchmark_delete(benchmarks: Tuple[str], all: Optional[bool], to_delete.append(record) if all: to_delete = benchmark_state.get_benchmarks() - if len(benchmarks) > 0: + if benchmarks: print('Both --all and benchmark(s) specified ' 'for sky bench delete. Letting --all take effect.') diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py index ee1b051d32b..e9c111c56ac 100644 --- a/sky/cloud_stores.py +++ b/sky/cloud_stores.py @@ -133,7 +133,7 @@ def is_directory(self, url: str) -> bool: # If is a bucket root, then we only need `gsutil` to succeed # to make sure the bucket exists. It is already a directory. _, key = data_utils.split_gcs_path(url) - if len(key) == 0: + if not key: return True # Otherwise, gsutil ls -d url will return: # --> url.rstrip('/') if url is not a directory diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index c0f22cc860b..ff200f84147 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -830,7 +830,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: ret_permissions = request.execute().get('permissions', []) diffs = set(gcp_minimal_permissions).difference(set(ret_permissions)) - if len(diffs) > 0: + if diffs: identity_str = identity[0] if identity else None return False, ( 'The following permissions are not enabled for the current ' diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 65b50042aba..f9242bd77aa 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -139,7 +139,7 @@ def _existing_allowed_contexts(cls) -> List[str]: use the service account mounted in the pod. """ all_contexts = kubernetes_utils.get_all_kube_context_names() - if len(all_contexts) == 0: + if not all_contexts: return [] all_contexts = set(all_contexts) diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 29df92d7535..0fce7c25f6a 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -270,9 +270,10 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str: candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9) candidate_loc = sorted(candidate_loc) candidate_strs = '' - if len(candidate_loc) > 0: + if candidate_loc: candidate_strs = ', '.join(candidate_loc) candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?' + return candidate_strs def _get_all_supported_regions_str() -> str: @@ -286,7 +287,7 @@ def _get_all_supported_regions_str() -> str: filter_df = df if region is not None: filter_df = _filter_region_zone(filter_df, region, zone=None) - if len(filter_df) == 0: + if filter_df.empty: with ux_utils.print_exception_no_traceback(): error_msg = (f'Invalid region {region!r}') candidate_strs = _get_candidate_str( @@ -310,7 +311,7 @@ def _get_all_supported_regions_str() -> str: if zone is not None: maybe_region_df = filter_df filter_df = filter_df[filter_df['AvailabilityZone'] == zone] - if len(filter_df) == 0: + if filter_df.empty: region_str = f' for region {region!r}' if region else '' df = maybe_region_df if region else df with ux_utils.print_exception_no_traceback(): @@ -378,7 +379,7 @@ def get_vcpus_mem_from_instance_type_impl( instance_type: str, ) -> Tuple[Optional[float], Optional[float]]: df = _get_instance_type(df, instance_type, None) - if len(df) == 0: + if df.empty: with ux_utils.print_exception_no_traceback(): raise ValueError(f'No instance type {instance_type} found.') assert len(set(df['vCPUs'])) == 1, ('Cannot determine the number of vCPUs ' @@ -484,7 +485,7 @@ def get_accelerators_from_instance_type_impl( instance_type: str, ) -> Optional[Dict[str, Union[int, float]]]: df = _get_instance_type(df, instance_type, None) - if len(df) == 0: + if df.empty: with ux_utils.print_exception_no_traceback(): raise ValueError(f'No instance type {instance_type} found.') row = df.iloc[0] @@ -518,7 +519,7 @@ def get_instance_type_for_accelerator_impl( result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) & (abs(df['AcceleratorCount'] - acc_count) <= 0.01)] result = _filter_region_zone(result, region, zone) - if len(result) == 0: + if result.empty: fuzzy_result = df[ (df['AcceleratorName'].str.contains(acc_name, case=False)) & (df['AcceleratorCount'] >= acc_count)] @@ -527,7 +528,7 @@ def get_instance_type_for_accelerator_impl( fuzzy_result = fuzzy_result[['AcceleratorName', 'AcceleratorCount']].drop_duplicates() fuzzy_candidate_list = [] - if len(fuzzy_result) > 0: + if not fuzzy_result.empty: for _, row in fuzzy_result.iterrows(): acc_cnt = float(row['AcceleratorCount']) acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else @@ -539,7 +540,7 @@ def get_instance_type_for_accelerator_impl( result = _filter_with_cpus(result, cpus) result = _filter_with_mem(result, memory) result = _filter_region_zone(result, region, zone) - if len(result) == 0: + if result.empty: return ([], []) # Current strategy: choose the cheapest instance @@ -680,7 +681,7 @@ def get_image_id_from_tag_impl(df: 'pd.DataFrame', tag: str, df = _filter_region_zone(df, region, zone=None) assert len(df) <= 1, ('Multiple images found for tag ' f'{tag} in region {region}') - if len(df) == 0: + if df.empty: return None image_id = df['ImageId'].iloc[0] if pd.isna(image_id): @@ -694,4 +695,4 @@ def is_image_tag_valid_impl(df: 'pd.DataFrame', tag: str, df = df[df['Tag'] == tag] df = _filter_region_zone(df, region, zone=None) df = df.dropna(subset=['ImageId']) - return len(df) > 0 + return not df.empty diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index 4aef41f9c90..00768d5c6bb 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -134,7 +134,7 @@ def get_pricing_df(region: Optional[str] = None) -> 'pd.DataFrame': content_str = r.content.decode('ascii') content = json.loads(content_str) items = content.get('Items', []) - if len(items) == 0: + if not items: break all_items += items url = content.get('NextPageLink') diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py index 216e8ed9b4f..c08a56955a0 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py @@ -534,7 +534,7 @@ def initialize_images_csv(csv_saving_path: str, vc_object, gpu_name = tag_name.split('-')[1] if gpu_name not in gpu_tags: gpu_tags.append(gpu_name) - if len(gpu_tags) > 0: + if gpu_tags: gpu_tags_str = str(gpu_tags).replace('\'', '\"') f.write(f'{item.id},{vcenter_name},{item_cpu},{item_memory}' f',,,\'{gpu_tags_str}\'\n') diff --git a/sky/clouds/utils/scp_utils.py b/sky/clouds/utils/scp_utils.py index 3e91e22e6d9..4efc79313c5 100644 --- a/sky/clouds/utils/scp_utils.py +++ b/sky/clouds/utils/scp_utils.py @@ -65,7 +65,7 @@ def __setitem__(self, instance_id: str, value: Optional[Dict[str, if value is None: if instance_id in metadata: metadata.pop(instance_id) # del entry - if len(metadata) == 0: + if not metadata: if os.path.exists(self.path): os.remove(self.path) return @@ -84,7 +84,7 @@ def refresh(self, instance_ids: List[str]) -> None: for instance_id in list(metadata.keys()): if instance_id not in instance_ids: del metadata[instance_id] - if len(metadata) == 0: + if not metadata: os.remove(self.path) return with open(self.path, 'w', encoding='utf-8') as f: @@ -410,7 +410,7 @@ def list_security_groups(self, vpc_id=None, sg_name=None): parameter.append('vpcId=' + vpc_id) if sg_name is not None: parameter.append('securityGroupName=' + sg_name) - if len(parameter) > 0: + if parameter: url = url + '?' + '&'.join(parameter) return self._get(url) diff --git a/sky/core.py b/sky/core.py index 9f1288d7fb6..36b3d45b849 100644 --- a/sky/core.py +++ b/sky/core.py @@ -732,7 +732,7 @@ def cancel( f'{colorama.Fore.YELLOW}' f'Cancelling latest running job on cluster {cluster_name!r}...' f'{colorama.Style.RESET_ALL}') - elif len(job_ids): + elif job_ids: # all = False, len(job_ids) > 0 => cancel the specified jobs. jobs_str = ', '.join(map(str, job_ids)) sky_logging.print( @@ -817,7 +817,7 @@ def download_logs( backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend), backend - if job_ids is not None and len(job_ids) == 0: + if job_ids is not None and not job_ids: return {} usage_lib.record_cluster_name_for_current_operation(cluster_name) @@ -866,7 +866,7 @@ def job_status(cluster_name: str, f'of type {backend.__class__.__name__!r}.') assert isinstance(handle, backends.CloudVmRayResourceHandle), handle - if job_ids is not None and len(job_ids) == 0: + if job_ids is not None and not job_ids: return {} sky_logging.print(f'{colorama.Fore.YELLOW}' diff --git a/sky/data/storage.py b/sky/data/storage.py index d3d18a9d18f..2247e4545f0 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -1067,7 +1067,7 @@ def add_if_not_none(key: str, value: Optional[Any]): add_if_not_none('source', self.source) stores = None - if len(self.stores) > 0: + if self.stores: stores = ','.join([store.value for store in self.stores]) add_if_not_none('store', stores) add_if_not_none('persistent', self.persistent) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 1348441a5bd..3718d0ac67c 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -347,8 +347,8 @@ def cancel(name: Optional[str] = None, stopped_message='All managed jobs should have finished.') job_id_str = ','.join(map(str, job_ids)) - if sum([len(job_ids) > 0, name is not None, all]) != 1: - argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else '' + if sum([bool(job_ids), name is not None, all]) != 1: + argument_str = f'job_ids={job_id_str}' if job_ids else '' argument_str += f' name={name}' if name is not None else '' argument_str += ' all' if all else '' with ux_utils.print_exception_no_traceback(): diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 9a5ab4b3cad..31dcfcfd5eb 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -591,7 +591,7 @@ def get_latest_task_id_status( If the job_id does not exist, (None, None) will be returned. """ id_statuses = _get_all_task_ids_statuses(job_id) - if len(id_statuses) == 0: + if not id_statuses: return None, None task_id, status = id_statuses[-1] for task_id, status in id_statuses: @@ -617,7 +617,7 @@ def get_failure_reason(job_id: int) -> Optional[str]: WHERE spot_job_id=(?) ORDER BY task_id ASC""", (job_id,)).fetchall() reason = [r[0] for r in reason if r[0] is not None] - if len(reason) == 0: + if not reason: return None return reason[0] diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 267c205285b..e5bbced997c 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -234,11 +234,11 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: if job_ids is None: job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None) job_ids = list(set(job_ids)) - if len(job_ids) == 0: + if not job_ids: return 'No job to cancel.' job_id_str = ', '.join(map(str, job_ids)) logger.info(f'Cancelling jobs {job_id_str}.') - cancelled_job_ids = [] + cancelled_job_ids: List[int] = [] for job_id in job_ids: # Check the status of the managed job status. If it is in # terminal state, we can safely skip it. @@ -268,7 +268,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: shutil.copy(str(signal_file), str(legacy_signal_file)) cancelled_job_ids.append(job_id) - if len(cancelled_job_ids) == 0: + if not cancelled_job_ids: return 'No job to cancel.' identity_str = f'Job with ID {cancelled_job_ids[0]} is' if len(cancelled_job_ids) > 1: @@ -281,7 +281,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: def cancel_job_by_name(job_name: str) -> str: """Cancel a job by name.""" job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name) - if len(job_ids) == 0: + if not job_ids: return f'No running job found with name {job_name!r}.' if len(job_ids) > 1: return (f'{colorama.Fore.RED}Multiple running jobs found ' @@ -515,7 +515,7 @@ def stream_logs(job_id: Optional[int], for job in managed_jobs if job['job_name'] == job_name } - if len(managed_job_ids) == 0: + if not managed_job_ids: return f'No managed job found with name {job_name!r}.' if len(managed_job_ids) > 1: job_ids_str = ', '.join( @@ -541,7 +541,7 @@ def stream_logs(job_id: Optional[int], if job_id is None: assert job_name is not None job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name) - if len(job_ids) == 0: + if not job_ids: return f'No running managed job found with name {job_name!r}.' if len(job_ids) > 1: raise ValueError( diff --git a/sky/optimizer.py b/sky/optimizer.py index 2f70dd39429..c5a631c213b 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -188,7 +188,7 @@ def _remove_dummy_source_sink_nodes(dag: 'dag_lib.Dag'): """Removes special Source and Sink nodes.""" source = [t for t in dag.tasks if t.name == _DUMMY_SOURCE_NAME] sink = [t for t in dag.tasks if t.name == _DUMMY_SINK_NAME] - if len(source) == len(sink) == 0: + if not source and not sink: return assert len(source) == len(sink) == 1, dag.tasks dag.remove(source[0]) @@ -1298,7 +1298,7 @@ def _fill_in_launchable_resources( resources, num_nodes=task.num_nodes) if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint - if len(feasible_resources.resources_list) > 0: + if feasible_resources.resources_list: # Assume feasible_resources is sorted by prices. Guaranteed by # the implementation of get_feasible_launchable_resources and # the underlying service_catalog filtering @@ -1310,7 +1310,7 @@ def _fill_in_launchable_resources( else: all_fuzzy_candidates.update( feasible_resources.fuzzy_candidate_list) - if len(launchable[resources]) == 0: + if not launchable[resources]: clouds_str = str(clouds_list) if len(clouds_list) > 1 else str( clouds_list[0]) num_node_str = '' diff --git a/sky/provision/aws/config.py b/sky/provision/aws/config.py index 6a8c77eafed..ffa87c3a011 100644 --- a/sky/provision/aws/config.py +++ b/sky/provision/aws/config.py @@ -279,7 +279,7 @@ def _has_igw_route(route_tables): logger.debug(f'subnet {subnet_id} route tables: {route_tables}') if _has_igw_route(route_tables): return True - if len(route_tables) > 0: + if route_tables: return False # Handle the case that a "main" route table is implicitly associated with @@ -454,7 +454,7 @@ def _vpc_id_from_security_group_ids(ec2, sg_ids: List[str]) -> Any: no_sg_msg = ('Failed to detect a security group with id equal to any of ' 'the configured SecurityGroupIds.') - assert len(vpc_ids) > 0, no_sg_msg + assert vpc_ids, no_sg_msg return vpc_ids[0] diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index a8292669a7c..a99267eb0b9 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -397,7 +397,7 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, operation = compute.networks().getEffectiveFirewalls(project=project_id, network=vpc_name) response = operation.execute() - if len(response) == 0: + if not response: return False effective_rules = response['firewalls'] @@ -515,7 +515,7 @@ def _create_rules(project_id: str, compute, rules, vpc_name): rule_list = _list_firewall_rules(project_id, compute, filter=f'(name={rule_name})') - if len(rule_list) > 0: + if rule_list: _delete_firewall_rule(project_id, compute, rule_name) body = rule.copy() @@ -624,7 +624,7 @@ def get_usable_vpc_and_subnet( vpc_list = _list_vpcnets(project_id, compute, filter=f'name={constants.SKYPILOT_VPC_NAME}') - if len(vpc_list) == 0: + if not vpc_list: body = constants.VPC_TEMPLATE.copy() body['name'] = body['name'].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) body['selfLink'] = body['selfLink'].format( diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index 370430720f0..0fe920be9d6 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -232,7 +232,7 @@ def _get_resource(container_resources: Dict[str, Any], resource_name: str, # Look for keys containing the resource_name. For example, # the key 'nvidia.com/gpu' contains the key 'gpu'. matching_keys = [key for key in resources if resource_name in key.lower()] - if len(matching_keys) == 0: + if not matching_keys: return float('inf') if len(matching_keys) > 1: # Should have only one match -- mostly relevant for gpu. @@ -265,7 +265,7 @@ def _configure_autoscaler_service_account( field_selector = f'metadata.name={name}' accounts = (kubernetes.core_api(context).list_namespaced_service_account( namespace, field_selector=field_selector).items) - if len(accounts) > 0: + if accounts: assert len(accounts) == 1 # Nothing to check for equality and patch here, # since the service_account.metadata.name is the only important @@ -308,7 +308,7 @@ def _configure_autoscaler_role(namespace: str, context: Optional[str], field_selector = f'metadata.name={name}' roles = (kubernetes.auth_api(context).list_namespaced_role( namespace, field_selector=field_selector).items) - if len(roles) > 0: + if roles: assert len(roles) == 1 existing_role = roles[0] # Convert to k8s object to compare @@ -374,7 +374,7 @@ def _configure_autoscaler_role_binding( field_selector = f'metadata.name={name}' role_bindings = (kubernetes.auth_api(context).list_namespaced_role_binding( rb_namespace, field_selector=field_selector).items) - if len(role_bindings) > 0: + if role_bindings: assert len(role_bindings) == 1 existing_binding = role_bindings[0] new_rb = kubernetes_utils.dict_to_k8s_object(binding, 'V1RoleBinding') @@ -415,7 +415,7 @@ def _configure_autoscaler_cluster_role(namespace, context, field_selector = f'metadata.name={name}' cluster_roles = (kubernetes.auth_api(context).list_cluster_role( field_selector=field_selector).items) - if len(cluster_roles) > 0: + if cluster_roles: assert len(cluster_roles) == 1 existing_cr = cluster_roles[0] new_cr = kubernetes_utils.dict_to_k8s_object(role, 'V1ClusterRole') @@ -460,7 +460,7 @@ def _configure_autoscaler_cluster_role_binding( field_selector = f'metadata.name={name}' cr_bindings = (kubernetes.auth_api(context).list_cluster_role_binding( field_selector=field_selector).items) - if len(cr_bindings) > 0: + if cr_bindings: assert len(cr_bindings) == 1 existing_binding = cr_bindings[0] new_binding = kubernetes_utils.dict_to_k8s_object( @@ -639,7 +639,7 @@ def _configure_services(namespace: str, context: Optional[str], field_selector = f'metadata.name={name}' services = (kubernetes.core_api(context).list_namespaced_service( namespace, field_selector=field_selector).items) - if len(services) > 0: + if services: assert len(services) == 1 existing_service = services[0] # Convert to k8s object to compare diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index b16482e5072..29fcf181edd 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -230,7 +230,7 @@ def get_ingress_external_ip_and_ports( namespace, _request_timeout=kubernetes.API_TIMEOUT).items if item.metadata.name == 'ingress-nginx-controller' ] - if len(ingress_services) == 0: + if not ingress_services: return (None, None) ingress_service = ingress_services[0] diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 4c23a41161a..87ccd6b105d 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -583,7 +583,7 @@ def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType', node for node in nodes if gpu_label_key in node.metadata.labels and node.metadata.labels[gpu_label_key] == gpu_label_val ] - assert len(gpu_nodes) > 0, 'GPU nodes not found' + assert gpu_nodes, 'GPU nodes not found' if is_tpu_on_gke(acc_type): # If requested accelerator is a TPU type, check if the cluster # has sufficient TPU resource to meet the requirement. @@ -1526,7 +1526,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str], def find(l, predicate): """Utility function to find element in given list""" results = [x for x in l if predicate(x)] - return results[0] if len(results) > 0 else None + return results[0] if results else None # Get the SSH jump pod name from the head pod try: diff --git a/sky/provision/lambda_cloud/lambda_utils.py b/sky/provision/lambda_cloud/lambda_utils.py index 4d8e6246b6d..cfd8e02ad23 100644 --- a/sky/provision/lambda_cloud/lambda_utils.py +++ b/sky/provision/lambda_cloud/lambda_utils.py @@ -50,7 +50,7 @@ def set(self, instance_id: str, value: Optional[Dict[str, Any]]) -> None: if value is None: if instance_id in metadata: metadata.pop(instance_id) # del entry - if len(metadata) == 0: + if not metadata: if os.path.exists(self.path): os.remove(self.path) return @@ -69,7 +69,7 @@ def refresh(self, instance_ids: List[str]) -> None: for instance_id in list(metadata.keys()): if instance_id not in instance_ids: del metadata[instance_id] - if len(metadata) == 0: + if not metadata: os.remove(self.path) return with open(self.path, 'w', encoding='utf-8') as f: @@ -150,7 +150,7 @@ def create_instances( ['regions_with_capacity_available']) available_regions = [reg['name'] for reg in available_regions] if region not in available_regions: - if len(available_regions) > 0: + if available_regions: aval_reg = ' '.join(available_regions) else: aval_reg = 'None' diff --git a/sky/provision/oci/query_utils.py b/sky/provision/oci/query_utils.py index 47a0438cb21..8cca0629305 100644 --- a/sky/provision/oci/query_utils.py +++ b/sky/provision/oci/query_utils.py @@ -248,7 +248,7 @@ def find_compartment(cls, region) -> str: limit=1) compartments = list_compartments_response.data - if len(compartments) > 0: + if compartments: skypilot_compartment = compartments[0].id return skypilot_compartment @@ -274,7 +274,7 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]: display_name=oci_utils.oci_config.VCN_NAME, lifecycle_state='AVAILABLE') vcns = list_vcns_response.data - if len(vcns) > 0: + if vcns: # Found the VCN. skypilot_vcn = vcns[0].id list_subnets_response = net_client.list_subnets( @@ -359,7 +359,7 @@ def create_vcn_subnet(cls, net_client, if str(s.cidr_block).startswith('all-') and str(s.cidr_block). endswith('-services-in-oracle-services-network') ] - if len(services) > 0: + if services: # Create service gateway for regional services. create_sg_response = net_client.create_service_gateway( create_service_gateway_details=oci_adaptor.oci.core.models. diff --git a/sky/provision/vsphere/common/vim_utils.py b/sky/provision/vsphere/common/vim_utils.py index 33c02db8feb..bde1bc25cf0 100644 --- a/sky/provision/vsphere/common/vim_utils.py +++ b/sky/provision/vsphere/common/vim_utils.py @@ -56,7 +56,7 @@ def get_hosts_by_cluster_names(content, vcenter_name, cluster_name_dicts=None): 'name': cluster.name } for cluster in cluster_view.view] cluster_view.Destroy() - if len(cluster_name_dicts) == 0: + if not cluster_name_dicts: logger.warning(f'vCenter \'{vcenter_name}\' has no clusters') # Retrieve all cluster names from the cluster_name_dicts diff --git a/sky/provision/vsphere/instance.py b/sky/provision/vsphere/instance.py index 787d8c97f62..2075cdb9c36 100644 --- a/sky/provision/vsphere/instance.py +++ b/sky/provision/vsphere/instance.py @@ -162,7 +162,7 @@ def _create_instances( if not gpu_instance: # Find an image for CPU images_df = images_df[images_df['GpuTags'] == '\'[]\''] - if len(images_df) == 0: + if not images_df: logger.error( f'Can not find an image for instance type: {instance_type}.') raise Exception( @@ -185,7 +185,7 @@ def _create_instances( image_instance_mapping_df = image_instance_mapping_df[ image_instance_mapping_df['InstanceType'] == instance_type] - if len(image_instance_mapping_df) == 0: + if not image_instance_mapping_df: raise Exception(f"""There is no image can match instance type named {instance_type} If you are using CPU-only instance, assign an image with tag @@ -218,10 +218,9 @@ def _create_instances( hosts_df = hosts_df[(hosts_df['AvailableCPUs'] / hosts_df['cpuMhz']) >= cpus_needed] hosts_df = hosts_df[hosts_df['AvailableMemory(MB)'] >= memory_needed] - assert len(hosts_df) > 0, ( - f'There is no host available to create the instance ' - f'{vms_item["InstanceType"]}, at least {cpus_needed} ' - f'cpus and {memory_needed}MB memory are required.') + assert hosts_df, (f'There is no host available to create the instance ' + f'{vms_item["InstanceType"]}, at least {cpus_needed} ' + f'cpus and {memory_needed}MB memory are required.') # Sort the hosts df by AvailableCPUs to get the compatible host with the # least resource @@ -365,7 +364,7 @@ def _choose_vsphere_cluster_name(config: common.ProvisionConfig, region: str, skypilot framework-optimized availability_zones""" vsphere_cluster_name = None vsphere_cluster_name_str = config.provider_config['availability_zone'] - if len(vc_object.clusters) > 0: + if vc_object.clusters: for optimized_cluster_name in vsphere_cluster_name_str.split(','): if optimized_cluster_name in [ item['name'] for item in vc_object.clusters diff --git a/sky/provision/vsphere/vsphere_utils.py b/sky/provision/vsphere/vsphere_utils.py index faec5d54930..51f284b0fc6 100644 --- a/sky/provision/vsphere/vsphere_utils.py +++ b/sky/provision/vsphere/vsphere_utils.py @@ -257,7 +257,7 @@ def get_skypilot_profile_id(self): # hard code here. should support configure later. profile_name = 'skypilot_policy' storage_profile_id = None - if len(profile_ids) > 0: + if profile_ids: profiles = pm.PbmRetrieveContent(profileIds=profile_ids) for profile in profiles: if profile_name in profile.name: diff --git a/sky/resources.py b/sky/resources.py index 5184278e02e..68d1b6f9ea8 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -661,7 +661,7 @@ def _validate_and_set_region_zone(self, region: Optional[str], continue valid_clouds.append(cloud) - if len(valid_clouds) == 0: + if not valid_clouds: if len(enabled_clouds) == 1: cloud_str = f'for cloud {enabled_clouds[0]}' else: @@ -773,7 +773,7 @@ def _try_validate_instance_type(self) -> None: for cloud in enabled_clouds: if cloud.instance_type_exists(self._instance_type): valid_clouds.append(cloud) - if len(valid_clouds) == 0: + if not valid_clouds: if len(enabled_clouds) == 1: cloud_str = f'for cloud {enabled_clouds[0]}' else: @@ -1008,7 +1008,7 @@ def _try_validate_labels(self) -> None: f'Label rejected due to {cloud}: {err_msg}' ]) break - if len(invalid_table.rows) > 0: + if invalid_table.rows: with ux_utils.print_exception_no_traceback(): raise ValueError( 'The following labels are invalid:' @@ -1283,7 +1283,7 @@ def copy(self, **override) -> 'Resources': _cluster_config_overrides=override.pop( '_cluster_config_overrides', self._cluster_config_overrides), ) - assert len(override) == 0 + assert not override return resources def valid_on_region_zones(self, region: str, zones: List[str]) -> bool: diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py index a4278f192fb..7a6311ad535 100644 --- a/sky/serve/autoscalers.py +++ b/sky/serve/autoscalers.py @@ -320,8 +320,8 @@ def select_outdated_replicas_to_scale_down( """Select outdated replicas to scale down.""" if self.update_mode == serve_utils.UpdateMode.ROLLING: - latest_ready_replicas = [] - old_nonterminal_replicas = [] + latest_ready_replicas: List['replica_managers.ReplicaInfo'] = [] + old_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = [] for info in replica_infos: if info.version == self.latest_version: if info.is_ready: diff --git a/sky/serve/core.py b/sky/serve/core.py index 561314bcbe0..f71c60b2fef 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -360,7 +360,7 @@ def update( raise RuntimeError(e.error_msg) from e service_statuses = serve_utils.load_service_status(serve_status_payload) - if len(service_statuses) == 0: + if not service_statuses: with ux_utils.print_exception_no_traceback(): raise RuntimeError(f'Cannot find service {service_name!r}.' f'To spin up a service, use {ux_utils.BOLD}' @@ -491,9 +491,9 @@ def down( stopped_message='All services should have terminated.') service_names_str = ','.join(service_names) - if sum([len(service_names) > 0, all]) != 1: - argument_str = f'service_names={service_names_str}' if len( - service_names) > 0 else '' + if sum([bool(service_names), all]) != 1: + argument_str = (f'service_names={service_names_str}' + if service_names else '') argument_str += ' all' if all else '' raise ValueError('Can only specify one of service_names or all. ' f'Provided {argument_str!r}.') diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index c0e5220e779..5f92dda0e2f 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -172,7 +172,7 @@ def _get_resources_ports(task_yaml: str) -> str: """Get the resources ports used by the task.""" task = sky.Task.from_yaml(task_yaml) # Already checked all ports are the same in sky.serve.core.up - assert len(task.resources) >= 1, task + assert task.resources, task task_resources: 'resources.Resources' = list(task.resources)[0] # Already checked the resources have and only have one port # before upload the task yaml. diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 983e17d00ae..f3e8fbf1e53 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -226,7 +226,7 @@ def from_replica_statuses( for status in ReplicaStatus.failed_statuses()) > 0: return cls.FAILED # When min_replicas = 0, there is no (provisioning) replica. - if len(replica_statuses) == 0: + if not replica_statuses: return cls.NO_REPLICA return cls.REPLICA_INIT diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 7e665929d66..35d2c25ff40 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -110,7 +110,7 @@ class UpdateMode(enum.Enum): class ThreadSafeDict(Generic[KeyType, ValueType]): """A thread-safe dict.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._dict: Dict[KeyType, ValueType] = dict(*args, **kwargs) self._lock = threading.Lock() @@ -383,7 +383,7 @@ def _get_service_status( def get_service_status_encoded(service_names: Optional[List[str]]) -> str: - service_statuses = [] + service_statuses: List[Dict[str, str]] = [] if service_names is None: # Get all service names service_names = serve_state.get_glob_service_names(None) @@ -400,7 +400,7 @@ def get_service_status_encoded(service_names: Optional[List[str]]) -> str: def load_service_status(payload: str) -> List[Dict[str, Any]]: service_statuses_encoded = common_utils.decode_payload(payload) - service_statuses = [] + service_statuses: List[Dict[str, Any]] = [] for service_status in service_statuses_encoded: service_statuses.append({ k: pickle.loads(base64.b64decode(v)) @@ -432,7 +432,7 @@ def _terminate_failed_services( A message indicating potential resource leak (if any). If no resource leak is detected, return None. """ - remaining_replica_clusters = [] + remaining_replica_clusters: List[str] = [] # The controller should have already attempted to terminate those # replicas, so we don't need to try again here. for replica_info in serve_state.get_replica_infos(service_name): @@ -459,8 +459,8 @@ def _terminate_failed_services( def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: service_names = serve_state.get_glob_service_names(service_names) - terminated_service_names = [] - messages = [] + terminated_service_names: List[str] = [] + messages: List[str] = [] for service_name in service_names: service_status = _get_service_status(service_name, with_replica_info=False) @@ -506,7 +506,7 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: f.write(UserSignal.TERMINATE.value) f.flush() terminated_service_names.append(f'{service_name!r}') - if len(terminated_service_names) == 0: + if not terminated_service_names: messages.append('No service to terminate.') else: identity_str = f'Service {terminated_service_names[0]} is' @@ -784,9 +784,9 @@ def get_endpoint(service_record: Dict[str, Any]) -> str: # Don't use backend_utils.is_controller_accessible since it is too slow. handle = global_user_state.get_handle_from_cluster_name( SKY_SERVE_CONTROLLER_NAME) - assert isinstance(handle, backends.CloudVmRayResourceHandle) if handle is None: return '-' + assert isinstance(handle, backends.CloudVmRayResourceHandle) load_balancer_port = service_record['load_balancer_port'] if load_balancer_port is None: return '-' @@ -816,7 +816,7 @@ def format_service_table(service_records: List[Dict[str, Any]], ]) service_table = log_utils.create_table(service_columns) - replica_infos = [] + replica_infos: List[Dict[str, Any]] = [] for record in service_records: for replica in record['replica_info']: replica['service_name'] = record['name'] @@ -888,7 +888,8 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], region = '-' zone = '-' - replica_handle: 'backends.CloudVmRayResourceHandle' = record['handle'] + replica_handle: Optional['backends.CloudVmRayResourceHandle'] = record[ + 'handle'] if replica_handle is not None: resources_str = resources_utils.get_readable_resources_repr( replica_handle, simplify=not show_all) diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index fbbca5bc0dd..41de54cf806 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -2,7 +2,7 @@ import json import os import textwrap -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml @@ -186,9 +186,12 @@ def from_yaml(yaml_path: str) -> 'SkyServiceSpec': return SkyServiceSpec.from_yaml_config(config['service']) def to_yaml_config(self) -> Dict[str, Any]: - config = dict() + config: Dict[str, Any] = {} - def add_if_not_none(section, key, value, no_empty: bool = False): + def add_if_not_none(section: str, + key: Optional[str], + value: Any, + no_empty: bool = False): if no_empty and not value: return if value is not None: @@ -231,8 +234,8 @@ def probe_str(self): ' with custom headers') return f'{method}{headers}' - def spot_policy_str(self): - policy_strs = [] + def spot_policy_str(self) -> str: + policy_strs: List[str] = [] if (self.dynamic_ondemand_fallback is not None and self.dynamic_ondemand_fallback): policy_strs.append('Dynamic on-demand fallback') diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dfd8332b019..65311688fb4 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -586,7 +586,7 @@ def update_job_status(job_ids: List[int], This function should only be run on the remote instance with ray>=2.4.0. """ echo = logger.info if not silent else logger.debug - if len(job_ids) == 0: + if not job_ids: return [] statuses = [] diff --git a/sky/skylet/providers/ibm/node_provider.py b/sky/skylet/providers/ibm/node_provider.py index 5e2a2d64493..44622369e92 100644 --- a/sky/skylet/providers/ibm/node_provider.py +++ b/sky/skylet/providers/ibm/node_provider.py @@ -377,7 +377,7 @@ def non_terminated_nodes(self, tag_filters) -> List[str]: node["id"], nic_id ).get_result() floating_ips = res["floating_ips"] - if len(floating_ips) == 0: + if not floating_ips: # not adding a node that's yet/failed to # to get a floating ip provisioned continue @@ -485,7 +485,7 @@ def _get_instance_data(self, name): """Returns instance (node) information matching the specified name""" instances_data = self.ibm_vpc_client.list_instances(name=name).get_result() - if len(instances_data["instances"]) > 0: + if instances_data["instances"]: return instances_data["instances"][0] return None diff --git a/sky/skylet/providers/scp/config.py b/sky/skylet/providers/scp/config.py index c20b1837f26..d19744e7322 100644 --- a/sky/skylet/providers/scp/config.py +++ b/sky/skylet/providers/scp/config.py @@ -107,7 +107,7 @@ def get_vcp_subnets(self): for item in subnet_contents if item['subnetState'] == 'ACTIVE' and item["vpcId"] == vpc ] - if len(subnet_list) > 0: + if subnet_list: vpc_subnets[vpc] = subnet_list return vpc_subnets diff --git a/sky/skylet/providers/scp/node_provider.py b/sky/skylet/providers/scp/node_provider.py index 004eaac3830..f99b477ab06 100644 --- a/sky/skylet/providers/scp/node_provider.py +++ b/sky/skylet/providers/scp/node_provider.py @@ -259,7 +259,7 @@ def _config_security_group(self, zone_id, vpc, cluster_name): for sg in sg_contents if sg["securityGroupId"] == sg_id ] - if len(sg) != 0 and sg[0] == "ACTIVE": + if sg and sg[0] == "ACTIVE": break time.sleep(5) @@ -282,16 +282,16 @@ def _del_security_group(self, sg_id): for sg in sg_contents if sg["securityGroupId"] == sg_id ] - if len(sg) == 0: + if not sg: break def _refresh_security_group(self, vms): - if len(vms) > 0: + if vms: return # remove security group if vm does not exist keys = self.metadata.keys() security_group_id = self.metadata[ - keys[0]]['creation']['securityGroupId'] if len(keys) > 0 else None + keys[0]]['creation']['securityGroupId'] if keys else None if security_group_id: try: self._del_security_group(security_group_id) @@ -308,7 +308,7 @@ def _del_vm(self, vm_id): for vm in vm_contents if vm["virtualServerId"] == vm_id ] - if len(vms) == 0: + if not vms: break def _del_firwall_rules(self, firewall_id, rule_ids): @@ -391,7 +391,7 @@ def _create_instance_sequence(self, vpc, instance_config): return None, None, None, None def _undo_funcs(self, undo_func_list): - while len(undo_func_list) > 0: + while undo_func_list: func = undo_func_list.pop() func() @@ -468,7 +468,7 @@ def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], zone_config = ZoneConfig(self.scp_client, node_config) vpc_subnets = zone_config.get_vcp_subnets() - if (len(vpc_subnets) == 0): + if not vpc_subnets: raise SCPError("This region/zone does not have available VPCs.") instance_config = zone_config.bootstrap_instance_config(node_config) diff --git a/sky/task.py b/sky/task.py index cebc616dc6d..bd454216b0f 100644 --- a/sky/task.py +++ b/sky/task.py @@ -956,7 +956,7 @@ def sync_storage_mounts(self) -> None: }``. """ for storage in self.storage_mounts.values(): - if len(storage.stores) == 0: + if not storage.stores: store_type, store_region = self._get_preferred_store() self.storage_plans[storage] = store_type storage.add_store(store_type, store_region) diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py index 78a708efb91..6f4a80cb886 100644 --- a/sky/utils/accelerator_registry.py +++ b/sky/utils/accelerator_registry.py @@ -106,7 +106,7 @@ def canonicalize_accelerator_name(accelerator: str, return names[0] # Do not print an error message here. Optimizer will handle it. - if len(names) == 0: + if not names: return accelerator # Currently unreachable. diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 3fcdd24e505..ee8f5cf7bec 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -633,7 +633,7 @@ def get_cleaned_username(username: str = '') -> str: return username -def fill_template(template_name: str, variables: Dict, +def fill_template(template_name: str, variables: Dict[str, Any], output_path: str) -> None: """Create a file from a Jinja template and return the filename.""" assert template_name.endswith('.j2'), template_name diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 3229f86abf9..d0eb03d46ea 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -89,7 +89,7 @@ def load_chain_dag_from_yaml( elif len(configs) == 1: dag_name = configs[0].get('name') - if len(configs) == 0: + if not configs: # YAML has only `name: xxx`. Still instantiate a task. configs = [{'name': dag_name}] diff --git a/sky/utils/kubernetes/gpu_labeler.py b/sky/utils/kubernetes/gpu_labeler.py index 6877c94a2a8..9f5a11cea42 100644 --- a/sky/utils/kubernetes/gpu_labeler.py +++ b/sky/utils/kubernetes/gpu_labeler.py @@ -139,7 +139,7 @@ def label(): # Create the job for this node` batch_v1.create_namespaced_job(namespace, job_manifest) print(f'Created GPU labeler job for node {node_name}') - if len(gpu_nodes) == 0: + if not gpu_nodes: print('No GPU nodes found in the cluster. If you have GPU nodes, ' 'please ensure that they have the label ' f'`{kubernetes_utils.get_gpu_resource_key()}: `') diff --git a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py index 380c82f8c88..a764fb6e5e4 100644 --- a/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +++ b/sky/utils/kubernetes/ssh_jump_lifecycle_manager.py @@ -126,7 +126,7 @@ def manage_lifecycle(): f'error: {e}\n') raise - if len(ret.items) == 0: + if not ret.items: sys.stdout.write( f'[Lifecycle] Did not find pods with label ' f'"{label_selector}" in namespace {current_namespace}\n') diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index 7d304b60633..a9fad1b4b83 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -96,8 +96,8 @@ def test_empty_fields_storage(tmp_path): storage = task.storage_mounts['/mystorage'] assert storage.name == 'sky-dataset' assert storage.source is None - assert len(storage.stores) == 0 - assert storage.persistent is True + assert not storage.stores + assert storage.persistent def test_invalid_fields_storage(tmp_path): diff --git a/tests/unit_tests/test_storage_utils.py b/tests/unit_tests/test_storage_utils.py index cd1e436390b..6edb5abf2f5 100644 --- a/tests/unit_tests/test_storage_utils.py +++ b/tests/unit_tests/test_storage_utils.py @@ -7,7 +7,7 @@ def test_get_excluded_files_from_skyignore_no_file(): excluded_files = storage_utils.get_excluded_files_from_skyignore('.') - assert len(excluded_files) == 0 + assert not excluded_files def test_get_excluded_files_from_skyignore():