Skip to content

Commit

Permalink
arguments fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alphonsce committed Apr 22, 2024
1 parent 7a93b90 commit 7bae1fb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/metr/finetune_ldm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def aa(*args, **kwargs):
aa("--val_dir", type=str, help="Path to the validation data directory", required=True)

group = parser.add_argument_group('Model parameters')
aa("--ldm_config", type=str, default="sd/stable-diffusion-v-1-4-original/v1-inference.yaml", help="Path to the configuration file for the LDM model")
aa("--ldm_ckpt", type=str, default="sd/stable-diffusion-v-1-4-original/sd-v1-4-full-ema.ckpt", help="Path to the checkpoint file for the LDM model")
aa("--msg_decoder_path", type=str, default= "models/hidden/dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model")
aa("--ldm_config", type=str, default="v2-inference.yaml", help="Path to the configuration file for the LDM model")
aa("--ldm_ckpt", type=str, default="v2-1_512-ema-pruned.ckpt", help="Path to the checkpoint file for the LDM model")
aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model")
aa("--num_bits", type=int, default=48, help="Number of bits in the watermark")
aa("--redundancy", type=int, default=1, help="Number of times the watermark is repeated to increase robustness")
aa("--decoder_depth", type=int, default=8, help="Depth of the decoder in the watermarking model")
Expand Down
2 changes: 1 addition & 1 deletion src/metr/metr_pp_eval_stable_sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def aa(*args, **kwargs):
aa("--eval_bits", type=utils.bool_inst, default=True, help="")
aa("--decode_only", type=utils.bool_inst, default=False, help="")
aa("--key_str", type=str, default="111010110101000001010111010011010100010000100111")
aa("--msg_decoder_path", type=str, default= "models/dec_48b_whit.torchscript.pt")
aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt")
aa("--attack_mode", type=str, default= "all")
aa("--num_bits", type=int, default=48)
aa("--redundancy", type=int, default=1)
Expand Down
13 changes: 7 additions & 6 deletions src/metr/run_metr_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(args):
pipe = pipe.to(device)

if args.use_stable_sig:
pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path)
pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path, args=args)
print("VAE CHANGED!")

# hard coding for now
Expand Down Expand Up @@ -316,8 +316,8 @@ def main(args):

# watermark
parser.add_argument('--w_seed', default=999999, type=int)
parser.add_argument('--w_channel', default=0, type=int)
parser.add_argument('--w_pattern', default='rand')
parser.add_argument('--w_channel', default=3, type=int)
parser.add_argument('--w_pattern', default='ring')
parser.add_argument('--w_mask_shape', default='circle')
parser.add_argument('--w_radius', default=10, type=int)
parser.add_argument('--w_measurement', default='l1_complex')
Expand All @@ -336,12 +336,13 @@ def main(args):
parser.add_argument('--msg_type', default='rand', help="Can be: rand or binary or decimal")
parser.add_argument('--msg', default='1110101101')
parser.add_argument('--use_random_msgs', action='store_true', help="Generate random message each step of cycle")
parser.add_argument('--msgs_file', default=None, help="Path to file, whicha")
parser.add_argument('--msg_scaler', default=100, type=int, help="Scaling coefficient of message")

# Stable-Signature arguments:
# METR++:
parser.add_argument('--use_stable_sig', action='store_true')
parser.add_argument('--decoder_state_dict_path', default='sd2_decoder.pth')
parser.add_argument('--decoder_state_dict_path', default='finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth')
parser.add_argument('--stable_sig_full_model_config', default="v2-inference.yaml")
parser.add_argument('--stable_sig_full_model_ckpt', default='v2-1_512-ema-pruned.ckpt')

# for image distortion
parser.add_argument('--r_degree', default=None, type=float)
Expand Down

0 comments on commit 7bae1fb

Please sign in to comment.