Skip to content

Commit

Permalink
[util,scramble_image] Remove hard coded memory scrambling parameters
Browse files Browse the repository at this point in the history
In order to remove the hard-coded, top-specific memory scrambling
parameters, we search for a memory controller based on the provided top
HJSON file and the scrambling mode.

Signed-off-by: Samuel Ortiz <[email protected]>
  • Loading branch information
sameo committed Feb 18, 2025
1 parent 3c7ce67 commit a4eea4c
Showing 1 changed file with 122 additions and 132 deletions.
254 changes: 122 additions & 132 deletions hw/ip/rom_ctrl/util/scramble_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
'''Script for scrambling a ROM image'''

import argparse
from enum import Enum
import sys
from typing import Dict, List, IO
from typing import Dict, IO, Optional

import hjson # type: ignore
from Crypto.Hash import cSHAKE256
Expand All @@ -17,54 +18,133 @@
from util.design.secded_gen import load_secded_config


class MemCtrlParams:
class ScramblingMode(Enum):
ROM0 = "base-rom"
ROM1 = "second-rom"
SRAM = "sram"


_UDict = Dict[object, object]


class MemoryController:
def __init__(
self, ctrl_name: str,
memory_type: str,
module_type: str,
scr_key: str, scr_key_width: int,
nonce: str, nonce_width: int
):
'''A memory controller parameters constructor
self, mem_ctrl: _UDict, ctrl_name: str, memory_type: str, mode: str):
'''A memory controller constructor
@mem_ctrl The memory controller hjson dictionary.
@ctrl_name is the memory controller IP block name (e.g. 'rom_ctrl0' for the base ROM)
@memory_type is the memory type this memory controller run ('rom' or 'ram')
@module_type is the HJSON module type for this memory controller (e.g. 'rom_ctrl')
@scr_key is the scrambling key entry name in the HJSON file
@scr_key_width is the expected scrambling key width
@nonce is the nonce entry name in the HJSON file
@nonce_width is the expected nonce width
@mode is the memory scrambling mode
'''

memory_type, key_name, nonce_name = {
ScramblingMode.ROM0.value: ["rom", "RndCnstScrKey", "RndCnstScrNonce"],
ScramblingMode.ROM1.value: ["rom", "RndCnstScrKey", "RndCnstScrNonce"],
ScramblingMode.SRAM.value: ["ram", "RndCnstSramKey", "RndCnstSramNonce"],
}[mode]

assert memory_type is not None
assert key_name is not None
assert nonce_name is not None

size_words = MemoryController._get_size_words(mem_ctrl, memory_type)
base = MemoryController._get_base(mem_ctrl, memory_type)
params = MemoryController._get_params(mem_ctrl)
nonce, nonce_width = MemoryController._get_param_cnst(params, nonce_name)
scr_key, scr_key_width = MemoryController._get_param_cnst(params, key_name)

self.ctrl_name = ctrl_name
self.memory_type = memory_type
self.module_type = module_type
self.base = base
self.size_words = size_words
self.scr_key = scr_key
self.scr_key_width = scr_key_width
self.nonce = nonce
self.nonce_width = nonce_width

@staticmethod
def _get_params(module: _UDict) -> Dict[str, _UDict]:
params = module.get('param_list')
assert isinstance(params, list)

named_params = {} # type: Dict[str, _UDict]
for param in params:
name = param.get('name')
assert isinstance(name, str)
assert name not in named_params
named_params[name] = param

return named_params

@staticmethod
def _get_param_cnst(params: Dict[str, _UDict], name: str) -> tuple[int, int]:
param = params.get(name)
assert isinstance(param, dict)

default = param.get("default")
assert isinstance(default, str)
val = int(default, 0)

width = param.get("randwidth")
assert isinstance(width, int)

assert 0 <= val < (1 << width)

return val, width

@staticmethod
def _get_size_words(module: _UDict, memory_type: str) -> int:
memory = module.get("memory")
assert isinstance(memory, dict)
memory = memory.get(memory_type)
assert isinstance(memory, dict)
size_words_bytes_str = memory.get("size")
assert isinstance(size_words_bytes_str, str)
size_words_bytes = int(size_words_bytes_str, 16)
assert size_words_bytes % 4 == 0
return size_words_bytes // 4

@staticmethod
def _get_base(module: _UDict, memory_type: str) -> int:
base = module.get("base_addrs")
assert isinstance(base, dict)
base_addr_rom = base.get(memory_type)
assert isinstance(base_addr_rom, dict)
base_addr = base_addr_rom.get("hart")
assert isinstance(base_addr, str)
return int(base_addr, 16)

@staticmethod
def from_hjson_path(path: str, mode: str) -> Optional["MemoryController"]:
with open(path, "r", encoding='utf-8') as handle:
top = hjson.load(handle, use_decimal=True)

assert isinstance(top, dict)
modules = top.get('module')
assert isinstance(modules, list)

for entry in modules:
assert isinstance(entry, dict)
entry_type = entry.get('type')
assert isinstance(entry_type, str)
entry_name = entry.get('name')
assert isinstance(entry_name, str)

if mode == ScramblingMode.ROM0.value:
# Earlgrey has only one ROM, named `rom_ctrl`, while Darjeeling's
# first ROM is named `rom_ctrl0`
if entry_name in ("rom_ctrl", "rom_ctrl0"):
if entry_type == "rom_ctrl":
return MemoryController(entry, entry_name, entry_type, mode)
elif mode == ScramblingMode.ROM1.value:
if entry_name == "rom_ctrl1" and entry_type == "rom_ctrl":
return MemoryController(entry, entry_name, entry_type, mode)
elif mode == ScramblingMode.SRAM.value:
if entry_name == "sram_ctrl_main" and entry_type == "sram_ctrl":
return MemoryController(entry, entry_name, entry_type, mode)

return None

MEM_CTRL_PARAMS = {
"earlgrey": {
'base-rom': MemCtrlParams('rom_ctrl',
'rom', 'rom_ctrl',
'RndCnstScrKey', 128,
'RndCnstScrNonce', 64),
},
"darjeeling": {
'base-rom': MemCtrlParams('rom_ctrl0',
'rom', 'rom_ctrl',
'RndCnstScrKey', 128,
'RndCnstScrNonce', 64),
'second-rom': MemCtrlParams('rom_ctrl1',
'rom', 'rom_ctrl',
'RndCnstScrKey', 128,
'RndCnstScrNonce', 64),
'sram': MemCtrlParams('sram_ctrl_main',
'ram', 'sram_ctrl',
'RndCnstSramKey', 128,
'RndCnstSramNonce', 128),
}
}

PRESENT_SBOX4 = [
0xc, 0x5, 0x6, 0xb,
Expand All @@ -80,8 +160,6 @@ def __init__(
0x0, 0x7, 0x9, 0xa
]

_UDict = Dict[object, object]


def subst_perm_enc(data: int, key: int, width: int, num_rounds: int) -> int:
'''A model of prim_subst_perm in encrypt mode'''
Expand Down Expand Up @@ -179,102 +257,14 @@ def __init__(self, nonce: int, nonce_width: int,

self._addr_width = (rom_size_words - 1).bit_length()

@staticmethod
def _get_mem_ctrl(modules: List[object], type: str, name: str) -> _UDict:
mem_ctrls = [] # type: List[_UDict]
for entry in modules:
assert isinstance(entry, dict)
entry_type = entry.get('type')
assert isinstance(entry_type, str)
entry_name = entry.get('name')
assert isinstance(entry_name, str)

if entry_type == type and entry_name == name:
mem_ctrls.append(entry)

assert len(mem_ctrls) == 1
return mem_ctrls[0]

@staticmethod
def _get_params(module: _UDict) -> Dict[str, _UDict]:
params = module.get('param_list')
assert isinstance(params, list)

named_params = {} # type: Dict[str, _UDict]
for param in params:
name = param.get('name')
assert isinstance(name, str)
assert name not in named_params
named_params[name] = param

return named_params

@staticmethod
def _get_param_value(params: Dict[str, _UDict], name: str,
width: int) -> int:
param = params.get(name)
assert isinstance(param, dict)

default = param.get('default')
assert isinstance(default, str)
int_val = int(default, 0)
assert 0 <= int_val < (1 << width)
return int_val

@staticmethod
def _get_size_words(module: _UDict, memory_type: str) -> int:
memory = module.get("memory")
assert isinstance(memory, dict)
memory = memory.get(memory_type)
assert isinstance(memory, dict)
size_words_bytes_str = memory.get("size")
assert isinstance(size_words_bytes_str, str)
size_words_bytes = int(size_words_bytes_str, 16)
assert size_words_bytes % 4 == 0
return size_words_bytes // 4

@staticmethod
def _get_base(module: _UDict, memory_type: str) -> int:
base = module.get("base_addrs")
assert isinstance(base, dict)
base_addr_rom = base.get(memory_type)
assert isinstance(base_addr_rom, dict)
base_addr = base_addr_rom.get("hart")
assert isinstance(base_addr, str)
return int(base_addr, 16)

@staticmethod
def from_hjson_path(path: str, mode: str, hash_file: IO[str]) -> 'Scrambler':
with open(path) as handle:
top = hjson.load(handle, use_decimal=True)

assert isinstance(top, dict)
modules = top.get('module')
assert isinstance(modules, list)
mem_ctrl = MemoryController.from_hjson_path(path, mode)
assert mem_ctrl is not None

print(top["name"])
assert top["name"] in MEM_CTRL_PARAMS, "top {} is not supported".format(top["name"])
mem_ctrl_params = MEM_CTRL_PARAMS[top["name"]]
assert mode in mem_ctrl_params, \
"mode {} is not supported in top {}".format(mode, top["name"])
mc_params = mem_ctrl_params[mode]

mem_ctrl = Scrambler._get_mem_ctrl(modules,
mc_params.module_type,
mc_params.ctrl_name)
size_words = Scrambler._get_size_words(mem_ctrl, mc_params.memory_type)
base = Scrambler._get_base(mem_ctrl, mc_params.memory_type)
params = Scrambler._get_params(mem_ctrl)
nonce = Scrambler._get_param_value(params,
mc_params.nonce,
mc_params.nonce_width)
key = Scrambler._get_param_value(params,
mc_params.scr_key,
mc_params.scr_key_width)

return Scrambler(nonce, mc_params.nonce_width,
key, mc_params.scr_key_width,
base, size_words, hash_file)
return Scrambler(mem_ctrl.nonce, mem_ctrl.nonce_width,
mem_ctrl.scr_key, mem_ctrl.scr_key_width,
mem_ctrl.base, mem_ctrl.size_words, hash_file)

def flatten(self, mem: MemFile) -> MemFile:
'''Flatten and pad mem up to the correct size
Expand Down

0 comments on commit a4eea4c

Please sign in to comment.