diff --git a/sentience/agent.py b/sentience/agent.py index eb42a13..632c282 100644 --- a/sentience/agent.py +++ b/sentience/agent.py @@ -188,6 +188,25 @@ def _compute_hash(self, text: str) -> str: """Compute SHA256 hash of text.""" return hashlib.sha256(text.encode("utf-8")).hexdigest() + async def _best_effort_post_snapshot_digest(self, goal: str) -> str | None: + """ + Best-effort post-action snapshot digest for tracing (async). + """ + try: + snap_opts = SnapshotOptions( + limit=min(10, self.default_snapshot_limit), + goal=f"{goal} (post)", + ) + snap_opts.screenshot = False + snap_opts.show_overlay = self.config.show_overlay if self.config else None + post_snap = await snapshot_async(self.browser, snap_opts) + if post_snap.status != "success": + return None + digest_input = f"{post_snap.url}{post_snap.timestamp}" + return f"sha256:{self._compute_hash(digest_input)}" + except Exception: + return None + def _best_effort_post_snapshot_digest(self, goal: str) -> str | None: """ Best-effort post-action snapshot digest for tracing. diff --git a/sentience/agent_runtime.py b/sentience/agent_runtime.py index b64fb12..27cd7ff 100644 --- a/sentience/agent_runtime.py +++ b/sentience/agent_runtime.py @@ -358,6 +358,50 @@ async def snapshot(self, **kwargs: Any) -> Snapshot: await self._handle_captcha_if_needed(self.last_snapshot, source="gateway") return self.last_snapshot + async def sampled_snapshot( + self, + *, + samples: int = 4, + scroll_delta_y: float | None = None, + settle_ms: int = 250, + union_limit: int | None = None, + restore_scroll: bool = True, + **kwargs: Any, + ) -> Snapshot: + """ + Take multiple snapshots while scrolling and merge them into a "union snapshot". + + Intended for analysis/extraction on long / virtualized pages where a single + viewport snapshot is insufficient. + + IMPORTANT: + - The returned snapshot's element bboxes may not correspond to the current viewport. + Do NOT use it for clicking unless you also scroll to the right position. + - This method does NOT update `self.last_snapshot` (to avoid confusing verification + loops that depend on the current viewport). + """ + # Legacy browser path: fall back to a single snapshot (we can't rely on backend ops). + if hasattr(self, "_legacy_browser") and hasattr(self, "_legacy_page"): + return await self.snapshot(**kwargs) + + from .backends.snapshot import sampled_snapshot as backend_sampled_snapshot + + # Merge default options with call-specific kwargs + options_dict = self._snapshot_options.model_dump(exclude_none=True) + options_dict.update(kwargs) + options = SnapshotOptions(**options_dict) + + snap = await backend_sampled_snapshot( + self.backend, + options=options, + samples=samples, + scroll_delta_y=scroll_delta_y, + settle_ms=settle_ms, + union_limit=union_limit, + restore_scroll=restore_scroll, + ) + return snap + async def evaluate_js(self, request: EvaluateJsRequest) -> EvaluateJsResult: """ Evaluate JavaScript expression in the active backend. diff --git a/sentience/backends/snapshot.py b/sentience/backends/snapshot.py index b1cda88..b09b1cb 100644 --- a/sentience/backends/snapshot.py +++ b/sentience/backends/snapshot.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any from ..constants import SENTIENCE_API_URL -from ..models import Snapshot, SnapshotOptions +from ..models import Element, Snapshot, SnapshotOptions from ..snapshot import ( _build_snapshot_payload, _merge_api_result_with_local, @@ -259,6 +259,182 @@ async def snapshot( return await _snapshot_via_extension(backend, options) +def _normalize_ws(text: str) -> str: + return " ".join((text or "").split()).strip() + + +def _dedupe_key(el: Element) -> tuple: + """ + Best-effort stable dedupe key across scroll-sampled snapshots. + + Notes: + - IDs are not reliable across snapshots (virtualization can remount nodes). + - BBox coordinates are viewport-relative and depend on scroll position. + - Prefer href/name/text + approximate document position when available. + """ + href = (el.href or "").strip() + if href: + return ("href", href) + + name = _normalize_ws(el.name or "") + if name: + return ("role_name", el.role, name) + + text = _normalize_ws(el.text or "") + doc_y = el.doc_y + if text: + # Use doc_y when present (more stable across scroll positions than bbox.y). + if isinstance(doc_y, (int, float)): + return ("role_text_docy", el.role, text[:120], int(float(doc_y) // 10)) + return ("role_text", el.role, text[:120]) + + # Fallback: role + approximate position + if isinstance(doc_y, (int, float)): + return ("role_docy", el.role, int(float(doc_y) // 10)) + + # Last resort (can still dedupe within a single snapshot) + return ("id", int(el.id)) + + +def merge_snapshots( + snaps: list[Snapshot], + *, + union_limit: int | None = None, +) -> Snapshot: + """ + Merge multiple snapshots into a single "union snapshot" for analysis/extraction. + + CRITICAL: + - Element bboxes are viewport-relative to the scroll position at the time each snapshot + was taken. Do NOT use merged elements for direct clicking unless you also scroll + back to their position. + """ + if not snaps: + raise ValueError("merge_snapshots requires at least one snapshot") + + base = snaps[0] + best_by_key: dict[tuple, Element] = {} + first_seen_idx: dict[tuple, int] = {} + + # Keep the "best" representative per key: + # - Prefer higher importance (usually means in-viewport at that sampling moment) + # - Prefer having href/text/name (more useful for extraction) + def _quality_score(e: Element) -> tuple: + has_href = 1 if (e.href or "").strip() else 0 + has_text = 1 if _normalize_ws(e.text or "") else 0 + has_name = 1 if _normalize_ws(e.name or "") else 0 + has_docy = 1 if isinstance(e.doc_y, (int, float)) else 0 + return (e.importance, has_href, has_text, has_name, has_docy) + + idx = 0 + for snap in snaps: + for el in list(getattr(snap, "elements", []) or []): + k = _dedupe_key(el) + if k not in first_seen_idx: + first_seen_idx[k] = idx + prev = best_by_key.get(k) + if prev is None or _quality_score(el) > _quality_score(prev): + best_by_key[k] = el + idx += 1 + + merged: list[Element] = list(best_by_key.values()) + + # Deterministic ordering: prefer document order when doc_y is available, + # then fall back to "first seen" (stable for a given sampling sequence). + def _sort_key(e: Element) -> tuple: + doc_y = e.doc_y + if isinstance(doc_y, (int, float)): + return (0, float(doc_y), -int(e.importance)) + return (1, float("inf"), first_seen_idx.get(_dedupe_key(e), 10**9)) + + merged.sort(key=_sort_key) + + if union_limit is not None: + try: + lim = max(1, int(union_limit)) + except (TypeError, ValueError): + lim = None + if lim is not None: + merged = merged[:lim] + + # Construct a new Snapshot object with merged elements. + # Keep base url/viewport/diagnostics, and drop screenshot by default to avoid confusion. + data = base.model_dump() + data["elements"] = [e.model_dump() for e in merged] + data["screenshot"] = None + return Snapshot(**data) + + +async def sampled_snapshot( + backend: "BrowserBackend", + *, + options: SnapshotOptions | None = None, + samples: int = 4, + scroll_delta_y: float | None = None, + settle_ms: int = 250, + union_limit: int | None = None, + restore_scroll: bool = True, +) -> Snapshot: + """ + Take multiple snapshots while scrolling downward and return a merged union snapshot. + + Designed for long / virtualized results pages where a single viewport snapshot + cannot cover enough relevant items. + """ + if options is None: + options = SnapshotOptions() + + k = max(1, int(samples)) + if k <= 1: + return await snapshot(backend, options=options) + + # Baseline scroll position + try: + info = await backend.refresh_page_info() + base_scroll_y = float(getattr(info, "scroll_y", 0.0) or 0.0) + vh = float(getattr(info, "height", 800) or 800) + except Exception: # pylint: disable=broad-exception-caught + base_scroll_y = 0.0 + vh = 800.0 + + # Choose a conservative scroll delta if not provided. + delta = float(scroll_delta_y) if scroll_delta_y is not None else (vh * 0.9) + if delta <= 0: + delta = max(200.0, vh * 0.9) + + snaps: list[Snapshot] = [] + try: + # Snapshot at current position. + snaps.append(await snapshot(backend, options=options)) + + for _i in range(1, k): + try: + # Scroll by wheel delta (plays nicer with sites that hook scroll events). + await backend.wheel(delta_y=delta) + except Exception: # pylint: disable=broad-exception-caught + # Fallback: direct scrollTo + try: + cur = await backend.eval("window.scrollY") + await backend.call("(y) => window.scrollTo(0, y)", [float(cur) + delta]) + except Exception: # pylint: disable=broad-exception-caught + break + + if settle_ms > 0: + await asyncio.sleep(float(settle_ms) / 1000.0) + + snaps.append(await snapshot(backend, options=options)) + finally: + if restore_scroll: + try: + await backend.call("(y) => window.scrollTo(0, y)", [float(base_scroll_y)]) + if settle_ms > 0: + await asyncio.sleep(min(0.2, float(settle_ms) / 1000.0)) + except Exception: # pylint: disable=broad-exception-caught + pass + + return merge_snapshots(snaps, union_limit=union_limit) + + async def _wait_for_extension( backend: "BrowserBackend", timeout_ms: int = 5000, @@ -273,7 +449,6 @@ async def _wait_for_extension( Raises: RuntimeError: If extension not injected within timeout """ - import asyncio import logging logger = logging.getLogger("sentience.backends.snapshot") @@ -446,6 +621,15 @@ async def _snapshot_via_api( # Re-raise validation errors as-is raise except Exception as e: + # Preserve structured gateway details when available. + try: + from ..snapshot import SnapshotGatewayError # type: ignore + + if isinstance(e, SnapshotGatewayError): + raise + except Exception: + pass + # Fallback to local extension on API error # This matches the behavior of the main snapshot function raise RuntimeError( diff --git a/sentience/cloud_tracing.py b/sentience/cloud_tracing.py index 8f1e9bc..a707c63 100644 --- a/sentience/cloud_tracing.py +++ b/sentience/cloud_tracing.py @@ -581,6 +581,56 @@ def _complete_trace(self) -> None: if self.logger: self.logger.warning(f"Error reporting trace completion: {e}") + def _normalize_screenshot_data( + self, screenshot_raw: str, default_format: str = "jpeg" + ) -> tuple[str, str]: + """ + Normalize screenshot data by extracting base64 from data URL if needed. + + Handles both formats: + - Data URL: "data:image/jpeg;base64,/9j/4AAQ..." + - Pure base64: "/9j/4AAQ..." + + Args: + screenshot_raw: Raw screenshot data (data URL or base64) + default_format: Default format if not detected from data URL + + Returns: + Tuple of (base64_string, format_string) + """ + if not screenshot_raw: + return "", default_format + + # Check if it's a data URL + if screenshot_raw.startswith("data:image"): + # Extract format from "data:image/jpeg;base64,..." or "data:image/png;base64,..." + try: + # Split on comma to get the base64 part + if "," in screenshot_raw: + header, base64_data = screenshot_raw.split(",", 1) + # Extract format from header: "data:image/jpeg;base64" + if "/" in header and ";" in header: + format_part = header.split("/")[1].split(";")[0] + if format_part in ("jpeg", "jpg"): + return base64_data, "jpeg" + elif format_part == "png": + return base64_data, "png" + return base64_data, default_format + else: + # Malformed data URL - return as-is with warning + if self.logger: + self.logger.warning( + "Malformed data URL in screenshot_base64 (missing comma)" + ) + return screenshot_raw, default_format + except Exception as e: + if self.logger: + self.logger.warning(f"Error parsing screenshot data URL: {e}") + return screenshot_raw, default_format + + # Already pure base64 + return screenshot_raw, default_format + def _extract_screenshots_from_trace(self) -> dict[int, dict[str, Any]]: """ Extract screenshots from trace events. @@ -604,15 +654,22 @@ def _extract_screenshots_from_trace(self) -> dict[int, dict[str, Any]]: # Check if this is a snapshot event with screenshot if event.get("type") == "snapshot": data = event.get("data", {}) - screenshot_base64 = data.get("screenshot_base64") - - if screenshot_base64: - sequence += 1 - screenshots[sequence] = { - "base64": screenshot_base64, - "format": data.get("screenshot_format", "jpeg"), - "step_id": event.get("step_id"), - } + screenshot_raw = data.get("screenshot_base64") + + if screenshot_raw: + # Normalize: extract base64 from data URL if needed + # Handles both "data:image/jpeg;base64,..." and pure base64 + screenshot_base64, screenshot_format = self._normalize_screenshot_data( + screenshot_raw, + data.get("screenshot_format", "jpeg"), + ) + if screenshot_base64: + sequence += 1 + screenshots[sequence] = { + "base64": screenshot_base64, + "format": screenshot_format, + "step_id": event.get("step_id"), + } except Exception as e: if self.logger: self.logger.error(f"Error extracting screenshots: {e}") @@ -755,10 +812,29 @@ def upload_one(seq: int, url: str) -> bool: try: screenshot_data = screenshots[seq] base64_str = screenshot_data["base64"] - format_str = screenshot_data.get("format", "jpeg") + format_str = str(screenshot_data.get("format", "jpeg") or "jpeg").lower() + content_type = "image/jpeg" # Decode base64 to image bytes image_bytes = base64.b64decode(base64_str) + if format_str not in ("jpeg", "jpg"): + # Convert to JPEG to match presigned content-type. + try: + from io import BytesIO + + from PIL import Image + + with Image.open(BytesIO(image_bytes)) as img: + rgb = img.convert("RGB") + out = BytesIO() + rgb.save(out, format="JPEG", quality=80) + image_bytes = out.getvalue() + format_str = "jpeg" + except Exception as e: + if self.logger: + self.logger.warning( + f"Screenshot {seq} format '{format_str}' could not be converted to JPEG: {e}" + ) image_size = len(image_bytes) # Update total size @@ -769,7 +845,7 @@ def upload_one(seq: int, url: str) -> bool: url, data=image_bytes, # Binary image data headers={ - "Content-Type": f"image/{format_str}", + "Content-Type": content_type, }, timeout=30, # 30 second timeout per screenshot ) diff --git a/sentience/llm_provider.py b/sentience/llm_provider.py index cb25603..21db342 100644 --- a/sentience/llm_provider.py +++ b/sentience/llm_provider.py @@ -353,6 +353,21 @@ def __init__( ) super().__init__(api_key=api_key, model=model, base_url=base_url) + def supports_vision(self) -> bool: + """ + DeepInfra hosts many non-OpenAI multimodal models. + + Their OpenAI-compatible API supports the same `image_url` message format: + `{"type":"image_url","image_url":{"url":"data:image/png;base64,..."}}` + + We therefore allow vision for common DeepInfra model naming patterns. + """ + model_lower = self._model_name.lower() + if any(x in model_lower for x in ["vision", "llava", "qvq", "ocr"]): + return True + # Preserve OpenAI-style vision detection for GPT models served via DeepInfra. + return super().supports_vision() + class AnthropicProvider(LLMProvider): """ diff --git a/sentience/snapshot.py b/sentience/snapshot.py index 277a85e..e0393fd 100644 --- a/sentience/snapshot.py +++ b/sentience/snapshot.py @@ -20,6 +20,161 @@ MAX_PAYLOAD_BYTES = 10 * 1024 * 1024 +class SnapshotGatewayError(RuntimeError): + """ + Structured error for server-side (gateway) snapshot failures. + + Keeps HTTP status/URL/response details available to callers for better logging/debugging. + Subclasses RuntimeError for backward compatibility. + """ + + def __init__( + self, + message: str, + *, + status_code: int | None = None, + url: str | None = None, + request_id: str | None = None, + response_text: str | None = None, + cause: Exception | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.url = url + self.request_id = request_id + self.response_text = response_text + # Note: callers should use `raise ... from cause` to preserve chaining. + _ = cause + + @staticmethod + def _snip(s: str | None, n: int = 400) -> str | None: + if not s: + return None + t = str(s).replace("\n", " ").replace("\r", " ").strip() + return t[:n] + + @classmethod + def from_httpx(cls, e: Exception) -> "SnapshotGatewayError": + status_code = None + url = None + request_id = None + body = None + try: + resp = getattr(e, "response", None) + if resp is not None: + status_code = getattr(resp, "status_code", None) + try: + url = str(getattr(resp, "url", None) or "") + except Exception: + url = None + try: + headers = getattr(resp, "headers", None) or {} + request_id = headers.get("x-request-id") or headers.get("x-trace-id") + except Exception: + request_id = None + try: + body = getattr(resp, "text", None) + except Exception: + body = None + req = getattr(e, "request", None) + if url is None and req is not None: + try: + url = str(getattr(req, "url", None) or "") + except Exception: + url = None + except Exception: + pass + + msg = "Server-side snapshot API failed" + bits = [] + if status_code is not None: + bits.append(f"status={status_code}") + if url: + bits.append(f"url={url}") + if request_id: + bits.append(f"request_id={request_id}") + body_snip = cls._snip(body) + if body_snip: + bits.append(f"body={body_snip}") + # If we don't have an HTTP status/response body, this is usually a transport error + # (timeout, DNS, connection reset). Preserve at least the exception type + message. + if status_code is None and not body_snip: + try: + err_s = cls._snip(str(e), 220) + except Exception: + err_s = None + bits.append(f"err_type={type(e).__name__}") + if err_s: + bits.append(f"err={err_s}") + if bits: + msg = f"{msg}: " + " ".join(bits) + msg = msg + ". Try using use_api=False to use local extension instead." + return cls( + msg, + status_code=int(status_code) if status_code is not None else None, + url=url, + request_id=str(request_id) if request_id else None, + response_text=body_snip, + cause=e, + ) + + @classmethod + def from_requests(cls, e: Exception) -> "SnapshotGatewayError": + status_code = None + url = None + request_id = None + body = None + try: + resp = getattr(e, "response", None) + if resp is not None: + status_code = getattr(resp, "status_code", None) + try: + url = str(getattr(resp, "url", None) or "") + except Exception: + url = None + try: + headers = getattr(resp, "headers", None) or {} + request_id = headers.get("x-request-id") or headers.get("x-trace-id") + except Exception: + request_id = None + try: + body = getattr(resp, "text", None) + except Exception: + body = None + except Exception: + pass + msg = "Server-side snapshot API failed" + bits = [] + if status_code is not None: + bits.append(f"status={status_code}") + if url: + bits.append(f"url={url}") + if request_id: + bits.append(f"request_id={request_id}") + body_snip = cls._snip(body) + if body_snip: + bits.append(f"body={body_snip}") + if status_code is None and not body_snip: + try: + err_s = cls._snip(str(e), 220) + except Exception: + err_s = None + bits.append(f"err_type={type(e).__name__}") + if err_s: + bits.append(f"err={err_s}") + if bits: + msg = f"{msg}: " + " ".join(bits) + msg = msg + ". Try using use_api=False to use local extension instead." + return cls( + msg, + status_code=int(status_code) if status_code is not None else None, + url=url, + request_id=str(request_id) if request_id else None, + response_text=body_snip, + cause=e, + ) + + def _is_execution_context_destroyed_error(e: Exception) -> bool: """ Playwright can throw while a navigation is in-flight, invalidating the JS execution context. @@ -170,14 +325,20 @@ def _post_snapshot_to_gateway_sync( "Content-Type": "application/json", } - response = requests.post( - f"{api_url}/v1/snapshot", - data=payload_json, - headers=headers, - timeout=30, - ) - response.raise_for_status() - return response.json() + try: + response = requests.post( + f"{api_url}/v1/snapshot", + data=payload_json, + headers=headers, + timeout=30, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + raise SnapshotGatewayError.from_requests(e) from e + except requests.exceptions.RequestException as e: + # Network/timeouts/etc (no status code available) + raise SnapshotGatewayError.from_requests(e) from e async def _post_snapshot_to_gateway_async( @@ -202,13 +363,21 @@ async def _post_snapshot_to_gateway_async( } async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - f"{api_url}/v1/snapshot", - content=payload_json, - headers=headers, - ) - response.raise_for_status() - return response.json() + try: + response = await client.post( + f"{api_url}/v1/snapshot", + content=payload_json, + headers=headers, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise SnapshotGatewayError.from_httpx(e) from e + except httpx.RequestError as e: + raise SnapshotGatewayError.from_httpx(e) from e + except Exception as e: + # JSON decode or other unexpected issues — keep details if possible. + raise SnapshotGatewayError.from_httpx(e) from e def _merge_api_result_with_local( diff --git a/tests/test_cloud_tracing.py b/tests/test_cloud_tracing.py index 31888f0..5b343cb 100644 --- a/tests/test_cloud_tracing.py +++ b/tests/test_cloud_tracing.py @@ -394,6 +394,90 @@ def post_side_effect(*args, **kwargs): if cleaned_trace_path.exists(): os.remove(cleaned_trace_path) + def test_normalize_screenshot_data_handles_data_url(self): + """Test that _normalize_screenshot_data extracts base64 from data URLs.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = f"test-run-{uuid.uuid4().hex[:8]}" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + try: + # Test JPEG data URL + jpeg_data_url = "data:image/jpeg;base64,/9j/4AAQSkZJRg..." + base64_str, fmt = sink._normalize_screenshot_data(jpeg_data_url) + assert base64_str == "/9j/4AAQSkZJRg..." + assert fmt == "jpeg" + + # Test PNG data URL + png_data_url = "data:image/png;base64,iVBORw0KGgoAAAA..." + base64_str, fmt = sink._normalize_screenshot_data(png_data_url) + assert base64_str == "iVBORw0KGgoAAAA..." + assert fmt == "png" + + # Test pure base64 (should pass through unchanged) + pure_base64 = "/9j/4AAQSkZJRg..." + base64_str, fmt = sink._normalize_screenshot_data(pure_base64, "jpeg") + assert base64_str == "/9j/4AAQSkZJRg..." + assert fmt == "jpeg" + + # Test empty string + base64_str, fmt = sink._normalize_screenshot_data("") + assert base64_str == "" + assert fmt == "jpeg" + finally: + # Close the sink to release file handle (required on Windows) + sink.close() + + def test_cloud_trace_sink_handles_data_url_in_screenshot(self): + """Test that CloudTraceSink properly extracts screenshots from data URLs.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = f"test-run-{uuid.uuid4().hex[:8]}" + api_key = "sk_test_123" + + # Create test screenshot as a data URL (how langchain-debugging was sending it) + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + data_url = f"data:image/png;base64,{test_image_base64}" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + # Emit trace event with screenshot as data URL (not pure base64) + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "step_id": "step-1", + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": data_url, # Data URL, not pure base64 + "screenshot_format": "png", + }, + } + ) + + # Extract screenshots - should normalize data URL to pure base64 + screenshots = sink._extract_screenshots_from_trace() + + assert len(screenshots) == 1 + assert 1 in screenshots + # Verify the base64 was extracted from data URL (no "data:image" prefix) + assert screenshots[1]["base64"] == test_image_base64 + assert not screenshots[1]["base64"].startswith("data:") + assert screenshots[1]["format"] == "png" + + # Cleanup + sink.close() + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + cleaned_trace_path = cache_dir / f"{run_id}.cleaned.jsonl" + if trace_path.exists(): + os.remove(trace_path) + if cleaned_trace_path.exists(): + os.remove(cleaned_trace_path) + class TestTracerFactory: """Test create_tracer factory function.""" diff --git a/tests/test_llm_provider_vision.py b/tests/test_llm_provider_vision.py new file mode 100644 index 0000000..3bc9fe9 --- /dev/null +++ b/tests/test_llm_provider_vision.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest + + +def test_deepinfra_provider_supports_vision_for_common_multimodal_names() -> None: + pytest.importorskip("openai") + + from sentience.llm_provider import DeepInfraProvider + + p1 = DeepInfraProvider(api_key="x", model="meta-llama/Llama-3.2-11B-Vision-Instruct") + assert p1.supports_vision() is True + + p2 = DeepInfraProvider(api_key="x", model="deepseek-ai/DeepSeek-OCR") + assert p2.supports_vision() is True + + p3 = DeepInfraProvider(api_key="x", model="deepseek-ai/DeepSeek-V3.1") + assert p3.supports_vision() is False + diff --git a/tests/test_video_recording.py b/tests/test_video_recording.py index 9a069da..b164dbf 100644 --- a/tests/test_video_recording.py +++ b/tests/test_video_recording.py @@ -10,6 +10,10 @@ from sentience import SentienceBrowser +# Use a data URL to avoid network dependency (DNS resolution can fail in CI) +# This is a minimal valid HTML page +TEST_PAGE_URL = "data:text/html,