Skip to content

Commit

Permalink
separate vw and pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 21, 2023
1 parent 97d9f0c commit 324ca81
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 66 deletions.
17 changes: 9 additions & 8 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
BasedOn,
Embed,
Featurizer,
ModelRepository,
Policy,
SelectionScorer,
ToSelectFrom,
VwPolicy,
VwLogger,
embed,
)
from learn_to_pick.pick_best import (
Expand All @@ -21,9 +18,13 @@
PickBestSelected,
)

from learn_to_pick.byom.pytorch_policy import PyTorchPolicy

from learn_to_pick.byom.pytorch_feature_embedder import PyTorchFeatureEmbedder
from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger

from learn_to_pick.pytorch.policy import PyTorchPolicy
from learn_to_pick.pytorch.pytorch_feature_embedder import PyTorchFeatureEmbedder


def configure_logger() -> None:
Expand Down Expand Up @@ -52,11 +53,11 @@ def configure_logger() -> None:
"SelectionScorer",
"AutoSelectionScorer",
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"VwPolicy",
"VwLogger",
"embed",
"ModelRepository",
"VwPolicy",
"VwLogger"
]
53 changes: 1 addition & 52 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Callable,
)

from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger

from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

Expand Down Expand Up @@ -89,10 +86,6 @@ def EmbedAndKeep(anything: Any) -> Any:
# helper functions


def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: v.value
Expand Down Expand Up @@ -144,50 +137,6 @@ def save(self) -> None:
pass


class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
self.featurizer = featurizer
self.formatter = formatter
self.vw_logger = vw_logger

def format(self, event):
return self.formatter(*self.featurizer.featurize(event))

def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw

vw_ex = self.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = _parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)

def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.format(event)
self.vw_logger.log(vw_ex)

def save(self) -> None:
self.model_repo.save(self.workspace)


class Featurizer(Generic[TEvent], ABC):
def __init__(self, *args: Any, **kwargs: Any):
pass
Expand Down
8 changes: 4 additions & 4 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import numpy as np

from learn_to_pick import base
from learn_to_pick import base, VwPolicy, ModelRepository, VwLogger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -333,14 +333,14 @@ def create_policy(

vw_cmd = interactions + vw_cmd

return base.VwPolicy(
model_repo=base.ModelRepository(
return VwPolicy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd,
featurizer=featurizer,
formatter=formatter,
vw_logger=base.VwLogger(rl_logs),
vw_logger=VwLogger(rl_logs),
)

def _default_policy(self):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from learn_to_pick import base, PickBestEvent
from learn_to_pick.byom.logistic_regression import ResidualLogisticRegressor
from learn_to_pick.byom.igw import SamplingIGW
from learn_to_pick.pytorch.logistic_regression import ResidualLogisticRegressor
from learn_to_pick.pytorch.igw import SamplingIGW
import torch
import os

Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
61 changes: 61 additions & 0 deletions src/learn_to_pick/vw/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from learn_to_pick.base import Event, Featurizer, Policy
from learn_to_pick import ModelRepository, VwLogger
from typing import (
Any,
List,
Callable,
TYPE_CHECKING,
TypeVar
)

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

TEvent = TypeVar("TEvent", bound=Event)

def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
self.featurizer = featurizer
self.formatter = formatter
self.vw_logger = vw_logger

def format(self, event):
return self.formatter(*self.featurizer.featurize(event))

def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw

vw_ex = self.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = _parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)

def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.format(event)
self.vw_logger.log(vw_ex)

def save(self) -> None:
self.model_repo.save(self.workspace)

0 comments on commit 324ca81

Please sign in to comment.