diff --git a/shmoverride/shmoverride.c b/shmoverride/shmoverride.c index d399905d..3cd95362 100644 --- a/shmoverride/shmoverride.c +++ b/shmoverride/shmoverride.c @@ -72,6 +72,8 @@ static int (*real_munmap) (void *shmaddr, size_t len); static int (*real_fstat64) (VER_ARG int fd, struct stat64 *buf); static int (*real_fstat)(VER_ARG int fd, struct stat *buf); +static int try_init(void); + static struct stat global_buf; static int gntdev_fd = -1; @@ -84,7 +86,7 @@ static int xc_hnd; static xengnttab_handle *xgt; static char __shmid_filename[SHMID_FILENAME_LEN]; static char *shmid_filename = NULL; -static int idfd = -1, display = -1; +static int idfd = -1, display = -1, init_called = 0; static uint8_t *mmap_mfns(struct shm_args_hdr *shm_args) { uint8_t *map; @@ -151,6 +153,8 @@ ASM_DEF(void *, mmap, real_fstat = FSTAT; } + try_init(); + #if defined MAP_ANON && defined MAP_ANONYMOUS && (MAP_ANONYMOUS) != (MAP_ANON) # error header bug (def mismatch) #endif @@ -217,6 +221,9 @@ ASM_DEF(int, munmap, void *addr, size_t len) { if (len > SIZE_MAX - XC_PAGE_SIZE) abort(); + + try_init(); + const uintptr_t addr_int = (uintptr_t)addr; const uintptr_t rounded_addr = addr_int & ~(uintptr_t)(XC_PAGE_SIZE - 1); return real_munmap((void *)rounded_addr, len + (addr_int - rounded_addr)); @@ -438,6 +445,7 @@ static int assign_off(off_t *off) { #define STAT(id) \ ASM_DEF(int, f ## id, int filedes, struct id *buf) { \ + try_init(); \ int res = real_f ## id(VER filedes, buf); \ if (res || \ !S_ISCHR(buf->st_mode) || \ @@ -454,6 +462,7 @@ STAT(stat64) #ifdef _STAT_VER #define STAT(id) \ ASM_DEF(int, __fx ## id, int ver, int filedes, struct id *buf) { \ + try_init(); \ if (ver != _STAT_VER) { \ fprintf(stderr, \ "Wrong _STAT_VER: got %d, expected %d, libc has incompatibly changed\n", \ @@ -467,8 +476,13 @@ STAT(stat64) #undef STAT #endif -int __attribute__ ((constructor)) initfunc(void) +static int try_init(void) { + // Ideally it is being called in constructor, if something is calling this before + // constructor - we're assuming it is not multi-threaded code. + if (__builtin_expect(init_called, 1)) return 0; + init_called = 1; + unsetenv("LD_PRELOAD"); fprintf(stderr, "shmoverride constructor running\n"); dlerror(); @@ -581,6 +595,10 @@ int __attribute__ ((constructor)) initfunc(void) shm_args = NULL; return 0; } +int __attribute__ ((constructor)) initfunc(void) +{ + return try_init(); +} int __attribute__ ((destructor)) descfunc(void) {