Skip to content

Commit

Permalink
Add num_recycles argument to run functions and update command generation
Browse files Browse the repository at this point in the history
  • Loading branch information
hllelli2 committed Jan 30, 2025
1 parent 83f50a8 commit 416d494
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
3 changes: 3 additions & 0 deletions abcfold/abcfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run(args, config, defaults, config_file):
model_params=args.model_params,
database_dir=args.database_dir,
number_of_models=args.number_of_models,
num_recycles=args.num_recycles,
)

# Need to find the name of the af3_dir
Expand All @@ -137,6 +138,7 @@ def run(args, config, defaults, config_file):
output_dir=args.output_dir,
save_input=args.save_input,
number_of_models=args.number_of_models,
num_recycles=args.num_recycles,
)
bolt_out_dir = list(args.output_dir.glob("boltz_results*"))[0]
bo = BoltzOutput(bolt_out_dir, input_params, name)
Expand All @@ -152,6 +154,7 @@ def run(args, config, defaults, config_file):
output_dir=chai_output_dir,
save_input=args.save_input,
number_of_models=args.number_of_models,
num_recycles=args.num_recycles,
)

co = ChaiOutput(chai_output_dir, input_params, name)
Expand Down
27 changes: 15 additions & 12 deletions abcfold/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def prediction_argparse_util(parser):
default=5,
help="Number of models to generate",
)
parser.add_argument(
"--num_recycles",
type=int,
default=10,
help="Number of recycles to use during the inference",
)
return parser


Expand All @@ -60,6 +66,7 @@ def boltz_argparse_util(parser):
help="Save the input json file",
default=False,
)

return parser


Expand All @@ -70,16 +77,6 @@ def chai_argparse_util(parser):
action="store_true",
help="Run Chai-1",
)
# check if save input is in the parser
if "--save_input" not in parser._option_string_actions:

parser.add_argument(
"--save_input",
action="store_true",
help="Save the input json file",
default=False,
)
# add more arguments here
return parser


Expand All @@ -104,14 +101,20 @@ def alphafold_argparse_util(parser):
"-a",
"--alphafold3",
action="store_true",
help="Run Alphafold",
help="Run Alphafold3",
)

parser.add_argument(
"--override",
help="Override the existing output directory, if it exists",
action="store_true",
)
# add more arguments here
if "--num_recycles" not in parser._option_string:
parser.add_argument(
"--num_recycles",
type=int,
default=10,
help="Number of recycles to use during the inference",
)

return parser
6 changes: 5 additions & 1 deletion abcfold/run_alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def run_alphafold3(
database_dir: Union[str, Path],
interactive: bool = True,
number_of_models: int = 5,
num_recycles: int = 10,
) -> None:
"""
Run Alphafold3 using the input JSON file
Expand Down Expand Up @@ -73,6 +74,7 @@ def run_alphafold3(
database_dir=database_dir,
interactive=interactive,
number_of_models=number_of_models,
num_recycles=num_recycles,
)

logger.info("Running Alphafold3")
Expand All @@ -92,7 +94,8 @@ def generate_af3_cmd(
output_dir: Union[str, Path],
model_params: Union[str, Path],
database_dir: Union[str, Path],
number_of_models: int = 5,
number_of_models: int = 10,
num_recycles: int = 5,
interactive: bool = True,
) -> str:
"""
Expand Down Expand Up @@ -124,6 +127,7 @@ def generate_af3_cmd(
--model_dir=/root/models \
--output_dir=/root/af_output \
--num_diffusion_samples {number_of_models}
--num_recycles {num_recycles}
"""


Expand Down
6 changes: 5 additions & 1 deletion abcfold/run_boltz.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_boltz(
save_input: bool = False,
test: bool = False,
number_of_models: int = 5,
num_recycles: int = 10,
):
"""
Run Boltz1 using the input JSON file
Expand Down Expand Up @@ -56,7 +57,7 @@ def run_boltz(
boltz_yaml.write_yaml(out_file)
logger.info("Running Boltz1")
cmd = (
generate_boltz_command(out_file, output_dir, number_of_models)
generate_boltz_command(out_file, output_dir, number_of_models, num_recycles)
if not test
else generate_boltz_test_command()
)
Expand All @@ -80,6 +81,7 @@ def generate_boltz_command(
input_yaml: Union[str, Path],
output_dir: Union[str, Path],
number_of_models: int = 5,
num_recycles: int = 10,
) -> list:
"""
Generate the Boltz1 command
Expand All @@ -103,6 +105,8 @@ def generate_boltz_command(
"--write_full_pde",
"--diffusion_samples",
str(number_of_models),
"--recycling_steps",
str(num_recycles),
]


Expand Down
3 changes: 3 additions & 0 deletions abcfold/run_chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_chai(
save_input: bool = False,
test: bool = False,
number_of_models: int = 5,
num_recycles: int = 10,
) -> None:
"""
Run Chai-1 using the input JSON file
Expand Down Expand Up @@ -80,6 +81,7 @@ def generate_chai_command(
input_constraints: Union[str, Path],
output_dir: Union[str, Path],
number_of_models: int = 5,
num_recycles: int = 10,
) -> list:
"""
Generate the Chai-1 command
Expand All @@ -105,6 +107,7 @@ def generate_chai_command(
cmd += ["--constraint-path", str(input_constraints)]

cmd += ["--num-diffn-samples", str(number_of_models)]
cmd += ["--num-trunk-recycles", str(num_recycles)]
cmd += [str(output_dir)]

return cmd
Expand Down

0 comments on commit 416d494

Please sign in to comment.