import bob.bio.base
import numpy
import os

import logging
logger = logging.getLogger('bob.bio.tutorial')


class Network (bob.bio.base.extractor.Extractor):
  """This class extracts features using a pre-trained Caffe network

  This wrapper will take a color image in RGB format and extract the deep features from the specified network.
  The extracted deep features will always be one-dimensional of type ``float64``, the number of elements depends on the network and the output layer.
  Network parameteres such as the ``output_layer`` or the ``mean_values`` need to be selected appropriately.

  Parameters
  ----------

  network_caffemodel : str
    The path to the .caffemodel file of the network

  network_prototxt : str
    The path to the .prototxt file of the network

  output_layer : str
    The nameo of the layer at which to extract the deep features.

  input_layer : str
    The name of the input layer of the network. Rarely changed.

  mean_values : [float] or ``None``
    The mean values (one per color channel) to subtract from each input image.
    If not given, mean values are not subtracted.

  """

  def __init__(self,
      network_caffemodel,
      network_prototxt,
      output_layer,
      input_layer='data',
      mean_values=None
  ):
    # call base class constructor
    bob.bio.base.extractor.Extractor.__init__(
        self,
        requires_training=False,
        network_caffemodel=network_caffemodel,
        network_prototxt=network_prototxt,
        mean_value=mean_values,
        input_layer=input_layer,
        output_layer=output_layer
    )

    self.network_caffemodel = network_caffemodel
    self.network_prototxt = network_prototxt
    self.mean_values = mean_values
    self.input_layer = input_layer
    self.output_layer = output_layer
    self.network = None
    self.input_shape = None


  def load(self, _=None):
    """Loads the caffe network from the file specified in the constructor"""
    if self.network is not None:
      # already loaded the network
      return
    logger.info("- Loading caffe network from file %s", self.network_caffemodel)
    # avoid logging information of caffe
    os.environ['GLOG_minloglevel'] = '2'
    import warnings
    warnings.simplefilter("ignore", RuntimeWarning)
    import caffe

    # read network
    self.network = caffe.Net(self.network_prototxt, self.network_caffemodel, caffe.TEST)
    # get required input shape
    self.input_shape = tuple(self.network.blobs[self.input_layer].shape)
    # check that we have a mean value for each color channel
    if self.mean_values is not None:
      assert len(self.mean_values) == self.input_shape[1]

    # check that we have the correct layers
    network_layers = list(self.network._layer_names)
    if self.output_layer not in network_layers:
      raise ValueError("The given output layer '%s' cannot be found in the given network layers '%s'" % (self.output_layer, network_layers))


  def __call__(self, data):
    """Extracts caffe features for the given image

    The image needs to be a color image of the exact same size as required by the network.

    Parameters
    ----------

    data : 3D :py:class:`numpy.ndarray`
      The input image, usually a color image in RGB color space.

    Returns
    -------

    1D :py:class:`numpy.ndarray` of type ``numpy.float64``
      The deep features extracted from the network.
    """
    # convert data into blob format
    assert data.ndim == 3
    # ... add a new dimension
    blob = data[numpy.newaxis]

    assert blob.shape == self.input_shape, "The input data of shape %s cannot be transformed into a blob of shape %s" % (str(blob.shape), str(self.input_shape))

    # check if we have to subtract mean values
    if self.mean_values is not None:
      for i in range(len(self.mean_values)):
        blob[:, i, :, :] -= self.mean_values[i]

    # extract features from network
    result = self.network.forward(**{self.input_layer: blob, 'blobs': [self.output_layer]})

    # concatenate flattened results
    return result[self.output_layer].flatten().astype(numpy.float64).copy()
