-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.c
executable file
·137 lines (106 loc) · 3.77 KB
/
main.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <assert.h>
#include <dlfcn.h>
#include <errno.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
size_t test_sym = 42;
int
one(void)
{
return 42;
}
void
execute_function(void *fn(void))
{
fn();
}
int
main()
{
for (int i = 0; i < 1; i++) {
int ret = one();
printf ("ret %d\n", ret);
}
void *handle = dlopen("./test.so", RTLD_LAZY | RTLD_LOCAL | RTLD_NODELETE);
if (handle == NULL) {
printf("could not open lib: %s\n", dlerror());
return -1;
}
void *own_handle = dlopen(NULL, RTLD_NOW | RTLD_GLOBAL);
if (own_handle == NULL) {
printf("could not resolve own symbols: %s\n", dlerror());
return -1;
}
void *magic_sym_addr = dlsym(handle, "magic");
assert(magic_sym_addr != NULL && "could not find magic sym.");
void *magic_sym_len_addr = dlsym(handle, "magic_len");
assert(magic_sym_len_addr != NULL && "could not find magic length");
size_t magic_sym_len = *((size_t *)magic_sym_len_addr);
char **magic_sym = (char **)magic_sym_addr;
printf("number of magic symbols: %zu\n", magic_sym_len);
for (int i = 0; i < magic_sym_len; i++) {
printf("looking for symbol %s.\n", magic_sym[i]);
/* XXX Find symbol in this executable, replace it. */
void *own_addr = dlsym(own_handle, magic_sym[i]);
if (own_addr == NULL) {
printf("Could not find symbol %s from magic in own executable: %s. Continuing to next sym.\n", magic_sym[i], dlerror());
continue;
}
void *magic_function_addr = dlsym(own_handle, "execute_function");
if (own_addr == NULL) {
printf("Could not find symbol execute_function from magic in own executable: %s\n", dlerror());
exit(1);
}
void *new_addr = dlsym(handle, magic_sym[i]);
if (new_addr == NULL) {
printf("Could not find symbol %s from magic in new shared lib: %s. Continuing to next sym.\n", magic_sym[i], dlerror());
continue;
}
printf("address of own function: %p\n", own_addr);
printf("want to jump to: %p\n", new_addr);
/* the relative address takes effect _after_ the length of the opcode and arg, which is 5 */
uint8_t jmp = 0xE9;
uint8_t call = 0xE8;
uint8_t mov = 0x88;
uint32_t rel_addr = (uint32_t)(((uintptr_t)new_addr) - ((uintptr_t)own_addr)) - 5;
printf("relative jmp address: %x\n", rel_addr);
/* XXX: assert rel_addr has no significant bits other than the lower 31 and the sign bit. */
uint8_t new_contents[8];
// TODO: Strategy:
// http://www.ragestorm.net/blogs/?p=107
//
memcpy(&new_contents[0], &call, sizeof(uint8_t));
memcpy(&new_contents[1], &rel_addr, sizeof(uint32_t));
printf("new_contents without old: %p\n", *((void **)&new_contents));
printf("relative jmp address from new_content: %x\n", *((uint32_t *)&new_contents[1]));
/* 3 bytes we need to keep from the old address */
memcpy(&new_contents[5], own_addr + 5, 3);
/*
* What needs to get written to the original function is a call to the new function
* and a ret.
*/
printf("new_contents with old: %p\n", *((void **)&new_contents));
void *nearest_page_down = (void *)((uintptr_t)own_addr & ~((1 << 22) - 1));
printf("nearest page down from own_addr %p: %p\n", own_addr, nearest_page_down);
/* Really this instruction might be on a page boundary. */
int ret = mprotect(nearest_page_down, 4096, PROT_READ|PROT_WRITE|PROT_EXEC);
if (ret != 0) {
char *error = strerror(errno);
printf("could not mprotect the right address: %s\n", error);
}
/* THIS IS DANGEROUS */
memcpy(own_addr, &new_contents, 8);
/* XXX: create a function which jumps to an arbitrary global address. jump to that function. */
/* XXX: mprotect back */
printf("Resolved symbol %s\n", magic_sym[i]);
}
for (int i = 0; i < 1; i++) {
printf ("running function %p\n", one);
int ret = one();
printf ("ret %d\n", ret);
}
return 0;
}