Skip to content

Commit

Permalink
Simplify zone db locking to avoid a race (#2561)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Nov 4, 2024
1 parent 8ff8490 commit e6a7128
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 108 deletions.
138 changes: 34 additions & 104 deletions src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,67 +52,34 @@ public class GpuTimeZoneDB {
// For the timezone database, we store the transitions in a ColumnVector that is a list of
// structs. The type of this column vector is:
// LIST<STRUCT<utcInstant: int64, localInstant: int64, offset: int32>>
private Map<String, Integer> zoneIdToTable;
private static Map<String, Integer> zoneIdToTable;

// use this reference to indicate if time zone cache is initialized.
private HostColumnVector fixedTransitions;
private static HostColumnVector fixedTransitions;

// Guarantee singleton instance
private GpuTimeZoneDB() {
}

// singleton instance
private static final GpuTimeZoneDB instance = new GpuTimeZoneDB();

// This method is default visibility for testing purposes only.
// The instance will be never be exposed publicly for this class.
static GpuTimeZoneDB getInstance() {
return instance;
}

static class LoadingLock {
Boolean isLoading = false;

// record whether a shutdown is called ever.
// if `isCloseCalledEver` is true, then the following loading should be skipped.
Boolean isShutdownCalledEver = false;
}

private static final LoadingLock lock = new LoadingLock();
private static boolean isShutdownCalledEver = false;

/**
* This should be called on startup of an executor.
* Runs in a thread asynchronously.
* If `shutdown` was called ever, then will not load the cache
*/
public static void cacheDatabaseAsync() {
synchronized (lock) {
if (lock.isShutdownCalledEver) {
// shutdown was called ever, will never load cache again.
// This has a race in that we could still launch a thread after
// shutting down. This is just to prevent the thread from launching
// in some cases.
synchronized (GpuTimeZoneDB.class) {
if (isShutdownCalledEver) {
log.error("cache async called after DB already loaded");
return;
}

if (lock.isLoading) {
// another thread is loading(), return
return;
} else {
lock.isLoading = true;
}
}

// start a new thread to load
Runnable runnable = () -> {
try {
instance.cacheDatabaseImpl();
cacheDatabaseImpl();
} catch (Exception e) {
log.error("cache time zone transitions cache failed", e);
} finally {
synchronized (lock) {
// now loading is done
lock.isLoading = false;
// `cacheDatabase` and `shutdown` may wait loading is done.
lock.notify();
}
}
};
Thread thread = Executors.defaultThreadFactory().newThread(runnable);
Expand All @@ -127,55 +94,21 @@ public static void cacheDatabaseAsync() {
* If cache is exits, do not load cache again.
*/
public static void cacheDatabase() {
synchronized (lock) {
if (lock.isLoading) {
// another thread is loading(), wait loading is done
while (lock.isLoading) {
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("cache time zone transitions cache failed", e);
}
}
return;
} else {
lock.isLoading = true;
}
}

try {
instance.cacheDatabaseImpl();
} finally {
// loading is done.
synchronized (lock) {
lock.isLoading = false;
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
}
cacheDatabaseImpl();
}

/**
* close the cache, used when Plugin is closing
*/
public static void shutdown() {
synchronized (lock) {
lock.isShutdownCalledEver = true;
while (lock.isLoading) {
// wait until loading is done
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("shutdown time zone transitions cache failed", e);
}
}
instance.shutdownImpl();
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
public static synchronized void shutdown() {
isShutdownCalledEver = true;
closeResources();
}

private void cacheDatabaseImpl() {
private static synchronized void cacheDatabaseImpl() {
if (isShutdownCalledEver) {
throw new IllegalStateException("GpuTimeZoneDB has already been shut down");
}
if (fixedTransitions == null) {
try {
loadData();
Expand All @@ -186,11 +119,7 @@ private void cacheDatabaseImpl() {
}
}

private void shutdownImpl() {
closeResources();
}

private void closeResources() {
private static synchronized void closeResources() {
if (zoneIdToTable != null) {
zoneIdToTable.clear();
zoneIdToTable = null;
Expand All @@ -208,9 +137,12 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
currentTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = instance.getZoneIDMap().get(currentTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
Integer tzIndex = zoneIdToTable.get(currentTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
return new ColumnVector(convertTimestampColumnToUTC(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand All @@ -223,9 +155,12 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
desiredTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = instance.getZoneIDMap().get(desiredTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
Integer tzIndex = zoneIdToTable.get(desiredTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
return new ColumnVector(convertUTCTimestampColumnToTimeZone(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand Down Expand Up @@ -258,7 +193,7 @@ public static ZoneId getZoneId(String timeZoneId) {
}

@SuppressWarnings("unchecked")
private void loadData() {
private static synchronized void loadData() {
try {
List<List<HostColumnVector.StructData>> masterTransitions = new ArrayList<>();
zoneIdToTable = new HashMap<>();
Expand Down Expand Up @@ -334,17 +269,13 @@ private void loadData() {
}
}

private Map<String, Integer> getZoneIDMap() {
return zoneIdToTable;
}

private Table getTransitions() {
private static synchronized Table getTransitions() {
try (ColumnVector fixedTransitions = getFixedTransitions()) {
return new Table(fixedTransitions);
}
}

private ColumnVector getFixedTransitions() {
private static synchronized ColumnVector getFixedTransitions() {
return fixedTransitions.copyToDevice();
}

Expand All @@ -358,16 +289,15 @@ private ColumnVector getFixedTransitions() {
* @param zoneId
* @return list of fixed transitions
*/
List getHostFixedTransitions(String zoneId) {
static synchronized List getHostFixedTransitions(String zoneId) {
zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe
Integer idx = getZoneIDMap().get(zoneId);
Integer idx = zoneIdToTable.get(zoneId);
if (idx == null) {
return null;
}
return fixedTransitions.getList(idx);
}


private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex);

private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex);
Expand Down
7 changes: 3 additions & 4 deletions src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,11 +44,10 @@ static void cleanup() {
@Test
void databaseLoadedTest() {
// Check for a few timezones
GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance();
List transitions = instance.getHostFixedTransitions("UTC+8");
List transitions = GpuTimeZoneDB.getHostFixedTransitions("UTC+8");
assertNotNull(transitions);
assertEquals(1, transitions.size());
transitions = instance.getHostFixedTransitions("Asia/Shanghai");
transitions = GpuTimeZoneDB.getHostFixedTransitions("Asia/Shanghai");
assertNotNull(transitions);
ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized();
assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size());
Expand Down

0 comments on commit e6a7128

Please sign in to comment.