"""
Volta → TouchDesigner Bootstrap
================================

Fetches a Volta layout via the REST API and generates a complete TouchDesigner
network: WebSocket DAT, per-control Constant CHOPs (and Text DATs for String
controls), Index option tables, and a dynamic callback script that routes
incoming OSC messages to the right operators.

USAGE
-----
1. Drop this file into a Text DAT named `volta_bootstrap` in your TD project.
2. Open the Textport (Alt+T) and call:

       mod('volta_bootstrap').build(
           layout_id = 'VAL-01ARZ3NDEKTSV4RRFFQ69G5FAV',
           api_key   = 'your-rest-api-key',
           ws_host   = 'abc123.execute-api.eu-west-2.amazonaws.com/production',
           parent    = op('/project1'),
       )

   A container COMP named `volta_<id-suffix>` is created inside `parent`.

3. To rebuild after a layout change, call `build()` again — it destroys and
   recreates the container.

WHAT GETS CREATED
-----------------
volta_<suffix>/
  websocket          WebSocket DAT  — connects to Volta, wss:// port 443
  ws_callbacks       Text DAT       — generated OSC dispatch script
  lid_001            Constant CHOP  — value channel for control at /001
  lid_001_trigger    Constant CHOP  — pulses 1→0 each Button press  (Button only)
  lid_002            Constant CHOP  — e.g. Slider 0–1 value
  lid_003_x          Constant CHOP  — XYPad / Accelerometer / Gyroscope, per axis
  lid_003_y          Constant CHOP
  options_004        Table DAT      — index/label/image rows  (Index only)
  all_controls       Select CHOP    — merges every value CHOP into one place

REQUIREMENTS
------------
- REST API key (x-api-key header) — ask your Volta account owner
- WebSocket host — the API Gateway domain shown in your Volta deployment output
  (format: <id>.execute-api.<region>.amazonaws.com/<stage>)
- TouchDesigner 2023+ (Python 3.11 compatible)

TOKEN NOTE
----------
A JWT is fetched from POST /token during build() and embedded in the WebSocket
URL as ?aeToken=<jwt> — the $connect Lambda reads queryStringParameters.aeToken
to verify the caller. JWTs are short-lived; if the connection drops for more
than ~15 minutes, call build() again to refresh the token.
"""

import json
import ssl
import struct
import urllib.request
import urllib.error

# Default base URL for the Volta REST API (production)
_VOLTA_REST_BASE_DEFAULT = 'https://api.volta-xr.com'

# TouchDesigner's bundled Python on macOS doesn't ship CA certificates, so SSL
# verification fails for any HTTPS request. We create an unverified context as a
# fallback — traffic is to known Volta/AWS endpoints so this is acceptable.
def _ssl_ctx():
    try:
        import certifi
        return ssl.create_default_context(cafile=certifi.where())
    except ImportError:
        ctx = ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = ssl.CERT_NONE
        return ctx

# Grid layout for generated OPs (pixels)
_COL_WIDTH  = 230
_ROW_HEIGHT = 160

# Control types that produce multiple axes rather than a single value CHOP
_AXIS_CHANNELS = {
    'XYPad':        ('x', 'y'),
    'XYPathPad':    ('x', 'y'),
    'Accelerometer': ('x', 'y', 'z'),
    'Gyroscope':    ('x', 'y', 'z'),
}

# Control types backed by a Text DAT instead of a Constant CHOP
_STRING_TYPES = {'String'}

# Control types that get a per-press trigger CHOP alongside their value CHOP
_TRIGGER_TYPES = {'Button'}


# ── Public entry point ────────────────────────────────────────────────────────

