Skip to content
Open
29 changes: 25 additions & 4 deletions packages/reflex-base/src/reflex_base/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
unwrap_var_annotation,
)
from reflex_base.style import Style, format_as_emotion
from reflex_base.utils import console, format, imports, types
from reflex_base.utils import console, format, imports, memo_paths, types
from reflex_base.utils.imports import ImportDict, ImportVar, ParsedImportDict
from reflex_base.vars import VarData
from reflex_base.vars.base import (
Expand Down Expand Up @@ -2092,6 +2092,11 @@ class CustomComponent(Component):
doc="The props of the component.", default_factory=dict
)

_source_module: str | None = field(
doc="The user-app Python module that defined this memo, used to mirror its compiled JSX path.",
default=None,
)

def _post_init(self, **kwargs):
"""Initialize the custom component.

Expand Down Expand Up @@ -2156,6 +2161,9 @@ def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
# Set the tag to the name of the function.
self.tag = format.to_title_case(self.component_fn.__name__)

if specifier := memo_paths.library_specifier_for(self._source_module):
self.library = specifier

for key, value in props.items():
# Skip kwargs that are not props.
if key not in props_types:
Expand Down Expand Up @@ -2304,11 +2312,15 @@ def _get_all_app_wrap_components(

def _register_custom_component(
component_fn: Callable[..., Component],
source_module: str | None = None,
):
"""Register a custom component to be compiled.

Args:
component_fn: The function that creates the component.
source_module: The user-app Python module that defined the component,
used to mirror its compiled JSX path. ``None`` falls back to the
legacy ``utils/components`` location.

