Skip to content

Commit

Permalink
Support variables in base files for configs (#1083)
Browse files Browse the repository at this point in the history
* Support variables in base files for configs

Signed-off-by: lizz <[email protected]>

* Test json and yaml as well

Signed-off-by: lizz <[email protected]>

* Add test for recusive base

Signed-off-by: lizz <[email protected]>

* Test misleading values

Signed-off-by: lizz <[email protected]>

* Improve comments

Signed-off-by: lizz <[email protected]>

* Add doc

Signed-off-by: lizz <[email protected]>

* Improve doc

Signed-off-by: lizz <[email protected]>

* More tests

Signed-off-by: lizz <[email protected]>

* Harder test case

Signed-off-by: lizz <[email protected]>

* use BASE_KEY instead of base

Signed-off-by: lizz <[email protected]>
  • Loading branch information
innerlee authored Jun 25, 2021
1 parent eb08835 commit d9effbd
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 0 deletions.
26 changes: 26 additions & 0 deletions docs/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ _base_ = ['./config_a.py', './config_e.py']
... d='string')
```

#### Reference variables from base

You can reference variables defined in base using the following grammar.

`base.py`

```python
item1 = 'a'
item2 = dict(item3 = 'b')
```

`config_g.py`

```python
_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```

```python
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
```

### ProgressBar

If you want to apply a method to a list of items and track the progress, `track_progress`
Expand Down
60 changes: 60 additions & 0 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
Expand Down Expand Up @@ -121,6 +123,57 @@ def _substitute_predefined_vars(filename, temp_config_name):
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)

@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict

@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)

if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split('.'):
new_v = new_v[new_k]
cfg = new_v

return cfg

@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
Expand All @@ -141,6 +194,9 @@ def _file2dict(filename, use_predefined_variables=True):
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)

if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
Expand Down Expand Up @@ -185,6 +241,10 @@ def _file2dict(filename, use_predefined_variables=True):
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)

# Subtitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)

base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict

Expand Down
13 changes: 13 additions & 0 deletions tests/data/config/t.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"_base_": [
"./l1.py",
"./l2.yaml",
"./l3.json",
"./l4.py"
],
"item3": false,
"item4": "test",
"item8": "{{fileBasename}}",
"item9": {{ _base_.item2 }},
"item10": {{ _base_.item7.b.c }}
}
6 changes: 6 additions & 0 deletions tests/data/config/t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 = False
item4 = 'test'
item8 = '{{fileBasename}}'
item9 = {{ _base_.item2 }}
item10 = {{ _base_.item7.b.c }}
6 changes: 6 additions & 0 deletions tests/data/config/t.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 : False
item4 : 'test'
item8 : '{{fileBasename}}'
item9 : {{ _base_.item2 }}
item10 : {{ _base_.item7.b.c }}
26 changes: 26 additions & 0 deletions tests/data/config/u.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"_base_": [
"./t.py"
],
"base": "_base_.item8",
"item11": {{ _base_.item8 }},
"item12": {{ _base_.item9 }},
"item13": {{ _base_.item10 }},
"item14": {{ _base_.item1 }},
"item15": {
"a": {
"b": {{ _base_.item2 }}
},
"b": [
{{ _base_.item3 }}
],
"c": [{{ _base_.item4 }}],
"d": [[
{
"e": {{ _base_.item5.a }}
}
],
{{ _base_.item6 }}],
"e": {{ _base_.item1 }}
}
}
13 changes: 13 additions & 0 deletions tests/data/config/u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./t.py']
base = '_base_.item8'
item11 = {{ _base_.item8 }}
item12 = {{ _base_.item9 }}
item13 = {{ _base_.item10 }}
item14 = {{ _base_.item1 }}
item15 = dict(
a = dict( b = {{ _base_.item2 }} ),
b = [{{ _base_.item3 }}],
c = [{{ _base_.item4 }}],
d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}],
e = {{ _base_.item1 }}
)
15 changes: 15 additions & 0 deletions tests/data/config/u.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_: ["./t.py"]
base: "_base_.item8"
item11: {{ _base_.item8 }}
item12: {{ _base_.item9 }}
item13: {{ _base_.item10 }}
item14: {{ _base_.item1 }}
item15:
a:
b: {{ _base_.item2 }}
b: [{{ _base_.item3 }}]
c: [{{ _base_.item4 }}]
d:
- [e: {{ _base_.item5.a }}]
- {{ _base_.item6 }}
e: {{ _base_.item1 }}
11 changes: 11 additions & 0 deletions tests/data/config/v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = ['./u.py']
item21 = {{ _base_.item11 }}
item22 = item21
item23 = {{ _base_.item10 }}
item24 = item23
item25 = dict(
a = dict( b = item24 ),
b = [item24],
c = [[dict(e = item22)],{{ _base_.item6 }}],
e = item21
)
75 changes: 75 additions & 0 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,81 @@ def test_merge_from_multiple_bases():
Config.fromfile(osp.join(data_path, 'config/m.py'))


def test_base_variables():
for file in ['t.py', 't.json', 't.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == file
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]

# test nested base
for file in ['u.py', 'u.json', 'u.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.base == '_base_.item8'
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 't.py'
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]
assert cfg.item11 == 't.py'
assert cfg.item12 == dict(a=0)
assert cfg.item13 == [3.1, 4.2, 5.3]
assert cfg.item14 == [1, 2]
assert cfg.item15 == dict(
a=dict(b=dict(a=0)),
b=[False],
c=['test'],
d=[[{
'e': 0
}], [{
'a': 0
}, {
'b': 1
}]],
e=[1, 2])

# test reference assignment for py
cfg_file = osp.join(data_path, 'config/v.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.item21 == 't.py'
assert cfg.item22 == 't.py'
assert cfg.item23 == [3.1, 4.2, 5.3]
assert cfg.item24 == [3.1, 4.2, 5.3]
assert cfg.item25 == dict(
a=dict(b=[3.1, 4.2, 5.3]),
b=[[3.1, 4.2, 5.3]],
c=[[{
'e': 't.py'
}], [{
'a': 0
}, {
'b': 1
}]],
e='t.py')


def test_merge_recursive_bases():
cfg_file = osp.join(data_path, 'config/f.py')
cfg = Config.fromfile(cfg_file)
Expand Down

0 comments on commit d9effbd

Please sign in to comment.