# Source code for estimator.util

```from multiprocessing import Pool
from functools import partial

from sage.all import ceil, floor, oo

from .io import Logging

[docs]class local_minimum_base:
"""
An iterator context for finding a local minimum using binary search.

We use the immediate neighborhood of a point to decide the next direction to go into (gradient
descent style), so the algorithm is not plain binary search (see ``update()`` function.)

.. note :: We combine an iterator and a context to give the caller access to the result.
"""

[docs]    def __init__(
self,
start,
stop,
smallerf=lambda x, best: x <= best,
suppress_bounds_warning=False,
log_level=5,
):
"""
Create a fresh local minimum search context.

:param start: starting point
:param stop:  end point (exclusive)
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``.
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal

"""

if stop < start:
raise ValueError(f"Incorrect bounds {start} > {stop}.")

self._suppress_bounds_warning = suppress_bounds_warning
self._log_level = log_level
self._start = start
self._stop = stop - 1
self._initial_bounds = (start, stop - 1)
self._smallerf = smallerf
# abs(self._direction) == 2: binary search step
# abs(self._direction) == 1: gradient descent direction
self._direction = -1  # going down
self._last_x = None
self._next_x = self._stop
self._best = (None, None)
self._all_x = set()

def __enter__(self):
""" """
return self

def __exit__(self, type, value, traceback):
""" """
pass

def __iter__(self):
""" """
return self

def __next__(self):
abort = False
if self._next_x is None:
abort = True  # we're told to abort
elif self._next_x in self._all_x:
abort = True  # we're looping
elif self._next_x < self._initial_bounds or self._initial_bounds < self._next_x:
abort = True  # we're stepping out of bounds

if not abort:
self._last_x = self._next_x
self._next_x = None
return self._last_x

if self._best in self._initial_bounds and not self._suppress_bounds_warning:
# We warn the user if the optimal solution is at the edge and thus possibly not optimal.
Logging.log(
"bins",
self._log_level,
f'warning: "optimal" solution {self._best} matches a bound ∈ {self._initial_bounds}.',
)

raise StopIteration

@property
def x(self):
return self._best

@property
def y(self):
return self._best

[docs]    def update(self, res):
"""

TESTS:

We keep cache old inputs in ``_all_x`` to prevent infinite loops::

>>> from estimator.util import binary_search
>>> from estimator.cost import Cost
>>> f = lambda x, log_level=1: Cost(rop=1) if x >= 19 else Cost(rop=2)
>>> binary_search(f, 10, 30, "x")
rop: 1

"""

Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})")

# We got nothing yet
if self._best is None:
self._best = self._last_x, res

# We found something better
if res is not False and self._smallerf(res, self._best):
# store it
self._best = self._last_x, res

# if it's a result of a long jump figure out the next direction
if abs(self._direction) != 1:
self._direction = -1
self._next_x = self._last_x - 1
# going down worked, so let's keep on doing that.
elif self._direction == -1:
self._direction = -2
self._stop = self._last_x
self._next_x = ceil((self._start + self._stop) / 2)
# going up worked, so let's keep on doing that.
elif self._direction == 1:
self._direction = 2
self._start = self._last_x
self._next_x = floor((self._start + self._stop) / 2)
else:
# going downwards didn't help, let's try up
if self._direction == -1:
self._direction = 1
self._next_x = self._last_x + 2
# going up didn't help either, so we stop
elif self._direction == 1:
self._next_x = None
# it got no better in a long jump, half the search space and try again
elif self._direction == -2:
self._start = self._last_x
self._next_x = ceil((self._start + self._stop) / 2)
elif self._direction == 2:
self._stop = self._last_x
self._next_x = floor((self._start + self._stop) / 2)

# We are repeating ourselves, time to stop
if self._next_x == self._last_x:
self._next_x = None

[docs]class local_minimum(local_minimum_base):
"""
An iterator context for finding a local minimum using binary search.

We use the neighborhood of a point to decide the next direction to go into (gradient descent
style), so the algorithm is not plain binary search (see ``update()`` function.)

We also zoom out by a factor ``precision``, find an approximate local minimum and then
search the neighbourhood for the smallest value.

.. note :: We combine an iterator and a context to give the caller access to the result.

"""

[docs]    def __init__(
self,
start,
stop,
precision=1,
smallerf=lambda x, best: x <= best,
suppress_bounds_warning=False,
log_level=5,
):
"""
Create a fresh local minimum search context.

:param start: starting point
:param stop:  end point (exclusive)
:param precision: only consider every ``precision``-th value in the main loop
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``.
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal

"""
self._precision = precision
self._orig_bounds = (start, stop)
start = ceil(start / precision)
stop = floor(stop / precision)
local_minimum_base.__init__(self, start, stop, smallerf, suppress_bounds_warning, log_level)

def __next__(self):
x = local_minimum_base.__next__(self)
return x * self._precision

@property
def x(self):
return self._best * self._precision

@property
def neighborhood(self):
"""
An iterator over the neighborhood of the currently best value.
"""

start, stop = self._orig_bounds

for x in range(max(start, self.x - self._precision), min(stop, self.x + self._precision)):
yield x

[docs]class early_abort_range:
"""
An iterator context for finding a local minimum using linear search.

.. note :: We combine an iterator and a context to give the caller access to the result.
"""

# TODO: unify whether we like contexts or not

[docs]    def __init__(
self,
start,
stop=oo,
step=1,
smallerf=lambda x, best: x <= best,
suppress_bounds_warning=False,
log_level=5,
):
"""
Create a fresh local minimum search context.

:param start: starting point
:param stop:  end point (exclusive, optional)
:param step:  step size
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``.
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal

"""

if stop < start:
raise ValueError(f"Incorrect bounds {start} > {stop}.")

self._suppress_bounds_warning = suppress_bounds_warning
self._log_level = log_level
self._start = start
self._step = step
self._stop = stop
self._smallerf = smallerf
self._last_x = None
self._next_x = self._start
self._best = (None, None)

def __iter__(self):
""" """
return self

def __next__(self):
if self._next_x is None:
raise StopIteration
elif self._next_x >= self._stop:
raise StopIteration

self._last_x = self._next_x
self._next_x += self._step
return self._last_x, self

@property
def x(self):
return self._best

@property
def y(self):
return self._best

[docs]    def update(self, res):
""" """
Logging.log("lins", self._log_level, f"({self._last_x}, {repr(res)})")

if self._best is None:
self._best = self._last_x, res
return

if res is False:
self._next_x = None
else:
if self._smallerf(res, self._best):
self._best = self._last_x, res
else:
self._next_x = None

[docs]def binary_search(
f, start, stop, param, step=1, smallerf=lambda x, best: x <= best, log_level=5, *args, **kwds
):
"""
Searches for the best value in the interval [start,stop] depending on the given comparison function.

:param start: start of range to search
:param stop: stop of range to search (exclusive)
:param param: the parameter to modify when calling `f`
:param smallerf: comparison is performed by evaluating ``smallerf(current, best)``
:param step: initially only consider every `steps`-th value
"""

with local_minimum(start, stop + 1, step, smallerf=smallerf, log_level=log_level) as it:
for x in it:
kwds_ = dict(kwds)
kwds_[param] = x
it.update(f(*args, **kwds_))

for x in it.neighborhood:
kwds_ = dict(kwds)
kwds_[param] = x
it.update(f(*args, **kwds_))

return it.y

def _batch_estimatef(f, x, log_level=0, f_repr=None, catch_exceptions=True):
try:
y = f(x)
except Exception as e:
if catch_exceptions:
print(f"Algorithm {f_repr} on {x} failed with {e}")
return None
else:
raise e

if f_repr is None:
f_repr = repr(f)

Logging.log("batch", log_level, f"f: {f_repr}")
Logging.log("batch", log_level, f"x: {x}")
Logging.log("batch", log_level, f"f(x): {repr(y)}")

return y

[docs]def f_name(f):
try:
return f.__name__
except AttributeError:
return repr(f)

[docs]def batch_estimate(params, algorithm, jobs=1, log_level=0, catch_exceptions=True, **kwds):
"""
Run estimates for all algorithms for all parameters.

:param params: (List of) LWE parameters.
:param algorithm: (List of) algorithms.
:param jobs: Use multiple threads in parallel.
:param log_level:
:param catch_exceptions: When an estimate fails, just print a warning.

Example::

>>> from estimator import Kyber512, LWE
>>> _ = batch_estimate(Kyber512, [LWE.primal_usvp, LWE.primal_bdd])
>>> _ = batch_estimate(Kyber512, [LWE.primal_usvp, LWE.primal_bdd], jobs=2)

"""
from .lwe_parameters import LWEParameters

if isinstance(params, LWEParameters):
params = (params,)
try:
iter(algorithm)
except TypeError:
algorithm = (algorithm,)

for x in params:
for f in algorithm:
tasks.append((partial(f, **kwds), x, log_level, f_name(f), catch_exceptions))

if jobs == 1:
res = {}
for f, x, lvl, f_repr, catch_exceptions in tasks:
y = _batch_estimatef(f, x, lvl, f_repr, catch_exceptions)
res[f_repr, x] = y
else:
pool = Pool(jobs)
res = pool.starmap(_batch_estimatef, tasks)
res = dict(
[((f_repr, x), res[i]) for i, (f, x, _, f_repr, catch_exceptions) in enumerate(tasks)]
)

ret = dict()
for f, x in res:
ret[x] = ret.get(x, dict())
if res[f, x] is not None:
ret[x][f] = res[f, x]

return ret
```