Source code for estimator.util

import itertools as it
from multiprocessing import Pool
from functools import partial
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple

from sage.all import ceil, floor, log, oo

from .io import Logging
from .lwe_parameters import LWEParameters


[docs]def log2(x): return log(x, 2.0)
[docs]class Bounds(NamedTuple): low: Any high: Any
[docs]class local_minimum_base: """ An iterator context for finding a local minimum using binary search. We use the immediate neighborhood of a point to decide the next direction to go into (gradient descent style), so the algorithm is not plain binary search (see ``update()`` function.) .. note :: We combine an iterator and a context to give the caller access to the result. """
[docs] def __init__( self, start, stop, smallerf=lambda x, best: x <= best, suppress_bounds_warning=False, log_level=5, ): """ Create a fresh local minimum search context. :param start: starting point :param stop: end point (exclusive) :param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``. :param suppress_bounds_warning: do not warn if a boundary is picked as optimal """ if stop < start: raise ValueError(f"Incorrect bounds {start} > {stop}.") self._suppress_bounds_warning = suppress_bounds_warning self._log_level = log_level self._start = start self._stop = stop - 1 self._initial_bounds = Bounds(start, stop - 1) self._smallerf = smallerf # abs(self._direction) == 2: binary search step # abs(self._direction) == 1: gradient descent direction self._direction = -1 # going down self._last_x = None self._next_x = self._stop self._best = Bounds(None, None) self._all_x = set()
def __enter__(self): """ """ return self def __exit__(self, type, value, traceback): """ """ pass def __iter__(self): """ """ return self def __next__(self): if (self._next_x is not None and self._next_x not in self._all_x and self._initial_bounds.low <= self._next_x <= self._initial_bounds.high): # we've not been told to abort # we're not looping # we're in bounds self._last_x = self._next_x self._next_x = None return self._last_x if self._best.low in self._initial_bounds and not self._suppress_bounds_warning: # We warn the user if the optimal solution is at the edge and thus possibly not optimal. msg = f'warning: "optimal" solution {self._best.low} matches a bound ∈ {self._initial_bounds}.', Logging.log("bins", self._log_level, msg) raise StopIteration @property def x(self): return self._best.low @property def y(self): return self._best.high
[docs] def update(self, res): """ TESTS: We keep cache old inputs in ``_all_x`` to prevent infinite loops:: >>> from estimator.util import binary_search >>> from estimator.cost import Cost >>> f = lambda x, log_level=1: Cost(rop=1) if x >= 19 else Cost(rop=2) >>> binary_search(f, 10, 30, "x") rop: 1 """ Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})") self._all_x.add(self._last_x) # We got nothing yet if self._best.low is None: self._best = Bounds(self._last_x, res) # We found something better if res is not False and self._smallerf(res, self._best.high): # store it self._best = Bounds(self._last_x, res) # if it's a result of a long jump figure out the next direction if abs(self._direction) != 1: self._direction = -1 self._next_x = self._last_x - 1 # going down worked, so let's keep on doing that. elif self._direction == -1: self._direction = -2 self._stop = self._last_x self._next_x = ceil((self._start + self._stop) / 2) # going up worked, so let's keep on doing that. elif self._direction == 1: self._direction = 2 self._start = self._last_x self._next_x = floor((self._start + self._stop) / 2) else: # going downwards didn't help, let's try up if self._direction == -1: self._direction = 1 self._next_x = self._last_x + 2 # going up didn't help either, so we stop elif self._direction == 1: self._next_x = None # it got no better in a long jump, half the search space and try again elif self._direction == -2: self._start = self._last_x self._next_x = ceil((self._start + self._stop) / 2) elif self._direction == 2: self._stop = self._last_x self._next_x = floor((self._start + self._stop) / 2) # We are repeating ourselves, time to stop if self._next_x == self._last_x: self._next_x = None
[docs]class local_minimum(local_minimum_base): """ An iterator context for finding a local minimum using binary search. We use the neighborhood of a point to decide the next direction to go into (gradient descent style), so the algorithm is not plain binary search (see ``update()`` function.) We also zoom out by a factor ``precision``, find an approximate local minimum and then search the neighbourhood for the smallest value. .. note :: We combine an iterator and a context to give the caller access to the result. """
[docs] def __init__( self, start, stop, precision=1, smallerf=lambda x, best: x <= best, suppress_bounds_warning=False, log_level=5, ): """ Create a fresh local minimum search context. :param start: starting point :param stop: end point (exclusive) :param precision: only consider every ``precision``-th value in the main loop :param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``. :param suppress_bounds_warning: do not warn if a boundary is picked as optimal """ self._precision = precision self._orig_bounds = (start, stop) start = ceil(start / precision) stop = floor(stop / precision) local_minimum_base.__init__(self, start, stop, smallerf, suppress_bounds_warning, log_level)
def __next__(self): x = local_minimum_base.__next__(self) return x * self._precision @property def x(self): return self._best.low * self._precision @property def neighborhood(self): """ An iterator over the neighborhood of the currently best value. """ start_bound, stop_bound = self._orig_bounds start = max(start_bound, self.x - self._precision) stop = min(stop_bound, self.x + self._precision) return range(start, stop)
[docs]class early_abort_range: """ An iterator context for finding a local minimum using linear search. .. note :: We combine an iterator and a context to give the caller access to the result. """ # TODO: unify whether we like contexts or not
[docs] def __init__( self, start, stop=oo, step=1, smallerf=lambda x, best: x <= best, suppress_bounds_warning=False, log_level=5, ): """ Create a fresh local minimum search context. :param start: starting point :param stop: end point (exclusive, optional) :param step: step size :param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``. :param suppress_bounds_warning: do not warn if a boundary is picked as optimal """ if stop < start: raise ValueError(f"Incorrect bounds {start} > {stop}.") self._suppress_bounds_warning = suppress_bounds_warning self._log_level = log_level self._start = start self._step = step self._stop = stop self._smallerf = smallerf self._last_x = None self._next_x = self._start self._best = Bounds(None, None)
def __iter__(self): """ """ return self def __next__(self): if self._next_x is None: raise StopIteration if self._next_x >= self._stop: raise StopIteration self._last_x = self._next_x self._next_x += self._step return self._last_x, self @property def x(self): return self._best.low @property def y(self): return self._best.high
[docs] def update(self, res): """ """ Logging.log("lins", self._log_level, f"({self._last_x}, {repr(res)})") if self._best.low is None: self._best = self._last_x, res return if res is False: self._next_x = None elif self._smallerf(res, self._best.high): self._best = self._last_x, res else: self._next_x = None
def _batch_estimatef(f, x, log_level=0, f_repr=None, catch_exceptions=True): try: y = f(x) except Exception as e: if catch_exceptions: print(f"Algorithm {f_repr} on {x} failed with {e}") return None else: raise e if f_repr is None: f_repr = repr(f) Logging.log("batch", log_level, f"f: {f_repr}") Logging.log("batch", log_level, f"x: {x}") Logging.log("batch", log_level, f"f(x): {y!r}") return y
[docs]def f_name(f): try: return f.__name__ except AttributeError: return repr(f)
[docs]class Task(NamedTuple): f: Callable x: LWEParameters log_level: int f_name: str catch_exceptions: bool
[docs]@dataclass(frozen=True) class TaskResults: _map: dict def __getitem__(self, params): return { task.f_name: result for task, result in self._map.items() if task.x == params and result is not None }
[docs]def batch_estimate(params, algorithm, jobs=1, log_level=0, catch_exceptions=True, **kwds): """ Run estimates for all algorithms for all parameters. :param params: (List of) LWE parameters. :param algorithm: (List of) algorithms. :param jobs: Use multiple threads in parallel. :param log_level: :param catch_exceptions: When an estimate fails, just print a warning. Example:: >>> from estimator import LWE >>> from estimator.schemes import Kyber512 >>> _ = batch_estimate(Kyber512, [LWE.primal_usvp, LWE.primal_bdd]) >>> _ = batch_estimate(Kyber512, [LWE.primal_usvp, LWE.primal_bdd], jobs=2) """ if isinstance(params, LWEParameters): params = (params,) if not hasattr(algorithm, "__iter__"): algorithm = (algorithm,) tasks = [ Task(partial(f, **kwds), x, log_level, f_name(f), catch_exceptions) for f, x in it.product(algorithm, params) ] if jobs == 1: results = [_batch_estimatef(*task) for task in tasks] else: with Pool(jobs) as pool: results = pool.starmap(_batch_estimatef, tasks) return TaskResults(dict(zip(tasks, results)))