Source code for pyrtl.helperfuncs

""" Helper functions that make constructing hardware easier.
"""

from __future__ import annotations

import collections
import math
import numbers
import sys
from functools import reduce

from .core import working_block, _NameIndexer, _get_debug_mode, Block
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list

# -----------------------------------------------------------------
#        ___       __   ___  __   __
#  |__| |__  |    |__) |__  |__) /__`
#  |  | |___ |___ |    |___ |  \ .__/
#


probeIndexer = _NameIndexer('Probe-')


[docs] def probe(w, name=None): """ Print useful information about a WireVector when in debug mode. :param WireVector w: WireVector from which to get info :param str name: optional name for probe (defaults to an autogenerated name) :return: original WireVector `w` :rtype: WireVector Probe can be inserted into a existing design easily as it returns the original wire unmodified. For example ``y <<= x[0:3] + 4`` could be turned into ``y <<= probe(x)[0:3] + 4`` to give visibility into both the origin of ``x`` (including the line that WireVector was originally created) and the run-time values of ``x`` (which will be named and thus show up by default in a trace). Likewise ``y <<= probe(x[0:3]) + 4``, ``y <<= probe(x[0:3] + 4)``, and ``probe(y) <<= x[0:3] + 4`` are all valid uses of `probe`. Note: `probe` does actually add an Output wire to the working block of `w` (which can confuse various post-processing transforms such as output to verilog). """ if not isinstance(w, WireVector): raise PyrtlError('Only WireVectors can be probed') if name is None: name = '(%s: %s)' % (probeIndexer.make_valid_string(), w.name) if _get_debug_mode(): print("Probe: " + name + ' ' + get_stack(w)) p = Output(name=name) p <<= w # late assigns len from w automatically return w
assertIndexer = _NameIndexer('assertion')
[docs] def rtl_assert(w, exp, block=None): """ Add hardware assertions to be checked on the RTL design. :param WireVector w: should be a WireVector :param Exception exp: Exception to throw when assertion fails :param Block block: block to which the assertion should be added (default to working block) :return: the Output wire for the assertion (can be ignored in most cases) :rtype: Output If at any time during execution the wire `w` is not `true` (i.e. asserted low) then simulation will raise `exp`. """ block = working_block(block) if not isinstance(w, WireVector): raise PyrtlError('Only WireVectors can be asserted with rtl_assert') if len(w) != 1: raise PyrtlError('rtl_assert checks only a WireVector of bitwidth 1') if not isinstance(exp, Exception): raise PyrtlError('the second argument to rtl_assert must be an instance of Exception') if isinstance(exp, KeyError): raise PyrtlError('the second argument to rtl_assert cannot be a KeyError') if w not in block.wirevector_set: raise PyrtlError('assertion wire not part of the block to which it is being added') if w not in block.wirevector_set: raise PyrtlError('assertion not a known wirevector in the target block') if w in block.rtl_assert_dict: raise PyrtlInternalError('assertion conflicts with existing registered assertion') assert_wire = Output(bitwidth=1, name=assertIndexer.make_valid_string(), block=block) assert_wire <<= w block.rtl_assert_dict[assert_wire] = exp return assert_wire
[docs] def check_rtl_assertions(sim): """ Checks the values in sim to see if any registers assertions fail. :param Simulation sim: Simulation in which to check the assertions :return: None """ for (w, exp) in sim.block.rtl_assert_dict.items(): try: value = sim.inspect(w) if not value: raise exp except KeyError: pass
[docs] def log2(integer_val): """ Return the log base 2 of the integer provided. :param int integer_val: The integer to take the log base 2 of. :return: The log base 2 of `integer_val`, or throw PyRTL error if not power of 2 :rtype: int This function is useful when checking that powers of 2 are provided on inputs to functions. It throws an error if a negative value is provided or if the value provided is not an even power of two. Examples: :: log2(2) # returns 1 log2(256) # returns 8 addrwidth = log2(size_of_memory) # will fail if size_of_memory is not a power of two """ i = integer_val if not isinstance(i, int): raise PyrtlError('this function can only take integers') if i <= 0: raise PyrtlError('this function can only take positive numbers 1 or greater') if i & (i - 1) != 0: raise PyrtlError('this function can only take even powers of 2') return i.bit_length() - 1
[docs] def truncate(wirevector_or_integer, bitwidth): """ Returns a WireVector or integer truncated to the specified bitwidth :param wirevector_or_integer: Either a WireVector or an integer to be truncated. :param int bitwidth: The length to which the first argument should be truncated. :return: A truncated WireVector or integer as appropriate. This function truncates the most significant bits of the input, leaving a result that is only `bitwidth` bits wide. For integers this is performed with a simple bitmask of size `bitwidth`. For WireVectors the function calls :meth:`.WireVector.truncate` and returns a WireVector of the specified `bitwidth`. Examples: :: truncate(9,3) # returns 1 (0b1001 truncates to 0b001) truncate(5,3) # returns 5 (0b101 truncates to 0b101) truncate(-1,3) # returns 7 (-0b1 truncates to 0b111) y = truncate(x+1, x.bitwidth) # y.bitwdith will equal x.bitwidth """ if bitwidth < 1: raise PyrtlError('bitwidth must be a positive integer') x = wirevector_or_integer try: return x.truncate(bitwidth) except AttributeError: return x & ((1 << bitwidth) - 1)
class MatchedFields(collections.namedtuple('MatchedFields', 'matched fields')): def __enter__(self): from .conditional import _push_condition _push_condition(self.matched) return self.fields def __exit__(self, *execinfo): from .conditional import _pop_condition _pop_condition()
[docs] def match_bitpattern(w, bitpattern, field_map=None): """Returns a single-bit WireVector that is 1 if and only if `w` matches the `bitpattern`, and a tuple containing the matched fields, if any. Compatible with the `with` statement. :param WireVector w: The WireVector to be compared to the bitpattern :param str bitpattern: A string holding the pattern (of bits and wildcards) to match :param field_map: (optional) A map from single-character field name in the bitpattern to the desired name of field in the returned namedtuple. If given, all non-"1"/"0"/"?" characters in the `bitpattern` must be present in the map. :return: A tuple of 1-bit WireVector carrying the result of the comparison, followed by a named tuple containing the matched fields, if any. This function will compare a multi-bit WireVector to a specified pattern of bits, where some of the pattern can be "wildcard" bits. If any of the ``1`` or ``0`` values specified in the bitpattern fail to match the WireVector during execution, a ``0`` will be produced, otherwise the value carried on the wire will be ``1``. The wildcard characters can be any other alphanumeric character, with characters other than ``?`` having special functionality (see below). The string must have length equal to the WireVector specified, although whitespace and underscore characters will be ignored and can be used for pattern readability. For all other characters besides ``1``, ``0``, or ``?``, a tuple of WireVectors will be returned as the second return value. Each character will be treated as the name of a field, and non-consecutive fields with the same name will be concatenated together, left-to-right, into a single field in the resultant tuple. For example, ``01aa1?bbb11a`` will match a string such as ``010010100111``, and the resultant matched fields are:: (a, b) = (0b001, 0b100) where the ``a`` field is the concenation of bits 9, 8, and 0, and the ``b`` field is the concenation of bits 5, 4, and 3. Thus, arbitrary characters beside ``?`` act as wildcard characters for the purposes of matching, with the additional benefit of returning the WireVectors corresponding to those fields. A prime example of this is for decoding instructions. Here we decode some RISC-V: :: with pyrtl.conditional_assignment: with match_bitpattern(inst, "iiiiiiiiiiiirrrrr010ddddd0000011") as (imm, rs1, rd): regfile[rd] |= mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)] pc.next |= pc + 1 with match_bitpattern(inst, "iiiiiiirrrrrsssss010iiiii0100011") as (imm, rs2, rs1): mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)] |= regfile[rs2] pc.next |= pc + 1 with match_bitpattern(inst, "0000000rrrrrsssss111ddddd0110011") as (rs2, rs1, rd): regfile[rd] |= regfile[rs1] & regfile[rs2] pc.next |= pc + 1 with match_bitpattern(inst, "0000000rrrrrsssss000ddddd0110011") as (rs2, rs1, rd): regfile[rd] |= (regfile[rs1] + regfile[rs2]).truncate(32) pc.next |= pc + 1 # ...etc... Some smaller examples: :: m, _ = match_bitpattern(w, '0101') # basically the same as w == '0b0101' m, _ = match_bitpattern(w, '01?1') # m will be true when w is '0101' or '0111' m, _ = match_bitpattern(w, '??01') # m be true when last two bits of w are '01' m, _ = match_bitpattern(w, '??_0 1') # spaces/underscores are ignored, same as line above m, (a, b) = match_pattern(w, '01aa1?bbb11a') # all bits with same letter make up same field m, fs = match_pattern(w, '01aa1?bbb11a', {'a': 'foo', 'b': 'bar'}) # fields fs.foo, fs.bar """ w = as_wires(w) if not isinstance(bitpattern, str): raise PyrtlError('bitpattern must be a string') nospace_string = ''.join(bitpattern.replace('_', '').split()) if len(w) != len(nospace_string): raise PyrtlError('bitpattern string different length than wirevector provided') lsb_first_string = nospace_string[::-1] # flip so index 0 is lsb zero_bits = [w[index] for index, x in enumerate(lsb_first_string) if x == '0'] one_bits = [w[index] for index, x in enumerate(lsb_first_string) if x == '1'] match = rtl_all(*one_bits) & ~rtl_any(*zero_bits) # Since only Python 3.7 and above guarantees maintaining insertion order in dictionaries, # do all of this to make sure we can maintain the ordering in the returned Tuple. # Order of fields is determined based on left-to-right ordering in original string. def field_name(name): if field_map is not None: if name not in field_map: raise PyrtlError('field_map argument has been given, ' 'but %s field is not present' % name) return field_map[name] return name fields = collections.defaultdict(list) for i, c in enumerate(lsb_first_string): if c not in '01?': fields[c].append(w[i]) fields = sorted(fields.items(), key=lambda m: nospace_string.index(m[0])) # now list of tuples Fields = collections.namedtuple('Fields', ' '.join(field_name(name) for name, _ in fields)) fields = Fields(**{field_name(k): concat_list(l) for k, l in fields}) return MatchedFields(match, fields)
def bitpattern_to_val(bitpattern, *ordered_fields, **named_fields): """Return an unsigned integer representation of field format filled with the provided values. :param bitpattern: A string holding the pattern (of bits and wildcards) to match :param ordered_fields: A list of parameters to be matched to the provided bit pattern in the order provided. If ordered_fields are provided then no named_fields can be used. :param named_fields: A list of parameters to be matched to the provided bit pattern by the names provided. If named_fields are provided then no ordered_fields can be used. A special keyword argument, 'field_map', can be provided, which will allow you to specify a correspondence between the 1-letter field names in the bitpattern string and longer, human readable field names (see example below). :return: An unsigned integer carrying the result of the field substitution. :rtype: int This function will compare a specified pattern of bits, where some of the pattern can be "wildcard" bits. The wildcard bits must all be named with a single letter and, unlike the related function ``match_bitpattern``, no "?" can be used. The function will take the provided bitpattern and create an integer that substitutes the provided fields in for the given wildcards at the bit level. This sort of bit substitution is useful when creating values for testing when the resulting values will be "chopped" up by the hardware later (e.g. instruction decode or other bitfield heavy functions). If a special keyword argument, 'field_map', is provided, then the named fields provided can be longer, human-readable field names, which will correspond to the field in the bitpattern according to the field_map. See the third example below. Examples:: bitpattern_to_val('0000000sssssrrrrr000ddddd0110011', s=2, r=1, d=3) # RISCV ADD instr # evaluates to 0b00000000000100010000000110110011 bitpattern_to_val('iiiiiiisssssrrrrr010iiiii0100011', i=1, s=4, r=3) # RISCV SW instr # evaluates to 0b00000000001100100010000010100011 bitpattern_to_val( 'iiiiiiisssssrrrrr010iiiii0100011', imm=1, rs2=4, rs1=3, field_map={'i': 'imm', 's': 'rs2', 'r': 'rs1} ) # RISCV SW instr # evaluates to 0b00000000001100100010000010100011 """ if not bitpattern: raise PyrtlError('bitpattern must be nonempty') if len(ordered_fields) > 0 and len(named_fields) > 0: raise PyrtlError('named and ordered fields cannot be mixed') def letters_in_field_order(): seen = [] for c in bitpattern: if c != '0' and c != '1' and c not in seen: seen.append(c) return seen field_map = None if 'field_map' in named_fields: field_map = named_fields['field_map'] named_fields.pop('field_map') bitlist = [] lifo = letters_in_field_order() if ordered_fields: if len(lifo) != len(ordered_fields): raise PyrtlError('number of fields and number of unique patterns do not match') intfields = [int(f) for f in ordered_fields] else: if len(lifo) != len(named_fields): raise PyrtlError('number of fields and number of unique patterns do not match') try: def fn(n): return field_map[n] if field_map else n intfields = [int(named_fields[fn(n)]) for n in lifo] except KeyError as e: raise PyrtlError('bitpattern field %s was not provided in named_field list' % e.args[0]) fmap = dict(zip(lifo, intfields)) for c in bitpattern[::-1]: if c == '0' or c == '1': bitlist.append(c) elif c == '?': raise PyrtlError('all fields in the bitpattern must have names') else: bitlist.append(str(fmap[c] & 0x1)) # append lsb of the field fmap[c] = fmap[c] >> 1 # and bit shift by one position for f in fmap: if fmap[f] not in [0, -1]: raise PyrtlError('too many bits given to value to fit in field %s' % f) if len(bitpattern) != len(bitlist): raise PyrtlInternalError('resulting values have different bitwidths') final_str = ''.join(bitlist[::-1]) return int(final_str, 2)
[docs] def chop(w, *segment_widths): """ Returns a list of WireVectors, each a slice of the original `w`. :param WireVector w: The WireVector to be chopped up into segments :param int segment_widths: Additional arguments are integers which are bitwidths :return: A list of WireVectors each with a proper segment width :rtype: List[WireVector] This function chops a WireVector into a set of smaller WireVectors of different lengths. It is most useful when multiple "fields" are contained with a single WireVector, for example when breaking apart an instruction. For example, if you wish to break apart a 32-bit MIPS I-type (Immediate) instruction you know it has an 6-bit opcode, 2 5-bit operands, and 16-bit offset. You could take each of those slices in absolute terms: ``offset=instr[0:16]``, ``rt=instr[16:21]`` and so on, but then you have to do the arithmetic yourself. With this function you can do all the fields at once which can be seen in the examples below. As a check, chop will throw an error if the sum of the lengths of the fields given is not the same as the length of the WireVector to chop. Note also that chop assumes that the "rightmost" arguments are the least signficant bits (just like :func:`.concat`) which is normal for hardware functions but makes the list order a little counter intuitive. Examples: :: opcode, rs, rt, offset = chop(instr, 6, 5, 5, 16) # MIPS I-type instruction opcode, instr_index = chop(instr, 6, 26) # MIPS J-type instruction opcode, rs, rt, rd, sa, function = chop(instr, 6, 5, 5, 5, 5, 6) # MIPS R-type msb, middle, lsb = chop(data, 1, 30, 1) # break out the most and least significant bits """ w = as_wires(w) for seg in segment_widths: if not isinstance(seg, int): raise PyrtlError('segment widths must be integers') if sum(segment_widths) != len(w): raise PyrtlError('sum of segment widths must equal length of wirevetor') n_segments = len(segment_widths) starts = [sum(segment_widths[i + 1:]) for i in range(n_segments)] ends = [sum(segment_widths[i:]) for i in range(n_segments)] return [w[s:e] for s, e in zip(starts, ends)]
[docs] def input_list(names, bitwidth=None): """ Allocate and return a list of :class:`Inputs<pyrtl.wire.Input>`. :param names: Names for the Inputs. Can be a list or single comma/space-separated string :param int bitwidth: The desired bitwidth for the resulting Inputs. :return: List of Inputs. :rtype: List[Input] Equivalent to: :: wirevector_list(names, bitwidth, wvtype=pyrtl.wire.Input) """ return wirevector_list(names, bitwidth, wvtype=Input)
[docs] def output_list(names, bitwidth=None): """ Allocate and return a list of :class:`Outputs<pyrtl.wire.Output>`. :param names: Names for the Outputs. Can be a list or single comma/space-separated string :param int bitwidth: The desired bitwidth for the resulting Outputs. :return: List of Outputs. :rtype: List[Output] Equivalent to: :: wirevector_list(names, bitwidth, wvtype=pyrtl.wire.Output) """ return wirevector_list(names, bitwidth, wvtype=Output)
[docs] def register_list(names, bitwidth=None): """ Allocate and return a list of :class:`Registers<pyrtl.wire.Register>`. :param names: Names for the Registers. Can be a list or single comma/space-separated string :param int bitwidth: The desired bitwidth for the resulting Registers. :return: List of Registers. :rtype: List[Register] Equivalent to: :: wirevector_list(names, bitwidth, wvtype=pyrtl.wire.Register) """ return wirevector_list(names, bitwidth, wvtype=Register)
[docs] def wirevector_list(names, bitwidth=None, wvtype=WireVector): """ Allocate and return a list of WireVectors. :param names: Names for the WireVectors. Can be a list or single comma/space-separated string :param int bitwidth: The desired bitwidth for the resulting WireVectors. :param WireVector wvtype: Which WireVector type to create. :return: List of WireVectors. :rtype: List[WireVector] Additionally, the `names` string can also contain an additional bitwidth specification separated by a ``/`` in the name. This cannot be used in combination with a `bitwidth` value other than ``1``. Examples: :: wirevector_list(['name1', 'name2', 'name3']) wirevector_list('name1, name2, name3') wirevector_list('input1 input2 input3', bitwidth=8, wvtype=pyrtl.wire.Input) wirevector_list('output1, output2 output3', bitwidth=3, wvtype=pyrtl.wire.Output) wirevector_list('two_bits/2, four_bits/4, eight_bits/8') wirevector_list(['name1', 'name2', 'name3'], bitwidth=[2, 4, 8]) """ if isinstance(names, str): names = names.replace(',', ' ').split() if any('/' in name for name in names) and bitwidth is not None: raise PyrtlError('only one of optional "/" or bitwidth parameter allowed') if bitwidth is None: bitwidth = 1 if isinstance(bitwidth, numbers.Integral): bitwidth = [bitwidth] * len(names) if len(bitwidth) != len(names): raise ValueError('number of names ' + str(len(names)) + ' should match number of bitwidths ' + str(len(bitwidth))) wirelist = [] for fullname, bw in zip(names, bitwidth): try: name, bw = fullname.split('/') except ValueError: name, bw = fullname, bw wirelist.append(wvtype(bitwidth=int(bw), name=name)) return wirelist
[docs] def val_to_signed_integer(value: int, bitwidth: int) -> int: """Return value as intrepreted as a signed integer under two's complement. :param value: A Python integer holding the value to convert. :param bitwidth: The length of the integer in bits to assume for conversion. :return: ``value`` as a signed integer Given an unsigned integer (not a ``WireVector``!) convert that to a signed integer. This is useful for printing and interpreting values which are negative numbers in two's complement. :: val_to_signed_integer(0xff, 8) == -1 ``val_to_signed_integer`` can also be used as an ``repr_func`` for :py:meth:`.SimulationTrace.render_trace`, to display signed integers in traces:: bitwidth = 3 counter = Register(name='counter', bitwidth=bitwidth) counter.next <<= counter + 1 sim = Simulation() sim.step_multiple(nsteps=2 ** bitwidth) # Generates a trace like: # │0 │1 │2 │3 │4 │5 │6 │7 # # counter ──┤1 │2 │3 │-4│-3│-2│-1 sim.tracer.render_trace(repr_func=val_to_signed_integer) """ if isinstance(value, WireVector) or isinstance(bitwidth, WireVector): raise PyrtlError('inputs must not be wirevectors') if bitwidth < 1: raise PyrtlError('bitwidth must be a positive integer') neg_mask = 1 << (bitwidth - 1) neg_part = value & neg_mask pos_mask = neg_mask - 1 pos_part = value & pos_mask return pos_part - neg_part
[docs] def formatted_str_to_val(data, format, enum_set=None): """ Return an unsigned integer representation of the data given format specified. :param str data: a string holding the value to convert :param str format: a string holding a format which will be used to convert the data string :param enum_set: an iterable of enums which are used as part of the conversion process :return: `data` as a signed integer :rtype: int Given a string (not a WireVector!) convert that to an unsigned integer ready for input to the simulation enviornment. This helps deal with signed/unsigned numbers (simulation assumes the values have been converted via two's complement already), but it also takes hex, binary, and enum types as inputs. It is easiest to see how it works with some examples. :: formatted_str_to_val('2', 's3') == 2 # 0b010 formatted_str_to_val('-1', 's3') == 7 # 0b111 formatted_str_to_val('101', 'b3') == 5 formatted_str_to_val('5', 'u3') == 5 formatted_str_to_val('-3', 's3') == 5 formatted_str_to_val('a', 'x3') == 10 class Ctl(Enum): ADD = 5 SUB = 12 formatted_str_to_val('ADD', 'e3/Ctl', [Ctl]) == 5 formatted_str_to_val('SUB', 'e3/Ctl', [Ctl]) == 12 """ type = format[0] bitwidth = int(format[1:].split('/')[0]) bitmask = (1 << bitwidth) - 1 if type == 's': rval = int(data) & bitmask elif type == 'x': rval = int(data, 16) elif type == 'b': rval = int(data, 2) elif type == 'u': rval = int(data) if rval < 0: raise PyrtlError('unsigned format requested, but negative value provided') elif type == 'e': enumname = format.split('/')[1] enum_inst_list = [e for e in enum_set if e.__name__ == enumname] if len(enum_inst_list) == 0: raise PyrtlError('enum "{}" not found in passed enum_set "{}"' .format(enumname, enum_set)) rval = getattr(enum_inst_list[0], data).value else: raise PyrtlError('unknown format type {}'.format(format)) return rval
[docs] def val_to_formatted_str(val, format, enum_set=None): """Return a string representation of the value given format specified. :param int val: an unsigned integer to convert :param str format: a string holding a format which will be used to convert the data string :param enum_set: an iterable of enums which are used as part of the converstion process :return: a human-readable string representing `val`. :rtype: str Given an unsigned integer (not a WireVector!) convert that to a human-readable string. This helps deal with signed/unsigned numbers (simulation operates on values that have been converted via two's complement), but it also generates hex, binary, and enum types as outputs. It is easiest to see how it works with some examples. :: val_to_formatted_str(2, 's3') == '2' val_to_formatted_str(7, 's3') == '-1' val_to_formatted_str(5, 'b3') == '101' val_to_formatted_str(5, 'u3') == '5' val_to_formatted_str(5, 's3') == '-3' val_to_formatted_str(10, 'x3') == 'a' class Ctl(Enum): ADD = 5 SUB = 12 val_to_formatted_str(5, 'e3/Ctl', [Ctl]) == 'ADD' val_to_formatted_str(12, 'e3/Ctl', [Ctl]) == 'SUB' """ type = format[0] bitwidth = int(format[1:].split('/')[0]) bitmask = (1 << bitwidth) - 1 if type == 's': rval = str(val_to_signed_integer(val, bitwidth)) elif type == 'x': rval = hex(val)[2:] # cuts off '0x' at the start elif type == 'b': rval = bin(val)[2:] # cuts off '0b' at the start elif type == 'u': rval = str(int(val)) # nothing fancy elif type == 'e': enumname = format.split('/')[1] enum_inst_list = [e for e in enum_set if e.__name__ == enumname] if len(enum_inst_list) == 0: raise PyrtlError('enum "{}" not found in passed enum_set "{}"' .format(enumname, enum_set)) rval = enum_inst_list[0](val).name else: raise PyrtlError('unknown format type {}'.format(format)) return rval
# this is the return type of value_bitwidth_tuple ValueBitwidthTuple = collections.namedtuple('ValueBitwidthTuple', 'value bitwidth')
[docs] def infer_val_and_bitwidth(rawinput, bitwidth=None, signed=False): """ Return a tuple (value, bitwidth) infered from the specified input. :param rawinput: a bool, int, or verilog-style string constant :param int bitwidth: an integer bitwidth or (by default) None :param bool signed: a bool (by default set False) to include bits for proper two's complement :return: tuple of integers (`value`, `bitwidth`) :rtype: (int, int) Given a boolean, integer, or verilog-style string constant, this function returns a tuple of two integers (`value`, `bitwidth`) which are infered from the specified `rawinput`. The tuple returned is, in fact, a named tuple with names `.value` and `.bitwidth` for fields 0 and 1 respectively. If `signed` is True, bits will be included to ensure a proper two's complement representation is possible, otherwise it is assume all bits can be used for standard unsigned representation. Error checks are performed that determine if the bitwidths specified are sufficient and appropriate for the values specified. Examples can be found below :: infer_val_and_bitwidth(2, bitwidth=5) == (2, 5) infer_val_and_bitwidth(3) == (3, 2) # bitwidth infered from value infer_val_and_bitwidth(3, signed=True) == (3, 3) # need a bit for the leading zero infer_val_and_bitwidth(-3, signed=True) == (5, 3) # 5 = -3 & 0b111 = ..111101 & 0b111 infer_val_and_bitwidth(-4, signed=True) == (4, 3) # 4 = -4 & 0b111 = ..111100 & 0b111 infer_val_and_bitwidth(-3, bitwidth=5, signed=True) == (29, 5) infer_val_and_bitwidth(-3) ==> Error # negative numbers require bitwidth or signed=True infer_val_and_bitwidth(3, bitwidth=2) == (3, 2) infer_val_and_bitwidth(3, bitwidth=2, signed=True) ==> Error # need space for sign bit infer_val_and_bitwidth(True) == (1, 1) infer_val_and_bitwidth(False) == (0, 1) infer_val_and_bitwidth("5'd12") == (12, 5) infer_val_and_bitwidth("5'b10") == (2, 5) infer_val_and_bitwidth("5'b10").bitwidth == 5 infer_val_and_bitwidth("5'b10").value == 2 infer_val_and_bitwidth("8'B 0110_1100") == (108, 8) """ if isinstance(rawinput, bool): return _convert_bool(rawinput, bitwidth, signed) elif isinstance(rawinput, numbers.Integral): return _convert_int(rawinput, bitwidth, signed) elif isinstance(rawinput, str): return _convert_verilog_str(rawinput, bitwidth, signed) else: raise PyrtlError('error, the value provided is of an improper type, "%s"' 'proper types are bool, int, and string' % type(rawinput))
def _convert_bool(bool_val, bitwidth=None, signed=False): if signed: raise PyrtlError('error, booleans cannot be signed (convert to int first)') num = int(bool_val) if bitwidth is None: bitwidth = 1 if bitwidth != 1: raise PyrtlError('error, boolean has bitwidth not equal to 1') return ValueBitwidthTuple(num, bitwidth) def _convert_int(val, bitwidth=None, signed=False): if val >= 0: num = val # infer bitwidth if it is not specified explicitly min_bitwidth = len(bin(num)) - 2 # the -2 for the "0b" at the start of the string if signed and val != 0: min_bitwidth += 1 # extra bit needed for the zero if bitwidth is None: bitwidth = min_bitwidth elif bitwidth < min_bitwidth: raise PyrtlError('bitwidth specified is insufficient to represent constant') else: # val is negative if not signed and bitwidth is None: raise PyrtlError('negative constants require either signed=True or specified bitwidth') if bitwidth is None: bitwidth = 1 if val == -1 else len(bin(~val)) - 1 if (val >> bitwidth - 1) != -1: raise PyrtlError('insufficient bits for negative number') num = val & ((1 << bitwidth) - 1) # result is a two's complement value return ValueBitwidthTuple(num, bitwidth) def _convert_verilog_str(val, bitwidth=None, signed=False): if signed: raise PyrtlError('error, "signed" option with verilog-style string constants not supported') bases = {'b': 2, 'o': 8, 'd': 10, 'h': 16, 'x': 16} passed_bitwidth = bitwidth neg = False if val.startswith('-'): neg = True val = val[1:] split_string = val.lower().split("'") if len(split_string) != 2: raise PyrtlError('error, string not in verilog style format') try: bitwidth = int(split_string[0]) sval = split_string[1] if sval[0] == 's': raise PyrtlError('error, signed integers are not supported in verilog-style constants') base = 10 if sval[0] in bases: base = bases[sval[0]] sval = sval[1:] sval = sval.replace('_', '') num = int(sval, base) except (IndexError, ValueError): raise PyrtlError('error, string not in verilog style format') if neg and num: if (num >> bitwidth - 1): raise PyrtlError('error, insufficient bits for negative number') num = (1 << bitwidth) - num if passed_bitwidth and passed_bitwidth != bitwidth: raise PyrtlError('error, bitwidth parameter of constant does not match' ' the bitwidth infered from the verilog style specification' ' (if bitwidth=None is used, pyrtl will determine the bitwidth from the' ' verilog-style constant specification)') if num >> bitwidth != 0: raise PyrtlError('specified bitwidth %d for verilog constant insufficient to store value %d' % (bitwidth, num)) return ValueBitwidthTuple(num, bitwidth) def get_stacks(*wires): call_stack = getattr(wires[0], 'init_call_stack', None) if not call_stack: return ' No call info found for wires: use set_debug_mode() ' \ 'to provide more information\n' else: return '\n'.join(str(wire) + ":\n" + get_stack(wire) for wire in wires) def get_stack(wire): if not isinstance(wire, WireVector): raise PyrtlError('Only WireVectors can be traced') call_stack = getattr(wire, 'init_call_stack', None) if call_stack: frames = ' '.join(frame for frame in call_stack[:-1]) return "Wire Traceback, most recent call last \n" + frames + "\n" else: return ' No call info found for wire: use set_debug_mode()'\ ' to provide more information' def _check_for_loop(block=None): block = working_block(block) logic_left = block.logic.copy() wires_left = block.wirevector_subset(exclude=(Input, Const, Output, Register)) prev_logic_left = len(logic_left) + 1 while prev_logic_left > len(logic_left): prev_logic_left = len(logic_left) nets_to_remove = set() # bc it's not safe to mutate a set inside its own iterator for net in logic_left: if not any(n_wire in wires_left for n_wire in net.args): nets_to_remove.add(net) wires_left.difference_update(net.dests) logic_left -= nets_to_remove if 0 == len(logic_left): return None return wires_left, logic_left def find_loop(block=None): block = working_block(block) block.sanity_check() # make sure that the block is sane first result = _check_for_loop(block) if not result: return wires_left, logic_left = result import random class _FilteringState(object): def __init__(self, dst_w): self.dst_w = dst_w self.arg_num = -1 def dead_end(): # clean up after a wire is found to not be part of the loop wires_left.discard(cur_item.dst_w) current_wires.discard(cur_item.dst_w) del checking_stack[-1] # now making a map to quickly look up nets dest_nets = {dest_w: net_ for net_ in logic_left for dest_w in net_.dests} initial_w = random.sample(list(wires_left), 1)[0] current_wires = set() checking_stack = [_FilteringState(initial_w)] # we don't use a recursive method as Python has a limited stack (default: 999 frames) while len(checking_stack): cur_item = checking_stack[-1] if cur_item.arg_num == -1: # first time testing this item if cur_item.dst_w not in wires_left: dead_end() continue current_wires.add(cur_item.dst_w) cur_item.net = dest_nets[cur_item.dst_w] if cur_item.net.op == 'r': dead_end() continue cur_item.arg_num += 1 # go to the next item if cur_item.arg_num == len(cur_item.net.args): dead_end() continue next_wire = cur_item.net.args[cur_item.arg_num] if next_wire not in current_wires: current_wires.add(next_wire) checking_stack.append(_FilteringState(next_wire)) else: # We have found the loop!!!!! loop_info = [] for f_state in reversed(checking_stack): loop_info.append(f_state) if f_state.dst_w is next_wire: break else: raise PyrtlError("Shouldn't get here! Couldn't figure out the loop") return loop_info raise PyrtlError("Error in detecting loop") def find_and_print_loop(block=None): loop_data = find_loop(block) print_loop(loop_data) return loop_data def print_loop(loop_data): if not loop_data: print("No Loop Found") else: print("Loop found:") print('\n'.join("{}".format(fs.net) for fs in loop_data)) # print '\n'.join("{} (dest wire: {})".format(fs.net, fs.dst_w) for fs in loop_info) print("") def _currently_in_jupyter_notebook(): """ Return true if running under Jupyter notebook, otherwise return False. We want to check for more than just the presence of __IPYTHON__ because that is present in both Jupyter notebooks and IPython terminals. """ try: # get_ipython() is in the global namespace when ipython is started shell = get_ipython().__class__.__name__ if shell == 'ZMQInteractiveShell': return True # Jupyter notebook or qtconsole elif shell == 'TerminalInteractiveShell': return False # Terminal running IPython else: return False # Other type except NameError: return False # Probably standard Python interpreter def _print_netlist_latex(netlist): """ Print each net in netlist in a Latex array """ from IPython.display import display, Latex # pylint: disable=import-error out = '\n\\begin{array}{ \\| c \\| c \\| l \\| }\n' out += '\n\\hline\n' out += '\\hline\n'.join(str(n) for n in netlist) out += '\\hline\n\\end{array}\n' display(Latex(out)) class _NetCount(object): """ Helper class to track when to stop an iteration that depends on number of nets Mainly useful for iterations that are for optimization """ def __init__(self, block=None): self.block = working_block(block) self.prev_nets = len(self.block.logic) * 1000 def shrank(self, block=None, percent_diff=0, abs_diff=1): """ Returns whether a block has fewer nets than before :param Block block: block to check (if changed) :param Number percent_diff: percentage difference threshold :param int abs_diff: absolute difference threshold :return: boolean This function checks whether the change in the number of nets is greater than the percentage and absolute difference thresholds. """ if block is None: block = self.block cur_nets = len(block.logic) net_goal = self.prev_nets * (1 - percent_diff) - abs_diff less_nets = (cur_nets <= net_goal) self.prev_nets = cur_nets return less_nets shrinking = shrank # _ComponentMeta holds the component's name, bitwidth, and type. If the # _ComponentMeta's type is None, then the default component_type should be used # instead. _ComponentMeta = collections.namedtuple('_ComponentMeta', ['name', 'bitwidth', 'type']) def _make_component(component_meta: _ComponentMeta, block: Block, name: str, component_type, component_value): '''Determine the component's type, instantiate it, and set its value.''' # Determine the component's actual type. # # If the _ComponentMeta specifies a type, then the component is a # wire_struct or a wire_matrix. The _ComponentMeta's type must be used as # the component's primary type, and the default_component_type becomes the # component's concatenated_type. # # If the _ComponentMeta does not specify a type, the component uses the # default component_type. if component_meta.type is None: actual_component_type = component_type else: actual_component_type = component_meta.type component_name = '' if len(name) > 0: if isinstance(component_meta.name, str): # wire_struct components are named with strings and printed with # dots, like `struct.component`. component_name = name + '.' + component_meta.name else: # wire_matrix components are numbered with integers and printed # with brackets, like `matrix[0]`. component_name = name + '[' + str(component_meta.name) + ']' # The logic below always creates a new wire_struct, wire_matrix, or # WireVector for each component. If the component_value already has the # appropriate type and name, we could use the component_value directly and # we don't need a new struct/matrix/Vector. Correctly detecting these # opportunities is complicated, so we keep things simple for now. # # Components are always initialized with one concatenated component_value, # which provides values for all its wires. This implies that component # wire_structs and wire_matricies always call _split(). if hasattr(actual_component_type, '_is_wire_struct'): # Make a wire_struct component. component_value may be None. component_kwargs = { actual_component_type._class_name: component_value } component = actual_component_type( name=component_name, block=block, concatenated_type=component_type, **component_kwargs) elif hasattr(actual_component_type, '_is_wire_matrix'): # Make a wire_matrix component. component_value may be None. component = actual_component_type( name=component_name, block=block, concatenated_type=component_type, values=[component_value]) elif (isinstance(component_value, int) and actual_component_type is WireVector): # Special case: simplify the component type to Const. component = Const( bitwidth=component_meta.bitwidth, name=component_name, block=block, val=component_value) else: # Make a WireVector component. component = actual_component_type( bitwidth=component_meta.bitwidth, name=component_name, block=block) if component_value is not None: component <<= component_value return component def _bitslice(value: int, start: int, end: int) -> int: '''Slice an integer value bitwise, from start to end.''' mask = (1 << (end - start)) - 1 return (value >> start) & mask def _slice(block: Block, schema: list[_ComponentMeta], bitwidth: int, component_type: type, name: str, concatenated, components, concatenated_value): '''Slice ``concatenated`` into components. ``concatenated_value`` is the driver for ``concatenated``. Some optimizations are possible by inspecting ``concatenated_value``, for example we immediately slice Consts rather than generating slicing logic. ''' if concatenated_value is not None and not isinstance(concatenated, Const): concatenated <<= concatenated_value end_index = bitwidth for component_meta in schema: if isinstance(concatenated_value, int): # Special case: immediately slice Const values. component_value = _bitslice(concatenated_value, end_index - component_meta.bitwidth, end_index) else: component_value = concatenated[ end_index - component_meta.bitwidth:end_index] end_index -= component_meta.bitwidth component = _make_component(component_meta=component_meta, block=block, name=name, component_type=component_type, component_value=component_value) components[component_meta.name] = component def _concatenate(block: Block, schema: list[_ComponentMeta], component_type: type, name: str, concatenated, components, component_map): '''Concatenate components from ``component_map`` to ``concatenated``.''' all_components = [] for component_meta in schema: component_value = component_map[component_meta.name] component = _make_component(component_meta=component_meta, block=block, name=name, component_type=component_type, component_value=component_value) components[component_meta.name] = component all_components.append(component) concatenated <<= concat(*all_components)
[docs] def wire_struct(wire_struct_spec): '''Decorator that assigns names to ``WireVector`` slices. ``@wire_struct`` assigns names to *non-overlapping* ``WireVector`` slices. Suppose we have an 8-bit wide ``WireVector`` called ``byte``. We can refer to all 8 bits with the name ``byte``, but ``@wire_struct`` lets us refer to slices by name, for example we could name the high 4 bits ``byte.high`` and the low 4 bits ``byte.low``. Without ``@wire_struct``, we would refer to these slices as ``byte[4:8]`` and ``byte[0:4]``, which are prone to off-by-one errors and harder to read. The example ``Byte`` ``@wire_struct`` can be defined as:: @wire_struct class Byte: high: 4 # 'high' is name for the 4 most significant bits. low: 4 # 'low' is name for the 4 least significant bits. ------------ Construction ------------ Once a ``@wire_struct`` class is defined, it can be instantiated by providing drivers for all of its wires. This can be done in two ways: 1. Provide a driver for *each* component wire, for example:: byte = Byte(high=0xA, low=0xB) Note how the component names (``high``, ``low``) are used as keyword args for the constructor. Drivers must be provided for *all* components. 2. Provide a driver for the entire ``@wire_struct``, for example:: byte = Byte(Byte=0xAB) Note how the class name (``Byte``) is used as a keyword arg for the constructor. ---------------- Accessing Slices ---------------- After instantiating a ``@wire_struct``, the instance functions as a ``WireVector`` containing all the wires. For example, ``byte`` functions as a ``WireVector`` with bitwidth 8:: byte = Byte(Byte=0xAB) print(byte.bitwidth) # Prints 8. The named slice can be accessed through the ``.`` operator (``__getattr__``), for example ``byte.high`` and ``byte.low``, which both function as ``WireVector`` with bitwidth 4:: byte = Byte(Byte=0xAB) print(byte.high.bitwidth) # Prints 4. print(byte.low.bitwidth) # Prints 4. Both the instance and the slices are first-class ``WireVector``, so they can be manipulated with all the usual PyRTL operators. .. NOTE:: ``len(byte)`` returns the number of components in the ``@wire_struct`` (2), not the total bitwidth (8 == 4 + 4). To get the total bitwidth, use ``byte.bitwidth`` or ``len(as_wires(byte))``. ------ Naming ------ A ``@wire_struct`` can be assigned a name in the usual way:: byte = Byte(name='b', high=0xC, low=0xD) byte = Byte(name='b', Byte=0xCD) When a ``@wire_struct`` is assigned a name (``b``), its components will be assigned dotted names (``b.high``, ``b.low``):: print(byte.high.name) # Prints 'b.high'. print(byte.low.name) # Prints 'b.low'. .. WARNING:: All ``@wire_struct`` names are only set during construction. You can later rename a ``@wire_struct`` or its components, but those changes are local, and will not propagate to other ``@wire_struct`` components. Renaming a ``@wire_struct`` or its components is strongly discouraged. ----------- Composition ----------- ``@wire_struct`` can be composed with itself, and with ``wire_matrix``. For example, we can define a ``Pixel`` that contains three ``Byte``:: @wire_struct class Pixel: red: Byte green: Byte blue: Byte Drivers must be specified for all components, but they can be specified at any level. All these examples construct an equivalent ``@wire_struct``:: pixel = Pixel(Pixel=0xABCDEF) pixel = Pixel(red=0xAB, green=0xCD, blue=0xEF) pixel = Pixel(red=Byte(high=0xA, low=0xB), green=0xCD, blue=0xEF) pixel = Pixel(red=Byte(high=0xA, low=0xB), green=Byte(high=0xC, low=0xD), blue=0xEF) Hierarchical ``@wire_struct`` components are accessed by composing ``.`` operators:: pixel pixel.red pixel.red.high pixel.red.low pixel.green pixel.green.high pixel.green.low pixel.blue pixel.blue.high pixel.blue.low ``@wire_struct`` can be composed with ``wire_matrix``:: Word = wire_matrix(component_schema=8, size=4) @wire_struct class CacheLine: address: Word data: Word valid: 1 cache_line = CacheLine(address=0x01234567, data=0x89ABCDEF, valid=1) Leaf-level components can be accessed by combining the ``.`` and ``[]`` operators, for example ``cache_line.address[3]``. ----- Types ----- You can change the type of a ``@wire_struct``'s components to a ``WireVector`` subclass like :py:class:`Input` or :py:class:`Output` with the ``component_type`` constructor argument:: # Generates Outputs named ``output_byte.low`` and ``output_byte.high``. output_byte = Byte(name='output_byte', component_type=pyrtl.Output, Byte=0xCD) You can also change the type of the ``@wire_struct`` itself with the ``concatenated_type`` cnstructor argument:: # Generates an Input named ``input_byte``. input_byte = Byte(name='input_byte', concatenated_type=pyrtl.Input) .. NOTE:: No values are specified for ``input_byte`` because its value is not known until simulation time. ''' # Convert the decorated class' annotations (dict of attr_name: attr_value) # to a list of _ComponentMetas. # # dict iteration order is guaranteed to be insertion order in Python 3.7+. schema = [] for component_name, component_bitwidth in ( wire_struct_spec.__annotations__.items()): if isinstance(component_bitwidth, int): # An ordinary component ("foo: 4") that should use the default # component_type. schema.append(_ComponentMeta( name=component_name, bitwidth=component_bitwidth, type=None)) else: # A nested component ("bar: Byte") that must use the nested # component's type. schema.append(_ComponentMeta( name=component_name, bitwidth=component_bitwidth._bitwidth, type=component_bitwidth)) total_bitwidth = sum([component.bitwidth for component in schema]) # Name of the decorated class. class_name = wire_struct_spec.__name__ class _WireStruct(WrappedWireVector): '''``wire_struct`` implementation: Concatenate or slice ``WireVector``. ``wire_struct`` works by either concatenating component ``WireVector`` to create the ``wire_struct``'s full value, *or* slicing a ``wire_struct``s value to create component ``WireVectors``. A ``wire_struct`` can only concatenate or slice, not both. The decision to concatenate or slice is made in __init__. ''' _bitwidth = total_bitwidth _class_name = class_name _is_wire_struct = True def __init__(self, name='', block=None, concatenated_type=WireVector, component_type=WireVector, **kwargs): '''Concatenate or slice ``WireVector`` components. :param str name: The name of the concatenated wire. Must be unique. If none is provided, one will be autogenerated. If a name is provided, components will be assigned names of the form "{name}.{component_name}". :param Block block: The block containing the concatenated and component wires. Defaults to the working block. :param type concatenated_type: Type for the concatenated ``WireVector``. :param type component_type: Type for each component. The remaining keyword args specify values for all wires. If the concatenated value is provided, its value must be provided with the keyword arg matching the decorated class name. For example, if the decorated class is:: @wire_struct class Byte: high: 4 # high is the 4 most significant bits. low: 4 # low is the 4 least significant bits. then the concatenated value must be provided like this:: byte = Byte(Byte=0xAB) And if the component values are provided instead, their values are set by keyword args matching the component names:: byte = Byte(low=0xA, high=0xB) ''' # The concatenated WireVector contains all the _WireStruct's wires. # WrappedWireVector (base class) will forward all attribute and # method accesses on this _WireStruct to the concatenated # WireVector. if ((class_name in kwargs and isinstance(kwargs[class_name], int) and concatenated_type is WireVector)): # Special case: simplify the concatenated type to Const. concatenated = Const( bitwidth=self._bitwidth, name=name, block=block, val=kwargs[class_name]) else: concatenated = concatenated_type( bitwidth=self._bitwidth, name=name, block=block) super().__init__(wire=concatenated) # self._components maps from component name to each component's # WireVector. components = {} self.__dict__['_components'] = components # Handle Input and Register special cases. if concatenated_type is Input or concatenated_type is Register: kwargs = {class_name: None} elif component_type is Input or component_type is Register: kwargs = {component_meta.name: None for component_meta in schema} if class_name in kwargs: # Check for unused kwargs. for component_name in kwargs: if component_name != class_name: raise PyrtlError( 'Do not pass additional kwargs to @wire_struct ' f'when slicing. ("{class_name}" was passed so ' f'don\'t pass "{component_name}")') # Concatenated value was provided. Slice it into components. _slice(block=block, schema=schema, bitwidth=self._bitwidth, component_type=component_type, name=name, concatenated=concatenated, components=components, concatenated_value=kwargs[class_name]) else: # Component values were provided; concatenate them. # Check that values were provided for all components. expected_component_names = ( [component_meta.name for component_meta in schema]) for expected_component_name in expected_component_names: if expected_component_name not in kwargs: raise PyrtlError( 'You must provide kwargs for all @wire_struct ' 'components when concatenating (missing kwarg ' f'"{expected_component_name}")') # Check for unused kwargs. for component_name in kwargs: if component_name not in expected_component_names: raise PyrtlError( 'Do not pass additional kwargs to @wire_struct ' 'when concatenating (don\'t pass ' f'"{component_name}")') _concatenate(block=block, schema=schema, component_type=component_type, name=name, concatenated=concatenated, components=components, component_map=kwargs) def __getattr__(self, component_name: str): '''Retrieve a component by name. Components are concatenated to form the concatenated ``WireVector``, or sliced from the concatenated ``WireVector``. :param component_name: The name of the component wire. ''' components = self.__dict__['_components'] if component_name in components: return components[component_name] return super().__getattr__(component_name) def __len__(self): components = self.__dict__['_components'] return len(components) return _WireStruct
[docs] def wire_matrix(component_schema, size: int): '''Returns a class that assigns numbered indices to ``WireVector`` slices. ``wire_matrix`` assigns numbered indices to *non-overlapping* ``WireVector`` slices. ``wire_matrix`` is very similar to :py:func:`wire_struct`, so read :py:func:`wire_struct`'s documentation first. An example 32-bit ``Word`` ``wire_matrix``, which represents a group of four bytes, can be defined as:: Word = wire_matrix(component_schema=8, size=4) .. NOTE:: ``wire_matrix`` returns a class, like ``namedtuple``. ------------ Construction ------------ Once a ``wire_matrix`` class is defined, it can be instantiated by providing drivers for all of its wires. This can be done in two ways:: # Provide a driver for each component, most significant bits first. word = Word(values=[0x89, 0xAB, 0xCD, 0xEF]) # Provide a driver for all components. word = Word(values=[0x89ABCDEF]) .. NOTE:: When specifying drivers for each component, the most significant bits are specified first. After instantiating a ``wire_matrix``, regardless of how it was constructed, the instance functions as a ``WireVector`` containing all the wires, so ``word`` functions as a ``WireVector`` with bitwidth 32. The named slice can be accessed with square brackets (``__getitem__``), for example ``word[0]`` and ``word[3]``, which both function as ``WireVector`` with bitwidth 8. ``word[0]`` refers to the most significant byte, and ``word[3]`` refers to the least significant byte. Both the instance and the slices are first-class ``WireVector``, so they can be manipulated with all the usual PyRTL operators. ------ Naming ------ A ``wire_matrix`` can be assigned a name in the usual way:: # The whole Word is named 'w', so the components will have names # w[0], w[1], ... word = Word(name='w', values=[0x89, 0xAB, 0xCD, 0xEF]) word = Word(name='w', values=[0x89ABCDEF]) ----------- Composition ----------- ``wire_matrix`` can be composed with itself and ``@wire_struct``. For example, we can define some multi-dimensional byte arrays:: Array1D = wire_matrix(component_schema=8, size=2) Array2D = wire_matrix(component_schema=Array1D, size=2) Drivers must be specified for all components, but they can be specified at any level. All these examples construct an equivalent ``wire_matrix``:: array_2d = Array2D(values=[0x89AB, 0xCDEF]) array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]), 0xCDEF]) array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]), Array1D(values=[0xCD, 0xEF])]) ---------------- Accessing Slices ---------------- Hierarchical components are accessed by composing ``[]`` operators, for example:: print(array_2d[0][0].bitwidth) # Prints 8. print(array_2d[0][1].bitwidth) # Prints 8. When ``wire_matrix`` is composed with ``@wire_struct``, components can be accessed by combining the ``[]`` and ``.`` operators:: @wire_struct class Byte: high: 4 low: 4 Array1D = wire_matrix(component_schema=Byte, size=2) array_1d = Array1D(values=[0xAB, 0xCD]) print(array_1d[0].high.bitwidth) # Prints 4. .. NOTE:: ``len(array_1d)`` returns the number of components in the ``wire_matrix`` (2), not the total bitwidth (16 == 2 * 8). To get the total bitwidth, use ``array_1d.bitwidth`` or ``len(as_wires(array_1d))``. ----- Types ----- You can change the type of a ``wire_matrix``'s components with the ``component_type`` constructor argument:: # Generates Outputs named ``output_word[0]``, ``output_word[1]``, ... word = Word(name='output_word', component_type=pyrtl.Output, values=[0x89ABCDEF]) You can change the type of the ``wire_matrix`` itself with the ``concatenated_type`` cnstructor argument:: # Generates an Input named ``input_word``. word = Word(name='input_word', concatenated_type=pyrtl.Input) .. NOTE:: No values are specified for ``input_word`` because its value is not known until simulation time. ''' # Determine each component's bitwidth. if ((hasattr(component_schema, '_is_wire_struct') or hasattr(component_schema, '_is_wire_matrix'))): component_bitwidth = component_schema._bitwidth else: component_bitwidth = component_schema component_schema = None class _WireMatrix(WrappedWireVector): _component_bitwidth = component_bitwidth _component_schema = component_schema _size = size _bitwidth = component_bitwidth * size _is_wire_matrix = True def __init__(self, name: str = '', block: Block = None, concatenated_type=WireVector, component_type=WireVector, values: list = []): # The concatenated WireVector contains all the _WireMatrix's wires. # WrappedWireVector (base class) will forward all attribute and # method accesses on this _WireMatrix to the concatenated # WireVector. if ((len(values) == 1 and isinstance(values[0], int) and concatenated_type is WireVector)): # Special case: simplify the concatenated type to Const. concatenated = Const( bitwidth=self._bitwidth, name=name, block=block, val=values[0]) else: concatenated = concatenated_type( bitwidth=self._bitwidth, name=name, block=block) super().__init__(wire=concatenated) schema = [] for component_name in range(self._size): schema.append(_ComponentMeta( name=component_name, bitwidth=self._component_bitwidth, type=component_schema)) # Handle Input and Register special cases. if concatenated_type is Input or concatenated_type is Register: values = [None] elif component_type is Input or component_type is Register: values = [None for _ in range(self._size)] self._components = [None for i in range(len(schema))] if len(values) == 1: # Concatenated value was provided. Slice it into components. _slice(block=block, schema=schema, bitwidth=self._bitwidth, component_type=component_type, name=name, concatenated=concatenated, components=self._components, concatenated_value=values[0]) else: if len(values) != len(schema): raise PyrtlError( 'wire_matrix constructor expects 1 value to slice, or ' f'{len(schema)} values to concatenate (received ' f'{len(values)} values)') # Component values were provided; concatenate them. _concatenate(block=block, schema=schema, component_type=component_type, name=name, concatenated=concatenated, components=self._components, component_map=values) def __getitem__(self, key): return self._components[key] def __len__(self): return len(self._components) return _WireMatrix