Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#66 from wjm202/flash_attn814
Browse files Browse the repository at this point in the history
add_blip2_flash_attn
  • Loading branch information
lyuwenyu authored Aug 15, 2023
2 parents 5747a4a + c74d90c commit c71874a
Show file tree
Hide file tree
Showing 16 changed files with 739 additions and 879 deletions.
44 changes: 22 additions & 22 deletions paddlemix/examples/blip2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../.."))
import argparse
import os
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
from dataclasses import dataclass, field

import paddle
import requests
import yaml
from paddlenlp.trainer import PdArgumentParser
from PIL import Image

from paddlemix.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlemix.processors.blip_processing import Blip2Processor
from paddlemix.utils.log import logger
import os
import yaml
import paddle
import argparse
import os
import paddle


@dataclass
Expand All @@ -42,8 +42,8 @@ class DataArguments:

input_image: str = field(
default="http://images.cocodataset.org/val2017/000000039769.jpg",
metadata={"help": "The name of input image."},
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
metadata={"help": "The name of input image."
}) # "http://images.cocodataset.org/val2017/000000039769.jpg"
prompt: str = field(
default=None,
metadata={"help": "The prompt of the image to be generated."
Expand Down Expand Up @@ -87,35 +87,35 @@ def main():
decorated = paddle.amp.decorate(
models=[model.visual_encoder, model.language_model],
optimizers=None,
level="O2", )
level="O2")
model.visual_encoder, model.language_model = decorated
dtype = "float16"

shape1 = [None, 3, None, None]
input_spec = [paddle.static.InputSpec(shape=shape1, dtype="float32"), ]
input_spec = [paddle.static.InputSpec(shape=shape1, dtype='float32'), ]
image_encoder = paddle.jit.to_static(
model.encode_image, input_spec=input_spec)
save_path = "blip2_export"
paddle.jit.save(image_encoder, os.path.join(save_path, "image_encoder"))
paddle.jit.save(image_encoder, os.path.join(save_path, 'image_encoder'))

# TODO add test config
deploy_info = {
"Deploy": {
"model": "image_encoder.pdmodel",
"params": "image_encoder.pdiparams",
"input_img_shape": shape1,
"output_dtype": dtype,
'Deploy': {
'model': 'image_encoder.pdmodel',
'params': 'image_encoder.pdiparams',
'input_img_shape': shape1,
'output_dtype': dtype
}
}
msg = "\n---------------Deploy Information---------------\n"
msg = '\n---------------Deploy Information---------------\n'
msg += str(yaml.dump(deploy_info))
logger.info(msg)

yml_file = os.path.join(save_path, "deploy.yaml")
with open(yml_file, "w") as file:
yml_file = os.path.join(save_path, 'deploy.yaml')
with open(yml_file, 'w') as file:
yaml.dump(deploy_info, file)

logger.info(f"The inference model is saved in {save_path}")
logger.info(f'The inference model is saved in {save_path}')


if __name__ == "__main__":
Expand Down
115 changes: 0 additions & 115 deletions paddlemix/examples/blip2/merge_weight.py

This file was deleted.

59 changes: 27 additions & 32 deletions paddlemix/examples/blip2/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import os
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../.."))
import random
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
import paddle.distributed as dist
from paddle.distributed import fleet
from dataclasses import dataclass, field

import numpy as np
import random
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from sklearn.utils import compute_sample_weight
from paddlenlp.trainer import (PdArgumentParser, TrainingArguments,
get_last_checkpoint)
from paddlenlp.transformers import (AutoConfig, AutoTokenizer, OPTConfig,
T5Config)
from sklearn.utils import compute_sample_weight

from paddlenlp.transformers import AutoConfig, OPTConfig, T5Config
from paddlemix.datasets import load_dataset
from paddlemix.examples.blip2.utils import BlipCollator
from paddlemix.models.blip2.configuration import (
Blip2Config, Blip2QFormerConfig, Blip2VisionConfig)
from paddlemix.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlemix.processors.blip_processing import (
Blip2Processor, BlipImageProcessor, BlipTextProcessor)
from paddlemix.processors.blip_processing import Blip2Processor
from paddlemix.trainer.blip2_trainer import BLIP2Trainer as Trainer
from paddlemix.utils.log import logger
from paddlenlp.transformers import AutoTokenizer
from paddlemix.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
from paddlemix.examples.blip2.utils import BlipCollator


@dataclass
Expand All @@ -58,8 +55,8 @@ class DataArguments:
}, )
prompt: str = field(
default="a photo of ",
metadata={"help": "The prompt of the image to be generated."},
) # "Question: how many cats are there? Answer:"
metadata={"help": "The prompt of the image to be generated."
}) # "Question: how many cats are there? Answer:"


