from spidr4.bitfield import BitField, Field
from spidr4 import rpc
from typing import Optional
from numpy.typing import ArrayLike
import numpy as np
# Width and height
WIDTH = 448
HEIGHT = 512
PACKET_COUNT_FB8_HALF=(448*256) // 8 + 20
PACKET_COUNT_FB16_HALF=(448*256) // 4 + 20
PACKETS_COUNT_FB_SEGMENT=(448*256) // (8 * 8)
CS_HEARTBEAT = 0xE0
CS_SHUTTER_RISE = 0xE1
CS_SHUTTER_FALL = 0xE2
CS_T0_SYNC = 0xE3
CS_SIGNAL_RISE = 0xE4
CS_SIGNAL_FALL = 0xE5
CS_CTRL_DATA_TEST = 0xEA
CS_FRAME_START = 0xF0
CS_FRAME_END = 0xF1
CS_SEGMENT_START = 0xF2
CS_SEGMENT_END= 0xF3
CS_IDLE = 0xFF
_cslookup = {
    CS_HEARTBEAT: 'HEARTBEAT',
    CS_SHUTTER_RISE: 'SHUTTER_RISE',
    CS_SHUTTER_FALL: 'SHUTTER_FALL',
    CS_T0_SYNC: 'T0_SYNC',
    CS_SIGNAL_RISE: 'SIGNAL_RISE',
    CS_SIGNAL_FALL: 'SIGNAL_FALL',
    CS_CTRL_DATA_TEST: 'CTRL_DATA_TEST',
    CS_FRAME_START: 'FRAME_START',
    CS_FRAME_END: 'FRAME_END',
    CS_SEGMENT_START: 'SEGMENT_START',
    CS_SEGMENT_END: 'SEGMENT_END',
    CS_IDLE: 'IDLE'
}
[docs]
def cs_lookup(header):
    if header in _cslookup:
        return _cslookup[header]
    else:
        return f"Uknown<{header}" 
[docs]
class ToAToTPacket(BitField):
    Top = Field(1, 63)
    EoC = Field(8, 55)
    SPGroup = Field(4, 51)
    SPixel = Field(2, 49)
    Pixel = Field(3, 46)
    addr = Field(18, 46)
    ToA = Field(16, 30)
    ufToA_start = Field(4, 26)
    ufToA_stop = Field(4, 22)
    fToA_rise = Field(5, 17)
    fToA_fall = Field(5, 12)
    ToT = Field(11, 1)
    Pileup = Field(1, 0) 
[docs]
class PC24Packet(BitField):
    Top = Field(1, 63)
    EoC = Field(8, 55)
    SPGroup = Field(4, 51)
    SPixel = Field(2, 49)
    Pixel = Field(3, 46)
    EventCount = Field(24, 0) 
[docs]
class ControlStatusPacket(BitField):
    top = Field(1, 63)
    header = Field(8, 55)
    segment = Field(3, 52)
    data = Field(52, 0) 
[docs]
class PixelConfig(BitField):
    mask = Field(1, 7)
    tp_enable = Field(1, 6)
    power_enable = Field(1, 5)
    dac = Field(5, 0) 
[docs]
class SPGroupConfig(BitField):
    mask_sp_3 = Field(1, 23)
    vco_adj3 = Field(4, 19)
    mask_sp_2 = Field(1, 17)
    vco_adj2 = Field(4, 13)
    bypass_up = Field(1, 12)
    mask_sp_1 = Field(1, 11)
    vco_adj1 = Field(4, 7)
    bypass_down = Field(1, 6)
    mask_sp_0 = Field(1, 5)
    vco_adj0 = Field(4, 1)
    digital_pixel_enable = Field(1, 0) 
