diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 9bf09236..16af56fd 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,15 +1,15 @@ name: publish on: - workflow_call: workflow_dispatch: - + push: + branches: [ main ] jobs: build: + if: contains(github.event.head_commit.message, 'release-please--branches--main') name: Build distribution 📦 runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f189f29c..fc4be7c4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,10 +29,4 @@ jobs: git push origin :stable || true git tag -a stable -m "Last Stable Release" git push origin stable - publish: - needs: release - if: ${{ needs.release.outputs.created }} - permissions: - id-token: write - uses: ./.github/workflows/publish.yml diff --git a/.gitignore b/.gitignore index d62c7696..34b0e478 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ package-lock.json /.quarto/ /.luarc.json +_modules.sh diff --git a/README.md b/README.md index 8d99e1af..d242befa 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,7 @@ pip install kimmdy[reactions,analysis] However, this is only half the fun! KIMMDY has two exciting plugins in the making, which properly parameterize your molecules -for radicals using GrAPPa (Graph Attentional Protein Parametrization) and predict -Hydrogen Atom Transfer (HAT) rates. +for radicals using GrAPPa (Graph Attentional Protein Parametrization) and predict Hydrogen Atom Transfer (HAT) rates. Full installation instructions are available [here](https://graeter-group.github.io/kimmdy/guide/how-to/install-ml-plugins.html) diff --git a/example/alanine_hat_naive/kimmdy.yml b/example/alanine_hat_naive/kimmdy.yml index f98e4b1c..bfe77ef0 100644 --- a/example/alanine_hat_naive/kimmdy.yml +++ b/example/alanine_hat_naive/kimmdy.yml @@ -1,5 +1,6 @@ dryrun: false name: 'alanine_hat_000' +max_hours: 23 max_tasks: 100 gromacs_alias: 'gmx' gmx_mdrun_flags: -maxh 24 -dlb yes @@ -7,7 +8,7 @@ ff: 'amber99sb-star-ildnp.ff' # optional, dir endinng with .ff by default top: 'Ala_out.top' gro: 'npt.gro' ndx: 'index.ndx' -kmc: "rfkmc" +kmc: rfkmc mds: equilibrium: mdp: 'md.mdp' diff --git a/example/alanine_hat_naive/kimmdy_restart.yml b/example/alanine_hat_naive/kimmdy_restart.yml index e84d0a9b..543f45a0 100644 --- a/example/alanine_hat_naive/kimmdy_restart.yml +++ b/example/alanine_hat_naive/kimmdy_restart.yml @@ -8,8 +8,7 @@ top: 'Ala_out.top' gro: 'npt.gro' ndx: 'index.ndx' kmc: "rfkmc" -restart: - run_directory: 'alanine_hat_000' +restart: true mds: equilibrium: mdp: 'md.mdp' diff --git a/example/alanine_hat_naive/kimmdy_restart_task.yml b/example/alanine_hat_naive/kimmdy_restart_task.yml index 25a03663..a72dee39 100644 --- a/example/alanine_hat_naive/kimmdy_restart_task.yml +++ b/example/alanine_hat_naive/kimmdy_restart_task.yml @@ -8,8 +8,7 @@ top: 'Ala_out.top' gro: 'npt.gro' ndx: 'index.ndx' kmc: "rfkmc" -restart: - run_directory: 'alanine_hat_000' +restart: true mds: equilibrium: mdp: 'md.mdp' diff --git a/example/charged_peptide_homolysis_hat_naive/kimmdy.yml b/example/charged_peptide_homolysis_hat_naive/kimmdy.yml index fd8af9bc..2a274ad3 100644 --- a/example/charged_peptide_homolysis_hat_naive/kimmdy.yml +++ b/example/charged_peptide_homolysis_hat_naive/kimmdy.yml @@ -30,7 +30,6 @@ reactions: frequency_factor: 100000000 h_cutoff: 3 polling_rate: 1 -plot_rates: true save_recipes: true sequence: - equilibrium diff --git a/requirements.txt b/requirements.txt index e2da78fa..e4410f30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ # git+ssh://git@github.com/graeter-group/grappa.git ## replace with path to your local copy of the plugins ## for plugin development and kimmdy testing +grappa-ff --index-url https://download.pytorch.org/whl/cpu -e ../kimmdy-grappa -e ../kimmdy-reactions diff --git a/src/kimmdy/analysis.py b/src/kimmdy/analysis.py index a2da7170..f211cdaa 100644 --- a/src/kimmdy/analysis.py +++ b/src/kimmdy/analysis.py @@ -11,7 +11,7 @@ from typing import Optional, Union from collections import defaultdict -import matplotlib as mpl +import matplotlib.axis import matplotlib.pyplot as plt import MDAnalysis as mda import numpy as np @@ -20,9 +20,11 @@ import seaborn.objects as so from seaborn import axes_style +from kimmdy.config import Config from kimmdy.parsing import read_json, write_json, read_time_marker +from kimmdy.plugins import discover_plugins from kimmdy.recipe import Bind, Break, DeferredRecipeSteps, Place, RecipeCollection -from kimmdy.utils import run_shell_cmd, get_task_directories +from kimmdy.utils import read_reaction_time_marker, run_shell_cmd, get_task_directories from kimmdy.constants import MARK_DONE, MARK_STARTED @@ -68,7 +70,6 @@ def concat_traj( """ run_dir = Path(dir).expanduser().resolve() analysis_dir = get_analysis_dir(run_dir) - directories = get_task_directories(run_dir, steps) if not directories: raise ValueError( @@ -90,33 +91,45 @@ def concat_traj( output = output_group ## gather trajectories - trajectories = [] + trajectories: list[Path] = [] tprs = [] gros = [] for d in directories: - trajectories.extend(d.glob(f"*.{filetype}")) + trjs = list(d.glob(f"*.{filetype}")) + trajectories.extend([t for t in trjs if not ".kimmdytrunc." in t.name]) tprs.extend(d.glob("*.tpr")) gros.extend(d.glob("*.gro")) assert ( len(trajectories) > 0 - ), f"No trrs found to concatenate in {run_dir} with subdirectory names {steps}" + ), f"No trajectories found to concatenate in {run_dir} with subdirectory names {steps}" + + for i, trj in enumerate(trajectories): + task_dir = trj.parent + time = read_reaction_time_marker(task_dir) + if time is not None: + new_trj = trj.with_suffix(".kimmdytrunc.xtc") + run_shell_cmd( + f"echo '0' | gmx trjconv -f {trj} -s {tprs[i]} -e {time} -o {new_trj}", + cwd=run_dir, + ) + trajectories[i] = new_trj - trajectories = [str(t) for t in trajectories] - print(trajectories) + flat_trajectories: list[str] = [str(trj) for trj in trajectories] ## write concatenated trajectory tmp_xtc = str(out_xtc.with_name("tmp.xtc")) run_shell_cmd( - f"gmx trjcat -f {' '.join(trajectories)} -o {tmp_xtc} -cat", + rf"gmx trjcat -f {' '.join(flat_trajectories)} -o {tmp_xtc} -cat", cwd=run_dir, ) run_shell_cmd( - f"echo 'Protein\n{output}' | gmx trjconv -dt 0 -f {tmp_xtc} -s {tprs[0]} -o {str(out_xtc)} -center -pbc mol", + s=rf"echo -e 'Protein\n{output}' | gmx trjconv -dt 0 -f {tmp_xtc} -s {tprs[0]} -o {str(out_xtc)} -center -pbc mol", cwd=run_dir, ) + assert out_xtc.exists(), f"Concatenated trajectory {out_xtc} not found." run_shell_cmd( - f"echo 'Protein\n{output}' | gmx trjconv -dump 0 -f {tmp_xtc} -s {tprs[0]} -o {str(out_gro)} -center -pbc mol", + rf"echo -e 'Protein\n{output}' | gmx trjconv -dt 0 -dump 0 -f {tmp_xtc} -s {tprs[0]} -o {str(out_gro)} -center -pbc mol", cwd=run_dir, ) run_shell_cmd(f"rm {tmp_xtc}", cwd=run_dir) @@ -126,7 +139,11 @@ def concat_traj( def plot_energy( - dir: str, steps: Union[list[str], str], terms: list[str], open_plot: bool = False + dir: str, + steps: Union[list[str], str], + terms: list[str], + open_plot: bool = False, + truncate: bool = True, ): """Plot GROMACS energy for a KIMMDY run. @@ -139,8 +156,10 @@ def plot_energy( Default is "all". terms Terms from gmx energy that will be plotted. Uses 'Potential' by default. - open_plot : + open_plot Open plot in default system viewer. + truncate + Truncate energy files to the reaction time marker. """ run_dir = Path(dir).expanduser().resolve() xvg_entries = ["time"] + terms @@ -153,9 +172,10 @@ def plot_energy( xvgs_dir.mkdir(exist_ok=True) ## gather energy files - edrs = [] + edrs: list[Path] = [] for d in subdirs_matched: - edrs.extend(d.glob("*.edr")) + new_edrs = d.glob("*.edr") + edrs.extend([edr for edr in new_edrs if not ".kimmdytrunc." in edr.name]) assert ( len(edrs) > 0 ), f"No GROMACS energy files in {run_dir} with subdirectory names {steps}" @@ -165,11 +185,19 @@ def plot_energy( time_offset = 0 for i, edr in enumerate(edrs): ## write energy .xvg file + task_dir = edr.parent xvg = str(xvgs_dir / edr.parents[0].with_suffix(".xvg").name) step_name = edr.parents[0].name.split("_")[1] + time = read_reaction_time_marker(task_dir) + if time is not None and truncate: + print(f"Truncating {edr} to {time} ps.") + new_edr = edr.with_suffix(".kimmdytrunc.edr") + run_shell_cmd(f"gmx eneconv -f {edr} -e {time} -o {new_edr}") + edr = new_edr + run_shell_cmd( - f"echo '{terms_str} \n\n' | gmx energy -f {str(edr)} -o {xvg}", + f"echo '{terms_str} \n\n' | gmx energy -f {edr} -o {xvg}", cwd=run_dir, ) @@ -217,7 +245,9 @@ def plot_energy( plt.text(x=t, y=v + 0.5, s=s, fontsize=6) ax = plt.gca() - steps_y_axis = [c for c in ax.get_children() if isinstance(c, mpl.axis.YAxis)][0] + steps_y_axis = [ + c for c in ax.get_children() if isinstance(c, matplotlib.axis.YAxis) + ][0] steps_y_axis.set_visible(False) output_path = str(run_dir / "analysis" / "energy.png") plt.savefig(output_path, dpi=300) @@ -423,7 +453,6 @@ def radical_migration( out_path = analysis_dir / "radical_migration.json" with open(out_path, "w") as json_file: json.dump(unique_migrations, json_file) - print("Done!") def plot_rates(dir: str, open: bool = False): @@ -694,7 +723,7 @@ def get_analysis_cmdline_args() -> argparse.Namespace: name="trjcat", help="Concatenate trajectories of a KIMMDY run" ) parser_trjcat.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_trjcat.add_argument("--filetype", "-f", default="xtc") parser_trjcat.add_argument( @@ -708,11 +737,13 @@ def get_analysis_cmdline_args() -> argparse.Namespace: ) parser_trjcat.add_argument( "--open-vmd", + "-o", action="store_true", help="Open VMD with the concatenated trajectory.", ) parser_trjcat.add_argument( "--output-group", + "-g", type=str, help="Index group to include in the output. Default is 'Protein' for xtc and 'System' for trr.", ) @@ -721,7 +752,7 @@ def get_analysis_cmdline_args() -> argparse.Namespace: name="energy", help="Plot GROMACS energy for a KIMMDY run" ) parser_energy.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_energy.add_argument( "--steps", @@ -743,17 +774,21 @@ def get_analysis_cmdline_args() -> argparse.Namespace: ) parser_energy.add_argument( "--open-plot", - "-p", + "-o", + action="store_true", + help="Open plot in default system viewer.", + ) + parser_energy.add_argument( + "--no-truncate", action="store_true", help="Open plot in default system viewer.", ) - parser_radical_population = subparsers.add_parser( name="radical_population", help="Plot population of radicals for one or multiple KIMMDY run(s)", ) parser_radical_population.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_radical_population.add_argument( "--population_type", @@ -820,7 +855,7 @@ def get_analysis_cmdline_args() -> argparse.Namespace: help="Plot rates of all possible reactions after a MD run. Rates must have been saved!", ) parser_rates.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_rates.add_argument( "--open", @@ -834,7 +869,7 @@ def get_analysis_cmdline_args() -> argparse.Namespace: help="Plot runtime of the tasks of a kimmdy run.", ) parser_runtime.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_runtime.add_argument( "--open-plot", @@ -847,7 +882,7 @@ def get_analysis_cmdline_args() -> argparse.Namespace: help="Plot counts of reaction participation per atom id", ) parser_reaction_participation.add_argument( - "dir", type=str, help="KIMMDY run directory to be analysed." + "dir", type=str, help="KIMMDY run directory to be analysed.", nargs="?" ) parser_reaction_participation.add_argument( "--open-plot", @@ -855,19 +890,38 @@ def get_analysis_cmdline_args() -> argparse.Namespace: action="store_true", help="Open plot in default system viewer.", ) + + for subparser in subparsers.choices.values(): + subparser.add_argument( + "--input", + "-i", + type=str, + help="Kimmdy input file. Default `kimmdy.yml`. Only used to infer the output directory if `dir` is not provided.", + default="kimmdy.yml", + ) + return parser.parse_args() def entry_point_analysis(): """Analyse existing KIMMDY runs.""" args = get_analysis_cmdline_args() + if hasattr(args, "dir") and args.dir is None: + discover_plugins() + # the restart option is used here to avoid creating a new + # output directory and instead use the one from the config verbatim + # without incrementing a number + config = Config(input_file=args.input, restart=True) + args.dir = str(config.out) if args.module == "trjcat": concat_traj( args.dir, args.filetype, args.steps, args.open_vmd, args.output_group ) elif args.module == "energy": - plot_energy(args.dir, args.steps, args.terms, args.open_plot) + plot_energy( + args.dir, args.steps, args.terms, args.open_plot, not args.no_truncate + ) elif args.module == "radical_population": radical_population( args.dir, @@ -890,5 +944,5 @@ def entry_point_analysis(): ) else: print( - "No analysis module specified. Use -h for help and a list of available modules." + "No analysis module specified. Use -h for --help and a list of available modules." ) diff --git a/src/kimmdy/assets/templates.py b/src/kimmdy/assets/templates.py index 5cffd2f3..33b0903e 100644 --- a/src/kimmdy/assets/templates.py +++ b/src/kimmdy/assets/templates.py @@ -4,20 +4,23 @@ #SBATCH --output={config.out.name}-job.log #SBATCH --error={config.out.name}-job.log #SBATCH --time={config.max_hours}:00:00 -#SBATCH -N={config.slurm.N} +#SBATCH --nodes={config.slurm.N} #SBATCH --ntasks-per-node={config.slurm.ntasks_per_node} #SBATCH --mincpus={config.slurm.mincpus} #SBATCH --exclusive #SBATCH --cpus-per-task={config.slurm.cpus_per_task} #SBATCH --gpus={config.slurm.gpus} #SBATCH --mail-type=ALL -# #SBATCH -p .p +# # uncomment these to use: +# #SBATCH --partition .p # #SBATCH --mail-user= # Setup up your environment here # modules.sh might load lmod modules, set environment variables, etc. -source ./_modules.sh +if [ -f ./_modules.sh ]; then + source ./_modules.sh +fi CYCLE={config.max_hours} CYCLE_buffered=$(echo "scale=2; $CYCLE - 0.08" | bc) @@ -25,7 +28,7 @@ START=$(date +"%s") -timeout ${{CYCLE_buffered}}h kimmdy -i {config.input_file} +timeout ${{CYCLE_buffered}}h kimmdy -i {config.input_file} --restart END=$(date +"%s") @@ -40,9 +43,7 @@ exit 3 else echo "jobscript resubmitting" - sed -i.bak "s/\\(run_directory:\\s*\\).*/\\1'{config.out.name}'/" {config.input_file} - kimmdy --generate-jobscript - sbatch ./jobscript.sh + {config.slurm.runcmd} ./jobscript.sh exit 2 fi """ diff --git a/src/kimmdy/cmd.py b/src/kimmdy/cmd.py index e77c60fe..8177789e 100644 --- a/src/kimmdy/cmd.py +++ b/src/kimmdy/cmd.py @@ -6,7 +6,6 @@ import argparse import logging import logging.config -import shutil import sys import textwrap from os import chmod @@ -40,26 +39,37 @@ def get_cmdline_args() -> argparse.Namespace: Parsed command line arguments """ parser = argparse.ArgumentParser( - description="Welcome to KIMMDY. `kimmdy` runs KIMMDY, further tools " + description="Welcome to KIMMDY. `kimmdy` runs KIMMDY. Additinal tools " "are available as `kimmdy-...` commands. These are `-analysis`, " "`-modify-top` and `-build-examples`. Access their help with " "`kimmdy-... -h.`" + "Visit the documentation online at " ) parser.add_argument( "--input", "-i", type=str, - help="Kimmdy input file. Default `kimmdy.yml`", + help=( + "Kimmdy input file. Defaults to `kimmdy.yml`. See for all options. CLI flags (e.g. --restart or --loglevel) have precedence over their counterparts in the input file." + ), default="kimmdy.yml", ) + parser.add_argument( + "--restart", + "-r", + action="store_true", + help=( + "Restart or continue from a previous run instead of incrementing the run number for the output directory. It the output directory does not exist, it will be like a regular fresh run." + ), + ) parser.add_argument( "--loglevel", "-l", type=str, - help="logging level (CRITICAL, ERROR, WARNING, INFO, DEBUG)", + help="Logging level (CRITICAL, ERROR, WARNING, INFO, DEBUG)", default=None, ) - parser.add_argument("--logfile", "-f", type=Path, help="logfile", default=None) + parser.add_argument("--logfile", "-f", type=Path, help="Logfile", default=None) # flag to show available plugins parser.add_argument( @@ -70,26 +80,28 @@ def get_cmdline_args() -> argparse.Namespace: parser.add_argument( "--generate-jobscript", action="store_true", - help="Instead of running KIMMDY directly, generate at jobscript.sh for" + help="Instead of running KIMMDY directly, generate the output directory and a jobscript `jobscript.sh` for" " slurm HPC clusters." - "You can then run this jobscript with sbatch jobscript.sh", + " You can then run this jobscript with sbatch jobscript.sh.", ) parser.add_argument( - "--version", action="version", version=f'KIMMDY {version("kimmdy")}' + "--version", + action="version", + version=f'KIMMDY {version("kimmdy")}', + help=("Show version and exit."), ) # on error, drop into debugger parser.add_argument( - "--debug", action="store_true", help=("on error, drop into debugger") + "--debug", action="store_true", help=("On error, drop into debugger") ) # visualize call stack parser.add_argument( "--callgraph", action="store_true", - help="Generate visualization of function calls. Mostly useful for " - "debugging and documentation.", + help="Generate a visualization of function calls for debugging and documentation.", ) return parser.parse_args() @@ -135,7 +147,10 @@ def _run(args: argparse.Namespace): try: discover_plugins() config = Config( - input_file=args.input, logfile=args.logfile, loglevel=args.loglevel + input_file=args.input, + logfile=args.logfile, + loglevel=args.loglevel, + restart=args.restart, ) logger.info("Welcome to KIMMDY") @@ -154,6 +169,10 @@ def _run(args: argparse.Namespace): runmgr = RunManager(config) if args.generate_jobscript: + if config.max_hours == 0: + m = f"kimmdy.config.max_hours is set to 0, which would create a non-sensical jobscript." + logger.error(m) + raise ValueError(m) content = jobscript.format(config=config).strip("\n") path = "jobscript.sh" @@ -197,8 +216,6 @@ def _run(args: argparse.Namespace): raise e finally: logging.shutdown() - if args.generate_jobscript: - shutil.rmtree(config.out) def kimmdy_run( @@ -209,6 +226,7 @@ def kimmdy_run( generate_jobscript: bool = False, debug: bool = False, callgraph: bool = False, + restart: bool = False, ): """Run KIMMDY from python. @@ -225,7 +243,13 @@ def kimmdy_run( show_plugins Show available plugins and exit. generate_jobscript - Instead of running KIMMDY directly, generate at jobscript.sh for slurm HPC clusters + Instead of running KIMMDY directly, generate at jobscript.sh for slurm HPC clusters. + debug + on error, drop into debugger. + callgraph + Generate visualization of function calls. Mostly useful for debugging and documentation. + restart + Restart from a previous run instead of incrementing the run number for the output directory. """ args = argparse.Namespace( input=input, @@ -235,6 +259,7 @@ def kimmdy_run( generate_jobscript=generate_jobscript, debug=debug, callgraph=callgraph, + restart=restart, ) _run(args) diff --git a/src/kimmdy/config.py b/src/kimmdy/config.py index 38bfdae0..3780f2dd 100644 --- a/src/kimmdy/config.py +++ b/src/kimmdy/config.py @@ -11,7 +11,7 @@ import shutil from pathlib import Path from pprint import pformat -from typing import Any, Optional +from typing import Any import yaml @@ -126,8 +126,9 @@ def __init__( opts: dict | None = None, scheme: dict | None = None, section: str = "config", - logfile: Optional[Path] = None, - loglevel: Optional[str] = None, + logfile: Path | None = None, + loglevel: str | None = None, + restart: bool = False, ): # initial scheme @@ -208,7 +209,9 @@ def __init__( "errors": [], "debugs": [], } - self._set_defaults(section, scheme) + + # defaults from the scheme + self._set_defaults(section=section, scheme=scheme, restart=restart) self._validate(section=section, cwd=self.cwd) # merge with command line arguments @@ -233,7 +236,9 @@ def __init__( # use the constructed config to set up the logger configure_logger(self) - def _set_defaults(self, section: str = "config", scheme: dict = {}): + def _set_defaults( + self, section: str = "config", scheme: dict = {}, restart: bool = False + ): """ Set defaults for attributes not set in yaml file but specified in scheme (generated from the schema). @@ -296,15 +301,20 @@ def _set_defaults(self, section: str = "config", scheme: dict = {}): # implicit defaults not in the schema # but defined in terms of other attributes + # or from the cli if section == "config": + # restart flag is true if set, otherwise false + if restart is True: + self.restart = restart + self.name = self.name.replace(" ", "_") if not hasattr(self, "cwd"): self.cwd = Path.cwd() if not hasattr(self, "out"): self.out = self.cwd / Path(self.name) - # make sure self.out is empty - while self.out.exists(): + # make sure self.out is empty unless restart is set + while self.out.exists() and not self.restart: self._logmessages["debugs"].append( f"Output dir {self.out} exists, incrementing name" ) @@ -317,8 +327,13 @@ def _set_defaults(self, section: str = "config", scheme: dict = {}): else: self.out = self.out.with_name(self.out.name + "_001") - self.out.mkdir() - self._logmessages["infos"].append(f"Created output dir {self.out}") + if not self.out.exists(): + self.out.mkdir() + self._logmessages["infos"].append(f"Created output dir {self.out}") + else: + self._logmessages["infos"].append( + f"Restarting in output dir {self.out}" + ) def _validate(self, section: str = "config", cwd: Path = Path(".")): """Validates config.""" diff --git a/src/kimmdy/constants.py b/src/kimmdy/constants.py index 29af47ba..ad04fd3f 100644 --- a/src/kimmdy/constants.py +++ b/src/kimmdy/constants.py @@ -4,8 +4,10 @@ MARK_STARTED = ".kimmdy_started" MARK_DONE = ".kimmdy_done" +MARK_FINISHED = ".kimmdy_finished" MARK_FAILED = ".kimmdy_failed" -MARKERS = [MARK_STARTED, MARK_DONE, MARK_FAILED] +MARK_REACION_TIME = ".kimmdy_reaction_time" +MARKERS = [MARK_STARTED, MARK_DONE, MARK_FAILED, MARK_FINISHED, MARK_REACION_TIME] ATOM_ID_FIELDS = { "atoms": [0, 5], # atomnr, chargegroup diff --git a/src/kimmdy/kimmdy-yaml-schema.json b/src/kimmdy/kimmdy-yaml-schema.json index 2550e481..842eaf51 100644 --- a/src/kimmdy/kimmdy-yaml-schema.json +++ b/src/kimmdy/kimmdy-yaml-schema.json @@ -43,6 +43,12 @@ "description": "n gpus", "pytype": "int", "default": 1 + }, + "runcmd": { + "type": "string", + "description": "Command to (re)submit the jobscript. Default is `sbatch`. For local testing replace with an empty string do run the jobscript directly.", + "pytype": "str", + "default": "sbatch" } }, "additionalProperties": false @@ -139,7 +145,8 @@ "frm", "extrande", "extrande_mod", - "multi_rfkmc" + "multi_rfkmc", + "dummy_first" ], "default": "" }, @@ -269,17 +276,9 @@ "restart": { "title": "restart", "type": "object", - "description": "Restart from a previous run.", - "properties": { - "run_directory": { - "type": "string", - "pytype": "Path", - "description": "KIMMDY run directory to restart from" - } - }, - "required": [ - "run_directory" - ] + "description": "Restart or continue from a previous run (in config.out) instead of starting from scratch", + "pytype": "bool", + "default": false }, "mds": { "title": "mds", diff --git a/src/kimmdy/kmc.py b/src/kimmdy/kmc.py index ba665e1d..a2630662 100644 --- a/src/kimmdy/kmc.py +++ b/src/kimmdy/kmc.py @@ -74,6 +74,27 @@ class KMCError(BaseException): KMCResult = KMCAccept | KMCReject | KMCError +def dummy_first_kmc( + recipe_collection: RecipeCollection, +) -> KMCResult: + """Dummy KMC method that always chooses the first reaction in the list.""" + if len(recipe_collection.recipes) == 0: + m = "Empty ReactionResult; no reaction chosen" + logger.warning(m) + return KMCError(m) + recipe = recipe_collection.recipes[0] + time_index = np.argmax(recipe.rates) + reaction_time = recipe.timespans[time_index][1] + logger.info(f"Chosen Recipe: {recipe} at time {reaction_time}") + return KMCAccept( + recipe=recipe, + reaction_probability=None, + time_delta=0, + time_start=reaction_time, + time_start_index=int(time_index), + ) + + def rf_kmc( recipe_collection: RecipeCollection, rng: np.random.Generator = default_rng(), @@ -491,7 +512,7 @@ def multi_rfkmc( logger.debug(f"Start multi-KMC, {len(recipes)} recipes. Picking {n} recipes.") results = [] - for i in range(n): + for _ in range(n): # check for empty ReactionResult if len(recipes) == 0 and len(results) == 0: m = "Empty ReactionResult; no reaction chosen" diff --git a/src/kimmdy/plugins.py b/src/kimmdy/plugins.py index 88a4b589..64edd810 100644 --- a/src/kimmdy/plugins.py +++ b/src/kimmdy/plugins.py @@ -25,6 +25,11 @@ def discover_plugins(): + """Discover and load KIMMDY plugins. + + This has to be called before initialzing the [](`~kimmdy.config.Config`) such that + in can be validated against the registered plugins. + """ if sys.version_info > (3, 10): from importlib_metadata import entry_points diff --git a/src/kimmdy/recipe.py b/src/kimmdy/recipe.py index 0c3e33fd..5c4af97b 100644 --- a/src/kimmdy/recipe.py +++ b/src/kimmdy/recipe.py @@ -527,7 +527,6 @@ def get_vmd_selection(self) -> str: return "" ixs = set() for rs in self.recipe_steps: - print(self.recipe_steps) if isinstance(rs, BondOperation): ixs.add(rs.atom_ix_1) ixs.add(rs.atom_ix_2) diff --git a/src/kimmdy/runmanager.py b/src/kimmdy/runmanager.py index 5ee1fdbc..1296400f 100644 --- a/src/kimmdy/runmanager.py +++ b/src/kimmdy/runmanager.py @@ -22,13 +22,20 @@ from typing import Callable, Optional from kimmdy.config import Config -from kimmdy.constants import MARK_DONE, MARK_FAILED, MARK_STARTED, MARKERS +from kimmdy.constants import ( + MARK_DONE, + MARK_FAILED, + MARK_FINISHED, + MARK_STARTED, + MARKERS, +) from kimmdy.coordinates import break_bond_plumed, merge_top_slow_growth, place_atom from kimmdy.kmc import ( KMCError, KMCReject, KMCAccept, KMCResult, + dummy_first_kmc, extrande, extrande_mod, frm, @@ -59,7 +66,8 @@ flatten_recipe_collections, get_task_directories, run_gmx, - truncate_sim_files, + write_coordinate_files_at_reaction_time, + write_reaction_time_marker, ) logger = logging.getLogger(__name__) @@ -68,11 +76,12 @@ AMBIGUOUS_SUFFS = ["dat", "xvg", "log", "itp", "mdp"] # file strings which to ignore IGNORE_SUBSTR = [ - "_prev.cpt", - r"step\d+[bc]\.pdb", - r"\.tail", - r"_mod\.top", - r"\.1#", + "_prev.cpt$", + r"step\d+[bc]\.pdb$", + r"\.tail$", + r"_mod\.top$", + r"\.\d+#$", + r"\.log$", "rotref", ] + MARKERS # are there cases where we have multiple trr files? @@ -182,9 +191,14 @@ def __init__(self, config: Config): self.cptfile: Path = self.config.out / "kimmdy.cpt" self.kmc_algorithm: str - logger.info( - f"Initialized RunManager at cwd: {config.cwd} with output directory {config.out}" - ) + with open(self.histfile, "w") as f: + f.write("KIMMDY task file history\n") + f.write( + "Filepaths in the output directory are shortened to be relative to the output directory.\n\n" + ) + + logger.info(f"Initialized KIMMDY at cwd: {config.cwd}") + logger.info(f"with output directory {config.out}") try: if self.config.changer.topology.parameterization == "basic": self.parameterizer = BasicParameterizer() @@ -211,7 +225,7 @@ def __init__(self, config: Config): reactive_nrexcl=nrexcl, ) self.filehist: list[dict[str, TaskFiles]] = [ - {"setup": TaskFiles(self.get_latest)} + {"0_setup": TaskFiles(self.get_latest)} ] # Initialize reaction plugins used in the sequence @@ -228,12 +242,21 @@ def __init__(self, config: Config): "frm": frm, "extrande_mod": extrande_mod, "multi_rfkmc": multi_rfkmc, + "dummy_first": dummy_first_kmc, } self.task_mapping = { - "md": {"f": self._run_md, "kwargs": {}, "out": None}, + "md": { + "f": self._run_md, + "kwargs": {}, + "out": None, + }, # name of out director is inferred from the instance "reactions": [ - {"f": self._place_reaction_tasks, "kwargs": {}, "out": None}, + { + "f": self._place_reaction_tasks, + "kwargs": {}, + "out": None, + }, # has no output directory { "f": self._decide_recipe, "kwargs": {}, @@ -241,7 +264,6 @@ def __init__(self, config: Config): }, {"f": self._apply_recipe, "kwargs": {}, "out": "apply_recipe"}, ], - "restart": {"f": self._restart_task, "kwargs": {}, "out": None}, } """Mapping of task names to functions and their keyword arguments.""" @@ -251,9 +273,13 @@ def run(self): logger.info("Start run") self.start_time = time.time() - if restart_dir := getattr(self.config.restart, "run_directory", None): - logger.info(f"Restarting from: {restart_dir}") - self._restart_from_rundir() + if self.config.restart: + logger.info(f"Restarting from previous run in: {self.config.out.name}") + if (self.config.out / MARK_FINISHED).exists(): + m = f"Run in {self.config.out} already finished. Exiting." + logger.info(m) + return + self._setup_restart() while ( self.state is not State.DONE @@ -265,12 +291,237 @@ def run(self): ): next(self) + write_time_marker(self.config.out / MARK_FINISHED, "finished") logger.info( - f"Finished running tasks, state: {self.state} after " + f"Finished running last task, state: {self.state} after " f"{timedelta(seconds=(time.time() - self.start_time))} " f"In output directory {self.config.out}" ) + def _setup_restart(self): + """Set up RunManager to restart from an existing run directory""" + + task_dirs = get_task_directories(self.config.out) + if task_dirs == []: + # no tasks found in the output directory. this is a fresh run + return + + logger.info( + f"Found task directories in existing output directory ({self.config.out.name}): {[p.name for p in task_dirs]}" + ) + logger.info(f"Task queue: {self.tasks}") + + found_restart_point = False + restart_task_name = None + restart_from_incomplete = False + + # restarting during or after MD except for relaxation/slow_growth MDs are valid restart points + if hasattr(self.config.changer, "coordinates") and hasattr( + self.config.changer.coordinates, "md" + ): + relax_md_name = self.config.changer.coordinates.md + else: + relax_md_name = None + if hasattr(self.config, "mds"): + md_task_names = [ + name + for name in self.config.mds.get_attributes() + if name != relax_md_name + ] + else: + md_task_names = [] + + # keep track of how often we have completed each md instance + # such that we can later restart correctly at the latest instance + md_instance_dir_counter = {} + for name in md_task_names: + md_instance_dir_counter[name] = 0 + + # discover completed or half completed tasks + logger.info("Checking for restart point in existing task dirs.") + for task_dir in task_dirs: + task_n, task_name = task_dir.name.split(sep="_", maxsplit=1) + task_n = int(task_n) + logger.info(f"Checking task: {task_n}_{task_name}") + if (task_dir / MARK_FAILED).exists(): + m = ( + f"Task in directory `{task_dir.name}` is indicated to " + "have failed. Aborting restart. Remove this task " + "directory if you want to restart from before the failed task." + ) + + inp = input("Do you want to delete this task directory? [y/n]") + if inp == "y": + shutil.rmtree(task_dir) + else: + exit(1) + elif ( + (task_dir / MARK_STARTED).exists() + and not (task_dir / MARK_DONE).exists() + and task_name in md_task_names + ): + # Continue from started but not finished md task is a valid restart point + # if it got so far that is has written at least one checkpoint file + checkpoint_files = list(task_dir.glob("*.cpt")) + if len(checkpoint_files) == 0: + m = f"Last started but not done task is {task_dir.name}, but no checkpoint file found. Using an earlier MD or setup task as restart point instead." + logger.warning(m) + continue + logger.info(f"Found started but not finished task {task_dir.name}.") + logger.info(f"Will continue task {task_dir.name}") + found_restart_point = True + self.iteration = task_n + restart_task_name = task_name + restart_from_incomplete = True + md_instance_dir_counter[task_name] += 1 + # no need to search further, as this task is unfinished + break + elif ( + (task_dir / MARK_STARTED).exists() + and (task_dir / MARK_DONE).exists() + and task_name in md_task_names + ): + # Continue after last finished MD task is a valid restart point + # but continue searching for newer tasks after this + logger.info(f"Found completed task {task_dir.name}") + logger.info(f"Will continue after task {task_dir.name}") + found_restart_point = True + self.iteration = task_n + restart_task_name = task_name + restart_from_incomplete = False + md_instance_dir_counter[task_name] += 1 + elif ( + (task_dir / MARK_STARTED).exists() + and (task_dir / MARK_DONE).exists() + and task_name == "setup" + ): + # Continuing just after 0_setup is a valid restart point + found_restart_point = True + self.iteration = task_n + restart_task_name = task_name + restart_from_incomplete = False + elif (task_dir / MARK_STARTED).exists() and (task_dir / MARK_DONE).exists(): + # Completed task, but not an MD task + pass + elif (task_dir / MARK_STARTED).exists() and not ( + task_dir / MARK_DONE + ).exists(): + # Started but not done task, not an MD task + m = f"Last started but not done task is {task_dir.name}, which can not be restarted from. Restarting instead from after the last completed MD task." + logger.info(m) + break + else: + m = f"Encountered task directory {task_dir.name} but kimmdy does not know how to handle this task. Aborting restart." + logger.error(m) + raise RuntimeError(m) + + if not found_restart_point or not restart_task_name: + m = "No valid restart point found in existing task directories." + logger.error(m) + raise RuntimeError(m) + + m = f"Restarting from iteration (task number) {self.iteration} with name {restart_task_name}" + logger.info(m) + + # pop from the task queue until the restart point + # all md tasks from which we may restart are in the task queue. + # Only e.g. relax mds would just show up in the runtime prioroty queue, + # regular md tasks are known before the run starts from the config.sequence. + task = None + instance = None + found_restart_task = False + md_instance_task_counter = {k: 0 for k in md_instance_dir_counter.keys()} + while not self.tasks.empty(): + task = self.tasks.get() + if task.name == restart_task_name and restart_task_name == "setup": + # this is the case if setup ends up as the restart point + instance = "setup" + found_restart_task = True + break + if task.name == "run_md": + instance = task.kwargs["instance"] + md_instance_task_counter[instance] += 1 + logger.debug( + f"{task}, {md_instance_task_counter[instance]}, {md_instance_dir_counter[instance]}" + ) + if ( + instance == restart_task_name + and md_instance_task_counter[instance] + == md_instance_dir_counter[instance] + ): + # restart from the last completed (or half completed) valid restart point + found_restart_task = True + # put the task back in the queue + if restart_from_incomplete: + logger.info( + "Restarting from incomplete task. Will continue this task." + ) + # append to the front of the queue + task.kwargs.update({"continue_md": True}) + self.tasks.queue.appendleft(task) + else: + logger.info("Restarting after completed task.") + break + + if not isinstance(task, Task) or not found_restart_task or not instance: + m = f"Could not find task {restart_task_name} in task queue. Either '{task}' is no 'Task', restart task found state is False: '{found_restart_task}' or MD task instance is None: '{instance}'. Aborting restart." + logger.error(m) + raise RuntimeError(m) + if restart_from_incomplete: + fragment = "within" + else: + fragment = "after" + logger.info( + f"Restarting from {fragment} task {task.name} with instance {instance}" + ) + + # clean up old task directories that will be overwritten + for task_dir in task_dirs[self.iteration + 1 :]: + shutil.rmtree(task_dir) + + # discover after it is clear which tasks will be in queue + for task_dir in task_dirs[: self.iteration + 1]: + task_name = "_".join(task_dir.name.split(sep="_")[1:]) + task_files = TaskFiles( + self.get_latest, {}, {}, self.config.out / task_dir.name + ) + self._discover_output_files(task_name, task_files) + + # plumed fix + for md_config in self.config.mds.__dict__.values(): + if getattr(md_config, "use_plumed"): + plumed_out_name = get_plumed_out(self.latest_files["plumed"]).name + plumed_out = self.get_latest(plumed_out_name) + if plumed_out is not None: + self.latest_files["plumed_out"] = plumed_out + self.latest_files.pop(plumed_out_name) + else: + logger.warning( + f"Plumed out file {plumed_out_name} not found. Continuing without it." + ) + + # use latest top file + top_path = self.get_latest("top") + if top_path is None: + m = "No topology file found in output directory." + logger.error(m) + raise FileNotFoundError(m) + self.top = Topology( + top=read_top(top_path, self.config.ff), + parametrizer=self.parameterizer, + is_reactive_predicate_f=get_is_reactive_predicate_from_config_f( + self.config.topology.reactive + ), + radicals=getattr(self.config, "radicals", None), + residuetypes_path=getattr(self.config, "residuetypes", None), + ) + + # if we restart from within an imcomplete task, + # decrement the iteration because if will be incremented + # when the task starts again + if restart_from_incomplete: + self.iteration -= 1 + def _setup_tasks(self): """Populates the tasks queue. Allows for mapping one sequence entry in the config to multiple tasks @@ -291,6 +542,8 @@ def _setup_tasks(self): md = self.task_mapping["md"] kwargs: dict = copy(md["kwargs"]) kwargs.update({"instance": task_name}) + if task_name is None: + raise ValueError("MD task name is None") task = Task( self, f=md["f"], @@ -312,27 +565,17 @@ def _setup_tasks(self): # check all reactions for task_kwargs in self.task_mapping["reactions"]: self.tasks.put(Task(self, **task_kwargs)) - elif task_name == "restart": - restart = self.task_mapping["restart"] - kwargs: dict = copy(restart["kwargs"]) - task = Task( - self, - f=restart["f"], - kwargs=kwargs, - out=task_name, - ) - self.tasks.put(task) else: m = f"Unknown task encountered in the sequence: {task_name}" logger.error(m) raise ValueError(m) logger.info(f"Task list build:\n{pformat(list(self.tasks.queue), indent=8)}") - def get_latest(self, suffix: str): + def get_latest(self, suffix: str) -> Path | None: """Returns path to latest file of given type. For .dat files (in general ambiguous extensions) use full file name. - Errors if file is not found. + Return None if file is not found. """ logger.debug("Getting latest suffix: " + suffix) try: @@ -341,8 +584,8 @@ def get_latest(self, suffix: str): return path except Exception: m = f"File {suffix} requested but not found!" - logger.error(m) - raise FileNotFoundError(m) + logger.warning(m) + return None def __iter__(self): return self @@ -369,62 +612,87 @@ def _discover_output_files( and add those files to the `files` as well as the file history and latest files. + and check if double suffs are properly defined declared files by the task """ - # discover other files written by the task - if hasattr(files, "outputdir"): - # check whether double suffs are properly defined in files by the task - discovered_files = [ - p - for p in files.outputdir.iterdir() - if not any(re.search(s, p.name) for s in IGNORE_SUBSTR) - ] - suffs = [p.suffix[1:] for p in discovered_files] - counts = [suffs.count(s) for s in suffs] - for suff, c in zip(suffs, counts): - if c != 1 and suff not in AMBIGUOUS_SUFFS: - if files.output.get(suff) is None: - e = ( - "ERROR: Task produced multiple files with same suffix but " - "did not define with which to continue!\n" - f"Task {taskname}, Suffix {suff} found {c} times" - ) - logger.error(e) - raise RuntimeError(e) - - # register discovered output files - for path in discovered_files: - suffix = path.suffix[1:] - if suffix in AMBIGUOUS_SUFFS: - suffix = path.name - # don't overwrite manually added keys in files.output - if files.output.get(suffix) is not None: - continue - files.output[suffix] = files.outputdir / path - - # remove double entries - if "plumed" in files.input.keys(): - if plumed := files.output.get("plumed"): - files.output.pop(plumed.name) - if plumed_out := files.output.get("plumed_out"): - files.output.pop(plumed_out.name) - - logger.debug(f"Update latest files with:\n{pformat(files.output)}") - self.latest_files.update(files.output) - self.filehist.append({taskname: files}) - - m = f""" - Task: {taskname} with output directory: {files.outputdir} - Task: {taskname}, input:\n{pformat(files.input)} - Task: {taskname}, output:\n{pformat(files.output)} - """ - with open(self.histfile, "a") as f: - f.write(m) - return files - else: + if not hasattr(files, "outputdir"): logger.debug("No output directory found for task: " + taskname) return None + discovered_files = [ + p + for p in files.outputdir.iterdir() + if not any(re.search(s, p.name) for s in IGNORE_SUBSTR) + ] + suffs = [p.suffix[1:] for p in discovered_files] + # if gro/trr file is found and we wrote a ._reaction.gro file + # explicitly make this the latest gro file + for duplicate_suffix in ["gro", "trr"]: + duplicates = [ + p for p in discovered_files if p.suffix[1:] == duplicate_suffix + ] + for duplicate in duplicates: + if "_reaction" in duplicate.name: + files.output[duplicate_suffix] = duplicate + logger.info( + f"Found reaction coordinate file: {duplicate} and set as latest or the non-reaction file" + ) + break + + counts = [suffs.count(s) for s in suffs] + for suff, c, path in zip(suffs, counts, discovered_files): + if c != 1 and suff not in AMBIGUOUS_SUFFS: + if files.output.get(suff) is None: + e = ( + "ERROR: Task produced multiple files with same suffix but " + "did not define with which to continue!\n" + f"Task {taskname}, Suffix {suff} found {c} times" + ) + logger.error(e) + raise RuntimeError(e) + + # register discovered output files + for path in discovered_files: + suffix = path.suffix[1:] + if suffix in AMBIGUOUS_SUFFS: + suffix = path.name + # don't overwrite manually added keys in files.output + if files.output.get(suffix) is not None: + continue + files.output[suffix] = files.outputdir / path + + # remove double entries + if "plumed" in files.input.keys(): + if plumed := files.output.get("plumed"): + files.output.pop(plumed.name) + if plumed_out := files.output.get("plumed_out"): + files.output.pop(plumed_out.name) + + logger.debug(f"Update latest files with:\n{pformat(files.output)}") + self.latest_files.update(files.output) + self.filehist.append({files.outputdir.name: files}) + + shortpaths_input = "" + for k, v in files.input.items(): + if v is not None: + shortpaths_input += ( + f' {k}: {str(v).removeprefix(str(self.config.out) + "/")}\n' + ) + + shortpaths_output = "" + for k, v in files.output.items(): + if v is not None: + shortpaths_output += ( + f' {k}: {str(v).removeprefix(str(self.config.out) + "/")}\n' + ) + + with open(self.histfile, "a") as f: + f.write(f"Task: {files.outputdir.name}\n") + f.write(f"Input:\n{shortpaths_input}") + f.write(f"Output:\n{shortpaths_output}\n") + + return files + def _setup(self, files: TaskFiles) -> TaskFiles: """A setup task to collect files processed by kimmdy such as the topology""" logger = files.logger @@ -457,149 +725,6 @@ def _setup(self, files: TaskFiles) -> TaskFiles: return files - def _restart_from_rundir(self): - """Set up RunManager to restart from a run directory""" - - task_dirs = get_task_directories(self.config.restart.run_directory, "all") - logger.debug(f"Found task directories in restart run directory: {task_dirs}") - logger.debug(f"Task queue: {self.tasks.queue}") - - completed_tasks: list[Task] = [] - nested_tasks: dict = {} - self.iteration = 0 - found_run_end = False - while not self.tasks.empty() and not found_run_end: - task: Task = self.tasks.queue[0] - if task.name == "restart_task": - logger.info("Found restart task.") - self.tasks.queue.popleft() - break - if task.out is None: - completed_tasks.append(self.tasks.queue.popleft()) - - else: - if task_dirs[self.iteration :] == []: - logger.info( - f"Found last finished task with task number {self.iteration}." - ) - # Condition 1: Continue from the last finished task - found_run_end = True - for task_dir in task_dirs[self.iteration :]: - if (task_dir / MARK_FAILED).exists(): - raise RuntimeError( - f"Task in directory `{task_dir}` is indicated to " - "have failed. Aborting restart. Remove this task " - "directory if you want to restart from before the failed task." - ) - if (task_dir / MARK_STARTED).exists(): - # symlink task directories from previous output and discover their files - symlink_dir = self.config.out / task_dir.name - symlink_dir.symlink_to(task_dir, target_is_directory=True) - self.iteration += 1 - - task_name = "_".join(task_dir.name.split(sep="_")[1:]) - if task_name == task.out: - task.kwargs.update( - { - "files": TaskFiles( - self.get_latest, {}, {}, symlink_dir - ) - } - ) - completed_tasks.append(self.tasks.queue.popleft()) - if not (task_dir / MARK_DONE).exists(): - logger.info( - f"Found started but not finished task {task_dir}." - ) - if completed_tasks[-1].name == "run_md": - symlink_dir.unlink(missing_ok=True) - shutil.copytree( - task_dir, self.config.out / task_dir.name - ) - kwargs: dict = copy(task.kwargs) - continue_md_task = Task( - self, f=self._run_md, kwargs=kwargs, out=None - ) - - continue_md_task.kwargs.update( - {"continue_md": True} - ) - self.priority_tasks.put(continue_md_task) - # Condition 2: Continue from started but not finished task - found_run_end = True - break - else: - # task probably not unique but having the latest of one kind should suffice - if not completed_tasks[-1] in nested_tasks.keys(): - nested_tasks[completed_tasks[-1]] = [] - nested_tasks[completed_tasks[-1]].append(symlink_dir) - else: - raise RuntimeError( - f"Encountered task directory {task_dir.name} but the" - " task is not indicated to have started. Aborting restart." - ) - - # add completed tasks to queue again until a reliable restart point (i.e after MD) is reached - while completed_tasks: - if completed_tasks[-1].name == "run_md": - logger.info( - f"Will continue after task {completed_tasks[-1].kwargs['files'].outputdir}" - ) - - self.iteration -= 1 - break - else: - current_nested_task_dirs = nested_tasks.get(completed_tasks[-1], []) - try: - current_task_dir = [completed_tasks[-1].kwargs["files"].outputdir] - except KeyError: - current_task_dir = [] - for task_dir in [ - *current_nested_task_dirs, - *current_task_dir, - ]: - task_dir.unlink(missing_ok=True) - self.iteration -= 1 - completed_tasks[-1].kwargs.pop("files", None) - self.tasks.queue.appendleft(completed_tasks.pop()) - else: - self.iteration -= 1 - - # discover after it is clear which tasks will be in queue - for task_dir in get_task_directories(self.config.out, "all"): - task_name = "_".join(task_dir.name.split(sep="_")[1:]) - task_files = TaskFiles( - self.get_latest, {}, {}, self.config.out / task_dir.name - ) - self._discover_output_files(task_name, task_files) - - # plumed fix - for md_config in self.config.mds.__dict__.values(): - if getattr(md_config, "use_plumed"): - try: - plumed_out_name = get_plumed_out(self.latest_files["plumed"]).name - self.latest_files["plumed_out"] = self.get_latest(plumed_out_name) - self.latest_files.pop(plumed_out_name) - except FileNotFoundError as e: - logger.debug(e) - - # use latest top file - self.top = Topology( - top=read_top(self.get_latest("top"), self.config.ff), - parametrizer=self.parameterizer, - is_reactive_predicate_f=get_is_reactive_predicate_from_config_f( - self.config.topology.reactive - ), - radicals=getattr(self.config, "radicals", None), - residuetypes_path=getattr(self.config, "residuetypes", None), - ) - - def _restart_task(self, _: TaskFiles) -> None: - raise RuntimeError( - "Called restart task. This task is only for finding the restart " - "point in the sequence and should never be called!" - ) - def _run_md( self, instance: str, files: TaskFiles, continue_md: bool = False ) -> TaskFiles: @@ -617,6 +742,8 @@ def _run_md( mdp = files.input["mdp"] ndx = files.input["ndx"] + outputdir = files.outputdir + # to continue MD after timeout if continue_md: cpt = files.input["cpt"] @@ -624,21 +751,20 @@ def _run_md( else: cpt = f"{instance}.cpt" - outputdir = files.outputdir - - grompp_cmd = ( - f"{gmx_alias} grompp -p {top} -c {gro} " - f"-f {mdp} -n {ndx} -o {instance}.tpr -maxwarn 5" - ) - - # optional files for grompp: - if self.latest_files.get("trr") is not None: - trr = files.input["trr"] - grompp_cmd += f" -t {trr}" - ## disable use of edr for now - # if self.latest_files.get("edr") is not None: - # edr = files.input["edr"] - # grompp_cmd += f" -e {edr}" + # running grompp again fails for pulling MD, skip it for restart because it is not necessary + grompp_cmd = ( + f"{gmx_alias} grompp -p {top} -c {gro} " + f"-f {mdp} -n {ndx} -o {instance}.tpr -maxwarn 5" + ) + # optional files for grompp: + if self.latest_files.get("trr") is not None: + trr = files.input["trr"] + grompp_cmd += f" -t {trr}" + ## disable use of edr for now + # if self.latest_files.get("edr") is not None: + # edr = files.input["edr"] + # grompp_cmd += f" -e {edr}" + logger.debug(f"grompp cmd: {grompp_cmd}") mdrun_cmd = ( f"{gmx_alias} mdrun -s {instance}.tpr -cpi {cpt} " @@ -652,15 +778,20 @@ def _run_md( ) if getattr(md_config, "use_plumed"): - mdrun_cmd += f" -plumed {files.input['plumed']}" + plumed_in = files.input["plumed"] + if plumed_in is None: + m = "Plumed input file not found in input files." + logger.error(m) + raise FileNotFoundError(m) + mdrun_cmd += f" -plumed {plumed_in}" - plumed_out = files.outputdir / get_plumed_out(files.input["plumed"]) + plumed_out = files.outputdir / get_plumed_out(plumed_in) files.output["plumed_out"] = plumed_out - logger.debug(f"grompp cmd: {grompp_cmd}") logger.debug(f"mdrun cmd: {mdrun_cmd}") try: - run_gmx(grompp_cmd, outputdir) + if continue_md is False: + run_gmx(grompp_cmd, outputdir) run_gmx(mdrun_cmd, outputdir) # specify trr to prevent rotref trr getting set as standard trr @@ -851,6 +982,9 @@ def _apply_recipe(self, files: TaskFiles) -> TaskFiles: # Set time to chosen 'time_start' of KMCResult ttime = self.kmcresult.time_start plugin_time_index = self.kmcresult.time_start_index_within_plugin + + shadow_files_binding = None + logger.info(f"Chosen time_start: {ttime} ps") logger.info(f"Time index within plugin: {plugin_time_index}") if plugin_time_index is None: @@ -884,11 +1018,29 @@ def _apply_recipe(self, files: TaskFiles) -> TaskFiles: with open(files.outputdir / "vmd_selection.txt", "w") as f: f.write(vmd_selection) - if not self.config.skip_truncation: - # truncate simulation files to the chosen time - m = f"Truncating simulation files to time {ttime} ps" - logger.info(m) - truncate_sim_files(files=files, time=ttime) + # write time marker for reaction time + # in the current task dir (_apply_recipe) + # but also in the output dir of the MD task + # onto which the reaction is applied + write_reaction_time_marker(dir=files.outputdir, time=ttime) + gro = files.input["gro"] + if gro is None: + m = "No gro file found from the previous md run." + logger.error(m) + else: + write_reaction_time_marker(dir=gro.parent, time=ttime) + + # because the gro_reaction file is written to files.output + # it will be discovered by _discover_output_files + # and set as the latest gro file for the next tasks + # but this only happens after the apply_recipe task + # so we need to set it manually here for intermediate tasks + # like Relax and Place to have the correct coordinates + logger.info(f"Writing coordinates (gro and trr) for reaction.") + if ttime is not None: + write_coordinate_files_at_reaction_time(files=files, time=ttime) + self.latest_files["gro"] = files.output["gro"] + self.latest_files["trr"] = files.output["trr"] top_initial = deepcopy(self.top) focus_nrs: set[str] = set() @@ -906,15 +1058,16 @@ def _apply_recipe(self, files: TaskFiles) -> TaskFiles: self.top.bind_bond((step.atom_id_1, step.atom_id_2)) focus_nrs.update([step.atom_id_1, step.atom_id_2]) elif isinstance(step, Place): - task = Task( + relax_task = Task( self, f=place_atom, kwargs={"step": step, "ttime": None}, out="place_atom", ) - place_files = task() + place_files = relax_task() if place_files is not None: - self._discover_output_files(task.name, place_files) + self._discover_output_files(relax_task.name, place_files) + shadow_files_binding = place_files if step.id_to_place is not None: focus_nrs.update([step.id_to_place]) @@ -980,15 +1133,16 @@ def _apply_recipe(self, files: TaskFiles) -> TaskFiles: write_top(top_merge.to_dict(), top_merge_path) self.latest_files["top"] = top_merge_path instance = self.config.changer.coordinates.md - task = Task( + relax_task = Task( self, f=self._run_md, kwargs={"instance": instance}, out=instance, ) - md_files = task() - if md_files is not None: - self._discover_output_files(task.name, md_files) + relax_task_files = relax_task() + if relax_task_files is not None: + self._discover_output_files(relax_task.name, relax_task_files) + shadow_files_binding = relax_task_files elif isinstance(step, CustomTopMod): step.f(self.top) @@ -1004,6 +1158,18 @@ def _apply_recipe(self, files: TaskFiles) -> TaskFiles: # Recipe done, reset runmanger state self.kmcresult = None + if shadow_files_binding is not None: + # if a relaxation or placement task was run, + # we overwrite the coordinate output files of + # the files object with the files from the relaxation or placement task + # (whichever was later) + # such that the next task will use these files + files.output["gro"] = shadow_files_binding.output["gro"] + files.output["trr"] = shadow_files_binding.output["trr"] + files.output["xtc"] = shadow_files_binding.output["xtc"] + # but not the `top`, because the top for the relaxation + # is only temporary and should not be used for the next task + logger.info("Done with Apply recipe") return files diff --git a/src/kimmdy/tasks.py b/src/kimmdy/tasks.py index 674a6b6e..5aa8eaa9 100644 --- a/src/kimmdy/tasks.py +++ b/src/kimmdy/tasks.py @@ -19,12 +19,16 @@ class AutoFillDict(dict): """Dictionary that gets populated by calling get_missing.""" - def __init__(self, get_missing: Callable): + def __init__(self, get_missing: Callable[[str], Path | None]): self.get_missing = get_missing - def __missing__(self, key): - self[key] = self.get_missing(key) - return self[key] + def __missing__(self, key: str) -> None | Path: + v = self.get_missing(key) + if v is not None: + self[key] = v + return v + else: + return None @dataclass @@ -63,8 +67,8 @@ class TaskFiles: {'top': 'latest top'} """ - get_latest: Callable - input: dict[str, Path] = field(default_factory=dict) + get_latest: Callable[[str], Path | None] + input: dict[str, Path | None] = field(default_factory=dict) output: dict[str, Path] = field(default_factory=dict) outputdir: Path = Path() logger: logging.Logger = logging.getLogger("kimmdy.basetask") @@ -73,7 +77,9 @@ def __post_init__(self): self.input = AutoFillDict(self.get_latest) -def create_task_directory(runmng, postfix: str) -> TaskFiles: +def create_task_directory( + runmng, postfix: str, is_continuation: bool = False +) -> TaskFiles: """Creates TaskFiles object, output directory, logger and symlinks ff. Gets called when a Task is called (from the runmanager.tasks queue). @@ -86,12 +92,13 @@ def create_task_directory(runmng, postfix: str) -> TaskFiles: # create outputdir files.outputdir = runmng.config.out / taskname logger.debug(f"Creating Output directory: {files.outputdir}") - if files.outputdir.exists(): + if files.outputdir.exists() and not is_continuation: logger.warning( f"Output directory {files.outputdir} for the task already exists. Deleting." ) shutil.rmtree(files.outputdir) - files.outputdir.mkdir() + if not is_continuation: + files.outputdir.mkdir() # set up logger files.logger = logging.getLogger(f"kimmdy.{taskname}") @@ -151,13 +158,34 @@ def __init__( logger.debug(f"Init task {self.name}\tkwargs: {self.kwargs}\tOut: {self.out}") def __call__(self) -> Optional[TaskFiles]: - logger.info(f"Starting task: {self.name} with args: {self.kwargs}") + logger.debug( + f"Starting task: {self.name} with args: {self.kwargs} in {self.runmng.iteration}_{self.out}" + ) + logger.info(f"Starting task: {self.name} in {self.runmng.iteration}_{self.out}") if self.out is not None: - self.kwargs.update({"files": create_task_directory(self.runmng, self.out)}) + is_continuation = False + if self.kwargs.get("continue_md") is True: + logger.info( + f"Continuing task: {self.name} in {self.runmng.iteration}_{self.out}" + ) + is_continuation = True + self.kwargs.update( + { + "files": create_task_directory( + self.runmng, self.out, is_continuation=is_continuation + ) + } + ) write_time_marker(self.kwargs["files"].outputdir / MARK_STARTED, self.name) + logger.info( + f"Wrote kimmdy start marker for task: {self.name} in {self.runmng.iteration}_{self.out}" + ) files = self.f(**self.kwargs) if self.out is not None: write_time_marker(self.kwargs["files"].outputdir / MARK_DONE, self.name) + logger.info( + f"Wrote kimmdy done marker for task: {self.name} in {self.runmng.iteration}_{self.out}" + ) logger.info(f"Finished task: {self.name}") if files is not None and files.logger: for h in files.logger.handlers: diff --git a/src/kimmdy/topology/topology.py b/src/kimmdy/topology/topology.py index eccb9ca7..57dd5699 100644 --- a/src/kimmdy/topology/topology.py +++ b/src/kimmdy/topology/topology.py @@ -759,11 +759,11 @@ def _extract_mergable_molecules(self) -> dict[str, int]: new_molecules += [(m, n)] self.molecules = new_molecules - logger.info( + logger.debug( "Merging the following molecules into the Reactive moleculetype and making their multiples explicit:" ) for m, n in reactive_molecules.items(): - logger.info(f"\t{m} {n}") + logger.debug(f"\t{m} {n}") return reactive_molecules def _merge_moleculetypes( @@ -889,7 +889,7 @@ def _update_dict_from_moleculetype_atomics( section = top.get(moleculetype_name) if section is None: if create: - logger.info( + logger.debug( f"topology does not contain {moleculetype_name}. Creating new section." ) section = empty_section() @@ -1370,7 +1370,7 @@ def bind_bond( atom.atom = "HX" name_set = True continue - logger.info(f"Hydrogen will be bound to {other_atom}.") + logger.debug(f"Hydrogen will be bound to {other_atom}.") break else: if name_set: @@ -1383,7 +1383,7 @@ def bind_bond( ) if not name_set: atom.atom = "HX" - logger.info(f"Named newly bonded hydrogen 'HX'") + logger.debug(f"Named newly bonded hydrogen 'HX'") # update bound_to atompair[0].bound_to_nrs.append(atompair[1].nr) @@ -1461,10 +1461,10 @@ def bind_bond( to_delete.append(exclusion_key) if len(to_delete) > 0: - logger.info(f"Removing exclusions {to_delete}") + logger.debug(f"Removing exclusions {to_delete}") for key in to_delete: reactive_moleculetype.exclusions.pop(key) settles = reactive_moleculetype.settles.get(ai) if settles is not None: - logger.info(f"Removing settles {ai}") + logger.debug(f"Removing settles {ai}") reactive_moleculetype.settles.pop(ai) diff --git a/src/kimmdy/utils.py b/src/kimmdy/utils.py index 55ba46bb..fda32b0e 100644 --- a/src/kimmdy/utils.py +++ b/src/kimmdy/utils.py @@ -8,10 +8,11 @@ import re import subprocess as sp from pathlib import Path -from typing import TYPE_CHECKING, Iterable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np +from kimmdy.constants import MARK_REACION_TIME from kimmdy.recipe import RecipeCollection if TYPE_CHECKING: @@ -64,15 +65,15 @@ def run_shell_cmd(s, cwd=None) -> sp.CompletedProcess: return sp.run(s, shell=True, cwd=cwd, capture_output=True, text=True) -def run_gmx(s: str, cwd=None) -> Optional[sp.CalledProcessError]: +def run_gmx(cmd: str, cwd=None) -> Optional[sp.CalledProcessError]: """Run GROMACS command in shell. Adds a '-quiet' flag to the command and checks the return code. """ - logger.debug(f"Starting Gromacs process with command {s} in {cwd}.") - result = run_shell_cmd(f"{s} -quiet", cwd) + logger.debug(f"Starting Gromacs process with command {cmd} in {cwd}.") + result = run_shell_cmd(f"{cmd} -quiet", cwd) if result.returncode != 0: - logger.error(f"Gromacs process with command {s} in {cwd} failed.") + logger.error(f"Gromacs process with command {cmd} in {cwd} failed.") logger.error(f"Gromacs exit code {result.returncode}.") logger.error(f"Gromacs stdout:\n{result.stdout}.") logger.error(f"Gromacs stderr:\n{result.stderr}.") @@ -384,115 +385,88 @@ def check_gmx_version(config): return version -def truncate_sim_files( - files: TaskFiles, - time: Optional[float], - keep_tail: bool = True, -): - """Truncates latest trr, xtc, edr, and gro to the time to a previous - point in time. +def write_reaction_time_marker(dir: Path, time: float): + """Write out a file as marker for the reaction time.""" + logger.info( + f"Writing reaction time marker {time} to {dir.name}/{MARK_REACION_TIME}" + ) + with open(dir / MARK_REACION_TIME, "w") as f: + f.write(str(time)) - The files stay in place, the truncated tail is by default kept and renamed - to '[...xtc].tail' - Parameters - ---------- - time - Time in ps up to which the data should be truncated. - files - TaskFiles to get the latest files. - """ +def read_reaction_time_marker(dir: Path) -> float | None: + if not (dir / MARK_REACION_TIME).exists(): + return None + with open(dir / MARK_REACION_TIME, "r") as f: + return float(f.read()) - # TODO: fix this to correctly use the working directory - # and do things in it's own directory and not - # modify the input files in place - if time is None: - logger.debug("time is None, nothing to truncate") - return +def write_coordinate_files_at_reaction_time(files: TaskFiles, time: float): + """Write out a gro file from the trajectory (xtc or trr) at the reaction time.""" + gro = files.input["gro"] + if gro is None: + m = "No gro file found from the previous md run." + logger.error(m) + raise FileNotFoundError(m) - paths = {} - paths["gro"] = files.input["gro"] - for s in ["trr", "xtc", "edr"]: - try: - paths[s] = files.input[s] - except FileNotFoundError: - paths[s] = None - - trjs = [p for p in [paths["trr"], paths["xtc"]] if p is not None and p.exists()] - - # trr or xtc must be present - if len(trjs) == 0: - logger.info("No trajectory files found, nothing to truncate.") - return - - for trj in trjs: - # check time exists in traj - p = sp.run( - f"gmx -quiet -nocopyright check -f {trj}", - text=True, - capture_output=True, - shell=True, - ) - # FOR SOME REASON gmx check writes in stderr instead of stdout - if ms := re.findall(r"Reading frame\s.*time\s+(\d+\.\d+)", p.stderr): - if len(ms) == 0: - m = f"Could not find time in trajectory {trj} with gmx check." - logger.warning(m) - return - last_time = float(ms[-1]) - if last_time == 0.0: - logger.info( - "Last traj contains single frame, will not truncate anything." - ) - return - if last_time * 1.01 <= time: - m = f"Requested to truncate trajectory at time {time} but last frame according to gmx check is at {last_time:.4} ps. This might led to unexpected results." - logger.warning(m) - else: - m = f"gmx check failed:\n{p.stdout}\n{p.stderr}.\nMay not be able to truncate trajectory {trj}." - logger.error(m) - raise RuntimeError(m) - logger.info( - f"Truncating trajectories to {time:.4} ps. Trajectory time was {last_time:.4} ps" - ) + if gro.name.endswith("_reaction.gro"): + m = f"The latest gro file registered already ends in _reaction.gro. This state should not be possible unless multiple reactions where run in sequence without any MD in between (even relaxation)." + logger.error(m) - # backup the tails of trajectories - for trj in trjs: - tmp = trj.rename(trj.with_name("tmp_backup_" + trj.name)) - if keep_tail: - run_gmx( - f"gmx trjconv -f {tmp} -b {time} -o {trj}", - ) - trj.rename(str(trj) + ".tail") + gro_reaction = gro.with_name(gro.stem + f"_reaction.gro") + trr_reaction = gro.with_name(gro.stem + f"_reaction.trr") - run_gmx(f"gmx trjconv -f {tmp} -e {time} -o {trj}") - tmp.unlink() + if gro_reaction.exists() or trr_reaction.exists(): + m = f"gro/trr file at reaction time {time} already exists in {gro.parent}. Removing it. This may happen by restarting from a previous run." + logger.error(m) + gro_reaction.unlink() + trr_reaction.unlink() - # backup the gro - bck_gro = paths["gro"].rename( - paths["gro"].with_name("tmp_backup_" + paths["gro"].name) - ) - sp.run( - f"gmx trjconv -f {trjs[0]} -s {bck_gro} -dump -1 -o {paths['gro']}", - text=True, - input="0", - shell=True, + logger.info( + f"Writing out gro/trr file {gro_reaction.name}/{trr_reaction.name} at reaction time {time} ps in {gro.parent.name}" ) - bck_gro.rename(str(paths["gro"]) + ".tail") - if not keep_tail: - bck_gro.unlink() - - # backup the edr - if paths["edr"] is not None: - bck_edr = paths["edr"].rename( - paths["edr"].with_name("tmp_backup_" + paths["edr"].name) - ) - run_shell_cmd(f"gmx eneconv -f {bck_edr} -e {time} -o {paths['edr']}") - bck_edr.rename(str(paths["edr"]) + ".tail") - if not keep_tail: - bck_edr.unlink() - return + files.output["gro"] = gro_reaction + files.output["trr"] = trr_reaction + + # Prefer xtc over trr + # It should have more frames and be smaller, + # but sometimes the people only write a specific index group to the xtc, + # in which case it fails and we try the trr + if files.input["xtc"] is not None: + try: + run_gmx( + f"echo '0' | gmx trjconv -f {files.input['xtc']} -s {gro} -b {time} -dump {time} -o {gro_reaction}" + ) + run_gmx( + f"echo '0' | gmx trjconv -f {files.input['xtc']} -s {gro} -b {time} -dump {time} -o {trr_reaction}" + ) + logger.info( + f"Successfully wrote out gro/trr file {gro_reaction.name}/{trr_reaction.name} at reaction time in {gro.parent.name} from xtc file." + ) + return + except sp.CalledProcessError: + logger.error( + f"Failed to write out gro/trr file {gro_reaction.name}/{trr_reaction.name} at reaction time in {gro.parent.name} from xtc file because the xtc doesn't contain all atoms. Will try trr file." + ) + if files.input["trr"] is not None: + try: + run_gmx( + f"echo '0' | gmx trjconv -f {files.input['trr']} -s {gro} -b {time} -dump {time} -o {gro_reaction}" + ) + run_gmx( + f"echo '0' | gmx trjconv -f {files.input['trr']} -s {gro} -b {time} -dump {time} -o {trr_reaction}" + ) + logger.info( + f"Successfully wrote out gro/trr file at reaction time in {gro.parent} from trr file." + ) + return + except sp.CalledProcessError: + logger.error( + f"Failed to write out gro/trr file at reaction time in {gro.parent} from trr file." + ) + m = f"No trajectory file found to write out gro/trr file at reaction time in {gro.parent}" + logger.error(m) + raise FileNotFoundError(m) def get_task_directories(dir: Path, tasks: Union[list[str], str] = "all") -> list[Path]: @@ -508,17 +482,14 @@ def get_task_directories(dir: Path, tasks: Union[list[str], str] = "all") -> lis List of steps e.g. ["equilibrium", "production"]. Or a string "all" to return all subdirectories """ directories = sorted( - [p for p in dir.glob("*_*/") if p.is_dir()], + [ + p + for p in dir.glob("*_*/") + if p.is_dir() and "_" in p.name and p.name[0].isdigit() + ], key=lambda p: int(p.name.split("_")[0]), ) if tasks == "all": - matching_directories = directories + return directories else: - matching_directories = list( - filter(lambda d: d.name.split("_")[1] in tasks, directories) - ) - - if not matching_directories: - print(f"WARNING: Could not find directories {tasks} in {dir}.") - - return matching_directories + return [d for d in directories if d.name.split("_")[1] in tasks] diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py index 885743da..dc128f51 100644 --- a/tests/test_coordinates.py +++ b/tests/test_coordinates.py @@ -16,7 +16,6 @@ place_atom, break_bond_plumed, ) -from kimmdy.utils import truncate_sim_files def test_get_explicit_MultipleDihedrals(arranged_tmp_path): @@ -151,31 +150,34 @@ def test_merge_prm_top(arranged_tmp_path): # assert one dihedral merge improper/proper -@pytest.mark.require_gmx -def test_truncate_sim_files(arranged_tmp_path): - files = DummyFiles() - files.input = { - "trr": arranged_tmp_path / "relax.trr", - "xtc": arranged_tmp_path / "relax.xtc", - "edr": arranged_tmp_path / "relax.edr", - "gro": arranged_tmp_path / "relax.gro", - } - files.outputdir = arranged_tmp_path - time = 5.2 - truncate_sim_files(files, time) - - for p in files.input.values(): - assert p.exists() - assert p.with_name(p.name + ".tail").exists() - - p = sp.run( - f"gmx -quiet -nocopyright check -f {files.input['trr']}", - text=True, - capture_output=True, - shell=True, - ) - # FOR SOME REASON gmx check writes in stderr instead of stdout - m = re.search(r"Last frame.*time\s+(\d+\.\d+)", p.stderr) - assert m, p.stderr - last_time = m.group(1) - assert last_time == "5.000" +# @pytest.mark.require_gmx +# def test_truncate_sim_files(arranged_tmp_path): +# files = DummyFiles() +# files.input = { +# "trr": arranged_tmp_path / "relax.trr", +# "xtc": arranged_tmp_path / "relax.xtc", +# "edr": arranged_tmp_path / "relax.edr", +# "gro": arranged_tmp_path / "relax.gro", +# } +# files.outputdir = arranged_tmp_path +# # time = 5.2 +# # TODO: truncate during run is deprecated +# # kimmdy writes marker files for the reaction +# # time instead +# # truncate_sim_files(files, time) +# +# for p in files.input.values(): +# assert p.exists() +# assert p.with_name(p.name + ".tail").exists() +# +# p = sp.run( +# f"gmx -quiet -nocopyright check -f {files.input['trr']}", +# text=True, +# capture_output=True, +# shell=True, +# ) +# # FOR SOME REASON gmx check writes in stderr instead of stdout +# m = re.search(r"Last frame.*time\s+(\d+\.\d+)", p.stderr) +# assert m, p.stderr +# last_time = m.group(1) +# assert last_time == "5.000" diff --git a/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart.yml b/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart.yml index e84d0a9b..543f45a0 100644 --- a/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart.yml +++ b/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart.yml @@ -8,8 +8,7 @@ top: 'Ala_out.top' gro: 'npt.gro' ndx: 'index.ndx' kmc: "rfkmc" -restart: - run_directory: 'alanine_hat_000' +restart: true mds: equilibrium: mdp: 'md.mdp' diff --git a/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart_task.yml b/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart_task.yml index 25a03663..a72dee39 100644 --- a/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart_task.yml +++ b/tests/test_files/test_integration/alanine_hat_naive/kimmdy_restart_task.yml @@ -8,8 +8,7 @@ top: 'Ala_out.top' gro: 'npt.gro' ndx: 'index.ndx' kmc: "rfkmc" -restart: - run_directory: 'alanine_hat_000' +restart: true mds: equilibrium: mdp: 'md.mdp' diff --git a/tests/test_files/test_integration/charged_peptide_homolysis_hat_naive/kimmdy.yml b/tests/test_files/test_integration/charged_peptide_homolysis_hat_naive/kimmdy.yml index fd8af9bc..2a274ad3 100644 --- a/tests/test_files/test_integration/charged_peptide_homolysis_hat_naive/kimmdy.yml +++ b/tests/test_files/test_integration/charged_peptide_homolysis_hat_naive/kimmdy.yml @@ -30,7 +30,6 @@ reactions: frequency_factor: 100000000 h_cutoff: 3 polling_rate: 1 -plot_rates: true save_recipes: true sequence: - equilibrium diff --git a/tests/test_integration.py b/tests/test_integration.py index bac59f11..175a3360 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -5,9 +5,11 @@ import pytest from kimmdy.cmd import kimmdy_run -from kimmdy.constants import MARK_DONE +from kimmdy.config import Config +from kimmdy.constants import MARK_DONE, MARK_FINISHED from kimmdy.parsing import read_top, write_top -from kimmdy.plugins import parameterization_plugins +from kimmdy.plugins import discover_plugins, parameterization_plugins +from kimmdy.runmanager import RunManager from kimmdy.topology.topology import Topology from kimmdy.utils import get_task_directories @@ -27,12 +29,11 @@ def read_last_line(file): "arranged_tmp_path", (["test_integration/emptyrun"]), indirect=True ) def test_integration_emptyrun(arranged_tmp_path): - # not expecting this to run - # because the topology is empty Path("emptyrun.txt").touch() with pytest.raises(ValueError): kimmdy_run() - assert len(list(Path.cwd().glob("emptyrun_001/*"))) == 2 + assert len(list(Path.cwd().glob("emptyrun_001/*"))) == 3 + assert not (arranged_tmp_path / "minimal" / MARK_FINISHED).exists() @pytest.mark.parametrize( @@ -40,8 +41,9 @@ def test_integration_emptyrun(arranged_tmp_path): ) def test_integration_valid_input_files(arranged_tmp_path): kimmdy_run() - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("minimal/*"))) == 2 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert (arranged_tmp_path / "minimal" / MARK_FINISHED).exists() + assert len(list(Path.cwd().glob("minimal/*"))) == 4 @pytest.mark.parametrize( @@ -121,8 +123,8 @@ def test_grappa_partial_parameterization(arranged_tmp_path): ) def test_integration_single_reaction(arranged_tmp_path): kimmdy_run(input=Path("kimmdy.yml")) - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("single_reaction_000/*"))) == 7 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert len(list(Path.cwd().glob("single_reaction_000/*"))) == 8 @pytest.mark.parametrize( @@ -132,8 +134,8 @@ def test_integration_single_reaction(arranged_tmp_path): ) def test_integration_just_reactions(arranged_tmp_path): kimmdy_run(input=Path("alternative_kimmdy.yml")) - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("single_reaction_000/*"))) == 7 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert len(list(Path.cwd().glob("single_reaction_000/*"))) == 8 @pytest.mark.slow @@ -142,8 +144,9 @@ def test_integration_just_reactions(arranged_tmp_path): ) def test_integration_hat_naive_reaction(arranged_tmp_path): kimmdy_run() - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("alanine_hat_000/*"))) == 15 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + print(list(Path.cwd().glob("alanine_hat_000/*"))) + assert len(list(Path.cwd().glob("alanine_hat_000/*"))) == 16 @pytest.mark.slow @@ -152,8 +155,8 @@ def test_integration_hat_naive_reaction(arranged_tmp_path): ) def test_integration_homolysis_reaction(arranged_tmp_path): kimmdy_run() - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("hexalanine_homolysis_000/*"))) == 12 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert len(list(Path.cwd().glob("hexalanine_homolysis_000/*"))) == 13 @pytest.mark.slow @@ -162,8 +165,8 @@ def test_integration_homolysis_reaction(arranged_tmp_path): ) def test_integration_pull(arranged_tmp_path): kimmdy_run() - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("kimmdy_001/*"))) == 11 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert len(list(Path.cwd().glob("kimmdy_001/*"))) == 12 @pytest.mark.require_grappa @@ -175,8 +178,8 @@ def test_integration_pull(arranged_tmp_path): ) def test_integration_whole_run(arranged_tmp_path): kimmdy_run() - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert len(list(Path.cwd().glob("kimmdy_001/*"))) == 24 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert len(list(Path.cwd().glob("kimmdy_001/*"))) == 25 @pytest.mark.slow @@ -185,32 +188,78 @@ def test_integration_whole_run(arranged_tmp_path): ) def test_integration_restart(arranged_tmp_path): run_dir = Path("alanine_hat_000") - restart_dir = Path("alanine_hat_001") - # get reference kimmdy_run(input=Path("kimmdy_restart.yml")) n_files_original = len(list(run_dir.glob("*"))) - # try restart from restart task - kimmdy_run(input=Path("kimmdy_restart_task.yml")) - n_files_restart_task = len(list(restart_dir.glob("*"))) - - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert n_files_original == n_files_restart_task == 16 + # restart already finished run + kimmdy_run(input=Path("kimmdy_restart.yml")) + assert "already finished" in read_last_line(Path("kimmdy.log")) - # try restart from stopped md (doesn't work with truncated trajectory) - task_dirs = get_task_directories(run_dir, "all") + # try restart from stopped md + task_dirs = get_task_directories(run_dir) (task_dirs[-1] / MARK_DONE).unlink() + (arranged_tmp_path / run_dir / MARK_FINISHED).unlink() kimmdy_run(input=Path("kimmdy_restart.yml")) - n_files_continue_md = len(list(restart_dir.glob("*"))) + n_files_continue_md = len(list(run_dir.glob("*"))) - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert n_files_original == n_files_continue_md == 16 + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert n_files_original == n_files_continue_md == 17 - # try restart from finished task - task_dirs = get_task_directories(run_dir, "all") + # try restart from finished md + task_dirs = get_task_directories(run_dir) (task_dirs[-4] / MARK_DONE).unlink() + (arranged_tmp_path / run_dir / MARK_FINISHED).unlink() kimmdy_run(input=Path("kimmdy_restart.yml")) - n_files_restart = len(list(restart_dir.glob("*"))) + n_files_restart = len(list(run_dir.glob("*"))) + + assert "Finished running last task" in read_last_line(Path("kimmdy.log")) + assert n_files_original == n_files_restart == 17 - assert "Finished running tasks" in read_last_line(Path("kimmdy.log")) - assert n_files_original == n_files_restart == 16 + +@pytest.mark.slow +@pytest.mark.parametrize( + "arranged_tmp_path", (["test_integration/alanine_hat_naive"]), indirect=True +) +def test_integration_file_usage(arranged_tmp_path): + """Do a kimmdy run + and verify that at each task the correct files are used + specifically when writing out gro and trr files when + applying a reaction. + """ + kimmdy_run() + histfile = Path("alanine_hat_000/kimmdy.history").read_text().split("\n\n") + header, blocks = histfile[0], histfile[1:] + tasks = {} + for block in blocks: + lines = block.split("\n") + name = lines[0].removeprefix("Task: ") + tasks[name] = {"input": {}, "output": {}} + section = "input" + for l in lines[1:]: + if l.startswith("Input:"): + section = "input" + elif l.startswith("Output:"): + section = "output" + elif l.startswith(" "): + tasks[name][section][l.split(": ")[0].strip()] = l.split(": ")[ + 1 + ].strip() + + assert tasks["0_setup"]["output"]["gro"] == "0_setup/npt.gro" + + assert tasks["5_apply_recipe"]["output"]["gro"] == "6_relax/relax.gro" + assert tasks["5_apply_recipe"]["output"]["trr"] == "6_relax/relax.trr" + assert tasks["5_apply_recipe"]["output"]["xtc"] == "6_relax/relax.xtc" + assert tasks["5_apply_recipe"]["output"]["top"] == "5_apply_recipe/Ala_out.top" + + assert tasks["6_relax"]["input"]["top"] == "5_apply_recipe/Ala_out_mod.top" + assert tasks["6_relax"]["input"]["gro"] == "2_equilibrium/equilibrium_reaction.gro" + assert tasks["6_relax"]["input"]["trr"] == "2_equilibrium/equilibrium_reaction.trr" + assert tasks["6_relax"]["output"]["trr"] == "6_relax/relax.trr" + assert tasks["6_relax"]["output"]["xtc"] == "6_relax/relax.xtc" + + assert tasks["7_equilibrium"]["input"]["gro"] == "6_relax/relax.gro" + assert tasks["7_equilibrium"]["input"]["trr"] == "6_relax/relax.trr" + assert tasks["7_equilibrium"]["output"]["trr"] == "7_equilibrium/equilibrium.trr" + assert tasks["7_equilibrium"]["output"]["cpt"] == "7_equilibrium/equilibrium.cpt" + assert tasks["7_equilibrium"]["output"]["xtc"] == "7_equilibrium/equilibrium.xtc" diff --git a/tests/test_runmanager.py b/tests/test_runmanager.py index 9b7992e4..e69de29b 100644 --- a/tests/test_runmanager.py +++ b/tests/test_runmanager.py @@ -1,46 +0,0 @@ -from kimmdy.config import Config -from kimmdy.runmanager import RunManager -from pathlib import Path - -# reimplement this without checkpoints -# def test_tasks_are_set_up(arranged_tmp_path): -# """ -# use the initial checkpoint writing -# to test properties of the runmanager -# initialization without having to `.run()` it. -# """ -# config = Config(Path("config1.yml")) -# runmgr = RunManager(config) -# runmgr.write_one_checkpoint() - -# items = [] -# while not runmgr.tasks.empty(): -# items.append(runmgr.tasks.get().name) -# assert items == [ -# "_setup", -# "_run_md", -# "_run_md", -# "_place_reaction_tasks", -# "_decide_recipe", -# "_apply_recipe", -# ] - -# config2 = Config(Path("config2.yml")) -# runmgr2 = RunManager(config2) -# runmgr2.write_one_checkpoint() - -# items = [] -# while not runmgr2.tasks.empty(): -# items.append(runmgr2.tasks.get().name) -# assert items == [ -# "_setup", -# "_run_md", -# "_run_md", -# "_place_reaction_tasks", -# "_decide_recipe", -# "_apply_recipe", -# "_place_reaction_tasks", -# "_decide_recipe", -# "_apply_recipe", -# "_run_md", -# ]