Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a "lifelong" CSP solver that reuses past solutions #109

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ python experiments/run_single_experiment.py -m \
env.max_environment_steps=100 \
env.eval_frequency=50 \
env.num_eval_trials=1 \
csp_solver=random_walk \
csp_solver.min_num_satisfying_solutions=1 \
csp_solver.max_iters=100
```
Expand All @@ -52,7 +53,7 @@ python experiments/run_single_experiment.py \
seed=0 \
llm=openai \
approach.max_motion_planning_candidates=50 \
csp_solver.min_num_satisfying_solutions=100 \
csp_solver.base_solver.min_num_satisfying_solutions=100 \
env.env.scene_spec.surface_dust_patch_size=4 \
env.env.scene_spec.use_standard_books=true \
env.env.hidden_spec.book_preferences='I only like science fiction. I do not like any other kinds of fiction or nonfiction.' \
Expand Down
2 changes: 1 addition & 1 deletion experiments/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ defaults:
- _self_
- approach: random
- env: tiny
- csp_solver: random_walk
- csp_solver: lifelong_random_walk
- rom_model: spherical
- llm: canned # change to openai for final experiments
8 changes: 8 additions & 0 deletions experiments/conf/csp_solver/lifelong_random_walk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_target_: "multitask_personalization.csp_solvers.LifelongCSPSolverWrapper"
base_solver:
_target_: "multitask_personalization.csp_solvers.RandomWalkCSPSolver"
seed: ${seed}
max_iters: 100000
min_num_satisfying_solutions: 50
show_progress_bar: True
seed: ${seed}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ split_on_trailing_comma = true
[tool.mypy]
strict_equality = true
disallow_untyped_calls = true
disable_error_code = ["method-assign"]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warn_unreachable = true
exclude = ["venv/*"]

Expand Down
80 changes: 79 additions & 1 deletion src/multitask_personalization/csp_solvers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""Different methods for solving CSPs."""

import abc
from collections import defaultdict, deque
from typing import Any

import numpy as np
from tqdm import tqdm

from multitask_personalization.structs import CSP, CSPSampler, CSPVariable
from multitask_personalization.structs import (
CSP,
CSPConstraint,
CSPSampler,
CSPVariable,
FunctionalCSPSampler,
)


class CSPSolver(abc.ABC):
Expand Down Expand Up @@ -83,3 +90,74 @@ def solve(
sol = sol.copy()
sol.update(partial_sol)
return best_satisfying_sol


class LifelongCSPSolverWrapper(CSPSolver):
"""A wrapper that samples from past constraint solutions."""

def __init__(
self, base_solver: CSPSolver, seed: int, memory_size: int = 100
) -> None:
super().__init__(seed)
self._base_solver = base_solver
self._memory_size = memory_size
self._constraint_to_recent_solutions: dict[
CSPConstraint, deque[dict[CSPVariable, Any]]
] = defaultdict(lambda: deque(maxlen=self._memory_size))

def solve(
self,
csp: CSP,
initialization: dict[CSPVariable, Any],
samplers: list[CSPSampler],
) -> dict[CSPVariable, Any] | None:
# Create the samplers from past experience.
memory_based_samplers = self._create_memory_based_samplers(csp)
samplers = samplers + memory_based_samplers
# Need to wrap the constraints so that we can memorize solutions.
# Note that we could also just take the output of solve() and memorize,
# but that would miss out on the opportunity to memorize intermediates.
wrapped_csp = self._wrap_csp(csp)
return self._base_solver.solve(wrapped_csp, initialization, samplers)

def _create_memory_based_samplers(self, csp: CSP) -> list[CSPSampler]:
new_samplers: list[CSPSampler] = []
for constraint in csp.constraints:
if constraint in self._constraint_to_recent_solutions:
sampler = self._create_memory_based_sampler(constraint, csp)
new_samplers.append(sampler)
return new_samplers

def _create_memory_based_sampler(
self, constraint: CSPConstraint, csp: CSP
) -> CSPSampler:
recent_solutions = self._constraint_to_recent_solutions[constraint]
num_recent_solutions = len(recent_solutions)

def _sample(
_: dict[CSPVariable, Any], rng: np.random.Generator
) -> dict[CSPVariable, Any] | None:
idx = rng.choice(num_recent_solutions)
return recent_solutions[idx]

return FunctionalCSPSampler(_sample, csp, set(constraint.variables))

def _wrap_csp(self, csp: CSP) -> CSP:
new_constraints = [self._wrap_constraint(c) for c in csp.constraints]
return CSP(csp.variables, new_constraints, csp.cost)

def _wrap_constraint(self, constraint: CSPConstraint) -> CSPConstraint:
# Make sure not to modify original constraint.
new_constraint = constraint.copy()

# Memorize solutions.
def _wrapped_check_solution(sol: dict[CSPVariable, Any]) -> bool:
result = constraint.check_solution(sol)
if result:
partial_sol = {v: sol[v] for v in constraint.variables}
self._constraint_to_recent_solutions[constraint].append(partial_sol)
return result

# Overwrite check_solution method.
new_constraint.check_solution = _wrapped_check_solution
return new_constraint
26 changes: 25 additions & 1 deletion src/multitask_personalization/structs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Common data structures."""

from __future__ import annotations

import abc
from dataclasses import dataclass
from typing import Any, Callable, Generic
Expand All @@ -26,7 +28,9 @@ def __hash__(self) -> int:

def __eq__(self, other: Any) -> bool:
assert isinstance(other, CSPVariable)
return self.name == other.name
# Example where domain checking matters: books in pybullet that may
# have different titles between evaluation episodes.
return self.name == other.name and str(self.domain) == str(other.domain)


class CSPConstraint(abc.ABC):
Expand All @@ -36,10 +40,22 @@ def __init__(self, name: str, variables: list[CSPVariable]):
self.name = name
self.variables = variables

def __hash__(self) -> int:
return hash((self.name, tuple(self.variables)))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, CSPConstraint):
return False
return self.name == other.name and self.variables == other.variables

@abc.abstractmethod
def check_solution(self, sol: dict[CSPVariable, Any]) -> bool:
"""Check whether the constraint holds given values of the variables."""

@abc.abstractmethod
def copy(self) -> CSPConstraint:
"""Create a copy of this constraint."""


class FunctionalCSPConstraint(CSPConstraint):
"""A constraint defined by a function that outputs bools."""
Expand All @@ -57,6 +73,9 @@ def check_solution(self, sol: dict[CSPVariable, Any]) -> bool:
vals = [sol[v] for v in self.variables]
return self.constraint_fn(*vals)

def copy(self) -> CSPConstraint:
return FunctionalCSPConstraint(self.name, self.variables, self.constraint_fn)


class LogProbCSPConstraint(CSPConstraint):
"""A constraint defined by a function that outputs log probabilities.
Expand Down Expand Up @@ -85,6 +104,11 @@ def get_logprob(self, sol: dict[CSPVariable, Any]) -> float:
vals = [sol[v] for v in self.variables]
return self.constraint_logprob_fn(*vals)

def copy(self) -> CSPConstraint:
return LogProbCSPConstraint(
self.name, self.variables, self.constraint_logprob_fn, self.threshold
)


@dataclass(frozen=True)
class CSPCost:
Expand Down
31 changes: 25 additions & 6 deletions tests/test_csp_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import gymnasium as gym
import numpy as np

from multitask_personalization.csp_solvers import RandomWalkCSPSolver
from multitask_personalization.csp_solvers import (
LifelongCSPSolverWrapper,
RandomWalkCSPSolver,
)
from multitask_personalization.structs import (
CSP,
CSPVariable,
Expand All @@ -13,8 +16,7 @@
)


def test_solve_csp():
"""Tests for csp_solvers.py."""
def _create_test_csp():
x = CSPVariable("x", gym.spaces.Box(0, 1, dtype=np.float_))
y = CSPVariable("y", gym.spaces.Box(0, 1, dtype=np.float_))
z = CSPVariable("z", gym.spaces.Discrete(5))
Expand All @@ -38,10 +40,27 @@ def test_solve_csp():
sampler_xy = FunctionalCSPSampler(sample_xy, csp, {x, y})
sampler_z = FunctionalCSPSampler(sample_z, csp, {z})
samplers = [sampler_xy, sampler_z]

initialization = {x: 0.0, y: 0.0, z: 0}

return csp, initialization, samplers


def test_solve_csp():
"""Tests for csp_solvers.py."""

# Test RandomWalkCSPSolver().
csp, initialization, samplers = _create_test_csp()
solver = RandomWalkCSPSolver(seed=123, show_progress_bar=False)
sol = solver.solve(csp, initialization, samplers)
assert sol is not None
assert sol[x] < sol[y]
assert sol[y] < sol[z] / 5

# Test LifelongCSPSolverWrapper(RandomWalkCSPSolver()).
# The lifelong solver should still work after deleting the samplers because
# it should use its own memory-based samplers.
lifelong_solver = LifelongCSPSolverWrapper(solver, seed=123)
sol = lifelong_solver.solve(csp, initialization, samplers)
assert sol is not None
# Regenerate the CSP to make sure that equality checking is based on names.
csp, initialization, samplers = _create_test_csp()
sol = lifelong_solver.solve(csp, initialization, [])
assert sol is not None
1 change: 1 addition & 0 deletions tests/test_run_single_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_run_single_experiment():
"env.max_environment_steps=10",
"env.eval_frequency=5",
"env.num_eval_trials=1",
"csp_solver=random_walk",
"csp_solver.min_num_satisfying_solutions=1",
"csp_solver.max_iters=1000",
]
Expand Down
Loading