diff --git a/examples/gds-example.ipynb b/examples/gds-example.ipynb
index 57604af..3b032d6 100644
--- a/examples/gds-example.ipynb
+++ b/examples/gds-example.ipynb
@@ -2,6 +2,7 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "7fb27b941602401d91542211134fc71a",
"metadata": {},
"source": [
"# Visualizing Neo4j Graph Data Science (GDS) Graphs"
@@ -10,6 +11,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "acae54e37e7d407bbb7b55eff062a284",
"metadata": {},
"outputs": [],
"source": [
@@ -17,47 +19,103 @@
"%pip install matplotlib"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9a63283cbaf04dbcab1f6479b197f3a8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dotenv import load_dotenv\n",
+ "\n",
+ "load_dotenv()"
+ ]
+ },
{
"cell_type": "markdown",
+ "id": "8dd0d8092fe74a7c96281538738b07e2",
"metadata": {},
"source": [
- "## Setup GDS graph"
+ "## Setup GDS graph\n",
+ "\n",
+ "To use GDS, you can either use GDS as a plugin or Aura Graph Analytics.\n",
+ "In the following, you can choose:\n",
+ "\n",
+ " * Provide Aura API credentials and and use Aura Graph Analytics.\n",
+ " * Use Neo4j + GDS Plugin.\n",
+ "\n",
+ "For more information, see the [GDS documentation](https://neo4j.com/docs/graph-data-science/current/installation/)."
]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "72eea5119410473aa328ad9291626812",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
+ "from graphdatascience.session import (\n",
+ " GdsSessions,\n",
+ " DbmsConnectionInfo,\n",
+ " AuraAPICredentials,\n",
+ " SessionMemory,\n",
+ ")\n",
"from graphdatascience import GraphDataScience\n",
"\n",
"# Get Neo4j DB URI, credentials and name from environment if applicable\n",
- "NEO4J_URI = os.environ.get(\"NEO4J_URI\", \"bolt://localhost:7687\")\n",
- "NEO4J_AUTH = (\"neo4j\", None)\n",
- "NEO4J_DB = os.environ.get(\"NEO4J_DB\", \"neo4j\")\n",
- "if os.environ.get(\"NEO4J_USER\") and os.environ.get(\"NEO4J_PASSWORD\"):\n",
- " NEO4J_AUTH = (\n",
- " os.environ.get(\"NEO4J_USER\"),\n",
- " os.environ.get(\"NEO4J_PASSWORD\"),\n",
+ "db_connection = DbmsConnectionInfo(\n",
+ " aura_instance_id=os.environ.get(\"AURA_INSTANCEID\"),\n",
+ " username=os.environ[\"NEO4J_USERNAME\"],\n",
+ " password=os.environ[\"NEO4J_PASSWORD\"],\n",
+ " uri=os.environ[\"NEO4J_URI\"],\n",
+ ")\n",
+ "\n",
+ "session_name = \"neo4j-viz-gds-example\"\n",
+ "if os.environ.get(\"AURA_API_CLIENT_ID\"):\n",
+ " # Use Aura Graph Analytics\n",
+ " sessions = GdsSessions(\n",
+ " api_credentials=AuraAPICredentials(\n",
+ " client_id=os.environ[\"AURA_API_CLIENT_ID\"],\n",
+ " client_secret=os.environ[\"AURA_API_CLIENT_SECRET\"],\n",
+ " project_id=os.environ.get(\"AURA_API_PROJECT_ID\"),\n",
+ " )\n",
+ " )\n",
+ " gds = sessions.get_or_create(\n",
+ " session_name=session_name,\n",
+ " memory=SessionMemory.m_2GB,\n",
+ " db_connection=db_connection,\n",
" )\n",
- "gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)"
+ "else:\n",
+ " # Use GDS Plugin\n",
+ " sessions = None\n",
+ " gds = GraphDataScience(\n",
+ " endpoint=db_connection.get_uri(),\n",
+ " auth=(db_connection.username, db_connection.password),\n",
+ " )"
]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "8edb47106e1a46a883d545849b8ab81b",
"metadata": {},
"outputs": [],
"source": [
"G = gds.graph.load_cora(graph_name=\"cora\")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "10185d26023b46108eb7d9f57d49d2b3",
+ "metadata": {},
+ "source": []
+ },
{
"cell_type": "code",
"execution_count": null,
+ "id": "8763a12b2bbd4a93a75aff182afb95dc",
"metadata": {},
"outputs": [],
"source": [
@@ -71,6 +129,7 @@
},
{
"cell_type": "markdown",
+ "id": "7623eae2785240b9bd12b16a66d81610",
"metadata": {},
"source": [
"## Visualization"
@@ -79,1610 +138,32 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "7cdc8c89c7104fffa095e18ddfef8986",
"metadata": {},
"outputs": [],
"source": [
"from neo4j_viz.gds import from_gds\n",
"\n",
- "VG = from_gds(gds, G, max_node_count=500)\n",
- "str(VG)"
+ "VG = from_gds(\n",
+ " gds,\n",
+ " G,\n",
+ " max_node_count=100,\n",
+ ")"
]
},
{
"cell_type": "code",
- "execution_count": 41,
- "metadata": {
- "tags": [
- "preserve-output"
- ]
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " neo4j-viz\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 41,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "id": "b118ea5561624da68c537baed56e602f",
+ "metadata": {},
+ "outputs": [],
"source": [
- "VG.render(theme=\"auto\")"
+ "VG.render()"
]
},
{
"cell_type": "markdown",
+ "id": "938c804e27f84196a10c8828c723f798",
"metadata": {},
"source": [
"### Changing captions\n",
@@ -1694,6 +175,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "504fb2a444614c0babb325280ed9130a",
"metadata": {},
"outputs": [],
"source": [
@@ -1704,9 +186,8 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "id": "59bbdb311c014d738909a11f9e486628",
+ "metadata": {},
"outputs": [],
"source": [
"VG.render()"
@@ -1714,6 +195,7 @@
},
{
"cell_type": "markdown",
+ "id": "b43b363d81ae4b689946ece5c682cd59",
"metadata": {},
"source": [
"## Sizing the nodes\n",
@@ -1724,16 +206,17 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "8a65eabff63a45729fe45fb5ade58bdc",
"metadata": {},
"outputs": [],
"source": [
"VG.resize_nodes(property=\"pagerank\")\n",
- "VG.color_nodes(property=\"componentId\")\n",
"VG.render()"
]
},
{
"cell_type": "markdown",
+ "id": "c3933fab20d04ec698c2621248eb3be0",
"metadata": {},
"source": [
"### Coloring"
@@ -1741,6 +224,7 @@
},
{
"cell_type": "markdown",
+ "id": "4dd4641cc4064e0191573fe9c69df29b",
"metadata": {},
"source": [
"There are two main ways of coloring the nodes of a graph:\n",
@@ -1754,17 +238,17 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "id": "8309879909854d7188b41380fd92a7c3",
+ "metadata": {},
"outputs": [],
"source": [
- "VG.color_nodes(property=\"componentId\")\n",
+ "VG.color_nodes(property=\"subject\")\n",
"VG.render()"
]
},
{
"cell_type": "markdown",
+ "id": "3ed186c9a28b402fb0bc4494df01f08d",
"metadata": {},
"source": [
"Now, let us color by our continuous node field \"size\" that we computed above with PageRank, again using the default colors.\n",
@@ -1775,9 +259,8 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "id": "cb1e1581032b452c9409d6c6813c49d1",
+ "metadata": {},
"outputs": [],
"source": [
"from neo4j_viz.colors import ColorSpace\n",
@@ -1788,6 +271,7 @@
},
{
"cell_type": "markdown",
+ "id": "379cbbc1e968416e875cc15c1202d7eb",
"metadata": {},
"source": [
"#### Custom coloring\n",
@@ -1799,6 +283,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "277c27b1587741f2af2001be3712ef0d",
"metadata": {},
"outputs": [],
"source": [
@@ -1808,6 +293,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "db7b79bc585a40fcaf58bf750017e135",
"metadata": {},
"outputs": [],
"source": [
@@ -1825,1599 +311,17 @@
},
{
"cell_type": "code",
- "execution_count": 49,
- "metadata": {
- "tags": [
- "preserve-output"
- ]
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " \n",
- " neo4j-viz\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 49,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "id": "916684f9a58a4a2aa5f864670399430d",
+ "metadata": {},
+ "outputs": [],
"source": [
"VG.render()"
]
},
{
"cell_type": "markdown",
+ "id": "1671c31a24314836a5b85d7ef7fbf015",
"metadata": {},
"source": [
"### Render options\n",
@@ -3433,1593 +337,10 @@
},
{
"cell_type": "code",
- "execution_count": 50,
- "metadata": {
- "tags": [
- "preserve-output"
- ]
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " \n",
- " neo4j-viz\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 50,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "id": "33b0902fd34d4ace834912fa1002cf8e",
+ "metadata": {},
+ "outputs": [],
"source": [
"from neo4j_viz import Layout\n",
"\n",
@@ -5028,6 +349,7 @@
},
{
"cell_type": "markdown",
+ "id": "f6fa52606d8c4a75a9b52967216f8f3f",
"metadata": {},
"source": [
"## Saving the visualization"
@@ -5036,6 +358,7 @@
{
"cell_type": "code",
"execution_count": null,
+ "id": "f5a1fa73e5044315a093ec459c9be902",
"metadata": {},
"outputs": [],
"source": [
@@ -5050,6 +373,7 @@
},
{
"cell_type": "markdown",
+ "id": "cdf66aed5cc84ca1b48e60bad68798a8",
"metadata": {},
"source": [
"## Cleanup\n",
@@ -5060,22 +384,26 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": [
- "teardown"
- ]
- },
+ "id": "28d3efd5258a48a79c179ea5c6759f01",
+ "metadata": {},
"outputs": [],
"source": [
"gds.graph.drop(\"cora\")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3f9bc0b9dd2c44919cc8dcca39b469f8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if sessions:\n",
+ " sessions.delete(session_name=session_name)"
+ ]
}
],
- "metadata": {
- "language_info": {
- "name": "python"
- }
- },
+ "metadata": {},
"nbformat": 4,
- "nbformat_minor": 4
+ "nbformat_minor": 5
}
diff --git a/examples/neo4j-example.ipynb b/examples/neo4j-example.ipynb
index 898bb43..87e3f02 100644
--- a/examples/neo4j-example.ipynb
+++ b/examples/neo4j-example.ipynb
@@ -167,1613 +167,14 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "2322065c",
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " \n",
- " neo4j-viz\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"from neo4j_viz.neo4j import from_neo4j\n",
"\n",
@@ -1794,1613 +195,14 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "e8b0f4c6",
"metadata": {
"tags": [
"preserve-output"
]
},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " \n",
- " neo4j-viz\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"VG.render()"
]
diff --git a/justfile b/justfile
index ea83288..bc58784 100644
--- a/justfile
+++ b/justfile
@@ -18,13 +18,25 @@ py-test-gds:
trap "cd $ENV_DIR && docker compose down" EXIT
cd $ENV_DIR && docker compose up -d
cd -
+ cd python-wrapper && \
NEO4J_URI=bolt://localhost:7687 \
NEO4J_USER=neo4j \
NEO4J_PASSWORD=password \
NEO4J_DB=neo4j \
- cd python-wrapper && uv run --group dev --extra gds pytest tests --include-neo4j-and-gds
+ uv run --group dev --extra gds pytest tests --include-neo4j-and-gds
cd ..
+
+# this expects the local compose setup to be running.
+py-test-gds-sessions filter="":
+ #!/usr/bin/env bash
+ cd python-wrapper && \
+ GDS_SESSION_URI=bolt://localhost:7688 \
+ NEO4J_URI=bolt://localhost:7687 \
+ NEO4J_USER=neo4j \
+ NEO4J_PASSWORD=password \
+ uv run --group dev --extra gds pytest tests --include-neo4j-and-gds {{ if filter != "" { "-k '" + filter + "'" } else { "" } }}
+
local-neo4j-setup:
#!/usr/bin/env bash
set -e
diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml
index 0a63d73..baafb12 100644
--- a/python-wrapper/pyproject.toml
+++ b/python-wrapper/pyproject.toml
@@ -42,7 +42,7 @@ requires-python = ">=3.10"
[project.optional-dependencies]
pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"]
-gds = ["graphdatascience>=1, <2"]
+gds = ["graphdatascience>=1.21, <2"]
neo4j = ["neo4j"]
snowflake = ["snowflake-snowpark-python>=1, <2"]
@@ -76,9 +76,9 @@ notebook = [
"palettable>=3.3.3",
"matplotlib>=3.9.4",
"snowflake-snowpark-python==1.42.0",
- "python-dotenv",
"requests",
"marimo",
+ "python-dotenv"
]
[project.urls]
@@ -174,9 +174,3 @@ exclude = [
]
plugins = ['pydantic.mypy']
untyped_calls_exclude=["nbconvert"]
-
-[tool.marimo.runtime]
-output_max_bytes = 20_000_000
-#
-#[tool.marimo.server]
-#follow_symlink = true
diff --git a/python-wrapper/src/neo4j_viz/gds.py b/python-wrapper/src/neo4j_viz/gds.py
index e6a399a..d3feaff 100644
--- a/python-wrapper/src/neo4j_viz/gds.py
+++ b/python-wrapper/src/neo4j_viz/gds.py
@@ -2,11 +2,13 @@
import warnings
from itertools import chain
-from typing import Optional, cast
+from typing import Collection, Optional
from uuid import uuid4
import pandas as pd
from graphdatascience import Graph, GraphDataScience
+from graphdatascience.graph.v2 import GraphV2
+from graphdatascience.session import AuraGraphDataScience
from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace
@@ -15,48 +17,40 @@
def _fetch_node_dfs(
- gds: GraphDataScience,
- G: Graph,
+ gds: GraphDataScience | AuraGraphDataScience,
+ G: GraphV2,
node_properties_by_label: dict[str, list[str]],
- node_labels: list[str],
+ node_labels: Collection[str],
additional_db_node_properties: list[str],
) -> dict[str, pd.DataFrame]:
return {
- lbl: gds.graph.nodeProperties.stream(
+ lbl: gds.v2.graph.node_properties.stream(
G,
node_properties=node_properties_by_label[lbl],
node_labels=[lbl],
- separate_property_columns=True,
db_node_properties=additional_db_node_properties,
)
for lbl in node_labels
}
-def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]:
- rel_types = G.relationship_types()
-
- rel_props = {rel_type: G.relationship_properties(rel_type) for rel_type in rel_types}
+def _fetch_rel_dfs(gds: GraphDataScience, G: GraphV2) -> list[pd.DataFrame]:
+ rel_props = G.relationship_properties()
rel_dfs: list[pd.DataFrame] = []
# Have to call per stream per relationship type as there was a bug in GDS < 2.21
for rel_type, props in rel_props.items():
- assert isinstance(props, list)
- if len(props) > 0:
- rel_df = gds.graph.relationshipProperties.stream(
- G, relationship_types=rel_type, relationship_properties=list(props), separate_property_columns=True
- )
- else:
- rel_df = gds.graph.relationships.stream(G, relationship_types=[rel_type])
-
+ rel_df = gds.v2.graph.relationships.stream(
+ G, relationship_types=[rel_type], relationship_properties=list(props)
+ )
rel_dfs.append(rel_df)
return rel_dfs
def from_gds(
- gds: GraphDataScience,
- G: Graph,
+ gds: GraphDataScience | AuraGraphDataScience,
+ G: Graph | GraphV2,
node_properties: Optional[list[str]] = None,
db_node_properties: Optional[list[str]] = None,
max_node_count: int = 10_000,
@@ -76,9 +70,9 @@ def from_gds(
Parameters
----------
- gds : GraphDataScience
- GraphDataScience object.
- G : Graph
+ gds
+ GraphDataScience object. AuraGraphDataScience object if using Aura Graph Analytics.
+ G
Graph object.
node_properties : list[str], optional
Additional properties to include in the visualization node, by default None which means that all node
@@ -91,37 +85,41 @@ def from_gds(
"""
if db_node_properties is None:
db_node_properties = []
+ if isinstance(G, Graph):
+ G_v2 = gds.v2.graph.get(G.name())
+ else:
+ G_v2 = G
- node_properties_from_gds = G.node_properties()
- assert isinstance(node_properties_from_gds, pd.Series)
- actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict())
- all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values()))
+ gds_properties_per_label = G_v2.node_properties()
+ all_gds_properties = list(chain.from_iterable(gds_properties_per_label.values()))
node_properties_by_label_sets: dict[str, set[str]] = dict()
if node_properties is None:
- node_properties_by_label_sets = {k: set(v) for k, v in actual_node_properties.items()}
+ node_properties_by_label_sets = {k: set(v) for k, v in gds_properties_per_label.items()}
else:
for prop in node_properties:
- if prop not in all_actual_node_properties:
+ if prop not in all_gds_properties:
raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'")
- for label, props in actual_node_properties.items():
+ for label, props in gds_properties_per_label.items():
node_properties_by_label_sets[label] = {
- prop for prop in actual_node_properties[label] if prop in node_properties
+ prop for prop in gds_properties_per_label[label] if prop in node_properties
}
node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()}
- node_count = G.node_count()
+ node_count = G_v2.node_count()
if node_count > max_node_count:
warnings.warn(
- f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
+ f"The '{G_v2.name()}' projection's node count ({G_v2.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
)
sampling_ratio = float(max_node_count) / node_count
sample_name = f"neo4j-viz_sample_{uuid4()}"
- G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
+ G_fetched, _ = gds.v2.graph.sample.rwr(
+ G_v2, sample_name, sampling_ratio=sampling_ratio, node_label_stratification=True
+ )
else:
- G_fetched = G
+ G_fetched = G_v2
property_name = None
try:
@@ -129,12 +127,12 @@ def from_gds(
# as a temporary property to ensure that we have at least one property for each label to fetch
if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0:
property_name = f"neo4j-viz_property_{uuid4()}"
- gds.degree.mutate(G_fetched, mutateProperty=property_name)
+ gds.v2.degree_centrality.mutate(G_fetched, mutate_property=property_name)
for props in node_properties_by_label.values():
props.append(property_name)
node_dfs = _fetch_node_dfs(
- gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), db_node_properties
+ gds, G_fetched, node_properties_by_label, node_properties_by_label.keys(), db_node_properties
)
if property_name is not None:
for df in node_dfs.values():
@@ -145,7 +143,7 @@ def from_gds(
if G_fetched.name() != G.name():
G_fetched.drop()
elif property_name is not None:
- gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])
+ gds.v2.graph.node_properties.drop(G_fetched, node_properties=[property_name])
for df in node_dfs.values():
if property_name is not None and property_name in df.columns:
@@ -154,7 +152,7 @@ def from_gds(
node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates(subset=["nodeId"])
for lbl, df in node_dfs.items():
- if "labels" in all_actual_node_properties:
+ if "labels" in all_gds_properties:
df.rename(columns={"labels": "__labels"}, inplace=True)
df["labels"] = lbl
diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py
index 40f7f4e..14eead8 100644
--- a/python-wrapper/tests/conftest.py
+++ b/python-wrapper/tests/conftest.py
@@ -1,4 +1,5 @@
import os
+import random
from typing import Any, Generator
import pytest
@@ -31,15 +32,20 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None:
@pytest.fixture(scope="package")
-def aura_ds_instance() -> Generator[Any, None, None]:
+def aura_db_instance() -> Generator[Any, None, None]:
+ if os.environ.get("NEO4J_URI", ""):
+ print(f"Skipping Aura DB setup since NEO4J_URI is set to {os.environ['NEO4J_URI']}")
+ yield None
+ return
+
if os.environ.get("AURA_API_CLIENT_ID", None) is None:
yield None
return
- from tests.gds_helper import aura_api, create_aurads_instance
+ from tests.gds_helper import aura_api, create_auradb_instance
api = aura_api()
- id, dbms_connection_info = create_aurads_instance(api)
+ dbms_connection_info = create_auradb_instance(api)
# setting as environment variables to run notebooks with this connection
os.environ["NEO4J_URI"] = dbms_connection_info.get_uri()
@@ -47,40 +53,55 @@ def aura_ds_instance() -> Generator[Any, None, None]:
os.environ["NEO4J_USER"] = dbms_connection_info.username
assert isinstance(dbms_connection_info.password, str)
os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password
+ old_instance = os.environ.get("AURA_INSTANCEID", "")
+ if dbms_connection_info.aura_instance_id:
+ os.environ["AURA_INSTANCEID"] = dbms_connection_info.aura_instance_id
+
yield dbms_connection_info
# Clear Neo4j_URI after test (rerun should create a new instance)
- os.environ["NEO4J_URI"] = ""
- api.delete_instance(id)
+ os.environ["AURA_INSTANCEID"] = old_instance
+ assert dbms_connection_info.aura_instance_id is not None
+ api.delete_instance(dbms_connection_info.aura_instance_id)
@pytest.fixture(scope="package")
-def gds(aura_ds_instance: Any) -> Generator[Any, None, None]:
- from graphdatascience import GraphDataScience
+def gds(aura_db_instance: Any) -> Generator[Any, None, None]:
+ from graphdatascience.session import SessionMemory
- from tests.gds_helper import connect_to_plugin_gds
+ from tests.gds_helper import connect_to_local_gds_session, connect_to_plugin_gds, gds_sessions
- if aura_ds_instance:
- yield GraphDataScience(
- endpoint=aura_ds_instance.uri,
- auth=(aura_ds_instance.username, aura_ds_instance.password),
- aura_ds=True,
- database="neo4j",
+ if aura_db_instance:
+ sessions = gds_sessions()
+
+ gds = sessions.get_or_create(
+ f"neo4j-viz-ci-{os.environ.get('GITHUB_RUN_ID', random.randint(0, 10**6))}",
+ memory=SessionMemory.m_2GB,
+ db_connection=aura_db_instance,
)
+
+ yield gds
+ gds.delete()
else:
- NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687")
- gds = connect_to_plugin_gds(NEO4J_URI)
+ neo4j_uri = os.environ["NEO4J_URI"]
+ neo4j_auth = (os.environ.get("NEO4J_USER", "neo4j"), os.environ.get("NEO4J_PASSWORD", "password"))
+
+ session_uri = os.environ.get("GDS_SESSION_URI")
+ if session_uri:
+ gds = connect_to_local_gds_session(session_uri, neo4j_uri, neo4j_auth) # type: ignore
+ else:
+ gds = connect_to_plugin_gds(neo4j_uri, neo4j_auth) # type: ignore
yield gds
gds.close()
@pytest.fixture(scope="package")
-def neo4j_driver(aura_ds_instance: Any) -> Generator[Any, None, None]:
+def neo4j_driver(aura_db_instance: Any) -> Generator[Any, None, None]:
import neo4j
- if aura_ds_instance:
+ if aura_db_instance:
driver = neo4j.GraphDatabase.driver(
- aura_ds_instance.uri, auth=(aura_ds_instance.username, aura_ds_instance.password)
+ aura_db_instance.uri, auth=(aura_db_instance.username, aura_db_instance.password)
)
else:
NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687")
diff --git a/python-wrapper/tests/gds_helper.py b/python-wrapper/tests/gds_helper.py
index e5a0d3d..b81d6ae 100644
--- a/python-wrapper/tests/gds_helper.py
+++ b/python-wrapper/tests/gds_helper.py
@@ -1,9 +1,10 @@
import os
import re
-from graphdatascience import GraphDataScience
+from graphdatascience import GdsSessions, GraphDataScience
+from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
from graphdatascience.semantic_version.semantic_version import SemanticVersion
-from graphdatascience.session import DbmsConnectionInfo, SessionMemory
+from graphdatascience.session import AuraAPICredentials, AuraGraphDataScience, DbmsConnectionInfo, SessionMemory
from graphdatascience.session.aura_api import AuraApi
from graphdatascience.session.aura_api_responses import InstanceCreateDetails
from graphdatascience.version import __version__
@@ -26,12 +27,22 @@ def parse_version(version: str) -> SemanticVersion:
GDS_VERSION = parse_version(__version__)
-def connect_to_plugin_gds(uri: str) -> GraphDataScience:
- NEO4J_AUTH = ("neo4j", "password")
- if os.environ.get("NEO4J_USER"):
- NEO4J_AUTH = (os.environ.get("NEO4J_USER", "DUMMY"), os.environ.get("NEO4J_PASSWORD", "neo4j"))
+def connect_to_plugin_gds(uri: str, auth: tuple[str, str]) -> GraphDataScience:
+ return GraphDataScience(endpoint=uri, auth=auth, database="neo4j")
- return GraphDataScience(endpoint=uri, auth=NEO4J_AUTH, database="neo4j")
+
+def connect_to_local_gds_session(session_uri: str, db_uri: str, db_auth: tuple[str, str]) -> AuraGraphDataScience:
+ session_bolt_connection_info = DbmsConnectionInfo(uri=session_uri, username="neo4j", password="password")
+ db_connection_info = DbmsConnectionInfo(uri=db_uri, username=db_auth[0], password=db_auth[1])
+
+ return AuraGraphDataScience.create(
+ session_bolt_connection_info=session_bolt_connection_info,
+ arrow_authentication=UsernamePasswordAuthentication(
+ session_bolt_connection_info.username, session_bolt_connection_info.password
+ ),
+ session_lifecycle_manager=None, # type: ignore
+ db_endpoint=db_connection_info,
+ )
def aura_api() -> AuraApi:
@@ -49,21 +60,29 @@ def aura_api() -> AuraApi:
)
-def create_aurads_instance(api: AuraApi) -> tuple[str, DbmsConnectionInfo]:
- # Switch to Sessions once they can be created without a DB
+def gds_sessions() -> GdsSessions:
+ return GdsSessions(
+ api_credentials=AuraAPICredentials(
+ client_id=os.environ["AURA_API_CLIENT_ID"],
+ client_secret=os.environ["AURA_API_CLIENT_SECRET"],
+ project_id=os.environ.get("AURA_API_TENANT_ID"),
+ )
+ )
+
+
+def create_auradb_instance(api: AuraApi) -> DbmsConnectionInfo:
instance_details: InstanceCreateDetails = api.create_instance(
- name="ci-neo4j-viz-session",
- memory=SessionMemory.m_8GB.value,
+ name="ci-neo4j-viz-db",
+ memory=SessionMemory.m_2GB.value,
cloud_provider="gcp",
region="europe-west1",
+ type="enterprise-db",
)
wait_result = api.wait_for_instance_running(instance_id=instance_details.id)
if wait_result.error:
raise Exception(f"Error while waiting for instance to be running: {wait_result.error}")
- return instance_details.id, DbmsConnectionInfo(
- uri=wait_result.connection_url,
- username="neo4j",
- password=instance_details.password,
+ return DbmsConnectionInfo(
+ username="neo4j", password=instance_details.password, aura_instance_id=instance_details.id
)
diff --git a/python-wrapper/tests/test_gds.py b/python-wrapper/tests/test_gds.py
index fb078aa..47dfbd3 100644
--- a/python-wrapper/tests/test_gds.py
+++ b/python-wrapper/tests/test_gds.py
@@ -20,12 +20,28 @@ def db_setup(gds: Any) -> Generator[None, None, None]:
gds.run_cypher("MATCH (n:_CI_A|_CI_B) DETACH DELETE n")
+def project_graph(gds: Any) -> Any:
+ from graphdatascience import GraphDataScience
+ from graphdatascience.session import AuraGraphDataScience
+
+ if isinstance(gds, GraphDataScience):
+ return gds.graph.project("g2", ["*"], "*")
+ elif isinstance(gds, AuraGraphDataScience):
+ return gds.v2.graph.project(
+ "g2",
+ query="""
+ MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)
+ """,
+ )
+ raise Exception(f"Unsupported GDS type {type(gds)}")
+
+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.requires_neo4j_and_gds
def test_from_gds_integration_all_db_properties(gds: Any, db_setup: None) -> None:
from neo4j_viz.gds import from_gds
- with gds.graph.project("g2", ["_CI_A", "_CI_B"], "*") as G:
+ with project_graph(gds) as G:
VG = from_gds(gds, G, db_node_properties=["name"])
assert len(VG.nodes) == 2
@@ -106,7 +122,7 @@ def test_from_gds_integration_all_properties(gds: Any) -> None:
def test_from_gds_sample(gds: Any) -> None:
from neo4j_viz.gds import from_gds
- with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
+ with gds.v2.graph.generate("hello", node_count=11_000, average_degree=1) as G:
with pytest.warns(
UserWarning,
match=re.escape(
diff --git a/python-wrapper/uv.lock b/python-wrapper/uv.lock
index 3b3044a..e7ab7a4 100644
--- a/python-wrapper/uv.lock
+++ b/python-wrapper/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.10"
resolution-markers = [
"python_full_version >= '3.14'",
@@ -1028,7 +1028,7 @@ wheels = [
[[package]]
name = "graphdatascience"
-version = "1.20"
+version = "1.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "multimethod" },
@@ -1044,9 +1044,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/af/3c/c9cddf119a5ec642a7f5d4b72b423994bc65ea6298be95735a4fb44342c6/graphdatascience-1.20.tar.gz", hash = "sha256:1b993b25196adacf6754463985cedcdeea18777d0624847f27546c6e9a78e069", size = 1744493, upload-time = "2026-02-26T09:23:15.98Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/c2/e9/a49127cc728c7ea2acaa2aad44f06f291aabe0a5a700f980ffdf88f853da/graphdatascience-1.21.tar.gz", hash = "sha256:5a9f16be010eee69d027c5b1ea76bba7029fad8426646c603f137cc9841e3934", size = 1746004, upload-time = "2026-04-16T11:49:01.974Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/72/38/70226ef4d6a909153e84b7b1815f26f4a29cf70df6a9cf3a2b3357b65ced/graphdatascience-1.20-py3-none-any.whl", hash = "sha256:6cf4a81029073f612ccf2371812357e7912c5b5bb09316cb856a935ca0c0e7c6", size = 2014190, upload-time = "2026-02-26T09:23:13.887Z" },
+ { url = "https://files.pythonhosted.org/packages/d5/67/a2a944d6d6c0baea1b30017435d2b04ce0b766328c277fd90b7ee36b7bbc/graphdatascience-1.21-py3-none-any.whl", hash = "sha256:4813c3fa6eef5d469a7d344ac37b8992461ed067ee30b6b11bd17c3d5c471592", size = 2016416, upload-time = "2026-04-16T11:48:59.86Z" },
]
[[package]]
@@ -2455,7 +2455,7 @@ notebook = [
requires-dist = [
{ name = "anywidget", specifier = ">=0.9,<1" },
{ name = "enum-tools", specifier = "==0.13.0" },
- { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1,<2" },
+ { name = "graphdatascience", marker = "extra == 'gds'", specifier = ">=1.21,<2" },
{ name = "ipython", specifier = ">=7,<10" },
{ name = "neo4j", marker = "extra == 'neo4j'" },
{ name = "pandas", marker = "extra == 'pandas'", specifier = ">=2,<3" },