Skip to content

Commit

Permalink
feat: Improved cli & updated pydantic2
Browse files Browse the repository at this point in the history
  • Loading branch information
mrkbac committed Oct 12, 2023
1 parent 54cedd2 commit db1c624
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 166 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ dependencies = [
"mcap-ros1-support>=0.6.0",
"mcap-ros2-support>=0.3.0",
"numpy",
"pydantic>=1.0.0,<2.0.0",
"pydantic>=2.4.2",
"strictyaml",
"tqdm",
"pyyaml>=6.0.1",
"scipy>=1.11.1",
"jsonargparse[signatures]>=4.25.0",
]

dynamic = ["version"]
Expand Down Expand Up @@ -70,6 +71,8 @@ ignore = [
'TCH', # flake8-type-checking
'TCH', # flake8-type-checking
'TRY003', # raise-vanilla-args
'TD',
'FIX002',
]

src = ['src']
Expand Down
298 changes: 159 additions & 139 deletions src/kappe/cli.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
"""
Convert mcap (ROS1 & ROS2) files.
Message definitions:
Message definitions are read from ROS and disk ./msgs/
git clone --depth=1 --branch=humble https://github.com/ros2/common_interfaces.git msgs
"""

import argparse
import logging
from multiprocessing import Pool, RLock
from pathlib import Path
from typing import Any

import pydantic
import strictyaml
from jsonargparse import CLI
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from kappe import __version__
from kappe.convert import Converter
from kappe.cut import CutSettings, cutter
from kappe.cut import CutSettings, CutSplitOn, CutSplits, cutter
from kappe.module.pointcloud import SettingPointCloud
from kappe.module.tf import SettingTF
from kappe.module.timing import SettingTimeOffset
from kappe.plugin import load_plugin
from kappe.settings import Settings
from kappe.settings import SettingGeneral, SettingPlugin, Settings, SettingSchema, SettingTopic


class TqdmLoggingHandler(logging.Handler):
Expand All @@ -43,31 +36,7 @@ def emit(self, record: Any):
logger = logging.getLogger(__name__)


def print_error(e: pydantic.ValidationError, config_yaml: strictyaml.YAML):
logger.info('Failed to parse config file')
for err in e.errors():
yaml_obj = config_yaml
for x in err['loc']:
k = None

match x:
case int(idx) if len(yaml_obj) > idx:
k = yaml_obj[idx]
case str(key):
k = yaml_obj.get(key)

if k is None:
break

yaml_obj = yaml_obj[x]

loc = ' -> '.join(str(x) for x in err['loc'])
msg = err['msg']
err_type = err['type']
logger.info('%s: %s @ Line: %i "%s"', err_type, msg, yaml_obj.start_line, loc)


def worker(arg: tuple[Path, Path, Settings, int]):
def convert_worker(arg: tuple[Path, Path, Settings, int]):
# TODO: dataclass
input_path, output_path, config, tqdm_idx = arg

Expand All @@ -86,7 +55,13 @@ def worker(arg: tuple[Path, Path, Settings, int]):
logger.info('Done %s', output_path)


def process(config: Settings, input_path: Path, output_path: Path, *, overwrite: bool) -> None:
def convert_process(
config: Settings,
input_path: Path,
output_path: Path,
*,
overwrite: bool,
) -> None:
tasks: list[tuple[Path, Path, Settings, int]] = []

