diff --git a/examples/gds-example.ipynb b/examples/gds-example.ipynb index 57604af..3b032d6 100644 --- a/examples/gds-example.ipynb +++ b/examples/gds-example.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "source": [ "# Visualizing Neo4j Graph Data Science (GDS) Graphs" @@ -10,6 +11,7 @@ { "cell_type": "code", "execution_count": null, + "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "outputs": [], "source": [ @@ -17,47 +19,103 @@ "%pip install matplotlib" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" + ] + }, { "cell_type": "markdown", + "id": "8dd0d8092fe74a7c96281538738b07e2", "metadata": {}, "source": [ - "## Setup GDS graph" + "## Setup GDS graph\n", + "\n", + "To use GDS, you can either use GDS as a plugin or Aura Graph Analytics.\n", + "In the following, you can choose:\n", + "\n", + " * Provide Aura API credentials and and use Aura Graph Analytics.\n", + " * Use Neo4j + GDS Plugin.\n", + "\n", + "For more information, see the [GDS documentation](https://neo4j.com/docs/graph-data-science/current/installation/)." ] }, { "cell_type": "code", "execution_count": null, + "id": "72eea5119410473aa328ad9291626812", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", + "from graphdatascience.session import (\n", + " GdsSessions,\n", + " DbmsConnectionInfo,\n", + " AuraAPICredentials,\n", + " SessionMemory,\n", + ")\n", "from graphdatascience import GraphDataScience\n", "\n", "# Get Neo4j DB URI, credentials and name from environment if applicable\n", - "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n", - "NEO4J_AUTH = (\"neo4j\", None)\n", - "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n", - "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n", - " NEO4J_AUTH = (\n", - " os.environ.get(\"NEO4J_USER\"),\n", - " os.environ.get(\"NEO4J_PASSWORD\"),\n", + "db_connection = DbmsConnectionInfo(\n", + " aura_instance_id=os.environ.get(\"AURA_INSTANCEID\"),\n", + " username=os.environ[\"NEO4J_USERNAME\"],\n", + " password=os.environ[\"NEO4J_PASSWORD\"],\n", + " uri=os.environ[\"NEO4J_URI\"],\n", + ")\n", + "\n", + "session_name = \"neo4j-viz-gds-example\"\n", + "if os.environ.get(\"AURA_API_CLIENT_ID\"):\n", + " # Use Aura Graph Analytics\n", + " sessions = GdsSessions(\n", + " api_credentials=AuraAPICredentials(\n", + " client_id=os.environ[\"AURA_API_CLIENT_ID\"],\n", + " client_secret=os.environ[\"AURA_API_CLIENT_SECRET\"],\n", + " project_id=os.environ.get(\"AURA_API_PROJECT_ID\"),\n", + " )\n", + " )\n", + " gds = sessions.get_or_create(\n", + " session_name=session_name,\n", + " memory=SessionMemory.m_2GB,\n", + " db_connection=db_connection,\n", " )\n", - "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)" + "else:\n", + " # Use GDS Plugin\n", + " sessions = None\n", + " gds = GraphDataScience(\n", + " endpoint=db_connection.get_uri(),\n", + " auth=(db_connection.username, db_connection.password),\n", + " )" ] }, { "cell_type": "code", "execution_count": null, + "id": "8edb47106e1a46a883d545849b8ab81b", "metadata": {}, "outputs": [], "source": [ "G = gds.graph.load_cora(graph_name=\"cora\")" ] }, + { + "cell_type": "markdown", + "id": "10185d26023b46108eb7d9f57d49d2b3", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, + "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": {}, "outputs": [], "source": [ @@ -71,6 +129,7 @@ }, { "cell_type": "markdown", + "id": "7623eae2785240b9bd12b16a66d81610", "metadata": {}, "source": [ "## Visualization" @@ -79,1610 +138,32 @@ { "cell_type": "code", "execution_count": null, + "id": "7cdc8c89c7104fffa095e18ddfef8986", "metadata": {}, "outputs": [], "source": [ "from neo4j_viz.gds import from_gds\n", "\n", - "VG = from_gds(gds, G, max_node_count=500)\n", - "str(VG)" + "VG = from_gds(\n", + " gds,\n", + " G,\n", + " max_node_count=100,\n", + ")" ] }, { "cell_type": "code", - "execution_count": 41, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "id": "b118ea5561624da68c537baed56e602f", + "metadata": {}, + "outputs": [], "source": [ - "VG.render(theme=\"auto\")" + "VG.render()" ] }, { "cell_type": "markdown", + "id": "938c804e27f84196a10c8828c723f798", "metadata": {}, "source": [ "### Changing captions\n", @@ -1694,6 +175,7 @@ { "cell_type": "code", "execution_count": null, + "id": "504fb2a444614c0babb325280ed9130a", "metadata": {}, "outputs": [], "source": [ @@ -1704,9 +186,8 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "id": "59bbdb311c014d738909a11f9e486628", + "metadata": {}, "outputs": [], "source": [ "VG.render()" @@ -1714,6 +195,7 @@ }, { "cell_type": "markdown", + "id": "b43b363d81ae4b689946ece5c682cd59", "metadata": {}, "source": [ "## Sizing the nodes\n", @@ -1724,16 +206,17 @@ { "cell_type": "code", "execution_count": null, + "id": "8a65eabff63a45729fe45fb5ade58bdc", "metadata": {}, "outputs": [], "source": [ "VG.resize_nodes(property=\"pagerank\")\n", - "VG.color_nodes(property=\"componentId\")\n", "VG.render()" ] }, { "cell_type": "markdown", + "id": "c3933fab20d04ec698c2621248eb3be0", "metadata": {}, "source": [ "### Coloring" @@ -1741,6 +224,7 @@ }, { "cell_type": "markdown", + "id": "4dd4641cc4064e0191573fe9c69df29b", "metadata": {}, "source": [ "There are two main ways of coloring the nodes of a graph:\n", @@ -1754,17 +238,17 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "id": "8309879909854d7188b41380fd92a7c3", + "metadata": {}, "outputs": [], "source": [ - "VG.color_nodes(property=\"componentId\")\n", + "VG.color_nodes(property=\"subject\")\n", "VG.render()" ] }, { "cell_type": "markdown", + "id": "3ed186c9a28b402fb0bc4494df01f08d", "metadata": {}, "source": [ "Now, let us color by our continuous node field \"size\" that we computed above with PageRank, again using the default colors.\n", @@ -1775,9 +259,8 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "id": "cb1e1581032b452c9409d6c6813c49d1", + "metadata": {}, "outputs": [], "source": [ "from neo4j_viz.colors import ColorSpace\n", @@ -1788,6 +271,7 @@ }, { "cell_type": "markdown", + "id": "379cbbc1e968416e875cc15c1202d7eb", "metadata": {}, "source": [ "#### Custom coloring\n", @@ -1799,6 +283,7 @@ { "cell_type": "code", "execution_count": null, + "id": "277c27b1587741f2af2001be3712ef0d", "metadata": {}, "outputs": [], "source": [ @@ -1808,6 +293,7 @@ { "cell_type": "code", "execution_count": null, + "id": "db7b79bc585a40fcaf58bf750017e135", "metadata": {}, "outputs": [], "source": [ @@ -1825,1599 +311,17 @@ }, { "cell_type": "code", - "execution_count": 49, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "id": "916684f9a58a4a2aa5f864670399430d", + "metadata": {}, + "outputs": [], "source": [ "VG.render()" ] }, { "cell_type": "markdown", + "id": "1671c31a24314836a5b85d7ef7fbf015", "metadata": {}, "source": [ "### Render options\n", @@ -3433,1593 +337,10 @@ }, { "cell_type": "code", - "execution_count": 50, - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "id": "33b0902fd34d4ace834912fa1002cf8e", + "metadata": {}, + "outputs": [], "source": [ "from neo4j_viz import Layout\n", "\n", @@ -5028,6 +349,7 @@ }, { "cell_type": "markdown", + "id": "f6fa52606d8c4a75a9b52967216f8f3f", "metadata": {}, "source": [ "## Saving the visualization" @@ -5036,6 +358,7 @@ { "cell_type": "code", "execution_count": null, + "id": "f5a1fa73e5044315a093ec459c9be902", "metadata": {}, "outputs": [], "source": [ @@ -5050,6 +373,7 @@ }, { "cell_type": "markdown", + "id": "cdf66aed5cc84ca1b48e60bad68798a8", "metadata": {}, "source": [ "## Cleanup\n", @@ -5060,22 +384,26 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [ - "teardown" - ] - }, + "id": "28d3efd5258a48a79c179ea5c6759f01", + "metadata": {}, "outputs": [], "source": [ "gds.graph.drop(\"cora\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f9bc0b9dd2c44919cc8dcca39b469f8", + "metadata": {}, + "outputs": [], + "source": [ + "if sessions:\n", + " sessions.delete(session_name=session_name)" + ] } ], - "metadata": { - "language_info": { - "name": "python" - } - }, + "metadata": {}, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/examples/neo4j-example.ipynb b/examples/neo4j-example.ipynb index 898bb43..87e3f02 100644 --- a/examples/neo4j-example.ipynb +++ b/examples/neo4j-example.ipynb @@ -167,1613 +167,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "2322065c", "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from neo4j_viz.neo4j import from_neo4j\n", "\n", @@ -1794,1613 +195,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "e8b0f4c6", "metadata": { "tags": [ "preserve-output" ] }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "VG.render()" ] diff --git a/justfile b/justfile index ea83288..bc58784 100644 --- a/justfile +++ b/justfile @@ -18,13 +18,25 @@ py-test-gds: trap "cd $ENV_DIR && docker compose down" EXIT cd $ENV_DIR && docker compose up -d cd - + cd python-wrapper && \ NEO4J_URI=bolt://localhost:7687 \ NEO4J_USER=neo4j \ NEO4J_PASSWORD=password \ NEO4J_DB=neo4j \ - cd python-wrapper && uv run --group dev --extra gds pytest tests --include-neo4j-and-gds + uv run --group dev --extra gds pytest tests --include-neo4j-and-gds cd .. + +# this expects the local compose setup to be running. +py-test-gds-sessions filter="": + #!/usr/bin/env bash + cd python-wrapper && \ + GDS_SESSION_URI=bolt://localhost:7688 \ + NEO4J_URI=bolt://localhost:7687 \ + NEO4J_USER=neo4j \ + NEO4J_PASSWORD=password \ + uv run --group dev --extra gds pytest tests --include-neo4j-and-gds {{ if filter != "" { "-k '" + filter + "'" } else { "" } }} + local-neo4j-setup: #!/usr/bin/env bash set -e diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml index 0a63d73..baafb12 100644 --- a/python-wrapper/pyproject.toml +++ b/python-wrapper/pyproject.toml @@ -42,7 +42,7 @@ requires-python = ">=3.10" [project.optional-dependencies] pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"] -gds = ["graphdatascience>=1, <2"] +gds = ["graphdatascience>=1.21, <2"] neo4j = ["neo4j"] snowflake = ["snowflake-snowpark-python>=1, <2"] @@ -76,9 +76,9 @@ notebook = [ "palettable>=3.3.3", "matplotlib>=3.9.4", "snowflake-snowpark-python==1.42.0", - "python-dotenv", "requests", "marimo", + "python-dotenv" ] [project.urls] @@ -174,9 +174,3 @@ exclude = [ ] plugins = ['pydantic.mypy'] untyped_calls_exclude=["nbconvert"] - -[tool.marimo.runtime] -output_max_bytes = 20_000_000 -# -#[tool.marimo.server] -#follow_symlink = true diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py index e6a399a..d3feaff 100644 --- a/python-wrapper/src/neo4j_viz/gds.py +++ b/python-wrapper/src/neo4j_viz/gds.py @@ -2,11 +2,13 @@ import warnings from itertools import chain -from typing import Optional, cast +from typing import Collection, Optional from uuid import uuid4 import pandas as pd from graphdatascience import Graph, GraphDataScience +from graphdatascience.graph.v2 import GraphV2 +from graphdatascience.session import AuraGraphDataScience from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace @@ -15,48 +17,40 @@ def _fetch_node_dfs( - gds: GraphDataScience, - G: Graph, + gds: GraphDataScience | AuraGraphDataScience, + G: GraphV2, node_properties_by_label: dict[str, list[str]], - node_labels: list[str], + node_labels: Collection[str], additional_db_node_properties: list[str], ) -> dict[str, pd.DataFrame]: return { - lbl: gds.graph.nodeProperties.stream( + lbl: gds.v2.graph.node_properties.stream( G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], - separate_property_columns=True, db_node_properties=additional_db_node_properties, ) for lbl in node_labels } -def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]: - rel_types = G.relationship_types() - - rel_props = {rel_type: G.relationship_properties(rel_type) for rel_type in rel_types} +def _fetch_rel_dfs(gds: GraphDataScience, G: GraphV2) -> list[pd.DataFrame]: + rel_props = G.relationship_properties() rel_dfs: list[pd.DataFrame] = [] # Have to call per stream per relationship type as there was a bug in GDS < 2.21 for rel_type, props in rel_props.items(): - assert isinstance(props, list) - if len(props) > 0: - rel_df = gds.graph.relationshipProperties.stream( - G, relationship_types=rel_type, relationship_properties=list(props), separate_property_columns=True - ) - else: - rel_df = gds.graph.relationships.stream(G, relationship_types=[rel_type]) - + rel_df = gds.v2.graph.relationships.stream( + G, relationship_types=[rel_type], relationship_properties=list(props) + ) rel_dfs.append(rel_df) return rel_dfs def from_gds( - gds: GraphDataScience, - G: Graph, + gds: GraphDataScience | AuraGraphDataScience, + G: Graph | GraphV2, node_properties: Optional[list[str]] = None, db_node_properties: Optional[list[str]] = None, max_node_count: int = 10_000, @@ -76,9 +70,9 @@ def from_gds( Parameters ---------- - gds : GraphDataScience - GraphDataScience object. - G : Graph + gds + GraphDataScience object. AuraGraphDataScience object if using Aura Graph Analytics. + G Graph object. node_properties : list[str], optional Additional properties to include in the visualization node, by default None which means that all node @@ -91,37 +85,41 @@ def from_gds( """ if db_node_properties is None: db_node_properties = [] + if isinstance(G, Graph): + G_v2 = gds.v2.graph.get(G.name()) + else: + G_v2 = G - node_properties_from_gds = G.node_properties() - assert isinstance(node_properties_from_gds, pd.Series) - actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict()) - all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values())) + gds_properties_per_label = G_v2.node_properties() + all_gds_properties = list(chain.from_iterable(gds_properties_per_label.values())) node_properties_by_label_sets: dict[str, set[str]] = dict() if node_properties is None: - node_properties_by_label_sets = {k: set(v) for k, v in actual_node_properties.items()} + node_properties_by_label_sets = {k: set(v) for k, v in gds_properties_per_label.items()} else: for prop in node_properties: - if prop not in all_actual_node_properties: + if prop not in all_gds_properties: raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'") - for label, props in actual_node_properties.items(): + for label, props in gds_properties_per_label.items(): node_properties_by_label_sets[label] = { - prop for prop in actual_node_properties[label] if prop in node_properties + prop for prop in gds_properties_per_label[label] if prop in node_properties } node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()} - node_count = G.node_count() + node_count = G_v2.node_count() if node_count > max_node_count: warnings.warn( - f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed" + f"The '{G_v2.name()}' projection's node count ({G_v2.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed" ) sampling_ratio = float(max_node_count) / node_count sample_name = f"neo4j-viz_sample_{uuid4()}" - G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True) + G_fetched, _ = gds.v2.graph.sample.rwr( + G_v2, sample_name, sampling_ratio=sampling_ratio, node_label_stratification=True + ) else: - G_fetched = G + G_fetched = G_v2 property_name = None try: @@ -129,12 +127,12 @@ def from_gds( # as a temporary property to ensure that we have at least one property for each label to fetch if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0: property_name = f"neo4j-viz_property_{uuid4()}" - gds.degree.mutate(G_fetched, mutateProperty=property_name) + gds.v2.degree_centrality.mutate(G_fetched, mutate_property=property_name) for props in node_properties_by_label.values(): props.append(property_name) node_dfs = _fetch_node_dfs( - gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), db_node_properties + gds, G_fetched, node_properties_by_label, node_properties_by_label.keys(), db_node_properties ) if property_name is not None: for df in node_dfs.values(): @@ -145,7 +143,7 @@ def from_gds( if G_fetched.name() != G.name(): G_fetched.drop() elif property_name is not None: - gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name]) + gds.v2.graph.node_properties.drop(G_fetched, node_properties=[property_name]) for df in node_dfs.values(): if property_name is not None and property_name in df.columns: @@ -154,7 +152,7 @@ def from_gds( node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates(subset=["nodeId"]) for lbl, df in node_dfs.items(): - if "labels" in all_actual_node_properties: + if "labels" in all_gds_properties: df.rename(columns={"labels": "__labels"}, inplace=True) df["labels"] = lbl diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py index 40f7f4e..14eead8 100644 --- a/python-wrapper/tests/conftest.py +++ b/python-wrapper/tests/conftest.py @@ -1,4 +1,5 @@ import os +import random from typing import Any, Generator import pytest @@ -31,15 +32,20 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: @pytest.fixture(scope="package") -def aura_ds_instance() -> Generator[Any, None, None]: +def aura_db_instance() -> Generator[Any, None, None]: + if os.environ.get("NEO4J_URI", ""): + print(f"Skipping Aura DB setup since NEO4J_URI is set to {os.environ['NEO4J_URI']}") + yield None + return + if os.environ.get("AURA_API_CLIENT_ID", None) is None: yield None return - from tests.gds_helper import aura_api, create_aurads_instance + from tests.gds_helper import aura_api, create_auradb_instance api = aura_api() - id, dbms_connection_info = create_aurads_instance(api) + dbms_connection_info = create_auradb_instance(api) # setting as environment variables to run notebooks with this connection os.environ["NEO4J_URI"] = dbms_connection_info.get_uri() @@ -47,40 +53,55 @@ def aura_ds_instance() -> Generator[Any, None, None]: os.environ["NEO4J_USER"] = dbms_connection_info.username assert isinstance(dbms_connection_info.password, str) os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password + old_instance = os.environ.get("AURA_INSTANCEID", "") + if dbms_connection_info.aura_instance_id: + os.environ["AURA_INSTANCEID"] = dbms_connection_info.aura_instance_id + yield dbms_connection_info # Clear Neo4j_URI after test (rerun should create a new instance) - os.environ["NEO4J_URI"] = "" - api.delete_instance(id) + os.environ["AURA_INSTANCEID"] = old_instance + assert dbms_connection_info.aura_instance_id is not None + api.delete_instance(dbms_connection_info.aura_instance_id) @pytest.fixture(scope="package") -def gds(aura_ds_instance: Any) -> Generator[Any, None, None]: - from graphdatascience import GraphDataScience +def gds(aura_db_instance: Any) -> Generator[Any, None, None]: + from graphdatascience.session import SessionMemory - from tests.gds_helper import connect_to_plugin_gds + from tests.gds_helper import connect_to_local_gds_session, connect_to_plugin_gds, gds_sessions - if aura_ds_instance: - yield GraphDataScience( - endpoint=aura_ds_instance.uri, - auth=(aura_ds_instance.username, aura_ds_instance.password), - aura_ds=True, - database="neo4j", + if aura_db_instance: + sessions = gds_sessions() + + gds = sessions.get_or_create( + f"neo4j-viz-ci-{os.environ.get('GITHUB_RUN_ID', random.randint(0, 10**6))}", + memory=SessionMemory.m_2GB, + db_connection=aura_db_instance, ) + + yield gds + gds.delete() else: - NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") - gds = connect_to_plugin_gds(NEO4J_URI) + neo4j_uri = os.environ["NEO4J_URI"] + neo4j_auth = (os.environ.get("NEO4J_USER", "neo4j"), os.environ.get("NEO4J_PASSWORD", "password")) + + session_uri = os.environ.get("GDS_SESSION_URI") + if session_uri: + gds = connect_to_local_gds_session(session_uri, neo4j_uri, neo4j_auth) # type: ignore + else: + gds = connect_to_plugin_gds(neo4j_uri, neo4j_auth) # type: ignore yield gds gds.close() @pytest.fixture(scope="package") -def neo4j_driver(aura_ds_instance: Any) -> Generator[Any, None, None]: +def neo4j_driver(aura_db_instance: Any) -> Generator[Any, None, None]: import neo4j - if aura_ds_instance: + if aura_db_instance: driver = neo4j.GraphDatabase.driver( - aura_ds_instance.uri, auth=(aura_ds_instance.username, aura_ds_instance.password) + aura_db_instance.uri, auth=(aura_db_instance.username, aura_db_instance.password) ) else: NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py index e5a0d3d..b81d6ae 100644 --- a/python-wrapper/tests/gds_helper.py +++ b/python-wrapper/tests/gds_helper.py @@ -1,9 +1,10 @@ import os import re -from graphdatascience import GraphDataScience +from graphdatascience import GdsSessions, GraphDataScience +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.semantic_version.semantic_version import SemanticVersion -from graphdatascience.session import DbmsConnectionInfo, SessionMemory +from graphdatascience.session import AuraAPICredentials, AuraGraphDataScience, DbmsConnectionInfo, SessionMemory from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import InstanceCreateDetails from graphdatascience.version import __version__ @@ -26,12 +27,22 @@ def parse_version(version: str) -> SemanticVersion: GDS_VERSION = parse_version(__version__) -def connect_to_plugin_gds(uri: str) -> GraphDataScience: - NEO4J_AUTH = ("neo4j", "password") - if os.environ.get("NEO4J_USER"): - NEO4J_AUTH = (os.environ.get("NEO4J_USER", "DUMMY"), os.environ.get("NEO4J_PASSWORD", "neo4j")) +def connect_to_plugin_gds(uri: str, auth: tuple[str, str]) -> GraphDataScience: + return GraphDataScience(endpoint=uri, auth=auth, database="neo4j") - return GraphDataScience(endpoint=uri, auth=NEO4J_AUTH, database="neo4j") + +def connect_to_local_gds_session(session_uri: str, db_uri: str, db_auth: tuple[str, str]) -> AuraGraphDataScience: + session_bolt_connection_info = DbmsConnectionInfo(uri=session_uri, username="neo4j", password="password") + db_connection_info = DbmsConnectionInfo(uri=db_uri, username=db_auth[0], password=db_auth[1]) + + return AuraGraphDataScience.create( + session_bolt_connection_info=session_bolt_connection_info, + arrow_authentication=UsernamePasswordAuthentication( + session_bolt_connection_info.username, session_bolt_connection_info.password + ), + session_lifecycle_manager=None, # type: ignore + db_endpoint=db_connection_info, + ) def aura_api() -> AuraApi: @@ -49,21 +60,29 @@ def aura_api() -> AuraApi: ) -def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]: - # Switch to Sessions once they can be created without a DB +def gds_sessions() -> GdsSessions: + return GdsSessions( + api_credentials=AuraAPICredentials( + client_id=os.environ["AURA_API_CLIENT_ID"], + client_secret=os.environ["AURA_API_CLIENT_SECRET"], + project_id=os.environ.get("AURA_API_TENANT_ID"), + ) + ) + + +def create_auradb_instance(api: AuraApi) -> DbmsConnectionInfo: instance_details: InstanceCreateDetails = api.create_instance( - name="ci-neo4j-viz-session", - memory=SessionMemory.m_8GB.value, + name="ci-neo4j-viz-db", + memory=SessionMemory.m_2GB.value, cloud_provider="gcp", region="europe-west1", + type="enterprise-db", ) wait_result = api.wait_for_instance_running(instance_id=instance_details.id) if wait_result.error: raise Exception(f"Error while waiting for instance to be running: {wait_result.error}") - return instance_details.id, DbmsConnectionInfo( - uri=wait_result.connection_url, - username="neo4j", - password=instance_details.password, + return DbmsConnectionInfo( + username="neo4j", password=instance_details.password, aura_instance_id=instance_details.id ) diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py index fb078aa..47dfbd3 100644 --- a/python-wrapper/tests/test_gds.py +++ b/python-wrapper/tests/test_gds.py @@ -20,12 +20,28 @@ def db_setup(gds: Any) -> Generator[None, None, None]: gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n") +def project_graph(gds: Any) -> Any: + from graphdatascience import GraphDataScience + from graphdatascience.session import AuraGraphDataScience + + if isinstance(gds, GraphDataScience): + return gds.graph.project("g2", ["*"], "*") + elif isinstance(gds, AuraGraphDataScience): + return gds.v2.graph.project( + "g2", + query=""" + MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m) + """, + ) + raise Exception(f"Unsupported GDS type {type(gds)}") + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.requires_neo4j_and_gds def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None: from neo4j_viz.gds import from_gds - with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G: + with project_graph(gds) as G: VG = from_gds(gds, G, db_node_properties=["name"]) assert len(VG.nodes) == 2 @@ -106,7 +122,7 @@ def test_from_gds_integration_all_properties(gds: Any) -> None: def test_from_gds_sample(gds: Any) -> None: from neo4j_viz.gds import from_gds - with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G: + with gds.v2.graph.generate("hello", node_count=11_000, average_degree=1) as G: with pytest.warns( UserWarning, match=re.escape( diff --git a/python-wrapper/uv.lock b/python-wrapper/uv.lock index 3b3044a..e7ab7a4 100644 --- a/python-wrapper/uv.lock +++ b/python-wrapper/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -1028,7 +1028,7 @@ wheels = [ [[package]] name = "graphdatascience" -version = "1.20" +version = "1.21" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "multimethod" }, @@ -1044,9 +1044,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/3c/c9cddf119a5ec642a7f5d4b72b423994bc65ea6298be95735a4fb44342c6/graphdatascience-1.20.tar.gz", hash = "sha256:1b993b25196adacf6754463985cedcdeea18777d0624847f27546c6e9a78e069", size = 1744493, upload-time = "2026-02-26T09:23:15.98Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/e9/a49127cc728c7ea2acaa2aad44f06f291aabe0a5a700f980ffdf88f853da/graphdatascience-1.21.tar.gz", hash = "sha256:5a9f16be010eee69d027c5b1ea76bba7029fad8426646c603f137cc9841e3934", size = 1746004, upload-time = "2026-04-16T11:49:01.974Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/72/38/70226ef4d6a909153e84b7b1815f26f4a29cf70df6a9cf3a2b3357b65ced/graphdatascience-1.20-py3-none-any.whl", hash = "sha256:6cf4a81029073f612ccf2371812357e7912c5b5bb09316cb856a935ca0c0e7c6", size = 2014190, upload-time = "2026-02-26T09:23:13.887Z" }, + { url = "https://files.pythonhosted.org/packages/d5/67/a2a944d6d6c0baea1b30017435d2b04ce0b766328c277fd90b7ee36b7bbc/graphdatascience-1.21-py3-none-any.whl", hash = "sha256:4813c3fa6eef5d469a7d344ac37b8992461ed067ee30b6b11bd17c3d5c471592", size = 2016416, upload-time = "2026-04-16T11:48:59.86Z" }, ] [[package]] @@ -2455,7 +2455,7 @@ notebook = [ requires-dist = [ { name = "anywidget", specifier = ">=0.9,<1" }, { name = "enum-tools", specifier = "==0.13.0" }, - { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1,<2" }, + { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1.21,<2" }, { name = "ipython", specifier = ">=7,<10" }, { name = "neo4j", marker = "extra == 'neo4j'" }, { name = "pandas", marker = "extra == 'pandas'", specifier = ">=2,<3" },