diff --git a/.vscode/launch.json b/.vscode/launch.json index 96d3ef57b5aac..1e95820a2cd93 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -33,13 +33,13 @@ "-resource-dir", "./build/lib/clang/17", "-isysroot", - "../wasi-libc/sysroot", + "/scratch/martin/src/wasm/wasi-libc/sysroot", "-internal-isystem", "./build/lib/clang/17/include", "-internal-isystem", - "../wasi-libc/sysroot/include/wasm64-wasi", + "/scratch/martin/src/wasm/wasi-libc/sysroot/include/wasm64-wasi", "-internal-isystem", - "../wasi-libc/sysroot/include", + "/scratch/martin/src/wasm/wasi-libc/sysroot/include", "-Os", "-ferror-limit", "19", @@ -135,7 +135,7 @@ "command": "extension.commandvariable.file.pickFile", "args": { "fromFolder": { - "fixed": "/scratch/martin/src/wasm/hello-world" + "fixed": "/scratch/fritz/src/safe-wasm/llvm-project/demo" }, "include": "**/*.c" } diff --git a/.vscode/settings.json b/.vscode/settings.json index c9117a3d8ec98..ebf127975b3ad 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -95,7 +95,8 @@ "shared_mutex": "cpp", "stop_token": "cpp", "variant": "cpp", - "*.inc": "cpp" + "*.inc": "cpp", + "*.def": "c" }, "clangd.arguments": [ "--query-driver=${env:PATH_TO_CLANG}" diff --git a/demo/.gitignore b/demo/.gitignore new file mode 100644 index 0000000000000..19e1bced9ad8a --- /dev/null +++ b/demo/.gitignore @@ -0,0 +1 @@ +*.wasm diff --git a/demo/README.md b/demo/README.md index 71baa21fe5097..4168f3f0f00d0 100644 --- a/demo/README.md +++ b/demo/README.md @@ -1,5 +1,11 @@ # Demo +## PAC Testing + +To test the PAC features, which consist of a module pass that inserts custom wasm instructions that sign (`pointer_sign`) and authenticate (`pointer_auth`) a pointer before it is stored to a memory address: +- Disable MTE (in wasmtime or simply don't generate our MTE-based custom instructions in llvm/clang), since PAC can currently not effectively work when MTE is used (no PAC instructions would be inserted due to limits in our LLVM analysis). +- Compile demo C file `test-prevent-real-attack.c` without optimizations (`-O0`), since the code that tests PAC would otherwise be optimized away. + ## Building ```shell @@ -39,3 +45,13 @@ Wasmtime can be cross-compiled for aarch64 with the provided `Dockerfile` in thi ./wasmtime compile demo-scanf.wasm --cranelift-enable use_mte --wasm-features=memory64,mem-safety ./wasmtime run --allow-precompiled --wasm-features=memory64,mem-safety -- demo-scanf.cwasm ``` + +## Checking generated code + +You can test what code is generated with the following commands, even if you are on a machine that is not aarch64 or doesn't support mte. + +```shell +./wasmtime compile --target aarch64-unknown-linux-gnu --cranelift-enable use_mte --wasm-features=memory64,mem-safety demo-.c +llvm-objdump -D demo-.cwasm > demo-.s +``` + diff --git a/demo/benchmarks-code/memory-tagging/Makefile b/demo/benchmarks-code/memory-tagging/Makefile new file mode 100644 index 0000000000000..817e825338a9b --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/Makefile @@ -0,0 +1,51 @@ +# CC=/scratch/fritz/src/safe-wasm/llvm-project/build/bin/clang +# WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks -Wl,--stack-first -Wl,--initial-memory=104857600 -Wl,--max-memory=104857600 -Wl,-z,stack-size=83886080 +# SAN_FLAGS=-march=wasm64-wasi+mem-safety -fsanitize=wasm-memsafety +# CFLAGS=-O0 ${WASM_FLAGS} +# # CFLAGS=-O2 ${WASM_FLAGS} +# BUILD_DIR=build + +# all: ${BUILD_DIR}/tagging.wasm + +# # ${BUILD_DIR}/tagging.wasm: tagging.c +# ${BUILD_DIR}/tagging.wasm: tagging.ll +# ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +# clean: +# @ rm -f ${BUILD_DIR}/tagging.wasm + +# ${BUILD_DIR}: +# mkdir -p $@ + +# # Add the directory as a dependency to ensure it's created before compilation +# ${BUILD_DIR}/%.wasm: | ${BUILD_DIR} + + +CC=/scratch/fritz/src/safe-wasm/llvm-project/build/bin/clang +WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks -Wl,--stack-first -Wl,--initial-memory=104857600 -Wl,--max-memory=104857600 -Wl,-z,stack-size=83886080 +SAN_FLAGS=-march=wasm64-wasi+mem-safety -fsanitize=wasm-memsafety +CFLAGS=-O0 ${WASM_FLAGS} +# CFLAGS=-O2 ${WASM_FLAGS} +BUILD_DIR=build + +FILES=tagging-few-loops-large-segments \ + tagging-few-loops-small-segments \ + tagging-many-loops-large-segments \ + tagging-many-loops-small-segments + +TARGETS=$(addprefix ${BUILD_DIR}/, $(addsuffix .wasm, ${FILES})) + +all: ${TARGETS} + +${BUILD_DIR}/%.wasm: %.ll + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +clean: + @ rm -f ${BUILD_DIR}/*.wasm + +${BUILD_DIR}: + mkdir -p $@ + +# Add the directory as a dependency to ensure it's created before compilation +${BUILD_DIR}/%.wasm: | ${BUILD_DIR} + diff --git a/demo/benchmarks-code/memory-tagging/tagging-few-loops-large-segments.ll b/demo/benchmarks-code/memory-tagging/tagging-few-loops-large-segments.ll new file mode 100644 index 0000000000000..bab5bdccc2f6a --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/tagging-few-loops-large-segments.ll @@ -0,0 +1,29 @@ +; File: segment_stack_test.ll + +declare ptr @llvm.wasm.segment.stack.new(ptr, i64) +declare void @llvm.wasm.segment.stack.free(ptr, ptr, i64) + +define i32 @__main_void() { +entry: + %static_size_array = alloca [1000000 x i32], align 16 + %iteration_count = alloca i32 + store i32 1600, i32* %iteration_count + + br label %loop_start + +loop_start: + %count = load i32, i32* %iteration_count + %is_done = icmp eq i32 %count, 0 + br i1 %is_done, label %loop_end, label %loop_body + +loop_body: + %segment_ptr = call ptr @llvm.wasm.segment.stack.new(ptr %static_size_array, i64 4000000) + call void @llvm.wasm.segment.stack.free(ptr %segment_ptr, ptr %static_size_array, i64 4000000) + + %new_count = sub i32 %count, 1 + store i32 %new_count, i32* %iteration_count + br label %loop_start + +loop_end: + ret i32 0 +} diff --git a/demo/benchmarks-code/memory-tagging/tagging-few-loops-small-segments.ll b/demo/benchmarks-code/memory-tagging/tagging-few-loops-small-segments.ll new file mode 100644 index 0000000000000..5bf3a741f7094 --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/tagging-few-loops-small-segments.ll @@ -0,0 +1,29 @@ +; File: segment_stack_test.ll + +declare ptr @llvm.wasm.segment.stack.new(ptr, i64) +declare void @llvm.wasm.segment.stack.free(ptr, ptr, i64) + +define i32 @__main_void() { +entry: + %static_size_array = alloca [400 x i32], align 16 + %iteration_count = alloca i32 + store i32 1600, i32* %iteration_count + + br label %loop_start + +loop_start: + %count = load i32, i32* %iteration_count + %is_done = icmp eq i32 %count, 0 + br i1 %is_done, label %loop_end, label %loop_body + +loop_body: + %segment_ptr = call ptr @llvm.wasm.segment.stack.new(ptr %static_size_array, i64 1600) + call void @llvm.wasm.segment.stack.free(ptr %segment_ptr, ptr %static_size_array, i64 1600) + + %new_count = sub i32 %count, 1 + store i32 %new_count, i32* %iteration_count + br label %loop_start + +loop_end: + ret i32 0 +} diff --git a/demo/benchmarks-code/memory-tagging/tagging-many-loops-large-segments.ll b/demo/benchmarks-code/memory-tagging/tagging-many-loops-large-segments.ll new file mode 100644 index 0000000000000..a49acaaed1c96 --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/tagging-many-loops-large-segments.ll @@ -0,0 +1,29 @@ +; File: segment_stack_test.ll + +declare ptr @llvm.wasm.segment.stack.new(ptr, i64) +declare void @llvm.wasm.segment.stack.free(ptr, ptr, i64) + +define i32 @__main_void() { +entry: + %static_size_array = alloca [10000 x i32], align 16 + %iteration_count = alloca i32 + store i32 40000, i32* %iteration_count + + br label %loop_start + +loop_start: + %count = load i32, i32* %iteration_count + %is_done = icmp eq i32 %count, 0 + br i1 %is_done, label %loop_end, label %loop_body + +loop_body: + %segment_ptr = call ptr @llvm.wasm.segment.stack.new(ptr %static_size_array, i64 40000) + call void @llvm.wasm.segment.stack.free(ptr %segment_ptr, ptr %static_size_array, i64 40000) + + %new_count = sub i32 %count, 1 + store i32 %new_count, i32* %iteration_count + br label %loop_start + +loop_end: + ret i32 0 +} diff --git a/demo/benchmarks-code/memory-tagging/tagging-many-loops-small-segments.ll b/demo/benchmarks-code/memory-tagging/tagging-many-loops-small-segments.ll new file mode 100644 index 0000000000000..551c329184ddb --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/tagging-many-loops-small-segments.ll @@ -0,0 +1,29 @@ +; File: segment_stack_test.ll + +declare ptr @llvm.wasm.segment.stack.new(ptr, i64) +declare void @llvm.wasm.segment.stack.free(ptr, ptr, i64) + +define i32 @__main_void() { +entry: + %static_size_array = alloca [400 x i32], align 16 + %iteration_count = alloca i32 + store i32 4000000, i32* %iteration_count + + br label %loop_start + +loop_start: + %count = load i32, i32* %iteration_count + %is_done = icmp eq i32 %count, 0 + br i1 %is_done, label %loop_end, label %loop_body + +loop_body: + %segment_ptr = call ptr @llvm.wasm.segment.stack.new(ptr %static_size_array, i64 1600) + call void @llvm.wasm.segment.stack.free(ptr %segment_ptr, ptr %static_size_array, i64 1600) + + %new_count = sub i32 %count, 1 + store i32 %new_count, i32* %iteration_count + br label %loop_start + +loop_end: + ret i32 0 +} diff --git a/demo/benchmarks-code/memory-tagging/tagging.c b/demo/benchmarks-code/memory-tagging/tagging.c new file mode 100644 index 0000000000000..10efa154855cb --- /dev/null +++ b/demo/benchmarks-code/memory-tagging/tagging.c @@ -0,0 +1,46 @@ +// #include +// #include + +// int main() { +// size_t n = 1000; +// // volatile int static_size_array[1000]; +// volatile int static_size_array[n]; + +// // for (int i = 0; i < n; i++) { +// // static_size_array[i] = i % 255; +// // } + +// for (int i = 0; i < n; i++) { +// static_size_array[i] = i % 255; +// } + +// // Pretend to use the data +// use_array(static_size_array, n); + +// return 0; +// } + + +#include +#include + +int check = 0; // global variable + +int main() { + size_t n = 10000; + size_t alignment = 32; + int static_size_array[n * 32]; + // int *static_size_array = malloc(sizeof(int) * (n*32)); + + for (int i = 0; i < n; i++) { + static_size_array[i] = i % 255; + } + + // Unpredictable branch to compiler, will never actually run, + // but compiler doesn't know that for sure + if (check) { + printf("%d", static_size_array[0]); + } + + return 0; +} diff --git a/demo/benchmarks-code/pac-store-load-loops/Makefile b/demo/benchmarks-code/pac-store-load-loops/Makefile new file mode 100644 index 0000000000000..755f762f2cf53 --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/Makefile @@ -0,0 +1,25 @@ +CC=/scratch/fritz/src/safe-wasm/llvm-project/build/bin/clang +# WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks -Wl,--stack-first -Wl,--initial-memory=104857600 -Wl,--max-memory=104857600 -Wl,-z,stack-size=83886080 +WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks -Wl,--stack-first -Wl,--initial-memory=1677721600 -Wl,--max-memory=1677721600 -Wl,-z,stack-size=1342177280 + +SAN_FLAGS=-march=wasm64-wasi+mem-safety -fsanitize=wasm-memsafety +#CFLAGS=-O0 ${WASM_FLAGS} +CFLAGS=-O2 ${WASM_FLAGS} +BUILD_DIR=build + +PAC_SOURCES=$(wildcard pac-*.c) +PAC_WASMS=$(PAC_SOURCES:%.c=${BUILD_DIR}/%.wasm) + +all: ${PAC_WASMS} + +${BUILD_DIR}/%.wasm: %.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +clean: + @ rm -f ${BUILD_DIR}/*.wasm + +${BUILD_DIR}: + mkdir -p $@ + +# Add the directory as a dependency to ensure it's created before compilation +${BUILD_DIR}/%.wasm: | ${BUILD_DIR} diff --git a/demo/benchmarks-code/pac-store-load-loops/pac-1.c b/demo/benchmarks-code/pac-store-load-loops/pac-1.c new file mode 100644 index 0000000000000..b32f2e389d7d8 --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/pac-1.c @@ -0,0 +1,21 @@ +#include +#include + +int main(int argc, char **argv) { + size_t n = 10000; + + void* ptrArray[n]; + + // Store n pointers in the array + for (size_t i = 0; i < n; i++) { + ptrArray[i] = (void*) i; // casting the iterating variable to a pointer + } + + // Load the n pointers from the array and accumulate their values + size_t sum = 0; + for (size_t i = 0; i < n; i++) { + sum += (size_t) ptrArray[i]; + } + + return sum % 125; // modulo to make sure it's a valid return code +} diff --git a/demo/benchmarks-code/pac-store-load-loops/pac-2.c b/demo/benchmarks-code/pac-store-load-loops/pac-2.c new file mode 100644 index 0000000000000..55a9ba9b3d9f5 --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/pac-2.c @@ -0,0 +1,21 @@ +#include +#include + +int main(int argc, char **argv) { + size_t n = 100000; + + void* ptrArray[n]; + + // Store n pointers in the array + for (size_t i = 0; i < n; i++) { + ptrArray[i] = (void*) i; // casting the iterating variable to a pointer + } + + // Load the n pointers from the array and accumulate their values + size_t sum = 0; + for (size_t i = 0; i < n; i++) { + sum += (size_t) ptrArray[i]; + } + + return sum % 125; // modulo to make sure it's a valid return code +} diff --git a/demo/benchmarks-code/pac-store-load-loops/pac-3.c b/demo/benchmarks-code/pac-store-load-loops/pac-3.c new file mode 100644 index 0000000000000..74b960236e1d1 --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/pac-3.c @@ -0,0 +1,21 @@ +#include +#include + +int main(int argc, char **argv) { + size_t n = 1000000; + + void* ptrArray[n]; + + // Store n pointers in the array + for (size_t i = 0; i < n; i++) { + ptrArray[i] = (void*) i; // casting the iterating variable to a pointer + } + + // Load the n pointers from the array and accumulate their values + size_t sum = 0; + for (size_t i = 0; i < n; i++) { + sum += (size_t) ptrArray[i]; + } + + return sum % 125; // modulo to make sure it's a valid return code +} diff --git a/demo/benchmarks-code/pac-store-load-loops/pac-4.c b/demo/benchmarks-code/pac-store-load-loops/pac-4.c new file mode 100644 index 0000000000000..a49c86bdcd4bf --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/pac-4.c @@ -0,0 +1,21 @@ +#include +#include + +int main(int argc, char **argv) { + size_t n = 10000000; + + void* ptrArray[n]; + + // Store n pointers in the array + for (size_t i = 0; i < n; i++) { + ptrArray[i] = (void*) i; // casting the iterating variable to a pointer + } + + // Load the n pointers from the array and accumulate their values + size_t sum = 0; + for (size_t i = 0; i < n; i++) { + sum += (size_t) ptrArray[i]; + } + + return sum % 125; // modulo to make sure it's a valid return code +} diff --git a/demo/benchmarks-code/pac-store-load-loops/pac-5.c b/demo/benchmarks-code/pac-store-load-loops/pac-5.c new file mode 100644 index 0000000000000..a4790ce6592f8 --- /dev/null +++ b/demo/benchmarks-code/pac-store-load-loops/pac-5.c @@ -0,0 +1,21 @@ +#include +#include + +int main(int argc, char **argv) { + size_t n = 100000000; + + void* ptrArray[n]; + + // Store n pointers in the array + for (size_t i = 0; i < n; i++) { + ptrArray[i] = (void*) i; // casting the iterating variable to a pointer + } + + // Load the n pointers from the array and accumulate their values + size_t sum = 0; + for (size_t i = 0; i < n; i++) { + sum += (size_t) ptrArray[i]; + } + + return sum % 125; // modulo to make sure it's a valid return code +} diff --git a/demo/benchmarks-code/sorting-ints/Makefile b/demo/benchmarks-code/sorting-ints/Makefile new file mode 100644 index 0000000000000..470a5811a4351 --- /dev/null +++ b/demo/benchmarks-code/sorting-ints/Makefile @@ -0,0 +1,25 @@ +CC=/scratch/fritz/src/safe-wasm/llvm-project/build/bin/clang +WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks /scratch/fritz/src/safe-wasm/llvm-project/wasm_memsafety_rtlib.c +SAN_FLAGS=-march=wasm64-wasi+mem-safety -fsanitize=wasm-memsafety +CFLAGS=-O2 ${WASM_FLAGS} +BUILD_DIR=build + +all: ${BUILD_DIR}/bubble_sort.wasm ${BUILD_DIR}/merge_sort.wasm ${BUILD_DIR}/modified_merge_sort.wasm + +${BUILD_DIR}/bubble_sort.wasm: bubble_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +${BUILD_DIR}/merge_sort.wasm: merge_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +${BUILD_DIR}/modified_merge_sort.wasm: modified_merge_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +clean: + @ rm -f ${BUILD_DIR}/bubble_sort.wasm ${BUILD_DIR}/merge_sort.wasm ${BUILD_DIR}/modified_merge_sort.wasm + +${BUILD_DIR}: + mkdir -p $@ + +# Add the directory as a dependency to ensure it's created before compilation +${BUILD_DIR}/%.wasm: | ${BUILD_DIR} diff --git a/demo/benchmarks-code/sorting-ints/bubble_sort.c b/demo/benchmarks-code/sorting-ints/bubble_sort.c new file mode 100644 index 0000000000000..b73011972eb3d --- /dev/null +++ b/demo/benchmarks-code/sorting-ints/bubble_sort.c @@ -0,0 +1,53 @@ +#include +#include + +// TODO: use size_t over int + +void bubble_sort(int* arr, int n) { + for (int i = 0; i < n-1; i++) { + for (int j = 0; j < n-i-1; j++) { + if (arr[j] > arr[j+1]) { + int temp = arr[j]; + arr[j] = arr[j+1]; + arr[j+1] = temp; + } + } + } +} + +int assert_sorted(int* arr, int n) { + for (int i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + return 0; // Not sorted + } + } + return 1; // Sorted +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + int n = atoi(argv[1]); + int* arr = (int*)malloc(n * sizeof(int)); + + // Read unsorted input array from stdin + for (int i = 0; i < n; i++) { + scanf("%d", &arr[i]); + } + + // Invoke SUT + bubble_sort(arr, n); + + if (!assert_sorted(arr, n)) { + fprintf(stderr, "Array is not sorted!\n"); + free(arr); + return 1; // Error + } + + free(arr); + + return 0; +} diff --git a/demo/benchmarks-code/sorting-ints/merge_sort.c b/demo/benchmarks-code/sorting-ints/merge_sort.c new file mode 100644 index 0000000000000..929c3452c7755 --- /dev/null +++ b/demo/benchmarks-code/sorting-ints/merge_sort.c @@ -0,0 +1,114 @@ +#include +#include + +// TODO: use size_t over int + +void merge(int* arr, int l, int m, int r) { + int L_size = m - l + 1; + int R_size = r - m; + + // create temporary arrays on the heap + int* L = (int*)malloc(L_size * sizeof(int)); + int* R = (int*)malloc(R_size * sizeof(int)); + + if (!L || !R) { + fprintf(stderr, "Memory allocation failed\n"); + exit(1); + } + + // Copy data to temp arrays L[] and R[] + for (int i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (int j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + // Merge the temp arrays back into arr[l..r] + int i = 0; + int j = 0; + int k = l; + while (i < L_size && j < R_size) { + if (L[i] <= R[j]) { + arr[k] = L[i]; + i++; + } else { + arr[k] = R[j]; + j++; + } + k++; + } + + // Copy the remaining elements of L[], if there are any + while (i < L_size) { + arr[k] = L[i]; + i++; + k++; + } + + // Copy the remaining elements of R[], if there are any + while (j < R_size) { + arr[k] = R[j]; + j++; + k++; + } + + free(L); + free(R); +} + +void merge_sort(int* arr, int l, int r) { + if (l < r) { + int m = l + (r - l) / 2; + merge_sort(arr, l, m); + merge_sort(arr, m + 1, r); + merge(arr, l, m, r); + } +} + +void bubble_sort(int* arr, int n) { + for (int i = 0; i < n-1; i++) { + for (int j = 0; j < n-i-1; j++) { + if (arr[j] > arr[j+1]) { + int temp = arr[j]; + arr[j] = arr[j+1]; + arr[j+1] = temp; + } + } + } +} + +int assert_sorted(int* arr, int n) { + for (int i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + return 0; // Not sorted + } + } + return 1; // Sorted +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + int n = atoi(argv[1]); + int* arr = (int*)malloc(n * sizeof(int)); + + // Read unsorted input array from stdin + for (int i = 0; i < n; i++) { + scanf("%d", &arr[i]); + } + + // Invoke SUT + merge_sort(arr, 0, n - 1); + + if (!assert_sorted(arr, n)) { + fprintf(stderr, "Array is not sorted!\n"); + free(arr); + return 1; // Error + } + + free(arr); + + return 0; +} diff --git a/demo/benchmarks-code/sorting-ints/modified_merge_sort.c b/demo/benchmarks-code/sorting-ints/modified_merge_sort.c new file mode 100644 index 0000000000000..484b91da3e059 --- /dev/null +++ b/demo/benchmarks-code/sorting-ints/modified_merge_sort.c @@ -0,0 +1,103 @@ +#include +#include + +#define LIMIT 160 + +void perform_merge(int *arr, const int *L, const int *R, size_t l, size_t n1, size_t n2) { + size_t i = 0, j = 0, k = l; + + while (i < n1 && j < n2) { + if (L[i] <= R[j]) { + arr[k++] = L[i++]; + } else { + arr[k++] = R[j++]; + } + } + + while (i < n1) { + arr[k++] = L[i++]; + } + + while (j < n2) { + arr[k++] = R[j++]; + } +} + +void merge(int* arr, size_t l, size_t m, size_t r) { + size_t L_size = m - l + 1; + size_t R_size = r - m; + + // For small paritition sizes, allocate constant size on the stack to take advantage of our optimizations + if (L_size <= LIMIT && R_size <= LIMIT) { + int L[LIMIT], R[LIMIT]; + + // Copy data to temp arrays L[] and R[] + for (size_t i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (size_t j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + perform_merge(arr, L, R, l, L_size, R_size); + } else { + int *L = (int *) malloc(L_size * sizeof(int)); + int *R = (int *) malloc(R_size * sizeof(int)); + + // Copy data to temp arrays L[] and R[] + for (size_t i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (size_t j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + perform_merge(arr, L, R, l, L_size, R_size); + + free(L); + free(R); + } +} + +void modified_merge_sort(int* arr, size_t l, size_t r) { + if (l < r) { + size_t m = l + (r - l) / 2; + modified_merge_sort(arr, l, m); + modified_merge_sort(arr, m + 1, r); + merge(arr, l, m, r); + } +} + +int assert_sorted(int* arr, int n) { + for (int i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + return 0; // Not sorted + } + } + return 1; // Sorted +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + int n = atoi(argv[1]); + int* arr = (int*)malloc(n * sizeof(int)); + + // Read unsorted input array from stdin + for (int i = 0; i < n; i++) { + scanf("%d", &arr[i]); + } + + // Invoke SUT + modified_merge_sort(arr, 0, n - 1); + + if (!assert_sorted(arr, n)) { + fprintf(stderr, "Array is not sorted!\n"); + free(arr); + return 1; // Error + } + + free(arr); + + return 0; +} + diff --git a/demo/benchmarks-code/sorting-pointers/Makefile b/demo/benchmarks-code/sorting-pointers/Makefile new file mode 100644 index 0000000000000..60da53fc1cd6e --- /dev/null +++ b/demo/benchmarks-code/sorting-pointers/Makefile @@ -0,0 +1,27 @@ +CC=/scratch/fritz/src/safe-wasm/llvm-project/build/bin/clang +# WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks /scratch/fritz/src/safe-wasm/llvm-project/wasm_memsafety_rtlib.c +WASM_FLAGS=--target=wasm64-unknown-wasi --sysroot /scratch/martin/src/wasm/wasi-libc/sysroot -g -D_WASI_EMULATED_PROCESS_CLOCKS -lwasi-emulated-process-clocks -Wl,--stack-first -Wl,--initial-memory=104857600 -Wl,--max-memory=104857600 -Wl,-z,stack-size=83886080 +SAN_FLAGS=-march=wasm64-wasi+mem-safety -fsanitize=wasm-memsafety +#CFLAGS=-O0 ${WASM_FLAGS} +CFLAGS=-O2 ${WASM_FLAGS} +BUILD_DIR=build + +all: ${BUILD_DIR}/bubble_sort.wasm ${BUILD_DIR}/merge_sort.wasm ${BUILD_DIR}/modified_merge_sort.wasm + +${BUILD_DIR}/bubble_sort.wasm: bubble_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +${BUILD_DIR}/merge_sort.wasm: merge_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +${BUILD_DIR}/modified_merge_sort.wasm: modified_merge_sort.c + ${CC} -o $@ $< ${CFLAGS} ${EXTRA_FLAGS} ${SAN_FLAGS} + +clean: + @ rm -f ${BUILD_DIR}/bubble_sort.wasm ${BUILD_DIR}/merge_sort.wasm ${BUILD_DIR}/modified_merge_sort.wasm + +${BUILD_DIR}: + mkdir -p $@ + +# Add the directory as a dependency to ensure it's created before compilation +${BUILD_DIR}/%.wasm: | ${BUILD_DIR} diff --git a/demo/benchmarks-code/sorting-pointers/bubble_sort.c b/demo/benchmarks-code/sorting-pointers/bubble_sort.c new file mode 100644 index 0000000000000..7b4f4499649f3 --- /dev/null +++ b/demo/benchmarks-code/sorting-pointers/bubble_sort.c @@ -0,0 +1,53 @@ +#include +#include + +int main(int argc, char* argv[]) { + if (argc != 2) { + // fprintf(stderr, "Usage: %s \n", argv[0]); + fprintf(stderr, "Usage: \n"); + return 1; + } + + // size_t n = (size_t) atoi(argv[1]); + size_t n = 40000; + // void** arr = (void**)malloc(n * sizeof(void*)); + void* arr[n]; + + // Read unsorted input array from stdin + for (size_t i = 0; i < n; i++) { + // int value; + // scanf("%d", &value); + // arr[i] = (void*)value; + arr[i] = (void*)(n-i); + } + + // We inline bubble sort here, so that our LLVM is able to insert PAC instructions + for (size_t i = 0; i < n-1; i++) { + for (size_t j = 0; j < n-i-1; j++) { + if (arr[j] > arr[j+1]) { + void* temp = arr[j]; + arr[j] = arr[j+1]; + arr[j+1] = temp; + } + } + } + + // Assert that the array was sorted correctly + int assert_sorted = 1; + for (size_t i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + assert_sorted = 0; + break; + } + } + + // We aren't allowed to free the array, as this counts as using the value elsewhere, so we memory-leak here for testing purposes + // free(arr); + + if (!assert_sorted) { + fprintf(stderr, "Array is not sorted!\n"); + return 1; // Error + } + + return 0; +} diff --git a/demo/benchmarks-code/sorting-pointers/merge_sort.c b/demo/benchmarks-code/sorting-pointers/merge_sort.c new file mode 100644 index 0000000000000..3662b553c7f1d --- /dev/null +++ b/demo/benchmarks-code/sorting-pointers/merge_sort.c @@ -0,0 +1,102 @@ +#include +#include + +void merge(void** arr, size_t l, size_t m, size_t r) { + size_t L_size = m - l + 1; + size_t R_size = r - m; + + // create temporary arrays on the heap + void** L = (void**)malloc(L_size * sizeof(void*)); + void** R = (void**)malloc(R_size * sizeof(void*)); + + if (!L || !R) { + fprintf(stderr, "Memory allocation failed\n"); + exit(1); + } + + // Copy data to temp arrays L[] and R[] + for (size_t i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (size_t j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + // Merge the temp arrays back into arr[l..r] + size_t i = 0; + size_t j = 0; + size_t k = l; + while (i < L_size && j < R_size) { + if (L[i] <= R[j]) { + arr[k] = L[i]; + i++; + } else { + arr[k] = R[j]; + j++; + } + k++; + } + + // Copy the remaining elements of L[], if there are any + while (i < L_size) { + arr[k] = L[i]; + i++; + k++; + } + + // Copy the remaining elements of R[], if there are any + while (j < R_size) { + arr[k] = R[j]; + j++; + k++; + } + + free(L); + free(R); +} + +void merge_sort(void** arr, size_t l, size_t r) { + if (l < r) { + size_t m = l + (r - l) / 2; + merge_sort(arr, l, m); + merge_sort(arr, m + 1, r); + merge(arr, l, m, r); + } +} + +int assert_sorted(void** arr, size_t n) { + for (size_t i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + return 0; // Not sorted + } + } + return 1; // Sorted +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + size_t n = (size_t) atoi(argv[1]); + void** arr = (void**)malloc(n * sizeof(void*)); + + // Read unsorted input array from stdin + for (size_t i = 0; i < n; i++) { + int value; + scanf("%d", &value); + arr[i] = (void*)value; + } + + // Invoke SUT + merge_sort(arr, 0, n - 1); + + if (!assert_sorted(arr, n)) { + fprintf(stderr, "Array is not sorted!\n"); + free(arr); + return 1; // Error + } + + free(arr); + + return 0; +} diff --git a/demo/benchmarks-code/sorting-pointers/modified_merge_sort.c b/demo/benchmarks-code/sorting-pointers/modified_merge_sort.c new file mode 100644 index 0000000000000..2d948016b38aa --- /dev/null +++ b/demo/benchmarks-code/sorting-pointers/modified_merge_sort.c @@ -0,0 +1,104 @@ +#include +#include + +#define LIMIT 160 + +void perform_merge(void** arr, void** const L, void** const R, size_t l, size_t n1, size_t n2) { + size_t i = 0, j = 0, k = l; + + while (i < n1 && j < n2) { + if (L[i] <= R[j]) { + arr[k++] = L[i++]; + } else { + arr[k++] = R[j++]; + } + } + + while (i < n1) { + arr[k++] = L[i++]; + } + + while (j < n2) { + arr[k++] = R[j++]; + } +} + +void merge(void** arr, size_t l, size_t m, size_t r) { + size_t L_size = m - l + 1; + size_t R_size = r - m; + + // For small partition sizes, allocate constant size on the stack to take advantage of our optimizations + if (L_size <= LIMIT && R_size <= LIMIT) { + void* L[LIMIT], *R[LIMIT]; + + // Copy data to temp arrays L[] and R[] + for (size_t i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (size_t j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + perform_merge(arr, L, R, l, L_size, R_size); + } else { + void** L = (void**) malloc(L_size * sizeof(void*)); + void** R = (void**) malloc(R_size * sizeof(void*)); + + // Copy data to temp arrays L[] and R[] + for (size_t i = 0; i < L_size; i++) + L[i] = arr[l + i]; + for (size_t j = 0; j < R_size; j++) + R[j] = arr[m + 1 + j]; + + perform_merge(arr, L, R, l, L_size, R_size); + + free(L); + free(R); + } +} + +void modified_merge_sort(void** arr, size_t l, size_t r) { + if (l < r) { + size_t m = l + (r - l) / 2; + modified_merge_sort(arr, l, m); + modified_merge_sort(arr, m + 1, r); + merge(arr, l, m, r); + } +} + +int assert_sorted(void** arr, size_t n) { + for (size_t i = 0; i < n-1; i++) { + if (arr[i] > arr[i+1]) { + return 0; // Not sorted + } + } + return 1; // Sorted +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + size_t n = (size_t) atoi(argv[1]); + void** arr = (void**)malloc(n * sizeof(void*)); + + // Read unsorted input array from stdin + for (size_t i = 0; i < n; i++) { + int value; + scanf("%d", &value); + arr[i] = (void*)value; + } + + // Invoke SUT + modified_merge_sort(arr, 0, n - 1); + + if (!assert_sorted(arr, n)) { + fprintf(stderr, "Array is not sorted!\n"); + free(arr); + return 1; // Error + } + + free(arr); + + return 0; +} diff --git a/demo/demo-pac-external-functions.c b/demo/demo-pac-external-functions.c new file mode 100644 index 0000000000000..cb2a963119ee8 --- /dev/null +++ b/demo/demo-pac-external-functions.c @@ -0,0 +1,7 @@ +#include +#include +#include "demo-pac-external-functions.h" + +void print_pointer_external(volatile int *ptr) { + printf("Printing pointer: %d\n", *ptr); +} diff --git a/demo/demo-pac-external-functions.h b/demo/demo-pac-external-functions.h new file mode 100644 index 0000000000000..1ae9c301e230d --- /dev/null +++ b/demo/demo-pac-external-functions.h @@ -0,0 +1,9 @@ +#ifndef DEMO_PAC_EXTERNAL_FUNCTIONS_H +#define DEMO_PAC_EXTERNAL_FUNCTIONS_H + +#include +#include + +void print_pointer_external(volatile int *ptr); + +#endif // DEMO_PAC_EXTERNAL_FUNCTIONS_H diff --git a/demo/demo-pac-external-functions.h.gch b/demo/demo-pac-external-functions.h.gch new file mode 100644 index 0000000000000..7feddc55b6724 Binary files /dev/null and b/demo/demo-pac-external-functions.h.gch differ diff --git a/demo/demo-pac.c b/demo/demo-pac.c new file mode 100644 index 0000000000000..8b1cb8c00d6d2 --- /dev/null +++ b/demo/demo-pac.c @@ -0,0 +1,133 @@ +#include +#include +#include "demo-pac-external-functions.h" + +// void store_pointer(int *value_ptr, int **dst_ptr) { +// *dst_ptr = value_ptr; +// } + +// int *load_pointer(int **x) { +// return *x; +// } + +// int main(int argc, char **argv) { +// if (argc < 2) { +// printf("Usage: %s value\n", argv[0]); +// return 1; +// } + +// int *buf[4]; +// int x = atoi(argv[1]); + +// // &buf[1] is signed with PAC here +// store_pointer(&x, &buf[1]); + +// // load succeeds because &buf[1] can be authenticated with PAC +// int *x_ptr = load_pointer(&buf[1]); + +// printf("Value stored in x is: %d\n", *x_ptr); + +// return 0; +// } + + +// void print_pointer(int *ptr) { +// printf("Printing pointer: %d\n", *ptr); +// } + +// void print_double_pointer(int **ptr) { +// printf("Printing pointer: %d\n", **ptr); +// } + +// int main(int argc, char **argv) { +// if (argc < 2) { +// printf("Usage: %s value\n", argv[0]); +// return 1; +// } + +// int *buf[4]; +// int x = atoi(argv[1]); +// int *x_pointer = &x; + +// // &x is signed with PAC here +// buf[1] = &x; + +// // These should both be recognized by as users of x +// print_pointer(&x); +// print_double_pointer(&x_pointer); + +// // load succeeds because &x can be authenticated with PAC +// int *x_ptr = *(&buf[1]); + +// printf("Value stored in x is: %d\n", *x_ptr); + +// return 0; +// } + + +// __attribute__((noinline)) +// void print_pointer(volatile int *ptr) { +// printf("Printing pointer: %d\n", *ptr); +// } + +// __attribute__((noinline)) +// int main(int argc, char **argv) { +// if (argc < 2) { +// printf("Usage: %s value\n", argv[0]); +// return 1; +// } + +// volatile int *buf[4]; +// volatile int x = atoi(argv[1]); + +// // &x is signed with PAC here +// // this would count as alias already +// buf[2] = &x; + +// // These should both be recognized by us as users of x +// print_pointer(&x); +// print_pointer(buf[2]); + +// // load succeeds because &x can be authenticated with PAC +// volatile int *x_ptr = *(&buf[2]); + +// printf("Value stored in x is: %d\n", *x_ptr); + +// return 0; +// } + + +// int main(int argc, char **argv) { +// int return_val = (int) argv[0]; +// printf("Code until here was executed.\n"); +// return return_val; +// } + + +void print_pointer_external(volatile int *ptr); + +__attribute__((noinline)) +void print_pointer(volatile int *ptr) { + printf("Printing pointer: %d\n", *ptr); +} + +// expect no auths or signs to be inserted, since they alias each other +__attribute__((noinline)) +int main(int argc, char **argv) { + // this should be alias of y + volatile int x = 42; + printf("x before: %d\n", x); + + // this should be alias of x + volatile int *y = &x; + *y = 41; + printf("x after: %d\n", x); + + // for now, this should be disallowed + // print_pointer(y); + // print_pointer_external(y); + // print_pointer(&x); + print_pointer_external(&x); + + return 0; +} diff --git a/demo/demo-test.c b/demo/demo-test.c new file mode 100644 index 0000000000000..3d006df607cb7 --- /dev/null +++ b/demo/demo-test.c @@ -0,0 +1,17 @@ +#include +#include + +int main() { + char *string = "Hello World!"; + char **pointer_storage = &string; + char name[10]; + + printf("What is your name?\n"); + scanf("%s", name); // potential buffer overflow + printf("Hello user %s!\n", name); + + char *loaded_string = *pointer_storage; // failed authentication + printf("String protected with PAC: %s\n", loaded_string); + + return 0; +} diff --git a/demo/lto-tests/a.cpp b/demo/lto-tests/a.cpp new file mode 100644 index 0000000000000..d9b5f6bd37e51 --- /dev/null +++ b/demo/lto-tests/a.cpp @@ -0,0 +1,7 @@ +int function_a(int x, int y) { + return x + y / x - y * x; +} + +int main() { + return function_a(2, 3); +} \ No newline at end of file diff --git a/demo/lto-tests/b.cpp b/demo/lto-tests/b.cpp new file mode 100644 index 0000000000000..6cc85c596e5ca --- /dev/null +++ b/demo/lto-tests/b.cpp @@ -0,0 +1,3 @@ +int function_b(int x, int y, int z) { + return z + y / x - z * x; +} \ No newline at end of file diff --git a/demo/pac/test-1.c b/demo/pac/test-1.c new file mode 100644 index 0000000000000..7f16a6d5d4dbc --- /dev/null +++ b/demo/pac/test-1.c @@ -0,0 +1,37 @@ +// int *some_declared_function(int *ptr); + +// int *increment_ptr(int *ptr) { +// printf("print something just so this function doesn't get inlined, so the returned pointer can't be identified as an alias."); +// return ptr; +// } + +// int main() { +// int buf[4]; +// int x = 42; + +// int *alias_to_buf_2 = increment_ptr(&buf[1]); +// some_declared_function(alias_to_buf_2); + +// return 0; +// } + + +// this function is external +void external_function(int **ptr) { + // load ptr and do sth +} + +// this function is not external (since it is defined), but it does call an external function +void non_external_function(int **ptr) { + external_function(ptr); +} + +// our llvm should detect that ptr has other uses, since it is indirectly passed to an external function +int main() { + int *buf[4]; + + int x = 16; + buf[1] = &x; + non_external_function(&buf[1]); + return 0; +} diff --git a/demo/pac/test-argv.c b/demo/pac/test-argv.c new file mode 100644 index 0000000000000..6380a770df3ea --- /dev/null +++ b/demo/pac/test-argv.c @@ -0,0 +1,13 @@ +#include + +int main(int argc, char **argv) { + printf("argv[0] = %s\n", argv[0]); + + // // Internally, this looks like so: + // // wasi_libc_get_argv() is an external function, meaning it doesn't sign the pointers it has stored + // char **argv = wasi_libc_get_argv(); + // // now we load a non-signed pointer, so we are not allowed to auth it + // argv[0]; + + return 0; +} \ No newline at end of file diff --git a/demo/pac/test-double-load.c b/demo/pac/test-double-load.c new file mode 100644 index 0000000000000..4545cc39e5b6d --- /dev/null +++ b/demo/pac/test-double-load.c @@ -0,0 +1,11 @@ +#include + +int main(int argc, char **argv) { + // store argv + char ***argv_ptr = &argv[0]; + + // load argv + printf("argv[0] = %s\n", *argv_ptr); + + return 0; +} diff --git a/demo/pac/test-elsewhere.ll b/demo/pac/test-elsewhere.ll new file mode 100644 index 0000000000000..217915ec1c1c4 --- /dev/null +++ b/demo/pac/test-elsewhere.ll @@ -0,0 +1,9 @@ +declare i8** @function2(); + +define void @function3() { + %1 = call i8** @function2() + %2 = getelementptr i8*, i8** %1, i32 1 + %string = load i8*, i8** %2 + + ret void +} diff --git a/demo/pac/test-loop.c b/demo/pac/test-loop.c new file mode 100644 index 0000000000000..0281ff5b03619 --- /dev/null +++ b/demo/pac/test-loop.c @@ -0,0 +1,11 @@ +#include + +int main() { + char **argv; + + while (argv++) { + printf("%s\n" ,*argv); + } + + return 0; +} \ No newline at end of file diff --git a/demo/pac/test-memory-location-comes-from-function.c b/demo/pac/test-memory-location-comes-from-function.c new file mode 100644 index 0000000000000..598ade84fcd2f --- /dev/null +++ b/demo/pac/test-memory-location-comes-from-function.c @@ -0,0 +1,14 @@ +int** sign_pointer() { + int *ptr = (int*) malloc(20); + int **memory_location = (int**) malloc(1); + // store pointer + *memory_location = ptr; + return memory_location; +} + +int main() { + int **memory_location = sign_pointer(); + + // load pointer + int *loaded_ptr = *memory_location; +} diff --git a/demo/pac/test-most-basic-usecase.c b/demo/pac/test-most-basic-usecase.c new file mode 100644 index 0000000000000..6b47a4953c512 --- /dev/null +++ b/demo/pac/test-most-basic-usecase.c @@ -0,0 +1,67 @@ +#include +#include + +// // This store would never get signed, because the memory location the pointer is stored to comes from the parameter of the function. +// void store_pointer(int *value_ptr, int **dst_ptr) { +// // store pointer +// *dst_ptr = value_ptr; +// } + +// // This load would never get signed, because the memory location the pointer is loaded from comes from the parameter of the function. +// int *load_pointer(int **x) { +// // load pointer +// return *x; +// } + +// int main(int argc, char **argv) { +// if (argc < 2) { +// printf("Usage: %s value\n", argv[0]); +// return 1; +// } + +// // int *buf[4]; +// int **buf; +// int x = atoi(argv[1]); + +// // &x should be signed with PAC here before storing +// // store pointer +// // buf[1] = &x; + +// // no signing, see explanation in method +// // store_pointer(&x, &buf[1]); +// store_pointer(&x, buf); + +// // &x should be authenticated with PAC here after loading +// // load pointer +// // int *x_ptr = buf[1]; +// int *x_ptr = *buf; + +// // no authentication, see explanation in method +// // int *x_ptr = load_pointer(&buf[1]); + +// printf("Value stored in x is: %d\n", *x_ptr); + +// return 0; +// } + +void store_pointer(int *value_ptr, int **dst_ptr) { + // store pointer + *dst_ptr = value_ptr; +} + +int *load_pointer(int **x) { + return *x; +} + +int main() { + int **buf; + int x = 42; + + buf[1] = &x; + + int *x_ptr = load_pointer(&buf[1]); + + int x_alias = *x_ptr; + + return 0; +} diff --git a/demo/pac/test-prevent-real-attack-external/external.c b/demo/pac/test-prevent-real-attack-external/external.c new file mode 100644 index 0000000000000..264f2dd3aa9b5 --- /dev/null +++ b/demo/pac/test-prevent-real-attack-external/external.c @@ -0,0 +1,7 @@ +#include +#include "external.h" + +void external_function(char **ptr) { + //... (implementation of the function) + printf("Printing from external function, ptr is: %p\n", *ptr); +} diff --git a/demo/pac/test-prevent-real-attack-external/external.h b/demo/pac/test-prevent-real-attack-external/external.h new file mode 100644 index 0000000000000..9aac1c49d86f4 --- /dev/null +++ b/demo/pac/test-prevent-real-attack-external/external.h @@ -0,0 +1,7 @@ +// File: external.h +#ifndef EXTERNAL_H +#define EXTERNAL_H + +void external_function(char **ptr); + +#endif // EXTERNAL_H diff --git a/demo/pac/test-prevent-real-attack-external/test-prevent-real-attack-extended.c b/demo/pac/test-prevent-real-attack-external/test-prevent-real-attack-extended.c new file mode 100644 index 0000000000000..bd6afad9009e4 --- /dev/null +++ b/demo/pac/test-prevent-real-attack-external/test-prevent-real-attack-extended.c @@ -0,0 +1,39 @@ +#include +#include +#include "external.h" + +void print_pointer_to_string(char **ptr) { + printf("Printing pointer to string: %p\n", *ptr); + // Signing *ptr should not be done because of this external function using it + external_function(ptr); +} + +__attribute__((noinline)) +int main() { + char *string = "Hello World!"; + + // store pointer: we should insert a PAC sign here to prevent the attacker from overwriting it by overflowing `name` + // if the attacker overflows `name`, then `pointer_storage` will be overwritten + // char **pointer_storage = &string; + char **pointer_storage; + // read variable length user input into this array + char name[10]; + + // If we remove this line, then the loading of pointer storage below would also succeed + // store_pointer(&string, &pointer_storage); + pointer_storage = &string; + + // We should be able to pass aliases of the pointer_storage to other functions, as long as they don't end in external functions + // print_pointer_to_string(pointer_storage); + + printf("What is your name?\n"); + scanf("%s", name); + + printf("Hello user %s!\n", name); + + // load pointer: we should insert a PAC auth here + char *loaded_string = *pointer_storage; + printf("Here is the string we stored and protected using PAC: %s\n", loaded_string); + + return 0; +} diff --git a/demo/pac/test-prevent-real-attack.c b/demo/pac/test-prevent-real-attack.c new file mode 100644 index 0000000000000..cd15ebed649bf --- /dev/null +++ b/demo/pac/test-prevent-real-attack.c @@ -0,0 +1,24 @@ +#include +#include + +__attribute__((noinline)) +int main() { + char *string = "Hello World!"; + + // store pointer: we should insert a PAC sign here to prevent the attacker from overwriting it by overflowing `name` + // if the attacker overflows `name`, then `pointer_storage` will be overwritten + char **pointer_storage = &string; + // read variable length user input into this array + char name[10]; + + printf("What is your name?\n"); + scanf("%s", name); + + printf("Hello user %s!\n", name); + + // load pointer: we should insert a PAC auth here + char *loaded_string = *pointer_storage; + printf("Here is the string we stored and protected using PAC: %s\n", loaded_string); + + return 0; +} diff --git a/demo/pac/test-value-comes-from-elsewhere.c b/demo/pac/test-value-comes-from-elsewhere.c new file mode 100644 index 0000000000000..217544b698d4b --- /dev/null +++ b/demo/pac/test-value-comes-from-elsewhere.c @@ -0,0 +1,107 @@ +#include + +// === TEST-CASE 1: + +// gets memory location, to be loaded from, from parameter, our llvm should detect not to auth here on load +void function(char **double_ptr) { + *double_ptr; +} + +// Equivalent llvm-ir: +/* +define void @function(i8** %double_ptr) { + %1 = load i8*, i8** %double_ptr + ret void +} +*/ + + +// === TEST-CASE 2: + +char **function2() { + char *string = malloc(42 * sizeof(char)); + char **double_ptr = malloc(sizeof(char *)); + // store pointer + *double_ptr = string; + return double_ptr; +} + +// gets memory location, to be loaded from, from other function's return value directly, our llvm should detect not to auth here on load +void function3() { + char **string = function2(); + // load pointer + char *loaded_ptr = *string; + free(*string); + free(string); +} + +// Equivalent llvm-ir: +/* +declare i8** @function2(); + +define void @function3() { + %1 = call i8** @function2() + %string = load i8*, i8** %1 + + ret void +} +*/ + + +// === TEST-CASE 3: + +// gets memory location, to be loaded from, from other function's return value with some slight code in between, our llvm should detect not to auth here on load +void function3() { + char **string = function2(); + string++; + char *loaded_ptr = *string; +} + +// Equivalent llvm-ir: +/* +declare i8** @function2(); + +define void @function3() { + %1 = call i8** @function2() + %2 = getelementptr i8*, i8** %1, i32 1 + %string = load i8*, i8** %2 + + ret void +} +*/ + +// === TEST-CASE 4: + +// gets memory location, to be loaded from, from function parameter with some slight code in between, our llvm should detect not to auth here on load +void function3(char **string) { + string++; + char *loaded_ptr = *string; +} + +// Equivalent llvm-ir: +/* +define void @function3(i8** %string) { + %1 = getelementptr i8**, i8*** %string, i32 1 + %2 = load i8**, i8*** %1 + %loaded_ptr = load i8*, i8** %2 + + ret void +} +*/ + +// === TEST-CASE 5: + +void main(int argc, char **argv) { + // load pointer => comes from elsewhere (i.e. function parameter), so we shouldn't authenticate it + argv[0]; +} + +// Equivalent llvm-ir: +/* +define i32 @main(i32 %argc, i8** %argv) { + %1 = getelementptr i8*, i8** %argv, i32 0 + %loaded_ptr = load i8*, i8** %1 + + ret i32 0 +} +*/ diff --git a/demo/pac/test-value-has-other-uses.c b/demo/pac/test-value-has-other-uses.c new file mode 100644 index 0000000000000..a7a8377da4303 --- /dev/null +++ b/demo/pac/test-value-has-other-uses.c @@ -0,0 +1,69 @@ +#include + +// === TEST-CASE 1: + +// this function is external +void external_function(int *ptr); + +// our llvm should detect that ptr has other uses, since it is passed to an external function +void main() { + int *ptr = get_pointer_from_not_elsewhere(); + external_function(ptr); + return 0; +} + +// Equivalent llvm-ir: +/* +*/ + +// === TEST-CASE 2: + +// this function is external +void external_function(int *ptr); + +// this function is not external (since it is defined), but it does call an external function +void non_external_function(int *ptr) { + external_function(ptr); +} + +// our llvm should detect that ptr has other uses, since it is indirectly passed to an external function +void main() { + int *ptr = get_pointer_from_not_elsewhere(); + non_external_function(ptr); + return 0; +} + +// Equivalent llvm-ir: +/* +*/ + +// === TEST-CASE 3: + +// this function is external +void external_function(int *ptr); + +// this function is not external (since it is defined), but it does call an external function for one of its parameters +void non_external_function(int *ptr1, int *ptr2) { + // ptr1 is unused; + external_function(ptr2); +} + +// our llvm should detect that ptr has other uses, since it is indirectly passed to an external function +void main() { + int *ptr = get_pointer_from_not_elsewhere(); + non_external_function(ptr, ptr); + return 0; +} + +// Equivalent llvm-ir: +/* +*/ + + +// === TEST-CASE 3: + +// TODO: test vararg function + +// === TEST-CASE 4: + +// TODO: test function pointer, with if statement depending on input that will decide the function to be executed (or just read in a function and that gets executed) diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td index 9a6b40a6333bb..e042814bc2e31 100644 --- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td +++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td @@ -35,6 +35,10 @@ def int_wasm_segment_stack_new : DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_ptr_ty, llvm_i64_ty], []>; def int_wasm_segment_stack_free : DefaultAttrsIntrinsic<[], [llvm_ptr_ty, llvm_ptr_ty, llvm_i64_ty], []>; +def int_wasm_pointer_sign : + DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_ptr_ty], []>; +def int_wasm_pointer_auth : + DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_ptr_ty], []>; //===----------------------------------------------------------------------===// // ref.null intrinsics diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index 1d3e8ec90420a..ce9dc18d4441a 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -52,6 +52,9 @@ add_llvm_target(WebAssemblyCodeGen WebAssemblySelectionDAGInfo.cpp WebAssemblySetP2AlignOperands.cpp WebAssemblyMemorySafety.cpp + WebAssemblyPointerAuthenticationFunctionPass.cpp + WebAssemblyPointerAuthenticationModulePass.cpp + WebAssemblyPointerAuthenticationLTOPass.cpp WebAssemblySortRegion.cpp WebAssemblyMemIntrinsicResults.cpp WebAssemblySubtarget.cpp diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index d7739673f4852..c54b3b363d240 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -55,6 +55,9 @@ FunctionPass *createWebAssemblyRegNumbering(); FunctionPass *createWebAssemblyDebugFixup(); FunctionPass *createWebAssemblyPeephole(); FunctionPass *createWebAssemblyMemorySafetyPass(bool IsOptNone); +FunctionPass *createWebAssemblyPointerAuthenticationFunctionPass(); +ModulePass *createWebAssemblyPointerAuthenticationModulePass(); +ModulePass *createWebAssemblyPointerAuthenticationLTOPass(); ModulePass *createWebAssemblyMCLowerPrePass(); // PassRegistry initialization declarations. @@ -85,6 +88,9 @@ void initializeWebAssemblyRegStackifyPass(PassRegistry &); void initializeWebAssemblyReplacePhysRegsPass(PassRegistry &); void initializeWebAssemblySetP2AlignOperandsPass(PassRegistry &); void initializeWebAssemblyMemorySafetyPass(PassRegistry &); +void initializeWebAssemblyPointerAuthenticationFunctionPass(PassRegistry &); +void initializeWebAssemblyPointerAuthenticationModulePass(PassRegistry &); +void initializeWebAssemblyPointerAuthenticationLTOPass(PassRegistry &); namespace WebAssembly { enum TargetIndex { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td index 73fb7aac15286..028c2ba412663 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td @@ -208,6 +208,22 @@ defm SEGMENT_STACK_FREE_A64 : I<(outs), "segment.stack_free\t${off}(${addr})${p2align}, $sp, $size", "segment.stack_free\t${off}${p2align}", 0xfa03, false>, Requires<[HasAddr64]>; // HasMemSafety + +defm POINTER_SIGN_A64 : I<(outs I64:$signed_ptr), + (ins I64:$ptr), + (outs), (ins), + [], + "i64.pointer_sign\t$signed_ptr, $ptr", "i64.pointer_sign", + 0xfa04, false>, + Requires<[HasAddr64]>; // HasMemSafety + +defm POINTER_AUTH_A64 : I<(outs I64:$authed_ptr), + (ins I64:$ptr), + (outs), (ins), + [], + "i64.pointer_auth\t$authed_ptr, $ptr", "i64.pointer_auth", + 0xfa05, false>, + Requires<[HasAddr64]>; // HasMemSafety } def : Pat<(int_wasm_segment_new I64:$size), @@ -218,6 +234,10 @@ def : Pat<(int_wasm_segment_stack_new (AddrOps64 offset64_op:$offset, I64:$addr) (!cast(SEGMENT_STACK_NEW_A64) 0, offset64_op:$offset, I64:$addr, I64:$size)>; def : Pat<(int_wasm_segment_stack_free (AddrOps64 offset64_op:$offset, I64:$addr), I64:$sp, I64:$size), (!cast(SEGMENT_STACK_FREE_A64) 0, offset64_op:$offset, I64:$addr, I64:$sp, I64:$size)>; +def : Pat<(int_wasm_pointer_sign I64:$ptr), + (!cast(POINTER_SIGN_A64) I64:$ptr)>; +def : Pat<(int_wasm_pointer_auth I64:$ptr), + (!cast(POINTER_AUTH_A64) I64:$ptr)>; multiclass MemoryOps { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMemorySafety.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMemorySafety.cpp index 9d8971318a0f6..b4a703f5d0104 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMemorySafety.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMemorySafety.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/LiveRegUnits.h" #include "llvm/CodeGen/MachineBasicBlock.h" @@ -61,6 +62,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include @@ -175,6 +177,7 @@ class WebAssemblyMemorySafety : public FunctionPass { private: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequiredTransitive(); } bool isAllocKind(Attribute Attr, AllocFnKind Kind) const { @@ -203,6 +206,8 @@ bool WebAssemblyMemorySafety::runOnFunction(Function &F) { F.getName().starts_with("__wasm_memsafety_")) return false; + auto &TLIAnalysis = getAnalysis(); + DataLayout DL = F.getParent()->getDataLayout(); LLVMContext &Ctx(F.getContext()); @@ -222,6 +227,7 @@ bool WebAssemblyMemorySafety::runOnFunction(Function &F) { } if (auto *Call = dyn_cast(&I)) { auto *CalledFunction = Call->getCalledFunction(); + inferNonMandatoryLibFuncAttrs(CalledFunction->getParent(), CalledFunction->getName(), TLIAnalysis.getTLI(F)); auto Attr = CalledFunction->getFnAttribute(Attribute::AttrKind::AllocKind); if (Attr.hasAttribute(Attribute::AllocKind)) { @@ -358,6 +364,7 @@ bool WebAssemblyMemorySafety::runOnFunction(Function &F) { FreeSegmentInst->insertBefore(Terminator); } } + // F.dump(); return true; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationFunctionPass.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationFunctionPass.cpp new file mode 100644 index 0000000000000..c9264c8b1090a --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationFunctionPass.cpp @@ -0,0 +1,394 @@ +//===- WebAssemblyPointerAuthenticationFunctionPass.cpp - Pointer Authentication for WASM --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "Utils/WebAssemblyUtilities.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/LiveRegUnits.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineLoopInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/MemoryTaggingSupport.h" +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "wasm-pointer-authentication-function" + +namespace { +class WebAssemblyPointerAuthenticationFunction final : public FunctionPass { + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { return "WebAssembly Pointer Authentication Function"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } + +public: + static char ID; + WebAssemblyPointerAuthenticationFunction() : FunctionPass(ID) {} +}; +} // end anonymous namespace + +char WebAssemblyPointerAuthenticationFunction::ID = 0; +INITIALIZE_PASS(WebAssemblyPointerAuthenticationFunction, DEBUG_TYPE, + "WebAssembly Pointer Authentication Function Pass", false, false) + +FunctionPass *llvm::createWebAssemblyPointerAuthenticationFunctionPass() { + return new WebAssemblyPointerAuthenticationFunction(); +} + +std::string getAliasResultString(AliasResult result) { + switch (result) { + case AliasResult::NoAlias: + return "NoAlias"; + case AliasResult::MayAlias: + return "MayAlias"; + case AliasResult::PartialAlias: + return "PartialAlias"; + case AliasResult::MustAlias: + return "MustAlias"; + } +} + +// TODO: read somewhere that AliasAnalysis does not account for loops apparently => test +void findAllAliasesOfValue(Value &V, SmallVector &Aliases, AliasAnalysis &AA, Function &F) { + // The pointer itself counts as one of its own aliases + Aliases.emplace_back(&V); + + // std::cout << " Value \"" << V.getName().str() << "\" is aliased by:" << std::endl; + for (BasicBlock &BB : F) { + for (Value &OtherValue : BB) { + // Only iterate on all other values + if (&V == &OtherValue) { + continue; + } + + // AliasResult aliasResult = AA.alias(&V, &OtherValue); + // if (aliasResult != AliasResult::NoAlias) { + if (!AA.isNoAlias(&V, &OtherValue)) { + // std::cout << " Other value \"" << OtherValue.getName().str() << "\" is a: " << getAliasResultString(aliasResult) << std::endl; + + Aliases.emplace_back(&OtherValue); + } + } + } +} + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +void findAllFunctionsWhereValueIsPassedAsArgumentHelper(Value &V, SmallVector &FunctionCalls, std::set &VisitedValues) { + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + errs() << "in find all functions where passed as param: Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we have already checked this value and all of its users. + return; + } + + // std::cout << " Value \"" << V.getName().str() << "\" is used in functions:" << std::endl; + errs() << " Value \"" << V << "\" is used in functions:\n"; + + for (User *U : V.users()) { + // TODO: we can't only consider function users, we also have to consider e.g. normal loads and stores, which are not function calls + // TODO: what if we add and subtract, and then use the new pointer to load => probably counts as an alias, but we could just recursively directly trace those calls + if (CallInst *CI = dyn_cast(U)) { + for (Value *Arg : CI->args()) { + if (Arg == &V) { + // std::cout << " Function \"" << CI->getCalledFunction()->getName().str() << "\"" << std::endl; + errs() << " Function: " << CI << "\n"; + + FunctionCalls.emplace_back(CI->getCalledFunction()); + + // TODO: only do this if the original value, that was passed to another function, is again returned from this function + // TODO: would a simple "if (CI == &V)" work? + // Recursively checks if any other functions use the function's return value + // findAllFunctionsWhereValueIsPassedAsArgument(*CI, FunctionCalls); + } + } + } + // Consider all users, and recurse on them, not just the functions with value as parameter + findAllFunctionsWhereValueIsPassedAsArgumentHelper(*U, FunctionCalls, VisitedValues); + } +} + +// Find all function calls that use the specified value as an argument. +// Once we found a function, we also have to recursively find all of +// the functions that use that function('s return value). +void findAllFunctionsWhereValueIsPassedAsArgument(Value &V, SmallVector &FunctionCalls) { + std::set VisitedValues; + return findAllFunctionsWhereValueIsPassedAsArgumentHelper(V, FunctionCalls, VisitedValues); +} + +// A value has other uses if it is passed as a function parameter to any other +// function. +bool valueHasOtherUses(Value &V, Function &F, AliasAnalysis &AA) { + SmallVector functionsUsingValue; + findAllFunctionsWhereValueIsPassedAsArgument(V, functionsUsingValue); + + // TODO: major optimization: we don't need to construct the entire list of all recursive functions, we basically just need to find if there are **any**, so we can immediately return once we found the first function that takes as parameter. + return !functionsUsingValue.empty(); +} + +// Checks whether the value is a parameter of a function. +bool valueIsParameterOfFunction(Value &V, Function &F) { + for (Argument &arg : F.args()) { + if (&arg == &V) { + return true; + } + } + return false; +} + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +bool valueComesFromElsewhereHelper(Value &V, Function &ParentFunction, std::set &VisitedValues) { + errs() << "Checking value: " << V.getName().str() << "\n"; + + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we continue searching, but skip re-entering the loop. + errs() << "Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + return false; + } + + if (valueIsParameterOfFunction(V, ParentFunction)) { + errs() << "Value: " << V.getName().str() << " is the parameter of function: " << ParentFunction.getName() << "\n"; + return true; + } + + // A global value could be used across different modules, so we can never control/know that global values aren't used elsewhere + if (isa(&V)) { + errs() << "Value: " << V.getName().str() << " is a global value\n"; + return true; + } + + // Checks, recursively, whether a Value was returned by a function call. + if (auto *I = dyn_cast(&V)) { + // Check if instruction is (directly) the return value of a function call. + if (isa(I)) { + errs() << "Instruction: " << I << " is the return value of a function call\n"; + return true; + } + // Check if value was loaded from a memory location, i.e. value is the + // return value of a load instruction. + if (isa(I)) { + return true; + } + + // Since `V` doesn't come from elsewhere directly, we have to + // check whether any of the parameters/operands of the instruction `V` + // come from elsewhere. + for (auto &Op : I->operands()) { + if (valueComesFromElsewhereHelper(*Op, ParentFunction, VisitedValues)) { + return true; + } + } + } + + return false; +} + +// Checks whether a value "comes from elsewhere". +// A value comes from elsewhere if any of the following conditions are met: +// 1. The value was passed as a parameter to the current function. +// 2. The value is the return value of any function. +// 3. The value was loaded from any memory location. +// 4. The value is a global value. +// In case the current value/instruction does not come from elsewhere, we also +// need to check whether any of its operands come from elsewhere. +bool valueComesFromElsewhere(Value &V, Function &ParentFunction) { + std::set VisitedValues; + return valueComesFromElsewhereHelper(V, ParentFunction, VisitedValues); +} + +// Pointer Authentication Rules: +// +// A pointer (value), that is being stored in or loaded from a memory location, +// is suitable for pointer authentication, if that memory location has no other +// uses and does not come from elsewhere. +// A pointer is only suitable for PA, if all of its aliases are also suitable for +// PA. +bool memoryLocationIsSuitableForPA(Value &MemoryLocation, Function &F, AliasAnalysis &AA) { + SmallVector Aliases; + findAllAliasesOfValue(MemoryLocation, Aliases, AA, F); + + // TODO: optimization possibility: cache the aliases that were already found to be non-suitable + // If any of the aliases are not suitable, then all of the aliases should be not suitable + for (auto Alias : Aliases) { + if (valueHasOtherUses(*Alias, F, AA) || valueComesFromElsewhere(*Alias, F)) { + return false; + } + } + + return true; +} + +// Go through all load and stores of pointers and check if they are suitable for +// pointer authentication. +bool authenticateStoredAndLoadedPointers(Function &F, AliasAnalysis &AA) { + auto *PointerSignFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_sign); + auto *PointerAuthFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_auth); + + SmallVector StorePointerInsts; + SmallVector LoadPointerInsts; + + // Look for instructions that load/store a pointer + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (StoreInst *SI = dyn_cast(&I)) { + // Store(value, ptr): $value is stored at data address pointed to by $ptr + // Check if value to be stored in memory is a pointer + Value *PointerValueToStore = SI->getValueOperand(); + if (PointerValueToStore->getType()->isPointerTy()) { + auto MemoryLocation = SI->getPointerOperand(); + errs() << "==== Checking if store: " << SI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, AA)) { + errs() << "Store instruction: " << SI << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + StorePointerInsts.emplace_back(SI); + } else { + errs() << "Store instruction: " << SI << " is not suitable for pointer authentication\n"; + } + } + } else + if (LoadInst *LI = dyn_cast(&I)) { + // Load(ptr): The data value located at the memory address pointed to by $ptr is returned + // Check if value to be loaded from memory is a pointer + if (LI->getType()->isPointerTy()) { + auto MemoryLocation = LI->getPointerOperand(); + errs() << "==== Checking if load: " << LI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, AA)) { + errs() << "Load instruction: " << LI << " is suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + LoadPointerInsts.emplace_back(LI); + } else { + errs() << "Load instruction: " << LI << " is not suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is not suitable for pointer authentication\n"; + } + } + } + } + } + + // Add pointer signing inst before pointer store inst + for (auto SI : StorePointerInsts) { + Value *PointerValueToStore = SI->getValueOperand(); + + auto *PointerSignInst = CallInst::Create(PointerSignFunc, {PointerValueToStore}); + PointerSignInst->insertBefore(SI); + + // Replace the value operand in the store inst with the new signed value + SI->setOperand(0, PointerSignInst); + } + + // Add pointer authentication inst after pointer load inst + for (auto LI : LoadPointerInsts) { + auto *PointerAuthInst = CallInst::Create(PointerAuthFunc, {LI}); + PointerAuthInst->insertAfter(LI); + + // All further uses of the load's return value must use our authenticated pointer instead now + LI->replaceUsesWithIf(PointerAuthInst, [&](Use &U) { + return U.getUser() != PointerAuthInst; + }); + } + + // We made changes if we added any pointer sign or auth instructions. + bool modified = !(LoadPointerInsts.empty() && StorePointerInsts.empty()); + return modified; +} + +bool WebAssemblyPointerAuthenticationFunction::runOnFunction(Function &F) { + errs() << "=== Starting analysis on function: " << F.getName().str() << "\n"; + // if (F.getName() != "__original_main") { + // return false; + // } + + AliasAnalysis &AA = getAnalysis().getAAResults(); + + // TODO: use the return value somehow, or remove it + bool modified = authenticateStoredAndLoadedPointers(F, AA); + + // if (F.getName().contains("main")) { + F.dump(); + // } + + // No changes relevant to other LLVM transformation passes were made. + // We simply added some instructions other passes are unaware of anyways. + // However, to be on the safe side, we will still indicate that the function + // was modified. + // TODO: potentially set to false in the future + return modified; +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationLTOPass.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationLTOPass.cpp new file mode 100644 index 0000000000000..74cb2124ce4b5 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationLTOPass.cpp @@ -0,0 +1,617 @@ +//===- WebAssemblyPointerAuthenticationLTOPass.cpp - Pointer Authentication for WASM --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "Utils/WebAssemblyUtilities.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/LiveRegUnits.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineLoopInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/MemoryTaggingSupport.h" +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "wasm-pointer-authentication-lto" + +namespace { + +class WebAssemblyPointerAuthenticationLTO : public ModulePass { + +public: + static char ID; + + WebAssemblyPointerAuthenticationLTO() : ModulePass(ID) { + initializeWebAssemblyPointerAuthenticationLTOPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + StringRef getPassName() const override { return "WebAssembly Pointer Authentication LTO"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } + +private: + AliasResult alias(Value &V1, Value &V2, Function &F) { + AliasAnalysis& AA = getAnalysis(F).getAAResults(); + return AA.alias(&V1, &V2); + } + + AliasAnalysis& getAliasAnalysisForFunction(Function &F) { + return getAnalysis(F).getAAResults(); + } + +// TODO: remove +std::string getAliasResultString(AliasResult result) { + switch (result) { + case AliasResult::NoAlias: + return "NoAlias"; + case AliasResult::MayAlias: + return "MayAlias"; + case AliasResult::PartialAlias: + return "PartialAlias"; + case AliasResult::MustAlias: + return "MustAlias"; + } +} + + +// TODO: a potential bug in our implementation i thought of: what if a pointer is stored in a function, and then passed to another function where it is loaded from? then, it will be signed in the original function, but never authed in the other function, because memory locations passed as parameters are always classified as coming from elsewhere, which makes sense, because other modules could use this function and we can't assume they will sign the memory lcoation's stored pointer. +// A solution to this would most likely be the LTO pass, because there we can be certain no-one else would use the function that should auth + + + +// We define a function as external if it is declared, but not defined, in +// the current module. +bool isExternalFunction(Function &F, Module &ParentModule) { + // std::cout << "==== Checking whether function \"" << F.getName().str() << "\" is an external function." << std::endl; + + // Sanity check: If the function belongs to a different module than the + // one we are currently analyzing, then it is definitely an external function. + if (F.getParent() != &ParentModule) { + // errs() << "Found external function from a different module: " << F.getName() << "\n"; + return true; + } + + if (F.isDeclaration() && !F.isIntrinsic()) { + // errs() << "Found external function: " << F.getName() << "\n"; + // std::cout << "==== From the base function: " << BaseFunction.getName().str() << " Function \"" << F.getName().str() << "\" is an external function." << std::endl; + // Check if the function has any external linkage + // if (F.hasExternalLinkage() || F.hasAvailableExternallyLinkage()) { + // if (F.hasExternalLinkage()) { + // // std::cout << "==== with external linkage" << std::endl; + // return true; + // } + + return true; + } + // std::cout << "==== From the base function: " << BaseFunction.getName().str() << " Function \"" << F.getName().str() << "\" is NOT an external function." << std::endl; + return false; +} + +// TODO: think of edge case: if we have 2 functions, and they call each other, then that would be endless recursion => check visitied funcitons+values used to visit + +// TODO: read somewhere that AliasAnalysis does not account for loops apparently => test +void findAllAliasesOfValue(Value &V, SmallVector &Aliases, Function &F) { + // The pointer itself counts as one of its own aliases + Aliases.emplace_back(&V); + + // std::cout << " Value \"" << V.getName().str() << "\" is aliased by:" << std::endl; + for (BasicBlock &BB : F) { + for (Value &OtherValue : BB) { + // Only iterate on all other values + if (&V == &OtherValue) { + continue; + } + + // AliasResult aliasResult = AA.alias(&V, &OtherValue); + AliasResult aliasResult = alias(V, OtherValue, F); + if (aliasResult != AliasResult::NoAlias) { + // if (!AA.isNoAlias(&V, &OtherValue)) { + // std::cout << " Other value \"" << OtherValue.getName().str() << "\" is a: " << getAliasResultString(aliasResult) << std::endl; + + Aliases.emplace_back(&OtherValue); + } + } + } +} + +// TODO: think about function pointers, what if those somehow call external functions + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +// If we encounter an unidentifiable function (e.g. function pointer, +// vararg function), then we immediately return false. +bool findAllFunctionsWhereValueIsPassedAsArgumentHelper(Value &V, SmallVector &FunctionCalls, std::set &VisitedValues, Module &BaseModule, Function &BaseFunction) { + // TODO: what about aliases? Are their users transitively counted as users as well? TEST this!!! + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + // errs() << "in find all functions where passed as param: Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we have already checked this value and all of its users. + // We want to continue searching, so don't mark this as an error. + return true; + } + + // std::cout << " Value \"" << V.getName().str() << "\" is used in functions:" << std::endl; + // errs() << " Value \"" << V << "\" is used in functions:\n"; + + for (User *U : V.users()) { + // TODO: we can't only consider function users, we also have to consider e.g. normal loads and stores, which are not function calls + // TODO: what if we add and subtract, and then use the new pointer to load => probably counts as an alias, but we could just recursively directly trace those calls + // We use CallBase to check for both InvokeInst and CallInst. + if (auto *CI = dyn_cast(U)) { + size_t ArgIndex = 0; + for (Value *Arg : CI->args()) { + if (Arg == &V) { + // std::cout << " Function \"" << CI->getCalledFunction()->getName().str() << "\"" << std::endl; + // errs() << " Function: " << CI << "\n"; + + auto PassedToFunction = CI->getCalledFunction(); + errs() << "Passed to function " << PassedToFunction->getName() << "\n"; + + // This might occur if the CallInst we tried to convert to a Function + // didn't have a known function signature at compile-time, e.g. because + // it was a function pointer, or if the argument size doesn't match + // the argument index, indicating a vararg function. We can't track + // function pointers further, so we are conservative and mark this + // as potentially calling external functions. + if (PassedToFunction == nullptr || PassedToFunction->arg_size() <= ArgIndex) { + // errs() << "We found a function that takes as value but we don't want to handle: " << PassedToFunction->getName() << "\n"; + errs() << "1\n"; + return false; + } + + // errs() << "We found a function that was not a nullptr\n"; + + // We only add valid function calls to the vector. + assert(PassedToFunction != nullptr); + FunctionCalls.emplace_back(PassedToFunction); + + // We need to check if our value has other users in the function + // it is passed to. + Value *ValueAsArg = PassedToFunction->getArg(ArgIndex); + + // The actual Value we passed to the function from another function + // differs from the parameter Value used inside the function. + assert(&V != ValueAsArg); + + // errs() << "Since Value " << V << " was passed to function " << PassedToFunction->getName() << " we need to recursively follow that function\n"; + // findAllFunctionsWhereValueIsPassedAsArgumentHelper(*ValueAsArg, FunctionCalls, VisitedValues); + + // TODO: This is DFS, moving this to a second for loop would make it BFS, which might be more efficient in our case. + // TODO: actually, this wouldn't work, since we would also check valuecomesfromelsewhere, which would always be true, since it is the parameter + + // The reason we have to thoroughly analyze the Value passed to the + // Function is that it might again have aliases inside that function, + // which we would not detect with just a recursive call to + // findAllFunctionsWhereValueIsPassedAsArgumentHelper. + // However, passing to a recursive function is fine, since we already + // analyzed and counted that function here. + if (PassedToFunction == &BaseFunction) { + continue; + } + // if (!memoryLocationIsSuitableForPA(*ValueAsArg, *PassedToFunction, BaseModule)) { + // TODO: fixed bug here + if (valueHasOtherUsesWithAA(*ValueAsArg, *PassedToFunction, BaseModule)) { + // errs() << "This will always be not suitable since it's the argument of the function. part 2\n"; + errs() << "2\n"; + return false; + } + + // Continue iterating over parameters even if we found our Value, + // since the same Value could be passed multiple times to the + // same function. + } + + ++ArgIndex; + } + } + // Also add all other functions that use the function's return value + if (!findAllFunctionsWhereValueIsPassedAsArgumentHelper(*U, FunctionCalls, VisitedValues, BaseModule, BaseFunction)) { + errs() << "3\n"; + return false; + } + } + + return true; +} + +void printSmallVector(const llvm::SmallVectorImpl &vec) { + for (const auto &item : vec) { + // llvm::errs() << item->getName() << " "; + } + // llvm::errs() << "\n"; +} + +// Find all function calls that use the specified value as an argument. +// Once we found a function, we also have to recursively find all of +// the functions that use that function('s return value). +// Additionally, returns false if we directly find some function we could +// not analyze further, and therefore classify as external. +bool findAllFunctionsWhereValueIsPassedAsArgument(Value &V, SmallVector &FunctionCalls, Module &BaseModule, Function &BaseFunction) { + std::set VisitedValues; + // return findAllFunctionsWhereValueIsPassedAsArgumentHelper(V, FunctionCalls, VisitedValues, BaseModule); + auto boolean = findAllFunctionsWhereValueIsPassedAsArgumentHelper(V, FunctionCalls, VisitedValues, BaseModule, BaseFunction); + // if (FunctionCalls.size() != 0) { + // // errs() << "found all functions where value is passed as arg: Val: " << V << "\n"; + // printSmallVector(FunctionCalls); + // } + return boolean; +} + +// A value has other uses if it is recursively passed as a function parameter +// to an external function. +// Therefore, once we see that a value is passed to a non-external function, +// we still need to check if the value has other uses in in that function. +// This function does not perform any Alias Analysis. +bool valueHasOtherUsesWithoutAA(Value &Value, Function &F, Module &BaseModule) { + SmallVector FunctionsUsingValue; + if (!findAllFunctionsWhereValueIsPassedAsArgument(Value, FunctionsUsingValue, BaseModule, F)) { + // While searching for the functions, we already encountered some + // error/invalid function that uses the Value, so we immediately return. + errs() << "encountered error during find all functions\n"; + return true; + } + + // TODO: !functionsOutsideModuleUsingPointer.empty() vs assert that all functionsUsingPointer are from this module + for (auto FunctionUsingValue: FunctionsUsingValue) { + // if (isExternalFunction(*function)) { + if (isExternalFunction(*FunctionUsingValue, BaseModule)) { + return true; + } + } + + // TODO: optimization: don't find all functions, just the first one that is + // // TODO: !functionsOutsideModuleUsingPointer.empty() vs assert that all functionsUsingPointer are from this module + + return false; +} + +// This is used instead of the WithoutAA variant when Alias Analysis is +// required, but valueComesFromElsewhere should not be bundled in. +bool valueHasOtherUsesWithAA(Value &V, Function &F, Module &BaseModule) { + SmallVector Aliases; + findAllAliasesOfValue(V, Aliases, F); + + // TODO: optimization possibility: cache the aliases that were already found to be non-suitable + // If any of the aliases are not suitable, then all of the aliases should be not suitable + for (auto Alias : Aliases) { + if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule)) { + // errs() << "This will always be not suitable since it's the argument of the function. part 1\n"; + return true; + } + } + + return false; +} + +// Checks whether the value is a parameter of a function. +bool valueIsParameterOfFunction(Value &V, Function &F) { + for (Argument &arg : F.args()) { + if (&arg == &V) { + return true; + } + } + return false; +} + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +bool valueComesFromElsewhereHelper(Value &V, Function &ParentFunction, std::set &VisitedValues) { + // errs() << "Checking value: " << V.getName().str() << "\n"; + + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we continue searching, but skip re-entering the loop. + errs() << "Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + return false; + } + + if (valueIsParameterOfFunction(V, ParentFunction)) { + errs() << "Value: " << V.getName().str() << " is the parameter of function: " << ParentFunction.getName() << "\n"; + return true; + } + + // A global value could be used across different modules, so we can never control/know that global values aren't used elsewhere + if (isa(&V)) { + errs() << "Value: " << V.getName().str() << " is a global value\n"; + return true; + } + + // Checks, recursively, whether a Value was returned by a function call. + if (auto *I = dyn_cast(&V)) { + // Check if instruction is (directly) the return value of a function call. + // TODO: check for superclass CallBase instead + // if (isa(I)) { + if (isa(I)) { + // errs() << "Instruction: " << I << " is the return value of a function call\n"; + // return true; + + CallBase *call = cast(I); + if (Function *calledFunction = call->getCalledFunction()) { + errs() << "Instruction: " << I << " is the return value of a function call to " << calledFunction->getName() << "\n"; + } else { + errs() << "Instruction: " << I << " is the return value of an indirect function call\n"; + } + return true; + } + // // Check if value was loaded from a memory location, i.e. value is the + // // return value of a load instruction. + // if (isa(I)) { + // errs() << "Instruction: " << I->getName() << " is a load inst\n"; + // return true; + // } + + // Check if value was loaded from a memory location, i.e. value is the + // return value of a load instruction. Also, the loaded value has to be + // a pointer. + if (LoadInst *LI = dyn_cast(I)) { + if (LI->getType()->isPointerTy()) { + errs() << "Instruction: " << I->getName() << " is a load inst loading a pointer\n"; + return true; + } + } + + // Since `V` doesn't come from elsewhere directly, we have to + // check whether any of the parameters/operands of the instruction `V` + // come from elsewhere. + for (auto &Op : I->operands()) { + if (valueComesFromElsewhereHelper(*Op, ParentFunction, VisitedValues)) { + errs() << "Resursive search in comes from elsewhere\n"; + return true; + } + } + } + + return false; +} + +// TODO: adapted for analysis over entire module +// Checks whether a value "comes from elsewhere". +// A value comes from elsewhere if any of the following conditions are met: +// 1. The value was passed as a parameter to the current function. +// 2. The value is the return value of any function. +// 3. The value was loaded from any memory location. +// 4. The value is a global value. +// In case the current value/instruction does not come from elsewhere, we also +// need to check whether any of its operands come from elsewhere. +bool valueComesFromElsewhere(Value &V, Function &ParentFunction) { + std::set VisitedValues; + return valueComesFromElsewhereHelper(V, ParentFunction, VisitedValues); +} + +// TODO: +// Rule Relaxations (only possible with module pass): +// - A value only has other uses if it is passed as a function parameter to an +// **external** function (aliases must still be accounted for though) or comes +// from such a function. + +// Pointer Authentication Rules: +// +// A pointer (value), that is being stored in or loaded from a memory location, +// is suitable for pointer authentication, if that memory location has no other +// uses and does not come from elsewhere. +// A pointer is only suitable for PA, if all of its aliases are also suitable for +// PA. +bool memoryLocationIsSuitableForPA(Value &MemoryLocation, Function &F, Module &BaseModule) { + SmallVector Aliases; + findAllAliasesOfValue(MemoryLocation, Aliases, F); + + // TODO: optimization possibility: cache the aliases that were already found to be non-suitable + // If any of the aliases are not suitable, then all of the aliases should be not suitable + for (auto Alias : Aliases) { + if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule) || valueComesFromElsewhere(*Alias, F)) { + // errs() << "This will always be not suitable since it's the argument of the function. part 1\n"; + return false; + } + // if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule)) { + // errs() << "Value " << Alias->getName() << " has other uses\n"; + // return false; + // } + // if (valueComesFromElsewhere(*Alias, F)) { + // errs() << "Value " << Alias->getName() << " comes from elsewhere\n"; + // return false; + // } + } + + return true; +} + +// Go through all load and stores of pointers and insert them into respective +// vector if they are suitable for pointer authentication. +bool authenticateStoredAndLoadedPointers(Function &F, Module &BaseModule, SmallVector &StorePointerInsts, SmallVector &LoadPointerInsts) { + // Look for instructions that load/store a pointer + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (StoreInst *SI = dyn_cast(&I)) { + // Store(value, ptr): $value is stored at data address pointed to by $ptr + // Check if value to be stored in memory is a pointer + Value *PointerValueToStore = SI->getValueOperand(); + if (PointerValueToStore->getType()->isPointerTy()) { + auto MemoryLocation = SI->getPointerOperand(); + // errs() << "==== Checking if store: " << SI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, BaseModule)) { + errs() << "Store instruction: " << SI << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + StorePointerInsts.emplace_back(SI); + } else { + errs() << "Store instruction: " << SI << " is not suitable for pointer authentication\n"; + } + } + } else + if (LoadInst *LI = dyn_cast(&I)) { + // Load(ptr): The data value located at the memory address pointed to by $ptr is returned + // Check if value to be loaded from memory is a pointer + if (LI->getType()->isPointerTy()) { + auto MemoryLocation = LI->getPointerOperand(); + // errs() << "==== Checking if load: " << LI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, BaseModule)) { + errs() << "Load instruction: " << LI << " is suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + LoadPointerInsts.emplace_back(LI); + } else { + errs() << "Load instruction: " << LI << " is not suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is not suitable for pointer authentication\n"; + } + } + } + } + } + + // We made changes if we added any pointer sign or auth instructions. + bool modified = !(LoadPointerInsts.empty() && StorePointerInsts.empty()); + return modified; +} + +void insertPACInstructions(SmallVector &StorePointerInsts, SmallVector &LoadPointerInsts, Function &F) { + auto *PointerSignFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_sign); + auto *PointerAuthFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_auth); + + // Add pointer signing inst before pointer store inst + for (auto SI : StorePointerInsts) { + Value *PointerValueToStore = SI->getValueOperand(); + + auto *PointerSignInst = CallInst::Create(PointerSignFunc, {PointerValueToStore}); + PointerSignInst->insertBefore(SI); + + // Replace the value operand in the store inst with the new signed value + SI->setOperand(0, PointerSignInst); + } + + // Add pointer authentication inst after pointer load inst + for (auto LI : LoadPointerInsts) { + auto *PointerAuthInst = CallInst::Create(PointerAuthFunc, {LI}); + PointerAuthInst->insertAfter(LI); + + // All further uses of the load's return value must use our authenticated pointer instead now + LI->replaceUsesWithIf(PointerAuthInst, [&](Use &U) { + return U.getUser() != PointerAuthInst; + }); + } +} + +}; // end class WebAssemblyPointerAuthentication + +// TODO: only run this on webassembly targets. Either do a target check here, or, ideally, only add to the wasm pipeline (if that works with LTO) +// TODO: keep in mind that this **must** only be run once, and it will be during LTO +// TODO: this means we have to prevent it from running multiple times, e.g. once before and once during LTO. To do this, we could check if any pointer auth instructions have been inserted already. if yes, we exit. +bool WebAssemblyPointerAuthenticationLTO::runOnModule(Module &M) { + bool modified = false; + + // We only want to insert the new pointer sign and auth instructions after + // the analysis of all functions. + std::map, SmallVector>> functionPointerMap; + + for (Function &F : M) { + SmallVector storeList; + SmallVector loadList; + + if (authenticateStoredAndLoadedPointers(F, M, storeList, loadList)) { + // Collect suitable Stores and Loads into vectors + functionPointerMap[&F] = std::make_pair(storeList, loadList); + modified = true; + } + } + + // Actually insert the new pointer authentication instructions + for (auto &[F, vectors] : functionPointerMap) { + auto &[storeList, loadList] = vectors; + insertPACInstructions(storeList, loadList, *F); + } + + // No changes relevant to other LLVM transformation passes were made. + // We simply added some instructions other passes are unaware of anyways. + // However, to be on the safe side, we will still indicate that the function + // was modified. + return modified; +} + +} // namespace + +char WebAssemblyPointerAuthenticationLTO::ID = 0; + +INITIALIZE_PASS_BEGIN(WebAssemblyPointerAuthenticationLTO, DEBUG_TYPE, + "WebAssembly Pointer Authentication LTO Pass", false, false) +INITIALIZE_PASS_END(WebAssemblyPointerAuthenticationLTO, DEBUG_TYPE, + "WebAssembly Pointer Authentication LTO Pass", false, false) + +ModulePass *llvm::createWebAssemblyPointerAuthenticationLTOPass() { + return new WebAssemblyPointerAuthenticationLTO(); +} + +#undef DEBUG_TYPE diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationModulePass.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationModulePass.cpp new file mode 100644 index 0000000000000..3cb32478f64eb --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyPointerAuthenticationModulePass.cpp @@ -0,0 +1,628 @@ +//===- WebAssemblyPointerAuthenticationModulePass.cpp - Pointer Authentication for WASM --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "Utils/WebAssemblyUtilities.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/LiveRegUnits.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineLoopInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/MemoryTaggingSupport.h" +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "wasm-pointer-authentication-module" + +namespace { + +class WebAssemblyPointerAuthenticationModule : public ModulePass { + +public: + static char ID; + + WebAssemblyPointerAuthenticationModule() : ModulePass(ID) { + initializeWebAssemblyPointerAuthenticationModulePass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + StringRef getPassName() const override { return "WebAssembly Pointer Authentication Module"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } + +private: + AliasResult alias(Value &V1, Value &V2, Function &F) { + AliasAnalysis& AA = getAnalysis(F).getAAResults(); + return AA.alias(&V1, &V2); + } + + AliasAnalysis& getAliasAnalysisForFunction(Function &F) { + return getAnalysis(F).getAAResults(); + } + +// TODO: remove +std::string getAliasResultString(AliasResult result) { + switch (result) { + case AliasResult::NoAlias: + return "NoAlias"; + case AliasResult::MayAlias: + return "MayAlias"; + case AliasResult::PartialAlias: + return "PartialAlias"; + case AliasResult::MustAlias: + return "MustAlias"; + } +} + + +// TODO: a potential bug in our implementation i thought of: what if a pointer is stored in a function, and then passed to another function where it is loaded from? then, it will be signed in the original function, but never authed in the other function, because memory locations passed as parameters are always classified as coming from elsewhere, which makes sense, because other modules could use this function and we can't assume they will sign the memory lcoation's stored pointer. +// A solution to this would most likely be the LTO pass, because there we can be certain no-one else would use the function that should auth + + + +// We define a function as external if it is declared, but not defined, in +// the current module. +bool isExternalFunction(Function &F, Module &ParentModule) { + // std::cout << "==== Checking whether function \"" << F.getName().str() << "\" is an external function." << std::endl; + + // Sanity check: If the function belongs to a different module than the + // one we are currently analyzing, then it is definitely an external function. + if (F.getParent() != &ParentModule) { + // errs() << "Found external function from a different module: " << F.getName() << "\n"; + return true; + } + + if (F.isDeclaration() && !F.isIntrinsic()) { + // errs() << "Found external function: " << F.getName() << "\n"; + // std::cout << "==== From the base function: " << BaseFunction.getName().str() << " Function \"" << F.getName().str() << "\" is an external function." << std::endl; + // Check if the function has any external linkage + // if (F.hasExternalLinkage() || F.hasAvailableExternallyLinkage()) { + // if (F.hasExternalLinkage()) { + // // std::cout << "==== with external linkage" << std::endl; + // return true; + // } + + return true; + } + // std::cout << "==== From the base function: " << BaseFunction.getName().str() << " Function \"" << F.getName().str() << "\" is NOT an external function." << std::endl; + return false; +} + +// TODO: think of edge case: if we have 2 functions, and they call each other, then that would be endless recursion => check visitied funcitons+values used to visit + +// TODO: read somewhere that AliasAnalysis does not account for loops apparently => test +void findAllAliasesOfValue(Value &V, SmallVector &Aliases, Function &F) { + // The pointer itself counts as one of its own aliases + Aliases.emplace_back(&V); + + // std::cout << " Value \"" << V.getName().str() << "\" is aliased by:" << std::endl; + for (BasicBlock &BB : F) { + for (Value &OtherValue : BB) { + // Only iterate on all other values + if (&V == &OtherValue) { + continue; + } + + // AliasResult aliasResult = AA.alias(&V, &OtherValue); + AliasResult aliasResult = alias(V, OtherValue, F); + if (aliasResult != AliasResult::NoAlias) { + // if (!AA.isNoAlias(&V, &OtherValue)) { + // std::cout << " Other value \"" << OtherValue.getName().str() << "\" is a: " << getAliasResultString(aliasResult) << std::endl; + + Aliases.emplace_back(&OtherValue); + } + } + } +} + +// TODO: think about function pointers, what if those somehow call external functions + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +// If we encounter an unidentifiable function (e.g. function pointer, +// vararg function), then we immediately return false. +bool findAllFunctionsWhereValueIsPassedAsArgumentHelper(Value &V, SmallVector &FunctionCalls, std::set &VisitedValues, Module &BaseModule, Function &BaseFunction) { + // TODO: what about aliases? Are their users transitively counted as users as well? TEST this!!! + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + // errs() << "in find all functions where passed as param: Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we have already checked this value and all of its users. + // We want to continue searching, so don't mark this as an error. + return true; + } + + // std::cout << " Value \"" << V.getName().str() << "\" is used in functions:" << std::endl; + // errs() << " Value \"" << V << "\" is used in functions:\n"; + + for (User *U : V.users()) { + // TODO: we can't only consider function users, we also have to consider e.g. normal loads and stores, which are not function calls + // TODO: what if we add and subtract, and then use the new pointer to load => probably counts as an alias, but we could just recursively directly trace those calls + // We use CallBase to check for both InvokeInst and CallInst. + if (auto *CI = dyn_cast(U)) { + size_t ArgIndex = 0; + for (Value *Arg : CI->args()) { + if (Arg == &V) { + // std::cout << " Function \"" << CI->getCalledFunction()->getName().str() << "\"" << std::endl; + // errs() << " Function: " << CI << "\n"; + + auto PassedToFunction = CI->getCalledFunction(); + errs() << "Passed to function " << PassedToFunction->getName() << "\n"; + + // This might occur if the CallInst we tried to convert to a Function + // didn't have a known function signature at compile-time, e.g. because + // it was a function pointer, or if the argument size doesn't match + // the argument index, indicating a vararg function. We can't track + // function pointers further, so we are conservative and mark this + // as potentially calling external functions. + if (PassedToFunction == nullptr || PassedToFunction->arg_size() <= ArgIndex) { + // errs() << "We found a function that takes as value but we don't want to handle: " << PassedToFunction->getName() << "\n"; + errs() << "1\n"; + return false; + } + + // errs() << "We found a function that was not a nullptr\n"; + + // We only add valid function calls to the vector. + assert(PassedToFunction != nullptr); + FunctionCalls.emplace_back(PassedToFunction); + + // We need to check if our value has other users in the function + // it is passed to. + Value *ValueAsArg = PassedToFunction->getArg(ArgIndex); + + // The actual Value we passed to the function from another function + // differs from the parameter Value used inside the function. + assert(&V != ValueAsArg); + + // errs() << "Since Value " << V << " was passed to function " << PassedToFunction->getName() << " we need to recursively follow that function\n"; + // findAllFunctionsWhereValueIsPassedAsArgumentHelper(*ValueAsArg, FunctionCalls, VisitedValues); + + // TODO: This is DFS, moving this to a second for loop would make it BFS, which might be more efficient in our case. + // TODO: actually, this wouldn't work, since we would also check valuecomesfromelsewhere, which would always be true, since it is the parameter + + // The reason we have to thoroughly analyze the Value passed to the + // Function is that it might again have aliases inside that function, + // which we would not detect with just a recursive call to + // findAllFunctionsWhereValueIsPassedAsArgumentHelper. + // However, passing to a recursive function is fine, since we already + // analyzed and counted that function here. + if (PassedToFunction == &BaseFunction) { + continue; + } + // if (!memoryLocationIsSuitableForPA(*ValueAsArg, *PassedToFunction, BaseModule)) { + // TODO: fixed bug here + if (valueHasOtherUsesWithAA(*ValueAsArg, *PassedToFunction, BaseModule)) { + // errs() << "This will always be not suitable since it's the argument of the function. part 2\n"; + errs() << "2\n"; + return false; + } + + // Continue iterating over parameters even if we found our Value, + // since the same Value could be passed multiple times to the + // same function. + } + + ++ArgIndex; + } + } + // Also add all other functions that use the function's return value + if (!findAllFunctionsWhereValueIsPassedAsArgumentHelper(*U, FunctionCalls, VisitedValues, BaseModule, BaseFunction)) { + errs() << "3\n"; + return false; + } + } + + return true; +} + +void printSmallVector(const llvm::SmallVectorImpl &vec) { + for (const auto &item : vec) { + // llvm::errs() << item->getName() << " "; + } + // llvm::errs() << "\n"; +} + +// Find all function calls that use the specified value as an argument. +// Once we found a function, we also have to recursively find all of +// the functions that use that function('s return value). +// Additionally, returns false if we directly find some function we could +// not analyze further, and therefore classify as external. +bool findAllFunctionsWhereValueIsPassedAsArgument(Value &V, SmallVector &FunctionCalls, Module &BaseModule, Function &BaseFunction) { + std::set VisitedValues; + // return findAllFunctionsWhereValueIsPassedAsArgumentHelper(V, FunctionCalls, VisitedValues, BaseModule); + auto boolean = findAllFunctionsWhereValueIsPassedAsArgumentHelper(V, FunctionCalls, VisitedValues, BaseModule, BaseFunction); + if (FunctionCalls.size() != 0) { + // errs() << "found all functions where value is passed as arg: Val: " << V << "\n"; + printSmallVector(FunctionCalls); + } + return boolean; +} + +// TODO: discuss depth first search (what we do) vs breadth first search in thesis + +// A value has other uses if it is recursively passed as a function parameter +// to an external function. +// Therefore, once we see that a value is passed to a non-external function, +// we still need to check if the value has other uses in in that function. +// This function does not perform any Alias Analysis. +bool valueHasOtherUsesWithoutAA(Value &Value, Function &F, Module &BaseModule) { + SmallVector FunctionsUsingValue; + if (!findAllFunctionsWhereValueIsPassedAsArgument(Value, FunctionsUsingValue, BaseModule, F)) { + // While searching for the functions, we already encountered some + // error/invalid function that uses the Value, so we immediately return. + errs() << "encountered error during find all functions\n"; + return true; + } + + // TODO: !functionsOutsideModuleUsingPointer.empty() vs assert that all functionsUsingPointer are from this module + for (auto FunctionUsingValue: FunctionsUsingValue) { + // if (isExternalFunction(*function)) { + if (isExternalFunction(*FunctionUsingValue, BaseModule)) { + return true; + } + } + + // TODO: optimization: don't find all functions, just the first one that is + // // TODO: !functionsOutsideModuleUsingPointer.empty() vs assert that all functionsUsingPointer are from this module + + return false; +} + +// This is used instead of the WithoutAA variant when Alias Analysis is +// required, but valueComesFromElsewhere should not be bundled in. +bool valueHasOtherUsesWithAA(Value &V, Function &F, Module &BaseModule) { + SmallVector Aliases; + findAllAliasesOfValue(V, Aliases, F); + + // TODO: optimization possibility: cache the aliases that were already found to be non-suitable + // If any of the aliases are not suitable, then all of the aliases should be not suitable + for (auto Alias : Aliases) { + if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule)) { + // errs() << "This will always be not suitable since it's the argument of the function. part 1\n"; + return true; + } + } + + return false; +} + +// Checks whether the value is a parameter of a function. +bool valueIsParameterOfFunction(Value &V, Function &F) { + for (Argument &arg : F.args()) { + if (&arg == &V) { + return true; + } + } + return false; +} + +// Tracks all visited values, and skips recursive call if we have already +// visited a certain value before (to avoid endless recursion). +bool valueComesFromElsewhereHelper(Value &V, Function &ParentFunction, std::set &VisitedValues) { + // errs() << "Checking value: " << V.getName().str() << "\n"; + + auto [_, ValueWasInserted] = VisitedValues.insert(&V); + if (!ValueWasInserted) { + // We found a value we have seen before, so were are in some sort of loop. + // Therefore, we continue searching, but skip re-entering the loop. + errs() << "Found value we have seen before: " << V.getName().str() << "; exiting to prevent infinite loop\n"; + return false; + } + + if (valueIsParameterOfFunction(V, ParentFunction)) { + errs() << "Value: " << V.getName().str() << " is the parameter of function: " << ParentFunction.getName() << "\n"; + return true; + } + + // A global value could be used across different modules, so we can never control/know that global values aren't used elsewhere + if (isa(&V)) { + errs() << "Value: " << V.getName().str() << " is a global value\n"; + return true; + } + + // Checks, recursively, whether a Value was returned by a function call. + if (auto *I = dyn_cast(&V)) { + // Check if instruction is (directly) the return value of a function call. + // TODO: check for superclass CallBase instead + // if (isa(I)) { + if (isa(I)) { + // errs() << "Instruction: " << I << " is the return value of a function call\n"; + // return true; + + CallBase *call = cast(I); + if (Function *calledFunction = call->getCalledFunction()) { + errs() << "Instruction: " << I << " is the return value of a function call to " << calledFunction->getName() << "\n"; + } else { + errs() << "Instruction: " << I << " is the return value of an indirect function call\n"; + } + return true; + } + // // Check if value was loaded from a memory location, i.e. value is the + // // return value of a load instruction. + // if (isa(I)) { + // errs() << "Instruction: " << I->getName() << " is a load inst\n"; + // return true; + // } + + // Check if value was loaded from a memory location, i.e. value is the + // return value of a load instruction. Also, the loaded value has to be + // a pointer. + if (LoadInst *LI = dyn_cast(I)) { + if (LI->getType()->isPointerTy()) { + errs() << "Instruction: " << I->getName() << " is a load inst loading a pointer\n"; + return true; + } + } + + // Since `V` doesn't come from elsewhere directly, we have to + // check whether any of the parameters/operands of the instruction `V` + // come from elsewhere. + for (auto &Op : I->operands()) { + if (valueComesFromElsewhereHelper(*Op, ParentFunction, VisitedValues)) { + errs() << "Resursive search in comes from elsewhere\n"; + return true; + } + } + } + + return false; +} + +// TODO: adapted for analysis over entire module +// Checks whether a value "comes from elsewhere". +// A value comes from elsewhere if any of the following conditions are met: +// 1. The value was passed as a parameter to the current function. +// 2. The value is the return value of any function. +// 3. The value was loaded from any memory location. +// 4. The value is a global value. +// In case the current value/instruction does not come from elsewhere, we also +// need to check whether any of its operands come from elsewhere. +bool valueComesFromElsewhere(Value &V, Function &ParentFunction) { + std::set VisitedValues; + return valueComesFromElsewhereHelper(V, ParentFunction, VisitedValues); +} + + +// TODO: +// Rule Relaxations (only possible with module pass): +// - A value only has other uses if it is passed as a function parameter to an +// **external** function (aliases must still be accounted for though) or comes +// from such a function. + +// Pointer Authentication Rules: +// +// A pointer (value), that is being stored in or loaded from a memory location, +// is suitable for pointer authentication, if that memory location has no other +// uses and does not come from elsewhere. +// A pointer is only suitable for PA, if all of its aliases are also suitable for +// PA. +bool memoryLocationIsSuitableForPA(Value &MemoryLocation, Function &F, Module &BaseModule) { + SmallVector Aliases; + findAllAliasesOfValue(MemoryLocation, Aliases, F); + + // TODO: optimization possibility: cache the aliases that were already found to be non-suitable + // If any of the aliases are not suitable, then all of the aliases should be not suitable + for (auto Alias : Aliases) { + // if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule) || valueComesFromElsewhere(*Alias, F)) { + // // errs() << "This will always be not suitable since it's the argument of the function. part 1\n"; + // return false; + // } + if (valueHasOtherUsesWithoutAA(*Alias, F, BaseModule)) { + errs() << "Value " << Alias->getName() << " has other uses\n"; + return false; + } + if (valueComesFromElsewhere(*Alias, F)) { + errs() << "Value " << Alias->getName() << " comes from elsewhere\n"; + return false; + } + } + + return true; +} + +void insertPACInstructions(SmallVector &StorePointerInsts, SmallVector &LoadPointerInsts, Function &F) { + auto *PointerSignFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_sign); + auto *PointerAuthFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::wasm_pointer_auth); + + // Add pointer signing inst before pointer store inst + for (auto SI : StorePointerInsts) { + Value *PointerValueToStore = SI->getValueOperand(); + + auto *PointerSignInst = CallInst::Create(PointerSignFunc, {PointerValueToStore}); + PointerSignInst->insertBefore(SI); + + // Replace the value operand in the store inst with the new signed value + SI->setOperand(0, PointerSignInst); + } + + // Add pointer authentication inst after pointer load inst + for (auto LI : LoadPointerInsts) { + auto *PointerAuthInst = CallInst::Create(PointerAuthFunc, {LI}); + PointerAuthInst->insertAfter(LI); + + // All further uses of the load's return value must use our authenticated pointer instead now + LI->replaceUsesWithIf(PointerAuthInst, [&](Use &U) { + return U.getUser() != PointerAuthInst; + }); + } +} + +// Go through all load and stores of pointers and insert them into respective +// vector if they are suitable for pointer authentication. +bool authenticateStoredAndLoadedPointers(Function &F, Module &BaseModule, SmallVector &StorePointerInsts, SmallVector &LoadPointerInsts) { + // Look for instructions that load/store a pointer + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (StoreInst *SI = dyn_cast(&I)) { + // Store(value, ptr): $value is stored at data address pointed to by $ptr + // Check if value to be stored in memory is a pointer + Value *PointerValueToStore = SI->getValueOperand(); + if (PointerValueToStore->getType()->isPointerTy()) { + auto MemoryLocation = SI->getPointerOperand(); + // errs() << "==== Checking if store: " << SI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, BaseModule)) { + errs() << "Store instruction: " << SI << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + StorePointerInsts.emplace_back(SI); + } else { + errs() << "Store instruction: " << SI << " is not suitable for pointer authentication\n"; + } + } + } else + if (LoadInst *LI = dyn_cast(&I)) { + // Load(ptr): The data value located at the memory address pointed to by $ptr is returned + // Check if value to be loaded from memory is a pointer + if (LI->getType()->isPointerTy()) { + auto MemoryLocation = LI->getPointerOperand(); + // errs() << "==== Checking if load: " << LI->getName().str() << " is suitable for PA\n"; + + if (memoryLocationIsSuitableForPA(*MemoryLocation, F, BaseModule)) { + errs() << "Load instruction: " << LI << " is suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is suitable for pointer authentication\n"; + // We shouldn't mutate the instructions we are iterating over + LoadPointerInsts.emplace_back(LI); + } else { + errs() << "Load instruction: " << LI << " is not suitable for pointer authentication\n"; + // std::cout << "Load instruction: " << LI->getName().str() << " is not suitable for pointer authentication\n"; + } + } + } + } + } + + // We made changes if we added any pointer sign or auth instructions. + bool modified = !(LoadPointerInsts.empty() && StorePointerInsts.empty()); + return modified; +} +}; + +bool WebAssemblyPointerAuthenticationModule::runOnModule(Module &M) { + bool modified = false; + + // We only want to insert the new pointer sign and auth instructions after + // the analysis of all functions. + std::map, SmallVector>> functionPointerMap; + + for (Function &F : M) { + // errs() << "======= Checking function: " << F.getName() << "\n"; + + SmallVector storeList; + SmallVector loadList; + + if (authenticateStoredAndLoadedPointers(F, M, storeList, loadList)) { + // Collect suitable Stores and Loads into vectors + functionPointerMap[&F] = std::make_pair(storeList, loadList); + modified = true; + } + } + + // Actually insert the new pointer authentication instructions + for (auto &[F, vectors] : functionPointerMap) { + auto &[storeList, loadList] = vectors; + insertPACInstructions(storeList, loadList, *F); + } + + for (Function &F : M) { + if (F.getName() == "__main_argc_argv" || F.getName() == "__original_main") { + F.dump(); + } + // errs() << "------ Printing altered function: " << F.getName() << "\n"; + // F.dump(); + } + + // TODO: potentially set to false in the future + + // No changes relevant to other LLVM transformation passes were made. + // We simply added some instructions other passes are unaware of anyways. + // However, to be on the safe side, we will still indicate that the function + // was modified. + return modified; +} + +} // namespace + +char WebAssemblyPointerAuthenticationModule::ID = 0; + +INITIALIZE_PASS_BEGIN(WebAssemblyPointerAuthenticationModule, DEBUG_TYPE, + "WebAssembly Pointer Authentication Module Pass", false, false) +INITIALIZE_PASS_END(WebAssemblyPointerAuthenticationModule, DEBUG_TYPE, + "WebAssembly Pointer Authentication Module Pass", false, false) + +ModulePass *llvm::createWebAssemblyPointerAuthenticationModulePass() { + return new WebAssemblyPointerAuthenticationModule(); +} + +#undef DEBUG_TYPE diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index c10147eaa326f..35a4968924d7f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -85,6 +85,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeWebAssemblyTarget() { initializeWebAssemblyLowerRefTypesIntPtrConvPass(PR); initializeWebAssemblyFixBrTableDefaultsPass(PR); initializeWebAssemblyDAGToDAGISelPass(PR); + initializeWebAssemblyPointerAuthenticationFunctionPass(PR); } //===----------------------------------------------------------------------===// @@ -462,8 +463,12 @@ void WebAssemblyPassConfig::addIRPasses() { // Expand indirectbr instructions to switches. addPass(createIndirectBrExpandPass()); - addPass( - createWebAssemblyMemorySafetyPass(TM->getOptLevel() == CodeGenOpt::None)); + // addPass( + // createWebAssemblyMemorySafetyPass(TM->getOptLevel() == CodeGenOpt::None)); + + addPass(createWebAssemblyPointerAuthenticationFunctionPass()); + // addPass(createWebAssemblyPointerAuthenticationModulePass()); + // addPass(createWebAssemblyPointerAuthenticationLTOPass()); TargetPassConfig::addIRPasses(); }