Source code for estimator.cost

# -*- coding: utf-8 -*-
from collections import UserDict

from sage.all import log, oo, round


# UserDict inherits from typing.MutableMapping
[docs] class Cost(UserDict): """ Algorithms costs. """ # An entry is "impermanent" if it grows when we run the algorithm again. For example, `δ` # would not scale with the number of operations but `rop` would. This check is strict such that # unknown entries raise an error. This is to enforce a decision on whether an entry should be # scaled. impermanents = { "rop": True, "repetitions": False, "tag": False, "problem": False, } @staticmethod def _update_without_overwrite(dst, src): keys_intersect = set(dst.keys()) & set(src.keys()) attempts = [ f"{k}: {dst[k]} with {src[k]}" for k in keys_intersect if dst[k] != src[k] ] if len(attempts) > 0: s = ", ".join(attempts) raise ValueError(f"Attempting to overwrite {s}") dst.update(src)
[docs] @classmethod def register_impermanent(cls, data=None, **kwds): if data is not None: cls._update_without_overwrite(cls.impermanents, data) cls._update_without_overwrite(cls.impermanents, kwds)
key_map = { "delta": "δ", "beta": "β", "beta_": "β'", "eta": "η", "eta_": "η'", "epsilon": "ε", "zeta": "ζ", "zeta_": "ζ'", "ell": "ℓ", "ell_": "ℓ'", "repetitions": "↻", } val_map = {"beta": "%8d", "beta_": "%8d", "d": "%8d", "delta": "%8.6f"}
[docs] def str(self, keyword_width=0, newline=False, round_bound=2048, compact=False): """ :param keyword_width: keys are printed with this width :param newline: insert a newline :param round_bound: values beyond this bound are represented as powers of two :param compact: do not add extra whitespace to align entries EXAMPLE:: >>> from estimator.cost import Cost >>> s = Cost(delta=5, bar=2) >>> s δ: 5.000000, bar: 2 """ def value_str(k, v): kstr = self.key_map.get(k, k) kk = f"{kstr:>{keyword_width}}" try: if (1 / round_bound < abs(v) < round_bound) or (not v) or (k in self.val_map): if abs(v % 1) < 1e-7: vv = self.val_map.get(k, "%8d") % round(v) else: vv = self.val_map.get(k, "%8.3f") % v else: vv = "%7s" % ("≈2^%.1f" % log(v, 2)) except TypeError: # strings and such vv = "%8s" % v if compact is True: kk = kk.strip() vv = vv.strip() return f"{kk}: {vv}" # we store the problem instance in a cost object for reference s = [value_str(k, v) for k, v in self.items() if k != "problem"] delimiter = "\n" if newline is True else ", " return delimiter.join(s)
[docs] def reorder(self, *args): """ Return a new ordered dict from the key:value pairs in dictionary but reordered such that the keys given to this function come first. :param args: keys which should come first (in order) EXAMPLE:: >>> from estimator.cost import Cost >>> d = Cost(a=1,b=2,c=3); d a: 1, b: 2, c: 3 >>> d.reorder("b","c","a") b: 2, c: 3, a: 1 """ reord = {k: self[k] for k in args if k in self.keys()} reord.update(self) return Cost(**reord)
[docs] def filter(self, **keys): """ Return new ordered dictionary from dictionary restricted to the keys. :param dictionary: input dictionary :param keys: keys which should be copied (ordered) """ r = {k: self[k] for k in keys if k in self.keys()} return Cost(**r)
[docs] def repeat(self, times, select=None): """ Return a report with all costs multiplied by ``times``. :param times: the number of times it should be run :param select: toggle which fields ought to be repeated and which should not :returns: a new cost estimate EXAMPLE:: >>> from estimator.cost import Cost >>> c0 = Cost(a=1, b=2) >>> c0.register_impermanent(a=True, b=False) >>> c0.repeat(1000) a: 1000, b: 2, ↻: 1000 TESTS:: >>> from estimator.cost import Cost >>> Cost(rop=1).repeat(1000).repeat(1000) rop: ≈2^19.9, ↻: ≈2^19.9 """ impermanents = dict(self.impermanents) if select is not None: impermanents.update(select) try: ret = {k: times * v if impermanents[k] else v for k, v in self.items()} ret["repetitions"] = times * ret.get("repetitions", 1) return Cost(**ret) except KeyError as error: raise NotImplementedError( f"You found a bug, this function does not know about about a key but should: {error}" )
def __rmul__(self, times): return self.repeat(times)
[docs] def combine(self, right, base=None): """Combine ``left`` and ``right``. :param left: cost dictionary :param right: cost dictionary :param base: add entries to ``base`` EXAMPLE:: >>> from estimator.cost import Cost >>> c0 = Cost(a=1) >>> c1 = Cost(b=2) >>> c2 = Cost(c=3) >>> c0.combine(c1) a: 1, b: 2 >>> c0.combine(c1, base=c2) c: 3, a: 1, b: 2 """ base_dict = {} if base is None else base cost = {**base_dict, **self, **right} return Cost(**cost)
def __bool__(self): return self.get("rop", oo) < oo def __add__(self, other): return self.combine(self, other) def __repr__(self): return self.str(compact=True) def __str__(self): return self.str(newline=True, keyword_width=12) def __lt__(self, other): try: return self["rop"] < other["rop"] except AttributeError: return self["rop"] < other def __le__(self, other): try: return self["rop"] <= other["rop"] except AttributeError: return self["rop"] <= other
[docs] def sanity_check(self): """ Perform basic checks. """ if self.get("rop", 0) > 2**10000: self["rop"] = oo if self.get("beta", 0) > self.get("d", 0): raise RuntimeError(f"β = {self['beta']} > d = {self['d']}") if self.get("eta", 0) > self.get("d", 0): raise RuntimeError(f"η = {self['eta']} > d = {self['d']}") return self