Skip to content

Commit

Permalink
Merge pull request #27 from maejima-fumika/fix-jit
Browse files Browse the repository at this point in the history
Fix JIT for general array type.
  • Loading branch information
maejima-fumika authored Nov 22, 2024
2 parents cd48c99 + e654557 commit ad21c84
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 57 deletions.
56 changes: 32 additions & 24 deletions notebook/src/hooks/repl-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ export default function ReplProvider({children}: {children: ReactNode}) {
const dram = useMemory('DRAM')

const bluetooth = useRef(new Bluetooth())
let onExecutionComplete = useRef((executionTime: number) => {})

// To use these variables in callbacks
const latestCellRef = useRef(latestCell)
latestCellRef.current = latestCell
const iramRef = useRef(iram)
const dramRef = useRef(dram)
iramRef.current = iram
dramRef.current = dram


useEffect(() => {
bluetooth.current.setNotificationHandler(onReceiveNotification);
Expand All @@ -75,10 +83,9 @@ export default function ReplProvider({children}: {children: ReactNode}) {
}

const executeLatestCell = async () => {
let updatedLatestCell:CellT = {...latestCell, compileError: '', state: 'compiling'}
setLatestCell(updatedLatestCell)
setLatestCell({...latestCell, compileError: '', state: 'compiling'})
try {
const compileResult = useJIT ? await network.compileWithProfiling(updatedLatestCell.code) : await network.compile(updatedLatestCell.code)
const compileResult = useJIT ? await network.compileWithProfiling(latestCell.code) : await network.compile(latestCell.code)
const iramBuffer = Buffer.from(compileResult.iram.data, "hex")
const dramBuffer = Buffer.from(compileResult.dram.data, "hex")
const flashBuffer = Buffer.from(compileResult.flash.data, "hex")
Expand All @@ -89,26 +96,12 @@ export default function ReplProvider({children}: {children: ReactNode}) {
.loadToFlash(compileResult.flash.address, flashBuffer)
.jump(compileResult.entryPoint)
.generate()
updatedLatestCell = {...updatedLatestCell, state: 'sending'}
setLatestCell(updatedLatestCell)
setLatestCell({...latestCell, compileError: '', state: 'sending'})
iram.actions.setUsedSegment(compileResult.iram.address, iramBuffer.length)
dram.actions.setUsedSegment(compileResult.dram.address, dramBuffer.length)
const bluetoothTime = await bluetooth.current.sendBuffers(bytecodeBuffer)
const compileTime = compileResult.compileTime
updatedLatestCell = {...updatedLatestCell, state: 'executing', time: {compile: compileTime, bluetooth: bluetoothTime}}
setLatestCell(updatedLatestCell)
onExecutionComplete.current = (executionTime: number) => {
if (updatedLatestCell.time !== undefined && updatedLatestCell.time?.execution === undefined) {
updatedLatestCell.time.execution = executionTime
updatedLatestCell.state = 'done'
const nextCellId = updatedLatestCell.id + 1
setPostExecutionCells((cells) =>
[...cells, updatedLatestCell]
)
setLatestCell({id: nextCellId, code: '', state: 'user-writing'})
}

}
setLatestCell({...latestCell, state: 'executing', time: {compile: compileTime, bluetooth: bluetoothTime}})
} catch (error: any) {
if (error instanceof CompileError) {
setLatestCell({...latestCell, state: 'user-writing', compileError: error.toString()})
Expand Down Expand Up @@ -145,13 +138,29 @@ export default function ReplProvider({children}: {children: ReactNode}) {
});
}

const onExecutionComplete = (executionTime: number) => {
if (latestCellRef.current.time !== undefined && latestCellRef.current.time?.execution === undefined) {
latestCellRef.current.time.execution = executionTime
latestCellRef.current.state = 'done'
const nextCellId = latestCellRef.current.id + 1
setPostExecutionCells((cells) =>
[...cells, latestCellRef.current]
)
setLatestCell({id: nextCellId, code: '', state: 'user-writing'})
}
}

const jitCompile = (fid: number, paramtypes: string[]) => {
network.jitCompile(fid, paramtypes).then((compileResult) => {
console.log(compileResult)
const iramBuffer = Buffer.from(compileResult.iram.data, "hex")
const dramBuffer = Buffer.from(compileResult.dram.data, "hex")
iramRef.current.actions.setUsedSegment(compileResult.iram.address, iramBuffer.length)
dramRef.current.actions.setUsedSegment(compileResult.dram.address, dramBuffer.length)
const bytecodeBuffer =
new BytecodeBufferBuilder(MAX_MTU)
.loadToRAM(compileResult.iram.address, Buffer.from(compileResult.iram.data, "hex"))
.loadToRAM(compileResult.dram.address, Buffer.from(compileResult.dram.data, "hex"))
.loadToRAM(compileResult.iram.address, iramBuffer)
.loadToRAM(compileResult.dram.address, dramBuffer)
.loadToFlash(compileResult.flash.address, Buffer.from(compileResult.flash.data, "hex"))
.jump(compileResult.entryPoint)
.generate()
Expand All @@ -175,7 +184,7 @@ export default function ReplProvider({children}: {children: ReactNode}) {
onDeviceResetComplete(parseResult.meminfo)
break;
case BYTECODE.RESULT_EXECTIME:
onExecutionComplete.current(parseResult.exectime)
onExecutionComplete(parseResult.exectime)
break;
case BYTECODE.RESULT_PROFILE: {
console.log("receive profile", parseResult.fid, parseResult.paramtypes);
Expand All @@ -186,7 +195,6 @@ export default function ReplProvider({children}: {children: ReactNode}) {
}
}


return (
<ReplContext.Provider value={{
state: replState,
Expand Down
4 changes: 2 additions & 2 deletions notebook/src/hooks/use-memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ export const MemoryDummry = {
name: '',
size: 0,
usedSize: 0,
buffer: [],
buffer: [false],
unitSize: UNIT_SIZE
},
actions: {
reset: (size: number) => {},
reset: (size: number, address: number) => {},
setUsedSegment: (start: number, size: number) => {},
}
}
Expand Down
15 changes: 10 additions & 5 deletions server/src/jit/jit-code-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function specializedFunctionBodyName(name: string) {

// returns '(' or '<type check function>('
// '(' is returned if the type cannot be checked.
function checkType(type?: StaticType) {
function checkType(env: VariableEnv, type?: StaticType): string|undefined {
if (type instanceof ArrayType) {
if (type.elementType === Integer)
return 'gc_is_intarray(';
Expand All @@ -83,7 +83,7 @@ function checkType(type?: StaticType) {
if (type.elementType === Any)
return 'gc_is_anyarray(';
else
throw new JITCompileError('Unknown array type.');
return `gc_is_instance_of(&${env.useArrayType(type)[0]}.clazz, `;
}

if (type instanceof InstanceType) {
Expand All @@ -98,13 +98,12 @@ function checkType(type?: StaticType) {
case BooleanT:
return 'is_bool_value(';
case StringT:
return `gc_is_string_object(`;
return 'gc_is_string_object(';
default:
return undefined;
}
}


export class JitCodeGenerator extends CodeGenerator{
private profiler: Profiler;
private bsSrc: string[];
Expand Down Expand Up @@ -250,13 +249,19 @@ export class JitCodeGenerator extends CodeGenerator{
this.result.write(') {')
this.result.right().nl()

// For test
this.result.write(`#ifdef TEST64`).nl().write('puts("Execute specialized function");').nl().write('#endif').nl()

this.result.write('return ');
this.functionCall(node, fenv, specializedFuncName, specializedType, funcType.paramTypes, 'self')

this.result.left().nl()
this.result.write('} else {')
this.result.right().nl()

// For test
this.result.write(`#ifdef TEST64`).nl().write('puts("Execute original function");').nl().write('#endif').nl()

this.result.write('return ');
this.functionCall(node, fenv, originalFuncName, funcType, funcType.paramTypes, 'self')
this.signatures += this.makeFunctionStruct(originalFuncName, funcType, false)
Expand Down Expand Up @@ -310,7 +315,7 @@ export class JitCodeGenerator extends CodeGenerator{
const paramName = (node.params[i] as AST.Identifier).name
const info = fenv.table.lookup(paramName)
if (info !== undefined) {
const check = checkType(targetParamTypes[i]);
const check = checkType(fenv, targetParamTypes[i]);
if (check) {
const name = info.transpiledName(paramName)
paramSig.push(`${check}${name})`);
Expand Down
29 changes: 26 additions & 3 deletions server/src/jit/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ export function typeStringToStaticType(typeString: string, gvnt?: GlobalVariable
return new ArrayType('float')
} else if (typeString === 'Array<boolean>') {
return new ArrayType('boolean')
} else if (typeString === 'Array') {
return 'any'
} else if (isArray(typeString)) {
return getArrayType(typeString, gvnt)
} else if (typeString === 'Function') {
return 'any'
} else {
Expand All @@ -42,6 +42,30 @@ export function typeStringToStaticType(typeString: string, gvnt?: GlobalVariable
}
}

function isArray(typeString: string) {
return /\[\]$/.test(typeString)
}

function getArrayType(typeString: string, gvnt?: GlobalVariableNameTable):StaticType {
const matches = typeString.match(/(\[\])+$/);
let ndim = matches ? matches[0].length / 2 : 0;
const className = typeString.replace(/(\[\])+$/, "");
let arr: StaticType|undefined;
if (className === 'string')
arr = 'string'
else {
arr = gvnt === undefined ? undefined : gvnt.lookup(className)?.type
if (arr === undefined || !(arr instanceof InstanceType))
throw new ProfileError(`Cannot find the profiled class: ${className}`)
}

while (ndim > 0) {
arr = new ArrayType(arr)
ndim -= 1
}
return arr
}

function staticTypeToTSType(type: StaticType): AST.TSType {
if (type === Integer || type === Float)
return tsTypeReference(identifier(type));
Expand Down Expand Up @@ -76,7 +100,6 @@ function staticTypeToTSType(type: StaticType): AST.TSType {
return tsAnyKeyword();
}


function staticTypeToNode(type: StaticType):AST.TSTypeAnnotation {
return tsTypeAnnotation(staticTypeToTSType(type));
}
Expand Down
8 changes: 4 additions & 4 deletions server/src/server/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ export default class Session {

public executeWithProfiling(tsString: string) {
this.currentCodeId += 1;
const ast = runBabelParser(tsString, 1)

const codeGenerator = (initializerName: string, codeId: number, moduleId: number) => {
return new JitCodeGenerator(initializerName, codeId, moduleId, this.profiler, tsString);
Expand All @@ -72,7 +71,10 @@ export default class Session {
}

const start = performance.now();

// Transpile
const ast = runBabelParser(tsString, 1);
convertAst(ast, this.profiler);
const tResult = jitTranspile(this.currentCodeId, ast, typeChecker, codeGenerator, this.nameTable, undefined)
const entryPointName = tResult.main;
const cString = cProlog + tResult.code;
Expand Down Expand Up @@ -105,11 +107,9 @@ export default class Session {
}

// Transpile
const ast = runBabelParser(func.src, 1);
const start = performance.now();

const ast = runBabelParser(func.src, 1);
convertAst(ast, this.profiler);
// const tResult = transpile(0, func.src, this.nameTable, undefined, -1, ast, codeGenerator);
const tResult = jitTranspile(this.currentCodeId, ast, typeChecker, codeGenerator, this.nameTable, undefined)
const entryPointName = tResult.main;
const cString = cProlog + tResult.code;
Expand Down
Loading

0 comments on commit ad21c84

Please sign in to comment.