# Convert logical X, Y coordinates to chip coordinates
# (half, eoc, spgroup, spixel and pixel coordinates)
#
# The origin (x=0, y=0) located at half='bottom', half=0 (bottom), eoc=0, spgroup=0, spixel=0, pixel=0
# while the opposite side (x=447, y=511) is located at half=1 (top), eoc=0, spgroup=0, spixel=0, pixel=0
# (x=0, y=511) is located at half=1 (top), eoc=223, spgroup=0, spixel=0, pixel=4
#
# (0, 511) +--------+ (447, 511)
#          |  top   |
#          +--------+
#          | bottom |
#   (0, 0) +--------+ (447, 0)
#
[docs]
def chip_coords2logic(top, eoc, spgroup, spixel, pixel, is_config=False):
  """
  Coords to logic x,y position
  :param half:            Half, 0 - bottom, 1 - top
  :param eoc:             End of column 0..223
  :param spgroup:         Super pixel group 0..15
  :param spixel:          Super pixel 0..3
  :param pixel:           Pixel 0..7
  :param is_config:       Whether or not the coordinates are for a configuration
  :return:
  """
  if top and is_config:
    eoc = 223- eoc
  
  if is_config:
    spgroup = 15-spgroup
  x = eoc * 2 + pixel // 4
  y = spgroup * 16 + spixel * 4 + pixel % 4
  if top:
    x = 447 - x
    y = 511 - y
  return x,y 
[docs]
def logic2chip_coords(x, y, is_config=False):
  """
  x, y to coords
  """
  top = y // 256
  if top:
    y = 511 - y
    x = 447 - x
  
  eoc = x //2
  spgroup = y // 16
  spixel = (y % 16) // 4
  pixel = y % 4 + 4 * (x % 2)
  if top and is_config:
    eoc = 223 - eoc
  if is_config:
    spgroup = 15 - spgroup
  return top, eoc, spgroup, spixel, pixel 
[docs]
def chip_coords2logic_idx(half, eoc, spgroup, spixel, pixel, is_config=False):
    x, y = chip_coords2logic(half, eoc, spgroup, spixel, pixel, is_config)
    return x + y * 448 
[docs]
def logic2chip_coords_idx(x, y, is_config=False, blob=False):
    half, eoc, spgroup, spixel, pixel = logic2chip_coords(x, y, is_config)
    if blob:
        offset = (1-half) * 256 * 448
    else:
        offset = half * 256 * 448
    offset += eoc * 4 * 8 * 16
    offset += spgroup * 4 * 8
    offset += spixel * 8
    offset += pixel
    if offset >= 512*448:
        raise RuntimeError(",".join(str(x) for x in (half, eoc, spgroup, spixel, pixel, x, y)))
    return offset 
# Create simply LUT
_CHIP2LOGIC_CFG_IDX = [chip_coords2logic_idx(half, eoc, spgroup, spixel, pixel, is_config=True)
                       for half in range(1, -1, -1)
                       for eoc in range(224)
                       for spgroup in range(16)
                       for spixel in range(4)
                       for pixel in range(8)
                       ]
# Create inverse LUT
_LOGIC2CHIP_CONFIG_IDX = [logic2chip_coords_idx(x, y, is_config=True, blob=True)
                          for y in range(512)
                          for x in range(448)]
[docs]
class PartialPixelUpdater:
    """
    Convenience class for (partially) updating a pixel matrix.
    It works by initially programming the matrix with provided pixel config.
    After this you can update the pixel-matrix by calling the `update()` function, which will only
    update those columns which are actually changed.
    """
    def __init__(self, tpx4: rpc.Timepix4Stub, initial_cfg: ArrayLike, idx: int=0, spgaccess=False):
        """
        spgaccess - Super pixel group access. This mode may boost performance, but at the cost
        of one bit of the pixel trim dacs at specific pixels. So, don't use unless you know what
        you are doing,
        """
        self._tpx4 = tpx4
        self._idx = idx
        self._shape = (2, 224, 16, 32) if spgaccess else (2, 224, 512)
        self._cur = np.array(initial_cfg)
        self._cur_blob = logic2chip_cfg_matrix(self._cur).reshape(self._shape)
        self._nxt = None
        tpx4.ConfigPixels(
            rpc.Tpx4PixelConfig(
                idx=self._idx,
                config=self._cur_blob.tobytes()
            )
        )
