Source code for flory.mcmp._finder_impl

"""The implementation details of the core algorithm for finder.

:mod:`~flory.mcmp._finder_impl` contains the implementation details of the module
:mod:`~flory.mcmp.finder`. The main components of the module is the function
:func:`multicomponent_self_consistent_metastep`, which implements the self consistent
iterations for minimizing the extended free energy functional, and the function
:func:`get_clusters`, which finds the unique phases.

In this module, arguments of functions are always marked by `Constant`, `Output` or
`Mutable`, to indicate whether the arguments will be kept invariant, directly overwritten,
or reused.

.. codeauthor:: Yicheng Qiang <yicheng.qiang@ds.mpg.de>
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import numba as nb
import numpy as np
from numba import literal_unroll

from ..constraint.base import ConstraintBaseCompiled
from ..ensemble.base import EnsembleBaseCompiled
from ..entropy.base import EntropyBaseCompiled
from ..interaction.base import InteractionBaseCompiled


[docs] @nb.njit() def count_valid_compartments(Js: np.ndarray, threshold: float) -> int: r"""Count valid compartments. Count how many entries in :paramref:`Js` are larger than :paramref:`threshold`. Args: Js: Constant. The 1D array with the size of :math:`N_\mathrm{M}`, containing the relative volumes of compartments :math:`J_m`. threshold: Constant. The threshold below which the corresponding compartment is considered dead. Returns: : Number of entries in :paramref:`Js` larger than :paramref:`threshold`. """ return (Js > threshold).sum()
[docs] @nb.njit() def make_valid_compartment_masks(Js: np.ndarray, threshold: float) -> np.ndarray: r"""Create masks for valid compartments. Create masks for entries in :paramref:`Js` are larger than :paramref:`threshold`. Value of 1.0 or 0.0 indicates a valid or invalid mask, respectively. Args: Js: Constant. The 1D array with the size of :math:`N_\mathrm{M}`, containing the relative volumes of compartments :math:`J_m`. threshold: Constant. The threshold below which the corresponding compartment is considered dead. Returns: : 1D array with the size of :math:`N_\mathrm{M}`, containing masks of entries in :paramref:`Js` larger than :paramref:`threshold`. """ return np.sign(Js - threshold).clip(0.0)
[docs] @nb.njit() def revive_compartments_by_random( Js: np.ndarray, targets: np.ndarray, threshold: float, rng: np.random.Generator, scaler: float, ) -> int: r"""Revive dead compartments randomly. Randomly revive compartments whose relative volume (element of :paramref:`Js`) is smaller than :paramref:`threshold`. The revived values are randomly and uniformly sampled between the extreme values of :paramref:`targets` across all compartments. The range can be scaled by the parameter :paramref:`scaler`. Note that this function does not conserve the quantities in :paramref:`targets` across all compartments, since the new values are randomly generated. Args: Js: Constant. The 1D array with the size of :math:`N_\mathrm{M}`, containing the relative volumes of compartments :math:`J_m`. targets: Mutable. 2D array with the size of :math:`N_* \times N_\mathrm{M}`, containing the values to be revived. The second dimension has to be the same as that of :paramref:`Js`. Note that this is not checked. threshold: Constant. The threshold below which the corresponding compartment is considered dead. For each element of :paramref:`Js` smaller than this parameter, the corresponding compartment will be considered as dead, and its :paramref:`targets` values will then be randomly drawn between the corresponding minimum and maximum values of :paramref:`targets` across all compartments. Corresponding :paramref:`Js` will be set to be unity. rng: Mutable. Random number generator for reviving. scaler: Constant. The scaler for generating random new values. Returns: : Number of dead compartments that have been revived. """ revive_count = 0 num_comp, num_part = targets.shape target_centers = np.full(num_comp, 0.0, float) omega_widths = np.full(num_comp, 0.0, float) for itr_component in range(num_comp): current_target_max = targets[itr_component].max() current_target_min = targets[itr_component].min() target_centers[itr_component] = (current_target_max + current_target_min) * 0.5 omega_widths[itr_component] = (current_target_max - current_target_min) * 0.5 # revive the compartment with random conjugate field for itr_compartment in range(num_part): if Js[itr_compartment] <= threshold: Js[itr_compartment] = 1.0 revive_count += 1 for itr_component in range(num_comp): targets[itr_component, itr_compartment] = target_centers[ itr_component ] + omega_widths[itr_component] * scaler * rng.uniform(-1, 1) return revive_count
[docs] @nb.njit() def revive_compartments_by_copy( Js: np.ndarray, targets: np.ndarray, threshold: float, rng: np.random.Generator, ) -> int: r"""Revive dead compartments by copying living ones. Revive compartments whose relative volume (element of :paramref:`Js`) is smaller than :paramref:`threshold`. The revived values are randomly copied from other living compartments. Note that this function conserves the quantities in :paramref:`targets` across all compartments by modifying the volumes :paramref:`Js` accordingly. Args: Js: Constant. The 1D array with the size of :math:`N_\mathrm{M}`, containing the relative volumes of compartments :math:`J_m`. targets: Mutable. 2D array with the size of :math:`N_* \times N_\mathrm{M}`, containing the values to be revived. The second dimension has to be the same as that of :paramref:`Js`. Note that this is not checked. threshold: Constant. The threshold below which the corresponding compartment is considered dead. For each element of :paramref:`Js` smaller than this parameter, the corresponding compartment will be considered as dead, and its :paramref:`targets` values will then be copied from that of a living compartment. At the same time, the corresponding elements (both the dead and the copied living one) in :paramref:`Js` will be redistributed to ensure conservation of :paramref:`Js`. rng: Mutable. Random number generator for reviving. Returns: : Number of revives """ revive_count = 0 num_comp, num_part = targets.shape dead_indexes = np.full(num_part, -1, dtype=np.int32) dead_count = 0 living_nicely_indexes = np.full(num_part, -1, dtype=np.int32) living_nicely_count = 0 for itr_compartment in range(num_part): if Js[itr_compartment] > 2.0 * threshold: living_nicely_indexes[living_nicely_count] = itr_compartment living_nicely_count += 1 elif Js[itr_compartment] <= threshold: dead_indexes[dead_count] = itr_compartment dead_count += 1 for itr_dead in dead_indexes[:dead_count]: while True: pos_in_living = rng.integers(0, living_nicely_count) ref_index = living_nicely_indexes[pos_in_living] if Js[int(ref_index)] > 2.0 * threshold: targets[:, itr_dead] = targets[:, ref_index] new_J = 0.5 * Js[ref_index] Js[itr_dead] = new_J Js[ref_index] = new_J living_nicely_indexes[living_nicely_count] = itr_dead living_nicely_count += 1 revive_count += 1 break living_nicely_count -= 1 living_nicely_indexes[pos_in_living] = living_nicely_indexes[ living_nicely_count ] living_nicely_indexes[living_nicely_count] = -1 return revive_count
[docs] @nb.njit() def multicomponent_self_consistent_metastep( interaction: InteractionBaseCompiled, entropy: EntropyBaseCompiled, ensemble: EnsembleBaseCompiled, constraints: tuple[ConstraintBaseCompiled], *, omegas: np.ndarray, Js: np.ndarray, phis_comp: np.ndarray, phis_feat: np.ndarray, steps_inner: int, acceptance_Js: float, Js_step_upper_bound: float, acceptance_omega: float, kill_threshold: float, revive_tries: int, revive_scaler: float, rng: np.random.Generator, ) -> tuple[float, float, float, float, int, bool]: r""" The core algorithm of finding coexisting states of multicomponent systems with self-consistent iterations. Args: interaction: Constant. The compiled interaction instance. See :class:`~flory.interaction.base.InteractionBaseCompiled` for more information. entropy: Constant. The compiled entropy instance. See :class:`~flory.entropy.base.EntropyBaseCompiled` for more information. ensemble: Constant. The compiled ensemble instance. See :class:`~flory.ensemble.base.EnsembleBaseCompiled` for more information. constraints: Constant. The tuple of compiled constraint instance. Note that constraint instances are usually stateful, therefore the internal states of :paramref:`constraints` are actually mutable. See :class:`~flory.constraint.base.constraintBaseCompiled` for more information. omegas: Mutable. 2D array with size of :math:`N_\mathrm{S} \times N_\mathrm{M}`, containing the conjugate field :math:`w_r^{(m)}` of features. Note that this field is both used as input and output. Note again that this function DO NOT initialize :paramref:`omegas`, it should be initialized externally, and usually a random initialization will be a reasonable choice. Js: Mutable. 1D array with size of :math:`N_\mathrm{M}`, containing the relative volumes of the compartments :math:`J_m`. The average value of `Js` will and should be unity, in order to keep the values invariant for different :math:`N_\mathrm{M}`. Note that this field is both used as input and output. An all-one array is usually a nice initialization, unless resume of a previous run is intended. phis_comp: Output. 2D array with size of :math:`N_\mathrm{C} \times N_\mathrm{M}`, containing the volume fractions of components :math:`\phi_i^{(m)}`. phis_feat: Output. 2D array with size of :math:`N_\mathrm{S} \times N_\mathrm{M}`, containing the volume fractions of features :math:`\phi_r^{(m)}`. steps_inner: Constant. Number of steps in current routine. Within these steps, convergence is not checked and no output will be generated. acceptance_Js: Constant. The acceptance of :paramref:`Js` (the relative compartment size :math:`J_m`). This value determines the amount of changes accepted in each step for the :math:`J_m` field. Typically this value can take the order of :math:`10^{-3}`, or smaller when the system becomes larger or stiffer. Js_step_upper_bound: Constant. The maximum change of :paramref:`Js` (the relative compartment size :math:`J_m`) per step. This value is designed to reduce the risk that a the volume of a compartment changes too fast before it develops meaningful composition. If the intended change is larger this value, all the changes will be scaled down to guarantee that the maximum changes do not exceed this value. Typically this value can take the order of :math:`10^{-3}`, or smaller when the system becomes larger or stiffer. acceptance_omega: Constant. The acceptance of :paramref:`omegas`(the conjugate fields :math:`w_r^{(m)}`). This value determines the amount of changes accepted in each step for the :math:`w_r^{(m)}` field. Note that if the iteration of :math:`J_m` is scaled down due to parameter :paramref:`Js_step_upper_bound`, the iteration of :math:`w_r^{(m)}` fields will be scaled down simultaneously. Note that this value also scales the evolution of the internal states (Lagrange multipliers) of the :paramref:`constraints`. See the documentation of actual constraint class for additional acceptances for :paramref:`constraints`. Typically this value can take the order of :math:`10^{-2}`, or smaller when the system becomes larger or stiffer. kill_threshold: Constant. The threshold of the :math:`J_m` for a compartment to be considered dead and killed afterwards. Should be not less than 0. In each iteration step, the :math:`J_m` array will be checked, for each element smaller than this parameter, the corresponding compartment will be killed and 0 will be assigned to the internal mask. The dead compartment may be revived, depending on whether reviving is allowed or whether the number of the revive tries has been exhausted. revive_tries: Constant. Number of tries left to revive the dead compartment. 0 or negative value indicates no reviving. When this value is exhausted, i.e. the number of revive in current function call exceeds this value, the revive will be turned off. Note that this function does not decrease this value, but returns the number of revives that have happened after completion. revive_scaler: Constant. The scaling factor for the conjugate fields :math:`w_r^{(m)}` when a dead compartment is revived. This value determines the range of the random conjugate field generated by the algorithm. Typically 1.0 or some value slightly larger will be a reasonable choice. See :meth:`revive_compartments_by_random` for more information. rng: Mutable. random number generator for reviving. Returns: [0]: Max absolute incompressibility. [1]: Max absolute conjugate field error. [2]: Max absolute relative volumes error. [3]: Max absolute constraints error. [4]: Number of revives. [5]: Whether no phase is killed in the last step. """ num_feat, num_part = omegas.shape n_valid_phase = 0 revive_count = 0 for _ in range(steps_inner): # check if we are still allowed to revive compartments if revive_count < revive_tries: n_valid_phase = count_valid_compartments(Js, kill_threshold) if n_valid_phase != num_part: # revive dead compartments revive_count += revive_compartments_by_random( Js, omegas, kill_threshold, rng, revive_scaler ) # generate masks for the compartments masks = make_valid_compartment_masks(Js, kill_threshold) n_valid_phase = int(masks.sum()) Js *= masks # calculate volume fractions, single molecular partition function Q and incompressibility Qs = entropy.partition(phis_comp, omegas, Js) # modifies phis_comp directly incomp = ensemble.normalize(phis_comp, Qs, masks) # modifies phis_comp directly entropy.comp_to_feat(phis_feat, phis_comp) # modifies phis_feat directly max_abs_incomp = np.abs(incomp).max() # prepare constraints: constraints are stateful if constraints: for cons in literal_unroll(constraints): cons.prepare(phis_feat, Js, masks) # temp for omega, namely chi.phi omega_temp = interaction.potential(phis_feat) # xi, the Lagrange multiplier xi = interaction.incomp_coef(phis_feat) * incomp for itr_feat in range(num_feat): xi += omegas[itr_feat] - omega_temp[itr_feat] for cons in literal_unroll(constraints): for itr_feat in range(num_feat): xi -= cons.potential[ itr_feat ] # potential from constraints are already calculated in preparation. xi *= masks xi /= num_feat # local energy. i.e. energy of phases excluding the partition function part local_energy = ( interaction.volume_derivative(omega_temp, phis_feat) + entropy.volume_derivative(phis_comp) + xi * incomp ) for cons in literal_unroll(constraints): local_energy += cons.volume_derivative # volume_derivative from constraints are already calculated in preparation. omega_temp += cons.potential for itr_feat in range(num_feat): omega_temp[itr_feat] += xi local_energy -= omega_temp[itr_feat] * phis_feat[itr_feat] # calculate the difference of Js local_energy_mean = (local_energy * Js).sum() / n_valid_phase Js_diff = (local_energy_mean - local_energy) * masks max_abs_Js_diff = np.abs(Js_diff).max() # calculate additional factor to scale down iteration Js_max_change = max(max_abs_Js_diff * acceptance_Js, Js_step_upper_bound) additional_factor = Js_step_upper_bound / Js_max_change # update Js Js += additional_factor * acceptance_Js * Js_diff Js *= masks Js += 1 - (Js.sum() / n_valid_phase) Js *= masks # calculate difference of omega and update omega directly max_abs_omega_diff = 0 for itr_comp in range(num_feat): omega_temp[itr_comp] -= omegas[itr_comp] omega_temp[itr_comp] *= masks max_abs_omega_diff = max(max_abs_omega_diff, omega_temp[itr_comp].max()) omegas[itr_comp] += ( additional_factor * acceptance_omega * omega_temp[itr_comp] ) omegas[itr_comp] *= masks max_constraint_residue = 0 for cons in literal_unroll(constraints): max_constraint_residue = max( max_constraint_residue, cons.evolve(additional_factor * acceptance_omega, masks), ) # count the valid phases in the last step n_valid_phase_last = count_valid_compartments(Js, kill_threshold) return ( max_abs_incomp, max_abs_omega_diff, max_abs_Js_diff, max_constraint_residue, revive_count, n_valid_phase == n_valid_phase_last, )