diff --git a/src/app/ota_image_tool.py b/src/app/ota_image_tool.py index 45fbe6bac9dddb..7e14ed19c97759 100755 --- a/src/app/ota_image_tool.py +++ b/src/app/ota_image_tool.py @@ -60,6 +60,8 @@ DIGEST_ALL_ALGORITHMS = hashlib.algorithms_available.intersection( DIGEST_ALGORITHM_ID.keys()) +# Buffer size used for file reads to ensure large files do not need to be loaded +# into memory fully before processing. PAYLOAD_BUFFER_SIZE = 16 * 1024 @@ -76,6 +78,46 @@ class HeaderTag(IntEnum): DIGEST = 9 +def warn(message: str): + sys.stderr.write(f'warning: {message}\n') + + +def error(message: str): + sys.stderr.write(f'error: {message}\n') + sys.exit(1) + + +def validate_header_attributes(args: object): + """ + Validate attributes to be stored in OTA image header + """ + + if args.vendor_id == 0: + error('Vendor ID is zero') + + if args.product_id == 0: + error('Product ID is zero') + + if not 1 <= len(args.version_str) <= 64: + error('Software version string is not of length 1-64') + + if args.min_version is not None and args.min_version >= args.version: + error('Minimum applicable version is greater or equal to software version') + + if args.max_version is not None and args.max_version >= args.version: + error('Maximum applicable version is greater or equal to software version') + + if args.min_version is not None and args.max_version is not None and args.max_version < args.min_version: + error('Minimum applicable version is greater than maximum applicable version') + + if args.release_notes is not None: + if not 1 <= len(args.release_notes) <= 256: + error('Release notes URL must be of length 1-256') + + if not args.release_notes.startswith('https://'): + warn('Release notes URL does not start with "https://"') + + def generate_payload_summary(args: object): """ Calculate total size and hash of all concatenated input payload files @@ -84,6 +126,9 @@ def generate_payload_summary(args: object): total_size = 0 digest = hashlib.new(args.digest_algorithm) + if digest.digest_size < (256 // 8): + warn('Using digest length below 256 bits is not recommended') + for path in args.input_files: with open(path, 'rb') as file: while True: @@ -111,14 +156,14 @@ def generate_header_tlv(args: object, payload_size: int, payload_digest: bytes): HeaderTag.DIGEST: payload_digest, } - if args.min_version: + if args.min_version is not None: fields.update({HeaderTag.MIN_VERSION: uint(args.min_version)}) - if args.max_version: + if args.max_version is not None: fields.update({HeaderTag.MAX_VERSION: uint(args.max_version)}) - if args.release_notes: - fields.append({HeaderTag.RELEASE_NOTES_URL: args.release_notes}) + if args.release_notes is not None: + fields.update({HeaderTag.RELEASE_NOTES_URL: args.release_notes}) writer = TLVWriter() writer.put(None, fields) @@ -242,6 +287,7 @@ def any_base_int(s): return int(s, 0) args = parser.parse_args() if args.subcommand == 'create': + validate_header_attributes(args) generate_image(args) elif args.subcommand == 'show': show_header(args)