Source code for theanolm.backend.parameters

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""A module that defines the Parameters class.
"""

import logging
import theano
from theanolm.backend.exceptions import IncompatibleStateError
from theanolm.backend.exceptions import TheanoConfigurationError

[docs] class Parameters: """Theano Function Parameters A dictionary of Theano shared variables. The values can be accessed through their path, which also acts as their identifier when they are saved to a HDF5 file. """
[docs] def __init__(self): """Initializes an empty parameter dictionary. """ self._vars = dict() self.total_size = 0
[docs] def __getitem__(self, path): """Returns a shared variable given parameter path. :type path: str :param path: parameter path :rtype: SharedVariable :returns: the corresponding Theano shared variable """ return self._vars[path]
[docs] def add(self, path, value, device=None): """Adds a new parameter. :type path: str :param path: identifier for the shared variable in Theano and its value when stored in a HDF5 file :type value: numpy.ndarray :param value: initial value for the shared variable :type device: str :param device: if other than ``None``, the shared variable will be kept in this device """ if path in self._vars: raise ValueError("Path `{}´ already in parameters.".format(path)) if theano.config.device.startswith('gpu') and value.dtype == 'float64': raise TheanoConfigurationError( 'You are using Theano with the old GPU backend ("device=gpu"), ' 'and the parameter {} is float64. This is very inefficient, so ' 'you most likely want to set "floatX=float32".'.format(path)) if device is None: self._vars[path] = theano.shared(value, path) else: try: self._vars[path] = theano.shared(value, path, target=device) except TypeError: raise RuntimeError( "Unable to create Theano shared variable for parameter {} " "on device {}. If you are using the old backend, you " "cannot assign layers to different GPU devices." .format(path, device)) logging.debug(" * %s size=%d type=%s device=%s", path, value.size, value.dtype, str(device)) self.total_size += value.size
[docs] def get_state(self, state): """Pulls values from the shared variables into a HDF5 file. If there already is a parameter in the file, it will be replaced, so it has to have the same number of elements. :type state: h5py.File :param state: HDF5 file for storing the parameters """ for path, param in self._vars.items(): if path in state: state[path][:] = param.get_value() else: state.create_dataset(path, data=param.get_value())
[docs] def set_state(self, state): """Sets the values of the shared variables. Requires that ``state`` contains values for all the parameters. :type state: h5py.File :param state: HDF5 file that contains the parameters """ for path, param in self._vars.items(): if path not in state: raise IncompatibleStateError( "Parameter `%s´ is missing from state." % path) new_value = state[path].value param.set_value(new_value) if len(new_value.shape) == 0: logging.debug("%s <- %s", path, str(new_value)) else: logging.debug("%s <- array%s", path, str(new_value.shape))
[docs] def get_variables(self): """Returns a list of the shared variables. :rtype: list of strs :returns: parameter paths """ return self._vars
[docs] def __str__(self): """Returns a string representation of the parameters. :rtype: str :returns: string representation of the parameters """ return str(self._vars)