From 9f021026edc7660d3f49471215d62844bf51308c Mon Sep 17 00:00:00 2001 From: dimitribarbot Date: Sat, 28 Sep 2024 15:21:48 +0200 Subject: [PATCH] Upgrade to latest version and include BiRefNet General-Lite-2K model --- README.md | 4 +++- birefnet/config.py | 19 ++++++++++++++----- internal_birefnet/pipeline.py | 5 +++-- scripts/postprocessing_birefnet.py | 1 + 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index eba6f40..70a37c6 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ The available models are: - General: A pre-trained model for general use cases. - General-Lite: A light pre-trained model for general use cases. +- General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440). - Portrait: A pre-trained model for human portraits. - DIS: A pre-trained model for dichotomous image segmentation (DIS). - HRSOD: A pre-trained model for high-resolution salient object detection (HRSOD). @@ -32,6 +33,7 @@ Model files go here (automatically downloaded if the folder is not present durin If necessary, they can be downloaded from: - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General.safetensors` - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite.safetensors` +- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite-2K.safetensors` - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Portrait.safetensors` - [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `DIS.safetensors` - [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `HRSOD.safetensors` @@ -48,7 +50,7 @@ Both endpoints share these parameters: - `return_mask`: whether to return mask (can be used for inpainting). - `return_edge_mask`: whether to return edge mask (can be used to blend foreground image with another background). - `edge_mask_width`: edge mask width in pixels. Default to 64. -- `model_name`: `General`, `General-Lite`, `Portrait`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`. +- `model_name`: `General`, `General-Lite`, `General-Lite-2K`, `Portrait`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`. - `output_dir`: directory to save output images. - `output_extension`: output image file extension (without leading dot, `png` by default). - `device_id`: GPU device id. diff --git a/birefnet/config.py b/birefnet/config.py index fc60176..0a93899 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -4,20 +4,29 @@ class Config(): def __init__(self, bb_index: int = 6) -> None: - # PATH settings + # PATH settings # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx - # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][1] # Default, custom + # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][0] # Default, custom # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') # TASK settings self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0] + self.validation_set = { + 'DIS5K': [], + 'COD': [], + 'HRSOD': [], + 'General': ['DIS-VD', 'TE-P3M-500-NP'], + 'General-2K': ['DIS-VD', 'TE-P3M-500-NP'], + 'Matting': ['TE-P3M-500-NP'], + }[self.task] + # datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.validation_set]) self.training_set = { 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 'COD': 'TR-COD10K+TR-CAMO', 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], - 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), # leave DIS-VD,TE-P3M-500-NP for evaluation. - 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), - 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', + 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all + 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all + 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', # datasets_all }[self.task] self.prompt4loc = ['dense', 'sparse'][0] diff --git a/internal_birefnet/pipeline.py b/internal_birefnet/pipeline.py index ce90d49..28a5d15 100644 --- a/internal_birefnet/pipeline.py +++ b/internal_birefnet/pipeline.py @@ -25,6 +25,7 @@ usage_to_weights_file = { "General": "BiRefNet", "General-Lite": "BiRefNet_T", + "General-Lite-2K": "BiRefNet_lite-2K", "Portrait": "BiRefNet-portrait", "DIS": "BiRefNet-DIS5K", "HRSOD": "BiRefNet-HRSOD", @@ -33,7 +34,7 @@ } BiRefNetModelName = Literal[ - "General", "General-Lite", "Portrait", "DIS", "HRSOD", "COD", "DIS-TR_TEs" + "General", "General-Lite", "General-Lite-2K", "Portrait", "DIS", "HRSOD", "COD", "DIS-TR_TEs" ] @@ -109,7 +110,7 @@ def __init__( state_dict = safetensors.torch.load_file(weight_path, device=self.device) - bb_index = 3 if model_name == "General-Lite" else 6 + bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" else 6 self.birefnet = BiRefNet(bb_pretrained=False, bb_index=bb_index) self.birefnet.load_state_dict(state_dict) diff --git a/scripts/postprocessing_birefnet.py b/scripts/postprocessing_birefnet.py index 2651dcf..783fd1d 100644 --- a/scripts/postprocessing_birefnet.py +++ b/scripts/postprocessing_birefnet.py @@ -11,6 +11,7 @@ "None", "General", "General-Lite", + "General-Lite-2K", "Portrait", "DIS", "HRSOD",