diff --git a/src/crypto.cpp b/src/crypto.cpp index 9ad17a7..daddf8e 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -13,7 +13,7 @@ #include #include #include -#define DEBUG 1 +#define DEBUG 0 #include "crypto.h" @@ -32,7 +32,6 @@ int current_cipher = 0; #define AES_BLOCK_SIZE 8 - static MUTEX_TYPE *mutex_buf = NULL; static void locking_function(int mode, int n, const char*file, int line); @@ -57,7 +56,6 @@ const int max_block_size = 64*1024; // Function for OpenSSL to lock mutex static void locking_function(int mode, int n, const char*file, int line) { - pris("LOCKING FUNCTION CALLED"); if (mode & CRYPTO_LOCK) MUTEX_LOCK(mutex_buf[n]); @@ -116,14 +114,8 @@ int THREAD_cleanup(void) void *crypto_update_thread(void* _args) { - // clock_t start = clock(); - e_thread_args* args = (e_thread_args*)_args; - int evp_outlen = 0; - - // for (int i = 0; i < args->len; i ++) - // args->in[i] = args->in[i] ^ 0xCC; if(!EVP_CipherUpdate(args->ctx, args->in, &evp_outlen, args->in, args->len)){ fprintf(stderr, "encryption error\n"); @@ -138,10 +130,6 @@ void *crypto_update_thread(void* _args) args->len = evp_outlen; - // clock_t end = clock(); - // double time_elapsed_in_seconds = (end - start)/(double)CLOCKS_PER_SEC; - // fprintf(stderr, "Time in crypto: %.3f s\n", time_elapsed_in_seconds); - pthread_exit(NULL); } @@ -161,14 +149,8 @@ int crypto_update(char* in, int len, crypto *c) } else { - // UPDATE CIPHER NUMBER - i = 0; - - // for (int i = 0; i < len; i ++) - // in[i] = in[i] ^ 0xCC; - // [EN][DE]CRYPT - if(!EVP_CipherUpdate(&c->ctx[i], (uchar*)in, &evp_outlen, (uchar*)in, len)){ + if(!EVP_CipherUpdate(&c->ctx[0], (uchar*)in, &evp_outlen, (uchar*)in, len)){ fprintf(stderr, "encryption error\n"); exit(EXIT_FAILURE); } @@ -186,5 +168,70 @@ int crypto_update(char* in, int len, crypto *c) } +int joined[N_CRYPTO_THREADS]; + +int pthread_join_disregard_ESRCH(pthread_t thread, int thread_id){ + + if (joined[thread_id]) + return 0; + + int ret = pthread_join(thread, NULL); + + pthread_mutex_lock(&c_lock); + joined[thread_id] = 1; + pthread_mutex_unlock(&c_lock); + + if (ret){ + if (ret != ESRCH){ + fprintf(stderr, "Unable to join encryption thread: %d\n", ret); + exit(1); + } + } + + return 0; +} + +int join_all_encryption_threads(pthread_t threads[N_CRYPTO_THREADS]){ + for (int i = 0; i < N_CRYPTO_THREADS; i++) + pthread_join_disregard_ESRCH(threads[i], i); + return 0; +} + +int pass_to_enc_thread(pthread_t crypto_threads[N_CRYPTO_THREADS], + e_thread_args * e_args, + int * curr_crypto_thread, + char* in, int len, + crypto*c){ + + pthread_join_disregard_ESRCH(crypto_threads[*curr_crypto_thread], *curr_crypto_thread); + + // Initialize and set thread detached attribute + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE); + + e_args[*curr_crypto_thread].in = (uchar*) in; + e_args[*curr_crypto_thread].len = len; + e_args[*curr_crypto_thread].ctx = &c->ctx[*curr_crypto_thread]; + + int ret = pthread_create(&crypto_threads[*curr_crypto_thread], + &attr, crypto_update_thread, &e_args[*curr_crypto_thread]); + + pthread_mutex_lock(&c_lock); + joined[*curr_crypto_thread] = 0; + pthread_mutex_unlock(&c_lock); + + if (ret){ + fprintf(stderr, "Unable to create thread: %d\n", ret); + exit(1); + } + + *curr_crypto_thread = *curr_crypto_thread+1; + + if (*curr_crypto_thread>=N_CRYPTO_THREADS) + *curr_crypto_thread = 0; + + return 0; +} diff --git a/src/crypto.h b/src/crypto.h index 6829f0b..d6abc35 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -19,7 +19,7 @@ and limitations under the License. #define CRYPTO_H -#define N_CRYPTO_THREADS 4 +#define N_CRYPTO_THREADS 16 #define USE_CRYPTO 1 #define PASSPHRASE_SIZE 32 @@ -186,6 +186,13 @@ typedef struct e_thread_args int crypto_update(char* in, int len, crypto *c); void *crypto_update_thread(void* _args); +int pthread_join_disregard_ESRCH(pthread_t thread, int thread_id); +int join_all_encryption_threads(pthread_t threads[N_CRYPTO_THREADS]); +int pass_to_enc_thread(pthread_t crypto_threads[N_CRYPTO_THREADS], + e_thread_args * e_args, + int * curr_crypto_thread, + char* in, int len, + crypto*c); #endif diff --git a/src/udtcat.h b/src/udtcat.h index 756245c..01fbfcc 100644 --- a/src/udtcat.h +++ b/src/udtcat.h @@ -30,7 +30,6 @@ and limitations under the License. #include "crypto.h" #define BUFF_SIZE 67108864 -/* #define BUFF_SIZE 2097152 */ typedef struct recv_args{ UDTSOCKET*usocket; diff --git a/src/udtcat_threads.cpp b/src/udtcat_threads.cpp index c25af1b..f5df686 100644 --- a/src/udtcat_threads.cpp +++ b/src/udtcat_threads.cpp @@ -25,11 +25,12 @@ and limitations under the License. #include "udtcat.h" #include "udtcat_threads.h" +#define EXIT_FAILURE 1 + #define prii(x) fprintf(stderr,"debug:%d\n",x) #define pris(x) fprintf(stderr,"debug: %s\n",x) #define prisi(x,y) fprintf(stderr,"%s: %d\n",x,y) -#define uc_err(x) {fprintf(stderr,"error:%s\n",x);exit(1);} - +#define uc_err(x) {fprintf(stderr,"error:%s\n",x);exit(EXIT_FAILURE);} const int ECONNLOST = 2001; @@ -40,56 +41,54 @@ int n_recv_threads = 0; int last_printed = -1; pthread_mutex_t lock; +/* + THINGS NEEDED IN APPLICATION TO RUN THE MT SUPPORTED CRYPTO: + pthread_t crypto_threads[N_THREADS]; + e_thread_args * e_args; + crypto *c; + int* curr_crypto_thread; +*/ + void* recvdata(void * _args) { - recv_args * args = (recv_args*)_args; + // Handle socket + recv_args * args = (recv_args*)_args; UDTSOCKET recver = *args->usocket; - // delete (UDTSOCKET*) args->usocket; - int size = BUFF_SIZE; + // Decryption locals int read_len; - char* data = new char[size]; - e_thread_args e_args[N_CRYPTO_THREADS]; - int decrypt_buf_len = BUFF_SIZE / N_CRYPTO_THREADS; - prisi("decrypt_buf_len", decrypt_buf_len); - pthread_t decryption_threads[N_CRYPTO_THREADS]; - - int len; - int decrypt_cursor = 0; - int buffer_cursor = 0; - int curr_crypto_thread = 0; + + int decrypt_buf_len = BUFF_SIZE / N_CRYPTO_THREADS; + int len, decrypt_cursor, buffer_cursor, curr_crypto_thread; + decrypt_cursor = buffer_cursor = curr_crypto_thread = 0; if (USE_CRYPTO){ char* decrypt_buffer = (char*) malloc(BUFF_SIZE*sizeof(char)); - if (!decrypt_buffer) - uc_err("Error allocating decryption buffer"); + if (!decrypt_buffer){ + fprintf(stderr, "Unable to allocate decryption buffer"); + exit(EXIT_FAILURE); + } while (1) { - // read in from UDT - if (UDT::ERROR == (len = UDT::recv(recver, - decrypt_buffer+buffer_cursor, + // Read in from UDT + if (UDT::ERROR == (len = UDT::recv(recver, decrypt_buffer+buffer_cursor, BUFF_SIZE-buffer_cursor, 0))) { if (UDT::getlasterror().getErrorCode() != ECONNLOST) cerr << "recv:" << UDT::getlasterror().getErrorMessage() << endl; + // Finish any remaining data in the buffer if (buffer_cursor > 0){ - - for (int i = 0; i < N_CRYPTO_THREADS; i++) - pthread_join(decryption_threads[i], NULL); - + join_all_encryption_threads(decryption_threads); crypto_update(decrypt_buffer+decrypt_cursor, - buffer_cursor-decrypt_cursor, - args->dec); - + buffer_cursor-decrypt_cursor, args->dec); write(fileno(stdout), decrypt_buffer, buffer_cursor); - } exit(0); @@ -97,83 +96,52 @@ void* recvdata(void * _args) buffer_cursor += len; + // This should never happen if (buffer_cursor > BUFF_SIZE) uc_err("Decryption buffer overflow"); + // Decrypt what we've got while (decrypt_cursor+decrypt_buf_len <= buffer_cursor){ - - pthread_join(decryption_threads[curr_crypto_thread], NULL); - - e_args[curr_crypto_thread].in = (uchar*) decrypt_buffer+decrypt_cursor; - e_args[curr_crypto_thread].len = decrypt_buf_len; - e_args[curr_crypto_thread].c = args->dec; - e_args[curr_crypto_thread].ctx = &args->dec->ctx[curr_crypto_thread]; - - pthread_create(&decryption_threads[curr_crypto_thread], - NULL, crypto_update_thread, &e_args[curr_crypto_thread]); - - curr_crypto_thread++; - if (curr_crypto_thread>=N_CRYPTO_THREADS) - curr_crypto_thread = 0; - + pass_to_enc_thread(decryption_threads, + e_args, + &curr_crypto_thread, + decrypt_buffer+decrypt_cursor, + decrypt_buf_len, + args->dec); decrypt_cursor += decrypt_buf_len; - } + // Write the decrypted buffer and reset if (decrypt_cursor >= BUFF_SIZE){ - - for (int i = 0; i < N_CRYPTO_THREADS; i++) - pthread_join(decryption_threads[i], NULL); - + join_all_encryption_threads(decryption_threads); write(fileno(stdout), decrypt_buffer, BUFF_SIZE); - - buffer_cursor = 0; - decrypt_cursor = 0; - curr_crypto_thread = 0; - + buffer_cursor = decrypt_cursor = curr_crypto_thread = 0; } } - prisi("Last buffer state", buffer_cursor); - + free(decrypt_buffer); } else { - + char* data = new char[BUFF_SIZE]; while (1){ - - // read in from UDT - if (UDT::ERROR == (read_len = UDT::recv(recver, data, size, 0))) { + if (UDT::ERROR == (read_len = UDT::recv(recver, data, BUFF_SIZE, 0))) { if (UDT::getlasterror().getErrorCode() != ECONNLOST) cerr << "recv:" << UDT::getlasterror().getErrorMessage() << endl; - exit(0); + break; } - write(fileno(stdout), data, read_len); } - } - - // free(data); - // delete [] data; - UDT::close(recver); - return NULL; -} -clock_t start, end; +} int send_buf(UDTSOCKET client, char* buf, int size, int flags) { - - // end = clock(); - // double time_elapsed_in_seconds = (end - start)/(double)CLOCKS_PER_SEC; - // fprintf(stderr, "Time since last send: %f\n", time_elapsed_in_seconds); - - int ssize = 0; int ss; while (ssize < size) { @@ -190,11 +158,9 @@ int send_buf(UDTSOCKET client, char* buf, int size, int flags) if (UDT::ERROR == ss) { cerr << "send:" << UDT::getlasterror().getErrorMessage() << endl; - exit(1); + exit(EXIT_FAILURE); } - // start = clock(); - return ss; } @@ -202,31 +168,20 @@ int send_buf(UDTSOCKET client, char* buf, int size, int flags) void* senddata(void* _args) { - - start = clock(); - + snd_args * args = (snd_args*) _args; - UDTSOCKET client = *(UDTSOCKET*)args->usocket; - // delete (UDTSOCKET*)usocket; + char* encrypt_buffer; e_thread_args e_args[N_CRYPTO_THREADS]; int encrypt_buf_len = BUFF_SIZE / N_CRYPTO_THREADS; - char* encrypt_buffer; - + pthread_t encryption_threads[N_CRYPTO_THREADS]; - char *data; int flags = 0; - - if (!(data = (char*)malloc(BUFF_SIZE*sizeof(char)))) - uc_err("Unable to allocate thread buffer data"); - - int len; - int encrypt_cursor = 0; - int buffer_cursor = 0; - int curr_crypto_thread = 0; + int len, encrypt_cursor, buffer_cursor, curr_crypto_thread; + encrypt_cursor = buffer_cursor = curr_crypto_thread = 0; if (USE_CRYPTO){ @@ -247,70 +202,54 @@ void* senddata(void* _args) buffer_cursor += len; + // This should never happen if (buffer_cursor > BUFF_SIZE) uc_err("Encryption buffer overflow"); + // Encrypt data while (encrypt_cursor+encrypt_buf_len <= buffer_cursor){ - - pthread_join(encryption_threads[curr_crypto_thread], NULL); - - e_args[curr_crypto_thread].in = (uchar*) encrypt_buffer+encrypt_cursor; - e_args[curr_crypto_thread].len = encrypt_buf_len; - e_args[curr_crypto_thread].c = args->enc; - e_args[curr_crypto_thread].ctx = &args->enc->ctx[curr_crypto_thread]; - - pthread_create(&encryption_threads[curr_crypto_thread], - NULL, crypto_update_thread, &e_args[curr_crypto_thread]); - - curr_crypto_thread++; - if (curr_crypto_thread >= N_CRYPTO_THREADS) - curr_crypto_thread = 0; - + pass_to_enc_thread(encryption_threads, + e_args, + &curr_crypto_thread, + encrypt_buffer+encrypt_cursor, + encrypt_buf_len, + args->enc); encrypt_cursor += encrypt_buf_len; - } + // If full buffer, then send to UDT if (encrypt_cursor >= BUFF_SIZE){ - - for (int i = 0; i < N_CRYPTO_THREADS; i++) - pthread_join(encryption_threads[i], NULL); - + join_all_encryption_threads(encryption_threads); send_buf(client, encrypt_buffer, buffer_cursor, flags); - - buffer_cursor = 0; - encrypt_cursor = 0; - curr_crypto_thread = 0; - + buffer_cursor = encrypt_cursor = curr_crypto_thread = 0; } - } + } + // Finish any remaining buffer data if (buffer_cursor > 0){ - - for (int i = 0; i < N_CRYPTO_THREADS; i++) - pthread_join(encryption_threads[i], NULL); - - crypto_update(encrypt_buffer+encrypt_cursor, - buffer_cursor-encrypt_cursor, - args->enc); - + join_all_encryption_threads(encryption_threads); + crypto_update(encrypt_buffer+encrypt_cursor, buffer_cursor-encrypt_cursor, args->enc); send_buf(client, encrypt_buffer, buffer_cursor, flags); } + free(encrypt_buffer); } else { // Ignore crypto + + char *data; + if (!(data = (char*)malloc(BUFF_SIZE*sizeof(char)))) + uc_err("Unable to allocate thread buffer data"); while (1) { - len = read(STDIN_FILENO, data, BUFF_SIZE); + len = read(STDIN_FILENO, data, BUFF_SIZE); if (len < 0){ uc_err(strerror(errno)); - } else if (!len) { break; - } else { send_buf(client, data, len, flags); } @@ -353,7 +292,7 @@ void* send_buf_threaded(void*_args) if (UDT::ERROR == ss) { cerr << "send:" << UDT::getlasterror().getErrorMessage() << endl; - exit(1); + exit(EXIT_FAILURE); } args->idle = 1;