Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify zone db locking to avoid a race #2561

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 34 additions & 104 deletions src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,67 +52,34 @@ public class GpuTimeZoneDB {
// For the timezone database, we store the transitions in a ColumnVector that is a list of
// structs. The type of this column vector is:
// LIST<STRUCT<utcInstant: int64, localInstant: int64, offset: int32>>
private Map<String, Integer> zoneIdToTable;
private static Map<String, Integer> zoneIdToTable;

// use this reference to indicate if time zone cache is initialized.
private HostColumnVector fixedTransitions;
private static HostColumnVector fixedTransitions;

// Guarantee singleton instance
private GpuTimeZoneDB() {
}

// singleton instance
private static final GpuTimeZoneDB instance = new GpuTimeZoneDB();

// This method is default visibility for testing purposes only.
// The instance will be never be exposed publicly for this class.
static GpuTimeZoneDB getInstance() {
return instance;
}

static class LoadingLock {
Boolean isLoading = false;

// record whether a shutdown is called ever.
// if `isCloseCalledEver` is true, then the following loading should be skipped.
Boolean isShutdownCalledEver = false;
}

private static final LoadingLock lock = new LoadingLock();
private static boolean isShutdownCalledEver = false;

/**
* This should be called on startup of an executor.
* Runs in a thread asynchronously.
* If `shutdown` was called ever, then will not load the cache
*/
public static void cacheDatabaseAsync() {
synchronized (lock) {
if (lock.isShutdownCalledEver) {
// shutdown was called ever, will never load cache again.
// This has a race in that we could still launch a thread after
// shutting down. This is just to prevent the thread from launching
// in some cases.
synchronized (GpuTimeZoneDB.class) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still leaves room for a race where shutdown is called from another thread after the lock is released on L87. Should we make the whole method cacheDatabaseAsync synchronized instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the second thread will check that the resource it is trying to update is not null, which would be closed and checked under the lock. So the second thread will do work, unnecessarily, but I don't see a case for a runtime error here, unless I am missing something.

But agree, what if this was all locked, is it bad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gerashegalov is 100% correct. I will fix it.

Is it bad?

Yes and no. We should not get inconsistent data, but we might load data after it was shutdown was called, and have no way to properly free it. It is on shutdown, but the change is small enough, and better enough, that I think it is best.

if (isShutdownCalledEver) {
log.error("cache async called after DB already loaded");
return;
}

if (lock.isLoading) {
// another thread is loading(), return
return;
} else {
lock.isLoading = true;
}
}

// start a new thread to load
Runnable runnable = () -> {
try {
instance.cacheDatabaseImpl();
cacheDatabaseImpl();
} catch (Exception e) {
log.error("cache time zone transitions cache failed", e);
} finally {
synchronized (lock) {
// now loading is done
lock.isLoading = false;
// `cacheDatabase` and `shutdown` may wait loading is done.
lock.notify();
}
}
};
Thread thread = Executors.defaultThreadFactory().newThread(runnable);
Expand All @@ -127,55 +94,21 @@ public static void cacheDatabaseAsync() {
* If cache is exits, do not load cache again.
*/
public static void cacheDatabase() {
synchronized (lock) {
if (lock.isLoading) {
// another thread is loading(), wait loading is done
while (lock.isLoading) {
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("cache time zone transitions cache failed", e);
}
}
return;
} else {
lock.isLoading = true;
}
}

try {
instance.cacheDatabaseImpl();
} finally {
// loading is done.
synchronized (lock) {
lock.isLoading = false;
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
}
cacheDatabaseImpl();
}

/**
* close the cache, used when Plugin is closing
*/
public static void shutdown() {
synchronized (lock) {
lock.isShutdownCalledEver = true;
while (lock.isLoading) {
// wait until loading is done
try {
lock.wait();
} catch (InterruptedException e) {
throw new IllegalStateException("shutdown time zone transitions cache failed", e);
}
}
instance.shutdownImpl();
// `cacheDatabase` and/or `shutdown` may wait loading is done.
lock.notify();
}
public static synchronized void shutdown() {
isShutdownCalledEver = true;
closeResources();
}

private void cacheDatabaseImpl() {
private static synchronized void cacheDatabaseImpl() {
if (isShutdownCalledEver) {
throw new IllegalStateException("GpuTimeZoneDB has already been shut down");
}
if (fixedTransitions == null) {
try {
loadData();
Expand All @@ -186,11 +119,7 @@ private void cacheDatabaseImpl() {
}
}

private void shutdownImpl() {
closeResources();
}

private void closeResources() {
private static synchronized void closeResources() {
if (zoneIdToTable != null) {
zoneIdToTable.clear();
zoneIdToTable = null;
Expand All @@ -208,9 +137,12 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
currentTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = instance.getZoneIDMap().get(currentTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
Integer tzIndex = zoneIdToTable.get(currentTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
return new ColumnVector(convertTimestampColumnToUTC(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand All @@ -223,9 +155,12 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI
throw new IllegalArgumentException(String.format("Unsupported timezone: %s",
desiredTimeZone.toString()));
}
// there is technically a race condition on shutdown. Shutdown could be called after
// the database is cached. This would result in a null pointer exception at some point
// in the processing. This should be rare enough that it is not a big deal.
cacheDatabase();
Integer tzIndex = instance.getZoneIDMap().get(desiredTimeZone.normalized().toString());
try (Table transitions = instance.getTransitions()) {
Integer tzIndex = zoneIdToTable.get(desiredTimeZone.normalized().toString());
try (Table transitions = getTransitions()) {
return new ColumnVector(convertUTCTimestampColumnToTimeZone(input.getNativeView(),
transitions.getNativeView(), tzIndex));
}
Expand Down Expand Up @@ -258,7 +193,7 @@ public static ZoneId getZoneId(String timeZoneId) {
}

@SuppressWarnings("unchecked")
private void loadData() {
private static synchronized void loadData() {
try {
List<List<HostColumnVector.StructData>> masterTransitions = new ArrayList<>();
zoneIdToTable = new HashMap<>();
Expand Down Expand Up @@ -334,17 +269,13 @@ private void loadData() {
}
}

private Map<String, Integer> getZoneIDMap() {
return zoneIdToTable;
}

private Table getTransitions() {
private static synchronized Table getTransitions() {
try (ColumnVector fixedTransitions = getFixedTransitions()) {
return new Table(fixedTransitions);
}
}

private ColumnVector getFixedTransitions() {
private static synchronized ColumnVector getFixedTransitions() {
return fixedTransitions.copyToDevice();
}

Expand All @@ -358,16 +289,15 @@ private ColumnVector getFixedTransitions() {
* @param zoneId
* @return list of fixed transitions
*/
List getHostFixedTransitions(String zoneId) {
static synchronized List getHostFixedTransitions(String zoneId) {
zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe
Integer idx = getZoneIDMap().get(zoneId);
Integer idx = zoneIdToTable.get(zoneId);
if (idx == null) {
return null;
}
return fixedTransitions.getList(idx);
}


private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex);

private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex);
Expand Down
7 changes: 3 additions & 4 deletions src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,11 +44,10 @@ static void cleanup() {
@Test
void databaseLoadedTest() {
// Check for a few timezones
GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance();
List transitions = instance.getHostFixedTransitions("UTC+8");
List transitions = GpuTimeZoneDB.getHostFixedTransitions("UTC+8");
assertNotNull(transitions);
assertEquals(1, transitions.size());
transitions = instance.getHostFixedTransitions("Asia/Shanghai");
transitions = GpuTimeZoneDB.getHostFixedTransitions("Asia/Shanghai");
assertNotNull(transitions);
ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized();
assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size());
Expand Down
Loading