[docs]
    @classmethod
    def to_pp(cls, h, c, s=0, data=None):
        return rpc.Tpx4PartialPixelConfig.Tpx4PixelPart(
            half=rpc.TPX4_TOP if h == 0 else rpc.TPX4_BOTTOM,
            column=c,
            sp_group=s,
            config=data.tobytes()) 
    def _update(self, validate):
        nxt_blob = logic2chip_cfg_matrix(self._nxt).reshape(self._shape)
        changed = self._cur_blob != nxt_blob
        if validate:
            print(f"Actual changes: {len(np.argwhere(changed))}")
        changes = np.argwhere(changed.any(len(changed.shape)-1))
        ppc = rpc.Tpx4PartialPixelConfig(idx=self._idx,
                                         parts=[
                                             PartialPixelUpdater.to_pp(*x, data=nxt_blob[tuple(x)])
                                            for x in changes])
        self._tpx4.ConfigPixelsPartial(ppc)
        if validate:
            should_be = np.frombuffer(self._tpx4.ConfigGetPixels(
                rpc.ChipIndex(idx=self._idx)
            ).config, dtype=np.uint8)
            pretty = chip2logic_cfg_matrix(should_be)
            expected_changes = np.argwhere(self._nxt != self._cur).transpose()
            applied_changes = np.argwhere(pretty != self._cur).transpose()
            if len(expected_changes) != len(applied_changes):
                import matplotlib.pyplot as plt
                for i, y in enumerate(np.arange(0.5, 448, 1)):
                    plt.axvline(y, color="#eeeeee" if i % 2 != 1 else "#cccccc", lw=0.5)
                for i, x in enumerate(np.arange(0.5, 512, 1)):
                    plt.axhline(x, color="#eeeeee" if i % 16 != 15 else "#cccccc", lw=0.5)
                plt.plot(expected_changes[1], expected_changes[0], "o", label="Expected")
                plt.plot(applied_changes[1], applied_changes[0], "x", label="Applied")
                plt.gcf().set_size_inches(32, 24)
                plt.tight_layout()
                plt.savefig("result.pdf")
                print(f"Expected changes count={len(expected_changes[0])}")
                print(f"Applied changes count={len(applied_changes[0])}")
                import sys
                sys.exit(0)
        self._cur_blob = nxt_blob
        self._cur = self._nxt
        self._nxt = None
[docs]
    def update(self, pixel_cfg: Optional[ArrayLike], hold_off=False, validate=False):
        self._nxt = np.array(pixel_cfg)
        if not hold_off:
            self._update(validate) 
 
[docs]
def logic2chip_cfg_matrix(pixelcfg):
    """
    Converts a logical pixel configuration (numpy array shape=(512,448), dtype=u8)
    to a flat numpy array which can be fed into the tpx4.ConfigPixels() called.
    Parameters
    ----------
        pixelcfg : ndarray
            Pixel configuration in [rows,column] ordering
    Retruns
    -------
    ndarray
        A flat array which can be fed to `rpc.Tpx4PixelConfig()`
    ---
    pixel_config = np.zeros(shape=(512, 448))
    pixel_config[5, 5] = 0xF
    config_blob = logic2chip_matrix(pixel_config)
    tpx4.ConfigPixels(
      rpc.Tpx4PixelConfig(idx = 0, config=config_blob.tobytes())
    )
    ---
    """
    if pixelcfg.shape != (512, 448):
        raise ValueError("Expected array of 512x448")
    pixelcfg = pixelcfg.flatten()
    return pixelcfg[_CHIP2LOGIC_CFG_IDX] 
[docs]
def chip2logic_cfg_matrix(configblob):
    """
    Converts a pixel configuration as returned by GetConfigPixels to a 2D array (rows x cols).
    :param configblob: numpy array containing pixel configuration data
    :return: A numpy array shape=(512,448), dtype=u8
    matrix_response=tpx4.GetConfigPixels(
      rpc.ChipIndex(idx=0)
    )
    pixelconfig=chip2logic_matrix(matrix_response.config)
    print(pixel_config[5,5])
    """
    return configblob[_LOGIC2CHIP_CONFIG_IDX].reshape(512, 448) 
[docs]
def decode_eoc_mon(reg_value):
    """
    Decodes an EoC monitoring column
    :param reg_value:       Reg value
    :return:                Tuple <DLL code>,<locked status>
    """
    dll_locked = (reg_value & 1) != 0
    dll_code = (0xF & (reg_value >> 5)) | ((reg_value << 3) & 0xF0)
    return dll_code, dll_locked 
[docs]
def decode_dd_packet(val, pc24=False):
    """
    Decodes data-driven mode
    :param val:
    :return:
    """
    if ((val >> 55) & 0xFF) >= 0xE0:
        return ControlStatusPacket(val)
    return PC24Packet(val) if pc24 else ToAToTPacket(val) 
