diff --git a/sky/usage/usage_lib.py b/sky/usage/usage_lib.py index 7b43aa9464d4..df6842870e29 100644 --- a/sky/usage/usage_lib.py +++ b/sky/usage/usage_lib.py @@ -35,9 +35,7 @@ def _get_current_timestamp_ns() -> int: def _get_user_hash(): """Returns a unique user-machine specific hash as a user id for logging.""" user_id = os.getenv(constants.USAGE_USER_ENV) - if user_id and len(user_id) == 8: - return user_id - return common_utils.get_user_hash() + return common_utils.get_user_hash(default_value=user_id) class MessageType(enum.Enum): diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 48075be17c58..fe320800a7ea 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -18,6 +18,7 @@ from sky import sky_logging _USER_HASH_FILE = os.path.expanduser('~/.sky/user_hash') +USER_HASH_LENGTH = 8 _PAYLOAD_PATTERN = re.compile(r'(.*)') _PAYLOAD_STR = '{}' @@ -40,14 +41,37 @@ def get_usage_run_id() -> str: return _usage_run_id -def get_user_hash() -> str: - """Returns a unique user-machine specific hash as a user id.""" +def get_user_hash(default_value: Optional[str] = None) -> str: + """Returns a unique user-machine specific hash as a user id. + + We cache the user hash in a file to avoid potential user_name or + hostname changes causing a new user hash to be generated. + """ + + def _is_valid_user_hash(user_hash: Optional[str]) -> bool: + try: + int(user_hash, 16) + except (TypeError, ValueError): + return False + return len(user_hash) == USER_HASH_LENGTH + + user_hash = default_value + if _is_valid_user_hash(user_hash): + return user_hash + if os.path.exists(_USER_HASH_FILE): + # Read from cached user hash file. with open(_USER_HASH_FILE, 'r') as f: - return f.read() + # Remove invalid characters. + user_hash = f.read().strip() + if _is_valid_user_hash(user_hash): + return user_hash hash_str = user_and_hostname_hash() - user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:8] + user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:USER_HASH_LENGTH] + if not _is_valid_user_hash(user_hash): + # A fallback in case the hash is invalid. + user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH] os.makedirs(os.path.dirname(_USER_HASH_FILE), exist_ok=True) with open(_USER_HASH_FILE, 'w') as f: f.write(user_hash)