Skip to content

Commit

Permalink
fix synchronous trainging error for redis backend
Browse files Browse the repository at this point in the history
  • Loading branch information
fuhailin authored and rhdong committed Sep 7, 2023
1 parent 7f20412 commit 363bdb4
Showing 1 changed file with 5 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
# pylint: disable=g-bad-name

import copy
import fcntl
import functools
import json
import os
import warnings

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.lookup_ops import LookupInterface
from tensorflow.python.training.saver import BaseSaverBuilder

Expand Down Expand Up @@ -156,7 +153,6 @@ def __init__(
self._new_obj_trackable = None # for restore op can easily found this table

self.redis_config_file_exist = False
self.redis_config_file_create = False

if self._config.redis_config_abs_dir_env:
if self._config.redis_config_abs_dir_env in os.environ:
Expand All @@ -169,7 +165,7 @@ def __init__(
" in system environment variable.")
self.redis_config_file_exist = os.path.exists(
self._config.redis_config_abs_dir)
if self.redis_config_file_exist == False:
if not self.redis_config_file_exist:
raise ValueError(
"Config redis_config_abs_dir_env in RedisTableConfig is not None, but the FILE which path stored in environment variable "
+ self._config.redis_config_abs_dir_env + " DOES NOT EXIST.")
Expand All @@ -180,20 +176,19 @@ def __init__(
warnings.warn(
"TFRA-Redis try to use environment variable TFRA_REDIS_CONFIG_PATH regardless redis_config_abs_dir in RedisTableConfig."
)
if self.redis_config_file_exist == False:
if not self.redis_config_file_exist:
raise ValueError(
"environment variable TFRA_REDIS_CONFIG_PATH exists, but the FILE which path stored in TFRA_REDIS_CONFIG_PATH DOES NOT EXIST. Please create a FILE in the corresponding path or delete the environment variable TFRA_REDIS_CONFIG_PATH."
)
"environment variable TFRA_REDIS_CONFIG_PATH exists, but the FILE which path stored in TFRA_REDIS_CONFIG_PATH DOES NOT EXIST. Please create a FILE in the corresponding path or delete "
"the environment variable TFRA_REDIS_CONFIG_PATH.")
elif self._config.redis_config_abs_dir_env is None and "TFRA_REDIS_CONFIG_PATH" not in os.environ and self._config.redis_config_abs_dir:
self.redis_config_file_exist = os.path.exists(
self._config.redis_config_abs_dir)
if self.redis_config_file_exist == False:
if not self.redis_config_file_exist:
raise ValueError(
"Config redis_config_abs_dir in RedisTableConfig is not None and redis_config_abs_dir_env is None, but the FILE "
+ self._config.redis_config_abs_dir +
" which path is redis_config_abs_dir DOES NOT EXIST.")
elif self._config.redis_config_abs_dir_env is None and "TFRA_REDIS_CONFIG_PATH" not in os.environ and self._config.redis_config_abs_dir is None:
self.redis_config_file_create = True
self._config.redis_config_abs_dir = "/tmp/tmp_TFRA_Redis_config_file.json"
warnings.warn(
"Both redis_config_abs_dir_env and redis_config_abs_dir in RedisTableConfig are None, now creating a temporary config file in /tmp/tmp_TFRA_Redis_config_file.json."
Expand All @@ -203,27 +198,6 @@ def __init__(
"TFRA-Redis didn't get the correct RedisTableConfig class initial parameter."
)

if self.redis_config_file_create == True and self.redis_config_file_exist == False:
with open(self._config.redis_config_abs_dir, 'w+',
encoding='utf-8') as f0:
fcntl.flock(f0, fcntl.LOCK_EX)
f0.write(
json.dumps(self.default_redis_params, indent=2, ensure_ascii=True))
fcntl.flock(f0, fcntl.LOCK_UN)
else:
with open(self._config.redis_config_abs_dir, 'r', encoding='utf-8') as f0:
fcntl.flock(f0, fcntl.LOCK_EX)
params_load = json.load(f0)
fcntl.flock(f0, fcntl.LOCK_UN)
self._redis_params = self.default_redis_params.copy()
for k in self._redis_params.keys():
if k in params_load:
self._redis_params[k] = params_load[k]
with open(self._config.redis_config_abs_dir, 'w', encoding='utf-8') as f1:
fcntl.flock(f1, fcntl.LOCK_EX)
f1.write(json.dumps(self._redis_params, indent=2, ensure_ascii=True))
fcntl.flock(f1, fcntl.LOCK_UN)

self._shared_name = None
if context.executing_eagerly():
# TODO(allenl): This will leak memory due to kernel caching by the
Expand Down

0 comments on commit 363bdb4

Please sign in to comment.