[docs]
def fb8_decode(pkt_gen):
    """
    returns the first full frame found while scanning the packet generator.
    A packet generator can be created through the use ``stream.unwrap_stream(strm).``
    :param pkt_gen:     An iterator providing 64 bit packets.
    :return:                    a 2D array (segment, pixel data)
    """
    # seek segment start
    for pkt in pkt_gen:
        if 0xFF & (pkt >> 55) == CS_FRAME_START:
            break
    else:
        raise ValueError("Frame decode error: Frame start not found")
    stack = []
    for i in range(8):
        if 0xFF & (next(pkt_gen) >> 55) != CS_SEGMENT_START:
            raise ValueError("Frame decode error: Expected segment start")
        segment = np.fromiter(pkt_gen, dtype=np.uint64, count=PACKETS_COUNT_FB_SEGMENT)
        stack.append(segment.view(dtype=np.uint8))
        if 0xFF & (next(pkt_gen) >> 55) != CS_SEGMENT_END:
            raise ValueError("Frame decode error: Expected segment end")
    if 0xFF & (next(pkt_gen) >> 55) != CS_FRAME_END:
        raise ValueError("Frame decode error: Expected frame end")
    return np.stack(stack) 
[docs]
def fb_to_image(bottom, top):
    """
    Converts frame-based acquisition arrays to
    an 448x512 pixel image.
    Input was was received by fb8_decode or fb16_decode
    :param bottom:  Bottom frame-based segements
    :param top:     Top frame-based segments
    :return:        A pixel array of 512x448 pixels
    """
    pass 
# Interal function
def __set_prbs(tpx4, is_top, channels, mode):
    if 0x1 <= mode <= 0x5:
        val = channels & 0x0000_00FF
        for i in range(0, 8):
            if channels & (0x01 << i):
                val |= mode << (8 + i*3) & 0x7 << (8 + i*3)
    else:
        val = 0
    tpx4.SimpleWrite(rpc.SimpleWriteRequest(addr=0xcb02 if is_top else 0x4b02, val=val))
[docs]
def enable_prbs(tpx4, top_channels, bot_channels, mode):
    """
    Enable PRBS for the channels you want
    :param tpx4:            The gRPC service tpx4 stub
    :param top_channels:    Bitmask of top channels to enable
    :param bot_channels:    Bitmask of bottom channels to enable
    :param mode:            The mode to select.
    """
    __set_prbs(tpx4, True, top_channels, mode)
    __set_prbs(tpx4, False, bot_channels, mode) 
[docs]
def disable_prbs(tpx4):
    """
    Disables the PRBS generation
    :param tpx4:            The gRPC service tpx4 stub
    """
    __set_prbs(tpx4, True, 0, 0)
    __set_prbs(tpx4, False, 0, 0) 
if __name__ == "__main__":
    import unittest
    import numpy as np
    class TestCoords(unittest.TestCase):
        def test_logic2chip_coords(self):
            self.assertEqual(logic2chip_coords(0, 0), (0, 0, 0, 0, 0))
            self.assertEqual(logic2chip_coords(0, 511), (1, 223, 0, 0, 4))
            self.assertEqual(logic2chip_coords(447, 511), (1, 0, 0, 0, 0))
            self.assertEqual(logic2chip_coords(444, 490), (1, 1, 1, 1, 5))
        def test_chip_coords2logic(self):
            self.assertEqual(chip_coords2logic(0, 0, 0, 0, 0), (0, 0))
            self.assertEqual(chip_coords2logic(1, 223, 0, 0, 4), (0, 511))
            self.assertEqual(chip_coords2logic(1, 0, 0, 0, 0), (447, 511))
            self.assertEqual(chip_coords2logic(1, 1, 1, 1, 5), (444, 490))
        def test_chip_coords2logic_idx(self):
            self.assertEqual(chip_coords2logic_idx(1, 0, 0, 0, 0), 512 * 448 - 1)
        def test_mapping(self):
            for y in range(448):
                for x in range(512):
                    chip_coords = logic2chip_coords(x, y)
                    xx, yy = chip_coords2logic(*chip_coords)
                    self.assertEqual(xx, x)
                    self.assertEqual(yy, y)
        
        def test_pixelconfig(self):
            original = np.random.randint(0, 256, (512, 448), dtype=np.uint8)
            blob = logic2chip_cfg_matrix(original)
            recreated = chip2logic_cfg_matrix(blob)
            np.testing.assert_array_equal(original, recreated)
    unittest.main()