Source code for ploonetide.numerical.simulator

"""This module defines Simulation class"""
import numpy as np
import warnings

from ploonetide.utils.constants import *

from scipy.integrate._ivp.base import OdeSolver
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm

# === Monkey-patch OdeSolver to include tqdm progress bar ===

# Save original methods
# _original_init = OdeSolver.__init__
# _original_step = OdeSolver.step


# # Define patched methods
# def _patched_init(self, fun, t0, y0, t_bound, vectorized=True, support_complex=False, **kwargs):
#     progress_total = kwargs.pop('_progress_total', None)
#     total_steps = progress_total if progress_total is not None else int(np.ceil(t_bound - t0))

#     bar_format = '{desc}{percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} steps | {elapsed}<{remaining}'
#     self._pbar = tqdm(
#         desc='Computing orbital evolution: ',
#         bar_format=bar_format,
#         total=total_steps,
#         initial=0
#     )
#     self._last_t = t0

#     _original_init(self, fun, t0, y0, t_bound, vectorized, support_complex, **kwargs)


# def _patched_step(self):
#     _original_step(self)

#     delta_t = self.t - self._last_t
#     self._pbar.update(delta_t)  # One step per call to step()
#     self._last_t = self.t

#     if self.t >= self.t_bound:
#         self._pbar.close()


# # Apply patch
# OdeSolver.__init__ = _patched_init
# OdeSolver.step = _patched_step


# === Variable and Simulation classes ===

[docs] class Variable: """Define a new variable for integration. Args: name (str): Name of the variable v_ini (float): Initial value """
[docs] def __init__(self, name, v_ini): self.name = name self.v_ini = v_ini
[docs] def return_vec(self) -> np.ndarray: return np.array([self.v_ini])
[docs] class Simulation: """Build and run a simulation. Args: variables (list): List of Variable instances """
[docs] def __init__(self, variables): self.variables = variables self.N_variables = len(self.variables) self.Ndim = self.N_variables self.quant_vec = np.concatenate([var.return_vec() for var in self.variables])
[docs] def set_diff_eq(self, calc_diff_eqs, params, ini_conds, events): """ Set the differential equation function. Args: calc_diff_eqs: Callable returning dy/dt **kwargs: Additional arguments passed to the function """ self.calc_diff_eqs = calc_diff_eqs self.diff_eq_kwargs = params self.diff_eq_ini_conds = ini_conds self.events = events
[docs] def set_integration_method(self, method='RK45'): """ Set the integration method. Args: method (str): One of ['RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'] """ self.integration_method = method
[docs] def run(self, t, dt, t0=0.0, jacobian=None): """ Run the simulation. Args: t (float): Final time dt (float): Timestep t0 (float): Initial time (default 0.0) """ self.time_step = dt self.total_time = t t_span = np.array([t0, t]) # Avoid t0=0 for stability t_eval = np.arange(t_span[0], t_span[1], dt) self.bar_fmt = '{desc}{percentage:4.0f}%|{bar}|'\ + ' {n_fmt}/{total_fmt} steps | {elapsed}<{remaining}' with warnings.catch_warnings(): warnings.simplefilter('ignore') y0 = np.asarray(self.quant_vec, dtype=float) atol = np.full(y0.size, 1e-8, dtype=float) atol[0] = 1e-12 # Omega_p atol[1] = 1e-13 # n_p atol[2] = 1e-8 # log(n_m) total_time = t_span[1] - t_span[0] max_step = min(5.0 * MYEAR, 0.005 * total_time) # max_step = 0.1 * MYEAR sols = solve_ivp( self.calc_diff_eqs, t_span, y0, method=self.integration_method, rtol=1e-6, atol=atol, args=(self.diff_eq_kwargs, self.diff_eq_ini_conds), t_eval=None, dense_output=True, _progress_total=len(t_eval), jac=jacobian if self.integration_method in ("Radau", "BDF", "LSODA") else None, max_step=max_step, events=self.events, ) self.history = sols