Skip to content

Commit

Permalink
reformat file and group constants together in config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer committed Jun 25, 2023
1 parent 9819330 commit afd9810
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,9 @@
import os
import pathlib

from dotenv import find_dotenv, load_dotenv
from dotenv import load_dotenv
from loguru import logger

ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
ZIPPED_RECORDING_FOLDER_PATH = ROOT_DIRPATH / "data" / "zipped"

ENV_FILE_PATH = (ROOT_DIRPATH / ".env").resolve()
logger.info(f"{ENV_FILE_PATH=}")
dotenv_file = find_dotenv()
load_dotenv(dotenv_file)


def set_db_url(db_fname):
"""Set the database URL based on the given database file name.
Args:
db_fname (str): The database file name.
"""
global DB_FNAME, DB_FPATH, DB_URL
DB_FNAME = db_fname
DB_FPATH = ROOT_DIRPATH / DB_FNAME
DB_URL = f"sqlite:///{DB_FPATH}"
logger.info(f"{DB_URL=}")
os.environ["DB_FNAME"] = db_fname


_DEFAULTS = {
"CACHE_DIR_PATH": ".cache",
Expand All @@ -46,8 +24,11 @@ def set_db_url(db_fname):
"DB_ECHO": False,
"DB_FNAME": "openadapt.db",
"OPENAI_API_KEY": "<set your api key in .env>",
# "OPENAI_MODEL_NAME": "gpt-4",
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
# may incur significant performance penalty
"RECORD_READ_ACTIVE_ELEMENT_STATE": False,
# TODO: remove?
"REPLAY_STRIP_ELEMENT_STATE": True,
# IGNORES WARNINGS (PICKLING, ETC.)
# TODO: ignore warnings by default on GUI
Expand Down Expand Up @@ -123,17 +104,19 @@ def getenv_fallback(var_name):
return rval


load_dotenv()

for key in _DEFAULTS:
val = getenv_fallback(key)
locals()[key] = val


ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
DB_FPATH = ROOT_DIRPATH / DB_FNAME # type: ignore # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
DIRNAME_PERFORMANCE_PLOTS = "performance"
DB_ECHO = False
DT_FMT = "%Y-%m-%d_%H-%M-%S"
DIRNAME_PERFORMANCE_PLOTS = "performance"
ZIPPED_RECORDING_FOLDER_PATH = ROOT_DIRPATH / "data" / "zipped"
ENV_FILE_PATH = (ROOT_DIRPATH / ".env").resolve()

if multiprocessing.current_process().name == "MainProcess":
for key, val in locals().items():
Expand All @@ -157,4 +140,18 @@ def filter_log_messages(data):
messages_to_ignore = [
"Cannot pickle Objective-C objects",
]
return not any(msg in data["message"] for msg in messages_to_ignore)
return not any(msg in data["message"] for msg in messages_to_ignore)


def set_db_url(db_fname):
"""Set the database URL based on the given database file name.
Args:
db_fname (str): The database file name.
"""
global DB_FNAME, DB_FPATH, DB_URL
DB_FNAME = db_fname
DB_FPATH = ROOT_DIRPATH / DB_FNAME
DB_URL = f"sqlite:///{DB_FPATH}"
logger.info(f"{DB_URL=}")
os.environ["DB_FNAME"] = db_fname

0 comments on commit afd9810

Please sign in to comment.