Skip to content

Refactored Models

Latest
Compare
Choose a tag to compare
@shyamsn97 shyamsn97 released this 28 Dec 21:44
· 2 commits to main since this release

Made it easier to create customized hypernets by removing embedding module + weight generator concepts from base hypernetwork.

Before:

class TorchHyperNetwork(nn.Module, HyperNetwork):
    def __init__(
        self,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
        embedding_dim: int = 100,
        num_embeddings: int = 3,
        weight_chunk_dim: Optional[int] = None,
        custom_embedding_module: Optional[nn.Module] = None,
        custom_weight_generator: Optional[nn.Module] = None,
    ):

After:

class TorchHyperNetwork(nn.Module, HyperNetwork):
    def __init__(
        self,
        target_network: nn.Module,
        num_target_parameters: Optional[int] = None,
    ):