Skip to content

Commit

Permalink
i#1569 AArch64: Reimplement encoder and decoder.
Browse files Browse the repository at this point in the history
In the new implementation there are encoder/decoder functions for each
operand type, rather than for each set of operands. A significant
proportion of A64 is now handled, including all loads and stores.

Review-URL: https://codereview.appspot.com/305320043
  • Loading branch information
egrimley-arm committed Sep 5, 2016
1 parent bc6588d commit 521301a
Show file tree
Hide file tree
Showing 9 changed files with 16,192 additions and 1,001 deletions.
2,345 changes: 1,815 additions & 530 deletions core/arch/aarch64/codec.c

Large diffs are not rendered by default.

309 changes: 223 additions & 86 deletions core/arch/aarch64/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,68 @@
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.

# This script reads "codec.txt" and generates "codec_gen.h" and "opcode.h".
# Run it manually, in this directory, when "codec.txt" has been changed.
# This script reads "codec.txt" and generates "decode_gen.h", "encode_gen.h",
# "opcode.h" and "opcode_names.h". Run it manually, in this directory, when
# "codec.txt" has been changed.

import re

N = 32 # bits in an instruction word

header = '/* This file was generated by codec.py from codec.txt. */\n\n'

def encoding_to_str(enc):
return '%08x %08x %s %s' % enc

def check(encodings):
tv = dict()
for i in range(len(encodings)):
if (encodings[i][0] & encodings[i][1]):
print 'Bad encoding:', encoding_to_str(encodings[i])
raise Exception
if encodings[i][3] in tv:
if encodings[i][1] != tv[encodings[i][3]]:
print 'Inconsistent mask:', encoding_to_str(encodings[i])
raise Exception
else:
tv[encodings[i][3]] = encodings[i][1]
for j in range(i):
if ((encodings[j][0] ^ encodings[i][0]) &
~encodings[j][1] & ~encodings[i][1] == 0):
print 'Overlapping encodings:'
print encoding_to_str(encodings[j])
print encoding_to_str(encodings[i])
raise Exception

def generate_decoder(encodings):
def gen(c, encs, depth):
def generate_decoder(patterns, opndsgen, opndtypes):

# Function to generate decoder for one opndset.
def gen1(c, opndsgen):
for name in sorted(opndsgen):
(dsts, srcs) = opndsgen[name]
f = 0xffffffff # bits not handled by any operand
for x in dsts + srcs:
f &= ~opndtypes[x]
c += ['static bool',
('decode_opnds%s' % name) + '(uint enc, dcontext_t *dcontext, ' +
'byte *pc, instr_t *instr, int opcode)',
'{']
if dsts + srcs != []:
vars = (['dst%d' % i for i in range(len(dsts))] +
['src%d' % i for i in range(len(srcs))])
tests = (['!decode_opnd_%s(enc & 0x%08x, opcode, pc, &dst%d)' %
(dsts[i], f | opndtypes[dsts[i]], i) for i in range(len(dsts))]
+
['!decode_opnd_%s(enc & 0x%08x, opcode, pc, &src%d)' %
(srcs[i], f | opndtypes[srcs[i]], i) for i in range(len(srcs))])
c += [' opnd_t ' + ', '.join(vars) + ';']
c += [' if (' + ' ||\n '.join(tests) + ')']
c += [' return false;']
c.append(' instr_set_opcode(instr, opcode);')
c.append(' instr_set_num_opnds(dcontext, instr, %d, %d);' %
(len(dsts), len(srcs)))
for i in range(len(dsts)):
c.append(' instr_set_dst(instr, %d, dst%d);' % (i, i))
for i in range(len(srcs)):
c.append(' instr_set_src(instr, %d, src%d);' % (i, i))
c.append(' return true;')
c.append('}')
c.append('')

