Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ scripts/combined_db*
*_play.py
src/lobster/hydra_config/experiment/*
src/lobster/mcp/claude_desktop_config.json
*.ipynb_checkpoints

notebooks/nathan/*
notebooks/karina/*
notebooks/amyxlu/*

models/*

Expand Down
17 changes: 10 additions & 7 deletions src/lobster/data/_fasta_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import importlib
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import Any, TypeVar
from typing import Any, TypeVar, Optional

import pandas as pd
import numpy as np
import torch.utils.data

# from beignet.datasets import FASTADataset
Expand Down Expand Up @@ -43,6 +44,7 @@
is_relative_model: bool = False,
tokenizer_dir: str | None = "pmlm_tokenizer",
mlm: bool = True,
offsets_arr: Optional[np.ndarray] = None,

Check failure on line 47 in src/lobster/data/_fasta_datamodule.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP007)

src/lobster/data/_fasta_datamodule.py:47:22: UP007 Use `X | Y` for type annotations
) -> None:
"""
:param path_to_fasta: path to fasta file
Expand Down Expand Up @@ -139,6 +141,7 @@
self._is_relative_model = is_relative_model
self._tokenizer_dir = tokenizer_dir
self._mlm = mlm
self._offsets_arr = offsets_arr

path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir
self._transform_fn = transform_fn or PmlmTokenizerTransform(
Expand All @@ -159,16 +162,16 @@
if stage == "fit":
if any(["train" in self._path_to_fasta]): # pre computed splits
self._train_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "train" in p]
[FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "train" in p]
)
self._val_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "val" in p]
[FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "val" in p]
)
self._test_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "test" in p]
[FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "test" in p]
)
else: # iid split
datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta]
datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta]
dataset = torch.utils.data.ConcatDataset(datasets)
(
self._train_dataset,
Expand All @@ -181,7 +184,7 @@
)

if stage == "predict":
datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta]
datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta]
dataset = torch.utils.data.ConcatDataset(datasets)
self._predict_dataset = dataset

Expand Down Expand Up @@ -236,4 +239,4 @@
seq_dict = dict(seqs_for_dl)
seq_dict_df = pd.DataFrame(seq_dict.items(), columns=["input_ids", "Labels"])
seq_dict_df = Dataset.from_pandas(seq_dict_df)
return seq_dict_df
return seq_dict_df

Check failure on line 242 in src/lobster/data/_fasta_datamodule.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

src/lobster/data/_fasta_datamodule.py:242:27: W292 No newline at end of file
20 changes: 12 additions & 8 deletions src/lobster/datasets/_fasta_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import TypeVar
from typing import TypeVar, Optional

import numpy
from beignet.datasets._sized_sequence_dataset import SizedSequenceDataset
Expand All @@ -17,6 +17,7 @@
*,
transform: Callable | None = None,
use_text_descriptions: bool = True,
offsets_arr: Optional[numpy.ndarray] = None,

Check failure on line 20 in src/lobster/datasets/_fasta_dataset.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP007)

src/lobster/datasets/_fasta_dataset.py:20:22: UP007 Use `X | Y` for type annotations
) -> None:
if isinstance(root, str):
root = Path(root)
Expand All @@ -32,14 +33,17 @@

self.data = ThreadSafeFile(self.root, open)

offsets = Path(f"{self.root}.offsets.npy")
if offsets_arr is None:
offsets_path = Path(f"{self.root}.offsets.npy")
if offsets_path.exists():
self.offsets, sizes = numpy.load(f"{offsets_path}")
else:
self.offsets, sizes = self._build_index()
numpy.save(f"{offsets_path}", numpy.stack([self.offsets, sizes]))

if offsets.exists():
self.offsets, sizes = numpy.load(f"{offsets}")
else:
self.offsets, sizes = self._build_index()

numpy.save(f"{offsets}", numpy.stack([self.offsets, sizes]))
self.offsets = offsets_arr[0, :]
sizes = offsets_arr[1, :]

self.transform = transform

Expand Down Expand Up @@ -93,4 +97,4 @@
dtype=numpy.int64,
sep=" ",
),
)
)

Check failure on line 100 in src/lobster/datasets/_fasta_dataset.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

src/lobster/datasets/_fasta_dataset.py:100:10: W292 No newline at end of file
72 changes: 42 additions & 30 deletions src/lobster/model/_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lightning.pytorch as pl
import torch
from torch.nn import CrossEntropyLoss
from transformers import LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline

from lobster.constants import SchedulerType
from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform
Expand All @@ -13,6 +13,9 @@
from ._clm_configuration import PCLM_CONFIG_ARGS


ALLOWABLE_MODEL_NAMES = list(PCLM_CONFIG_ARGS.keys()) + ["ProtGPT2"]


class LobsterPCLM(pl.LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -68,36 +71,45 @@ def __init__(
self.scheduler_kwargs = scheduler_kwargs or {}
model_kwargs = model_kwargs or {}

if self._tokenizer_dir is not None:
path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir
self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False)
self._transform_fn = transform_fn or PmlmTokenizerTransform(
path,
padding="max_length",
truncation=True,
max_length=self._max_length,
mlm=False,
assert model_name in ALLOWABLE_MODEL_NAMES, f"model_name must be one of {ALLOWABLE_MODEL_NAMES}"

if model_name == "ProtGPT2":
self.tokenizer = AutoTokenizer.from_pretrained("nferruz/ProtGPT2")
self.model = AutoModelForCausalLM.from_pretrained("nferruz/ProtGPT2")
self.config = self.model.config

else:
# Create PCLM model
if self._tokenizer_dir is not None:
path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir
self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False)
self._transform_fn = transform_fn or PmlmTokenizerTransform(
path,
padding="max_length",
truncation=True,
max_length=self._max_length,
mlm=False,
)

config_args = PCLM_CONFIG_ARGS[model_name]
if num_key_value_heads is None:
num_key_value_heads = config_args["num_attention_heads"]
self._num_key_value_heads = num_key_value_heads

config = LlamaConfig(
**config_args,
mask_token_id=self.tokenizer.mask_token_id,
pad_token_id=self.tokenizer.pad_token_id,
cls_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.eos_token_id,
vocab_size=len(self.tokenizer.get_vocab()),
max_position_embeddings=self._max_length,
num_key_value_heads=self._num_key_value_heads,
attention_bias=self._attention_bias,
**model_kwargs,
)

config_args = PCLM_CONFIG_ARGS[model_name]
if num_key_value_heads is None:
num_key_value_heads = config_args["num_attention_heads"]
self._num_key_value_heads = num_key_value_heads

config = LlamaConfig(
**config_args,
mask_token_id=self.tokenizer.mask_token_id,
pad_token_id=self.tokenizer.pad_token_id,
cls_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.eos_token_id,
vocab_size=len(self.tokenizer.get_vocab()),
max_position_embeddings=self._max_length,
num_key_value_heads=self._num_key_value_heads,
attention_bias=self._attention_bias,
**model_kwargs,
)
self.model = LlamaForCausalLM(config)
self.config = self.model.config
self.model = LlamaForCausalLM(config)
self.config = self.model.config

self.save_hyperparameters(logger=False)

Expand Down
Loading
Loading