-
Notifications
You must be signed in to change notification settings - Fork 242
/
Copy pathGpuSemaphore.scala
161 lines (145 loc) · 4.82 KB
/
GpuSemaphore.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import ai.rapids.cudf.{NvtxColor, NvtxRange}
import org.apache.commons.lang3.mutable.MutableInt
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
object GpuSemaphore {
private val enabled = {
val propstr = System.getProperty("com.nvidia.spark.rapids.semaphore.enabled")
if (propstr != null) {
java.lang.Boolean.parseBoolean(propstr)
} else {
true
}
}
// DO NOT ACCESS DIRECTLY! Use `getInstance` instead.
@volatile private var instance: GpuSemaphore = _
private def getInstance: GpuSemaphore = {
if (instance == null) {
GpuSemaphore.synchronized {
// The instance is trying to be used before it is initialized.
// Since we don't have access to a configuration object here,
// default to only one task per GPU behavior.
if (instance == null) {
initialize(1)
}
}
}
instance
}
/**
* Initializes the GPU task semaphore.
* @param tasksPerGpu number of tasks that will be allowed to use the GPU concurrently.
*/
def initialize(tasksPerGpu: Int): Unit = synchronized {
if (enabled) {
if (instance != null) {
throw new IllegalStateException("already initialized")
}
instance = new GpuSemaphore(tasksPerGpu)
}
}
/**
* Tasks must call this when they begin to use the GPU.
* If the task has not already acquired the GPU semaphore then it is acquired,
* blocking if necessary.
* NOTE: A task completion listener will automatically be installed to ensure
* the semaphore is always released by the time the task completes.
*/
def acquireIfNecessary(context: TaskContext): Unit = {
if (enabled && context != null) {
getInstance.acquireIfNecessary(context)
}
}
/**
* Tasks must call this when they are finished using the GPU.
*/
def releaseIfNecessary(context: TaskContext): Unit = {
if (enabled && context != null) {
getInstance.releaseIfNecessary(context)
}
}
/**
* Uninitialize the GPU semaphore.
* NOTE: This does not wait for active tasks to release!
*/
def shutdown(): Unit = synchronized {
if (instance != null) {
instance.shutdown()
instance = null
}
}
}
private final class GpuSemaphore(tasksPerGpu: Int) extends Logging {
private val semaphore = new Semaphore(tasksPerGpu)
// Map to track which tasks have acquired the semaphore.
private val activeTasks = new ConcurrentHashMap[Long, MutableInt]
def acquireIfNecessary(context: TaskContext): Unit = {
val nvtxRange = new NvtxRange("Acquire GPU", NvtxColor.RED)
try {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs == null || refs.getValue == 0) {
logDebug(s"Task $taskAttemptId acquiring GPU")
semaphore.acquire()
if (refs != null) {
refs.increment()
} else {
// first time this task has been seen
activeTasks.put(taskAttemptId, new MutableInt(1))
context.addTaskCompletionListener[Unit](completeTask)
}
GpuDeviceManager.initializeFromTask()
}
} finally {
nvtxRange.close()
}
}
def releaseIfNecessary(context: TaskContext): Unit = {
val nvtxRange = new NvtxRange("Release GPU", NvtxColor.RED)
try {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.getValue > 0) {
if (refs.decrementAndGet() == 0) {
logDebug(s"Task $taskAttemptId releasing GPU")
semaphore.release()
}
}
} finally {
nvtxRange.close()
}
}
def completeTask(context: TaskContext): Unit = {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
}
if (refs.getValue > 0) {
logDebug(s"Task $taskAttemptId releasing GPU")
semaphore.release()
}
}
def shutdown(): Unit = {
if (!activeTasks.isEmpty) {
logDebug(s"shutting down with ${activeTasks.size} tasks still registered")
}
}
}