-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathjitify.hpp
4540 lines (4172 loc) · 163 KB
/
jitify.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
-----------
Jitify 0.9
-----------
A C++ library for easy integration of CUDA runtime compilation into
existing codes.
--------------
How to compile
--------------
Compiler dependencies: <jitify.hpp>, -std=c++11
Linker dependencies: dl cuda nvrtc
--------------------------------------
Embedding source files into executable
--------------------------------------
g++ ... -ldl -rdynamic -DJITIFY_ENABLE_EMBEDDED_FILES=1
-Wl,-b,binary,my_kernel.cu,include/my_header.cuh,-b,default nvcc ... -ldl
-Xcompiler "-rdynamic
-Wl\,-b\,binary\,my_kernel.cu\,include/my_header.cuh\,-b\,default"
JITIFY_INCLUDE_EMBEDDED_FILE(my_kernel_cu);
JITIFY_INCLUDE_EMBEDDED_FILE(include_my_header_cuh);
----
TODO
----
Extract valid compile options and pass the rest to cuModuleLoadDataEx
See if can have stringified headers automatically looked-up
by having stringify add them to a (static) global map.
The global map can be updated by creating a static class instance
whose constructor performs the registration.
Can then remove all headers from JitCache constructor in example code
See other TODOs in code
*/
/*! \file jitify.hpp
* \brief The Jitify library header
*/
/*! \mainpage Jitify - A C++ library that simplifies the use of NVRTC
* \p Use class jitify::JitCache to manage and launch JIT-compiled CUDA
* kernels.
*
* \p Use namespace jitify::reflection to reflect types and values into
* code-strings.
*
* \p Use JITIFY_INCLUDE_EMBEDDED_FILE() to declare files that have been
* embedded into the executable using the GCC linker.
*
* \p Use jitify::parallel_for and JITIFY_LAMBDA() to generate and launch
* simple kernels.
*/
#pragma once
#ifndef JITIFY_THREAD_SAFE
#define JITIFY_THREAD_SAFE 1
#endif
#if JITIFY_ENABLE_EMBEDDED_FILES
#include <dlfcn.h>
#endif
#include <stdint.h>
#include <algorithm>
#include <cctype>
#include <cstring> // For strtok_r etc.
#include <deque>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#if JITIFY_THREAD_SAFE
#include <mutex>
#endif
#include <cuda.h>
#include <cuda_runtime_api.h> // For dim3, cudaStream_t
#if CUDA_VERSION >= 8000
#define NVRTC_GET_TYPE_NAME 1
#endif
#include <nvrtc.h>
// For use by get_current_executable_path().
#ifdef __linux__
#include <linux/limits.h> // For PATH_MAX
#include <cstdlib> // For realpath
#define JITIFY_PATH_MAX PATH_MAX
#elif defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#define JITIFY_PATH_MAX MAX_PATH
#else
#error "Unsupported platform"
#endif
#ifdef _MSC_VER // MSVC compiler
#include <dbghelp.h> // For UnDecorateSymbolName
#else
#include <cxxabi.h> // For abi::__cxa_demangle
#endif
#if defined(_WIN32) || defined(_WIN64)
// WAR for strtok_r being called strtok_s on Windows
#pragma push_macro("strtok_r")
#undef strtok_r
#define strtok_r strtok_s
// WAR for min and max possibly being macros defined by windows.h
#pragma push_macro("min")
#pragma push_macro("max")
#undef min
#undef max
#endif
#ifndef JITIFY_PRINT_LOG
#define JITIFY_PRINT_LOG 1
#endif
#if JITIFY_PRINT_ALL
#define JITIFY_PRINT_INSTANTIATION 1
#define JITIFY_PRINT_SOURCE 1
#define JITIFY_PRINT_LOG 1
#define JITIFY_PRINT_PTX 1
#define JITIFY_PRINT_LINKER_LOG 1
#define JITIFY_PRINT_LAUNCH 1
#define JITIFY_PRINT_HEADER_PATHS 1
#endif
#if JITIFY_ENABLE_EMBEDDED_FILES
#define JITIFY_FORCE_UNDEFINED_SYMBOL(x) void* x##_forced = (void*)&x
/*! Include a source file that has been embedded into the executable using the
* GCC linker.
* \param name The name of the source file (<b>not</b> as a string), which must
* be sanitized by replacing non-alpha-numeric characters with underscores.
* E.g., \code{.cpp}JITIFY_INCLUDE_EMBEDDED_FILE(my_header_h)\endcode will
* include the embedded file "my_header.h".
* \note Files declared with this macro can be referenced using
* their original (unsanitized) filenames when creating a \p
* jitify::Program instance.
*/
#define JITIFY_INCLUDE_EMBEDDED_FILE(name) \
extern "C" uint8_t _jitify_binary_##name##_start[] asm("_binary_" #name \
"_start"); \
extern "C" uint8_t _jitify_binary_##name##_end[] asm("_binary_" #name \
"_end"); \
JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_start); \
JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_end)
#endif // JITIFY_ENABLE_EMBEDDED_FILES
/*! Jitify library namespace
*/
namespace jitify {
/*! Source-file load callback.
*
* \param filename The name of the requested source file.
* \param tmp_stream A temporary stream that can be used to hold source code.
* \return A pointer to an input stream containing the source code, or NULL
* to defer loading of the file to Jitify's file-loading mechanisms.
*/
typedef std::istream* (*file_callback_type)(std::string filename,
std::iostream& tmp_stream);
// Exclude from Doxygen
//! \cond
class JitCache;
// Simple cache using LRU discard policy
template <typename KeyType, typename ValueType>
class ObjectCache {
public:
typedef KeyType key_type;
typedef ValueType value_type;
private:
typedef std::map<key_type, value_type> object_map;
typedef std::deque<key_type> key_rank;
typedef typename key_rank::iterator rank_iterator;
object_map _objects;
key_rank _ranked_keys;
size_t _capacity;
inline void discard_old(size_t n = 0) {
if (n > _capacity) {
throw std::runtime_error("Insufficient capacity in cache");
}
while (_objects.size() > _capacity - n) {
key_type discard_key = _ranked_keys.back();
_ranked_keys.pop_back();
_objects.erase(discard_key);
}
}
public:
inline ObjectCache(size_t capacity = 8) : _capacity(capacity) {}
inline void resize(size_t capacity) {
_capacity = capacity;
this->discard_old();
}
inline bool contains(const key_type& k) const {
return (bool)_objects.count(k);
}
inline void touch(const key_type& k) {
if (!this->contains(k)) {
throw std::runtime_error("Key not found in cache");
}
rank_iterator rank = std::find(_ranked_keys.begin(), _ranked_keys.end(), k);
if (rank != _ranked_keys.begin()) {
// Move key to front of ranks
_ranked_keys.erase(rank);
_ranked_keys.push_front(k);
}
}
inline value_type& get(const key_type& k) {
if (!this->contains(k)) {
throw std::runtime_error("Key not found in cache");
}
this->touch(k);
return _objects[k];
}
inline value_type& insert(const key_type& k,
const value_type& v = value_type()) {
this->discard_old(1);
_ranked_keys.push_front(k);
return _objects.insert(std::make_pair(k, v)).first->second;
}
template <typename... Args>
inline value_type& emplace(const key_type& k, Args&&... args) {
this->discard_old(1);
// Note: Use of piecewise_construct allows non-movable non-copyable types
auto iter = _objects
.emplace(std::piecewise_construct, std::forward_as_tuple(k),
std::forward_as_tuple(args...))
.first;
_ranked_keys.push_front(iter->first);
return iter->second;
}
};
namespace detail {
// Convenience wrapper for std::vector that provides handy constructors
template <typename T>
class vector : public std::vector<T> {
typedef std::vector<T> super_type;
public:
vector() : super_type() {}
vector(size_t n) : super_type(n) {} // Note: Not explicit, allows =0
vector(std::vector<T> const& vals) : super_type(vals) {}
template <int N>
vector(T const (&vals)[N]) : super_type(vals, vals + N) {}
vector(std::vector<T>&& vals) : super_type(vals) {}
vector(std::initializer_list<T> vals) : super_type(vals) {}
};
// Helper functions for parsing/manipulating source code
inline std::string replace_characters(std::string str,
std::string const& oldchars,
char newchar) {
size_t i = str.find_first_of(oldchars);
while (i != std::string::npos) {
str[i] = newchar;
i = str.find_first_of(oldchars, i + 1);
}
return str;
}
inline std::string sanitize_filename(std::string name) {
return replace_characters(name, "/\\.-: ?%*|\"<>", '_');
}
#if JITIFY_ENABLE_EMBEDDED_FILES
class EmbeddedData {
void* _app;
EmbeddedData(EmbeddedData const&);
EmbeddedData& operator=(EmbeddedData const&);
public:
EmbeddedData() {
_app = dlopen(NULL, RTLD_LAZY);
if (!_app) {
throw std::runtime_error(std::string("dlopen failed: ") + dlerror());
}
dlerror(); // Clear any existing error
}
~EmbeddedData() {
if (_app) {
dlclose(_app);
}
}
const uint8_t* operator[](std::string key) const {
key = sanitize_filename(key);
key = "_binary_" + key;
uint8_t const* data = (uint8_t const*)dlsym(_app, key.c_str());
if (!data) {
throw std::runtime_error(std::string("dlsym failed: ") + dlerror());
}
return data;
}
const uint8_t* begin(std::string key) const {
return (*this)[key + "_start"];
}
const uint8_t* end(std::string key) const { return (*this)[key + "_end"]; }
};
#endif // JITIFY_ENABLE_EMBEDDED_FILES
inline bool is_tokenchar(char c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '_';
}
inline std::string replace_token(std::string src, std::string token,
std::string replacement) {
size_t i = src.find(token);
while (i != std::string::npos) {
if (i == 0 || i == src.size() - token.size() ||
(!is_tokenchar(src[i - 1]) && !is_tokenchar(src[i + token.size()]))) {
src.replace(i, token.size(), replacement);
i += replacement.size();
} else {
i += token.size();
}
i = src.find(token, i);
}
return src;
}
inline std::string path_base(std::string p) {
// "/usr/local/myfile.dat" -> "/usr/local"
// "foo/bar" -> "foo"
// "foo/bar/" -> "foo/bar"
#if defined _WIN32 || defined _WIN64
const char* sep = "\\/";
#else
char sep = '/';
#endif
size_t i = p.find_last_of(sep);
if (i != std::string::npos) {
return p.substr(0, i);
} else {
return "";
}
}
inline std::string path_join(std::string p1, std::string p2) {
#ifdef _WIN32
char sep = '\\';
#else
char sep = '/';
#endif
if (p1.size() && p2.size() && p2[0] == sep) {
throw std::invalid_argument("Cannot join to absolute path");
}
if (p1.size() && p1[p1.size() - 1] != sep) {
p1 += sep;
}
return p1 + p2;
}
// Elides "/." and "/.." tokens from path.
inline std::string path_simplify(const std::string& path) {
std::vector<std::string> dirs;
std::string cur_dir;
bool after_slash = false;
for (int i = 0; i < (int)path.size(); ++i) {
if (path[i] == '/') {
if (after_slash) continue; // Ignore repeat slashes
after_slash = true;
if (cur_dir == ".." && !dirs.empty() && dirs.back() != "..") {
if (dirs.size() == 1 && dirs.front().empty()) {
throw std::runtime_error(
"Invalid path: back-traversals exceed depth of absolute path");
}
dirs.pop_back();
} else if (cur_dir != ".") { // Ignore /./
dirs.push_back(cur_dir);
}
cur_dir.clear();
} else {
after_slash = false;
cur_dir.push_back(path[i]);
}
}
if (!after_slash) {
dirs.push_back(cur_dir);
}
std::stringstream ss;
for (int i = 0; i < (int)dirs.size() - 1; ++i) {
ss << dirs[i] << "/";
}
if (!dirs.empty()) ss << dirs.back();
if (after_slash) ss << "/";
return ss.str();
}
inline unsigned long long hash_larson64(const char* s,
unsigned long long seed = 0) {
unsigned long long hash = seed;
while (*s) {
hash = hash * 101 + *s++;
}
return hash;
}
inline uint64_t hash_combine(uint64_t a, uint64_t b) {
// Note: The magic number comes from the golden ratio
return a ^ (0x9E3779B97F4A7C17ull + b + (b >> 2) + (a << 6));
}
inline bool extract_include_info_from_compile_error(std::string log,
std::string& name,
std::string& parent,
int& line_num) {
static const std::vector<std::string> pattern = {
"could not open source file \"", "cannot open source file \""};
for (auto& p : pattern) {
size_t beg = log.find(p);
if (beg != std::string::npos) {
beg += p.size();
size_t end = log.find("\"", beg);
name = log.substr(beg, end - beg);
size_t line_beg = log.rfind("\n", beg);
if (line_beg == std::string::npos) {
line_beg = 0;
} else {
line_beg += 1;
}
size_t split = log.find("(", line_beg);
parent = log.substr(line_beg, split - line_beg);
line_num =
atoi(log.substr(split + 1, log.find(")", split + 1) - (split + 1))
.c_str());
return true;
}
}
return false;
}
inline bool is_include_directive_with_quotes(const std::string& source,
int line_num) {
// TODO: Check each find() for failure.
size_t beg = 0;
for (int i = 1; i < line_num; ++i) {
beg = source.find("\n", beg) + 1;
}
beg = source.find("include", beg) + 7;
beg = source.find_first_of("\"<", beg);
return source[beg] == '"';
}
inline std::string comment_out_code_line(int line_num, std::string source) {
size_t beg = 0;
for (int i = 1; i < line_num; ++i) {
beg = source.find("\n", beg) + 1;
}
return (source.substr(0, beg) + "//" + source.substr(beg));
}
inline void print_with_line_numbers(std::string const& source) {
int linenum = 1;
std::stringstream source_ss(source);
std::stringstream output_ss;
output_ss.imbue(std::locale::classic());
for (std::string line; std::getline(source_ss, line); ++linenum) {
output_ss << std::setfill(' ') << std::setw(3) << linenum << " " << line
<< std::endl;
}
std::cout << output_ss.str();
}
inline void print_compile_log(std::string program_name,
std::string const& log) {
std::cout << "---------------------------------------------------"
<< std::endl;
std::cout << "--- JIT compile log for " << program_name << " ---"
<< std::endl;
std::cout << "---------------------------------------------------"
<< std::endl;
std::cout << log << std::endl;
std::cout << "---------------------------------------------------"
<< std::endl;
}
inline std::vector<std::string> split_string(std::string str,
long maxsplit = -1,
std::string delims = " \t") {
std::vector<std::string> results;
if (maxsplit == 0) {
results.push_back(str);
return results;
}
// Note: +1 to include NULL-terminator
std::vector<char> v_str(str.c_str(), str.c_str() + (str.size() + 1));
char* c_str = v_str.data();
char* saveptr = c_str;
char* token = nullptr;
for (long i = 0; i != maxsplit; ++i) {
token = ::strtok_r(c_str, delims.c_str(), &saveptr);
c_str = 0;
if (!token) {
return results;
}
results.push_back(token);
}
// Check if there's a final piece
token += ::strlen(token) + 1;
if (token - v_str.data() < (ptrdiff_t)str.size()) {
// Find the start of the final piece
token += ::strspn(token, delims.c_str());
if (*token) {
results.push_back(token);
}
}
return results;
}
static const std::map<std::string, std::string>& get_jitsafe_headers_map();
inline bool load_source(
std::string filename, std::map<std::string, std::string>& sources,
std::string current_dir = "",
std::vector<std::string> include_paths = std::vector<std::string>(),
file_callback_type file_callback = 0, std::string* program_name = nullptr,
std::map<std::string, std::string>* fullpaths = nullptr,
bool search_current_dir = true) {
std::istream* source_stream = 0;
std::stringstream string_stream;
std::ifstream file_stream;
// First detect direct source-code string ("my_program\nprogram_code...")
size_t newline_pos = filename.find("\n");
if (newline_pos != std::string::npos) {
std::string source = filename.substr(newline_pos + 1);
filename = filename.substr(0, newline_pos);
string_stream << source;
source_stream = &string_stream;
}
if (program_name) {
*program_name = filename;
}
if (sources.count(filename)) {
// Already got this one
return true;
}
if (!source_stream) {
std::string fullpath = path_join(current_dir, filename);
// Try loading from callback
if (!file_callback ||
!((source_stream = file_callback(fullpath, string_stream)) != 0)) {
#if JITIFY_ENABLE_EMBEDDED_FILES
// Try loading as embedded file
EmbeddedData embedded;
std::string source;
try {
source.assign(embedded.begin(fullpath), embedded.end(fullpath));
string_stream << source;
source_stream = &string_stream;
} catch (std::runtime_error const&)
#endif // JITIFY_ENABLE_EMBEDDED_FILES
{
// Try loading from filesystem
bool found_file = false;
if (search_current_dir) {
file_stream.open(fullpath.c_str());
if (file_stream) {
source_stream = &file_stream;
found_file = true;
}
}
// Search include directories
if (!found_file) {
for (int i = 0; i < (int)include_paths.size(); ++i) {
fullpath = path_join(include_paths[i], filename);
file_stream.open(fullpath.c_str());
if (file_stream) {
source_stream = &file_stream;
found_file = true;
break;
}
}
if (!found_file) {
// Try loading from builtin headers
fullpath = path_join("__jitify_builtin", filename);
auto it = get_jitsafe_headers_map().find(filename);
if (it != get_jitsafe_headers_map().end()) {
string_stream << it->second;
source_stream = &string_stream;
} else {
return false;
}
}
}
}
}
if (fullpaths) {
// Record the full file path corresponding to this include name.
(*fullpaths)[filename] = path_simplify(fullpath);
}
}
sources[filename] = std::string();
std::string& source = sources[filename];
std::string line;
size_t linenum = 0;
unsigned long long hash = 0;
bool pragma_once = false;
bool remove_next_blank_line = false;
while (std::getline(*source_stream, line)) {
++linenum;
// HACK WAR for static variables not allowed on the device (unless
// __shared__)
// TODO: This breaks static member variables
// line = replace_token(line, "static const", "/*static*/ const");
// TODO: Need to watch out for /* */ comments too
std::string cleanline =
line.substr(0, line.find("//")); // Strip line comments
// if( cleanline.back() == "\r" ) { // Remove Windows line ending
// cleanline = cleanline.substr(0, cleanline.size()-1);
//}
// TODO: Should trim whitespace before checking .empty()
if (cleanline.empty() && remove_next_blank_line) {
remove_next_blank_line = false;
continue;
}
// Maintain a file hash for use in #pragma once WAR
hash = hash_larson64(line.c_str(), hash);
if (cleanline.find("#pragma once") != std::string::npos) {
pragma_once = true;
// Note: This is an attempt to recover the original line numbering,
// which otherwise gets off-by-one due to the include guard.
remove_next_blank_line = true;
// line = "//" + line; // Comment out the #pragma once line
continue;
}
// HACK WAR for Thrust using "#define FOO #pragma bar"
// TODO: This is not robust to block comments, line continuations, or tabs.
size_t pragma_beg = cleanline.find("#pragma ");
if (pragma_beg != std::string::npos) {
std::string line_after_pragma = line.substr(pragma_beg + 8);
// TODO: Handle block comments (currently they cause a compilation error).
size_t comment_start = line_after_pragma.find("//");
std::string pragma_args = line_after_pragma.substr(0, comment_start);
// handle quote character used in #pragma expression
pragma_args = replace_token(pragma_args, "\"", "\\\"");
std::string comment = comment_start != std::string::npos
? line_after_pragma.substr(comment_start)
: "";
line = line.substr(0, pragma_beg) + "_Pragma(\"" + pragma_args + "\")" +
comment;
}
source += line + "\n";
}
// HACK TESTING (WAR for cub)
source = "#define cudaDeviceSynchronize() cudaSuccess\n" + source;
////source = "cudaError_t cudaDeviceSynchronize() { return cudaSuccess; }\n" +
/// source;
// WAR for #pragma once causing problems when there are multiple inclusions
// of the same header from different paths.
if (pragma_once) {
std::stringstream ss;
ss.imbue(std::locale::classic());
ss << std::uppercase << std::hex << std::setw(8) << std::setfill('0')
<< hash;
std::string include_guard_name = "_JITIFY_INCLUDE_GUARD_" + ss.str() + "\n";
std::string include_guard_header;
include_guard_header += "#ifndef " + include_guard_name;
include_guard_header += "#define " + include_guard_name;
std::string include_guard_footer;
include_guard_footer += "#endif // " + include_guard_name;
source = include_guard_header + source + "\n" + include_guard_footer;
}
// return filename;
return true;
}
} // namespace detail
//! \endcond
/*! Jitify reflection utilities namespace
*/
namespace reflection {
// Provides type and value reflection via a function 'reflect':
// reflect<Type>() -> "Type"
// reflect(value) -> "(T)value"
// reflect<VAL>() -> "VAL"
// reflect<Type,VAL> -> "VAL"
// reflect_template<float,NonType<int,7>,char>() -> "<float,7,char>"
// reflect_template({"float", "7", "char"}) -> "<float,7,char>"
/*! A wrapper class for non-type template parameters.
*/
template <typename T, T VALUE_>
struct NonType {
constexpr static T VALUE = VALUE_;
};
// Forward declaration
template <typename T>
inline std::string reflect(T const& value);
//! \cond
namespace detail {
template <typename T>
inline std::string value_string(const T& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
// WAR for non-printable characters
template <>
inline std::string value_string<char>(const char& x) {
std::stringstream ss;
ss << (int)x;
return ss.str();
}
template <>
inline std::string value_string<signed char>(const signed char& x) {
std::stringstream ss;
ss << (int)x;
return ss.str();
}
template <>
inline std::string value_string<unsigned char>(const unsigned char& x) {
std::stringstream ss;
ss << (int)x;
return ss.str();
}
template <>
inline std::string value_string<wchar_t>(const wchar_t& x) {
std::stringstream ss;
ss << (long)x;
return ss.str();
}
// Specialisation for bool true/false literals
template <>
inline std::string value_string<bool>(const bool& x) {
return x ? "true" : "false";
}
// Removes all tokens that start with double underscores.
inline void strip_double_underscore_tokens(char* s) {
using jitify::detail::is_tokenchar;
char* w = s;
do {
if (*s == '_' && *(s + 1) == '_') {
while (is_tokenchar(*++s))
;
}
} while ((*w++ = *s++));
}
//#if CUDA_VERSION < 8000
#ifdef _MSC_VER // MSVC compiler
inline std::string demangle_cuda_symbol(const char* mangled_name) {
// We don't have a way to demangle CUDA symbol names under MSVC.
return mangled_name;
}
inline std::string demangle_native_type(const std::type_info& typeinfo) {
// Get the decorated name and skip over the leading '.'.
const char* decorated_name = typeinfo.raw_name() + 1;
char undecorated_name[4096];
if (UnDecorateSymbolName(
decorated_name, undecorated_name,
sizeof(undecorated_name) / sizeof(*undecorated_name),
UNDNAME_NO_ARGUMENTS | // Treat input as a type name
UNDNAME_NAME_ONLY // No "class" and "struct" prefixes
/*UNDNAME_NO_MS_KEYWORDS*/)) { // No "__cdecl", "__ptr64" etc.
// WAR for UNDNAME_NO_MS_KEYWORDS messing up function types.
strip_double_underscore_tokens(undecorated_name);
return undecorated_name;
}
throw std::runtime_error("UnDecorateSymbolName failed");
}
#else // not MSVC
inline std::string demangle_cuda_symbol(const char* mangled_name) {
size_t bufsize = 0;
char* buf = nullptr;
std::string demangled_name;
int status;
auto demangled_ptr = std::unique_ptr<char, decltype(free)*>(
abi::__cxa_demangle(mangled_name, buf, &bufsize, &status), free);
if (status == 0) {
demangled_name = demangled_ptr.get(); // all worked as expected
} else if (status == -2) {
demangled_name = mangled_name; // we interpret this as plain C name
} else if (status == -1) {
throw std::runtime_error(
std::string("memory allocation failure in __cxa_demangle"));
} else if (status == -3) {
throw std::runtime_error(std::string("invalid argument to __cxa_demangle"));
}
return demangled_name;
}
inline std::string demangle_native_type(const std::type_info& typeinfo) {
return demangle_cuda_symbol(typeinfo.name());
}
#endif // not MSVC
//#endif // CUDA_VERSION < 8000
template <typename>
class JitifyTypeNameWrapper_ {};
template <typename T>
struct type_reflection {
inline static std::string name() {
//#if CUDA_VERSION < 8000
// TODO: Use nvrtcGetTypeName once it has the same behavior as this.
// WAR for typeid discarding cv qualifiers on value-types
// Wrap type in dummy template class to preserve cv-qualifiers, then strip
// off the wrapper from the resulting string.
std::string wrapped_name =
demangle_native_type(typeid(JitifyTypeNameWrapper_<T>));
// Note: The reflected name of this class also has namespace prefixes.
const std::string wrapper_class_name = "JitifyTypeNameWrapper_<";
size_t start = wrapped_name.find(wrapper_class_name);
if (start == std::string::npos) {
throw std::runtime_error("Type reflection failed: " + wrapped_name);
}
start += wrapper_class_name.size();
std::string name =
wrapped_name.substr(start, wrapped_name.size() - (start + 1));
return name;
//#else
// std::string ret;
// nvrtcResult status = nvrtcGetTypeName<T>(&ret);
// if( status != NVRTC_SUCCESS ) {
// throw std::runtime_error(std::string("nvrtcGetTypeName
// failed:
//")+ nvrtcGetErrorString(status));
// }
// return ret;
//#endif
}
}; // namespace detail
template <typename T, T VALUE>
struct type_reflection<NonType<T, VALUE> > {
inline static std::string name() {
return jitify::reflection::reflect(VALUE);
}
};
} // namespace detail
//! \endcond
/*! Create an Instance object that contains a const reference to the
* value. We use this to wrap abstract objects from which we want to extract
* their type at runtime (e.g., derived type). This is used to facilitate
* templating on derived type when all we know at compile time is abstract
* type.
*/
template <typename T>
struct Instance {
const T& value;
Instance(const T& value_arg) : value(value_arg) {}
};
/*! Create an Instance object from which we can extract the value's run-time
* type.
* \param value The const value to be captured.
*/
template <typename T>
inline Instance<T const> instance_of(T const& value) {
return Instance<T const>(value);
}
/*! A wrapper used for representing types as values.
*/
template <typename T>
struct Type {};
// Type reflection
// E.g., reflect<float>() -> "float"
// Note: This strips trailing const and volatile qualifiers
/*! Generate a code-string for a type.
* \code{.cpp}reflect<float>() --> "float"\endcode
*/
template <typename T>
inline std::string reflect() {
return detail::type_reflection<T>::name();
}
// Value reflection
// E.g., reflect(3.14f) -> "(float)3.14"
/*! Generate a code-string for a value.
* \code{.cpp}reflect(3.14f) --> "(float)3.14"\endcode
*/
template <typename T>
inline std::string reflect(T const& value) {
return "(" + reflect<T>() + ")" + detail::value_string(value);
}
// Non-type template arg reflection (implicit conversion to int64_t)
// E.g., reflect<7>() -> "(int64_t)7"
/*! Generate a code-string for an integer non-type template argument.
* \code{.cpp}reflect<7>() --> "(int64_t)7"\endcode
*/
template <int64_t N>
inline std::string reflect() {
return reflect<NonType<int64_t, N> >();
}
// Non-type template arg reflection (explicit type)
// E.g., reflect<int,7>() -> "(int)7"
/*! Generate a code-string for a generic non-type template argument.
* \code{.cpp} reflect<int,7>() --> "(int)7" \endcode
*/
template <typename T, T N>
inline std::string reflect() {
return reflect<NonType<T, N> >();
}
// Type reflection via value
// E.g., reflect(Type<float>()) -> "float"
/*! Generate a code-string for a type wrapped as a Type instance.
* \code{.cpp}reflect(Type<float>()) --> "float"\endcode
*/
template <typename T>
inline std::string reflect(jitify::reflection::Type<T>) {
return reflect<T>();
}
/*! Generate a code-string for a type wrapped as an Instance instance.
* \code{.cpp}reflect(Instance<float>(3.1f)) --> "float"\endcode
* or more simply when passed to a instance_of helper
* \code{.cpp}reflect(instance_of(3.1f)) --> "float"\endcodei
* This is specifically for the case where we want to extract the run-time
* type, e.g., derived type, of an object pointer.
*/
template <typename T>
inline std::string reflect(jitify::reflection::Instance<T>& value) {
return detail::demangle_native_type(typeid(value.value));
}
// Type from value
// E.g., type_of(3.14f) -> Type<float>()
/*! Create a Type object representing a value's type.
* \param value The value whose type is to be captured.
*/
template <typename T>
inline Type<T> type_of(T&) {
return Type<T>();
}
/*! Create a Type object representing a value's type.
* \param value The const value whose type is to be captured.
*/
template <typename T>
inline Type<T const> type_of(T const&) {
return Type<T const>();
}
// Multiple value reflections one call, returning list of strings
template <typename... Args>
inline std::vector<std::string> reflect_all(Args... args) {
return {reflect(args)...};
}
inline std::string reflect_list(jitify::detail::vector<std::string> const& args,
std::string opener = "",
std::string closer = "") {