Source code for xga.models

#  This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
#  Last modified by David J Turner (turne540@msu.edu) 25/08/2025, 15:58. Copyright (c) The Contributors

import inspect
from types import FunctionType

# Doing star imports just because its more convenient, and there won't ever be enough code in these that
#  it becomes a big inefficiency
from .density import *
from .entropy import *
from .mass import *
from .misc import *
from .pressure import *
from .sb import *
from .temperature import *

# This dictionary is meant to provide pretty versions of model/function names to go in plots
# This method of merging dictionaries only works in Python 3.5+, but that should be fine
MODEL_PUBLICATION_NAMES = {**DENS_MODELS_PUB_NAMES, **MISC_MODELS_PUB_NAMES, **SB_MODELS_PUB_NAMES,
                           **TEMP_MODELS_PUB_NAMES, **ENTROPY_MODELS_PUB_NAMES, **MASS_MODELS_PUB_NAMES,
                           **PRESSURE_MODELS_PUB_NAMES}
MODEL_PUBLICATION_PAR_NAMES = {**DENS_MODELS_PAR_NAMES, **MISC_MODELS_PAR_NAMES, **SB_MODELS_PAR_NAMES,
                               **TEMP_MODELS_PAR_NAMES, **ENTROPY_MODELS_PAR_NAMES, **MASS_MODELS_PAR_NAMES,
                               **PRESSURE_MODELS_PAR_NAMES}
# These dictionaries tell the profile fitting function what models, start pars, and priors are allowed
PROF_TYPE_MODELS = {"brightness": SB_MODELS, "gas_density": DENS_MODELS, "gas_temperature": TEMP_MODELS,
                    '1d_proj_temperature': TEMP_MODELS, 'specific_entropy': ENTROPY_MODELS,
                    'hydrostatic_mass': MASS_MODELS, 'thermal_pressure': PRESSURE_MODELS}


[docs] def convert_to_odr_compatible(model_func: FunctionType, new_par_name: str = 'β', new_data_name: str = 'x_values') \ -> FunctionType: """ This is a bit of a weird one; its meant to convert model functions from the standard XGA setup (i.e. pass x values, then parameters as individual variables), into the form expected by Scipy's ODR. I'd recommend running a check to compare results from the original and converted functions where-ever this function is called - I don't completely trust it. :param FunctionType model_func: The original model function to be converted. :param str new_par_name: The name we want to use for the new list/array of fit parameters. :param str new_data_name: The new name we want to use for the x_data. :return: A successfully converted model function (hopefully) which can be used with ODR. :rtype: FunctionType """ # This is not at all perfect, but its a bodge that will do for now. If type hints are included in # the signature (as they should be in all XGA models), then np.ndarray will be numpy.ndarray in the # signature I extract. This dictionary will be used to swap that out, along with any similar problems I encounter common_conversions = {'numpy': 'np'} # This reads out the function signature - which should be structured as x_values, par1, par2, par3 etc. mod_sig = inspect.signature(model_func) # Convert that signature into a string str_mod_sig = str(mod_sig) # Go through the conversion dictionary and 'correct' the signature for conv in common_conversions: str_mod_sig = str_mod_sig.replace(conv, common_conversions[conv]) # For ODR I've decided that β is the name of the new fit parameter array, and x_values the name of the # x data. This will replace the current signature of the function. new_mod_sig = '({np}, {nd})'.format(np=new_par_name, nd=new_data_name) # I find the current names of the parameters in the signature, excluding the x value name in the original function # and reading that into a separate variable mod_sig_pars = list(mod_sig.parameters.keys()) par_names = mod_sig_pars[1:] # Store the name of the x data here data_name = mod_sig_pars[0] # This gets the source code of the function as a string mod_code = inspect.getsource(model_func) # I swap in the new signature new_mod_code = mod_code.replace(str_mod_sig, new_mod_sig) # And now I know the exact form of the whole def line I can define that as a variable and then temporarily # remove it from the source code known_def = 'def {mn}'.format(mn=model_func.__name__) + new_mod_sig + ':' new_mod_code = new_mod_code.replace(known_def, '') # Then I swing through all the original parameter names and replace them with accessing elements of our # new beta parameter list/array. for par_ind, par_name in enumerate(par_names): new_mod_code = new_mod_code.replace(par_name, '{np}[{i}]'.format(np=new_par_name, i=par_ind)) # Then I do the same thing for the new x data variable name new_mod_code = new_mod_code.replace(data_name, new_data_name) # Adds the def SIGNATURE line back in new_mod_code = known_def + new_mod_code # This compiles the code and creates a new function new_model_func_code = compile(new_mod_code, '<string>', 'exec') new_model_func = FunctionType(new_model_func_code.co_consts[0], globals(), model_func.__name__) return new_model_func
[docs] def derivative(func: FunctionType, x0: float, dx: float = 1.0, n: int = 1, args: tuple= (), order: int = 3): """ Find the nth derivative of a function at a point. Given a function, use a central difference formula with spacing `dx` to compute the nth derivative at `x0`. This is intended as a drop-in replacement for Scipy's misc.derivative function, which was deprecated in Scipy v1.10.0 and removed after Scipy v1.14.1. It has been directly copied/reconstructed from Scipy code. :param FunctionType func: Input function :param x0: The point at which the nth derivative is found. :param dx: Spacing. :param n: Order of the derivative. Default is 1. :param args: Arguments :param order: Number of points to use, must be odd. """ def _central_diff_weights(Np, ndiv=1): """ Return weights for an Np-point central derivative. Assumes equally-spaced function points. If weights are in the vector w, then derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx) """ if Np < ndiv + 1: raise ValueError( "Number of points must be at least the derivative order + 1." ) if Np % 2 == 0: raise ValueError("The number of points must be odd.") ho = Np >> 1 x = np.arange(-ho, ho + 1.0) x = x[:, np.newaxis] X = x ** 0.0 for k in range(1, Np): X = np.hstack([X, x ** k]) w = np.prod(np.arange(1, ndiv + 1), axis=0) * np.linalg.inv(X)[ndiv] return w if order < n + 1: raise ValueError( "'order' (the number of points used to compute the derivative), " "must be at least the derivative order 'n' + 1." ) if order % 2 == 0: raise ValueError( "'order' (the number of points used to compute the derivative) " "must be odd." ) # pre-computed for n=1 and 2 and low-order for speed. if n == 1: if order == 3: weights = np.array([-1, 0, 1]) / 2.0 elif order == 5: weights = np.array([1, -8, 0, 8, -1]) / 12.0 elif order == 7: weights = np.array([-1, 9, -45, 0, 45, -9, 1]) / 60.0 elif order == 9: weights = np.array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0 else: weights = _central_diff_weights(order, 1) elif n == 2: if order == 3: weights = np.array([1, -2.0, 1]) elif order == 5: weights = np.array([-1, 16, -30, 16, -1]) / 12.0 elif order == 7: weights = np.array([2, -27, 270, -490, 270, -27, 2]) / 180.0 elif order == 9: weights = ( np.array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9]) / 5040.0 ) else: weights = _central_diff_weights(order, 2) else: weights = _central_diff_weights(order, n) val = 0.0 ho = order >> 1 for k in range(order): val += weights[k] * func(x0 + (k - ho) * dx, *args) return val / np.prod((dx,) * n, axis=0)