Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@ lightning_logs
.env
dev

*.ckpt
*.ckpt
web/build
web/node_modules
web/.svelte-kit

share/lobster/ui
65 changes: 65 additions & 0 deletions hatch_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import shutil
import subprocess
import sys
from pathlib import Path

from hatchling.builders.hooks.plugin.interface import BuildHookInterface


class CustomHook(BuildHookInterface):
"""
A custom build hook for lobster that builds the UI in share/lobster
"""

def initialize(self, version, build_data):
# https://hatch.pypa.io/1.1/plugins/build-hook/#hatchling.builders.hooks.plugin.interface.BuildHookInterface

# Hatchling intends for us to mutate the input build_data communicate
# that 'share/tiled/ui' contains build artifacts that should be included
# in the distribution.

# Set this irrespective of whether the build happens below. It may have
# already been done manually by the user. This simply allow-lists the
# files, however they were put there.
artifact_path = "share/lobster/ui" # must be relative
build_data["artifacts"].append(artifact_path)

if os.getenv("LOBSTER_BUILD_SKIP_UI"):
print(
"Will skip building the lobster web UI because LOBSTER_BUILD_SKIP_UI is set",
file=sys.stderr,
)
return
npm_path = shutil.which("npm")
if npm_path is None:
print(
"Will skip building the lobster web UI because 'npm' executable is not found",
file=sys.stderr,
)
return
print(
f"Building lobster web UI using {npm_path!r}. (Set LOBSTER_BUILD_SKIP_UI=1 to skip.)",
file=sys.stderr,
)
try:
subprocess.check_call([npm_path, "install"], cwd="web")
subprocess.check_call(
[
npm_path,
"run",
"build",
],
cwd="web",
)
if Path(artifact_path).exists():
shutil.rmtree(artifact_path)
shutil.copytree("web/build", artifact_path)
except Exception:
print(
f"There was an error while building the lobster web UI using {npm_path!r}. "
"If you do not need the web UI, you can LOBSTER_BUILD_SKIP_UI=1 to skip it; "
"the Python aspects will work fine without it.",
file=sys.stderr,
)
raise
41 changes: 25 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dependencies = [
]

[build-system]
requires = ["setuptools >= 65", "setuptools_scm[toml]>=6.2"]
build-backend = 'setuptools.build_meta'
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[project.scripts]
lobster_train = "lobster.cmdline:train"
Expand All @@ -52,6 +52,7 @@ lobster_intervene_multiproperty = "lobster.cmdline:intervene_multiproperty"
lobster_perplexity = "lobster.cmdline:perplexity"
lobster_eval_embed = "lobster.cmdline:eval_embed"
lobster_eval = "lobster.cmdline:evaluate"
lobster_web = "lobster.server:serve"
lobster_mcp_server = "lobster.mcp.inference_server:main"
lobster_mcp_setup = "lobster.mcp.setup:main"

Expand All @@ -70,6 +71,7 @@ mcp = [
"fastmcp>=0.2.0",
"python-Levenshtein>=0.20.0",
"pydantic>=2.0.0",
"fastapi>=0.115.14",
]
# eval = [
# "umap-learn<=0.5.6"
Expand All @@ -78,20 +80,10 @@ trl = [
"trl",
"accelerate",
]

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
lobster = ["*.txt", "*.json", "*.yaml"]
"lobster.assets" = ["**/*.txt", "**/*.json", "**/*.yaml"]
"lobster.hydra_config" = [ "**/*.yaml"]

[tool.setuptools_scm]
search_parent_directories = true
version_scheme = "no-guess-dev"
fallback_version = "0.0.0"
local_scheme = "no-local-version" # see https://github.com/pypa/setuptools-scm/issues/455
web = [
"fastapi>=0.115.14",
"uvicorn>=0.34.3",
]

