From 7bae1fbec566c9fc0916ed212b390a4645ea9a65 Mon Sep 17 00:00:00 2001
From: Alphonsce <varlamov.al@phystech.edu>
Date: Mon, 22 Apr 2024 14:49:11 +0300
Subject: [PATCH] arguments fixes

---
 src/metr/finetune_ldm_decoder.py    |  6 +++---
 src/metr/metr_pp_eval_stable_sig.py |  2 +-
 src/metr/run_metr_fid.py            | 13 +++++++------
 3 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/src/metr/finetune_ldm_decoder.py b/src/metr/finetune_ldm_decoder.py
index b310baf..4423e48 100644
--- a/src/metr/finetune_ldm_decoder.py
+++ b/src/metr/finetune_ldm_decoder.py
@@ -58,9 +58,9 @@ def aa(*args, **kwargs):
     aa("--val_dir", type=str, help="Path to the validation data directory", required=True)
 
     group = parser.add_argument_group('Model parameters')
-    aa("--ldm_config", type=str, default="sd/stable-diffusion-v-1-4-original/v1-inference.yaml", help="Path to the configuration file for the LDM model") 
-    aa("--ldm_ckpt", type=str, default="sd/stable-diffusion-v-1-4-original/sd-v1-4-full-ema.ckpt", help="Path to the checkpoint file for the LDM model") 
-    aa("--msg_decoder_path", type=str, default= "models/hidden/dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model")
+    aa("--ldm_config", type=str, default="v2-inference.yaml", help="Path to the configuration file for the LDM model") 
+    aa("--ldm_ckpt", type=str, default="v2-1_512-ema-pruned.ckpt", help="Path to the checkpoint file for the LDM model") 
+    aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt", help="Path to the hidden decoder for the watermarking model")
     aa("--num_bits", type=int, default=48, help="Number of bits in the watermark")
     aa("--redundancy", type=int, default=1, help="Number of times the watermark is repeated to increase robustness")
     aa("--decoder_depth", type=int, default=8, help="Depth of the decoder in the watermarking model")
diff --git a/src/metr/metr_pp_eval_stable_sig.py b/src/metr/metr_pp_eval_stable_sig.py
index 37c7c80..af8534a 100644
--- a/src/metr/metr_pp_eval_stable_sig.py
+++ b/src/metr/metr_pp_eval_stable_sig.py
@@ -338,7 +338,7 @@ def aa(*args, **kwargs):
     aa("--eval_bits", type=utils.bool_inst, default=True, help="")
     aa("--decode_only", type=utils.bool_inst, default=False, help="")
     aa("--key_str", type=str, default="111010110101000001010111010011010100010000100111")
-    aa("--msg_decoder_path", type=str, default= "models/dec_48b_whit.torchscript.pt")
+    aa("--msg_decoder_path", type=str, default= "dec_48b_whit.torchscript.pt")
     aa("--attack_mode", type=str, default= "all")
     aa("--num_bits", type=int, default=48)
     aa("--redundancy", type=int, default=1)
diff --git a/src/metr/run_metr_fid.py b/src/metr/run_metr_fid.py
index 456d3d7..f56a0cb 100644
--- a/src/metr/run_metr_fid.py
+++ b/src/metr/run_metr_fid.py
@@ -74,7 +74,7 @@ def main(args):
     pipe = pipe.to(device)
 
     if args.use_stable_sig:
-        pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path)
+        pipe = change_pipe_vae_decoder(pipe, weights_path=args.decoder_state_dict_path, args=args)
         print("VAE CHANGED!")
 
     # hard coding for now
@@ -316,8 +316,8 @@ def main(args):
 
     # watermark
     parser.add_argument('--w_seed', default=999999, type=int)
-    parser.add_argument('--w_channel', default=0, type=int)
-    parser.add_argument('--w_pattern', default='rand')
+    parser.add_argument('--w_channel', default=3, type=int)
+    parser.add_argument('--w_pattern', default='ring')
     parser.add_argument('--w_mask_shape', default='circle')
     parser.add_argument('--w_radius', default=10, type=int)
     parser.add_argument('--w_measurement', default='l1_complex')
@@ -336,12 +336,13 @@ def main(args):
     parser.add_argument('--msg_type', default='rand', help="Can be: rand or binary or decimal")
     parser.add_argument('--msg', default='1110101101')
     parser.add_argument('--use_random_msgs', action='store_true', help="Generate random message each step of cycle")
-    parser.add_argument('--msgs_file', default=None, help="Path to file, whicha")
     parser.add_argument('--msg_scaler', default=100, type=int, help="Scaling coefficient of message")
 
-    # Stable-Signature arguments:
+    # METR++:
     parser.add_argument('--use_stable_sig', action='store_true')
-    parser.add_argument('--decoder_state_dict_path', default='sd2_decoder.pth')
+    parser.add_argument('--decoder_state_dict_path', default='finetune_ldm_decoder/ldm_decoder_checkpoint_000.pth')
+    parser.add_argument('--stable_sig_full_model_config', default="v2-inference.yaml")
+    parser.add_argument('--stable_sig_full_model_ckpt', default='v2-1_512-ema-pruned.ckpt')
 
     # for image distortion
     parser.add_argument('--r_degree', default=None, type=float)