from __future__ import annotations
import numpy as np
from ...common.constants import Constants
from ..core.i_port import IPort
from ..core.io_node import IONode
from ..core.o_port import OPort
# Port name constants for convenience
PORT_IN = Constants.Defaults.PORT_IN
PORT_OUT = Constants.Defaults.PORT_OUT
[docs]
class Trigger(IONode):
"""Event-triggered data extraction node for BCI applications.
Monitors trigger events and extracts time-locked data segments around
trigger occurrences. Maintains a rolling buffer of input data and outputs
complete data epochs when target trigger values are detected. Commonly
used in event-related potential (ERP) analysis.
"""
# Default time window values in seconds
DEFAULT_TIME_PRE = 0.7 # Pre-trigger window
DEFAULT_TIME_POST = 0.2 # Post-trigger window
# Port name for trigger input
PORT_TRIGGER = "trigger"
[docs]
class Configuration(IONode.Configuration):
"""Configuration class for Trigger parameters."""
[docs]
class Keys(IONode.Configuration.Keys):
"""Configuration key constants for the Trigger."""
TIME_PRE = "time_pre"
TIME_POST = "time_post"
TARGET = "target"
[docs]
def __init__(
self,
time_pre: float = None,
time_post: float = None,
target: float = None,
**kwargs,
):
"""Initialize the Trigger node with timing and target configurations.
Args:
time_pre: Time in seconds before trigger to include in epoch.
Must be > 0. Defaults to 0.7 seconds.
time_post: Time in seconds after trigger to include in epoch.
Must be > 0. Defaults to 0.2 seconds.
target: Trigger value(s) that cause epoch extraction. Can be
single value or list. Defaults to [1].
**kwargs: Additional configuration parameters passed to IONode.
Raises:
ValueError: If time_pre or time_post is <= 0.
"""
# Set default values if not provided
if time_pre is None:
time_pre = self.DEFAULT_TIME_PRE
if time_post is None:
time_post = self.DEFAULT_TIME_POST
if target is None:
target = [1]
# Ensure target is always a list for consistent handling
if type(target) is not list:
target = [target]
# Validate timing parameters
if time_pre <= 0:
raise ValueError("time_pre must be greater than 0.")
if time_post <= 0:
raise ValueError("time_post must be greater than 0.")
# Configure input ports: data port and trigger port
input_ports = [
IPort.Configuration(), # Main data input
IPort.Configuration(
name=self.PORT_TRIGGER, timing=Constants.Timing.INHERITED
),
]
# Configure output port with asynchronous timing (trigger-dependent)
output_ports = [OPort.Configuration(timing=Constants.Timing.ASYNC)]
# Initialize parent IONode with all configurations
super().__init__(
time_pre=time_pre,
time_post=time_post,
target=target,
input_ports=input_ports,
output_ports=output_ports,
**kwargs,
)
# Initialize internal state variables
self._buf_input = None # Rolling input data buffer
self._buf_output = None # Output buffer (legacy, unused)
self._frame_size = None # Total epoch frame size
self._target = None # Target trigger values
self._countdown = None # List of active countdown timers
self._samples_pre = None # Pre-trigger samples count
self._samples_post = None # Post-trigger samples count
self._last_trigger = None # Last seen trigger value
[docs]
def setup(
self, data: dict[str, np.ndarray], port_context_in: dict[str, dict]
) -> dict[str, dict]:
"""Set up the Trigger node and initialize internal buffers.
Validates input requirements, calculates buffer sizes based on
sampling rate and timing parameters, and initializes the rolling
input buffer for data collection.
Args:
data: Initial data dictionary for port configuration.
port_context_in: Input port context with sampling rates, frame
sizes, and channel counts.
Returns:
Output port context with updated frame size and timing
information for the extracted epochs.
Raises:
ValueError: If input frame size is not 1 or sampling rate is
not provided.
"""
# Call parent setup to get base output context
port_context_out = super().setup(data, port_context_in)
# Validate input frame size requirement for precise timing
frame_size_in = port_context_in[PORT_IN][Constants.Keys.FRAME_SIZE]
if frame_size_in != 1:
raise ValueError("Input frame size must be 1.")
# Get configuration parameters
tpre_key = self.Configuration.Keys.TIME_PRE
tpost_key = self.Configuration.Keys.TIME_POST
time_pre = self.config[tpre_key]
time_post = self.config[tpost_key]
# Get sampling rate for time-to-samples conversion
sampling_rate = port_context_in[PORT_IN][Constants.Keys.SAMPLING_RATE]
if sampling_rate is None:
raise ValueError("Sampling rate must be provided in context.")
# Convert time windows to sample counts
self._samples_pre = int(round(time_pre * sampling_rate))
self._samples_post = int(round(time_post * sampling_rate))
frame_size_out = self._samples_pre + self._samples_post
# Update output port context with epoch specifications
cc_key = Constants.Keys.CHANNEL_COUNT
fsz_key = Constants.Keys.FRAME_SIZE
timing_key = OPort.Configuration.Keys.TIMING
# Get channel count and timing from input context
channel_count = port_context_out[PORT_OUT][cc_key]
timing = port_context_out[PORT_OUT][timing_key][PORT_IN]
# Set output context values
port_context_out[PORT_OUT][cc_key] = channel_count
port_context_out[PORT_OUT][fsz_key] = frame_size_out
port_context_out[PORT_OUT][tpre_key] = time_pre
port_context_out[PORT_OUT][tpost_key] = time_post
port_context_out[PORT_OUT][timing_key] = timing
# Initialize internal buffers and state
self._buf_input = np.zeros(shape=(frame_size_out, channel_count))
self._buf_output = [] # Legacy buffer, kept for compatibility
self._frame_size = frame_size_out
self._target = self.config[self.Configuration.Keys.TARGET]
self._countdown = [] # List of active trigger countdowns
# Initialize last trigger to value outside target range
self._last_trigger = min(self._target) - 1
return port_context_out
[docs]
def step(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Process one frame of data and check for trigger events.
Updates the rolling input buffer with new data, monitors trigger
changes, and extracts complete epochs when countdown timers expire.
Multiple triggers can be active simultaneously.
Args:
data: Dictionary containing input data arrays. Must include
both the main data port and trigger port data.
Returns:
Dictionary containing extracted epoch data when a trigger
countdown completes, None otherwise. Epoch has shape
(samples_pre + samples_post, channel_count).
"""
# Update rolling buffer with new data sample
# Efficiently shift buffer: move all rows up by one position
self._buf_input[:-1] = self._buf_input[1:]
self._buf_input[-1] = data[PORT_IN][-1] # Add newest sample at end
# Check for trigger state changes
trigger = data[self.PORT_TRIGGER]
if trigger != self._last_trigger:
# New trigger detected - check if it's a target value
if trigger in self._target:
# Start new countdown for this trigger event
self._countdown.append(self._samples_post)
self._last_trigger = trigger
# Process all active countdowns
for i in reversed(range(len(self._countdown))):
self._countdown[i] -= 1 # Decrement countdown
# Check if countdown has completed
if self._countdown[i] <= 0:
# Remove completed countdown and return epoch data
self._countdown.pop(i)
return {PORT_OUT: self._buf_input.copy()}
# No epochs ready - return None (asynchronous behavior)
return None