from functools import reduce
import builtins
from pyrtl.rtllib import multipliers as mult
from ..wire import Const, WireVector
from ..corecircuits import as_wires, concat, select
from ..pyrtlexceptions import PyrtlError
from ..helperfuncs import formatted_str_to_val
[docs]
class Matrix(object):
''' Class for making a Matrix using PyRTL.
Provides the ability to perform different matrix operations.
'''
# Internally, this class uses a Python matrix of WireVectors.
# So, a Matrix is represented as follows for a 2 x 2:
# [[WireVector, WireVector], [WireVector, WireVector]]
[docs]
def __init__(self, rows, columns, bits, signed=False, value=None, max_bits=64):
''' Constructs a Matrix object.
:param int rows: the number of rows in the matrix. Must be greater than 0
:param int columns: the number of columns in the matrix. Must be greater than 0
:param int bits: The amount of bits per WireVector. Must be greater than 0
:param bool signed: Currently not supported (will be added in the future)
:param (WireVector/list) value: The value you want to initialize the Matrix with.
If a WireVector, must be of size `rows * columns * bits`. If a list, must have
`rows` rows and `columns` columns, and every element must fit in `bits` size.
If not given, the matrix initializes to 0
:param int max_bits: The maximum number of bits each WireVector can have, even
after operations like adding two matrices together results in larger
resulting WireVectors
:return: a constructed Matrix object
'''
if not isinstance(rows, int):
raise PyrtlError('Rows must be of type int, instead "%s" '
'was passed of type %s' %
(str(rows), type(rows)))
if rows <= 0:
raise PyrtlError('Rows cannot be less than or equal to zero. '
'Rows value passed: %s' % str(rows))
if not isinstance(columns, int):
raise PyrtlError('Columns must be of type int, instead "%s" '
'was passed of type %s' %
(str(columns), type(columns)))
if columns <= 0:
raise PyrtlError('Columns cannot be less than or equal to zero. '
'Columns value passed: %s' % str(columns))
if not isinstance(bits, int):
raise PyrtlError('Bits must be of type int, instead "%s" '
'was passed of type %s' %
(str(bits), type(bits)))
if bits <= 0:
raise PyrtlError(
'Bits cannot be negative or zero, '
'instead "%s" was passed' % str(bits))
if max_bits is not None:
if bits > max_bits:
bits = max_bits
self._matrix = [[0 for _ in range(columns)] for _ in range(rows)]
if value is None:
for i in range(rows):
for j in range(columns):
self._matrix[i][j] = Const(0)
elif isinstance(value, WireVector):
if value.bitwidth != bits * rows * columns:
raise PyrtlError('Initialized bitwidth value does not match '
'given value.bitwidth: %s, expected: %s'
'' % (str(value.bitwidth),
str(bits * rows * columns)))
for i in range(rows):
for j in range(columns):
start_index = (j * bits) + (i * columns * bits)
self._matrix[rows - i - 1][columns - j - 1] =\
as_wires(value[start_index:start_index + bits], bitwidth=bits)
elif isinstance(value, list):
if len(value) != rows or any(len(row) != columns for row in value):
raise PyrtlError('Rows and columns mismatch\n'
'Rows: %s, expected: %s\n'
'Columns: %s, expected: %s'
'' % (str(len(value)), str(rows),
str(len(value[0])), str(columns)))
for i in range(rows):
for j in range(columns):
self._matrix[i][j] = as_wires(value[i][j], bitwidth=bits)
else:
raise PyrtlError('Initialized value must be of type WireVector or '
'list. Instead was passed value of type %s' % (type(value)))
self.rows = rows
self.columns = columns
self._bits = bits
self.bits = bits
self.signed = False
self.max_bits = max_bits
@property
def bits(self):
''' Gets the number of bits each value is allowed to hold.
:return: an integer representing the number of bits
'''
return self._bits
@bits.setter
def bits(self, bits):
''' Sets the number of bits.
:param int bits: The number of bits. Must be greater than 0
Called automatically when bits is changed.
NOTE: This function will truncate the most significant bits.
'''
if not isinstance(bits, int):
raise PyrtlError('Bits must be of type int, instead "%s" '
'was passed of type %s' %
(str(bits), type(bits)))
if bits <= 0:
raise PyrtlError(
'Bits cannot be negative or zero, '
'instead "%s" was passed' % str(bits))
self._bits = bits
for i in range(self.rows):
for j in range(self.columns):
self._matrix[i][j] = self._matrix[i][j][:bits]
def __len__(self):
''' Gets the output WireVector length.
:return: an integer representing the output WireVector bitwidth
Used with default ``len()`` function
'''
return self.bits * self.rows * self.columns
[docs]
def to_wirevector(self):
''' Outputs the PyRTL Matrix as a singular concatenated WireVector.
:return: a Wirevector representing the whole PyRTL matrix
For instance, if we had a 2 x 1 matrix ``[[wire_a, wire_b]]`` it would
return the concatenated wire: ``wire = wire_a.wire_b``
'''
result = []
for i in range(len(self._matrix)):
for j in range(len(self._matrix[0])):
result.append(as_wires(self[i, j], bitwidth=self.bits))
return as_wires(concat(*result), bitwidth=len(self))
[docs]
def transpose(self):
''' Constructs the transpose of the matrix
:return: a Matrix object representing the transpose
'''
result = Matrix(self.columns, self.rows, self.bits, max_bits=self.max_bits)
for i in range(result.rows):
for j in range(result.columns):
result[i, j] = self[j, i]
return result
def __reversed__(self):
''' Constructs the reverse of matrix
:return: a Matrix object representing the reverse
Used with the ``reversed()`` method
'''
result = Matrix(self.rows, self.columns, self.bits, max_bits=self.max_bits)
for i in range(self.rows):
for j in range(self.columns):
result[i, j] = self[self.rows - 1 - i, self.columns - 1 - j]
return result
def __getitem__(self, key):
''' Accessor for the matrix.
:param (int/slice row, int/slice column) key: The key value to get
:return: WireVector or Matrix containing the value of key
Called when using square brackets (``matrix[...]``).
Examples::
int_matrix = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
matrix = Matrix.Matrix(3, 3, 4, value=int_matrix)
matrix[1] == [3, 4, 5]
matrix[2, 0] == 6
matrix[(2, 0)] = 6
matrix[slice(0, 2), slice(0, 3)] == [[0, 1, 2], [3, 4, 5]]
matrix[0:2, 0:3] == [[0, 1, 2], [3, 4, 5]]
matrix[:2] == [[0, 1, 2], [3, 4, 5]]
matrix[-1] == [6, 7, 8]
matrix[-2:] == [[3, 4, 5], [6, 7, 8]]
'''
if isinstance(key, tuple):
rows, columns = key
# First set up proper slice
if not isinstance(rows, slice):
if not isinstance(rows, int):
raise PyrtlError('Rows must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(rows), type(rows)))
if rows < 0:
rows = self.rows - abs(rows)
if rows < 0:
raise PyrtlError("Invalid bounds for rows. Max rows: %s, got: %s" % (
str(self.rows), str(rows)))
rows = slice(rows, rows + 1, 1)
if not isinstance(columns, slice):
if not isinstance(columns, int):
raise PyrtlError('Columns must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(columns), type(columns)))
if columns < 0:
columns = self.columns - abs(columns)
if columns < 0:
raise PyrtlError("Invalid bounds for columns. Max columns: %s, got: %s" % (
str(self.columns), str(columns)))
columns = slice(columns, columns + 1, 1)
if rows.start is None:
rows = slice(0, rows.stop, rows.step)
elif rows.start < 0:
rows = slice(self.rows - abs(rows.start),
rows.stop, rows.step)
if rows.stop is None:
rows = slice(rows.start, self.rows, rows.step)
elif rows.stop < 0:
rows = slice(rows.start, self.rows - abs(rows.stop),
rows.step)
rows = slice(rows.start, rows.stop, 1)
if columns.start is None:
columns = slice(0, columns.stop, columns.step)
elif columns.start < 0:
columns = slice(self.columns - abs(columns.start),
columns.stop, columns.step)
if columns.stop is None:
columns = slice(columns.start, self.columns, columns.step)
elif columns.stop < 0:
columns = slice(
columns.start, self.columns - abs(columns.stop),
columns.step)
columns = slice(columns.start, columns.stop, 1)
# Check slice bounds
if rows.start > self.rows or rows.stop > self.rows \
or rows.start < 0 or rows.stop < 0:
raise PyrtlError("Invalid bounds for rows. Max rows: %s, got: %s" % (
str(self.rows), str(rows.start) + ":" + str(rows.stop)))
if columns.start > self.columns or columns.stop > self.columns \
or columns.start < 0 or columns.stop < 0:
raise PyrtlError("Invalid bounds for columns. Max columns: %s, got: %s" % (
str(self.columns), str(columns.start) + ":" + str(columns.stop)))
# If it's a single value we want to return a wirevector
if rows.stop - rows.start == 1 and \
columns.stop - columns.start == 1:
return as_wires(self._matrix[rows][0][columns][0],
bitwidth=self.bits)
# Otherwise set up matrix and return that
result = [[0 for _ in range(columns.stop - columns.start)]
for _ in range(rows.stop - rows.start)]
for i in range(len(result)):
for j in range(len(result[0])):
result[i][j] = self._matrix[i + rows.start][j + columns.start]
return Matrix(len(result), len(result[0]), self._bits,
signed=self.signed, value=result, max_bits=self.max_bits)
# Second case when we just want to get full row
if isinstance(key, int):
if key < 0:
start = self.rows - abs(key)
if start < 0:
raise PyrtlError('Index %d is out of bounds for '
'matrix with %d rows' % (key, self.rows))
key = slice(start, start + 1, None)
else:
key = slice(key, key + 1, None)
return self[key, :]
# Third case when we want multiple rows
if isinstance(key, slice):
return self[key, :]
# Otherwise improper value was passed
raise PyrtlError('Rows must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(key), type(key)))
def __setitem__(self, key, value):
''' Mutator for the matrix.
:param (slice/int rows, slice/int columns) key: The key value to set
:param Wirevector/int/Matrix value: The value in which to set the key
Called when setting a value using square brackets (e.g. ``matrix[a, b] = value``).
The `value` given will be truncated to match the bitwidth of all the elements
in the matrix.
'''
if isinstance(key, tuple):
rows, columns = key
# First ensure that slices are correct
if not isinstance(rows, slice):
if not isinstance(rows, int):
raise PyrtlError('Rows must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(rows), type(rows)))
rows = slice(rows, rows + 1, 1)
if not isinstance(columns, slice):
if not isinstance(columns, int):
raise PyrtlError('Columns must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(columns), type(columns)))
columns = slice(columns, columns + 1, 1)
if rows.start is None:
rows = slice(0, rows.stop, rows.step)
elif rows.start < 0:
rows = slice(self.rows - abs(rows.start),
rows.stop, rows.step)
if rows.stop is None:
rows = slice(rows.start, self.rows, rows.step)
elif rows.stop < 0:
rows = slice(rows.start, self.rows - abs(rows.stop),
rows.step)
if columns.start is None:
columns = slice(0, columns.stop, columns.step)
elif columns.start < 0:
columns = slice(self.columns - abs(columns.start),
columns.stop, columns.step)
if columns.stop is None:
columns = slice(columns.start, self.columns, columns.step)
elif columns.stop < 0:
columns = slice(
columns.start, self.columns - abs(columns.stop),
columns.step)
# Check Slice Bounds
if rows.start > self.rows or rows.stop > self.rows \
or rows.start < 0 or rows.stop < 0:
raise PyrtlError("Invalid bounds for rows. Max rows: %s, got: %s" % (
str(self.rows), str(rows.start) + ":" + str(rows.stop)))
if columns.start > self.columns or columns.stop > self.columns \
or columns.start < 0 or columns.stop < 0:
raise PyrtlError("Invalid bounds for columns. Max columns: %s, got: %s" % (
str(self.columns), str(columns.start) + ":" + str(columns.stop)))
# First case when setting value to Matrix
if isinstance(value, Matrix):
if value.rows != (rows.stop - rows.start):
raise PyrtlError(
'Value rows mismatch. Expected Matrix '
'of rows "%s", instead recieved Matrix of rows "%s"' %
(str(rows.stop - rows.start), str(value.rows)))
if value.columns != (columns.stop - columns.start):
raise PyrtlError(
'Value columns mismatch. Expected Matrix '
'of columns "%s", instead recieved Matrix of columns "%s"' %
(str(columns.stop - columns.start), str(value.columns)))
for i in range(rows.stop - rows.start):
for j in range(columns.stop - columns.start):
self._matrix[rows.start + i][columns.start + j] =\
as_wires(value[i, j], bitwidth=self.bits)
# Second case when setting value to wirevector
elif isinstance(value, (int, WireVector)):
if ((rows.stop - rows.start) != 1) or \
((columns.stop - columns.start) != 1):
raise PyrtlError(
'Value mismatch: expected Matrix, instead received WireVector')
self._matrix[rows.start][columns.start] = as_wires(value, bitwidth=self.bits)
# Otherwise Error
else:
raise PyrtlError('Invalid value of type %s' % type(value))
else:
# Second case if we just want to set a full row
if isinstance(key, int):
if key < 0:
start = self.rows - abs(key)
if start < 0:
raise PyrtlError('Index %d is out of bounds for '
'matrix with %d rows' % (key, self.rows))
key = slice(start, start + 1, None)
else:
key = slice(key, key + 1, None)
self[key, :] = value
# Third case if we want to set full rows
elif isinstance(key, slice):
self[key, :] = value
else:
raise PyrtlError('Rows must be of type int or slice, '
'instead "%s" was passed of type %s' %
(str(key), type(key)))
[docs]
def copy(self):
''' Constructs a deep copy of the Matrix.
:return: a Matrix copy
'''
return Matrix(self.rows, self.columns, self.bits,
value=self.to_wirevector(), max_bits=self.max_bits)
def __iadd__(self, other):
''' Perform the in-place addition operation.
:return: a Matrix object with the elementwise addition being preformed
Is used with ``a += b``. Performs an elementwise addition.
'''
new_value = (self + other)
self._matrix = new_value._matrix
self.bits = new_value._bits
return self.copy()
def __add__(self, other):
''' Perform the addition operation.
:return: a Matrix object with the element wise addition being performed
Is used with `a + b`. Performs an elementwise addition.
'''
if not isinstance(other, Matrix):
raise PyrtlError('error: expecting a Matrix, '
'got %s instead' % type(other))
if self.columns != other.columns:
raise PyrtlError('error: columns mismatch. '
'Matrix a: %s columns, Matrix b: %s rows' %
(str(self.columns), str(other.columns)))
elif self.rows != other.rows:
raise PyrtlError('error: row mismatch. '
'Matrix a: %s columns, Matrix b: %s column' %
(str(self.rows), str(other.rows)))
new_bits = self.bits
if other.bits > new_bits:
new_bits = other.bits
result = Matrix(self.rows, self.columns, new_bits + 1, max_bits=self.max_bits)
for i in range(result.rows):
for j in range(result.columns):
result[i, j] = self[i, j] + other[i, j]
return result
def __isub__(self, other):
''' Perform the inplace subtraction opperation.
:Matrix other: the PyRTL Matrix to subtract
:return: a Matrix object with the element wise subtraction being performed
Is used with ``a -= b``. Performs an elementwise subtraction.
'''
new_value = self - other
self._matrix = new_value._matrix
self._bits = new_value._bits
return self.copy()
def __sub__(self, other):
''' Perform the subtraction operation.
:Matrix other: the PyRTL Matrix to subtract
:return: a Matrix object with the elementwise subtraction being performed
Is used with ``a - b``. Performs an elementwise subtraction.
Note: If using unsigned numbers, the result will be floored at 0.
'''
if not isinstance(other, Matrix):
raise PyrtlError('error: expecting a Matrix, '
'got %s instead' % type(other))
if self.columns != other.columns:
raise PyrtlError('error: columns mismatch. '
'Matrix a: %s columns, Matrix b: %s rows' %
(str(self.columns), str(other.columns)))
if self.rows != other.rows:
raise PyrtlError('error: row mismatch. '
'Matrix a: %s columns, Matrix b: %s column' %
(str(self.rows), str(other.rows)))
new_bits = self.bits
if other.bits > new_bits:
new_bits = other.bits
result = Matrix(self.rows, self.columns, new_bits, max_bits=self.max_bits)
for i in range(result.rows):
for j in range(result.columns):
if self.signed:
result[i, j] = self[i, j] - other[i, j]
else:
result[i, j] = select(self[i, j] > other[i, j],
self[i, j] - other[i, j],
Const(0))
return result
def __imul__(self, other):
''' Perform the in-place multiplication operation.
:param Matrix/Wirevector other: the Matrix or scalar to multiply
:return: a Matrix object with the resulting multiplication operation being preformed
Is used with `a *= b`. Performs an elementwise or scalar multiplication.
'''
new_value = self * other
self._matrix = new_value._matrix
self._bits = new_value._bits
return self.copy()
def __mul__(self, other):
''' Perform the elementwise or scalar multiplication operation.
:param Matrix/Wirevector other: the Matrix to multiply
:return: a Matrix object with the resulting multiplication operation being performed
Is used with ``a * b``.
'''
if isinstance(other, Matrix):
if self.columns != other.columns:
raise PyrtlError('error: columns mismatch. '
'Martrix a: %s columns, Matrix b: %s rows' %
(str(self.columns), str(other.columns)))
if self.rows != other.rows:
raise PyrtlError('error, row mismatch '
'Martrix a: %s columns, Matrix b: %s column' %
(str(self.rows), str(other.rows)))
bits = self.bits + other.bits
elif isinstance(other, WireVector):
bits = self.bits + len(other)
else:
raise PyrtlError('Expecting a Matrix or WireVector '
'got %s instead' % type(other))
result = Matrix(self.rows, self.columns, bits, max_bits=self.max_bits)
for i in range(self.rows):
for j in range(self.columns):
if isinstance(other, Matrix):
result[i, j] = self[i, j] * other[i, j]
else:
result[i, j] = self[i, j] * other
return result
def __imatmul__(self, other):
''' Performs the inplace matrix multiplication operation.
:param Matrix other: the second matrix.
:return: a PyRTL Matrix that contains the matrix multiplication product of this and other
Is used with ``a @= b``.
Note: The matmul symbol (``@``) only works in Python 3.5+. Otherwise you must
call ``__imatmul__(other)``.
'''
new_value = self.__matmul__(other)
self.columns = new_value.columns
self.rows = new_value.rows
self._matrix = new_value._matrix
self._bits = new_value._bits
return self.copy()
def __matmul__(self, other):
''' Performs the matrix multiplication operation.
:param Matrix other: the second matrix.
:return: a PyRTL Matrix that contains the matrix multiplication product of this and other
Is used with ``a @ b``.
Note: The matmul symbol (`@`) only works in Python 3.5+. Otherwise you must
call ``__matmul__(other)``.
'''
if not isinstance(other, Matrix):
raise PyrtlError('error: expecting a Matrix, '
'got %s instead' % type(other))
if self.columns != other.rows:
raise PyrtlError('error: rows and columns mismatch. '
'Matrix a: %s columns, Matrix b: %s rows' %
(str(self.columns), str(other.rows)))
result = Matrix(self.rows, other.columns,
self.columns * other.rows * (self.bits + other.bits),
max_bits=self.max_bits)
for i in range(self.rows):
for j in range(other.columns):
for k in range(self.columns):
result[i, j] = mult.fused_multiply_adder(
self[i, k], other[k, j], result[i, j], signed=self.signed)
return result
def __ipow__(self, power):
''' Performs the matrix power operation.
:param int power: the power to perform the matrix on
:return: a PyRTL Matrix that contains the matrix power product
Is used with ``a **= b``.
'''
new_value = self ** power
self._matrix = new_value._matrix
self._bits = new_value._bits
return self.copy()
def __pow__(self, power):
''' Performs the matrix power operation.
:param int power: the power to perform the matrix on
:return: a PyRTL Matrix that contains the matrix power product
Is used with ``a ** b``.
'''
if not isinstance(power, int):
raise PyrtlError('Unexpected power given. Type int expected, '
'but recieved type %s' % type(power))
if self.rows != self.columns:
raise PyrtlError("Matrix must be square")
result = self.copy()
# First case: return identity matrix
if power == 0:
for i in range(self.rows):
for j in range(self.columns):
if i != j:
result[i, j] = Const(0)
else:
result[i, j] = Const(1)
return result
# Second case: do matrix multiplications
if power >= 1:
inputs = [result] * power
def pow_2(first, second):
return first.__matmul__(second)
return reduce(pow_2, inputs)
raise PyrtlError('Power must be greater than or equal to 0')
[docs]
def put(self, ind, v, mode='raise'):
''' Replace specified elements of the matrix with given values
:param int/list[int]/tuple[int] ind: target indices
:param int/list[int]/tuple[int]/Matrix row-vector v: values to place in
matrix at target indices; if `v` is shorter than `ind`, it is repeated as necessary
:param str mode: how out-of-bounds indices behave; ``raise`` raises an
error, ``wrap`` wraps around, and ``clip`` clips to the range
Note that the index is on the flattened matrix.
'''
count = self.rows * self.columns
if isinstance(ind, int):
ind = (ind,)
elif not isinstance(ind, (tuple, list)):
raise PyrtlError("Expected int or list-like indices, got %s" % type(ind))
if isinstance(v, int):
v = (v,)
if isinstance(v, (tuple, list)) and len(v) == 0:
return
elif isinstance(v, Matrix):
if v.rows != 1:
raise PyrtlError(
"Expected a row-vector matrix, instead got matrix with %d rows" % v.rows
)
if mode not in ['raise', 'wrap', 'clip']:
raise PyrtlError(
"Unexpected mode %s; allowable modes are 'raise', 'wrap', and 'clip'" % mode
)
def get_ix(ix):
if ix < 0:
ix = count - abs(ix)
if ix < 0 or ix >= count:
if mode == 'raise':
raise PyrtlError("index %d is out of bounds with size %d" % (ix, count))
elif mode == 'wrap':
ix = ix % count
elif mode == 'clip':
ix = 0 if ix < 0 else count - 1
return ix
def get_value(ix):
if isinstance(v, (tuple, list)):
if ix >= len(v):
return v[-1] # if v is shorter than ind, repeat last as necessary
return v[ix]
elif isinstance(v, Matrix):
if ix >= count:
return v[0, -1]
return v[0, ix]
for v_ix, mat_ix in enumerate(ind):
mat_ix = get_ix(mat_ix)
row = mat_ix // self.columns
col = mat_ix % self.columns
self[row, col] = get_value(v_ix)
[docs]
def reshape(self, *newshape, **order):
''' Create a matrix of the given shape from the current matrix.
:param int/ints/tuple[int] newshape: shape of the matrix to return;
if a single int, will result in a 1-D row-vector of that length;
if a tuple, will use values for number of rows and cols. Can also
be a varargs.
:param str order: ``C`` means to read from self using
row-major order (C-style), and ``F`` means to read from self
using column-major order (Fortran-style).
:return: A copy of the matrix with same data, with a new number of rows/cols
One shape dimension in newshape can be -1; in this case, the value
for that dimension is inferred from the other given dimension (if any)
and the number of elements in the matrix.
Examples::
int_matrix = [[0, 1, 2, 3], [4, 5, 6, 7]]
matrix = Matrix.Matrix(2, 4, 4, value=int_matrix)
matrix.reshape(-1) == [[0, 1, 2, 3, 4, 5, 6, 7]]
matrix.reshape(8) == [[0, 1, 2, 3, 4, 5, 6, 7]]
matrix.reshape(1, 8) == [[0, 1, 2, 3, 4, 5, 6, 7]]
matrix.reshape((1, 8)) == [[0, 1, 2, 3, 4, 5, 6, 7]]
matrix.reshape((1, -1)) == [[0, 1, 2, 3, 4, 5, 6, 7]]
matrix.reshape(4, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]]
matrix.reshape(-1, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]]
matrix.reshape(4, -1) == [[0, 1], [2, 3], [4, 5], [6, 7]]
'''
# python2 does not support named arguments after *args, so we use
# **kwargs for 'order' and set the default here.
order = order.get('order', 'C')
count = self.rows * self.columns
if isinstance(newshape, int):
if newshape == -1:
newshape = (1, count)
else:
newshape = (1, newshape)
elif isinstance(newshape, tuple):
if isinstance(newshape[0], tuple):
newshape = newshape[0]
if len(newshape) == 1:
newshape = (1, newshape[0])
if len(newshape) > 2:
raise PyrtlError("length of newshape tuple must be <= 2")
rows, cols = newshape
if not isinstance(rows, int) or not isinstance(cols, int):
raise PyrtlError(
"newshape dimensions must be integers, instead got %s" % type(newshape)
)
if rows == -1 and cols == -1:
raise PyrtlError("Both dimensions in newshape cannot be -1")
if rows == -1:
rows = count // cols
newshape = (rows, cols)
elif cols == -1:
cols = count // rows
newshape = (rows, cols)
else:
raise PyrtlError(
"newshape can be an integer or tuple of integers, not %s" % type(newshape)
)
rows, cols = newshape
if rows * cols != count:
raise PyrtlError(
"Cannot reshape matrix of size %d into shape %s" % (count, str(newshape))
)
if order not in 'CF':
raise PyrtlError(
"Invalid order %s. Acceptable orders are 'C' (for row-major C-style order) "
"and 'F' (for column-major Fortran-style order)." % order
)
value = [[0] * cols for _ in range(rows)]
ix = 0
if order == 'C':
# Read and write in row-wise order
for newr in range(rows):
for newc in range(cols):
r = ix // self.columns
c = ix % self.columns
value[newr][newc] = self[r, c]
ix += 1
else:
# Read and write in column-wise order
for newc in range(cols):
for newr in range(rows):
r = ix % self.rows
c = ix // self.rows
value[newr][newc] = self[r, c]
ix += 1
return Matrix(rows, cols, self.bits, self.signed, value, self.max_bits)
[docs]
def flatten(self, order='C'):
''' Flatten the matrix into a single row.
:param str order: ``C`` means row-major order (C-style), and
``F`` means column-major order (Fortran-style)
:return: A copy of the matrix flattened in to a row vector matrix
'''
return self.reshape(self.rows * self.columns, order=order)
[docs]
def multiply(first, second):
''' Perform the elementwise or scalar multiplication operation.
:param Matrix first: first matrix
:param Matrix/Wirevector second: second matrix
:return: a Matrix object with the element wise or scalar multiplication being performed
'''
if not isinstance(first, Matrix):
raise PyrtlError('error: expecting a Matrix, '
'got %s instead' % type(second))
return first * second
[docs]
def sum(matrix, axis=None, bits=None):
''' Returns the sum of all the values in a matrix
:param Matrix/Wirevector matrix: the matrix to perform sum operation on.
If it is a WireVector, it will return itself
:param None/int axis: The axis to perform the operation on
None refers to sum of all item. 0 is sum of column. 1 is sum of rows. Defaults to None
:param int bits: The bits per value of the sum. Defaults to bits of old matrix
:return: A WireVector or Matrix representing sum
'''
def sum_2(first, second):
return first + second
if isinstance(matrix, WireVector):
return matrix
if not isinstance(matrix, Matrix):
raise PyrtlError('error: expecting a Matrix or Wirevector for matrix, '
'got %s instead' % type(matrix))
if not isinstance(bits, int) and bits is not None:
raise PyrtlError('error: expecting an int/None for bits, '
'got %s instead' % type(bits))
if not isinstance(axis, int) and axis is not None:
raise PyrtlError('error: expecting an int or None for axis, '
'got %s instead' % type(axis))
if bits is None:
bits = matrix.bits
if bits <= 0:
raise PyrtlError('error: bits cannot be negative or zero, '
'got %s instead' % bits)
if axis is None:
inputs = []
for i in range(matrix.rows):
for j in range(matrix.columns):
inputs.append(matrix[i, j])
return reduce(sum_2, inputs)
if axis == 0:
result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits)
for i in range(matrix.columns):
inputs = []
for j in range(matrix.rows):
inputs.append(matrix[j, i])
result[0, i] = reduce(sum_2, inputs)
return result
if axis == 1:
result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits)
for i in range(matrix.rows):
inputs = []
for j in range(matrix.columns):
inputs.append(matrix[i, j])
result[0, i] = reduce(sum_2, inputs)
return result
raise PyrtlError('Axis invalid: expected (None, 0, or 1), got %s' % axis)
[docs]
def min(matrix, axis=None, bits=None):
''' Returns the minimum value in a matrix.
:param Matrix/Wirevector matrix: the matrix to perform min operation on.
If it is a WireVector, it will return itself
:param None/int axis: The axis to perform the operation on
None refers to min of all item. 0 is min of column. 1 is min of rows. Defaults to None
:param int bits: The bits per value of the min. Defaults to bits of old matrix
:return: A WireVector or Matrix representing the min value
'''
def min_2(first, second):
return select(first < second, first, second)
if isinstance(matrix, WireVector):
return matrix
if not isinstance(matrix, Matrix):
raise PyrtlError('error: expecting a Matrix or Wirevector for matrix, '
'got %s instead' % type(matrix))
if not isinstance(bits, int) and bits is not None:
raise PyrtlError('error: expecting an int/None for bits, '
'got %s instead' % type(bits))
if not isinstance(axis, int) and axis is not None:
raise PyrtlError('error: expecting an int or None for axis, '
'got %s instead' % type(axis))
if bits is None:
bits = matrix.bits
if bits <= 0:
raise PyrtlError('error: bits cannot be negative or zero, '
'got %s instead' % bits)
if axis is None:
inputs = []
for i in range(matrix.rows):
for j in range(matrix.columns):
inputs.append(matrix[i, j])
return reduce(min_2, inputs)
if axis == 0:
result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits)
for i in range(matrix.columns):
inputs = []
for j in range(matrix.rows):
inputs.append(matrix[j, i])
result[0, i] = reduce(min_2, inputs)
return result
if axis == 1:
result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits)
for i in range(matrix.rows):
inputs = []
for j in range(matrix.columns):
inputs.append(matrix[i, j])
result[0, i] = reduce(min_2, inputs)
return result
raise PyrtlError('Axis invalid: expected (None, 0, or 1), got %s' % axis)
[docs]
def max(matrix, axis=None, bits=None):
''' Returns the max value in a matrix.
:param Matrix/Wirevector matrix: the matrix to perform max operation on.
If it is a WireVector, it will return itself
:param None/int axis: The axis to perform the operation on
None refers to max of all items. 0 is max of the columns. 1 is max of rows.
Defaults to None
:param int bits: The bits per value of the max. Defaults to bits of old matrix
:return: A WireVector or Matrix representing the max value
'''
def max_2(first, second):
return select(first > second, first, second)
if isinstance(matrix, WireVector):
return matrix
if not isinstance(matrix, Matrix):
raise PyrtlError('error: expecting a Matrix or WireVector for matrix, '
'got %s instead' % type(matrix))
if not isinstance(bits, int) and bits is not None:
raise PyrtlError('error: expecting an int/None for bits, '
'got %s instead' % type(bits))
if not isinstance(axis, int) and axis is not None:
raise PyrtlError('error: expecting an int or None for axis, '
'got %s instead' % type(axis))
if bits is None:
bits = matrix.bits
if bits <= 0:
raise PyrtlError('error: bits cannot be negative or zero, '
'got %s instead' % bits)
if axis is None:
inputs = []
for i in range(matrix.rows):
for j in range(matrix.columns):
inputs.append(matrix[i, j])
return reduce(max_2, inputs)
if axis == 0:
result = Matrix(
1, matrix.columns, signed=matrix.signed, bits=bits)
for i in range(matrix.columns):
inputs = []
for j in range(matrix.rows):
inputs.append(matrix[j, i])
result[0, i] = reduce(max_2, inputs)
return result
if axis == 1:
result = Matrix(
1, matrix.rows, signed=matrix.signed, bits=bits)
for i in range(matrix.rows):
inputs = []
for j in range(matrix.columns):
inputs.append(matrix[i, j])
result[0, i] = reduce(max_2, inputs)
return result
raise PyrtlError('Axis invalid: expected (None, 0, or 1), got %s' % axis)
[docs]
def argmax(matrix, axis=None, bits=None):
''' Returns the index of the max value of the matrix.
:param Matrix/Wirevector matrix: the matrix to perform argmax operation on.
If it is a WireVector, it will return itself
:param None/int axis: The axis to perform the operation on.
None refers to argmax of all items. 0 is argmax of the columns. 1 is argmax of rows.
Defaults to None
:param int bits: The bits per value of the argmax. Defaults to bits of old matrix
:return: A WireVector or Matrix representing the argmax value
NOTE: If there are two indices with the same max value, this function
picks the first instance.
'''
if isinstance(matrix, WireVector):
return Const(0)
if not isinstance(matrix, Matrix):
raise PyrtlError('error: expecting a Matrix or Wirevector for matrix, '
'got %s instead' % type(matrix))
if not isinstance(bits, int) and bits is not None:
raise PyrtlError('error: expecting an int/None for bits, '
'got %s instead' % type(bits))
if not isinstance(axis, int) and axis is not None:
raise PyrtlError('error: expecting an int or None for axis, '
'got %s instead' % type(axis))
if bits is None:
bits = matrix.bits
if bits <= 0:
raise PyrtlError('error: bits cannot be negative or zero, '
'got %s instead' % bits)
max_number = max(matrix, axis=axis, bits=bits)
if axis is None:
index = Const(0)
arg = matrix.rows * matrix.columns - 1
for i in reversed(range(matrix.rows)):
for j in reversed(range(matrix.columns)):
index = select(
max_number == matrix[i, j], Const(arg), index)
arg -= 1
return index
if axis == 0:
result = Matrix(
1, matrix.columns, signed=matrix.signed, bits=bits)
for i in range(matrix.columns):
local_max = max_number[0, i]
index = Const(0)
arg = matrix.rows - 1
for j in reversed(range(matrix.rows)):
index = select(
local_max == matrix[j, i], Const(arg), index)
arg -= 1
result[0, i] = index
return result
if axis == 1:
result = Matrix(
1, matrix.rows, signed=matrix.signed, bits=bits)
for i in range(matrix.rows):
local_max = max_number[0, i]
index = Const(0)
arg = matrix.columns - 1
for j in reversed(range(matrix.columns)):
index = select(
local_max == matrix[i, j], Const(arg), index)
arg -= 1
result[0, i] = index
return result
[docs]
def dot(first, second):
''' Performs the dot product on two matrices.
:param Matrix first: the first matrix
:param Matrix second: the second matrix
:return: a PyRTL Matrix that contains the dot product of the two PyRTL Matrices
Specifically, the dot product on two matrices is:
* If either `first` or `second` are WireVectors/have both rows and columns
equal to 1, it is equivalent to :py:meth:`.Matrix.__mul__`
* If both `first` and `second` are both arrays (have rows or columns equal to 1),
it is inner product of vectors.
* Otherwise it is :py:meth:`.Matrix.__matmul__` between `first` and `second`
NOTE: Row vectors and column vectors are both treated as arrays
'''
if not isinstance(first, (WireVector, Matrix)):
raise PyrtlError('error: expecting a Matrix, '
'got %s instead' % type(first))
if not isinstance(second, (WireVector, Matrix)):
raise PyrtlError('error: expecting a Matrix/WireVector, '
'got %s instead' % type(second))
# First case when it is multiply
if isinstance(first, WireVector):
if isinstance(second, WireVector):
return first * second
return second[:, :] * first
if isinstance(second, WireVector):
return first[:, :] * second
if (first.rows == 1 and first.columns == 1) \
or (second.rows == 1 and second.columns == 1):
return first[:, :] * second[:, :]
# Second case when it is Inner Product
if first.rows == 1:
if second.rows == 1:
return sum(first * second)
if second.columns == 1:
return sum(first * second.transpose())
elif first.columns == 1:
if second.rows == 1:
return sum(first * second.transpose())
if second.columns == 1:
return sum(first * second)
# Third case when it is Matrix Multiply
return first.__matmul__(second)
[docs]
def hstack(*matrices):
""" Stack matrices in sequence horizontally (column-wise).
:param list[Matrix] matrices: a list of matrices to concatenate one after another horizontally
:return Matrix: a new Matrix, with the same number of rows as the original, with
a bitwidth equal to the max of the bitwidths of all the matrices
All the matrices must have the same number of rows and same 'signed' value.
For example::
m1 = Matrix(2, 3, bits=5, value=[[1,2,3],
[4,5,6]])
m2 = Matrix(2, 1, bits=10, value=[[17],
[23]]])
m3 = hstack(m1, m2)
``m3`` looks like::
[[1,2,3,17],
[4,5,6,23]]
"""
if len(matrices) == 0:
raise PyrtlError("Must supply at least one matrix to hstack()")
if any([not isinstance(matrix, Matrix) for matrix in matrices]):
raise PyrtlError("All arguments to hstack must be matrices.")
if len(matrices) == 1:
return matrices[0].copy()
new_rows = matrices[0].rows
if any([m.rows != new_rows for m in matrices]):
raise PyrtlError("All matrices being hstacked together must have the same number of rows")
new_signed = matrices[0].signed
if any([m.signed != new_signed for m in matrices]):
raise PyrtlError("All matrices being hstacked together must have the same signedness")
new_cols = builtins.sum(m.columns for m in matrices)
new_bits = builtins.max(m.bits for m in matrices)
new_max_bits = builtins.max(m.max_bits for m in matrices)
new = Matrix(new_rows, new_cols, new_bits, max_bits=new_max_bits)
new_c = 0
for matrix in matrices:
for c in range(matrix.columns):
for r in range(matrix.rows):
new[r, new_c] = matrix[r, c]
new_c += 1
return new
[docs]
def vstack(*matrices):
""" Stack matrices in sequence vertically (row-wise).
:param list[Matrix] matrices: a list of matrices to concatenate one after another vertically
:return Matrix: a new Matrix, with the same number of columns as the original, with
a bitwidth equal to the max of the bitwidths of all the matrices
All the matrices must have the same number of columns and same 'signed' value.
For example::
m1 = Matrix(2, 3, bits=5, value=[[1,2,3],
[4,5,6]])
m2 = Matrix(1, 3, bits=10, value=[[7,8,9]])
m3 = vstack(m1, m2)
``m3`` looks like::
[[1,2,3],
[4,5,6],
[7,8,9]]
"""
if len(matrices) == 0:
raise PyrtlError("Must supply at least one matrix to hstack()")
if any([not isinstance(matrix, Matrix) for matrix in matrices]):
raise PyrtlError("All arguments to vstack must be matrices.")
if len(matrices) == 1:
return matrices[0].copy()
new_cols = matrices[0].columns
if any([m.columns != new_cols for m in matrices]):
raise PyrtlError("All matrices being vstacked together must have the "
"same number of columns")
new_signed = matrices[0].signed
if any([m.signed != new_signed for m in matrices]):
raise PyrtlError("All matrices being hstacked together must have the same signedness")
new_rows = builtins.sum(m.rows for m in matrices)
new_bits = builtins.max(m.bits for m in matrices)
new_max_bits = builtins.max(m.max_bits for m in matrices)
new = Matrix(new_rows, new_cols, new_bits, max_bits=new_max_bits)
new_r = 0
for matrix in matrices:
for r in range(matrix.rows):
for c in range(matrix.columns):
new[new_r, c] = matrix[r, c]
new_r += 1
return new
[docs]
def concatenate(matrices, axis=0):
""" Join a sequence of matrices along an existing axis.
:param list[Matrix] matrices: a list of matrices to concatenate one after another
:param int axis: axis along which to join; 0 is horizontally, 1 is vertically (defaults to 0)
:return: a new Matrix composed of the given matrices joined together
This function essentially wraps hstack/vstack.
"""
if axis == 0:
return hstack(*matrices)
elif axis == 1:
return vstack(*matrices)
else:
raise PyrtlError("Only allowable axes are 0 or 1")
[docs]
def matrix_wv_to_list(matrix_wv, rows, columns, bits):
''' Convert a wirevector representing a matrix into a Python list of lists.
:param WireVector matrix_wv: result of calling to_wirevector() on a Matrix object
:param int rows: number of rows in the matrix `matrix_wv` represents
:param int columns: number of columns in the matrix `matrix_wv` represents
:param int bits: number of bits in each element of the matrix `matrix_wv` represents
:return list[list[int]]: a Python list of lists
This is useful when printing the value of a wire you've inspected
during Simulation that you know represnts a matrix.
Example::
values = [[1, 2, 3], [4, 5, 6]]
rows = 2
cols = 3
bits = 4
m = Matrix.Matrix(rows, cols, bits, values=values)
output = Output(name='output')
output <<= m.to_wirevector()
sim = Simulation()
sim.step({})
raw_matrix = Matrix.matrix_wv_to_list(sim.inspect('output'), rows, cols, bits)
print(raw_matrix)
# Produces:
# [[1, 2, 3], [4, 5, 6]]
'''
value = bin(matrix_wv)[2:].zfill(rows * columns * bits)
result = [[0 for _ in range(columns)]
for _ in range(rows)]
bit_pointer = 0
for i in range(rows):
for j in range(columns):
int_value = int(value[bit_pointer: bit_pointer + bits], 2)
result[i][j] = int_value
bit_pointer += bits
return result
[docs]
def list_to_int(matrix, n_bits):
''' Convert a Python matrix (a list of lists) into an integer.
:param list[list[int]] matrix: a pure Python list of lists representing a matrix
:param int n_bits: number of bits to be used to represent each element; if an
element doesn't fit in `n_bits`, it truncates the most significant bits
:return int: a `N * n_bits` wide WireVector containing the elements of `matrix`,
where `N` is the number of elements in `matrix`
Integers that are signed will automatically be converted to their two's complement form.
This function is helpful for turning a pure Python list of lists
into a integer suitable for creating a Constant WireVector that can
be passed in to as a Matrix constructor's `value` argument, or for
passing into a Simulation's step function for a particular input wire.
For example, calling Matrix.list_to_int([3, 5], [7, 9], 4) produces 13,689,
which in binary looks like this::
0011 0101 0111 1001
Note how the elements of the list of lists were added, 4 bits at a time,
in row order, such that the element at row 0, column 0 is in the most significant
4 bits, and the element at row 1, column 1 is in the least significant 4 bits.
Here's an example of using it in simulation::
a_vals = [[0, 1], [2, 3]]
b_vals = [[2, 4, 6], [8, 10, 12]]
a_in = pyrtl.Input(4 * 4, 'a_in')
b_in = pyrtl.Input(6 * 4, 'b_in')
a = Matrix.Matrix(2, 2, 4, value=a_in)
b = Matrix.Matrix(2, 3, 4, value=b_in)
...
sim = pyrtl.Simulation()
sim.step({
'a_in': Matrix.list_to_int(a_vals)
'b_in': Matrix.list_to_int(b_vals)
})
'''
if n_bits <= 0:
raise PyrtlError("Number of bits per element must be positive, instead got %d" % n_bits)
result = 0
for i in range(len(matrix)):
for j in range(len(matrix[0])):
val = formatted_str_to_val(str(matrix[i][j]), 's' + str(n_bits))
result = (result << n_bits) | val
return result