Skip to content

Commit

Permalink
Merge branch 'master' into fix-some-bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mosakana committed Aug 21, 2024
2 parents b78f447 + 63b5c16 commit dcc9e1c
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions src/xpmir/text/huggingface/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import InitVar
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import InitVar
from pathlib import Path
from typing import Type

import torch.nn as nn
from experimaestro import Config, Param

Expand All @@ -30,7 +31,12 @@ class HFModelConfigFromId(HFModelConfig):
model_id: Param[str]
"""HuggingFace Model ID"""

def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
def get_config(
self,
options: ModuleInitOptions,
autoconfig: Type[AutoModel],
automodel: Type[AutoConfig],
):
model_id_or_path = self.model_id

# Use saved models
Expand All @@ -49,19 +55,29 @@ def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
)

# Load the model configuration
config = AutoConfig.from_pretrained(model_id_or_path)
config = autoconfig.from_pretrained(model_id_or_path)

# Return it
return config, model_id_or_path

def __call__(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
config, model_id_or_path = self.get_config(options, automodel)
def __call__(
self,
options: ModuleInitOptions,
autoconfig: Type[AutoConfig],
automodel: Type[AutoModel],
):
config, model_id_or_path = self.get_config(options, autoconfig, automodel)

if options.mode == ModuleInitMode.NONE or options.mode == ModuleInitMode.RANDOM:
logging.info("Random initialization of HF model")
return config, automodel.from_config(config)

logging.info("Loading model from HF (%s)", self.model_id)
logging.info(
"Loading model from HF (%s) with model %s.%s",
self.model_id,
automodel.__module__,
automodel.__name__,
)
return config, automodel.from_pretrained(model_id_or_path, config=config)


Expand All @@ -81,6 +97,10 @@ class HFModel(Module):
def from_pretrained_id(cls, model_id: str):
return cls(config=HFModelConfigFromId(model_id=model_id))

@property
def autoconfig(self):
return AutoConfig

@property
def automodel(self):
return AutoModel
Expand All @@ -93,7 +113,9 @@ def __initialize__(self, options: ModuleInitOptions):
"""
super().__initialize__(options)

self.hf_config, self.model = self.config(options, self.automodel)
self.hf_config, self.model = self.config(
options, self.autoconfig, self.automodel
)

@property
def contextual_model(self) -> nn.Module:
Expand Down

0 comments on commit dcc9e1c

Please sign in to comment.