Returns:
The custom component.
Expand All @@ -2331,6 +2343,7 @@ def _register_custom_component(
dummy_component = CustomComponent._create(
children=[],
component_fn=component_fn,
_source_module=source_module,
**dummy_props,
)
if dummy_component.tag is None:
Expand All @@ -2351,18 +2364,26 @@ def custom_component(
Returns:
The decorated function.
"""
source_module = memo_paths.capture_source_module(component_fn)

@wraps(component_fn)
def wrapper(*children, **props) -> CustomComponent:
# Remove the children from the props.
props.pop("children", None)
return CustomComponent._create(
children=list(children), component_fn=component_fn, **props
children=list(children),
component_fn=component_fn,
_source_module=source_module,
**props,
)

# Register this component so it can be compiled.
dummy_component = _register_custom_component(component_fn)
dummy_component = _register_custom_component(component_fn, source_module)
if tag := dummy_component.tag:
import_specifier = (
memo_paths.library_specifier_for(source_module)
or f"$/{constants.Dirs.UTILS}/components"
)
object.__setattr__(
wrapper,
"_as_var",
Expand All @@ -2371,7 +2392,7 @@ def wrapper(*children, **props) -> CustomComponent:
_var_type=type[Component],
_var_data=VarData(
imports={
f"$/{constants.Dirs.UTILS}/components": [ImportVar(tag=tag)],
import_specifier: [ImportVar(tag=tag)],
"@emotion/react": [
ImportVar(tag="jsx"),
],
Expand Down
1 change: 1 addition & 0 deletions packages/reflex-base/src/reflex_base/plugins/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ class PageContext(BaseContext):
frontend_imports: ParsedImportDict = dataclasses.field(default_factory=dict)
output_path: str | None = None
output_code: str | None = None
source_module: str | None = None
# Stack of ``id(component)`` for components whose subtree is
# memoize-suppressed. Populated by ``MemoizeStatefulPlugin`` when it
# encounters a ``MemoizationLeaf``-style snapshot boundary and popped on
Expand Down
213 changes: 213 additions & 0 deletions packages/reflex-base/src/reflex_base/utils/memo_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Mirror user-app Python module paths into the compiler's ``.web`` output.

The compiler uses these helpers to write each memo's compiled JSX to a path
that mirrors its Python source module, instead of bundling everything into
``.web/utils/components.jsx``. This module owns the small set of helpers that:

- Read ``fn.__module__`` and reject framework / synthetic modules.
- Walk the live frame stack as a fallback for entry points that don't take a
user-supplied callable (notably ``app.add_page(component)`` with a Component
instance).
- Translate a dotted Python module name into mirrored JSX path segments and
the corresponding ``$/...`` library specifier consumed by the import system.
"""

from __future__ import annotations

import functools
import importlib.util
import inspect
from collections.abc import Callable
from pathlib import Path

# Modules whose names start with one of these prefixes are treated as
# framework code and never mirrored. Mirroring them would emit ``.web/reflex/...``
# files for memos defined inside the framework's own component packages.
_FRAMEWORK_MODULE_PREFIXES = (
"reflex.",
"reflex_base.",
"reflex_components_",
"reflex_site_shared.",
"reflex_hosting_cli.",
"reflex_docgen.",
)

# Bare module names that are treated as framework. Prefix matches above use
# trailing dots, so the bare ``reflex`` package itself is matched here.
_FRAMEWORK_MODULE_NAMES = frozenset({
"reflex",
"reflex_base",
"reflex_site_shared",
"reflex_hosting_cli",
"reflex_docgen",
})


def _is_framework_module(module_name: str) -> bool:
"""Whether ``module_name`` belongs to the framework itself.

Args:
module_name: The dotted module name.

Returns:
True if the module is part of the framework and should not be
mirrored under ``.web/``.
"""
if module_name in _FRAMEWORK_MODULE_NAMES:
return True
return module_name.startswith(_FRAMEWORK_MODULE_PREFIXES)


def capture_source_module(fn: Callable | None) -> str | None:
"""Return the user-app module name for ``fn``, or ``None`` if not user code.

Reads ``fn.__module__`` directly — Python sets this on every function
definition, and it survives re-exports, decorators that ``functools.wraps``
correctly, and aliasing. Returns ``None`` for ``__main__``, missing
modules, and framework modules.

Args:
fn: The user callable whose definition module is wanted.

Returns:
The dotted module name to mirror under ``.web/``, or ``None`` to fall
back to the legacy un-mirrored output path.
"""
if fn is None:
return None
module_name = getattr(fn, "__module__", None)
if not module_name or module_name == "__main__":
return None
if _is_framework_module(module_name):
return None
return module_name


def resolve_user_module_from_frame(skip: int = 0) -> str | None:
"""Walk the live frame stack and return the first user-app module name.

Used only as a fallback for ``app.add_page(component)`` when the caller
passed a pre-built ``Component`` instance instead of a callable, so there
is no ``__module__`` to read directly.

Args:
skip: Number of frames above the immediate caller to skip before
starting the search. Pass ``1`` to ignore the function that is
calling this helper.

Returns:
The first frame's module name that is not a framework module, or
``None`` if no suitable frame exists (e.g. running inside a REPL).
"""
frame = inspect.currentframe()
if frame is None:
return None
frame = frame.f_back
for _ in range(skip):
if frame is None:
return None
frame = frame.f_back
while frame is not None:
module_name = frame.f_globals.get("__name__")
if (
module_name
and module_name != "__main__"
and not _is_framework_module(module_name)
):
return module_name
frame = frame.f_back
return None


def _segment_is_safe(segment: str) -> bool:
"""Whether ``segment`` is a path-safe Python identifier-like fragment.

Args:
segment: A single dotted-module segment.

Returns:
True if the segment can be used as a directory or filename without
introducing path traversal or platform-specific quirks.
"""
if not segment or segment in {".", ".."}:
return False
return not any(ch in segment for ch in ("/", "\\", ":", "\0"))


@functools.cache
def module_to_mirrored_segments(module_name: str | None) -> tuple[str, ...] | None:
"""Translate a dotted module name to a tuple of mirrored path segments.

For a *package* (a module whose import resolves to ``__init__.py``), an
extra ``"index"`` segment is appended so the file lives at
``<pkg>/index.jsx`` and submodule files can coexist alongside it as
siblings under ``<pkg>/``.

Args:
module_name: The dotted Python module name. ``None`` short-circuits.

Returns:
A tuple of safe path segments to join under ``.web/``, or ``None`` if
the module name is missing, contains unsafe segments, or cannot be
resolved as a package vs. module.
"""
if not module_name:
return None
segments = module_name.split(".")
if not all(_segment_is_safe(seg) for seg in segments):
return None
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ValueError):
spec = None
if spec is not None and spec.origin and spec.origin.endswith("__init__.py"):
return (*segments, "index")
return tuple(segments)
Comment thread
FarhanAliRaza marked this conversation as resolved.
Outdated


def library_specifier_for(source_module: str | None) -> str | None:
"""Return the ``$/...`` import specifier mirroring ``source_module``, or None.

Args:
source_module: The dotted module name a memo was defined in.

Returns:
The ``$/<segments>`` specifier, or ``None`` if no source module was
captured or it can't be safely mirrored.
"""
if source_module is None:
return None
segments = module_to_mirrored_segments(source_module)
if segments is None:
return None
return mirrored_library_specifier(segments)


def mirrored_jsx_path(web_dir: Path, segments: tuple[str, ...]) -> Path:
"""Build the absolute ``.jsx`` path under ``web_dir`` for ``segments``.

Args:
web_dir: The project's ``.web`` directory.
segments: Mirrored path segments from
:func:`module_to_mirrored_segments`.

Returns:
The absolute path the compiler should write the memo module to.
"""
return web_dir.joinpath(*segments).with_suffix(".jsx")


def mirrored_library_specifier(segments: tuple[str, ...]) -> str:
"""Build the ``$/...`` import specifier for mirrored ``segments``.

The specifier has no extension; Vite resolves the ``.jsx`` automatically.

Args:
segments: Mirrored path segments from
:func:`module_to_mirrored_segments`.

Returns:
A ``$/`` prefixed module specifier suitable for use as a
``Component.library`` value.
"""
return "$/" + "/".join(segments)
2 changes: 1 addition & 1 deletion pyi_hashes.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,5 @@
"packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7",
"reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df",
"reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e",
"reflex/experimental/memo.pyi": "9946d9b757f7cef5f53d599194d6e50e"
"reflex/experimental/memo.pyi": "ad3685fc293017ebfe2d7803128aaaa8"
}
14 changes: 13 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from reflex_base.event.context import EventContext
from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor
from reflex_base.registry import RegistrationContext
from reflex_base.utils import console
from reflex_base.utils import console, memo_paths
from reflex_base.utils.imports import ImportVar
from reflex_base.utils.types import ASGIApp, Message, Receive, Scope, Send
from reflex_components_core.base.error_boundary import ErrorBoundary
Expand Down Expand Up @@ -238,6 +238,7 @@ class UnevaluatedPage:
on_load: EventType[()] | None = None
meta: Sequence[Mapping[str, Any] | Component] = ()
context: Mapping[str, Any] = dataclasses.field(default_factory=dict)
_source_module: str | None = None

def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage:
"""Merge the other page into this one.
Expand All @@ -256,6 +257,9 @@ def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage:
else other.description,
on_load=self.on_load if self.on_load is not None else other.on_load,
context=self.context if self.context is not None else other.context,
_source_module=self._source_module
if self._source_module is not None
else other._source_module,
)


Expand Down Expand Up @@ -864,6 +868,13 @@ def add_page(
# Check if the route given is valid
verify_route_validity(route)

if isinstance(component, Callable):
source_module = memo_paths.capture_source_module(component)
else:
# The user passed a pre-built Component instance — fall back to
# walking the call stack from add_page's caller.
source_module = memo_paths.resolve_user_module_from_frame(skip=1)

unevaluated_page = UnevaluatedPage(
component=component,
route=route,
Expand All @@ -873,6 +884,7 @@ def add_page(
on_load=on_load,
meta=meta,
context=context or {},
_source_module=source_module,
)

if route in self._unevaluated_pages:
Expand Down
Loading
Loading