From 7bae1fbec566c9fc0916ed212b390a4645ea9a65 Mon Sep 17 00:00:00 2001 From: Alphonsce <varlamov.al@phystech.edu> Date: Mon, 22 Apr 2024 14:49:11 +0300 Subject: [PATCH] arguments fixes --- src/metr/finetune_ldm_decoder.py | 6 +++--- src/metr/metr_pp_eval_stable_sig.py | 2 +- src/metr/run_metr_fid.py | 13 +++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/metr/finetune_ldm_decoder.py b/src/metr/finetune_ldm_decoder.py index b310baf..4423e48 100644 --- a/src/metr/finetune_ldm_decoder.py +++ b/src/metr/finetune_ldm_decoder.py @@ -58,9 +58,9 @@ def aa(*args, **kwargs): aa("--val_dir", type=str, help="Path to the validation data directory", required=True) group = parser.add_argument_group('Model parameters') - aa("--ldm_config", type=str, default="sd/stable-diffusion-v-1-4-original/v1-inference.yaml", help="Path to the configuration file for the LDM model") - aa("--ldm_ckpt", type=str, default="sd/stable-diffusion-v-1-4-original/sd-v1-4-full-ema.ckpt", help="Path to the checkpoint file for the LDM model") - aa("--msg_decoder_path", type=str, default= "models/hidden/dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model") + aa("--ldm_config", type=str, default="v2-inference.yaml", help="Path to the configuration file for the LDM model") + aa("--ldm_ckpt", type=str, default="v2-1_512-ema-pruned.ckpt", help="Path to the checkpoint file for the LDM model") + aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model") aa("--num_bits", type=int, default=48, help="Number of bits in the watermark") aa("--redundancy", type=int, default=1, help="Number of times the watermark is repeated to increase robustness") aa("--decoder_depth", type=int, default=8, help="Depth of the decoder in the watermarking model") diff --git a/src/metr/metr_pp_eval_stable_sig.py b/src/metr/metr_pp_eval_stable_sig.py index 37c7c80..af8534a 100644 --- a/src/metr/metr_pp_eval_stable_sig.py +++ b/src/metr/metr_pp_eval_stable_sig.py @@ -338,7 +338,7 @@ def aa(*args, **kwargs): aa("--eval_bits", type=utils.bool_inst, default=True, help="") aa("--decode_only", type=utils.bool_inst, default=False, help="") aa("--key_str", type=str, default="111010110101000001010111010011010100010000100111") - aa("--msg_decoder_path", type=str, default= "models/dec_48b_whit.torchscript.pt") + aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt") aa("--attack_mode", type=str, default= "all") aa("--num_bits", type=int, default=48) aa("--redundancy", type=int, default=1) diff --git a/src/metr/run_metr_fid.py b/src/metr/run_metr_fid.py index 456d3d7..f56a0cb 100644 --- a/src/metr/run_metr_fid.py +++ b/src/metr/run_metr_fid.py @@ -74,7 +74,7 @@ def main(args): pipe = pipe.to(device) if args.use_stable_sig: - pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path) + pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path, args=args) print("VAE CHANGED!") # hard coding for now @@ -316,8 +316,8 @@ def main(args): # watermark parser.add_argument('--w_seed', default=999999, type=int) - parser.add_argument('--w_channel', default=0, type=int) - parser.add_argument('--w_pattern', default='rand') + parser.add_argument('--w_channel', default=3, type=int) + parser.add_argument('--w_pattern', default='ring') parser.add_argument('--w_mask_shape', default='circle') parser.add_argument('--w_radius', default=10, type=int) parser.add_argument('--w_measurement', default='l1_complex') @@ -336,12 +336,13 @@ def main(args): parser.add_argument('--msg_type', default='rand', help="Can be: rand or binary or decimal") parser.add_argument('--msg', default='1110101101') parser.add_argument('--use_random_msgs', action='store_true', help="Generate random message each step of cycle") - parser.add_argument('--msgs_file', default=None, help="Path to file, whicha") parser.add_argument('--msg_scaler', default=100, type=int, help="Scaling coefficient of message") - # Stable-Signature arguments: + # METR++: parser.add_argument('--use_stable_sig', action='store_true') - parser.add_argument('--decoder_state_dict_path', default='sd2_decoder.pth') + parser.add_argument('--decoder_state_dict_path', default='finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth') + parser.add_argument('--stable_sig_full_model_config', default="v2-inference.yaml") + parser.add_argument('--stable_sig_full_model_ckpt', default='v2-1_512-ema-pruned.ckpt') # for image distortion parser.add_argument('--r_degree', default=None, type=float)