Source code for OpenPinch.utils.miscellaneous

"""Shared numerical helpers."""

from typing import Any, Iterable, Tuple, Union

import numpy as np

from ..classes.value import Value
from ..lib.config import tol
from ..lib.schemas.common import MaybeVU, StatefulValueWithUnit, ValueWithUnit

__all__ = [
    "clean_composite_curve",
    "clean_composite_curve_ends",
    "delta_vals",
    "delta_with_zero_at_start",
    "g_ineq_penalty",
    "get_state_index",
    "get_value",
    "graph_simple_cc_plot",
    "interp_with_plateaus",
    "linear_interpolation",
    "make_monotonic",
    "resolve_stream_attr",
    "resolve_stream_attr_array",
]


def _require_plotly():
    try:
        import plotly.graph_objects as go
    except ImportError as exc:  # pragma: no cover - optional dependency guard
        raise ImportError(
            "Plotly is required for graph_simple_cc_plot. "
            "Install it directly or reinstall OpenPinch with "
            "'pip install openpinch[notebook]' or 'pip install openpinch[dashboard]'."
        ) from exc
    return go


def resolve_value_for_state(
    val: Any,
    state_id: str | None = None,
    *,
    state_ids: dict[str, int] | list[str] | None = None,
    default_allowed: bool = True,
) -> float | None:
    """Return one scalar magnitude from a scalar or stateful value-like object."""
    if val is None:
        return None
    if isinstance(val, Value):
        raw_value = val
    else:
        raw_value = Value(val)

    if len(raw_value.state_values) <= 1:
        return float(raw_value.value)

    if state_ids is None and isinstance(val, dict) and val.get("state_ids") is not None:
        state_ids = {str(sid): idx for idx, sid in enumerate(val["state_ids"])}

    if isinstance(state_ids, dict):
        state_lookup = {str(sid): int(idx) for sid, idx in state_ids.items()}
    elif state_ids:
        state_lookup = {str(sid): idx for idx, sid in enumerate(state_ids)}
    else:
        state_lookup = None

    if state_lookup is None:
        if not default_allowed and state_id is not None:
            raise ValueError("state_ids are required for stateful values.")
        return float(raw_value[0].value)

    resolved_state_id = None if state_id is None else str(state_id)
    if resolved_state_id is None:
        if not default_allowed:
            raise ValueError("state_id is required for stateful values.")
        resolved_state_id = "0" if "0" in state_lookup else next(iter(state_lookup))

    if resolved_state_id not in state_lookup:
        raise ValueError(
            f"Unknown state_id {resolved_state_id!r}. "
            f"Available states: {', '.join(state_lookup)}."
        )
    return float(raw_value[state_lookup[resolved_state_id]].value)