[tool.ruff]
line-length = 120
Expand Down Expand Up @@ -158,3 +150,20 @@ dev = [

[[tool.uv.index]]
url = "https://pypi.python.org/simple"

[tool.hatch]
version.source = "vcs"
version.fallback-version = "0.0.0"

[tool.hatch.version.raw-options]
local_scheme = "no-local-version"

# Defining this (empty) section invokes hatch_build.py
[tool.hatch.build.hooks.custom]

[tool.hatch.build.targets.wheel]
packages = ["src/lobster"]

[tool.hatch.build.targets.wheel.shared-data]
"share/lobster/.identifying_file_42688fa5f3a65e2caa822dec544d5694" = "share/lobster/.identifying_file_42688fa5f3a65e2caa822dec544d5694"
"share/lobster/ui" = "share/lobster/ui"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this file is used to identify share/lobster
19 changes: 2 additions & 17 deletions src/lobster/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,2 @@
"""
Lobster MCP (Model Context Protocol) Integration

This module provides MCP servers that expose Lobster's pretrained models
for sequence representation, concept analysis, and interventions.

Available when the 'mcp' extra is installed:
uv sync --extra mcp
"""

try:
from .inference_server import LobsterInferenceServer

__all__ = ["LobsterInferenceServer"]
except ImportError:
# MCP dependencies not installed
__all__ = []
# TODO
# from ._mcp import serve
9 changes: 9 additions & 0 deletions src/lobster/mcp/_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastmcp import FastMCP

from lobster.server import app

mcp = FastMCP.from_fastapi(app=app)


def serve():
mcp.run()
2 changes: 0 additions & 2 deletions src/lobster/server/_server.py → src/lobster/mcp/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
class UMEServer:
def __init__(self, model: UME):
"""Initialize the UME MCP server with a model.

Args:
model: A UME model instance to use for embeddings
"""
Expand All @@ -26,7 +25,6 @@ def embed_sequences(sequences, modality, aggregate=True):

def get_server(self):
"""Get the FastMCP server instance.

Returns:
FastMCP: The configured MCP server instance
"""
Expand Down
1 change: 1 addition & 0 deletions src/lobster/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._app import app, serve
76 changes: 76 additions & 0 deletions src/lobster/server/_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from typing import Literal
from pathlib import Path

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field

from ._esm import get_esm_cached, esm_aa_naturalness
from ._utils import SHARE_LOBSTER_PATH

DEFAULT_LOBSTER_ALLOW_ORIGINS = "*"

ALLOW_ORIGINS = os.getenv("LOBSTER_ALLOW_ORIGINS", DEFAULT_LOBSTER_ALLOW_ORIGINS).split(",")


class NaturalnessInput(BaseModel):
sequence: str
model_name: Literal[
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t30_150M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] = Field(default="facebook/esm2_t6_8M_UR50D")


class NaturalnessOutput(BaseModel):
logp: list[list[float]]
wt_logp: list[float]
naturalness: float
encoded: list[float]


app = FastAPI()


@app.post("/naturalness")
def naturalness(input: NaturalnessInput) -> NaturalnessOutput:
L = len(input.sequence)
if not L > 0:
return {
"logp": [],
"wt_logp": [],
"naturalness": None,
"encoded": [],
}

model, tokenizer = get_esm_cached(input.model_name)

out = esm_aa_naturalness(sequence=input.sequence, model=model, tokenizer=tokenizer, batch_size=64)

return out


app.mount(
"/",
StaticFiles(directory=Path(SHARE_LOBSTER_PATH, "ui"), html=True),
name="ui",
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


def serve():
import uvicorn

host = "localhost"
port = 8000

uvicorn.run(app, host=host, port=port)
72 changes: 72 additions & 0 deletions src/lobster/server/_esm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import functools
from typing import Literal

import einops
import torch
from transformers import EsmForMaskedLM, EsmTokenizer
from beignet.constants import STANDARD_RESIDUES

from ._utils import single_position_masked_sequences, batched


@functools.cache
def get_esm_cached(
mlm_model_name: Literal[
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t30_150M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] = "facebook/esm2_t6_8M_UR50D",
):
assert mlm_model_name in {
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t30_150M_UR50D",
"facebook/esm2_t33_650M_UR50D",
}
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = EsmTokenizer.from_pretrained(mlm_model_name, clean_up_tokenization_spaces=False)
model = EsmForMaskedLM.from_pretrained(mlm_model_name)
model.to(device)
model.eval()

return model, tokenizer


def esm_aa_naturalness(sequence: str, model, tokenizer, batch_size: int = 16) -> dict:
L = len(sequence)
vocab = tokenizer.get_vocab()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

assert set(sequence) <= set(STANDARD_RESIDUES)

vocab_id_to_aa_id = {vocab[aa]: i for i, aa in enumerate(sorted(STANDARD_RESIDUES))}
amino_acid_vocab_ids = list(vocab_id_to_aa_id.keys())

encoded = tokenizer(sequence, return_tensors="pt")
encoded_aa_id = torch.as_tensor(
[vocab_id_to_aa_id[x] for x in encoded["input_ids"].squeeze(0)[1:-1].cpu().tolist()], device=model.device
)

masked_sequences = single_position_masked_sequences(sequence)

with torch.inference_mode():
logits = torch.cat(
[
model(**{k: v.to(device) for k, v in tokenizer(input, return_tensors="pt").items()}).logits
for input in batched(masked_sequences, batch_size)
],
dim=0,
)
logits = torch.diagonal(logits[:, 1:-1, :], dim1=0, dim2=1)
logits = einops.rearrange(logits, "token length -> length token", length=L)

logits = logits[:, torch.as_tensor(amino_acid_vocab_ids, device=model.device)]
logp = torch.nn.functional.log_softmax(logits, dim=-1)
wt_logp = logp[torch.arange(L), encoded_aa_id]
naturalness = torch.exp(wt_logp.mean()).item()

return {
"logp": logp.cpu().tolist(),
"wt_logp": wt_logp.cpu().tolist(),
"naturalness": naturalness,
"encoded": encoded_aa_id.cpu().tolist(),
}
Loading
Loading