Source code for simcardems.models.fully_coupled_ORdmm_Land.em_model

from __future__ import annotations

from functools import partial
from pathlib import Path
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union

import dolfin
import numpy as np
from dolfin import FiniteElement  # noqa: F401
from dolfin import MixedElement  # noqa: F401
from dolfin import tetrahedron  # noqa: F401
from dolfin import VectorElement  # noqa: F401

from ... import save_load_functions as io
from ... import utils
from ...config import Config
from ...geometry import BaseGeometry
from ...geometry import load_geometry
from ...time_stepper import TimeStepper
from ..em_model import BaseEMCoupling
from ..em_model import setup_EM_model

if TYPE_CHECKING:
    from ... import mechanics_model
    from ... import datacollector

logger = utils.getLogger(__name__)


[docs] class EMCoupling(BaseEMCoupling): def __init__( self, geometry: BaseGeometry, **state_params, ) -> None: super().__init__(geometry=geometry, **state_params) self.V_mech = dolfin.FunctionSpace(self.mech_mesh, "CG", 1) self.XS_mech = dolfin.Function(self.V_mech, name="XS_mech") self.XW_mech = dolfin.Function(self.V_mech, name="XW_mech") self.V_ep = dolfin.FunctionSpace(self.ep_mesh, "CG", 1) self.lmbda_ep = dolfin.Function(self.V_ep, name="lambda_ep") self.Zetas_ep = dolfin.Function(self.V_ep, name="Zetas_ep") self.Zetaw_ep = dolfin.Function(self.V_ep, name="Zetaw_ep") self.transfer_matrix = dolfin.PETScDMCollection.create_transfer_matrix( self.V_mech, self.V_ep, ).mat()
[docs] def interpolate( self, f_mech: dolfin.Function, f_ep: dolfin.Function, ) -> dolfin.Function: """Interpolates function from mechanics to ep mesh""" x = dolfin.as_backend_type(f_mech.vector()).vec() a, temp = self.transfer_matrix.getVecs() self.transfer_matrix.mult(x, temp) f_ep.vector().vec().aypx(0.0, temp) f_ep.vector().apply("") # Remember to free memory allocated by petsc: https://gitlab.com/petsc/petsc/-/issues/1309 x.destroy() a.destroy() temp.destroy()
@property def coupling_type(self): return "fully_coupled_ORdmm_Land" def __eq__(self, __o: object) -> bool: if not isinstance(__o, type(self)): return NotImplemented if not super().__eq__(__o): return False for attr in [ "vs", "mech_state", "lmbda_mech", "Zetas_mech", "Zetaw_mech", "lmbda_ep", "Zetas_ep", "Zetaw_ep", "XS_mech", "XW_mech", ]: if not np.allclose( getattr(self, attr).vector().get_local(), getattr(__o, attr).vector().get_local(), ): logger.info(f"{attr} differs in equality") return False return True
[docs] def register_time_stepper(self, time_stepper: TimeStepper) -> None: super().register_time_stepper(time_stepper) self.mech_solver.material.active.register_time_stepper(time_stepper)
@property def dt_mechanics(self) -> float: return self.mech_solver.material.active.dt @property def mech_mesh(self): return self.geometry.mechanics_mesh @property def ep_mesh(self): return self.geometry.ep_mesh @property def mech_state(self) -> dolfin.Function: return self.mech_solver.state @property def vs(self) -> dolfin.Function: return self.ep_solver.solution_fields()[0] @property def assigners(self) -> datacollector.Assigners: return self._assigners @assigners.setter def assigners(self, assigners) -> None: self._assigners = assigners
[docs] def setup_assigners(self) -> None: from ...datacollector import Assigners self.assigners = Assigners(vs=self.vs, mech_state=self.mech_state) for name, index in [ ("V", 0), ("Ca", 45), ("CaTrpn", 42), ("TmB", 43), ("Cd", 44), ("XS", 40), ("XW", 41), ]: self.assigners.register_subfunction( name=name, group="ep", subspace_index=index, ) self.assigners.register_subfunction( name="u", group="mechanics", subspace_index=self.mech_solver.u_subspace_index, ) for name, index in [ ("XS", 40), ("XW", 41), ]: self.assigners.register_subfunction( name=name, group="ep", subspace_index=index, is_pre=True, ) self.coupling_to_mechanics()
[docs] def register_ep_model(self, solver): logger.debug("Registering EP model") self.ep_solver = solver if hasattr(self, "mech_solver"): self.mechanics_to_coupling() self.coupling_to_mechanics() logger.debug("Done registering EP model")
[docs] def register_mech_model(self, solver: mechanics_model.MechanicsProblem): logger.debug("Registering mech model") self.mech_solver = solver self.Zetas_mech = solver.material.active.Zetas_prev self.Zetaw_mech = solver.material.active.Zetaw_prev self.lmbda_mech = solver.material.active.lmbda_prev # Note sure why we need to do this for the LV? self.lmbda_mech.set_allow_extrapolation(True) self.Zetas_mech.set_allow_extrapolation(True) self.Zetaw_mech.set_allow_extrapolation(True) self.mechanics_to_coupling() if hasattr(self, "ep_solver"): self.coupling_to_mechanics() logger.debug("Done registering EP model")
[docs] def update_prev_mechanics(self): self.mech_solver.material.active.update_prev()
[docs] def update_prev_ep(self): self.ep_solver.vs_.assign(self.ep_solver.vs)
[docs] def ep_to_coupling(self): logger.debug("Update mechanics") self.assigners.assign_ep() logger.debug("Done updating mechanics")
[docs] def coupling_to_mechanics(self): logger.debug("Interpolate mechanics") if hasattr(self, "_assigners"): self.XS_mech.interpolate(self.assigners.functions["ep"]["XS"]) self.XW_mech.interpolate(self.assigners.functions["ep"]["XW"]) logger.debug("Done interpolating mechanics")
[docs] def mechanics_to_coupling(self): logger.debug("Interpolate EP") self.interpolate(self.lmbda_mech, self.lmbda_ep) self.interpolate(self.Zetas_mech, self.Zetas_ep) self.interpolate(self.Zetaw_mech, self.Zetaw_ep) logger.debug("Done interpolating EP")
[docs] def coupling_to_ep(self): logger.debug("Update EP") logger.debug("Done updating EP")
[docs] def solve_mechanics(self) -> None: logger.debug("Solve mechanics") self.mech_solver.solve()
[docs] def solve_ep(self, interval: Tuple[float, float]) -> None: logger.debug("Solve EP") self.ep_solver.step(interval)
[docs] def print_mechanics_info(self): total_dofs = self.mech_state.function_space().dim() utils.print_mesh_info(self.mech_mesh, total_dofs) logger.info("Mechanics model")
[docs] def print_ep_info(self): # Output some degrees of freedom total_dofs = self.vs.function_space().dim() logger.info("EP model") utils.print_mesh_info(self.ep_mesh, total_dofs)
[docs] def cell_params(self): return self.ep_solver.ode_solver._model.parameters()
[docs] def register_datacollector(self, collector: datacollector.DataCollector) -> None: super().register_datacollector(collector=collector) collector.register("ep", "Zetas", self.Zetas_ep) collector.register("ep", "Zetaw", self.Zetaw_ep) collector.register("ep", "lambda", self.lmbda_ep) collector.register("mechanics", "XS", self.XS_mech) collector.register("mechanics", "XW", self.XW_mech) collector.register("mechanics", "Zetas", self.Zetas_mech) collector.register("mechanics", "Zetaw", self.Zetaw_mech) collector.register("mechanics", "lambda", self.lmbda_mech) collector.register( "mechanics", "Ta", self.mech_solver.material.active.Ta_current, ) self.mech_solver.solver.register_datacollector(collector)
[docs] def save_state( self, path: Union[str, Path], config: Optional[Config] = None, ) -> None: super().save_state(path=path, config=config) with dolfin.HDF5File( self.geometry.comm(), Path(path).as_posix(), "a", ) as h5file: h5file.write(self.lmbda_mech, "/em/lmbda_prev") h5file.write(self.Zetas_mech, "/em/Zetas_prev") h5file.write(self.Zetaw_mech, "/em/Zetaw_prev") h5file.write(self.ep_solver.vs, "/ep/vs") h5file.write(self.mech_solver.state, "/mechanics/state") io.dict_to_h5( self.cell_params(), path, "ep/cell_params", comm=self.geometry.comm(), )
[docs] @classmethod def from_state( cls, path: Union[str, Path], drug_factors_file: Union[str, Path] = "", popu_factors_file: Union[str, Path] = "", disease_state="healthy", PCL: float = 1000, ) -> BaseEMCoupling: logger.debug(f"Load state from path {path}") path = Path(path) if not path.is_file(): raise FileNotFoundError(f"File {path} does not exist") geo = load_geometry(path, schema_path=path.with_suffix(".json")) logger.debug("Open file with h5py") with io.h5pyfile(path) as h5file: config = Config(**io.h5_to_dict(h5file["config"])) state_params = io.h5_to_dict(h5file["state_params"]) cell_params = io.h5_to_dict(h5file["ep"]["cell_params"]) vs_signature = h5file["ep"]["vs"].attrs["signature"].decode() mech_signature = h5file["mechanics"]["state"].attrs["signature"].decode() config.drug_factors_file = drug_factors_file config.popu_factors_file = popu_factors_file config.disease_state = disease_state config.PCL = PCL VS = dolfin.FunctionSpace(geo.ep_mesh, eval(vs_signature)) vs = dolfin.Function(VS) W = dolfin.FunctionSpace(geo.mechanics_mesh, eval(mech_signature)) mech_state = dolfin.Function(W) # FIXME: load this signature from the file as well V = dolfin.FunctionSpace(geo.mechanics_mesh, "CG", 1) lmbda_prev = dolfin.Function(V, name="lambda") Zetas_prev = dolfin.Function(V, name="Zetas") Zetaw_prev = dolfin.Function(V, name="Zetaw") logger.debug("Load functions") with dolfin.HDF5File(geo.ep_mesh.mpi_comm(), path.as_posix(), "r") as h5file: h5file.read(vs, "/ep/vs") h5file.read(mech_state, "/mechanics/state") h5file.read(lmbda_prev, "/em/lmbda_prev") h5file.read(Zetas_prev, "/em/Zetas_prev") h5file.read(Zetaw_prev, "/em/Zetaw_prev") from . import CellModel, ActiveModel cell_inits = io.vs_functions_to_dict( vs, state_names=CellModel.default_initial_conditions().keys(), ) cls_ActiveModel = partial( ActiveModel, Zetas=Zetas_prev, Zetaw=Zetaw_prev, lmbda=lmbda_prev, ) return setup_EM_model( cls_EMCoupling=cls, cls_CellModel=CellModel, cls_ActiveModel=cls_ActiveModel, geometry=geo, config=config, cell_inits=cell_inits, cell_params=cell_params, mech_state_init=mech_state, state_params=state_params, )