# Recursive function to generate nested conditionals in main decoder.
def gen(c, pats, depth):
indent = " " * depth
if len(encs) < 4:
for (f, v, m, t) in sorted(encs, key = lambda (f, v, m, t): (m, t, f, v)):
if len(pats) < 4:
for (f, v, m, t) in sorted(pats, key = lambda (f, v, m, t): (m, t, f, v)):
c.append('%sif ((enc & 0x%08x) == 0x%08x)' %
(indent, ((1 << N) - 1) & ~v, f))
c.append('%s return decode_%s(enc, dc, pc, instr, OP_%s);' %
c.append('%s return decode_opnds%s(enc, dc, pc, instr, OP_%s);' %
(indent, t, m))
return
# Look for best bit to test. We aim to reduce the number of patterns remaining.
best_b = -1
best_x = len(encs)
best_x = len(pats)
for b in range(N):
x0 = 0
x1 = 0
for (f, v, _, _) in encs:
for (f, v, _, _) in pats:
if (1 << b) & (~f | v):
x0 += 1
if (1 << b) & (f | v):
Expand All @@ -87,59 +101,102 @@ def gen(c, encs, depth):
best_b = b
best_x = x
c.append('%sif ((enc >> %d & 1) == 0) {' % (indent, best_b))
encs0 = []
encs1 = []
for e in encs:
(f, v, _, _) = e
pats0 = []
pats1 = []
for p in pats:
(f, v, _, _) = p
if (1 << best_b) & (~f | v):
encs0.append(e)
pats0.append(p)
if (1 << best_b) & (f | v):
encs1.append(e)
gen(c, encs0, depth + 1)
pats1.append(p)
gen(c, pats0, depth + 1)
c.append('%s} else {' % indent)
gen(c, encs1, depth + 1)
gen(c, pats1, depth + 1)
c.append('%s}' % indent)

c = ['static bool',
'decoder(uint enc, dcontext_t *dc, byte *pc, instr_t *instr)',
'{']
gen(c, encodings, 1)
c = []
gen1(c, opndsgen)
c += ['static bool',
'decoder(uint enc, dcontext_t *dc, byte *pc, instr_t *instr)',
'{']
gen(c, patterns, 1)
c.append(' return false;')
c.append('}')
return '\n'.join(c) + '\n'

def generate_encoder(encodings):
def generate_encoder(patterns, opndsgen, opndtypes):
c = []
for name in sorted(opndsgen):
(dsts, srcs) = opndsgen[name]
f = 0xffffffff # bits not handled by any operand
for x in dsts + srcs:
f &= ~opndtypes[x]
c += ['static uint',
('encode_opnds%s' % name) + '(byte *pc, instr_t *instr, uint enc)',
'{']
if dsts + srcs == []:
c.append(' return enc;')
else:
vars = (['dst%d' % i for i in range(len(dsts))] +
['src%d' % i for i in range(len(srcs))])
c += [' int opcode = instr->opcode;']
c += [' uint ' + ', '.join(vars) + ';']
tests = (['instr_num_dsts(instr) == %d && instr_num_srcs(instr) == %d' %
(len(dsts), len(srcs))] +
['encode_opnd_%s(enc & 0x%08x, opcode, '
'pc, instr_get_dst(instr, %d), &dst%d)' %
(dsts[i], f | opndtypes[dsts[i]], i, i) for i in range(len(dsts))] +
['encode_opnd_%s(enc & 0x%08x, opcode, '
'pc, instr_get_src(instr, %d), &src%d)' %
(srcs[i], f | opndtypes[srcs[i]], i, i) for i in range(len(srcs))])
tests2 = (['dst%d == (enc & 0x%08x)' % (i, opndtypes[dsts[i]])
for i in range(len(dsts))] +
['src%d == (enc & 0x%08x)' % (i, opndtypes[srcs[i]])
for i in range(len(srcs))])
c += [' if (' + ' &&\n '.join(tests) + ') {']
c += [' ASSERT((dst%d & 0x%08x) == 0);' %
(i, 0xffffffff & ~opndtypes[dsts[i]]) for i in range(len(dsts))]
c += [' ASSERT((src%d & 0x%08x) == 0);' %
(i, 0xffffffff & ~opndtypes[srcs[i]]) for i in range(len(srcs))]
c += [' enc |= ' + ' | '.join(vars) + ';']
c += [' if (' + ' &&\n '.join(tests2) + ')']
c += [' return enc;']
c += [' }']
c += [' return ENCFAIL;']
c.append('}')
c.append('')
case = dict()
for e in encodings:
(b, m, mn, f) = e
for p in patterns:
(b, m, mn, f) = p
if not mn in case:
case[mn] = []
case[mn].append(e)
c = ['static uint',
'encoder(byte *pc, instr_t *i)',
'{',
' uint enc;',
' (void)enc;',
' switch (i->opcode) {']
case[mn].append(p)
c += ['static uint',
'encoder(byte *pc, instr_t *instr)',
'{',
' uint enc;',
' (void)enc;',
' switch (instr->opcode) {']
for mn in sorted(case):
c.append(' case OP_%s:' % (mn))
encs = sorted(case[mn], key = lambda (b, m, mn, f): (mn, f, b, m))
enc1 = encs.pop()
for e in encs:
(b, m, mn, f) = e
c.append(' if ((enc = encode_%s(pc, i, 0x%08x)) != ENCFAIL)' % (f, b))
pats = sorted(case[mn], key = lambda (b, m, mn, f): (mn, f, b, m))
pat1 = pats.pop()
for p in pats:
(b, m, mn, f) = p
c.append(' if ((enc = encode_opnds%s(pc, instr, 0x%08x)) != ENCFAIL)' %
(f, b))
c.append(' return enc;')
(b, m, mn, f) = enc1
c.append(' return encode_%s(pc, i, 0x%08x);' % (f, b))
(b, m, mn, f) = pat1
c.append(' return encode_opnds%s(pc, instr, 0x%08x);' % (f, b))
c += [' }',
' return ENCFAIL;',
'}']
return '\n'.join(c) + '\n'

def generate_opcodes(encodings):
def generate_opcodes(patterns):
mns = dict()
for e in encodings:
mns[e[2]] = 1
for p in patterns:
mns[p[2]] = 1
c = ['#ifndef OPCODE_H',
'#define OPCODE_H 1',
'',
Expand Down Expand Up @@ -191,10 +248,10 @@ def generate_opcodes(encodings):
'#endif /* OPCODE_H */']
return '\n'.join(c) + '\n'

def generate_opcode_names(encodings):
def generate_opcode_names(patterns):
mns = dict()
for e in encodings:
mns[e[2]] = 1
for p in patterns:
mns[p[2]] = 1
c = ['#ifndef OPCODE_NAMES_H',
'#define OPCODE_NAMES_H 1',
'',
Expand All @@ -221,30 +278,110 @@ def write_if_changed(file, data):
pass
open(file, 'w').write(data)

def main():
f = open('codec.txt', 'r')
encodings = []
for line in f:
def read_file(path):
file = open(path, 'r')
opndtypes = dict()
patterns = []
for line in file:
# Remove comment and trailing spaces.
line = re.sub("\s*(#.*)?\n?$", "", line)
if line == '':
continue
x = line.split()
if (len(x) != 4):
print 'Wrong number of words:', x
raise Exception
if (not re.match("[0-9a-f]{8}$", x[0]) or
not re.match("[0-9a-f]{8}$", x[1])):
print 'Bad hex:', x[0], x[1]
raise Exception
encodings.append((int(x[0], 16), int(x[1], 16), x[2], x[3]))
check(encodings)
write_if_changed('codec_gen.h',
header + generate_decoder(encodings) +
'\n' + generate_encoder(encodings))
if re.match("^[x-]{32} +[a-zA-Z_0-9]+$", line):
# Syntax: mask opndtype
(mask, opndtype) = line.split()
if opndtype in opndtypes:
raise Exception('Repeated definition of opndtype %s' % opndtype)
opndtypes[opndtype] = int(re.sub("x", "1", re.sub("-", "0", mask)), 2)
continue
if re.match("^[01x]{32} +[a-zA-Z_0-9][a-zA-Z_0-9 ]*:[a-zA-Z_0-9 ]*$", line):
# Syntax: pattern opcode opndtype* : opndtype*
(str1, str2) = line.split(":")
(words, srcs) = (str1.split(), str2.split())
(pattern, opcode, dsts) = (words[0], words[1], words[2:])
opcode_bits = int(re.sub("x", "0", pattern), 2)
opnd_bits = int(re.sub("x", "1", re.sub("1", "0", pattern)), 2)
patterns.append((opcode_bits, opnd_bits, opcode, (dsts, srcs)))
continue
if re.match("^[01x]{32} +[a-zA-Z_0-9]+ +[a-zA-Z_0-9]+", line):
# Syntax: pattern opcode opndset
(pattern, opcode, opndset) = line.split()
opcode_bits = int(re.sub("x", "0", pattern), 2)
opnd_bits = int(re.sub("x", "1", re.sub("1", "0", pattern)), 2)
patterns.append((opcode_bits, opnd_bits, opcode, opndset))
continue
raise Exception('Cannot parse line: %s' % line)
return (patterns, opndtypes)

def pattern_to_str((opcode_bits, opnd_bits, opcode, opndset)):
p = ''
for i in range(N - 1, -1, -1):
p += 'x' if (opnd_bits >> i & 1) else '%d' % (opcode_bits >> i & 1)
t = opndset
if not type(t) is str:
(dsts, srcs) = t
t = ' '.join(dsts) + ' : ' + ' '.join(srcs)
return '%s %s %s' % (p, opcode, t)

def consistency_check(patterns, opndtypes):
for p in patterns:
(opcode_bits, opnd_bits, opcode, opndset) = p
if not type(opndset) is str:
(dsts, srcs) = opndset
bits = opnd_bits
for ot in dsts + srcs:
if not ot in opndtypes:
raise Exception('Undefined opndtype %s in:\n%s' %
(ot, pattern_to_str(p)))
bits &= ~opndtypes[ot]
if bits != 0:
raise Exception('Unhandled bits:\n%32s in:\n%s' %
(re.sub('1', 'x', re.sub('0', ' ', bin(bits)[2:])),
pattern_to_str(p)))
for i in range(len(patterns)):
for j in range(i):
if ((patterns[j][0] ^ patterns[i][0]) &
~patterns[j][1] & ~patterns[i][1] == 0):
raise Exception('Overlapping patterns:\n%s\n%s' %
(pattern_to_str(patterns[j]),
pattern_to_str(patterns[i])))

# Here we give the opndsets names, which will be used in function names.
# We use the hex representation of the smallest pattern.
def opndset_naming(patterns):
opndsets = dict() # maps hash((dst, src)) to smallest pattern seen so far
for (opcode_bits, opnd_bits, opcode, opndset) in patterns:
if not type(opndset) is str:
(dsts, srcs) = opndset
h = ' '.join(dsts) + ':' + ' '.join(srcs)
if not h in opndsets or opcode_bits < opndsets[h]:
opndsets[h] = opcode_bits

opndsgen = dict() # maps generated name to original opndsets
new_patterns = []
for (opcode_bits, opnd_bits, opcode, opndset) in patterns:
if type(opndset) is str:
new_opndset = '_' + opndset
else:
(dsts, srcs) = opndset
h = ' '.join(dsts) + ':' + ' '.join(srcs)
new_opndset = 'gen_%08x' % opndsets[h]
opndsgen[new_opndset] = (dsts, srcs)
new_patterns.append((opcode_bits, opnd_bits, opcode, new_opndset))
return (new_patterns, opndsgen)

def main():
(patterns, opndtypes) = read_file('codec.txt')
consistency_check(patterns, opndtypes)
(patterns, opndsgen) = opndset_naming(patterns)
write_if_changed('decode_gen.h',
header + generate_decoder(patterns, opndsgen, opndtypes))
write_if_changed('encode_gen.h',
header + generate_encoder(patterns, opndsgen, opndtypes))
write_if_changed('opcode.h',
header + generate_opcodes(encodings))
header + generate_opcodes(patterns))
write_if_changed('opcode_names.h',
header + generate_opcode_names(encodings))
header + generate_opcode_names(patterns))

if __name__ == "__main__":
main()
Loading

0 comments on commit 521301a

Please sign in to comment.