# Source code for Arms.Bernoulli

# -*- coding: utf-8 -*-
""" Bernoulli distributed arm.

Example of creating an arm:

>>> import random; import numpy as np
>>> random.seed(0); np.random.seed(0)
>>> B03 = Bernoulli(0.3)
>>> B03
B(0.3)
>>> B03.mean
0.3

Examples of sampling from an arm:

>>> B03.draw()
0
>>> B03.draw_nparray(20)
array([1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1.,
1., 1., 1.])
"""
from __future__ import division, print_function  # Python 2 compatibility

__author__ = "Lilian Besson"
__version__ = "0.6"

import numpy as np
from numpy.random import binomial

# Local imports
try:
from .Arm import Arm
from .kullback import klBern
except ImportError:
from Arm import Arm
from kullback import klBern

[docs]class Bernoulli(Arm):
""" Bernoulli distributed arm."""

[docs]    def __init__(self, probability):
"""New arm."""
assert 0 <= probability <= 1, "Error, the parameter probability for Bernoulli class has to be in [0, 1]."  # DEBUG
self.probability = probability  #: Parameter p for this Bernoulli arm
self.mean = probability  #: Mean for this Bernoulli arm

# --- Random samples

[docs]    def draw(self, t=None):
""" Draw one random sample."""
return binomial(1, self.probability)
# return np.asarray(binomial(1, self.probability), dtype=float)

[docs]    def draw_nparray(self, shape=(1,)):
""" Draw a numpy array of random samples, of a certain shape."""
return np.asarray(binomial(1, self.probability, shape), dtype=float)

[docs]    def set_mean_param(self, probability):
self.probability = self.mean = probability

# --- Printing

# This decorator @property makes this method an attribute, cf. https://docs.python.org/3/library/functions.html#property
@property
def lower_amplitude(self):
"""(lower, amplitude)"""
return 0., 1.

[docs]    def __str__(self):
return "Bernoulli"

[docs]    def __repr__(self):
return "B({:.3g})".format(self.probability)

# --- Lower bound

[docs]    @staticmethod
def kl(x, y):
""" The kl(x, y) to use for this arm."""
return klBern(x, y)

[docs]    @staticmethod
def oneLR(mumax, mu):
""" One term of the Lai & Robbins lower bound for Bernoulli arms: (mumax - mu) / KL(mu, mumax). """
return (mumax - mu) / klBern(mu, mumax)

# Only export and expose the class defined here
__all__ = ["Bernoulli"]

# --- Debugging

if __name__ == "__main__":
# Code for debugging purposes.
from doctest import testmod
print("\nTesting automatically all the docstring written in each functions of this module :")
testmod(verbose=True)