# TODO: make more generic
Expand All @@ -110,13 +85,10 @@ def process(config: Settings, input_path: Path, output_path: Path, *, overwrite:
return

logger.info('Using %d threads', config.general.threads)
tqdm.set_lock(RLock()) # for managing output contention

pool = None
try:
pool = Pool(min(config.general.threads, len(tasks)),
initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),))
pool.map(worker, tasks)
process_map(convert_worker, tasks, max_workers=min(config.general.threads, len(tasks)))
except KeyboardInterrupt:
logger.info('Keyboard interrupt')
finally:
Expand All @@ -125,113 +97,161 @@ def process(config: Settings, input_path: Path, output_path: Path, *, overwrite:
pool.join()


def cmd_convert(args: argparse.Namespace):
if args.config is None:
config = Settings()
else:

config_text = args.config.read()
config_yaml: strictyaml.YAML = strictyaml.load(config_text)
try:
config = Settings(**config_yaml.data)
except pydantic.ValidationError as e:
print_error(e, config_yaml)
return
class KappeCLI:
def convert( # noqa: PLR0913, PLR0912
self,
input: Path, # noqa: A002
output: Path,
*,
general: SettingGeneral | None = None,
topic: SettingTopic | None = None,
tf_static: SettingTF | None = None,
msg_schema: SettingSchema | None = None,
msg_folder: Path | None = Path('./msgs'),
point_cloud: dict[str, SettingPointCloud] | None = None,
time_offset: dict[str, SettingTimeOffset] | None = None,
plugins: list[SettingPlugin] | None = None,
plugin_folder: Path | None = Path('./plugins'),
time_start: float | None = None,
time_end: float | None = None,
keep_all_static_tf: bool = False,
overwrite: bool = False,
) -> None:
"""Convert mcap(s) with changing, filtering, converting, ... data.
Args:
input: Input mcap or folder of mcaps.
general: General settings (threads, etc.).
topic: Migrations for topics (remove, rename, etc.).
tf_static: Migrations for TF (insert, remove).
msg_schema: Updating or changing a schema.
msg_folder: Path to the folder containing .msg files used to change the schema and
upgrading from ROS1.
point_cloud: Migrations for point clouds (Update filed, rotate, etc.).
time_offset: Migrations for time (Add offset, sync with mcap time, etc.).
plugins: Settings to loading custom plugins.
plugin_folder: Path to plugin folder.
time_start: Start time of the new MCAP.
time_end: End time of the new MCAP.
keep_all_static_tf: If true ensue all /tf_static messages are in the outputted file.
overwrite: If true already existing files will be overwritten.
"""

# TODO: cleanup
if general is None:
general = SettingGeneral()
if topic is None:
topic = SettingTopic()
if tf_static is None:
tf_static = SettingTF()
if msg_schema is None:
msg_schema = SettingSchema()
if point_cloud is None:
point_cloud = {}
if time_offset is None:
time_offset = {}
if plugins is None:
plugins = []

config.raw_text = config_text

# check for msgs folder
if config.msg_folder is not None and not config.msg_folder.exists():
logger.error('msg_folder does not exist: %s', config.msg_folder)
config.msg_folder = None

errors = False

for conv in config.plugins:
try:
load_plugin(config.plugin_folder, conv.name)
continue
except ValueError:
pass

errors = True
logger.error('Failed to load plugin: %s', conv.name)
config = Settings()
config.general = general
config.topic = topic
config.tf_static = tf_static
config.msg_schema = msg_schema
config.point_cloud = point_cloud
config.time_offset = time_offset
config.plugins = plugins
config.time_start = time_start
config.time_end = time_end
config.keep_all_static_tf = keep_all_static_tf
config.msg_folder = msg_folder
config.plugin_folder = plugin_folder

# check for msgs folder
if msg_folder is not None and not msg_folder.exists():
logger.error('msg_folder does not exist: %s', msg_folder)
msg_folder = None

errors = False

for conv in plugins:
try:
load_plugin(plugin_folder, conv.name)
continue
except ValueError:
pass

input_path: Path = args.input
if not input_path.exists():
raise FileNotFoundError(f'Input path does not exist: {input_path}')
errors = True
logger.error('Failed to load plugin: %s', conv.name)

output_path: Path = args.output
input_path: Path = input
if not input_path.exists():
raise FileNotFoundError(f'Input path does not exist: {input_path}')

if errors:
logger.error('Errors found, aborting')
else:
process(config, input_path, output_path, overwrite=args.overwrite)
output_path: Path = output

if errors:
logger.error('Errors found, aborting')
else:
convert_process(
config,
input_path,
output_path,
overwrite=overwrite,
)

def cut( # noqa: PLR0913
self,
mcap: Path,
output: Path = Path('./output'),
*,
overwrite: bool = False,
keep_tf_tree: bool = False,
splits: list[CutSplits] | None = None,
topic: str | None = None,
debounce: float = 0.0,
) -> None:
"""
Cut a mcap based on time or maker topic.
Args:
mcap: Input mcap file.
output: Output folder.
overwrite: Overwrite existing files.
keep_tf_tree: Keep all /tf_static message in file.
splits: List of splits.
topic: Topic to use for splitting.
debounce: Number of seconds to wait before splitting on the same topic.
"""
split_on_topic = None

if output.exists() and not overwrite:
logger.error('Output folder already exists. Delete or use --overwrite=true.')
return

def cmd_cut(args: argparse.Namespace):
logger.info('cut')
if topic is not None:
split_on_topic = CutSplitOn(
topic=topic,
debounce=debounce,
)

config_text = args.config.read()
config_yaml: strictyaml.YAML = strictyaml.load(config_text)
config = CutSettings(
keep_tf_tree=keep_tf_tree,
splits=splits,
split_on_topic=split_on_topic,
)

try:
config = CutSettings(**config_yaml.data)
except pydantic.ValidationError as e:
print_error(e, config_yaml)
return
cutter(mcap, output, config)

cutter(args.input, args.output_folder, config)
def version(self) -> None:
"""
Print version.
"""
logger.info('kappe %s', __version__)


def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
fromfile_prefix_chars='@',
)

parser.add_argument('--verbose', action='store_true', help='Verbose output')
parser.add_argument('--version', action='version', version=__version__)

sub = parser.add_subparsers(
title='subcommands',
required=True,
)

cutter = sub.add_parser('cut')
cutter.set_defaults(func=cmd_cut)

cutter.add_argument('input', type=Path, help='input file')
cutter.add_argument(
'output_folder',
type=Path,
help='output folder, default: ./cut_out',
default=Path('./cut_out'),
nargs='?')
cutter.add_argument('--config', type=argparse.FileType(), help='config file', required=True)
cutter.add_argument(
'--overwrite',
action='store_true',
help='Overwrite existing files')

convert = sub.add_parser('convert')
convert.set_defaults(func=cmd_convert)

convert.add_argument('input', type=Path, help='input folder or file')
convert.add_argument('output', type=Path, help='output folder')
convert.add_argument('--config', type=argparse.FileType(), help='config file')
convert.add_argument(
'--overwrite',
action='store_true',
help='Overwrite existing files')

args = parser.parse_args()

if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)

args.func(args)
CLI(KappeCLI)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit db1c624

Please sign in to comment.