def build(layout_id, api_key, ws_host, parent=None, rest_base=None):
    """
    Fetch the Volta layout and generate a self-contained TouchDesigner network.

    Args:
        layout_id (str): Volta layout ID, e.g. 'VAL-01ARZ3NDEKTSV4RRFFQ69G5FAV'
        api_key   (str): REST API key for the x-api-key header
        ws_host   (str): WebSocket host, e.g. 'abc.execute-api.eu-west-2.amazonaws.com/production'
        parent    (OP):  COMP to build inside. Defaults to op('/project1').
        rest_base (str): REST API base URL. Defaults to 'https://api.volta-xr.com'.
                         Override for staging, e.g. 'https://pr42.api.volta-xr.com'.

    Returns:
        The created baseCOMP container.
    """
    if parent is None:
        parent = op('/project1')
    if rest_base is None:
        rest_base = _VOLTA_REST_BASE_DEFAULT

    # 1. Fetch layout from REST API
    layout   = _fetch_layout(layout_id, api_key, rest_base)
    controls = layout.get('controls', [])
    print(f'[Volta] Layout {layout_id!r}: {len(controls)} controls, page {layout.get("activePageId")}')

    # 2. Fetch JWT for WebSocket authentication
    token = _fetch_token(layout_id, api_key, rest_base)

    # 3. Create (or replace) container COMP
    suffix         = layout_id.split('-')[-1][:10]
    container_name = f'volta_{suffix}'
    existing       = parent.op(container_name)
    if existing:
        existing.destroy()
    container = parent.create(baseCOMP, container_name)
    container.nodeX, container.nodeY = 0, 0

    # 4. Generate network inside the container
    _build_controls(container, controls)
    callback_dat = _build_callback_dat(container, controls)
    _build_websocket(container, ws_host, token, callback_dat.name)
    _build_merge(container, controls)

    print(f'[Volta] Network "{container_name}" ready.')
    return container


# ── REST helpers ──────────────────────────────────────────────────────────────

def _fetch_layout(layout_id, api_key, rest_base):
    url = f'{rest_base}/layout/{layout_id}'
    req = urllib.request.Request(url, headers={'x-api-key': api_key})
    with urllib.request.urlopen(req, timeout=10, context=_ssl_ctx()) as resp:
        return json.loads(resp.read())


def _fetch_token(layout_id, api_key, rest_base):
    url  = f'{rest_base}/token'
    body = json.dumps({
        'layoutId': layout_id,
        'permissions': ['layoutWrite', 'layoutStateWrite', 'layoutDestination'],
    }).encode()
    req  = urllib.request.Request(
        url, data=body,
        headers={'x-api-key': api_key, 'Content-Type': 'application/json'},
        method='POST',
    )
    with urllib.request.urlopen(req, timeout=10, context=_ssl_ctx()) as resp:
        return json.loads(resp.read())['token']


# ── Network builders ──────────────────────────────────────────────────────────

def _chop_name(ctrl):
    """Canonical CHOP/DAT name for a control: lid_NNN (e.g. lid_001)."""
    return 'lid_' + ctrl['address'].lstrip('/')


def _build_controls(container, controls):
    """Create one or more OPs per control, laid out in a grid."""
    for i, ctrl in enumerate(controls):
        col   = i % 5
        row   = i // 5
        x     = col * _COL_WIDTH - 5 * _COL_WIDTH // 2
        y     = -row * _ROW_HEIGHT - _ROW_HEIGHT
        ctype = ctrl['type']
        name  = _chop_name(ctrl)

        if ctype in _STRING_TYPES:
            # Text DAT for free-text inputs
            dat = container.create(textDAT, name)
            dat.nodeX, dat.nodeY = x, y
            dat.clear()

        elif ctype in _AXIS_CHANNELS:
            # Multi-axis controls: one Constant CHOP per axis, named lid_NNN_x etc.
            axes = _AXIS_CHANNELS[ctype]
            for j, axis in enumerate(axes):
                chop = container.create(constantCHOP, f'{name}_{axis}')
                chop.nodeX, chop.nodeY = x + j * 110, y
                chop.par.value0.val = 0.5 if axis in ('x', 'y') else 0

        else:
            # Single-value CHOP (Button, Slider, Toggle, Index, …)
            chop = container.create(constantCHOP, name)
            chop.nodeX, chop.nodeY = x, y
            chop.par.value0.val = 0

            # Button also gets a trigger pulse CHOP
            if ctype in _TRIGGER_TYPES:
                trig = container.create(constantCHOP, f'{name}_trigger')
                trig.nodeX, trig.nodeY = x, y - 90
                trig.par.value0.val = 0

        # Index controls: Table DAT with option labels (and image URLs if present)
        options = ctrl.get('options') or []
        if options:
            _build_options_table(container, ctrl, options, x, y + _ROW_HEIGHT)