@dataclass
Expand All @@ -82,7 +79,6 @@ class PreTrainingArguments(TrainingArguments):
"""
Arguments pertaining to what training options we are going to use during pretraining.
"""

weight_decay: float = field(
default=0.05, metadata={"help": "Weight decay if we apply some."})
learning_rate: float = field(
Expand All @@ -103,12 +99,12 @@ class PreTrainingArguments(TrainingArguments):
default=128,
metadata={
"help": "Batch size per GPU core/CPU for training. (default: 8)"
}, )
})
per_device_eval_batch_size: int = field(
default=1,
metadata={
"help": " Batch size per GPU core/CPU for evaluation. (default:8)"
}, )
})
warmup_start_lr: float = field(
default=1e-6,
metadata={"help": " The initial learning rate of blip2."})
Expand Down Expand Up @@ -137,7 +133,7 @@ class PreTrainingArguments(TrainingArguments):
default=1,
metadata={
"help": "Set the number of sharding, enable sharding parallel"
}, )
})
pipeline_parallel_degree: int = field(
default=1, metadata={"help": "Enable pipeline parallel"})
fp16_opt_level: str = field(
Expand All @@ -154,7 +150,7 @@ class PreTrainingArguments(TrainingArguments):
default=1,
metadata={
"help": "Set the number of sharding, enable sharding parallel"
}, )
})
pipeline_parallel_degree: int = field(
default=1, metadata={"help": "Enable pipeline parallel"})
model_path: str = field(
Expand Down Expand Up @@ -241,18 +237,18 @@ def main():
eval_collator=blip_eval_collator,
processor=processor,
eval_processor=eval_processor,
tokenizer=tokenizer_class, )
tokenizer=tokenizer_class)
eval_metrics = trainer.evaluate(eval_dataset)
trainer.log_metrics("eval", eval_metrics)


def setdistenv(args):
if (args.tensor_parallel_degree * args.sharding_parallel_degree *
args.pipeline_parallel_degree != 1):
if args.tensor_parallel_degree * args.sharding_parallel_degree * args.pipeline_parallel_degree != 1:
args.use_hybrid_parallel = True
args.dp_degree = dist.get_world_size() // (args.tensor_parallel_degree *
args.sharding_parallel_degree *
args.pipeline_parallel_degree)
args.dp_degree = dist.get_world_size() \
// (args.tensor_parallel_degree \
* args.sharding_parallel_degree * \
args.pipeline_parallel_degree)
strategy = fleet.DistributedStrategy()
if args.tensor_parallel_degree > 1:
strategy.tensor_parallel = True
Expand All @@ -271,7 +267,7 @@ def setdistenv(args):
MICRO_BATCH_SIZE = 32
strategy.pipeline_configs = {
"accumulate_steps": BATCH_SIZE // MICRO_BATCH_SIZE,
"micro_batch_size": MICRO_BATCH_SIZE,
"micro_batch_size": MICRO_BATCH_SIZE
}
strategy.find_unused_parameters = True

Expand All @@ -290,8 +286,7 @@ def setdistenv(args):
args.dp_rank = hcg.get_data_parallel_rank()
args.sharding_rank = hcg.get_sharding_parallel_rank()

args.data_world_rank = (
args.dp_rank * args.sharding_parallel_degree + args.sharding_rank)
args.data_world_rank = args.dp_rank * args.sharding_parallel_degree + args.sharding_rank
args.data_world_size = dist.get_world_size() // abs(
args.tensor_parallel_degree * args.pipeline_parallel_degree)

Expand All @@ -301,12 +296,12 @@ def setdistenv(args):

def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank=0):
device_id = paddle.device.get_device()
assert "gpu" in device_id
assert 'gpu' in device_id

random.seed(basic_seed + data_world_rank)
np.random.seed(basic_seed + data_world_rank)
paddle.seed(basic_seed + data_world_rank)
# TODO add manual_seed
#TODO add manual_seed
# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = 1024 + basic_seed + mp_rank * 100 + data_world_rank
global_seed = 2048 + basic_seed + data_world_rank
Expand Down
Loading

0 comments on commit c71874a

Please sign in to comment.