[docs] def resolve_stream_attr( stream: Any, attr_name: str, state_id: str | None = None, *, default_allowed: bool = True, ) -> float | None: """Resolve one stream attribute to a scalar for the selected state.""" if not hasattr(stream, attr_name): raise AttributeError(f"Stream {stream!r} has no attribute {attr_name!r}.") return resolve_value_for_state( getattr(stream, attr_name), state_id=state_id, state_ids=getattr(stream, "state_ids", None), default_allowed=default_allowed, )
[docs] def resolve_stream_attr_array( streams: Iterable[Any], attr_name: str, state_id: str | None = None, *, default_allowed: bool = True, ) -> np.ndarray: """Resolve one attribute across a stream iterable into a float array.""" return np.asarray( [ resolve_stream_attr( stream, attr_name, state_id=state_id, default_allowed=default_allowed, ) for stream in streams ], dtype=float, )
[docs] def get_value( val: Union[float, int, str, dict, "ValueWithUnit", None], val2: Union[float, int, str, None] = None, zone_name: str = None, state_id: str | None = None, ) -> float: """Extract a numeric value from supported scalars and payload wrappers.""" if isinstance(val, bool): raise TypeError( "Unsupported type: " f"{type(val)}. Expected float, int, numeric string, dict, " "or ValueWithUnit." ) elif isinstance(val, Value): return resolve_value_for_state(val, state_id=state_id) elif isinstance(val, (float, int)): return float(val) elif hasattr(val, "model_dump"): return get_value( val.model_dump(mode="python"), val2=val2, zone_name=zone_name, state_id=state_id, ) elif isinstance(val, dict): if zone_name in val: return get_value(val[zone_name], val2=val2, state_id=state_id) if _is_stateful_value_payload(val): return resolve_value_for_state(val, state_id=state_id) payload = val.copy() if "value" not in payload: if val2 is None: raise KeyError("value") payload["value"] = val2 if len(payload) > 2: raise ValueError( "Invalid payload: more than one operation specified. Payload " "must contain only 'value' and at most one of " "'multiplier', 'multiply', 'add', 'subtract', 'divide', " "'power', 'log', 'exp', 'abs', 'min', or 'max'." ) value = get_value(payload["value"], state_id=state_id) if "multiplier" in payload: return value * get_value(payload["multiplier"], state_id=state_id) elif "multiply" in payload: return value * get_value(payload["multiply"], state_id=state_id) elif "add" in payload: return value + get_value(payload["add"], state_id=state_id) elif "subtract" in payload: return value - get_value(payload["subtract"], state_id=state_id) elif "divide" in payload: return ( value / get_value(payload["divide"], state_id=state_id) if value != 0 else 0.0 ) elif "power" in payload: return value ** get_value(payload["power"], state_id=state_id) elif "log" in payload: base = payload["log"] if isinstance(payload["log"], float) else np.e return np.log(value) / np.log(base) if value > 0 else 0.0 elif "exp" in payload: base = payload["exp"] if isinstance(payload["exp"], float) else np.e return base**value if value > 0 else 0.0 elif "abs" in payload: return abs(value) elif "min" in payload: return min(value, get_value(payload["min"], state_id=state_id)) elif "max" in payload: return max(value, get_value(payload["max"], state_id=state_id)) else: return value elif _is_value_with_unit(val): return val.value elif isinstance(val, str): try: return float(val) except ValueError: raise TypeError( f"Unsupported string value: {val}. String must be convertible to float." ) elif val is None and val2 is not None: return get_value(val2, zone_name=zone_name, state_id=state_id) elif val is None: return None else: raise TypeError("Unsupported type")
def get_values(obj: MaybeVU) -> np.ndarray: if isinstance(obj, ValueWithUnit): return np.asarray(obj.value) elif isinstance(obj, (float, int)): return np.asarray([float(obj)]) elif isinstance(obj, StatefulValueWithUnit): return np.asarray(obj.values) elif obj is None: return np.array([]) else: raise TypeError("Unsupported type") def _is_value_with_unit(val: Any) -> bool: """Return ``True`` for objects that look like ``ValueWithUnit`` containers.""" return hasattr(val, "value") and hasattr(val, "unit") def _is_stateful_value_payload(val: Any) -> bool: """Return ``True`` for dict-like stateful value payloads.""" if not isinstance(val, dict): return False keys = set(val) return keys.issubset({"values", "state_ids", "weights", "unit"}) and ( "values" in keys or "state_ids" in keys or "weights" in keys ) def get_state_index( state_ids: dict[str, int] | None, args: dict | None, ) -> Tuple[int, str | None]: sid = None if not isinstance(args, dict) else args.get("state_id") sid = None if sid is None else str(sid) raw_idx = None if not isinstance(args, dict) else args.get("idx") explicit_idx = None if raw_idx is None else int(raw_idx) lookup = {} if state_ids is None else state_ids if sid is not None: if lookup and sid not in lookup: raise ValueError( f"state_id {sid!r} was not found on this collection. " f"Available states: {', '.join(lookup)}." ) resolved_idx = lookup.get(sid, 0) if explicit_idx is not None and explicit_idx != resolved_idx: raise ValueError( f"state_id {sid!r} resolves to idx {resolved_idx}, " f"but idx {explicit_idx} was also provided." ) return resolved_idx, sid if explicit_idx is not None: if explicit_idx < 0: raise ValueError("idx must be a non-negative integer.") if lookup and explicit_idx not in set(lookup.values()): raise ValueError( f"idx {explicit_idx} was not found on this collection. " f"Available indices: {', '.join(str(idx) for idx in lookup.values())}." ) return explicit_idx, None return 0, None
[docs] def linear_interpolation( xi: float, x1: float, x2: float, y1: float, y2: float ) -> float: """Estimate ``y`` at ``xi`` using two known points and linear interpolation.""" if x1 == x2: raise ValueError( "Cannot perform interpolation when x1 == x2 (undefined slope)." ) m = (y1 - y2) / (x1 - x2) c = y1 - m * x1 yi = m * xi + c return yi
[docs] def delta_with_zero_at_start(x: np.ndarray) -> np.ndarray: """Compute successive differences and prepend a zero entry.""" return np.insert(delta_vals(x), 0, 0.0)
[docs] def delta_vals(x: np.ndarray, descending_vals: bool = True) -> np.ndarray: """Compute difference between successive entries in a column.""" deltas = x[:-1] - x[1:] if descending_vals else x[1:] - x[:-1] deltas[np.abs(deltas) <= tol] = 0.0 return deltas
[docs] def clean_composite_curve_ends( y_vals: np.ndarray | list, x_vals: np.ndarray | list ) -> Tuple[np.ndarray, np.ndarray]: """Remove redundant points in composite curves.""" y_vals = np.array(y_vals) x_vals = np.array(x_vals) if np.all(np.isclose(x_vals, 0.0, atol=tol)) or np.abs(x_vals.var()) < tol: return np.array([]), np.array([]) mask_0 = ~np.isclose(x_vals, x_vals[0] * np.ones(len(x_vals)), atol=tol) start = np.flatnonzero(mask_0)[0] - 1 mask_1 = ~np.isclose(x_vals, x_vals[-1] * np.ones(len(x_vals)), atol=tol) end = np.flatnonzero(mask_1)[-1] + 1 x_clean = x_vals[start : end + 1] y_clean = y_vals[start : end + 1] return y_clean, x_clean
[docs] def clean_composite_curve( y_array: np.ndarray | list, x_array: np.ndarray | list ) -> Tuple[np.ndarray | list]: """Remove redundant points in composite curves.""" # Round to avoid tiny numerical errors y_vals, x_vals = clean_composite_curve_ends(y_array, x_array) if len(x_vals) <= 2: return y_vals, x_vals x_clean, y_clean = [x_vals[0]], [y_vals[0]] for i in range(1, len(x_vals) - 1): x1, x2, x3 = x_vals[i - 1], x_vals[i], x_vals[i + 1] y1, y2, y3 = y_vals[i - 1], y_vals[i], y_vals[i + 1] if x1 == x3: # All three x are the same; keep x2 only if y2 is different if x1 != x2: x_clean.append(x2) y_clean.append(y2) else: # Linear interpolation check y_interp = y1 + (y3 - y1) * (x2 - x1) / (x3 - x1) if abs(y2 - y_interp) > tol: x_clean.append(x2) y_clean.append(y2) x_clean.append(x_vals[-1]) y_clean.append(y_vals[-1]) if abs(x_clean[0] - x_clean[1]) < tol: x_clean.pop(0) y_clean.pop(0) i = len(x_clean) - 1 if abs(x_clean[i] - x_clean[i - 1]) < tol: x_clean.pop(i) y_clean.pop(i) return np.asarray(y_clean), np.asarray(x_clean)
[docs] def graph_simple_cc_plot(Tc, Hc, Th, Hh): """Render a quick Plotly plot of hot/cold composite curves for debugging.""" go = _require_plotly() fig = go.Figure() fig.add_trace( go.Scatter( x=Hc, y=Tc, mode="lines", name="Cold composite", ) ) fig.add_trace( go.Scatter( x=Hh, y=Th, mode="lines", name="Hot composite", ) ) fig.update_layout( title="Balanced Composite Curves", xaxis_title="Enthalpy", yaxis_title="Temperature", template="plotly_white", ) fig.update_yaxes(showgrid=True, gridcolor="rgba(0, 0, 0, 0.15)") fig.update_xaxes(showgrid=True, gridcolor="rgba(0, 0, 0, 0.15)") fig.show() return fig
[docs] def interp_with_plateaus( h_vals: np.ndarray, t_vals: np.ndarray, targets: np.ndarray, side: str, tol: float = 1e-6, ) -> np.ndarray: """Interpolate temperatures while respecting vertical curve segments.""" if side not in {"left", "right"}: raise ValueError("side must be 'left' or 'right'") h_vals = np.asarray(h_vals, dtype=float) t_vals = np.asarray(t_vals, dtype=float) targets = np.asarray(targets, dtype=float) if h_vals.size == 1: return np.full_like(targets, t_vals[0], dtype=float) h_monotonic = make_monotonic(h_vals, side, tol) return np.interp(targets, h_monotonic, t_vals)
[docs] def make_monotonic(h_vals: np.ndarray, side: str, tol: float = 1e-6) -> np.ndarray: """Adjust repeated values to become strictly increasing for interpolation.""" adjusted = np.asarray(h_vals, dtype=float).copy() if adjusted.size <= 1: return adjusted eps = tol * 0.5 # Identify the start of each strictly increasing block diff = np.abs(np.diff(adjusted)) > tol starts = np.flatnonzero(np.concatenate(([True], diff))) n = adjusted.size lengths = np.diff(np.append(starts, n)) if np.all(lengths == 1): return adjusted # Compute position within each block using vectorised repetition within_block = np.arange(n) - np.repeat(starts, lengths) block_lengths = np.repeat(lengths, lengths) mask = block_lengths > 1 offsets = np.zeros_like(adjusted) if side == "right": offsets[mask] = (block_lengths[mask] - 1 - within_block[mask]) * eps adjusted[mask] -= offsets[mask] else: # side == "left" offsets[mask] = within_block[mask] * eps adjusted[mask] += offsets[mask] return adjusted
[docs] def g_ineq_penalty( g: float | list | np.ndarray, *, eta: float = 0.01, rho: float = 10, form: str = "square", ) -> np.float64: """Return a penalty value for an inequality-constraint residual.""" g = np.asarray(g, dtype=float) if ( form.lower() == "square_root_smoothing" or form.lower() == "square root smoothing" ): p = 0.5 * rho * (g + ((g) ** 2 + (eta) ** 2) ** 0.5) elif form.lower() == "square": p = rho * (g**2) else: raise ValueError("Unrecognised penalty function form selection.") if isinstance(p, float): return np.float64(p) elif isinstance(p, np.ndarray): return p.sum() else: raise ValueError( "Return of the penalty function failed due to unrecognised type." )