def _build_options_table(container, ctrl, options, x, y):
    """Create a Table DAT listing the selectable options for an Index control."""
    address    = ctrl['address'].lstrip('/')
    has_images = any(o.get('image') for o in options)

    tbl = container.create(tableDAT, f'options_{address}')
    tbl.nodeX, tbl.nodeY = x, y
    tbl.clear()

    tbl.appendRow(['index', 'label'] + (['image'] if has_images else []))
    for i, opt in enumerate(options):
        row = [str(opt.get('index', i)), opt.get('label', str(i))]
        if has_images:
            row.append(opt.get('image', ''))
        tbl.appendRow(row)


def _build_websocket(container, ws_host, token, callback_dat_name):
    """Create the WebSocket DAT configured to connect to Volta."""
    ws = container.create(websocketDAT, 'websocket')
    ws.nodeX, ws.nodeY = -700, 0

    # The $connect handler reads queryStringParameters.aeToken to verify the JWT
    ws.par.netaddress        = f'{ws_host}?aeToken={token}'
    ws.par.port              = 443
    ws.par.callbacks         = callback_dat_name
    ws.par.active            = True
    # Note: wss:// is automatic for port 443 in TD. For non-443 wss ports,
    # prefix the address with "wss://".


def _build_merge(container, controls):
    """Create a Select CHOP that surfaces all value CHOPs in one place."""
    names = []
    for ctrl in controls:
        ctype = ctrl['type']
        name  = _chop_name(ctrl)
        if ctype in _STRING_TYPES:
            continue  # Text DATs can't be merged into a CHOP
        if ctype in _AXIS_CHANNELS:
            names.extend(f'{name}_{ax}' for ax in _AXIS_CHANNELS[ctype])
        else:
            names.append(name)

    if not names:
        return

    sel = container.create(selectCHOP, 'all_controls')
    sel.nodeX, sel.nodeY = 700, 0
    sel.par.chop = ' '.join(names)


def _build_callback_dat(container, controls):
    """Generate and install the WebSocket callback script as a Text DAT."""
    addr_map         = {c['address']: c['type'] for c in controls}
    button_addresses = [c['address'] for c in controls if c['type'] in _TRIGGER_TYPES]
    axis_addresses   = {c['address']: _AXIS_CHANNELS[c['type']]
                        for c in controls if c['type'] in _AXIS_CHANNELS}

    script = _generate_callback_script(addr_map, button_addresses, axis_addresses)

    dat = container.create(textDAT, 'ws_callbacks')
    dat.nodeX, dat.nodeY = -700, -200
    dat.clear()
    dat.write(script)
    return dat


# ── Callback script generation ────────────────────────────────────────────────

