English | 简体中文
本项目仓库是论文 STT: Building Extraction from Remote Sensing Images with Sparse Token Transformers 的代码实现。
当前分支在PyTorch 2.x 和 CUDA 12.1 下测试通过,支持 Python 3.7+,能兼容绝大多数的 CUDA 版本。
如果你觉得本项目对你有帮助,请给我们一个 star ⭐️,你的支持是我们最大的动力。
🌟 2022.10.23 发布了 STT 项目代码。
🌟 2022.10.25 开源了 WHU 和 INRIA 数据集的预训练模型,你可以在 Hugging Face Spaces 上找到它们。
🌟 2024.02.28 重新整理了本项目。
- Linux 或 Windows
- Python 3.7+,推荐使用 3.10
- PyTorch 2.0 或更高版本,推荐使用 2.1
- CUDA 11.7 或更高版本,推荐使用 12.1
我们推荐使用 Miniconda 来进行安装,以下命令将会创建一个名为 stt
的虚拟环境,并安装 PyTorch。
注解:如果你对 PyTorch 有经验并且已经安装了它,你可以直接跳转到下一小节。否则,你可以按照下述步骤进行准备。
步骤 0:安装 Miniconda。
步骤 1:创建一个名为 stt
的虚拟环境,并激活它。
conda create -n stt python=3.10 -y
conda activate stt
步骤 2:安装 PyTorch2.1.x。
Linux/Windows:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
或者
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
步骤 4:安装其他依赖项。
pip install -U pandas opencv-python tqdm scikit-image einops matplotlib
下载或克隆 STT 仓库即可。
git clone [email protected]:KyanChen/STT.git
cd STT
- 图片及标签下载地址: WHU。
你也可以选择其他来源进行数据的下载,但是需要将数据集组织成如下的格式:
${DATASET_ROOT} # 数据集根目录,例如:/home/username/data/WHU
├── train
│ ├── img
│ └── label
├── val
│ ├── img
│ └── label
└── test
├── img
└── label
注解:在项目文件夹中,我们提供了一个名为 Data
的文件夹,其中包含了上述数据集的组织方式的示例。
你也可以选择其他来源进行数据的下载,但是需要将数据集组织成上述的格式。
如果你想使用其他数据集,可以参考上述方式来进行数据集的准备。
我们提供了训练脚本 Train.py。下面我们提供了一些主要参数的解析。
参数解析:
line3
:os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3' 设置使用的 GPU,一般不需要修改。line17
:backbone:主干网络的类型,一般不需要修改。line22
:top_k_s:选择的空间稀疏的 token 数量,一般不需要修改。line23
:top_k_c:选择的通道稀疏的 token 数量,一般不需要修改。line28
:BATCH_SIZE:单卡的 batch size,需要根据显存大小进行修改。line31
:DATASET:训练数据列表的路径,需要根据数据集的路径进行修改。line32
:model_path:模型的保存路径,一般不需要修改。line33
:log_path:日志的保存路径,一般不需要修改。line35
:IS_VAL:是否进行验证,一般不需要修改。line37
:VAL_DATASET:验证数据列表的路径,需要根据数据集的路径进行修改。line39
:IS_TEST:是否进行测试,一般不需要修改。line40
:TEST_DATASET:测试数据列表的路径,需要根据数据集的路径进行修改。line45
:PRIOR_MEAN:训练集的图片的均值,需要根据数据集的均值进行修改。line46
:PRIOR_STD:训练集的图片的标准差,需要根据数据集的标准差进行修改。line53
:load_checkpoint_path:是否加载检查点,一般为空。line55
:resume_checkpoint_path:是否断点续训,一般为空。
CUDA_VISIBLE_DEVICES=0 python Train.py # 0 为使用的 GPU 编号
CUDA_VISIBLE_DEVICES=0,1,2,3 python Train.py # 0,1,2,3 为使用的 GPU 编号
我们提供了训练脚本 Test.py。需要修改Test.py中line47
的load_checkpoint_path
为你想要使用的检查点文件。
CUDA_VISIBLE_DEVICES=0 python Test.py # 0 为使用的 GPU 编号
CUDA_VISIBLE_DEVICES=0,1,2,3 python Test.py # 0,1,2,3 为使用的 GPU 编号
注解:输出的结果将会保存在 log_path
中。
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 TTP。
@Article{rs13214441,
AUTHOR = {Chen, Keyan and Zou, Zhengxia and Shi, Zhenwei},
TITLE = {Building Extraction from Remote Sensing Images with Sparse Token Transformers},
JOURNAL = {Remote Sensing},
VOLUME = {13},
YEAR = {2021},
NUMBER = {21},
ARTICLE-NUMBER = {4441},
URL = {https://www.mdpi.com/2072-4292/13/21/4441},
ISSN = {2072-4292},
DOI = {10.3390/rs13214441}
}
该项目采用 Apache 2.0 开源许可证。
如果有其他问题❓,请及时与我们联系 👬