Skip to content

Commit

Permalink
Revert "Simplify zone db locking to avoid a race (#2561)"
Browse files Browse the repository at this point in the history
This reverts commit e6a7128.
  • Loading branch information
revans2 authored Nov 4, 2024
1 parent e6a7128 commit 2dc0ede
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 37 deletions.
138 changes: 104 additions & 34 deletions src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,67 @@ public class GpuTimeZoneDB {
// For the timezone database, we store the transitions in a ColumnVector that is a list of
// structs. The type of this column vector is:
// LIST<STRUCT<utcInstant: int64, localInstant: int64, offset: int32>>
private static Map<String, Integer> zoneIdToTable;
private Map<String, Integer> zoneIdToTable;

// use this reference to indicate if time zone cache is initialized.
private static HostColumnVector fixedTransitions;
private HostColumnVector fixedTransitions;

private static boolean isShutdownCalledEver = false;
// Guarantee singleton instance
private GpuTimeZoneDB() {
}

// singleton instance
private static final GpuTimeZoneDB instance = new GpuTimeZoneDB();

// This method is default visibility for testing purposes only.
// The instance will be never be exposed publicly for this class.
static GpuTimeZoneDB getInstance() {
return instance;
}

static class LoadingLock {
Boolean isLoading = false;

// record whether a shutdown is called ever.
// if `isCloseCalledEver` is true, then the following loading should be skipped.
Boolean isShutdownCalledEver = false;
}

private static final LoadingLock lock = new LoadingLock();

/**
* This should be called on startup of an executor.
* Runs in a thread asynchronously.
* If `shutdown` was called ever, then will not load the cache
*/
public static void cacheDatabaseAsync() {
// This has a race in that we could still launch a thread after
// shutting down. This is just to prevent the thread from launching
// in some cases.
synchronized (GpuTimeZoneDB.class) {
if (isShutdownCalledEver) {
log.error("cache async called after DB already loaded");
synchronized (lock) {
if (lock.isShutdownCalledEver) {
// shutdown was called ever, will never load cache again.
return;
}

if (lock.isLoading) {
// another thread is loading(), return
return;
} else {
lock.isLoading = true;
}
}

// start a new thread to load
Runnable runnable = () -> {
try {
cacheDatabaseImpl();
instance.cacheDatabaseImpl();
} catch (Exception e) {
log.error("cache time zone transitions cache failed", e);
} finally {
synchronized (lock) {
// now loading is done
lock.isLoading = false;
// `cacheDatabase` and `shutdown` may wait loading is done.
lock.notify();
}
}
};
Thread thread = Executors.defaultThreadFactory().newThread(runnable);
Expand All @@ -94,21 +127,55 @@ public static void cacheDatabaseAsync() {
* If cache is exits, do not load cache again.
*/
public static void cacheDatabase() {
cacheDatabaseImpl();
synchronized (lock) {
if (lock.isLoading) {
// another thread is loading(), wait loading is done
while (lock.isLoading) {
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("cache time zone transitions cache failed", e);
}
}
return;
} else {
lock.isLoading = true;
}
}

try {
instance.cacheDatabaseImpl();
} finally {
// loading is done.
synchronized (lock) {
lock.isLoading = false;
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
}
}

/**
* close the cache, used when Plugin is closing
*/
public static synchronized void shutdown() {
isShutdownCalledEver = true;
closeResources();
public static void shutdown() {
synchronized (lock) {
lock.isShutdownCalledEver = true;
while (lock.isLoading) {
// wait until loading is done
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("shutdown time zone transitions cache failed", e);
}
}
instance.shutdownImpl();
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
}

private static synchronized void cacheDatabaseImpl() {
if (isShutdownCalledEver) {
throw new IllegalStateException("GpuTimeZoneDB has already been shut down");
}
private void cacheDatabaseImpl() {
if (fixedTransitions == null) {
try {
loadData();
Expand All @@ -119,7 +186,11 @@ private static synchronized void cacheDatabaseImpl() {
}
}

private static synchronized void closeResources() {
private void shutdownImpl() {
closeResources();
}

private void closeResources() {
if (zoneIdToTable != null) {
zoneIdToTable.clear();
zoneIdToTable = null;
Expand All @@ -137,12 +208,9 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
currentTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = zoneIdToTable.get(currentTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
Integer tzIndex = instance.getZoneIDMap().get(currentTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
return new ColumnVector(convertTimestampColumnToUTC(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand All @@ -155,12 +223,9 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
desiredTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = zoneIdToTable.get(desiredTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
Integer tzIndex = instance.getZoneIDMap().get(desiredTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
return new ColumnVector(convertUTCTimestampColumnToTimeZone(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand Down Expand Up @@ -193,7 +258,7 @@ public static ZoneId getZoneId(String timeZoneId) {
}

@SuppressWarnings("unchecked")
private static synchronized void loadData() {
private void loadData() {
try {
List<List<HostColumnVector.StructData>> masterTransitions = new ArrayList<>();
zoneIdToTable = new HashMap<>();
Expand Down Expand Up @@ -269,13 +334,17 @@ private static synchronized void loadData() {
}
}

private static synchronized Table getTransitions() {
private Map<String, Integer> getZoneIDMap() {
return zoneIdToTable;
}

private Table getTransitions() {
try (ColumnVector fixedTransitions = getFixedTransitions()) {
return new Table(fixedTransitions);
}
}

private static synchronized ColumnVector getFixedTransitions() {
private ColumnVector getFixedTransitions() {
return fixedTransitions.copyToDevice();
}

Expand All @@ -289,15 +358,16 @@ private static synchronized ColumnVector getFixedTransitions() {
* @param zoneId
* @return list of fixed transitions
*/
static synchronized List getHostFixedTransitions(String zoneId) {
List getHostFixedTransitions(String zoneId) {
zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe
Integer idx = zoneIdToTable.get(zoneId);
Integer idx = getZoneIDMap().get(zoneId);
if (idx == null) {
return null;
}
return fixedTransitions.getList(idx);
}


private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex);

private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex);
Expand Down
7 changes: 4 additions & 3 deletions src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,10 +44,11 @@ static void cleanup() {
@Test
void databaseLoadedTest() {
// Check for a few timezones
List transitions = GpuTimeZoneDB.getHostFixedTransitions("UTC+8");
GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance();
List transitions = instance.getHostFixedTransitions("UTC+8");
assertNotNull(transitions);
assertEquals(1, transitions.size());
transitions = GpuTimeZoneDB.getHostFixedTransitions("Asia/Shanghai");
transitions = instance.getHostFixedTransitions("Asia/Shanghai");
assertNotNull(transitions);
ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized();
assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size());
Expand Down

0 comments on commit 2dc0ede

Please sign in to comment.