import json
import typing
from pathlib import Path
import dolfin
from tqdm import tqdm
from . import save_load_functions as io
from . import utils
from .config import Config
from .models import em_model
from .time_stepper import TimeStepper
logger = utils.getLogger(__name__)
[docs]
class Runner:
def __init__(
self,
config: typing.Optional[Config] = None,
empty: bool = False,
) -> None:
if config is None:
config = Config()
self._config = config
self.outdir.mkdir(exist_ok=True)
# Save config to outdir
def serialize(x):
if isinstance(x, Path):
return x.as_posix()
return x
(self.outdir / "config.json").write_text(
json.dumps(self._config.as_dict(), default=serialize),
)
from . import set_log_level
set_log_level(config.loglevel)
if empty:
return
self._reset = not self._config.load_state
if self._config.load_state and self.state_path.is_file():
# Load state
logger.info("Load previously saved state")
self.coupling = io.load_state(
path=self.state_path,
drug_factors_file=self._config.drug_factors_file,
popu_factors_file=self._config.popu_factors_file,
disease_state=self._config.disease_state,
PCL=self._config.PCL, # Set bcl from cli
)
else:
logger.info("Create a new state")
# Create a new state
self.coupling = em_model.setup_EM_model_from_config(self._config)
self._t0 = self.coupling.t
self._time_stepper: typing.Optional[TimeStepper] = None
self._setup_datacollector()
logger.info(f"Starting at t0={self._t0}")
@property
def _dt(self):
return self._config.dt
@_dt.setter
def _dt(self, value):
self._config.dt = value
@property
def state_path(self) -> Path:
return self.outdir / "state.h5"
@property
def outdir(self) -> Path:
return Path(self._config.outdir)
@property
def t(self) -> float:
if self._time_stepper is None:
raise RuntimeError("Please create a time stepper before solving")
return self._time_stepper.t
@property
def t0(self) -> float:
return self._t0
def _setup_time_stepper(
self,
T: float,
use_ns: bool = True,
st_progress: typing.Any = None,
) -> None:
self._time_stepper = TimeStepper(
t0=self._t0,
T=T,
dt=self._dt,
use_ns=use_ns,
st_progress=st_progress,
)
self.coupling.register_time_stepper(self._time_stepper)
[docs]
@classmethod
def from_models(
cls,
coupling: em_model.BaseEMCoupling,
config: typing.Optional[Config] = None,
reset: bool = True,
):
obj = cls(empty=True, config=config)
obj.coupling = coupling
obj._t0 = coupling.t
obj._reset = reset
obj._time_stepper = None
obj._setup_datacollector()
return obj
[docs]
def store(self):
# Assign u, v and Ca for postprocessing
self.coupling.assigners.assign()
self.collector.store(TimeStepper.ns2ms(self.t))
def _setup_datacollector(self):
from .datacollector import DataCollector
self.collector = DataCollector(
outdir=self.outdir,
geo=self.coupling.geometry,
reset_state=self._reset,
outfilename=self._config.outfilename,
)
self.coupling.register_datacollector(self.collector)
def _solve_mechanics_now(self) -> bool:
if self._config.mechanics_solve_strategy == "fixed":
return self.coupling.dt_mechanics > self._config.dt_mech
self.coupling.assigners.assign_pre()
norm = self.coupling.assigners.compute_pre_norm()
return norm >= self._config.mech_threshold or self.coupling.dt_mechanics > self._config.dt_mech
def _post_mechanics_solve(self) -> None:
# Update previous lmbda
self.coupling.update_prev_mechanics()
self.coupling.mechanics_to_coupling()
self.coupling.coupling_to_ep()
def _solve_mechanics(self):
# if self._config.mechanics_use_continuation:
# self.mech_heart.solve_for_control(self.coupling.XS_ep)
# else:
self.coupling.coupling_to_mechanics()
self.coupling.solve_mechanics()
self._post_mechanics_solve()
def _post_ep(self):
self.coupling.update_prev_ep()
self.coupling.ep_to_coupling()
[docs]
def save_state(self):
self.coupling.save_state(path=self.state_path, config=self._config)
[docs]
def solve(
self,
T: float = Config.T,
save_freq: int = Config.save_freq,
show_progress_bar: bool = Config.show_progress_bar,
st_progress: typing.Any = None,
default_save_condition: typing.Callable[[int, float, float], bool] = (
lambda i, T, dt: i > 0 and T >= 40000 and i % int(10000 / dt) == 0
),
):
save_it = int(save_freq / self._dt)
self._setup_time_stepper(T, use_ns=True, st_progress=st_progress)
pbar = create_progressbar(
time_stepper=self._time_stepper,
show_progress_bar=show_progress_bar,
)
for i, (t0, t) in enumerate(pbar):
logger.debug(
f"Solve EP model at step {i} from {TimeStepper.ns2ms(t0):.2f} ms to {TimeStepper.ns2ms(t):.2f} ms",
)
# Solve EP model
self.coupling.t = TimeStepper.ns2ms(t)
self.coupling.solve_ep((TimeStepper.ns2ms(t0), TimeStepper.ns2ms(t)))
self._post_ep()
if self._solve_mechanics_now():
logger.debug(f"Solve mechanics model at step {i} from ")
self._solve_mechanics()
# Store every 'save_freq' ms
if i % save_it == 0:
self.store()
# Save state every 10 beats if simulation is longer than 40 sec
if default_save_condition(i, T, self._dt):
self.coupling.save_state(
path=self.outdir.joinpath(f"state_{int(i*self._dt/1000)}beat.h5"),
config=self._config,
)
self.save_state()
[docs]
def create_progressbar(
time_stepper: typing.Optional[TimeStepper] = None,
show_progress_bar: bool = Config.show_progress_bar,
):
if time_stepper is None:
raise ValueError("Please provide a time stepper")
if dolfin.MPI.size(dolfin.MPI.comm_world) == 1 and show_progress_bar:
# Show progressbar
pbar = tqdm(time_stepper, total=time_stepper.total_steps)
else:
# Hide progressbar
pbar = _tqdm(time_stepper, total=time_stepper.total_steps)
return pbar
class _tqdm:
def __init__(self, iterable, *args, **kwargs):
self._iterable = iterable
def set_postfix(self, msg):
pass
def __iter__(self):
return iter(self._iterable)