# The template uses __PLACEHOLDER__ tokens replaced via .replace() at generation
# time. Since it's NOT an f-string, {var} in the output is just {var} — no escaping needed.
_CALLBACK_TEMPLATE = '''\
"""
Auto-generated by volta_bootstrap.py — do not edit by hand.
Call mod("volta_bootstrap").build(...) to regenerate if the layout changes.

MESSAGE FORMAT (Volta WebSocket -> destination):
  {"action": "audienceAction", "data": "<base64 OSC packet>"}

OSC PACKET LAYOUT:
  address  /NNN  -- matches control lid (3-digit, zero-padded)
  args[0]  primary value  (index, float, int, bool)
  args[1]  control message ID
  args[2]  hashed audience ID
  args[3]  cumulative count  (Button with isCumulative only)
"""

import json
import struct
import base64


# -- OSC decoding -------------------------------------------------------------

def _decode_osc(b64_data):
    data     = base64.b64decode(b64_data)
    addr_end = data.index(0)
    address  = data[:addr_end].decode('ascii')

    tt_off   = (addr_end + 4) & ~3
    tt_end   = data.index(0, tt_off)
    type_tag = data[tt_off:tt_end].decode('ascii')

    offset = (tt_end + 4) & ~3
    args   = []
    for t in type_tag[1:]:  # skip leading comma
        if t in ('i', 'I'):
            args.append(struct.unpack('>i', data[offset:offset + 4])[0]); offset += 4
        elif t == 'f':
            args.append(struct.unpack('>f', data[offset:offset + 4])[0]); offset += 4
        elif t == 's':
            s_end = data.index(0, offset)
            args.append(data[offset:s_end].decode('utf-8'))
            offset = (s_end + 4) & ~3
        elif t == 'T':
            args.append(1)
        elif t == 'F':
            args.append(0)

    return address, type_tag, args


# -- WebSocket DAT callbacks --------------------------------------------------

def onReceiveText(dat, rowIndex, message):
    try:
        envelope = json.loads(message)
    except Exception:
        return

    if envelope.get('action') != 'audienceAction':
        debug('[Volta] Unhandled action: ' + str(envelope.get('action')))
        return

    b64 = envelope.get('data')
    if not b64:
        return

    try:
        address, type_tag, args = _decode_osc(b64)
    except Exception as e:
        debug('[Volta] OSC decode error: ' + str(e))
        return

    ctrl_type = CONTROL_TYPES.get(address)
    if ctrl_type is None:
        debug('[Volta] Unknown address: ' + address)
        return

    chop_name = 'lid_' + address.lstrip('/')
    value     = args[0] if args else None

    if ctrl_type == 'Button':
        chop = op(chop_name)
        if chop and value is not None:
            chop.par.value0 = float(value)
        trigger = op(chop_name + '_trigger')
        if trigger:
            trigger.par.value0 = 1
            run('op("' + chop_name + '_trigger").par.value0 = 0', delayFrames=1)

    elif ctrl_type in ('Slider', 'Toggle', 'Index'):
        chop = op(chop_name)
        if chop and value is not None:
            chop.par.value0 = float(value)

    elif ctrl_type in __AXIS_TYPES__:
        axes = __AXIS_MAP__.get(address, ())
        for i, axis in enumerate(axes):
            chop = op(chop_name + '_' + axis)
            if chop and i < len(args):
                chop.par.value0 = float(args[i])

    elif ctrl_type == 'String':
        dat_op = op(chop_name)
        if dat_op and value is not None:
            dat_op.clear()
            dat_op.write(str(value))

    debug('[Volta] ' + address + ' (' + str(ctrl_type) + ') = ' + str(value))


def onConnect(dat):
    debug('[Volta] WebSocket connected')


def onDisconnect(dat):
    debug('[Volta] WebSocket disconnected')


def onReceivePing(dat, contents):
    pass  # API Gateway keepalive


def onReceivePong(dat, contents):
    pass


def onMonitorMessage(dat, message):
    debug('[Volta WS] ' + str(message))


# -- Generated control map ----------------------------------------------------

CONTROL_TYPES = __CONTROL_TYPES__

# Multi-axis control types present in this layout
__AXIS_TYPES__ = __AXIS_TYPE_SET__

# address -> axis names, e.g. {"/003": ["x", "y"]}
__AXIS_MAP__ = __AXIS_MAP_VALUE__
'''


def _generate_callback_script(addr_map, button_addresses, axis_addresses):
    control_types_repr = json.dumps(addr_map, indent=4)
    axis_type_set_repr = repr(set(addr_map[a] for a in axis_addresses))
    axis_map_repr      = json.dumps(
        {addr: list(axes) for addr, axes in axis_addresses.items()}, indent=4
    )

    script = _CALLBACK_TEMPLATE
    script = script.replace('__CONTROL_TYPES__',  control_types_repr)
    script = script.replace('__AXIS_TYPE_SET__',  axis_type_set_repr)
    script = script.replace('__AXIS_MAP_VALUE__', axis_map_repr)
    return script
