# Source code for xga.products.base

```
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
# Last modified by David J Turner (turne540@msu.edu) 20/02/2023, 14:04. Copyright (c) The Contributors
import inspect
import os
import pickle
from copy import deepcopy
from random import randint
from typing import Tuple, List, Dict, Union
from warnings import warn
import corner
import emcee as em
import numpy as np
from astropy.units import Quantity, UnitConversionError, Unit, deg
from getdist import plots, MCSamples
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import FuncFormatter
from scipy.optimize import curve_fit, minimize
from tabulate import tabulate
from ..exceptions import SASGenerationError, UnknownCommandlineError, XGAFitError, XGAInvalidModelError, \
ModelNotAssociatedError
from ..models import PROF_TYPE_MODELS, BaseModel1D, MODEL_PUBLICATION_NAMES
from ..models.fitting import log_likelihood, log_prob
from ..utils import SASERROR_LIST, SASWARNING_LIST, OUTPUT
[docs]class BaseProduct:
"""
The super class for all SAS generated products in XGA. Stores relevant file path information, parses the std_err
output of the generation process, and stores the instrument and ObsID that the product was generated for.
:param str path: The path to where the product file SHOULD be located.
:param str obs_id: The ObsID related to the product being declared.
:param str instrument: The instrument related to the product being declared.
:param str stdout_str: The stdout from calling the terminal command.
:param str stderr_str: The stderr from calling the terminal command.
:param str gen_cmd: The command used to generate the product.
"""
def __init__(self, path: str, obs_id: str, instrument: str, stdout_str: str, stderr_str: str, gen_cmd: str):
"""
The initialisation method for the BaseProduct class.
"""
# This attribute stores strings that indicate why a product object has been deemed as unusable
self._why_unusable = []
# So this flag indicates whether we think this data product can be used for analysis
self._usable = True
if os.path.exists(path):
self._path = path
else:
self._path = None
self._usable = False
self._why_unusable.append("ProductPathDoesNotExist")
# Saving this in attributes for future reference
self.unprocessed_stdout = stdout_str
self.unprocessed_stderr = stderr_str
self._sas_error, self._sas_warn, self._other_error = self.parse_stderr()
self._obs_id = obs_id
self._inst = instrument
self._og_cmd = gen_cmd
self._energy_bounds = (None, None)
self._prod_type = None
self._src_name = None
# Users are not allowed to change this, so just a getter.
@property
def usable(self) -> bool:
"""
Returns whether this product instance should be considered usable for an analysis.
:return: A boolean flag describing whether this product should be used.
:rtype: bool
"""
return self._usable
@property
def path(self) -> str:
"""
Property getter for the attribute containing the path to the product.
:return: The product path.
:rtype: str
"""
return self._path
@path.setter
def path(self, prod_path: str):
"""
Property setter for the attribute containing the path to the product.
:param str prod_path: The product path.
"""
if not os.path.exists(prod_path):
prod_path = None
# We won't be able to make use of this product if it isn't where we think it is
self._usable = False
self._why_unusable.append("ProductPathDoesNotExist")
self._path = prod_path
[docs] def parse_stderr(self) -> Tuple[List[str], List[Dict], List]:
"""
This method parses the stderr associated with the generation of a product into errors confirmed to have
come from SAS, and other unidentifiable errors. The SAS errors are returned with the actual error
name, the error message, and the SAS routine that caused the error.
:return: A list of dictionaries containing parsed, confirmed SAS errors, another containing SAS warnings,
and another list of unidentifiable errors that occured in the stderr.
:rtype: Tuple[List[Dict], List[Dict], List]
"""
def find_sas(split_stderr: list, err_type: str) -> Tuple[List[dict], List[str]]:
"""
Function to search for and parse SAS errors and warnings.
:param list split_stderr: The stderr string split on line endings.
:param str err_type: Should this look for errors or warnings?
:return: Returns the dictionary of parsed errors/warnings, as well as all lines
with SAS errors/warnings in.
:rtype: Tuple[List[dict], List[str]]
"""
parsed_sas = []
# This is a crude way of looking for SAS error/warning strings ONLY
sas_lines = [line for line in split_stderr if "** " in line and ": {}".format(err_type) in line]
for err in sas_lines:
try:
# This tries to split out the SAS task that produced the error
originator = err.split("** ")[-1].split(":")[0]
# And this should split out the actual error name
err_ident = err.split(": {} (".format(err_type))[-1].split(")")[0]
# Actual error message
err_body = err.split("({})".format(err_ident))[-1].strip("\n").strip(", ").strip(" ")
if err_type == "error":
# Checking to see if the error identity is in the list of SAS errors
sas_err_match = [sas_err for sas_err in SASERROR_LIST if err_ident.lower()
in sas_err.lower()]
elif err_type == "warning":
# Checking to see if the error identity is in the list of SAS warnings
sas_err_match = [sas_err for sas_err in SASWARNING_LIST if err_ident.lower()
in sas_err.lower()]
if len(sas_err_match) != 1:
originator = ""
err_ident = ""
err_body = ""
except IndexError:
originator = ""
err_ident = ""
err_body = ""
parsed_sas.append({"originator": originator, "name": err_ident, "message": err_body})
return parsed_sas, sas_lines
# Defined as empty as they are returned by this method
sas_errs_msgs = []
parsed_sas_warns = []
other_err_lines = []
# err_str being "" is ideal, hopefully means that nothing has gone wrong
if self.unprocessed_stderr != "":
# Errors will be added to the error summary, then raised later
# That way if people try except the error away the object will have been constructed properly
err_lines = [e for e in self.unprocessed_stderr.split('\n') if e != '']
# Fingers crossed each line is a separate error
parsed_sas_errs, sas_err_lines = find_sas(err_lines, "error")
parsed_sas_warns, sas_warn_lines = find_sas(err_lines, "warning")
sas_errs_msgs = ["{e} raised by {t} - {b}".format(e=e["name"], t=e["originator"], b=e["message"])
for e in parsed_sas_errs]
# These are impossible to predict the form of, so they won't be parsed
other_err_lines = [line for line in err_lines if line not in sas_err_lines
and line not in sas_warn_lines and line != "" and "warn" not in line]
# Adding some advice
for e_ind, e in enumerate(other_err_lines):
if 'seg' in e.lower() and 'fault' in e.lower():
other_err_lines[e_ind] += ' - Try examining an image of the cluster with regions subtracted, ' \
'and have a look at where your coordinate lies.'
if len(sas_errs_msgs) > 0:
self._usable = False
self._why_unusable.append("SASErrorPresent")
if len(other_err_lines) > 0:
self._usable = False
self._why_unusable.append("OtherErrorPresent")
return sas_errs_msgs, parsed_sas_warns, other_err_lines
@property
def sas_errors(self) -> List[str]:
"""
Property getter for the confirmed SAS errors associated with a product.
:return: The list of confirmed SAS errors.
:rtype: List[Dict]
"""
return self._sas_error
@property
def sas_warnings(self) -> List[Dict]:
"""
Property getter for the confirmed SAS warnings associated with a product.
:return: The list of confirmed SAS warnings.
:rtype: List[Dict]
"""
return self._sas_warn
[docs] def raise_errors(self):
"""
Method to raise the errors parsed from std_err string.
"""
for error in self._sas_error:
raise SASGenerationError(error)
# This is for any unresolved errors.
for error in self._other_error:
if "warning" not in error:
raise UnknownCommandlineError("{}".format(error))
@property
def obs_id(self) -> str:
"""
Property getter for the ObsID of this image. Admittedly this information is implicit in the location
this object is stored in a source object, but I think it worth storing directly as a property as well.
:return: The XMM ObsID of this image.
:rtype: str
"""
return self._obs_id
@property
def instrument(self) -> str:
"""
Property getter for the instrument used to take this image. Admittedly this information is implicit
in the location this object is stored in a source object, but I think it worth storing
directly as a property as well.
:return: The XMM instrument used to take this image.
:rtype: str
"""
return self._inst
@property
def type(self) -> str:
"""
Property getter for the string identifier for the type of product this object is, mostly useful for
internal methods of source objects.
:return: The string identifier for this type of object.
:rtype: str
"""
return self._prod_type
@property
def errors(self) -> List[str]:
"""
Property getter for non-SAS errors detected during the generation of a product.
:return: A list of errors that aren't related to SAS.
:rtype: List[str]
"""
return self._other_error
# This is a fundamental property of the generated product, so I won't allow it be changed.
@property
def energy_bounds(self) -> Tuple[Quantity, Quantity]:
"""
Getter method for the energy_bounds property, which returns the rest frame energy band that this
product was generated in.
:return: Tuple containing the lower and upper energy limits as Astropy quantities.
:rtype: Tuple[Quantity, Quantity]
"""
return self._energy_bounds
@property
def src_name(self) -> str:
"""
Method to return the name of the object a product is associated with. The product becomes
aware of this once it is added to a source object.
:return: The name of the source object this product is associated with.
:rtype: str
"""
return self._src_name
# This needs a setter, as this property only becomes not-None when the product is added to a source object.
@src_name.setter
def src_name(self, name: str):
"""
Property setter for the src_name attribute of a product, should only really be called by a source object,
not by a user.
:param str name: The name of the source object associated with this product.
"""
self._src_name = name
@property
def not_usable_reasons(self) -> List:
"""
Whenever the usable flag of a product is set to False (indicating you shouldn't use the product), a string
indicating the reason is added to a list, which this property returns.
:return: A list of reasons why this product is unusable.
:rtype: List
"""
return self._why_unusable
@property
def sas_command(self) -> str:
"""
A property that returns the original SAS command used to generate this object.
:return: String containing the command.
:rtype: str
"""
return self._og_cmd
[docs]class BaseAggregateProduct:
"""
A base class for any XGA products that are an aggregate of an XGA SAS product, for instance this is sub-classed
to make the AnnularSpectra class. Users really shouldn't be instantiating these for themselves.
:param list file_paths: The file paths of the main files for a given aggregate product.
:param str prod_type: The product type of the individual elements.
:param str obs_id: The ObsID related to the product.
:param str instrument: The instrument related to the product.
"""
def __init__(self, file_paths: list, prod_type: str, obs_id: str, instrument: str):
"""
The init method for the BaseAggregateProduct class
"""
self._all_paths = file_paths
self._all_usable = True
self._obs_id = obs_id
self._inst = instrument
self._prod_type = prod_type
self._src_name = None
# This was originally going to create the individual products here, but realised it was
# easier to do in subclasses
self._component_products = {}
# Setting up energy limits, if they're ever required
self._energy_bounds = (None, None)
@property
def src_name(self) -> str:
"""
Method to return the name of the object a product is associated with. The product becomes
aware of this once it is added to a source object. This is overridden in the AnnularSpectra class.
:return: The name of the source object this product is associated with.
:rtype: str
"""
return self._src_name
# This needs a setter, as this property only becomes not-None when the product is added to a source object.
@src_name.setter
def src_name(self, name: str):
"""
Property setter for the src_name attribute of a product, should only really be called by a source object,
not by a user. This is overridden in the AnnularSpectra class.
:param str name: The name of the source object associated with this product.
"""
self._src_name = name
for p in self._component_products.values():
p.src_name = name
@property
def obs_id(self) -> str:
"""
Property getter for the ObsID of this image. Admittedly this information is implicit in the location
this object is stored in a source object, but I think it worth storing directly as a property as well.
:return: The XMM ObsID of this image.
:rtype: str
"""
return self._obs_id
@property
def instrument(self) -> str:
"""
Property getter for the instrument used to take this image. Admittedly this information is implicit
in the location this object is stored in a source object, but I think it worth storing
directly as a property as well.
:return: The XMM instrument used to take this image.
:rtype: str
"""
return self._inst
@property
def type(self) -> str:
"""
Property getter for the string identifier for the type of product this object is, mostly useful for
internal methods of source objects.
:return: The string identifier for this type of object.
:rtype: str
"""
return self._prod_type
@property
def usable(self) -> bool:
"""
Property getter for the boolean variable that tells you whether all component products have been
found to be usable.
:return: Boolean variable, are all component products usable?
:rtype: bool
"""
return self._all_usable
# This is a fundamental property of the generated product, so I won't allow it be changed.
@property
def energy_bounds(self) -> Tuple[Quantity, Quantity]:
"""
Getter method for the energy_bounds property, which returns the rest frame energy band that this
product was generated in, if relevant.
:return: Tuple containing the lower and upper energy limits as Astropy quantities.
:rtype: Tuple[Quantity, Quantity]
"""
return self._energy_bounds
@property
def sas_errors(self) -> List:
"""
Equivelant to the BaseProduct sas_errors property, but reports any SAS errors stored in the component products.
:return: A list of SAS errors related to component products.
:rtype: List
"""
sas_err_list = []
for p in self._component_products:
prod = self._component_products[p]
sas_err_list += prod.sas_errors
return sas_err_list
@property
def errors(self) -> List:
"""
Equivelant to the BaseProduct errors property, but reports any non-SAS errors stored in the
component products.
:return: A list of non-SAS errors related to component products.
:rtype: List
"""
err_list = []
for p in self._component_products:
prod = self._component_products[p]
err_list += prod.errors
return err_list
@property
def unprocessed_stderr(self) -> List:
"""
Equivelant to the BaseProduct sas_errors unprocessed_stderr, but returns a list of all the unprocessed
standard error outputs.
:return: List of stderr outputs.
:rtype: List
"""
unprocessed_err_list = []
for p in self._component_products:
prod = self._component_products[p]
unprocessed_err_list.append(prod.unprocessed_stderr)
return unprocessed_err_list
def __len__(self) -> int:
"""
The length of an AggregateProduct is the number of component products that makes it up.
:return:
:rtype: int
"""
return len(self._component_products)
def __iter__(self):
"""
Called when initiating iterating through an AggregateProduct based object. Resets the counter _n.
"""
self._n = 0
return self
def __next__(self):
"""
Iterates the counter _n and returns the next entry in the the component_products dictionary.
"""
if self._n < self.__len__():
result = self.__getitem__(self._n)
self._n += 1
return result
else:
raise StopIteration
def __getitem__(self, ind):
return list(self._component_products.values())[ind]
[docs]class BaseProfile1D:
"""
The superclass for all 1D radial profile products, with built in fitting, viewing, and result retrieval
functionality. Classes derived from BaseProfile1D can be added together to create Aggregate Profiles.
:param Quantity radii: The radii at which the y values of this profile have been measured.
:param Quantity values: The y values of this profile.
:param Quantity centre: The central coordinate the profile was generated from.
:param str source_name: The name of the source this profile is associated with.
:param str obs_id: The observation which this profile was generated from.
:param str inst: The instrument which this profile was generated from.
:param Quantity radii_err: Uncertainties on the radii.
:param Quantity values_err: Uncertainties on the values.
:param int associated_set_id: The set ID of the AnnularSpectra that generated this - if applicable. If this
value is supplied a set_storage_key value must also be supplied.
:param str set_storage_key: Must be present if associated_set_id is, this is the storage key which the
associated AnnularSpectra generates to place itself in XGA's store structure.
:param Quantity deg_radii: A slightly unfortunate variable that is required only if radii is not in
units of degrees, or if no set_storage_key is passed. It should be a quantity containing the radii
values converted to degrees, and allows this object to construct a predictable storage key.
:param Quantity x_norm: An astropy quantity to use to normalise the x-axis values, this is only used when
plotting if the user tells the view method that they wish for the plot to use normalised x-axis data.
:param Quantity y_norm: An astropy quantity to use to normalise the y-axis values, this is only used when
plotting if the user tells the view method that they wish for the plot to use normalised y-axis data.
"""
def __init__(self, radii: Quantity, values: Quantity, centre: Quantity, source_name: str, obs_id: str,
inst: str, radii_err: Quantity = None, values_err: Quantity = None, associated_set_id: int = None,
set_storage_key: str = None, deg_radii: Quantity = None, x_norm: Quantity = Quantity(1, ''),
y_norm: Quantity = Quantity(1, '')):
"""
The init of the superclass 1D profile product. Unlikely to ever be declared by a user, but the base
of all other 1D profiles in XGA - contains many useful functions.
"""
if type(radii) != Quantity or type(values) != Quantity:
raise TypeError("Both the radii and values passed into this object definition must "
"be astropy quantities.")
elif radii_err is not None and type(radii_err) != Quantity:
raise TypeError("The radii_err variable must be an astropy Quantity, or None.")
elif radii_err is not None and radii_err.unit != radii.unit:
raise UnitConversionError("The radii_err unit must be the same as the radii unit.")
elif values_err is not None and type(values_err) != Quantity:
raise TypeError("The values_err variable must be an astropy Quantity, or None.")
elif values_err is not None and values_err.unit != values.unit:
raise UnitConversionError("The values_err unit must be the same as the values unit.")
# Check for one dimensionality
if radii.ndim != 1 or values.ndim != 1:
raise ValueError("The radii and values arrays must be one-dimensional. The shape of radii is {0} "
"and the shape of values is {1}".format(radii.shape, values.shape))
elif (radii_err is not None and radii_err.ndim != 1) or (values_err is not None and values_err.ndim != 1):
raise ValueError("The radii_err and values_err arrays must be one-dimensional. The shape of "
"radii_err is {0} and the shape of values_err is "
"{1}".format(radii_err.shape, values_err.shape))
# Making sure the arrays have the same number of entries
elif radii.shape != values.shape:
raise ValueError("The radii and values arrays must have the same shape. The shape of radii is {0} "
"and the shape of values is {1}".format(radii.shape, values.shape))
elif (radii_err is not None and radii_err.shape != radii.shape) or \
(values_err is not None and values_err.shape != values.shape):
raise ValueError("radii_err must be the same shape as radii, and values_err must be the same shape "
"as values. The shape of radii_err is {0} where radii is {1}, and the shape of "
"values_err is {2} where values is {3}".format(radii_err.shape, radii.shape,
values_err.shape, values.shape))
# I'm actually going to enforce that the central coordinates passed when declaring a profile must
# be RA and Dec, I need the storage keys to be predictable to make everything neater
if centre.unit != deg:
raise UnitConversionError("The central coordinate value passed into a profile on declaration must be"
" in RA and Dec coordinates.")
# I'm also going to require that the profiles have knowledge of radii in degree units, also so I can make
# predictable storage strings. I don't really like to do this as it feels bodgy, but oh well
if not radii.unit.is_equivalent('deg') and deg_radii is None and set_storage_key is None:
raise ValueError("If the radii variable is not in units that are convertible to degrees, please pass "
"radii in degrees to deg_radii, this profile needs knowledge of the radii in degrees"
" to construct a storage key.")
elif not radii.unit.is_equivalent('deg') and set_storage_key is None and len(deg_radii) != len(radii):
raise ValueError("deg_radii is a different length to radii, they should be equivelant quantities, simply"
" in different units.")
elif radii.unit.is_equivalent('deg') and set_storage_key is None:
deg_radii = radii.to('deg')
if deg_radii is not None:
deg_radii = deg_radii.to("deg")
self._deg_radii = deg_radii
# Finally I'm going to make a check for infinite or NaN values in the passed value,
# radius, and uncertainty quantities
if np.isnan(radii).any() or np.isinf(radii).any():
raise ValueError("The radii quantity has NaN or infinite values")
if np.isnan(values).any() or np.isinf(values).any():
raise ValueError("The values quantity has NaN or infinite values")
if radii_err is not None and (np.isnan(radii_err).any() or np.isinf(radii_err).any()):
raise ValueError("The radii_err quantity has NaN or infinite values")
if values_err is not None and (np.isnan(values_err).any() or np.isinf(values_err).any()):
raise ValueError("The values_err quantity has NaN or infinite values")
if deg_radii is not None and (np.isnan(deg_radii).any() or np.isinf(deg_radii).any()):
raise ValueError("The deg_radii quantity has NaN or infinite values")
# And now we will check that no uncertainty values are negative, as that does not make sense yet can
# happen sometimes when XSPEC cannot constrain a parameter (for instance).
if radii_err is not None and (radii_err < 0).any():
raise ValueError("The radii_err quantity has negative values, which does not make sense "
"for an uncertainty.")
if values_err is not None and (values_err < 0).any():
raise ValueError("The radii_err quantity has negative values, which does not make sense "
"for an uncertainty.")
# Storing the key values in attributes
self._radii = radii
self._values = values
self._radii_err = radii_err
self._values_err = values_err
self._centre = centre
# This generates an array containing (hopefully) the original annular boundaries of the profile
if self._radii_err is not None:
upper_bounds = self._radii + self._radii_err
bounds = np.insert(upper_bounds, 0, self._radii[0]-self._radii_err[0])
if self._radii[0].value == 0:
bounds[0] = self._radii[0]
bounds[1] = bounds[1] + self._radii_err[0]
self._rad_ann_bounds = bounds
else:
self._rad_ann_bounds = None
# Just checking that if one of these values is combined, then both are. Doesn't make sense otherwise.
if (obs_id == "combined" and inst != "combined") or (inst == "combined" and obs_id != "combined"):
raise ValueError("If ObsID or inst is set to combined, then both must be set to combined.")
# Storing the passed source name in an attribute, as well as the ObsID and instrument
self._src_name = source_name
self._obs_id = obs_id
self._inst = inst
# Going to have this convenient attribute for profile classes, I could just use the type() command
# when I wanted to know but this is easier.
self._prof_type = "base"
# The currently implemented and allowed types of fitting for a profile
self._fit_methods = ['curve_fit', 'mcmc', 'odr']
self._nice_fit_methods = {'curve_fit': 'Curve Fit', 'mcmc': 'MCMC', 'odr': 'ODR'}
# Here is where information about fitted models is stored (and any failed fit attempts)
self._good_model_fits = {m: {} for m in self._fit_methods}
self._bad_model_fits = {m: {} for m in self._fit_methods}
# Some types of profiles will support a background value (like surface brightness), which will
# need to be incorporated into the fit and plotting.
self._background = Quantity(0, self._values.unit)
# Need to be able to store upper and lower energy bounds for those profiles that
# have them (like brightness profiles for instance)
self._energy_bounds = (None, None)
# Checking the if associated_set_id is supplied, so is set_storage_key, and vica versa
if not all([associated_set_id is None, set_storage_key is None]) and \
not all([associated_set_id is not None, set_storage_key is not None]):
raise ValueError("Both associated_set_id and set_storage_key must be None, or both must be not None.")
# Putting the associated set ID into an attribute, if this profile wasn't generated by an AnnularSpectra
# then this will just be None. Same for the set_storage_key
self._set_id = associated_set_id
# Don't think this one will get a property, I can't see why the user would need it.
self._set_storage_key = set_storage_key
# Here we generate a storage key for the profile to use to place itself in XGA's storage structure
if self._set_storage_key is not None:
# If there is a storage key for a spectrum which generated this available, then our life becomes
# quite simple, as it has a lot of information in it.
# In fact as the profile will also be indexed under the profile type name, we can just use this as
# our storage key
self._storage_key = self._set_storage_key
else:
# Default storage key for profiles that don't implement their own storage key will include their radii
# and the central coordinate
# Just being doubly sure its in degrees
cent_chunk = "ra{r}_dec{d}_r".format(r=centre.value[0], d=centre.value[1])
rad_chunk = "_".join(self._deg_radii.value.astype(str))
self._storage_key = cent_chunk + rad_chunk
# The y-axis label used to be stored in a dictionary in the init of models, but it makes more sense
# just declaring it in the init I think - it should be over-ridden in every subclass
self._y_axis_name = "Unknown"
# Akin to the usable attribute of other product classes, will be set for different reasons by different
# profile subclasses
self._usable = True
# These are the normalisation values for plotting (and possibly fitting at some point). I don't check their
# units at any point because the user is welcome to normalise their plots by whatever value they wish
self._x_norm = x_norm
self._y_norm = y_norm
if radii_err is not None:
self._outer_rad = radii[-1] + radii_err[-1]
else:
self._outer_rad = radii[-1]
self._save_path = None
def _model_allegiance(self, model: BaseModel1D):
"""
This internal method with a silly name just checks whether a model instance has already been associated
with a profile other than this one, or if it has any association with a profile. If there is an association
with another profile then it throws an error, as that that can cause serious problems that have caught me
out before (see issue #742). If there is no association it sets the models profile attribute.
:param BaseModel1D model: An instance of a BaseModel1D class (or subclass) to check.
"""
if model.profile is not None and model.profile != self:
raise ModelNotAssociatedError("The passed model instance is already associated with another profile, and"
" as such cannot be fit to this one. Ensure that individual model instances"
" are declared for each profile you are fitting.")
elif model.profile is None:
model.profile = self
[docs] def emcee_fit(self, model: BaseModel1D, num_steps: int, num_walkers: int, progress_bar: bool, show_warn: bool,
num_samples: int) -> Tuple[BaseModel1D, bool]:
"""
A fitting function to fit an XGA model instance to the data in this profile using the emcee
affine-invariant MCMC sampler, this should be called through .fit() for full functionality. An initial
run of curve_fit is used to find start parameters for the sampler, though if that fails a maximum
likelihood estimate is run, and if that fails the method will revert to using the start parameters
set in the model instance.
:param BaseModel1D model: The model to be fit to the data, you cannot pass a model name for this argument.
:param int num_steps: The number of steps each chain should take.
:param int num_walkers: The number of walkers to be run for the ensemble sampler.
:param bool progress_bar: Whether a progress bar should be displayed.
:param bool show_warn: Should warnings be printed out, otherwise they are just stored in the model
instance (this also happens if show_warn is True).
:param int num_samples: The number of random samples to take from the posterior distributions of
the model parameters.
:return: The model instance, and a boolean flag as to whether this was a successful fit or not.
:rtype: Tuple[BaseModel1D, bool]
"""
def find_to_replace(start_pos: np.ndarray, par_lims: np.ndarray) -> np.ndarray:
"""
Tiny function to generate an array of which start positions are currently invalid and should
be replaced.
:param np.ndarray start_pos: The current start positions of the walkers/parameters to be checked.
:param np.ndarray par_lims: The limits imposed on the start parameters.
:return: A numpy array of True and False values, True means the value should be replaced/recalculated.
:rtype: np.ndarray
"""
# It is possible that the start parameters can be outside of the range allowed by the the priors. In which
# case the MCMC fit will get super upset but not actually throw an error.
start_check_greater = np.greater_equal(start_pos, par_lims[:, 0])
start_check_lower = np.less_equal(start_pos, par_lims[:, 1])
# Any true value in this array is a parameter that isn't in the allowed prior range
to_replace_arr = ~(start_check_greater & start_check_lower)
return to_replace_arr
# The very first thing I do is to check whether the passed model is ACTUALLY a model or a model name - I
# expect this confusion could arise because the fit() method (which is what users should REALLY be using)
# allows either an instance or a model name.
if not isinstance(model, BaseModel1D):
raise TypeError("This fitting method requires that a model instance be passed for the model argument, "
"rather than a model name.")
# Then I check that the model instance hasn't already been fit to another profile - I would do this in the
# fit() method (because then I wouldn't have to it in every separate fitting method), but I can't
else:
self._model_allegiance(model)
# I'm just defining these here so that the lines don't get too long for PEP standards
y_data = (self.values.copy() - self._background).value
y_errs = self.values_err.copy().value
rads = self.fit_radii.copy().value
success = True
warning_str = ""
for prior in model.par_priors:
if prior['type'] != 'uniform':
raise NotImplementedError("Sorry but I don't yet support non-uniform priors for profile fitting!")
prior_list = [p['prior'].to(model.par_units[p_ind]).value for p_ind, p in enumerate(model.par_priors)]
prior_arr = np.array(prior_list)
# We can run a curve_fit fit to try and get start values for the model parameters, and if that fails
# we try maximum likelihood, and if that fails then we fall back on the default start parameters in the
# model.
# Making a copy of the model, and setting the profile to None otherwise the allegiance check gets very upset
curve_fit_model = deepcopy(model)
curve_fit_model.profile = None
curve_fit_model, success = self.nlls_fit(curve_fit_model, 10, show_warn=False)
if success or curve_fit_model.fit_warning == "Very large parameter uncertainties":
base_start_pars = np.array([p.value for p in curve_fit_model.model_pars])
else:
# This finds maximum likelihood parameter values for the model+data
max_like_res = minimize(lambda *args: -log_likelihood(*args, model.model), model.unitless_start_pars,
args=(rads, y_data, y_errs))
# I'm now adding this checking step, which will revert to the default start parameters of the model if the
# maximum likelihood estimate produced insane results.
base_start_pars = max_like_res.x
# So if any of the max likelihood pars are outside their prior, we just revert back to the original
# start parameters of the model. This step may make the checks performed later for instances where all
# start positions for a parameter are outside the prior a bit pointless, but I'm leaving them in for safety.
if find_to_replace(base_start_pars, prior_arr).any():
warn("Maximum likelihood estimator has produced at least one start parameter that is outside"
" the allowed values defined by the prior, reverting to default start parameters for this model.")
base_start_pars = model.unitless_start_pars
# This basically finds the order of magnitude of each parameter, so we know the scale on which we should
# randomly perturb
ml_rand_dev = np.power(10, np.floor(np.log10(np.abs(base_start_pars)))) / 10
# Then that order of magnitude is multiplied by a value drawn from a standard gaussian, and this is what
# we perturb the maximum likelihood values with - so we get random start parameters for all
# of our walkers
pos = base_start_pars + (ml_rand_dev * np.random.randn(num_walkers, model.num_pars))
# It is possible that some of the start parameters we've generated are outside the prior, in which
# case emcee gets quite angry. Just in case I draw random values from the priors of all parameters,
# ready to be substituted in, but only if I definitely can't get the start parameters from perturbing
# the max likelihood 'fit'
rand_uniform_pos = np.random.uniform(prior_arr[:, 0], prior_arr[:, 1], size=(num_walkers, model.num_pars))
# Check which of the current start parameters are currently valid
to_replace = find_to_replace(pos, prior_arr)
# Setting up decent starting values is very important, so first I'm going to check if every single starting
# value of any parameter is outside the bounds of our priors, because if that's true then the maximum
# likelihood 'fit' to get the initial starting parameters is probably a bit crappy
all_bad = np.all(to_replace, axis=0)
if any(all_bad):
warn("All walker starting parameters for one or more of the model parameters are outside the priors, which"
"probably indicates a bad initial fit (which is used to get initial start parameters). Values will be"
" drawn from the priors directly.")
# This replacement only affects those parameters for which ALL start positions are outside the
# prior range
all_bad_inds = np.argwhere(all_bad).T[0]
pos[:, all_bad_inds] = rand_uniform_pos[:, all_bad_inds]
# Now need to re-calculate the to_replace array, as it is no longer valid
to_replace = find_to_replace(pos, prior_arr)
# Now, if there are still problems with the starting positions, we know it is likely because some of
# the random perturbations of the 'best fit' start parameters are outside the priors, thus we will
# iteratively try and draw more acceptable perturbations, and if there are any 'bad' start positions
# left after 100 tries then they'll just get assigned randomly drawn values from the priors
iter_cnt = 0
while True in to_replace and iter_cnt < 100:
new_pos = base_start_pars + ml_rand_dev * np.random.randn(num_walkers, model.num_pars)
pos[to_replace] = new_pos[to_replace]
to_replace = find_to_replace(pos, prior_arr)
iter_cnt += 1
# So any start values that fall outside the allowed range will be swapped out with a value randomly drawn
# from the prior
pos[to_replace] = rand_uniform_pos[to_replace]
# This instantiates an Ensemble sampler with the number of walkers specified by the user,
# with the log probability as defined in the functions above
sampler = em.EnsembleSampler(num_walkers, model.num_pars, log_prob, args=(rads, y_data, y_errs, model.model,
prior_list))
try:
# So now we start the sampler, running for the number of steps specified on function call, with
# the starting parameters defined in the if statement above this.
sampler.run_mcmc(pos, num_steps, progress=progress_bar)
model.acceptance_fraction = np.mean(sampler.acceptance_fraction)
success = True
except ValueError as bugger:
warning_str = str(bugger)
model.fit_warning = warning_str
success = False
rng = np.random.default_rng()
if success:
# The auto-correlation can produce an error that basically says not to trust the chains
try:
# The sampler has a convenient auto-correlation time derivation, which returns the
# auto-correlation time for each parameter - with this I simply choose the highest one and
# round up to the nearest 100 to use as the burn-in
auto_corr = np.mean(sampler.get_autocorr_time())
# Find the nearest hundred above the mean auto-correlation time, then multiply by two for
# burn-in region
cut_off = int(np.ceil(auto_corr / 100) * 100)*2
success = True
except ValueError as bugger:
model.fit_warning = str(bugger)
success = False
cut_off = None
except em.autocorr.AutocorrError as bugger:
model.fit_warning = str(bugger)
cut_off = int(0.3 * num_steps)
# Store the chosen cut off in the model instance
model.cut_off = cut_off
# If the fit is considered to have not completely failed then we store distributions and parameters
# in the model instance
if success:
# I am not going to thin the chains, apparently that can actually increase variance?
flat_samp = sampler.get_chain(discard=cut_off, flat=True)
# Construct a numpy array representing the indices of the flattened chains
all_inds = np.arange(flat_samp.shape[0])
# Use the numpy random submodule to choose randomly sample the flattened chains
chosen_inds = rng.choice(all_inds, num_samples)
chosen = flat_samp[chosen_inds, :]
# Then give those chains the correct units and store them in the model instance
par_dists = [Quantity(chosen[:, p_ind], model.par_units[p_ind]) for p_ind in range(model.num_pars)]
model.par_dists = par_dists
# Start constructing the uncertainties on the parameters, though with a more complex model where the
# parameters distributions are not gaussian using these parameter values would not be appropriate
model_par_errs = []
for p_dist in par_dists:
# Store the current unit
u = p_dist.unit
# Measure the 50th percentile value of the current parameter distribution
fiftieth = np.percentile(p_dist, 50).value
# Find the upper and lower bounds of the 1sigma region of the distribution
upper = np.percentile(p_dist, 84.1).value
lower = np.percentile(p_dist, 15.9).value
# Store the upper and lower uncertainties with the correct units
model_par_errs.append(Quantity([fiftieth-lower, upper-fiftieth], u))
# Store the model parameter and uncertainties in the model instance
model.model_pars = [p_dist.mean() for p_dist in par_dists]
model.model_par_errs = model_par_errs
# Store the sampler in the model instance, useful for getting raw chains etc again in the future
model.emcee_sampler = sampler
# I show all the warnings at once, if the user wants that
if model.fit_warning != "" and show_warn:
print(model.fit_warning)
# Tell the model whether we think the fit was successful or not
model.success = success
# And finally storing the fit method used in the model itself
model.fit_method = "mcmc"
# Explicitly deleting the curve fit model, just to be safe
del curve_fit_model
return model, success
[docs] def nlls_fit(self, model: BaseModel1D, num_samples: int, show_warn: bool) -> Tuple[BaseModel1D, bool]:
"""
A function to fit an XGA model instance to the data in this profile using the non-linear least squares
curve_fit routine from scipy, this should be called through .fit() for full functionality
:param BaseModel1D model: An instance of the model to be fit to this profile.
:param int num_samples: The number of random samples to be drawn and stored in the model
parameter distribution property.
:param bool show_warn: Should warnings be printed out, otherwise they are just stored in the model
instance (this also happens if show_warn is True).
:return: The model (with best fit parameters stored within it), and a boolean flag as to whether the
fit was successful or not.
:rtype: Tuple[BaseModel1D, bool]
"""
# The very first thing I do is to check whether the passed model is ACTUALLY a model or a model name - I
# expect this confusion could arise because the fit() method (which is what users should REALLY be using)
# allows either an instance or a model name.
if not isinstance(model, BaseModel1D):
raise TypeError("This fitting method requires that a model instance be passed for the model argument, "
"rather than a model name.")
# Then I check that the model instance hasn't already been fit to another profile - I would do this in the
# fit() method (because then I wouldn't have to it in every separate fitting method), but I can't
else:
self._model_allegiance(model)
y_data = (self.values.copy() - self._background).value
y_errs = self.values_err.copy().value
rads = self.fit_radii.copy().value
success = True
warning_str = ""
lower_bounds = []
upper_bounds = []
for prior_ind, prior in enumerate(model.par_priors):
if prior['type'] == 'uniform':
conv_prior = prior['prior'].to(model.par_units[prior_ind]).value
lower_bounds.append(conv_prior[0])
upper_bounds.append(conv_prior[1])
else:
lower_bounds.append(-np.inf)
upper_bounds.append(np.inf)
# Curve fit is a simple non-linear least squares implementation, its alright but fragile
try:
fit_par, fit_cov = curve_fit(model.model, rads, y_data, p0=model.unitless_start_pars, sigma=y_errs,
absolute_sigma=True, bounds=(lower_bounds, upper_bounds))
# If there is an infinite value in the covariance matrix, it means curve_fit was
# unable to estimate it properly
if np.inf in fit_cov:
warning_str = "Infinity in covariance matrix"
success = False
else:
# Grab the diagonal of the covariance matrix, then sqrt to get sigma values for each parameter
fit_par_err = np.sqrt(np.diagonal(fit_cov))
frac_err = np.divide(fit_par_err, fit_par, where=fit_par != 0)
if frac_err.max() > 10:
warning_str = "Very large parameter uncertainties"
success = False
except RuntimeError as r_err:
warn("{}, curve_fit has failed.".format(str(r_err)))
warning_str = str(r_err)
success = False
fit_par = np.full(len(model.model_pars), np.nan)
fit_par_err = np.full(len(model.model_pars), np.nan)
# Using the parameter values and the covariance matrix I generate parameter distributions to store in the
# model instance.
# Apparently this is what we should be using for random numbers from numpy now!
rng = np.random.default_rng()
if success:
# I generate num_samples random points from the distribution of each parameter
ext_model_par = np.repeat(fit_par[..., None], num_samples, axis=1).T
ext_model_par_err = np.repeat(fit_par_err[..., None], num_samples, axis=1).T
# This generates model_real random samples from the passed model parameters, assuming they are Gaussian
model_par_dists = rng.normal(ext_model_par, ext_model_par_err)
par_dists = [Quantity(model_par_dists[:, p_ind], model.par_units[p_ind])
for p_ind in range(0, len(fit_par))]
model.par_dists = par_dists
# Now we put the values BACK into quantities
fit_par = [Quantity(p, model.par_units[p_ind]) for p_ind, p in enumerate(fit_par)]
fit_par_err = [Quantity(p, model.par_units[p_ind]) for p_ind, p in enumerate(fit_par_err)]
# And pass them into the model
model.model_pars = fit_par
model.model_par_errs = fit_par_err
if not success:
model.fit_warning = warning_str
if show_warn and warning_str != "":
warn(warning_str)
# Tell the model whether we think the fit was successful or not
model.success = success
# And finally storing the fit method used in the model itself
model.fit_method = "curve_fit"
# And then the model gets sent back
return model, success
def _odr_fit(self, model: BaseModel1D, show_warn: bool):
# TODO MAKE THIS A NON-INTERNAL METHOD WHEN ITS WRITTEN
# TODO REMEMBER TO USE THE FIT RADII PROPERTY
# Tell the model whether we think the fit was successful or not
# model.success = success
# The very first thing I do is to check whether the passed model is ACTUALLY a model or a model name - I
# expect this confusion could arise because the fit() method (which is what users should REALLY be using)
# allows either an instance or a model name.
if not isinstance(model, BaseModel1D):
raise TypeError("This fitting method requires that a model instance be passed for the model argument, "
"rather than a model name.")
# Then I check that the model instance hasn't already been fit to another profile - I would do this in the
# fit() method (because then I wouldn't have to it in every separate fitting method), but I can't
else:
self._model_allegiance(model)
# And finally storing the fit method used in the model itself
model.fit_method = "odr"
raise NotImplementedError("This fitting method is still under construction!")
[docs] def fit(self, model: Union[str, BaseModel1D], method: str = "mcmc", num_samples: int = 10000,
num_steps: int = 30000, num_walkers: int = 20, progress_bar: bool = True,
show_warn: bool = True) -> BaseModel1D:
"""
Method to fit a model to this profile's data, then store the resulting model parameter results. Each
profile can store one instance of a type of model per fit method. So for instance you could fit both
a 'beta' and 'double_beta' model to a surface brightness profile with curve_fit, and then you could
fit 'double_beta' again with MCMC.
If any of the parameters of the passed model have a uniform prior associated, and the chosen method
is curve_fit, then those priors will be used to place bounds on those parameters.
:param str/BaseModel1D model: Either an instance of an XGA model to be fit to this profile, or the name
of a profile (e.g. 'beta', or 'simple_vikhlinin_dens').
:param str method: The fit method to use, either 'curve_fit', 'mcmc', or 'odr'.
:param int num_samples: The number of random samples to draw to create the parameter distributions
that are saved in the model.
:param int num_steps: Only applicable if using MCMC fitting, the number of steps each walker should take.
:param int num_walkers: Only applicable if using MCMC fitting, the number of walkers to initialise
for the ensemble sampler.
:param bool progress_bar: Only applicable if using MCMC fitting, should a progress bar be shown.
:param bool show_warn: Should warnings be printed out, otherwise they are just stored in the model
instance (this also happens if show_warn is True).
:return: The fitted model object. The fitted model is also stored within the profile object.
:rtype: BaseModel1D
"""
# Make sure the method is lower case
method = method.lower()
# This chunk is just checking inputs and making sure they're valid
if self._prof_type in PROF_TYPE_MODELS:
# Put the allowed models for this profile type into a string
allowed = ", ".join(PROF_TYPE_MODELS[self._prof_type])
else:
allowed = ""
if self._prof_type == "base":
raise XGAFitError("A BaseProfile1D object currently cannot have a model fitted to it, as there"
" is no physical context.")
elif isinstance(model, str) and model.lower() not in PROF_TYPE_MODELS[self._prof_type]:
raise XGAInvalidModelError("{p} is not available for this type of profile, please use one of the "
"following models {a}".format(p=model, a=allowed))
elif isinstance(model, str):
model = PROF_TYPE_MODELS[self._prof_type][model](self.radii_unit, self.values_unit)
elif isinstance(model, BaseModel1D) and model.name not in PROF_TYPE_MODELS[self._prof_type]:
raise XGAInvalidModelError("{p} is not available for this type of profile, please use one of the "
"following models {a}".format(p=model, a=allowed))
elif isinstance(model, BaseModel1D) and (model.x_unit != self.radii_unit or model.y_unit != self.values_unit):
raise UnitConversionError("The model instance passed to the fit method has units that are incompatible, "
"with the data. This profile has an radius unit of {r} and a value unit of "
"{v}".format(r=self.radii_unit.to_string(), v=self.values_unit.to_string()))
# I don't think I'm going to allow any fits without value uncertainties - just seems daft
if self._values_err is None:
raise XGAFitError("You cannot fit to a profile that doesn't have value uncertainties.")
# Checking that the method passed is valid
if method not in self._fit_methods:
allowed = ", ".join(self._fit_methods)
raise XGAFitError("{me} is not a valid fitting method, please use one of these; {a}".format(me=method,
a=allowed))
# Check whether a good fit result already exists for this model. We use the storage_key property that
# XGA model objects generate from their name and their start parameters
if model.name in self._good_model_fits[method]:
warn("{m} already has a successful fit result for this profile using {me}, with those start "
"parameters".format(m=model.name, me=method))
already_done = True
elif model.name in self._bad_model_fits[method]:
warn("{m} already has a failed fit result for this profile using {me} with those start "
"parameters".format(m=model.name, me=method))
already_done = False
else:
already_done = False
# Running the requested fitting method
if not already_done and method == 'mcmc':
model, success = self.emcee_fit(model, num_steps, num_walkers, progress_bar, show_warn, num_samples)
elif not already_done and method == 'curve_fit':
model, success = self.nlls_fit(model, num_samples, show_warn)
elif not already_done and method == 'odr':
model, success = self._odr_fit(model, show_warn)
else:
model = self.get_model_fit(model.name, method)
# Storing the model in the internal dictionaries depending on whether the fit was successful or not
if not already_done and success:
self._good_model_fits[method][model.name] = model
elif not already_done and not success:
self._bad_model_fits[method][model.name] = model
# This method means that a change has happened to the model, so it should be re-saved
self.save()
return model
[docs] def allowed_models(self, table_format: str = 'fancy_grid'):
"""
This is a convenience function to tell the user what models can be used to fit a profile
of the current type, what parameters are expected, and what the defaults are.
:param str table_format: The desired format of the allowed models table. This is passed to the
tabulate module (allowed formats can be found here - https://pypi.org/project/tabulate/), and
alters the way the printed table looks.
"""
# Base profile don't have any type of model associated with them, so just making an empty list
if self._prof_type == "base":
warn("There are no implemented models for this profile type")
else:
allowed = list(PROF_TYPE_MODELS[self._prof_type].keys())
# These just roll through the available models for this type of profile and construct strings of
# parameter names and start parameters to put in the table
model_par_names = []
model_par_starts = []
for m in allowed:
exp_pars = ""
par_len = 0
def_starts = ""
def_len = 0
# This chunk of code tries to make sure that the strings aren't too long to display nicely
# in the table
mod_inst = PROF_TYPE_MODELS[self._prof_type][m]()
for p_ind, p in enumerate(list(inspect.signature(mod_inst.model).parameters.values())[1:]):
if par_len > 35:
exp_pars += ' \n'
par_len = 0
next_par = '{}, '.format(p.name)
par_len += len(next_par)
exp_pars += next_par
if def_len > 35:
def_starts += ' \n'
def_len = 0
next_def = '{}, '.format(str(mod_inst.start_pars[p_ind]))
def_len += len(next_def)
def_starts += next_def
# We slice out the last character because we know its going to be a spurious comma, just
# because of the lazy way I wrote the loop above
model_par_names.append(exp_pars[:-2])
model_par_starts.append(def_starts[:-2])
# Construct the table data and display it using tabulate module
tab_dat = [[allowed[i], model_par_names[i], model_par_starts[i]] for i in range(0, len(allowed))]
print(tabulate(tab_dat, ["MODEL NAME", "EXPECTED PARAMETERS", "DEFAULT START VALUES"],
tablefmt=table_format))
[docs] def get_model_fit(self, model: str, method: str) -> BaseModel1D:
"""
A get method for fitted model objects associated with this profile. Models for which the fit failed will
also be returned, but a warning will be shown to inform the user that the fit failed.
:param str model: The name of the model to retrieve.
:param str method: The method which was used to fit the model.
:return: An instance of an XGA model object that was fitted to this profile and updated with the
parameter values.
:rtype: BaseModel1D
"""
if model not in PROF_TYPE_MODELS[self._prof_type]:
allowed = list(PROF_TYPE_MODELS[self._prof_type].keys())
prof_name = self._y_axis_name.lower()
raise XGAInvalidModelError("{m} is not a valid model for a {p} profile, please choose from "
"one of these; {a}".format(m=model, a=", ".join(allowed), p=prof_name))
elif model in self._bad_model_fits[method]:
warn("An attempt was made to fit {m} with {me} but it failed, so treat the model with "
"suspicion".format(m=model, me=method))
ret_model = self._bad_model_fits[method][model]
elif model not in self._good_model_fits[method]:
raise ModelNotAssociatedError("{m} is valid for this profile, but hasn't been fit with {me} "
"yet".format(m=model, me=method))
else:
ret_model = self._good_model_fits[method][model]
return ret_model
[docs] def add_model_fit(self, model: BaseModel1D, method: str):
"""
There are rare circumstances where XGA processes might wish to add a model to a profile from the outside,
which is what this method allows you to do.
:param BaseModel1D model: The XGA model object to add to the profile.
:param str method: The method used to fit the model.
"""
# Checking that the method passed is valid
if method not in self._fit_methods:
allowed = ", ".join(self._fit_methods)
raise XGAFitError("{me} is not a valid fitting method, please use one of these; {a}".format(me=method,
a=allowed))
# Checking that the model is valid for this particular profile
allowed = ", ".join(PROF_TYPE_MODELS[self._prof_type])
if model.name not in PROF_TYPE_MODELS[self._prof_type]:
raise XGAInvalidModelError("{p} is not valid for this type of profile, please use one of the "
"following models {a}".format(p=model.name, a=allowed))
elif model.x_unit != self.radii_unit or model.y_unit != self.values_unit:
raise UnitConversionError("The model instance passed to the fit method has units that are incompatible, "
"with the data. This profile has an radius unit of {r} and a value unit of "
"{v}".format(r=self.radii_unit.to_string(), v=self.values_unit.to_string()))
elif not model.success:
raise ValueError("Please only add successful models to this profile.")
else:
self._good_model_fits[method][model.name] = model
# This method means that a change has happened to the model, so it should be re-saved
self.save()
[docs] def get_sampler(self, model: str) -> em.EnsembleSampler:
"""
A get method meant to retrieve the MCMC ensemble sampler used to fit a particular
model (supplied by the user). Checks are applied to the supplied model, to make
sure that it is valid for the type of profile, that a good fit has actually been
performed, and that the fit was performed with Emcee and not another method.
:param str model: The name of the model for which to retrieve the sampler.
:return: The Emcee sampler used to fit the user supplied model.
:rtype: em.EnsembleSampler
"""
model = self.get_model_fit(model, 'mcmc')
return model.emcee_sampler
[docs] def get_chains(self, model: str, discard: Union[bool, int] = True, flatten: bool = True,
thin: int = 1) -> np.ndarray:
"""
Get method for the sampler chains of an MCMC fit to the user supplied model. get_sampler is
called to retrieve the sampler object, as well as perform validity checks on the model name.
:param str model: The name of the model for which to retrieve the chains.
:param bool/int discard: Whether steps should be discarded for burn-in. If True then the cut off decided
using the auto-correlation time will be used. If an integer is passed then this will be used as the
number of steps to discard, and if False then no steps will be discarded.
:param bool flatten: Should the chains of the multiple walkers be flattened into one chain per parameter.
:param int thin: The thinning that should be applied to the chains. The default is 1, which means no
thinning is applied.
:return: The requested chains.
:rtype: np.ndarray
"""
model = self.get_model_fit(model, 'mcmc')
if isinstance(discard, bool) and discard:
chains = model.emcee_sampler.get_chain(discard=model.cut_off, flat=flatten, thin=thin)
elif isinstance(discard, int):
chains = model.emcee_sampler.get_chain(discard=discard, flat=flatten, thin=thin)
else:
chains = model.emcee_sampler.get_chain(flat=flatten, thin=thin)
return chains
[docs] def view_chains(self, model: str, discard: Union[bool, int] = True, thin: int = 1, figsize: Tuple = None):
"""
Simple view method to quickly look at the MCMC chains for a given model fit.
:param str model: The name of the model for which to view the MCMC chains.
:param bool/int discard: Whether steps should be discarded for burn-in. If True then the cut off decided
using the auto-correlation time will be used. If an integer is passed then this will be used as the
number of steps to discard, and if False then no steps will be discarded.
:param int thin: The thinning that should be applied to the chains. The default is 1, which means no
thinning is applied.
:param Tuple figsize: Desired size of the figure, if None will be set automatically.
"""
chains = self.get_chains(model, discard, thin=thin, flatten=False)
model_obj = self.get_model_fit(model, 'mcmc')
if figsize is None:
fig, axes = plt.subplots(nrows=model_obj.num_pars, figsize=(12, 2*model_obj.num_pars), sharex='col')
else:
fig, axes = plt.subplots(model_obj.num_pars, figsize=figsize, sharex='col')
plt.suptitle("{m} Parameter Chains".format(m=model_obj.publication_name), fontsize=14, y=1.02)
for i in range(model_obj.num_pars):
cur_unit = model_obj.par_units[i]
if cur_unit == Unit(''):
par_unit_name = ""
else:
par_unit_name = r" $\left[" + cur_unit.to_string("latex").strip("$") + r"\right]$"
ax = axes[i]
ax.plot(chains[:, :, i], "k", alpha=0.3)
ax.set_xlim(0, len(chains))
ax.set_ylabel(model_obj.par_publication_names[i] + par_unit_name, fontsize=13)
ax.yaxis.set_label_coords(-0.1, 0.5)
axes[-1].set_xlabel("Step Number", fontsize=13)
plt.tight_layout()
plt.show()
[docs] def view_corner(self, model: str, figsize: Tuple = (8, 8)):
"""
A convenient view method to examine the corner plot of the parameter posterior distributions.
:param str model: The name of the model for which to view the corner plot.
:param Tuple figsize: The desired figure size.
"""
flat_chains = self.get_chains(model, flatten=True)
model_obj = self.get_model_fit(model, 'mcmc')
frac_conf_lev = [(50 - 34.1)/100, 0.5, (50 + 34.1)/100]
# If any of the median parameter values are above 1e+4 we get corner to format them in scientific
# notation, to avoid super long numbers spilling over the edge of the corner plot. I will say that
# scientific notation in titles in corner doesn't look that great either, but its better than the
# alternative
if np.any(np.median(flat_chains, axis=0) > 1e+4):
fig = corner.corner(flat_chains, labels=model_obj.par_publication_names, figsize=figsize,
quantiles=frac_conf_lev, show_titles=True, title_fmt=".2e")
else:
fig = corner.corner(flat_chains, labels=model_obj.par_publication_names, figsize=figsize,
quantiles=frac_conf_lev, show_titles=True)
t = self._y_axis_name
plt.suptitle("{m} - {s} {t} Profile".format(m=model_obj.publication_name, s=self.src_name, t=t),
fontsize=14, y=1.02)
plt.show()
[docs] def view_getdist_corner(self, model: str, settings: dict = {}, figsize: tuple = (10, 10)):
"""
A view method to see a corner plot generated with the getdist module, using flattened chains with
burn-in removed (whatever the getdist message might say).
:param str model: The name of the model for which to view the corner plot.
:param dict settings: The settings dictionary for a getdist MCSample.
:param tuple figsize: A tuple to set the size of the figure.
"""
# Grab the flattened chains
flat_chains = self.get_chains(model, flatten=True)
model_obj = self.get_model_fit(model, 'mcmc')
# Need to remove $ from the labels because getdist adds them itself
stripped_labels = [n.replace('$', '') for n in model_obj.par_publication_names]
# Setup the getdist sample object
gd_samp = MCSamples(samples=flat_chains, names=model_obj.par_names, labels=stripped_labels,
settings=settings)
# And generate the triangle plot
g = plots.get_subplot_plotter(width_inch=figsize[0])
g.triangle_plot([gd_samp], filled=True)
plt.show()
[docs] def generate_data_realisations(self, num_real: int):
"""
A method to generate random realisations of the data points in this profile, using their y-axis values
and uncertainties. This can be useful for error propagation for instance, and does not require a model fit
to work. This method assumes that the y-errors are 1-sigma, which isn't necessarily the case.
:param int num_real: The number of random realisations to generate.
:return: An N x R astropy quantity, where N is the number of realisations and R is the number of radii
at which there are data points in this profile.
:rtype: Quantity
"""
if self.values_err is None:
raise ValueError("This profile has no y-error information, and as such you cannot generate random"
" realisations of the data.")
# Here I copy the values and value uncertainties N times, where N is the number of realisations
# the user wants
ext_values = np.repeat(self.values[..., None], num_real, axis=1).T
ext_value_errs = np.repeat(self.values_err[..., None], num_real, axis=1).T
# Then I just generate N realisations of the profiles using a normal distribution, though this does assume
# that the errors are one sigma which isn't necessarily true
realisations = np.random.normal(ext_values, ext_value_errs)
realisations = Quantity(realisations, self.values_unit)
return realisations
[docs] def get_view(self, fig: Figure, main_ax: Axes, xscale="log", yscale="log", xlim=None, ylim=None, models=True,
back_sub: bool = True, just_models: bool = False, custom_title: str = None, draw_rads: dict = {},
x_norm: Union[bool, Quantity] = False, y_norm: Union[bool, Quantity] = False, x_label: str = None,
y_label: str = None, data_colour: str = 'black', model_colour: str = 'seagreen',
show_legend: bool = True, show_residual_ax: bool = True):
"""
A get method for an axes (or multiple axes) showing this profile and model fits. The idea of this get method
is that, whilst it is used by the view() method, it can also be called by external methods that wish to use
the profile plot in concert with other views.
:param Figure fig: The figure which has been set up for this profile plot.
:param Axes main_ax: The matplotlib axes on which to show the image.
:param str xscale: The scaling to be applied to the x axis, default is log.
:param str yscale: The scaling to be applied to the y axis, default is log.
:param Tuple xlim: The limits to be applied to the x axis, upper and lower, default is
to let matplotlib decide by itself.
:param Tuple ylim: The limits to be applied to the y axis, upper and lower, default is
to let matplotlib decide by itself.
:param str models: Should the fitted models to this profile be plotted, default is True
:param bool back_sub: Should the plotted data be background subtracted, default is True.
:param bool just_models: Should ONLY the fitted models be plotted? Default is False
:param str custom_title: A plot title to replace the automatically generated title, default is None.
:param dict draw_rads: A dictionary of extra radii (as astropy Quantities) to draw onto the plot, where
the dictionary key they are stored under is what they will be labelled.
e.g. ({'r500': Quantity(), 'r200': Quantity()}
:param bool x_norm: Controls whether the x-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the x_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param bool y_norm: Controls whether the y-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the y_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param str x_label: Custom label for the x-axis (excluding units, which will be added automatically).
:param str y_label: Custom label for the y-axis (excluding units, which will be added automatically).
:param str data_colour: Used to set the colour of the data points.
:param str model_colour: Used to set the colour of a model fit.
:param bool show_legend: Whether the legend should be displayed or not. Default is True.
:param bool show_residual_ax: Controls whether a lower axis showing the residuals between data and
model (if a model is fitted and being shown) is displayed. Default is True.
"""
# Checks that any extra radii that have been passed are the correct units (i.e. the same as the radius units
# used in this profile)
if not all([r.unit == self.radii_unit for r in draw_rads.values()]):
raise UnitConversionError("All radii in draw_rad have to be in the same units as this profile, "
"{}".format(self.radii_unit.to_string()))
# Default is to show models, but that flag is set to False here if there are none, otherwise we get
# extra plotted stuff that doesn't make sense
if len(self.good_model_fits) == 0:
models = False
just_models = False
# If the user wants the x-axis to be normalised then we grab the value from the profile (though of course
# if the user didn't set it initially then self.x_norm will also be 1
if isinstance(x_norm, bool) and x_norm:
x_norm = self.x_norm
if self.x_norm == Quantity(1, ''):
warn("No normalisation value is stored for the x-axis")
elif isinstance(x_norm, Quantity):
x_norm = x_norm
elif isinstance(x_norm, bool) and not x_norm:
# Otherwise we set x_norm to a harmless values with no units and unity value
x_norm = Quantity(1, '')
if isinstance(y_norm, bool) and y_norm:
y_norm = self.y_norm
if self.y_norm == Quantity(1, ''):
warn("No normalisation value is stored for the y-axis")
elif isinstance(y_norm, Quantity):
y_norm = y_norm
elif isinstance(y_norm, bool) and not y_norm:
y_norm = Quantity(1, '')
main_ax.minorticks_on()
if models and show_residual_ax:
# This sets up an axis for the residuals to be plotted on, if model plotting is enabled
res_ax = fig.add_axes((0.125, -0.075, 0.775, 0.2))
res_ax.minorticks_on()
res_ax.tick_params(axis='both', direction='in', which='both', top=True, right=True)
# Adds a zero line for reference, as its ideally where residuals would be
res_ax.axhline(0.0, color="black")
# Setting some aesthetic parameters for the main plotting axis
main_ax.tick_params(axis='both', direction='in', which='both', top=True, right=True)
if self.type == "brightness_profile" and self.psf_corrected:
leg_label = self.src_name + " PSF Corrected"
else:
leg_label = self.src_name
# This subtracts the background if the user wants a background subtracted plot
plot_y_vals = self.values.copy()
if back_sub:
plot_y_vals -= self.background
rad_vals = self.fit_radii.copy()
plot_y_vals /= y_norm
rad_vals /= x_norm
# Now the actual plotting of the data
if self.radii_err is not None and self.values_err is None:
x_errs = (self.radii_err.copy() / x_norm).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, xerr=x_errs, fmt="x", capsize=2,
label=leg_label, color=data_colour)
elif self.radii_err is None and self.values_err is not None:
y_errs = (self.values_err.copy() / y_norm).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, yerr=y_errs, fmt="x", capsize=2,
label=leg_label, color=data_colour)
elif self.radii_err is not None and self.values_err is not None:
x_errs = (self.radii_err.copy() / x_norm).value
y_errs = (self.values_err.copy() / y_norm).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, xerr=x_errs, yerr=y_errs, fmt="x", capsize=2,
label=leg_label, color=data_colour)
else:
line = main_ax.plot(rad_vals.value, plot_y_vals.value, 'x', label=leg_label, color=data_colour)
if just_models and models:
line[0].set_visible(False)
if len(line) != 1:
for coll in line[1:]:
for art_obj in coll:
art_obj.set_visible(False)
if not back_sub and self.background.value != 0:
main_ax.axhline(self.background.value, label=leg_label + ' Background', linestyle='dashed',
color=line[0].get_color())
if models:
for method in self._good_model_fits:
for model in self._good_model_fits[method]:
model_obj = self._good_model_fits[method][model]
lo_rad = self.fit_radii.min()
hi_rad = self.fit_radii.max()
mod_rads = np.linspace(lo_rad, hi_rad, 100)
mod_reals = model_obj.get_realisations(mod_rads)
# mean_model = np.mean(mod_reals, axis=1)
median_model = np.percentile(mod_reals, 50, axis=1)
upper_model = np.percentile(mod_reals, 84.1, axis=1)
lower_model = np.percentile(mod_reals, 15.9, axis=1)
mod_lab = model_obj.publication_name + " - {}".format(self._nice_fit_methods[method])
main_ax.plot(mod_rads.value / x_norm.value, median_model.value / y_norm, label=mod_lab,
color=model_colour)
main_ax.fill_between(mod_rads.value / x_norm.value, lower_model.value / y_norm.value,
upper_model.value / y_norm.value, alpha=0.7, interpolate=True,
where=upper_model.value >= lower_model.value, facecolor=model_colour)
main_ax.plot(mod_rads.value / x_norm.value, lower_model.value / y_norm.value, color=model_colour,
linestyle="dashed")
main_ax.plot(mod_rads.value / x_norm.value, upper_model.value / y_norm.value, color=model_colour,
linestyle="dashed")
# I only want this to trigger if the user has decided they want a residual axis. I expect most
# of the time that they will, but for things like the Hydrostatic mass diagnostic plots I want
# to be able to turn the residual axis off.
if show_residual_ax:
# This calculates and plots the residuals between the model and the data on the extra
# axis we added near the beginning of this method
res = np.percentile(model_obj.get_realisations(self.fit_radii), 50, axis=1) \
- (plot_y_vals * y_norm)
res_ax.plot(rad_vals.value, res.value, 'D', color=model_colour)
# Parsing the astropy units so that if they are double height then the square brackets will adjust size
x_unit = r"$\left[" + rad_vals.unit.to_string("latex").strip("$") + r"\right]$"
y_unit = r"$\left[" + plot_y_vals.unit.to_string("latex").strip("$") + r"\right]$"
# If the quantity being plotted is unitless then we don't want there to be empty brackets
if x_unit == r"$\left[\mathrm{}\right]$":
x_unit = ""
if y_unit == r"$\left[\mathrm{}\right]$":
y_unit = ""
if x_label is None:
# Setting the main plot's x label
main_ax.set_xlabel("Radius {}".format(x_unit), fontsize=13)
else:
main_ax.set_xlabel(x_label + " {}".format(x_unit), fontsize=13)
if y_label is None and (self._background.value == 0 or not back_sub):
main_ax.set_ylabel(r"{l} {u}".format(l=self._y_axis_name, u=y_unit), fontsize=13)
elif y_label is None:
# If background has been subtracted it will be mentioned in the y axis label
main_ax.set_ylabel(r"Background Subtracted {l} {u}".format(l=self._y_axis_name, u=y_unit), fontsize=13)
elif y_label is not None:
main_ax.set_ylabel(y_label + ' {}'.format(y_unit), fontsize=13)
# If the user wants a legend to be shown, then we create one
if show_legend:
main_leg = main_ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), ncol=1, borderaxespad=0)
# This makes sure legend keys are shown, even if the data is hidden
for leg_key in main_leg.legendHandles:
leg_key.set_visible(True)
# If the user has manually set limits then we can use them, only on the main axis because
# we grab those limits from the axes object for the residual axis later
if xlim is not None:
main_ax.set_xlim(xlim)
if ylim is not None:
main_ax.set_ylim(ylim)
# Setup the scale that the user wants to see, again on the main axis
main_ax.set_xscale(xscale)
main_ax.set_yscale(yscale)
if models and show_residual_ax:
# We want the residual x axis limits to be identical to the main axis, as the
# points should line up
res_ax.set_xlim(main_ax.get_xlim())
res_ax.set_xlabel("Radius {}".format(x_unit), fontsize=13)
res_ax.set_xscale(xscale)
# Grabbing the automatically assigned y limits for the residual axis, then finding the maximum
# difference from zero, increasing it by 10%, then setting that value is the new -+ limits
# That way its symmetrical
outer_ylim = 1.1 * max([abs(lim) for lim in res_ax.get_ylim()])
res_ax.set_ylim(-outer_ylim, outer_ylim)
res_ax.set_ylabel("Model - Data", fontsize=13)
# Adds a title to this figure, changes depending on whether model fits are plotted as well
if models and custom_title is None and len(self.good_model_fits) == 1:
title_str = "{l} Profile - with model".format(l=self._y_axis_name)
elif models and custom_title is None and len(self.good_model_fits) > 1:
title_str = "{l} Profile - with models".format(l=self._y_axis_name)
elif not models and custom_title is None:
title_str = "{l} Profile".format(l=self._y_axis_name)
else:
# If the user doesn't like my title, they can supply their own
title_str = custom_title
# If this particular profile is not considered usable, the user should be made aware in the plot
if not self._usable:
title_str += " [CONSIDERED UNUSABLE]"
# Actually plots the title
plt.suptitle(title_str, y=0.91)
# If the user has passed radii to plot, then we plot them
for r_name in draw_rads:
d_rad = (draw_rads[r_name] / x_norm).value
main_ax.axvline(d_rad, linestyle='dashed', color='black')
main_ax.annotate(r_name, (d_rad * 1.01, 0.9), rotation=90, verticalalignment='center',
color='black', fontsize=14, xycoords=('data', 'axes fraction'))
# Use the axis limits quite a lot in this next bit, so read them out into variables
x_axis_lims = main_ax.get_xlim()
y_axis_lims = main_ax.get_ylim()
# This dynamically changes how tick labels are formatted depending on the values displayed
if max(x_axis_lims) < 100 and not models and min(x_axis_lims) > 0.1:
main_ax.xaxis.set_minor_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
main_ax.xaxis.set_major_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
elif max(x_axis_lims) < 100 and models and show_residual_ax:
res_ax.xaxis.set_minor_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
res_ax.xaxis.set_major_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
if max(y_axis_lims) < 100 and min(y_axis_lims) > 0.1:
main_ax.yaxis.set_minor_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
main_ax.yaxis.set_major_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
elif max(y_axis_lims) < 100 and min(y_axis_lims) <= 0.1:
main_ax.yaxis.set_major_formatter(FuncFormatter(lambda inp, _: '{:g}'.format(inp)))
if models and show_residual_ax:
return main_ax, res_ax
else:
return main_ax, None
[docs] def view(self, figsize=(10, 7), xscale="log", yscale="log", xlim=None, ylim=None, models=True,
back_sub: bool = True, just_models: bool = False, custom_title: str = None, draw_rads: dict = {},
x_norm: Union[bool, Quantity] = False, y_norm: Union[bool, Quantity] = False, x_label: str = None,
y_label: str = None, data_colour: str = 'black', model_colour: str = 'seagreen', show_legend: bool = True,
show_residual_ax: bool = True):
"""
A method that allows us to view the current profile, as well as any models that have been fitted to it,
and their residuals. The models are plotted by generating random model realisations from the parameter
distributions, then plotting the median values, with 1sigma confidence limits.
:param Tuple figsize: The desired size of the figure, the default is (10, 7)
:param str xscale: The scaling to be applied to the x axis, default is log.
:param str yscale: The scaling to be applied to the y axis, default is log.
:param Tuple xlim: The limits to be applied to the x axis, upper and lower, default is
to let matplotlib decide by itself.
:param Tuple ylim: The limits to be applied to the y axis, upper and lower, default is
to let matplotlib decide by itself.
:param str models: Should the fitted models to this profile be plotted, default is True
:param bool back_sub: Should the plotted data be background subtracted, default is True.
:param bool just_models: Should ONLY the fitted models be plotted? Default is False
:param str custom_title: A plot title to replace the automatically generated title, default is None.
:param dict draw_rads: A dictionary of extra radii (as astropy Quantities) to draw onto the plot, where
the dictionary key they are stored under is what they will be labelled.
e.g. ({'r500': Quantity(), 'r200': Quantity()}
:param bool x_norm: Controls whether the x-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the x_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param bool y_norm: Controls whether the y-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the y_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param str x_label: Custom label for the x-axis (excluding units, which will be added automatically).
:param str y_label: Custom label for the y-axis (excluding units, which will be added automatically).
:param str data_colour: Used to set the colour of the data points.
:param str model_colour: Used to set the colour of a model fit.
:param bool show_legend: Whether the legend should be displayed or not. Default is True.
:param bool show_residual_ax: Controls whether a lower axis showing the residuals between data and
model (if a model is fitted and being shown) is displayed. Default is True.
"""
# Setting up figure for the plot
fig = plt.figure(figsize=figsize)
# Grabbing the axis object and making sure the ticks are set up how we want
main_ax = plt.gca()
main_ax, res_ax = self.get_view(fig, main_ax, xscale, yscale, xlim, ylim, models, back_sub, just_models,
custom_title, draw_rads, x_norm, y_norm, x_label, y_label, data_colour,
model_colour, show_legend, show_residual_ax)
# plt.tight_layout()
plt.show()
# Wipe the figure
plt.close("all")
[docs] def save_view(self, save_path: str, figsize=(10, 7), xscale="log", yscale="log", xlim=None, ylim=None, models=True,
back_sub: bool = True, just_models: bool = False, custom_title: str = None, draw_rads: dict = {},
x_norm: Union[bool, Quantity] = False, y_norm: Union[bool, Quantity] = False, x_label: str = None,
y_label: str = None, data_colour: str = 'black', model_colour: str = 'seagreen',
show_legend: bool = True, show_residual_ax: bool = True):
"""
A method that allows us to save a view of the current profile, as well as any models that have been
fitted to it, and their residuals. The models are plotted by generating random model realisations from
the parameter distributions, then plotting the median values, with 1sigma confidence limits.
This method will not display a figure, just save it at the supplied save_path.
:param str save_path: The path (including file name) where you wish to save the profile view.
:param Tuple figsize: The desired size of the figure, the default is (10, 7)
:param str xscale: The scaling to be applied to the x axis, default is log.
:param str yscale: The scaling to be applied to the y axis, default is log.
:param Tuple xlim: The limits to be applied to the x axis, upper and lower, default is
to let matplotlib decide by itself.
:param Tuple ylim: The limits to be applied to the y axis, upper and lower, default is
to let matplotlib decide by itself.
:param str models: Should the fitted models to this profile be plotted, default is True
:param bool back_sub: Should the plotted data be background subtracted, default is True.
:param bool just_models: Should ONLY the fitted models be plotted? Default is False
:param str custom_title: A plot title to replace the automatically generated title, default is None.
:param dict draw_rads: A dictionary of extra radii (as astropy Quantities) to draw onto the plot, where
the dictionary key they are stored under is what they will be labelled.
e.g. ({'r500': Quantity(), 'r200': Quantity()}
:param bool x_norm: Controls whether the x-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the x_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param bool y_norm: Controls whether the y-axis of the profile is normalised by another value, the default is
False, in which case no normalisation is applied. If it is set to True then it will attempt to use the
internal normalisation value (which can be set with the y_norm property), and if a quantity is passed it
will attempt to normalise using that.
:param str x_label: Custom label for the x-axis (excluding units, which will be added automatically).
:param str y_label: Custom label for the y-axis (excluding units, which will be added automatically).
:param str data_colour: Used to set the colour of the data points.
:param str model_colour: Used to set the colour of a model fit.
:param bool show_legend: Whether the legend should be displayed or not. Default is True.
:param bool show_residual_ax: Controls whether a lower axis showing the residuals between data and
model (if a model is fitted and being shown) is displayed. Default is True.
"""
# Setting up figure for the plot
fig = plt.figure(figsize=figsize)
# Grabbing the axis object and making sure the ticks are set up how we want
main_ax = plt.gca()
main_ax, res_ax = self.get_view(fig, main_ax, xscale, yscale, xlim, ylim, models, back_sub, just_models,
custom_title, draw_rads, x_norm, y_norm, x_label, y_label, data_colour,
model_colour, show_legend, show_residual_ax)
plt.savefig(save_path)
# plt.tight_layout()
plt.show()
# Wipe the figure
plt.close("all")
[docs] def save(self, save_path: str = None):
"""
This method pickles and saves the profile object. This will be called automatically when the profile
is initialised, and when changes are made to the profile (such as when a model is fitted). The save
file is a pickled version of this object.
:param str save_path: The path where this profile should be saved. By default this is None, which means
this method will use the save_path attribute of the profile.
"""
# Checks to see if the user has supplied their own custom save path.
if save_path is None and self.save_path is not None:
save_path = self.save_path
elif save_path is None and self.save_path is None:
raise TypeError("Base profiles cannot be saved")
# Pickles and saves this profile instance.
with open(save_path, 'wb') as picklo:
pickle.dump(self, picklo)
@property
def save_path(self) -> str:
"""
Property getter that assembles the default XGA save path of this profile. The file name contains
limited information; the type of profile, the source name, and a random integer.
:return: The default XGA save path for this profile.
:rtype: str
"""
if self._save_path is None and self._prof_type != "base":
temp_path = OUTPUT + "profiles/{sn}/{pt}_{sn}_{id}.xga"
rand_prof_id = randint(0, 1e+8)
while os.path.exists(temp_path.format(pt=self.type, sn=self.src_name, id=rand_prof_id)):
rand_prof_id = randint(0, 1e+8)
self._save_path = temp_path.format(pt=self.type, sn=self.src_name, id=rand_prof_id)
return self._save_path
@property
def good_model_fits(self) -> List:
"""
A list of the names of models that have been successfully fitted to the profile.
:return: A list of model names.
:rtype: Dict
"""
models = []
for method in self._good_model_fits:
for model in self._good_model_fits[method]:
if model not in models:
models.append(model)
return models
# None of these properties concerning the radii and values are going to have setters, if the user
# wants to modify it then they can define a new product.
@property
def radii(self) -> Quantity:
"""
Getter for the radii passed in at init. These radii correspond to radii where the values were measured.
:return: Astropy quantity array of radii.
:rtype: Quantity
"""
return self._radii
@property
def radii_err(self) -> Quantity:
"""
Getter for the uncertainties on the profile radii.
:return: Astropy quantity array of radii uncertainties, or a None value if no radii_err where passed.
:rtype: Quantity
"""
return self._radii_err
@property
def fit_radii(self) -> Quantity:
"""
This property gives the user a sanitised set of radii that is safe to use for fitting to XGA models, by
which I mean if the first element is zero (true for many of XGA's profiles), then it will be replaced by
a value slightly above zero that won't cause divide by zeros in the fit process.
If the radius units are convertible to kpc then the zero value will be set to the equivelant of 1kpc, if
they have pixel units then it will be set to one pixel, and if they are equivelant to degrees then it will
be set to 1eâˆ’5 degrees. The value for degrees is loosely based on the value of 1kpc at a redshift of 1.
:return: A Quantity with a set of radii that are 'safe' for fitting
:rtype: Quantity
"""
safe_rads = self._radii.copy()
if safe_rads[0] == 0 and self.radii_unit.is_equivalent('kpc'):
safe_rads[0] = Quantity(1, 'kpc').to(self.radii_unit)
elif safe_rads[0] == 0 and self.radii_unit.is_equivalent('pix'):
safe_rads[0] = Quantity(1, 'pix')
elif safe_rads[0] == 0 and self.radii_unit.is_equivalent('deg'):
safe_rads[0] = Quantity(1e-5, 'deg')
return safe_rads
@property
def radii_unit(self) -> Unit:
"""
Getter for the unit of the radii passed by the user at init.
:return: An astropy unit object.
:rtype: Unit
"""
return self._radii.unit
@property
def annulus_bounds(self) -> Quantity:
"""
Getter for the original boundary radii of the annuli this profile may have been generated from. Only
available if radii errors were passed on init.
:return: An astropy quantity containing the boundary radii of the annuli, or None if not available.
:rtype: Quantity
"""
return self._rad_ann_bounds
@property
def values(self) -> Quantity:
"""
Getter for the values passed by user at init.
:return: Astropy quantity array of values.
:rtype: Quantity
"""
return self._values
@property
def values_err(self) -> Quantity:
"""
Getter for uncertainties on the profile values.
:return: Astropy quantity array of values uncertainties, or a None value if no values_err where passed.
:rtype: Quantity
"""
return self._values_err
@property
def values_unit(self) -> Unit:
"""
Getter for the unit of the values passed by the user at init.
:return: An astropy unit object.
:rtype: Unit
"""
return self._values.unit
@property
def background(self) -> Quantity:
"""
Getter for the background associated with the profile values. If no background is set this will
be zero.
:return: Astropy scalar quantity.
:rtype: Quantity
"""
return self._background
@property
def centre(self) -> Quantity:
"""
Property that returns the central coordinate that the profile was generated from.
:return: An astropy quantity of the central coordinate
:rtype: Quantity
"""
return self._centre
# This definitely doesn't get a setter, as its basically a proxy for type() return, it will not change
# during the life of the object
@property
def type(self) -> str:
"""
Getter for a string representing the type of profile stored in this object.
:return: String description of profile.
:rtype: str
"""
return self._prof_type + "_profile"
@property
def src_name(self) -> str:
"""
Getter for the name attribute of this profile, what source object it was derived from.
:return:
:rtype: object
"""
return self._src_name
@src_name.setter
def src_name(self, new_name):
"""
Setter for the name attribute of this profile, what source object it was derived from.
"""
self._src_name = new_name
# This method means that a change has happened to the model, so it should be re-saved
self.save()
@property
def obs_id(self) -> str:
"""
Property getter for the ObsID this profile was made from. Admittedly this information is implicit
in the location this object is stored in a source object, but I think it worth storing directly
as a property as well.
:return: XMM ObsID string.
:rtype: str
"""
return self._obs_id
@property
def instrument(self) -> str:
"""
Property getter for the instrument this profile was made from. Admittedly this information is implicit
in the location this object is stored in a source object, but I think it worth storing directly
as a property as well.
directly as a property as well.
:return: XMM instrument name string.
:rtype: str
"""
return self._inst
@property
def energy_bounds(self) -> Union[Tuple[Quantity, Quantity], Tuple[None, None]]:
"""
Getter method for the energy_bounds property, which returns the rest frame energy band that this
profile was generated from
:return: Tuple containing the lower and upper energy limits as Astropy quantities.
:rtype: Union[Tuple[Quantity, Quantity], Tuple[None, None]]
"""
return self._energy_bounds
@property
def set_ident(self) -> int:
"""
If this profile was generated from an annular spectrum, this will contain the set_id of that annular spectrum.
:return: The integer set ID of the annular spectrum that generated this, or None if it wasn't generated
from an AnnularSpectra object.
:rtype: int
"""
return self._set_id
@property
def y_axis_label(self) -> str:
"""
Property to return the name used for labelling the y-axis in any plot generated by a profile object.
:return: The y_axis label.
:rtype: str
"""
return self._y_axis_name
@y_axis_label.setter
def y_axis_label(self, new_name: str):
"""
This allows the user to set a new y axis label for any plots generated by a profile object.
:param str new_name: The new y axis label.
"""
if not isinstance(new_name, str):
raise TypeError("Axis labels must be strings!")
self._y_axis_name = new_name
# This method means that a change has happened to the model, so it should be re-saved
self.save()
@property
def associated_set_storage_key(self) -> str:
"""
This property provides the storage key of the associated AnnularSpectra object, if the profile was generated
from an AnnularSpectra. If it was not then a None value is returned.
:return: The storage key of the associated AnnularSpectra, or None if not applicable.
:rtype: str
"""
return self._set_storage_key
@property
def deg_radii(self) -> Quantity:
"""
The radii in degrees if available.
:return: An astropy quantity containing the radii in degrees, or None.
:rtype: Quantity
"""
return self._deg_radii
@property
def storage_key(self) -> str:
"""
This property returns the storage key which this object assembles to place the profile in
an XGA source's storage structure. If the profile was generated from an AnnularSpectra then the key is based
on the properties of the AnnularSpectra, otherwise it is based upon the properties of the specific profile.
:return: String storage key.
:rtype: str
"""
return self._storage_key
@property
def usable(self) -> bool:
"""
Whether the profile object can be considered usable or not, reasons for this decision will vary for
different profile types.
:return: A boolean variable.
:rtype: bool
"""
return self._usable
@property
def x_norm(self) -> Quantity:
"""
The normalisation value for x-axis data passed on the definition of the this profile object.
:return: An astropy quantity containing the normalisation value.
:rtype: Quantity
"""
return self._x_norm
@x_norm.setter
def x_norm(self, new_val: Quantity):
"""
The setter for the normalisation value for x-axis data passed on the definition of the this profile object.
:param Quantity new_val: The new value for the normalisation of x-axis data.
"""
self._x_norm = new_val
# This method means that a change has happened to the model, so it should be re-saved
self.save()
@property
def y_norm(self) -> Quantity:
"""
The normalisation value for y-axis data passed on the definition of the this profile object.
:return: An astropy quantity containing the normalisation value.
:rtype: Quantity
"""
return self._y_norm
@y_norm.setter
def y_norm(self, new_val: Quantity):
"""
The setter for the normalisation value for y-axis data passed on the definition of the this profile object.
:param Quantity new_val: The new value for the normalisation of y-axis data.
"""
self._y_norm = new_val
# This method means that a change has happened to the model, so it should be re-saved
self.save()
@property
def fit_options(self) -> List[str]:
"""
Returns the supported fit options for XGA profiles.
:return: List of supported fit options.
:rtype: List[str]
"""
return self._fit_methods
@property
def nice_fit_names(self) -> List[str]:
"""
Returns nicer looking names for the supported fit options of XGA profiles.
:return: List of nice fit options.
:rtype: List[str]
"""
return self._nice_fit_methods
@property
def outer_radius(self) -> Quantity:
"""
Property that returns the outer radius used for the generation of this profile.
:return: The outer radius used in the generation of the profile.
:rtype: Quantity
"""
return self._outer_rad
def __len__(self):
"""
The length of a BaseProfile1D object is equal to the length of the radii and values arrays
passed in on init.
:return: The number of bins in this radial profile.
"""
return len(self._radii)
def __add__(self, other):
to_combine = [self]
if type(other) == list:
to_combine += other
elif isinstance(other, BaseProfile1D):
to_combine.append(other)
elif isinstance(other, BaseAggregateProfile1D):
to_combine += other.profiles
else:
raise TypeError("You may only add 1D Profiles, 1D Aggregate Profiles, or a list of 1D profiles"
" to this object.")
return BaseAggregateProfile1D(to_combine)
[docs]class BaseAggregateProfile1D:
"""
Quite a simple class that is generated when multiple 1D radial profile objects are added together. The
purpose of instances of this class is simply to make it easy to view 1D radial profiles on the same axes.
:param list profiles: A list of profile objects (of the same type) to include in this aggregate profile.
"""
def __init__(self, profiles: List[BaseProfile1D]):
"""
The init for the BaseAggregateProfile1D class.
"""
# This checks that all types of profiles in the profiles list are the same
types = [type(p) for p in profiles]
if len(set(types)) != 1:
raise TypeError("All component profiles must be of the same type")
# This checks that all profiles have the same x units
x_units = [p.radii_unit.to_string() for p in profiles]
if len(set(x_units)) != 1:
raise TypeError("All component profiles must have the same radii units.")
# THis checks that they all have the same y units. This is likely to be true if they are the same
# type, but you never know
y_units = [p.values_unit.to_string() for p in profiles]
if len(set(y_units)) != 1:
raise TypeError("All component profiles must have the same value units.")
# We check to see if all profiles either have a background, or not
backs = [p.background.value != 0 for p in profiles]
if len(set(backs)) != 1:
raise ValueError("All component profiles must have a background, or not have a "
"background. You cannot profiles that do to profiles that don't.")
elif backs[0]:
# An attribute to tell us whether backgrounds are present in the component profiles
self._back_avail = True
else:
self._back_avail = False
# Here we check that all energy bounds are the same
lo_bounds = [p.energy_bounds[0] for p in profiles]
hi_bounds = [p.energy_bounds[1] for p in profiles]
if len(set(lo_bounds)) != 1 or len(set(hi_bounds)) != 1:
raise ValueError("All component profiles must have been generate from the same energy range,"
" otherwise they aren't directly comparable.")
self._profiles = profiles
self._radii_unit = x_units[0]
self._values_unit = y_units[0]
# Not doing a check that all the prof types are the same, because that should be included in the
# type check on the first line of this init
self._prof_type = profiles[0].type.split("_profile")[0]
self._energy_bounds = profiles[0].energy_bounds
# We set the y-axis name attribute which is now expected by the plotting function, just grab it from the
# first component because we've already checked that they're all the same type
self._y_axis_name = self._profiles[0].y_axis_label
# Here I grab all the x_norm and y_norms, so that the view method of this aggregate profile can also
# apply normalisation to the separate profiles if the user wants
self._x_norms = [p.x_norm for p in self._profiles]
self._y_norms = [p.y_norm for p in self._profiles]
@property
def radii_unit(self) -> Unit:
"""
Getter for the unit of the radii passed by the user at init.
:return: An astropy unit object.
:rtype: Unit
"""
return self._radii_unit
@property
def values_unit(self) -> Unit:
"""
Getter for the unit of the values passed by the user at init.
:return: An astropy unit object.
:rtype: Unit
"""
return self._values_unit
@property
def type(self) -> str:
"""
Getter for a string representing the type of profile stored in this object.
:return: String description of profile.
:rtype: str
"""
return self._prof_type
@property
def profiles(self) -> List[BaseProfile1D]:
"""
This property is for the constituent profiles that makes up this aggregate profile.
:return: A list of the profiles that make up this object.
:rtype: List[BaseProfile1D]
"""
return self._profiles
@property
def energy_bounds(self) -> Union[Tuple[Quantity, Quantity], Tuple[None, None]]:
"""
Getter method for the energy_bounds property, which returns the rest frame energy band that
the component profiles of this object were generated from.
:return: Tuple containing the lower and upper energy limits as Astropy quantities.
:rtype: Union[Tuple[Quantity, Quantity], Tuple[None, None]]
"""
return self._energy_bounds
@property
def x_norms(self) -> List[Quantity]:
"""
The collated x normalisation values for the constituent profiles of this aggregate profile.
:return: A list of astropy quantities which represent the x-normalisations of the different profiles.
:rtype: List[Quantity]
"""
return self._x_norms
@x_norms.setter
def x_norms(self, new_vals: List[Quantity]):
"""
Setter for the collated x normalisation values for the constituent profiles of this aggregate profile.
:param List[Quantity] new_vals: A list of astropy quantities that the profile's x-axis values are
to be normalised by, there must be one entry for each profile.
"""
if len(new_vals) == len(self._x_norms):
self._x_norms = new_vals
else:
raise ValueError("The new list passed for x-axis normalisations must be the same length"
" as the original.")
@property
def y_norms(self) -> List[Quantity]:
"""
The collated y normalisation values for the constituent profiles of this aggregate profile.
:return: A list of astropy quantities which represent the y-normalisations of the different profiles.
:rtype: List[Quantity]
"""
return self._y_norms
@y_norms.setter
def y_norms(self, new_vals: List[Quantity]):
"""
Setter for the collated y normalisation values for the constituent profiles of this aggregate profile.
:param List[Quantity] new_vals: A list of astropy quantities that the profile's y-axis values are
to be normalised by, there must be one entry for each profile.
"""
if len(new_vals) == len(self._y_norms):
self._y_norms = new_vals
else:
raise ValueError("The new list passed for y-axis normalisations must be the same length"
" as the original.")
[docs] def view(self, figsize: Tuple = (10, 7), xscale: str = "log", yscale: str = "log", xlim: Tuple = None,
ylim: Tuple = None, model: str = None, back_sub: bool = True, legend: bool = True,
just_model: bool = False, custom_title: str = None, draw_rads: dict = {}, normalise_x: bool = False,
normalise_y: bool = False, x_label: str = None, y_label: str = None, save_path: str = None):
"""
A method that allows us to see all the profiles that make up this aggregate profile, plotted
on the same figure.
:param Tuple figsize: The desired size of the figure, the default is (10, 7)
:param str xscale: The scaling to be applied to the x axis, default is log.
:param str yscale: The scaling to be applied to the y axis, default is log.
:param Tuple xlim: The limits to be applied to the x axis, upper and lower, default is
to let matplotlib decide by itself.
:param Tuple ylim: The limits to be applied to the y axis, upper and lower, default is
to let matplotlib decide by itself.
:param str model: The name of the model fit to display, default is None. If the model
hasn't been fitted, or it failed, then it won't be displayed.
:param bool back_sub: Should the plotted data be background subtracted, default is True.
:param bool legend: Should a legend with source names be added to the figure, default is True.
:param bool just_model: Should only the models, not the data, be plotted. Default is False.
:param str custom_title: A plot title to replace the automatically generated title, default is None.
:param dict draw_rads: A dictionary of extra radii (as astropy Quantities) to draw onto the plot, where
the dictionary key they are stored under is what they will be labelled.
e.g. ({'r500': Quantity(), 'r200': Quantity()}. If normalise_x option is also used, and the x-norm values
are not the same for each profile, then draw_rads will be disabled.
:param bool normalise_x: Should the x-axis values be normalised with the x_norm value passed on the
definition of the constituent profile objects.
:param bool normalise_y: Should the y-axis values be normalised with the y_norm value passed on the
definition of the constituent profile objects.
:param str x_label: Custom label for the x-axis (excluding units, which will be added automatically).
:param str y_label: Custom label for the y-axis (excluding units, which will be added automatically).
:param str save_path: The path where the figure produced by this method should be saved. Default is None, in
which case the figure will not be saved.
"""
# Checks that any extra radii that have been passed are the correct units (i.e. the same as the radius units
# used in this profile)
if not all([r.unit == self.radii_unit for r in draw_rads.values()]):
raise UnitConversionError("All radii in draw_rad have to be in the same units as this profile, "
"{}".format(self.radii_unit.to_string()))
# Set up the x normalisation and y normalisation variables
if normalise_x:
x_norms = self.x_norms
else:
x_norms = [Quantity(1, '') for n in self.x_norms]
if normalise_y:
y_norms = self.y_norms
else:
y_norms = [Quantity(1, '') for n in self.y_norms]
# Need to make sure that draw_rads, if set, is compatible with the normalisations. The problem is that
# if the profiles all have different normalisations then the draw_rads values can't be normalised
if len(draw_rads) != 0 and len(set(x_norms)) != 1:
draw_rads = {}
# Setting up figure for the plot
fig = plt.figure(figsize=figsize)
# Grabbing the axis object and making sure the ticks are set up how we want
main_ax = plt.gca()
main_ax.minorticks_on()
if model is not None:
# This sets up an axis for the residuals to be plotted on, if model plotting is enabled
res_ax = fig.add_axes((0.125, -0.075, 0.775, 0.2))
res_ax.minorticks_on()
res_ax.tick_params(axis='both', direction='in', which='both', top=True, right=True)
# Adds a zero line for reference, as its ideally where residuals would be
res_ax.axhline(0.0, color="black")
# Setting some aesthetic parameters for the main plotting axis
main_ax.tick_params(axis='both', direction='in', which='both', top=True, right=True)
# Cycles through the component profiles of this aggregate profile, plotting them all
for p_ind, p in enumerate(self._profiles):
if p.obs_id != 'combined':
p_name = p.src_name + " {o}-{i}".format(o=p.obs_id, i=p.instrument.upper())
else:
p_name = p.src_name
if p.type == "brightness_profile" and p.psf_corrected:
leg_label = p_name + " PSF Corrected"
else:
leg_label = p_name
# This subtracts the background if the user wants a background subtracted plot
plot_y_vals = p.values.copy()
if back_sub:
plot_y_vals -= p.background
rad_vals = p.radii.copy()
plot_y_vals /= y_norms[p_ind]
rad_vals /= x_norms[p_ind]
# Now the actual plotting of the data
if p.radii_err is not None and p.values_err is None:
x_errs = (p.radii_err.copy() / x_norms[p_ind]).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, xerr=x_errs, fmt="x", capsize=2,
label=leg_label)
elif p.radii_err is None and p.values_err is not None:
y_errs = (p.values_err.copy() / y_norms[p_ind]).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, yerr=y_errs, fmt="x", capsize=2,
label=leg_label)
elif p.radii_err is not None and p.values_err is not None:
x_errs = (p.radii_err.copy() / x_norms[p_ind]).value
y_errs = (p.values_err.copy() / y_norms[p_ind]).value
line = main_ax.errorbar(rad_vals.value, plot_y_vals.value, xerr=x_errs, yerr=y_errs, fmt="x", capsize=2,
label=leg_label)
else:
line = main_ax.plot(rad_vals.value, plot_y_vals.value, 'x', label=leg_label)
# If the user only wants the models to be plotted, then this goes through the matplotlib
# artist objects that make up the line plot and hides them.
# Take this approach because I still want them on the legend, and I want the colour to use
# for the model plot
if just_model and model is not None:
line[0].set_visible(False)
if len(line) != 1:
for coll in line[1:]:
for art_obj in coll:
art_obj.set_visible(False)
if not back_sub and p.background.value != 0:
main_ax.axhline(p.background.value, label=leg_label + ' Background', linestyle='dashed',
color=line[0].get_color())
# If the user passes a model name, and that model has been fitted to the data, then that
# model will be plotted
if model is not None:
# I've put them in this order because I would prefer mcmc over odr, and odr over curve_fit
for method in ['mcmc', 'odr', 'curve_fit']:
try:
model_obj = p.get_model_fit(model, method)
lo_rad = p.radii.min()
hi_rad = p.radii.max()
mod_rads = np.linspace(lo_rad, hi_rad, 100)
mod_reals = model_obj.get_realisations(mod_rads)
median_model = np.percentile(mod_reals, 50, axis=1)
upper_model = np.percentile(mod_reals, 84.1, axis=1)
lower_model = np.percentile(mod_reals, 15.9, axis=1)
colour = line[0].get_color()
mod_lab = model_obj.publication_name + " - {}".format(p.nice_fit_names[method])
mod_line = main_ax.plot(mod_rads.value / x_norms[p_ind].value,
median_model.value/y_norms[p_ind], color=colour)
main_ax.fill_between(mod_rads.value / x_norms[p_ind].value,
lower_model.value / y_norms[p_ind].value,
upper_model.value / y_norms[p_ind].value, alpha=0.7, interpolate=True,
where=upper_model.value >= lower_model.value, facecolor=colour)
main_ax.plot(mod_rads.value / x_norms[p_ind].value, lower_model.value / y_norms[p_ind].value,
color=colour, linestyle="dashed")
main_ax.plot(mod_rads.value / x_norms[p_ind].value, upper_model.value / y_norms[p_ind].value,
color=colour, linestyle="dashed")
# This calculates and plots the residuals between the model and the data on the extra
# axis we added near the beginning of this method
res = np.percentile(model_obj.get_realisations(p.radii), 50, axis=1) - \
(plot_y_vals * y_norms[p_ind])
res_ax.plot(rad_vals.value, res.value, 'D', color=colour)
break
except ModelNotAssociatedError:
pass
# Parsing the astropy units so that if they are double height then the square brackets will adjust size
x_unit = r"$\left[" + rad_vals.unit.to_string("latex").strip("$") + r"\right]$"
y_unit = r"$\left[" + plot_y_vals.unit.to_string("latex").strip("$") + r"\right]$"
if x_label is None:
# Setting the main plot's x label
main_ax.set_xlabel("Radius {}".format(x_unit), fontsize=13)
else:
main_ax.set_xlabel(x_label + " {}".format(x_unit), fontsize=13)
if y_label is None and (not self._back_avail or not back_sub):
main_ax.set_ylabel(r"{l} {u}".format(l=self._y_axis_name, u=y_unit), fontsize=13)
elif y_label is None:
# If background has been subtracted it will be mentioned in the y axis label
main_ax.set_ylabel(r"Background Subtracted {l} {u}".format(l=self._y_axis_name, u=y_unit), fontsize=13)
elif y_label is not None:
main_ax.set_ylabel(y_label + ' {}'.format(y_unit), fontsize=13)
# Adds a legend with source names to the side if the user requested it
# I let the user decide because there could be quite a few names in it and it could get messy
if legend:
main_leg = main_ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), ncol=1, borderaxespad=0)
# This makes sure legend keys are shown, even if the data is hidden
for leg_key in main_leg.legendHandles:
leg_key.set_visible(True)
# If the user has manually set limits then we can use them, only on the main axis because
# we grab those limits from the axes object for the residual axis later
if xlim is not None:
main_ax.set_xlim(xlim)
if ylim is not None:
main_ax.set_ylim(ylim)
# Setup the scale that the user wants to see, again on the main axis
main_ax.set_xscale(xscale)
main_ax.set_yscale(yscale)
if model is not None:
# We want the residual x axis limits to be identical to the main axis, as the
# points should line up
res_ax.set_xlim(main_ax.get_xlim())
res_ax.set_xlabel("Radius {}".format(x_unit), fontsize=13)
res_ax.set_xscale(xscale)
# Grabbing the automatically assigned y limits for the residual axis, then finding the maximum
# difference from zero, increasing it by 10%, then setting that value is the new -+ limits
# That way its symmetrical
outer_ylim = 1.1*max([abs(lim) for lim in res_ax.get_ylim()])
res_ax.set_ylim(-outer_ylim, outer_ylim)
res_ax.set_ylabel("Model - Data")
# Adds a title to this figure, changes depending on whether model fits are plotted as well
if model is not None and custom_title is None:
plt.suptitle("{l} Profiles - {m} fit".format(l=self._y_axis_name,
m=MODEL_PUBLICATION_NAMES[model]), y=0.91)
elif model is None and custom_title is None:
plt.suptitle("{l} Profiles".format(l=self._y_axis_name), y=0.91)
else:
# If the user doesn't like my title, they can supply their own
plt.suptitle(custom_title, y=0.91)
# If the user has passed radii to plot, then we plot them
for r_name in draw_rads:
d_rad = (draw_rads[r_name] / x_norms[0]).value
main_ax.axvline(d_rad, linestyle='dashed', color='black')
main_ax.annotate(r_name, (d_rad * 1.01, 0.9), rotation=90, verticalalignment='center',
color='black', fontsize=14, xycoords=('data', 'axes fraction'))
# If the user passed a save_path value, then we assume they want to save the figure
if save_path is not None:
plt.savefig(save_path)
# And of course actually showing it
plt.show()
def __add__(self, other):
to_combine = self.profiles
if type(other) == list:
to_combine += other
elif isinstance(other, BaseProfile1D):
to_combine.append(other)
elif isinstance(other, BaseAggregateProfile1D):
to_combine += other.profiles
else:
raise TypeError("You may only add 1D Profiles, 1D Aggregate Profiles, or a list of 1D profiles"
" to this object.")
return BaseAggregateProfile1D(to_combine)
```