forked from dereklstinson/gocudnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcudnnAlgorithm.go
222 lines (194 loc) · 5.75 KB
/
cudnnAlgorithm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
import (
"runtime"
"github.com/dereklstinson/cutil"
)
//AlgorithmD holds the C.cudnnAlgorithmDescriptor_t
type AlgorithmD struct {
descriptor C.cudnnAlgorithmDescriptor_t
gogc bool
}
//Algorithm is used to pass generic stuff
type Algorithm C.cudnnAlgorithm_t
func (a Algorithm) c() C.cudnnAlgorithm_t { return C.cudnnAlgorithm_t(a) }
func (a *Algorithm) cptr() *C.cudnnAlgorithm_t { return (*C.cudnnAlgorithm_t)(a) }
//CreateAlgorithmDescriptor creates an AlgorithmD that needs to be set
func CreateAlgorithmDescriptor() (*AlgorithmD, error) {
x := new(AlgorithmD)
x.gogc = setfinalizer
err := Status(C.cudnnCreateAlgorithmDescriptor(&x.descriptor)).error("CreateAlgorithmDescriptor")
if err != nil {
return nil, err
}
if setfinalizer {
runtime.SetFinalizer(x, destroyalgorithmdescriptor)
}
return x, nil
}
//Set sets the algorthm into the algorithmd
func (a *AlgorithmD) Set(algo Algorithm) error {
return Status(C.cudnnSetAlgorithmDescriptor(
a.descriptor,
algo.c(),
)).error("SetAlgorithmDescriptor")
}
// Get returns AlgrothmD values a Algorithm.
func (a *AlgorithmD) Get() (Algorithm, error) {
var algo C.cudnnAlgorithm_t
err := Status(C.cudnnGetAlgorithmDescriptor(
a.descriptor,
&algo,
)).error("GetAlgorithmDescriptor")
return Algorithm(algo), err
}
//Copy returns a copy of AlgorithmD
func (a *AlgorithmD) Copy() (*AlgorithmD, error) {
var desc C.cudnnAlgorithmDescriptor_t
err := Status(C.cudnnCopyAlgorithmDescriptor(
a.descriptor,
desc,
)).error("CopyAlgorithmDescriptor")
if err != nil {
return nil, err
}
return &AlgorithmD{
descriptor: desc,
}, nil
}
//Destroy destroys descriptor. Right now since gocudnn is on go's gc this won't do anything
func (a *AlgorithmD) Destroy() error {
if a.gogc || setfinalizer {
return nil
}
return destroyalgorithmdescriptor(a)
}
func destroyalgorithmdescriptor(a *AlgorithmD) error {
err := Status(C.cudnnDestroyAlgorithmDescriptor(a.descriptor)).error("DestroyDescriptor")
if err != nil {
return err
}
a = nil
return nil
}
//CreateAlgorithmPerformance creates and returns an AlgorithmPerformance
//
//returns
//
// nil = Sucess
// CUDNN_STATUS_ALLOC_FAILED - The resources could not be allocated
func CreateAlgorithmPerformance(numberToCreate int32) ([]AlgorithmPerformance, error) {
//var algoperf C.cudnnAlgorithmPerformance_t
algoperf := make([]C.cudnnAlgorithmPerformance_t, numberToCreate)
err := Status(C.cudnnCreateAlgorithmPerformance(
&algoperf[0],
C.int(numberToCreate),
)).error("CreateAlgorithmPerformance")
return calgoperftogoarray(algoperf, setfinalizer), err
}
//Set sets the algo performance
func (a *AlgorithmPerformance) Set(aD *AlgorithmD, s Status, time float32, memory uint) error {
return Status(C.cudnnSetAlgorithmPerformance(a.descriptor, aD.descriptor, s.c(), C.float(time), C.size_t(memory))).error("SetAlgorithmPerformance")
}
func (a *AlgorithmPerformance) keepsalive() {
runtime.KeepAlive(a)
}
//Get gets algorithm performance. it returns AlgorithmD, Status, float32(time), SizeT(memorysize in bytes)
//I didn't include the setalgorithmperformance func, but it might need to be made.
func (a *AlgorithmPerformance) Get() (AlgorithmD, Status, float32, uint, error) {
var algoD AlgorithmD
var status C.cudnnStatus_t
var time C.float
var mem C.size_t
err := Status(C.cudnnGetAlgorithmPerformance(
a.descriptor,
&algoD.descriptor,
&status,
&time,
&mem,
)).error("GetAlgorithmPerformance")
return algoD, Status(status), float32(time), uint(mem), err
}
//Destroy destroys the perfmance
func (a *AlgorithmPerformance) Destroy() error {
if a.gogc || setfinalizer {
return nil
}
return destroyalgorithmperformance(a)
}
func destroyalgorithmperformance(a *AlgorithmPerformance) error {
return Status(C.cudnnDestroyAlgorithmPerformance(
&a.descriptor,
C.int(0),
)).error("DestroyPerformance")
}
func calgoperftogoarray(input []C.cudnnAlgorithmPerformance_t, gogc bool) []AlgorithmPerformance {
size := len(input)
output := make([]AlgorithmPerformance, size)
for i := 0; i < size; i++ {
output[i].gogc = gogc
output[i].descriptor = (input[i])
output[i].index = C.int(i)
}
return output
}
//GetAlgorithmSpaceSize gets the size in bytes of the algorithm
func (a *AlgorithmD) GetAlgorithmSpaceSize(handle *Handle) (uint, error) {
var sizet C.size_t
var err error
if handle.w != nil {
err = handle.w.Work(func() error {
return Status(C.cudnnGetAlgorithmSpaceSize(handle.x, a.descriptor, &sizet)).error("(a *AlgorithmD) GetAlgorithmSpaceSize(handle *Handle)")
})
} else {
err = Status(C.cudnnGetAlgorithmSpaceSize(handle.x, a.descriptor, &sizet)).error("(a *AlgorithmD) GetAlgorithmSpaceSize(handle *Handle)")
}
return uint(sizet), err
}
//SaveAlgorithm saves the algorithm to host
func (a *AlgorithmD) SaveAlgorithm(handle *Handle, algoSpace cutil.Mem, sizeinbytes uint) error {
var err error
if handle.w != nil {
err = handle.w.Work(func() error {
return Status(C.cudnnSaveAlgorithm(
handle.x,
a.descriptor,
algoSpace.Ptr(),
C.size_t(sizeinbytes),
)).error("SaveAlgorithm")
})
} else {
err = Status(C.cudnnSaveAlgorithm(
handle.x,
a.descriptor,
algoSpace.Ptr(),
C.size_t(sizeinbytes),
)).error("SaveAlgorithm")
}
return err
}
//RestoreAlgorithm from host
func (a *AlgorithmD) RestoreAlgorithm(handle *Handle, algoSpace cutil.Mem, sizeinbytes uint) error {
var err error
if handle.w != nil {
err = handle.w.Work(func() error {
return Status(C.cudnnRestoreAlgorithm(
handle.x,
algoSpace.Ptr(),
C.size_t(sizeinbytes),
a.descriptor,
)).error("RestoreAlgorithm")
})
} else {
err = Status(C.cudnnRestoreAlgorithm(
handle.x,
algoSpace.Ptr(),
C.size_t(sizeinbytes),
a.descriptor,
)).error("RestoreAlgorithm")
}
return err
}