Source code for getdist.paramnames

import os
import fnmatch
from itertools import chain


[docs] def makeList(roots): """ Checks if the given parameter is a list. If not, Creates a list with the parameter as an item in it. :param roots: The parameter to check :return: A list containing the parameter. """ if isinstance(roots, (list, tuple)): return roots else: return [roots]
def escapeLatex(text): if text: import matplotlib if matplotlib.rcParams['text.usetex']: return text.replace('_', '{\\textunderscore}') return text
[docs] def mergeRenames(*dicts, **kwargs): """ Joins several dicts of renames. If `keep_names_1st=True` (default: `False`), keeps empty entries when possible in order to preserve the parameter names of the first input dictionary. Returns a merged dictionary of renames, whose keys are chosen from the left-most input. """ keep_names_1st = kwargs.pop("keep_names_1st", False) if kwargs: raise ValueError("kwargs not recognized: %r" % kwargs) sets = list(chain(*[[set([k] + (makeList(v or []))) for k, v in dic.items()] for dic in dicts])) # If two sets have elements in common, join them. something_changed = True out = [] while something_changed: something_changed = False for i in range(1, len(sets)): if sets[0].intersection(sets[i]): sets[0] = sets[0].union(sets.pop(i)) something_changed = True break if not something_changed and sets: out += [sets.pop(0)] if len(sets): something_changed = True merged = {} for params in out: for dic in dicts: p = set(dic).intersection(params) if p and (params != p or keep_names_1st): key = p.pop() params.remove(key) merged[key] = list(params) break return merged
[docs] class ParamInfo: """ Parameter information object. :ivar name: the parameter name tag (no spacing or punctuation) :ivar label: latex label (without enclosing $) :ivar comment: any descriptive comment describing the parameter :ivar isDerived: True if a derived parameter, False otherwise (e.g. for MCMC parameters) """ def __init__(self, line=None, name='', label='', comment='', derived=False, renames=None, number=None): self.setName(name) self.isDerived = derived self.label = label or name self.comment = comment self.filenameLoadedFrom = '' self.number = number self.renames = makeList(renames or []) if line is not None: self.setFromString(line) def nameEquals(self, name): if isinstance(name, ParamInfo): return name.name == name else: return name == name def setFromString(self, line): items = line.split(None, 1) name = items[0] if name.endswith('*'): name = name.strip('*') self.isDerived = True self.setName(name) if len(items) > 1: tmp = items[1].split('#', 1) self.label = tmp[0].strip().replace('!', '\\') if len(tmp) > 1: self.comment = tmp[1].strip() else: self.comment = '' return self def setName(self, name): if not isinstance(name, str): raise ValueError('"name" must be a parameter name string not %s: %s' % (type(name), name)) if '*' in name or '?' in name or ' ' in name or '\t' in name: raise ValueError('Parameter names must not contain spaces, * or ?') self.name = name def getLabel(self): if self.label: return self.label else: return self.name def latexLabel(self): if self.label: return '$' + self.label + '$' else: return self.name def setFromStringWithComment(self, items): self.setFromString(items[0]) if items[1] != 'NULL': self.comment = items[1] def string(self, wantComments=True): res = self.name if self.isDerived: res += '*' res = res + '\t' + self.label if wantComments and self.comment != '': res = res + '\t#' + self.comment return res def __str__(self): return self.string()
[docs] class ParamList: """ Holds an orders list of :class:`ParamInfo` objects describing a set of parameters. :ivar names: list of :class:`ParamInfo` objects """ loadFromFile: callable def __init__(self, fileName=None, setParamNameFile=None, default=0, names=None, labels=None): """ :param fileName: name of .paramnames file to load from :param setParamNameFile: override specific parameter names' labels using another file :param default: set to int>0 to automatically generate that number of default names and labels (param1, p_{1}, etc.) :param names: a list of name strings to use """ self.names = [] self.info_dict = None # if read from yaml file, saved here if default: self.setDefault(default) if names is not None: self.setWithNames(names) if fileName is not None: self.loadFromFile(fileName) if setParamNameFile is not None: self.setLabelsFromParamNames(setParamNameFile) if labels is not None: self.setLabels(labels) def setDefault(self, n): self.names = [ParamInfo(name='param' + str(i + 1), label='p_{' + str(i + 1) + '}') for i in range(n)] return self def setWithNames(self, names): self.names = [ParamInfo(name) for name in names] return self def setLabels(self, labels): for name, label in zip(self.names, labels): name.label = label def numDerived(self): return len([1 for info in self.names if info.isDerived])
[docs] def list(self): """ Gets a list of parameter name strings """ return [name.name for name in self.names]
[docs] def labels(self): """ Gets a list of parameter labels """ return [name.label for name in self.names]
def listString(self): return " ".join(self.list()) def numParams(self): return len(self.names) def numNonDerived(self): return len([1 for info in self.names if not info.isDerived]) def parWithNumber(self, num): for par in self.names: if par.number == num: return par return None def _check_name_str(self, name): if not isinstance(name, str): raise ValueError('"name" must be a parameter name string not %s: %s' % (type(name), name))
[docs] def parWithName(self, name, error=False, renames=None): """ Gets the :class:`ParamInfo` object for the parameter with the given name :param name: name of the parameter :param error: if True raise an error if parameter not found, otherwise return None :param renames: a dictionary that is used to provide optional name mappings to the stored names """ self._check_name_str(name) given_names = {name} if renames: given_names.update(makeList(renames.get(name, []))) for par in self.names: known_names = set([par.name] + makeList(getattr(par, 'renames', [])) + (makeList(renames.get(par.name, [])) if renames else [])) if known_names.intersection(given_names): return par if error: raise Exception("parameter name not found: %s" % name) return None
[docs] def numberOfName(self, name): """ Gets the parameter number of the given parameter name :param name: parameter name tag :return: index of the parameter, or -1 if not found """ self._check_name_str(name) for i, par in enumerate(self.names): if par.name == name: return i return -1
def hasParam(self, name): return self.numberOfName(name) != -1
[docs] def parsWithNames(self, names, error=False, renames=None): """ gets the list of :class:`ParamInfo` instances for given list of name strings. Also expands any names that are globs into list with matching parameter names :param names: list of name strings :param error: if True, raise an error if any name not found, otherwise returns None items. Can be a list of length `len(names)` :param renames: optional dictionary giving mappings of parameter names """ res = [] if isinstance(names, str): names = [names] errors = makeList(error) if len(errors) < len(names): errors = len(names) * errors for name, error in zip(names, errors): if isinstance(name, ParamInfo): res.append(name) else: if '?' in name or '*' in name: res += self.getMatches(name) else: res.append(self.parWithName(name, error, renames)) return res
def getMatches(self, pattern, strings=False): pars = [] for par in self.names: if fnmatch.fnmatchcase(par.name, pattern): if strings: pars.append(par.name) else: pars.append(par) return pars def setLabelsFromParamNames(self, fname): self.setLabelsAndDerivedFromParamNames(fname, False) def setLabelsAndDerivedFromParamNames(self, fname, set_derived=True): if isinstance(fname, ParamNames): p = fname else: p = ParamNames(fname) for par in p.names: param = self.parWithName(par.name) if param is not None: param.label = par.label if set_derived: param.isDerived = par.isDerived
[docs] def getRenames(self, keep_empty=False): """ Gets dictionary of renames known to each parameter. """ return {param.name: getattr(param, "renames", []) for param in self.names if (getattr(param, "renames", False) or keep_empty)}
[docs] def updateRenames(self, renames): """ Updates the renames known to each parameter with the given dictionary of renames. """ merged_renames = mergeRenames( self.getRenames(keep_empty=True), renames, keep_names_1st=True) known_names = self.list() for name, rename in merged_renames.items(): if name in known_names: self.parWithName(name).renames = rename
def fileList(self, fname): with open(fname, encoding='utf-8-sig') as f: textFileLines = f.readlines() return textFileLines def deleteIndices(self, indices): self.names = [name for i, name in enumerate(self.names) if i not in indices] def filteredCopy(self, params): usedNames = self.__class__() for name in self.names: if isinstance(params, list): p = name.name in params else: p = params.parWithName(name.name) if p: usedNames.names.append(name) return usedNames
[docs] def addDerived(self, name, **kwargs): """ adds a new parameter :param name: name tag for the new parameter :param kwargs: other arguments for constructing the new :class:`ParamInfo` """ if kwargs.get('derived') is None: kwargs['derived'] = True self._check_name_str(name) kwargs['name'] = name self.names.append(ParamInfo(**kwargs)) return self.names[-1]
def maxNameLen(self): return max([len(name.name) for name in self.names]) def parFormat(self): maxLen = max(9, self.maxNameLen()) + 1 return "%-" + str(maxLen) + "s" def name(self, ix, tag_derived=False): par = self.names[ix] if tag_derived and par.isDerived: return par.name + '*' else: return par.name def __str__(self): text = '' for par in self.names: text += par.string() + '\n' return text
[docs] def saveAsText(self, filename): """ Saves to a plain text .paramnames file :param filename: filename to save to """ with open(filename, 'w', encoding='utf-8') as f: f.write(str(self))
[docs] def getDerivedNames(self): """ Get the names of all derived parameters """ return [name.name for name in self.names if name.isDerived]
[docs] def getRunningNames(self): """ Get the names of all running (non-derived) parameters """ return [name.name for name in self.names if not name.isDerived]
[docs] class ParamNames(ParamList): """ Holds an orders list of :class:`ParamInfo` objects describing a set of parameters, inheriting from :class:`ParamList`. Can be constructed programmatically, and also loaded and saved to a .paramnames files, which is a plain text file giving the names and optional label and comment for each parameter, in order. :ivar names: list of :class:`ParamInfo` objects describing each parameter :ivar filenameLoadedFrom: if loaded from file, the file name """
[docs] def loadFromFile(self, fileName): """ loads from fileName, a plain text .paramnames file or a "full" yaml file """ self.filenameLoadedFrom = os.path.split(fileName)[1] extension = os.path.splitext(fileName)[-1] if extension == '.paramnames': with open(fileName, encoding='utf-8-sig') as f: self.names = [ParamInfo(line) for line in [s.strip() for s in f] if line != ''] elif extension.lower() in ('.yaml', '.yml'): from getdist import yaml_tools from getdist.cobaya_interface import get_info_params, is_sampled_param from getdist.cobaya_interface import is_derived_param, _p_label, _p_renames self.info_dict = yaml_tools.yaml_load_file(fileName) info_params = get_info_params(self.info_dict) # first sampled, then derived self.names = [ParamInfo(name=param, label=(info or {}).get(_p_label, param), renames=(info or {}).get(_p_renames)) for param, info in info_params.items() if is_sampled_param(info)] self.names += [ParamInfo(name=param, label=(info or {}).get(_p_label, param), renames=(info or {}).get(_p_renames), derived=True) for param, info in info_params.items() if is_derived_param(info)] else: raise ValueError('ParanNames must be loaded from .paramnames or .yaml/.yml file, ' 'found %s' % fileName)
def loadFromKeyWords(self, keywordProvider): num_params_used = keywordProvider.keyWord_int('num_params_used') num_derived_params = keywordProvider.keyWord_int('num_derived_params') nparam = num_params_used + num_derived_params for i in range(nparam): info = ParamInfo() info.setFromStringWithComment(keywordProvider.keyWordAndComment('param_' + str(i + 1))) self.names.append(info) return nparam def saveKeyWords(self, keywordProvider): keywordProvider.setKeyWord_int('num_params_used', len(self.names) - self.numDerived()) keywordProvider.setKeyWord_int('num_derived_params', self.numDerived()) for i, name in enumerate(self.names): keywordProvider.setKeyWord('param_' + str(i + 1), name.string(False).replace('\\', '!'), name.comment)