diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..3040da1
Binary files /dev/null and b/.DS_Store differ
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/FLIPv3.iml b/.idea/FLIPv3.iml
new file mode 100644
index 0000000..ecf479a
--- /dev/null
+++ b/.idea/FLIPv3.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..e460088
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..c419510
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..8acbfef
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..0ff281a
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/amlt/kky.yaml b/amlt/kky.yaml
new file mode 100644
index 0000000..725c154
--- /dev/null
+++ b/amlt/kky.yaml
@@ -0,0 +1,88 @@
+description: Results on FLIPv3.
+
+target:
+ service: sing
+ name: msrresrchvc
+ workspace_name: bio0
+
+environment:
+ image: amlt-sing/acpt-torch2.7.1-py3.10-cuda12.6-ubuntu22.04
+ setup:
+ - cd sequence_models
+ - pip install -U --user -e .
+ - cd ..
+ - pip install biopython
+ - pip install httpx
+ - pip install esm
+# - pip install scikit-learn
+# - pip install scipy
+storage:
+ amlt:
+ storage_account_name: kevyaneastus2
+ container_name: amlt
+code:
+ local_dir: src/
+jobs:
+ - name: "get_predictions-2"
+ identity: managed
+ sku: 1xG1-A100
+ command:
+ - python FLIPv3/baselines/get_predictions.py /mnt/amlt/data/flip_data_pruned/ /mnt/amlt/flip_predictions/
+ mpi: True
+ process_count_per_node: -1
+ submit_args:
+ env:
+ WANDB_BASE_URL: "https://microsoft-research.wandb.io"
+ WANDB_API_KEY: "$WANDB_API_KEY"
+ WANDB_ENTITY: "bio"
+ NCCL_DEBUG: "INFO"
+ NCCL_SOCKET_IFNAME: "bond0"
+ NCCL_CUMEM_ENABLE: 0
+ NCCL_DEBUG_SUBSYS: "ALL"
+ TORCH_DISTRIBUTED_DEBUG: "INFO"
+ AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai"
+ tags: [ Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0 ]
+search:
+ job_template:
+ name: "{model_name}_{task}_{weights}_{seed}_{lr}"
+ sku: 1xG1-A100
+ command:
+ - python FLIPv3/baselines/{model_name}.py /mnt/amlt/data/flip_data_5-9-2025/datasets/ /mnt/amlt/flip_results/ {task} {weights} --seed {seed} --lr {lr}
+ mpi: True
+ process_count_per_node: -1
+ submit_args:
+ env:
+ WANDB_BASE_URL: "https://microsoft-research.wandb.io"
+ WANDB_API_KEY: "$WANDB_API_KEY"
+ WANDB_ENTITY: "bio"
+ NCCL_DEBUG: "INFO"
+ NCCL_SOCKET_IFNAME: "bond0"
+ NCCL_CUMEM_ENABLE: 0
+ NCCL_DEBUG_SUBSYS: "ALL"
+ TORCH_DISTRIBUTED_DEBUG: "INFO"
+ _AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai"
+ tags: [Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0]
+ type: grid
+ max_trials: 10000
+ params:
+ - name: lr
+ spec: discrete
+ values: [1e-5]
+ - name: model_name
+ spec: discrete
+ values: [carp]
+ - name: seed
+ spec: discrete
+ values: [1]
+ - name: weights
+ spec: discrete
+ values: [ naive ]
+ - name: task
+ spec: discrete
+ values: [ hydro_med_P06241test_split
+ ]
+# RhoMax_by_wt, AMY_BACSU_easy_split, AMY_BACSU_hard_split_, AMY_BACSU_med_split_is_buried_0, AMY_BACSU_med_split_is_buried_1, AMY_BACSU_med_split_is_close_to_as_0, AMY_BACSU_med_split_is_close_to_as_1,
+# AMY_BACSU_med_split_is_connected_0, AMY_BACSU_med_split_is_connected_1, AMY_BACSU_med_split_is_secondary_0, AMY_BACSU_med_split_is_secondary_1, AMY_BACSU_random_split,
+# hydro_hard_split, hydro_med1_split, hydro_med2_split, hydro_random_split,
+# ired_ired_excludeT241mut_mutation_order_split, ired_ired_low_high_split, ired_ired_mutation_order_split, ired_ired_random_split,
+# PDZ3_low_vs_high, PDZ3_one_vs_rest, PDZ3_sampled, PDZ3_three_vs_rest, PDZ3_two_vs_rest,
\ No newline at end of file
diff --git a/amlt/ridge.yaml b/amlt/ridge.yaml
new file mode 100644
index 0000000..c1a6962
--- /dev/null
+++ b/amlt/ridge.yaml
@@ -0,0 +1,50 @@
+description: Results on FLIPv3.
+
+target:
+ service: sing
+ name: msrresrchvc
+ workspace_name: bio0
+
+environment:
+ image: amlt-sing/acpt-torch2.7.0-py3.10-cuda12.6-ubuntu22.04
+ setup:
+ - cd sequence_models
+ - pip install -U --user -e .
+ - cd ..
+ - pip install biopython
+ - pip install esm
+ - pip install scikit-learn
+ - pip install scipy
+storage:
+ amlt:
+ storage_account_name: kevyaneastus2
+ container_name: amlt
+code:
+ local_dir: src/
+search:
+ job_template:
+ name: "ridge_{task}"
+ sku: C3
+ command:
+ - python FLIPv3/baselines/linear_models.py /mnt/amlt/data/flip_data/datasets/ /mnt/amlt/flip_results/ {task}
+ mpi: True
+ submit_args:
+ env:
+ WANDB_BASE_URL: "https://microsoft-research.wandb.io"
+ WANDB_API_KEY: "$WANDB_API_KEY"
+ WANDB_ENTITY: "bio"
+ NCCL_DEBUG: "INFO"
+ NCCL_SOCKET_IFNAME: "bond0"
+ NCCL_CUMEM_ENABLE: 0
+ NCCL_DEBUG_SUBSYS: "ALL"
+ TORCH_DISTRIBUTED_DEBUG: "INFO"
+ _AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai"
+ tags: [Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0]
+ type: random
+ max_trials: 25
+ params:
+ - name: task
+ spec: discrete
+ values: [trpb_trpB_four_to_three_split, trpb_trpB_no_position_overlap_split, trpb_trpB_one_vs_many_split, trpb_trpB_three_to_four_split, trpb_trpB_two_vs_many_split,
+ NucB_easy, NucB_medium, NucB_hard
+ ]
diff --git a/analysis/plot_baselines.py b/analysis/plot_baselines.py
new file mode 100644
index 0000000..035fa64
--- /dev/null
+++ b/analysis/plot_baselines.py
@@ -0,0 +1,139 @@
+import os
+from collections import Counter
+
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Patch
+import seaborn as sns
+from scipy.stats import spearmanr, pearsonr
+_ = sns.set(font_scale=1.7)
+_ = sns.set_style('white')
+pd.set_option('display.max_columns', None)
+pd.set_option('display.width', 1000)
+pd.set_option('display.max_rows', None)
+flip_path = '/home/kevyan/results/flipv3/'
+pruned_path = "/home/kevyan/data/flip_data_pruned/"
+
+datasets = os.listdir(os.path.join(flip_path, "all_predictions"))
+
+
+df = pd.read_csv(os.path.join(flip_path, "random_ridge.csv"))
+
+df_zs = pd.DataFrame()
+zs_model_dict = {
+ 'dayhoff': 'Dayhoff',
+ 'carp_640m_zs': 'CARP-640M',
+ # 'carp_640m_masked_zs': 'CARP-640M',
+ 'esm2_650M_scores': 'ESM2-650M'
+}
+idx = 0
+for dataset in datasets:
+ predictions = pd.read_csv(os.path.join(flip_path, 'zs', dataset + '_zs.csv'))
+ for model in zs_model_dict.keys():
+ if model in predictions.columns:
+ df_zs.loc[idx, 'dataset'] = dataset
+ df_zs.loc[idx, 'model'] = zs_model_dict[model]
+ df_zs.loc[idx, 'Spearman'] = spearmanr(predictions['target'], predictions[model]).correlation
+ idx += 1
+best_zs = df_zs.groupby('dataset').agg({'Spearman': 'max'}).reset_index()
+
+# Get number of training examples in each split
+dataset_path = '/home/kevyan/data/flip_data_pruned/'
+df_sizes = pd.DataFrame(columns=['dataset', 'split', 'n_train', 'n_valid', 'n_test'])
+for dataset in datasets:
+ split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits'))
+ split_csvs = [c for c in split_csvs if c[-4:] == '.csv']
+ for split_csv in split_csvs:
+ df_data = pd.read_csv(os.path.join(dataset_path, dataset, 'splits', split_csv))
+ n_test = len(df_data[df_data['set'] == 'test'])
+ n_train = len(df_data[(df_data['set'] == 'train') & (~df_data['validation'])])
+ n_valid = len(df_data[df_data['validation']])
+ df_sizes.loc[len(df_sizes)] = [dataset, split_csv[:-4], n_train, n_valid, n_test]
+print(df_sizes)
+
+df_metrics = pd.read_csv(os.path.join(flip_path, 'all_metrics.csv'))
+model_dict = {
+ "Ridge": "Ridge (one-hot)",
+ "zsRidge": 'Ridge (one-hot + likelihoods)',
+ "Dayhoff": "Dayhoff likelihood",
+ "ESM2-650M": "ESM2-650M likelihood",
+ "CARP-640M zero shot": "CARP-640M likelihood",
+ ("CARP-640M", True): "CARP-640M supervised",
+ ("CARP-640M", False): "CARP-640M naive supervised",
+ ("ESMC-300M", True): "ESMC-300M supervised",
+ ("ESMC-300M", False): "ESMC-300M naive supervised",
+}
+
+for i, row in df_metrics.iterrows():
+ if row['model'] in model_dict:
+ df_metrics.loc[i, 'model'] = model_dict[row['model']]
+ else:
+ df_metrics.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])]
+df_metrics = df_metrics.fillna(0)
+split_dict = {
+ 'close_to_far': 'position',
+ 'far_to_close': 'position',
+ 'by_mutation': 'mutation',
+ 'by_position': 'position',
+ 'by_wt': 'wild type',
+ 'random': 'random',
+ 'one_to_many': 'number',
+ 'to_P06241': 'wild type',
+ 'to_P01053': 'wild type',
+ 'to_P0A9X9': 'wild type',
+ 'low_to_high': 'fitness',
+ 'three_to_many': 'number',
+ 'single_to_double': 'number',
+ 'two_to_many': 'number',
+}
+pal = [
+ '#76B900',
+ '#A77BB5',
+ '#4E79A7',
+ '#FF8A80',
+ '#F28E2B',
+ '#E15759'
+]
+split_hues = {"number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]}
+
+
+for i, dataset in enumerate(datasets):
+ fig, ax = plt.subplots()
+ _ = sns.lineplot(data=df[df['dataset'] == dataset], x='n_train', y='Spearman', color='grey', style='model',
+ markers=True, ax=ax, ms=20, alpha=0.8)
+ _ = ax.axhline(y=best_zs[best_zs['dataset'] == dataset]['Spearman'].values[0], color='grey', linestyle='-',
+ label='best zero-shot likelihood score')
+ _ = ax.semilogx()
+ _ = ax.set_xlabel('Number of training examples')
+ _ = ax.set_ylim([-0.35, 1])
+ split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits'))
+ split_csvs = [c for c in split_csvs if c[-4:] == '.csv']
+ for split_csv in split_csvs:
+ if 'random' in split_csv:
+ continue
+ split_name = split_csv[:-4]
+ if dataset == 'Amylase' and split_name == 'by_position':
+ split_name = 'by_mutation'
+ color = split_hues[split_dict[split_name]]
+ x = df_sizes[(df_sizes['dataset'] == dataset) & (df_sizes['split'] == split_csv[:-4])]['n_train']
+ y1 = df_metrics[(df_metrics['dataset'] == dataset) & (df_metrics['split'] == split_csv[:-4]) & (df_metrics['model'] == 'Ridge (one-hot)')]['Spearman'].values[0]
+ _ = ax.plot(x, y1, 'o', color=color, ms=20, alpha=0.7, mew=0)
+ y2 = df_metrics[(df_metrics['dataset'] == dataset) & (df_metrics['split'] == split_csv[:-4]) & (df_metrics['model'] == 'Ridge (one-hot + likelihoods)')]['Spearman'].values[0]
+ _ = ax.plot(x, y2, 'x', color=color, ms=20, alpha=1.0, mew=4)
+ _ = ax.set_title(dataset)
+ legend = ax.legend(title='Model')
+ legend.remove()
+ fig.savefig(os.path.join(flip_path, "plots", "random_ridge_%s.pdf" %dataset), dpi=300, bbox_inches='tight')
+handles, labels = ax.get_legend_handles_labels()
+fig, ax = plt.subplots()
+legend = ax.legend(handles, labels, title='Model')
+bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+fig.savefig(os.path.join(flip_path, 'plots', 'random_ridge_models.pdf'), dpi=300, bbox_inches=bbox)
+elements = [
+ Patch(facecolor='gray', edgecolor=None, label='random'),
+]
+elements += [Patch(facecolor=split_hues[s], edgecolor=None, label=s, alpha=0.7) for s in split_hues]
+legend = ax.legend(handles=elements, title='Split type')
+bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+fig.savefig(os.path.join(flip_path, 'plots', 'random_ridge_split_types.pdf'), dpi=300, bbox_inches=bbox)
diff --git a/analysis/plot_zs.py b/analysis/plot_zs.py
new file mode 100644
index 0000000..131d563
--- /dev/null
+++ b/analysis/plot_zs.py
@@ -0,0 +1,457 @@
+import os
+from collections import Counter
+
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from scipy.stats import spearmanr, pearsonr
+_ = sns.set(font_scale=1.7)
+_ = sns.set_style('white')
+pd.set_option('display.max_columns', None)
+pd.set_option('display.width', 1000)
+pd.set_option('display.max_rows', None)
+flip_path = '/home/kevyan/results/flipv3/'
+pruned_path = "/home/kevyan/data/flip_data_pruned/"
+
+datasets = os.listdir(os.path.join(flip_path, "all_predictions"))
+
+pal = sns.color_palette()
+models = [
+ ('dayhoff', 'Dayhoff likelihood', pal[0]),
+ ('esm2_650M_scores', "ESM2-650M likelihood", pal[1]),
+ ('carp_640m_zs', "CARP-640M likelihood", pal[2]),
+ ('Ridge', 'Ridge (one-hot)', pal[3]),
+ ('zsRidge', 'Ridge (one-hot + likelihoods)', pal[4]),
+ ('carp_naive', 'CARP-640M naive supervised', pal[5]),
+ ('carp_pretrained', 'CARP-640M supervised', pal[6]),
+ ('esmc_naive', 'ESMC-300M naive supervised', pal[7]),
+ ('esmc_pretrained', 'ESMC-300M supervised', pal[8]),
+]
+os.makedirs(os.path.join(flip_path, 'plots', 'predictions'), exist_ok=True)
+# plot individual predictions
+_ = sns.set(font_scale=1)
+_ = sns.set_style('white')
+for dataset in datasets:
+ splits = os.listdir(os.path.join(flip_path, "all_predictions", dataset))
+ for split in splits:
+ split_name = split[:-4]
+ df = pd.read_csv(os.path.join(flip_path, 'all_predictions', dataset, split))
+ df['carp_naive'] = (df['carp_naive_0'] + df['carp_naive_1'] + df['carp_naive_2'] + df['carp_naive_3'] + df['carp_naive_4']) / 5
+ df['carp_pretrained'] = (df['carp_pretrained_0'] + df['carp_pretrained_1'] + df['carp_pretrained_2'] + df['carp_pretrained_3'] + df['carp_pretrained_4']) / 5
+ df['esmc_naive'] = (df['esmc_naive_0'] + df['esmc_naive_1'] + df['esmc_naive_2'] + df['esmc_naive_3'] + df['esmc_naive_4']) / 5
+ df['esmc_pretrained'] = (df['esmc_pretrained_0'] + df['esmc_pretrained_1'] + df['esmc_pretrained_2'] + df['esmc_pretrained_3'] + df['esmc_pretrained_4']) / 5
+ for ugly, pretty, color in models:
+ # if dataset == "PDZ3":
+ # if ugly in ['dayhoff', 'esm2_650M_scores', 'carp_640m_zs']:
+ # continue
+ fig, ax = plt.subplots()
+ _ = sns.scatterplot(data=df, x='scaled_target', y=ugly, alpha=0.3, marker='o', ax=ax, color='gray')
+ _ = ax.set_ylabel(pretty)
+ _ = ax.set_xlabel('scaled target')
+ fig.savefig(os.path.join(flip_path, 'plots', 'predictions',
+ '_'.join([dataset, split_name, ugly + '.png'])),
+ dpi=100, bbox_inches='tight')
+
+# Plot landscape zero-shot scores
+_ = sns.set(font_scale=1.7)
+_ = sns.set_style('white')
+pal = [
+ '#76B900',
+ '#A77BB5',
+ '#4E79A7',
+ '#FF8A80',
+ '#F28E2B',
+ '#E15759'
+]
+model_dict = {
+ 'dayhoff': 'Dayhoff likelihood',
+ 'carp_640m_zs': 'CARP-640M likelihood',
+ 'esm2_650M_scores': 'ESM2-650M likelihood',
+}
+model_order = [
+ 'Dayhoff',
+ 'ESM2-650M',
+ 'CARP-640M',
+]
+plot_me = pd.DataFrame()
+idx = 0
+for dataset in datasets:
+ predictions = pd.read_csv(os.path.join(flip_path, 'zs', dataset + '_zs.csv'))
+ for model in model_dict.keys():
+ if model in predictions.columns:
+ plot_me.loc[idx, 'dataset'] = dataset
+ plot_me.loc[idx, 'split'] = 'full'
+ plot_me.loc[idx, 'model'] = model_dict[model]
+ plot_me.loc[idx, 'Spearman'] = spearmanr(predictions['target'], predictions[model]).correlation
+ idx += 1
+dataset_order = ['Amylase', 'IRED', 'NucB', 'TrpB', 'Hydro', 'Rhomax', 'PDZ3']
+fig, ax = plt.subplots(figsize=(8, 6))
+palette = ['gray', 'gray', 'gray']
+_ = sns.barplot(x='dataset', y='Spearman', data=plot_me, ax=ax, hue='model', palette=palette, order=dataset_order)
+hatches = ['/', '.', '*']
+handles, labels = ax.get_legend_handles_labels()
+for i, bar in enumerate(ax.patches[:21]):
+ hatch = hatches[i // 7]
+ bar.set_hatch(hatch)
+for i, handle in enumerate(handles):
+ handle.set_hatch(hatches[i])
+ax.legend(handles, labels)
+ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
+fig.savefig(os.path.join(flip_path, 'plots', 'dataset_zeroshot.pdf'), dpi=300, bbox_inches='tight')
+
+
+# plot aggregate things
+df = pd.read_csv(os.path.join(flip_path, 'all_metrics.csv'))
+model_dict = {
+ "Ridge": "Ridge (one-hot)",
+ "zsRidge": 'Ridge (one-hot + likelihoods)',
+ "Dayhoff": "Dayhoff likelihood",
+ "ESM2-650M": "ESM2-650M likelihood",
+ "CARP-640M zero shot": "CARP-640M likelihood",
+ ("CARP-640M", True): "CARP-640M supervised",
+ ("CARP-640M", False): "CARP-640M naive supervised",
+ ("ESMC-300M", True): "ESMC-300M supervised",
+ ("ESMC-300M", False): "ESMC-300M naive supervised",
+}
+model_order = [
+ "Dayhoff likelihood",
+ "ESM2-650M likelihood",
+ "CARP-640M likelihood",
+ "Ridge (one-hot)",
+ "Ridge (one-hot + likelihoods)",
+ "CARP-640M naive supervised",
+ "CARP-640M supervised",
+ "ESMC-300M naive supervised",
+ "ESMC-300M supervised",
+]
+# hue = {m: p for m, p in zip(model_order, pal)}
+
+
+for i, row in df.iterrows():
+ if row['model'] in model_dict:
+ df.loc[i, 'model'] = model_dict[row['model']]
+ else:
+ df.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])]
+
+# for dataset in set(df['dataset']):
+# for split in set(df[df['dataset'] == dataset]['split']):
+# pretty_split = split.replace('_', '-')
+# data = df[(df['dataset'] == dataset) & (df['split'] == split)]
+# fig, ax = plt.subplots()
+# _ = sns.barplot(data=data, x='model', y='Spearman', hue='model', hue_order=model_order, palette=hue,
+# errorbar=None, ax=ax, alpha=0.7, order=model_order)
+# _ = sns.stripplot(data=data, x='model', y='Spearman', hue='model', hue_order=model_order, palette=hue,
+# ax=ax, order=model_order)
+# _ = ax.set_title(dataset + " " + pretty_split)
+# _ = ax.set_ylim([-0.2, 1])
+# ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
+# fig.savefig(os.path.join(flip_path, 'plots', 'spearman_' + dataset + '_' + split + '.pdf'),
+# dpi=300, bbox_inches='tight')
+
+
+split_dict = {
+ 'close_to_far': 'position',
+ 'far_to_close': 'position',
+ 'by_position': 'position',
+ 'by_mutation': 'mutation',
+ 'by_wt': 'wild type',
+ 'random': 'random',
+ 'one_to_many': 'number',
+ 'to_P06241': 'wild type',
+ 'to_P01053': 'wild type',
+ 'to_P0A9X9': 'wild type',
+ 'low_to_high': 'fitness',
+ 'three_to_many': 'number',
+ 'single_to_double': 'number',
+ 'two_to_many': 'number',
+}
+split_hues = {"full": 'gray', "number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]}
+split_order = ["full", 'number', 'wild type', 'position', 'mutation', 'fitness']
+df['split type'] = df.apply(lambda row: split_dict[row['split']], axis=1)
+
+all_zs = df[df['model'].isin(['Dayhoff likelihood', 'ESM2-650M likelihood', 'CARP-640M likelihood'])]
+all_zs = all_zs[['dataset', 'split', 'model', 'Spearman']]
+all_zs = pd.concat([plot_me, all_zs], ignore_index=True)
+
+for idx, row in all_zs.iterrows():
+ all_zs.loc[idx, 'task'] = row['dataset'] + ' ' + '-'.join(row['split'].split('_'))
+ all_zs.loc[idx, 'split type'] = split_dict[row['split']] if row['split'] in split_dict else row['split']
+
+task_order = [
+ 'Amylase full',
+ 'Amylase one-to-many',
+ 'Amylase close-to-far',
+ 'Amylase far-to-close',
+ 'Amylase by-mutation',
+ 'IRED full',
+ 'IRED two-to-many',
+ 'NucB full',
+ 'NucB two-to-many',
+ 'TrpB full',
+ 'TrpB one-to-many',
+ 'TrpB two-to-many',
+ 'TrpB by-position',
+ 'Hydro full',
+ 'Hydro three-to-many',
+ 'Hydro low-to-high',
+ 'Hydro to-P06241',
+ 'Hydro to-P0A9X9',
+ 'Hydro to-P01053',
+ 'Rhomax full',
+ 'Rhomax by-wt',
+ 'PDZ3 full',
+ 'PDZ3 single-to-double'
+]
+all_zs['task'] = pd.Categorical(all_zs['task'], categories=task_order, ordered=True)
+_ = sns.set(font_scale=1.5)
+_ = sns.set_style('white')
+fig, ax = plt.subplots(figsize=(16, 6))
+# _ = ax.fill_betweenx([-0.5, 0.67], x1=[-1.1, -1.1], x2=[4.5, 4.5], color='gray', alpha=0.1)
+_ = ax.fill_betweenx([-0.5, 0.67], x1=[4.5, 4.5], x2=[6.5, 6.5], color='gray', alpha=0.1, linewidth=0)
+_ = ax.fill_betweenx([-0.5, 0.67], x1=[8.5, 8.5], x2=[12.5, 12.5], color='gray', alpha=0.1, linewidth=0)
+_ = ax.fill_betweenx([-0.5, 0.67], x1=[18.5, 18.5], x2=[20.5, 20.5], color='gray', alpha=0.1, linewidth=0)
+
+_ = sns.scatterplot(data=all_zs, x='task', y='Spearman', hue='split type', hue_order=split_order,
+ ax=ax, style='model', palette=split_hues, s=150, markers=['X', 'o', 's'], alpha=0.7)
+_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
+_ = ax.legend(loc='upper left', bbox_to_anchor=(1.01, 1))
+fig.savefig(os.path.join(flip_path, 'plots', 'all_zeroshot.pdf'), dpi=300, bbox_inches='tight')
+
+_ = sns.set(font_scale=1.7)
+_ = sns.set_style('white')
+split_hues = {"number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]}
+split_order = ['number', 'wild type', 'position', 'mutation', 'fitness']
+plot_me = pd.DataFrame()
+idx = 0
+tasks = list(df[['dataset', 'split']].values)
+tasks = set([(t[0], t[1]) for t in tasks])
+for dataset, split in tasks:
+ current = df[(df['dataset'] == dataset) & (df['split'] == split)]
+ for j, row in current.iterrows():
+ plot_me.loc[idx, 'dataset'] = dataset
+ plot_me.loc[idx, 'split'] = split
+ plot_me.loc[idx, row['model']] = row['Spearman']
+ plot_me.loc[idx, 'split type'] = split_dict[split]
+ idx += 1
+plot_me = plot_me.fillna(0)
+plot_me = plot_me[plot_me['split type'] != 'random']
+models = list(plot_me.columns[2:])
+models.remove('split type')
+models = np.array(models)
+plot_me['best model'] = models[plot_me[models].values.argmax(axis=1)]
+print(plot_me[['dataset', 'split', 'best model']].sort_values('dataset'))
+# Plot Ridge vs Ridge + zs
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot)', y='Ridge (one-hot + likelihoods)', hue='split type',
+ alpha=0.7, ax=ax, color='gray', hue_order=split_order, s=150, legend=False, palette=split_hues)
+_ = ax.set_ylabel('Ridge (one-hot+likelihoods)')
+fig.savefig(os.path.join(flip_path, 'plots', 'ridge_comparison.pdf'), dpi=300, bbox_inches='tight')
+
+# Plot Ridge + zs vs best zs
+plot_me['best zero-shot'] = plot_me.loc[:, ['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood']].max(axis=1)
+plot_me['best zs method'] = np.array(['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood'])[plot_me.loc[:, ['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood']].values.argmax(axis=1)]
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me, x='best zero-shot', y='Ridge (one-hot + likelihoods)', hue='split type',
+ alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150)
+_ = ax.set_ylabel('Ridge (one-hot+likelihoods)')
+
+legend = ax.legend()
+legend.remove()
+fig.savefig(os.path.join(flip_path, 'plots', 'zs_vs_ridgezs.pdf'), dpi=300, bbox_inches='tight')
+
+# Plot Ridge vs best PLM
+plot_me['best PLM'] = plot_me.loc[:, ['CARP-640M supervised', 'CARP-640M naive supervised', 'ESMC-300M supervised', 'ESMC-300M naive supervised']].max(axis=1)
+archs = np.array(['CARP-640M', 'CARP-640M', 'ESMC-300M', 'ESMC-300M'])
+plot_me['best architecture'] = archs[plot_me.loc[:, ['CARP-640M supervised', 'CARP-640M naive supervised', 'ESMC-300M supervised', 'ESMC-300M naive supervised']].values.argmax(axis=1)]
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot)', y='best PLM', hue='split type', style='best architecture',
+ alpha=0.7,
+ legend=False, style_order=['CARP-640M', 'ESMC-300M'], ax=ax, palette=split_hues,
+ hue_order=split_order, s=150)
+fig.savefig(os.path.join(flip_path, 'plots', 'ridge_vs_plm.pdf'), dpi=300, bbox_inches='tight')
+
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot + likelihoods)', y='best PLM', hue='split type',
+ style='best architecture', style_order=['CARP-640M', 'ESMC-300M'], legend=False,
+ alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150)
+fig.savefig(os.path.join(flip_path, 'plots', 'ridgezs_vs_plm.pdf'), dpi=300, bbox_inches='tight')
+
+# Plot CARP
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me, x='CARP-640M likelihood', y='CARP-640M supervised', hue='split type',
+ alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150)
+legend = ax.legend()
+legend.remove()
+fig.savefig(os.path.join(flip_path, 'plots', 'carp.pdf'), dpi=300, bbox_inches='tight')
+handles, labels = ax.get_legend_handles_labels()
+fig, ax = plt.subplots()
+legend = ax.legend(handles, labels, title='Split type')
+bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+fig.savefig(os.path.join(flip_path, 'plots', 'split_types.pdf'), dpi=300, bbox_inches=bbox)
+
+# Get dataset sizes
+dataset_path = '/home/kevyan/data/flip_data_pruned/'
+for dataset in set(plot_me['dataset']):
+ split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits'))
+ split_csvs = [c for c in split_csvs if c[-4:] == '.csv']
+ for split_csv in split_csvs:
+ df_data = pd.read_csv(os.path.join(dataset_path, dataset, 'splits', split_csv))
+ n_test = len(df_data[df_data['set'] == 'test'])
+ n_train = len(df_data[(df_data['set'] == 'train') & (~df_data['validation'])])
+ n_valid = len(df_data[df_data['validation']])
+ index = plot_me[(plot_me['dataset'] == dataset) & (plot_me['split'] == split_csv[:-4])].index
+ if not index.empty:
+ plot_me.loc[index, 'n_train'] = n_train
+
+
+
+# Plot pretrained vs naive
+plot_me2 = pd.DataFrame()
+idx = 0
+for i, row in plot_me.iterrows():
+ plot_me2.loc[idx, 'dataset'] = row['dataset']
+ plot_me2.loc[idx, 'split'] = row['split']
+ plot_me2.loc[idx, 'split type'] = split_dict[row['split']]
+ plot_me2.loc[idx, 'pretrained'] = row['CARP-640M supervised']
+ plot_me2.loc[idx, 'naive'] = row['CARP-640M naive supervised']
+ plot_me2.loc[idx, 'architecture'] = 'CARP-640M'
+ idx += 1
+ plot_me2.loc[idx, 'dataset'] = row['dataset']
+ plot_me2.loc[idx, 'split'] = row['split']
+ plot_me2.loc[idx, 'split type'] = split_dict[row['split']]
+ plot_me2.loc[idx, 'pretrained'] = row['ESMC-300M supervised']
+ plot_me2.loc[idx, 'naive'] = row['ESMC-300M naive supervised']
+ plot_me2.loc[idx, 'architecture'] = 'ESMC-300M'
+ idx += 1
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray')
+_ = sns.scatterplot(data=plot_me2, x='naive', y='pretrained', hue='split type', style='architecture', alpha=0.7,
+ palette=split_hues, ax=ax, style_order=['CARP-640M', 'ESMC-300M'], s=150, legend=True,
+ hue_order=split_order)
+legend = ax.legend()
+legend.remove()
+handles, labels = ax.get_legend_handles_labels()
+fig.savefig(os.path.join(flip_path, 'plots', 'plms.pdf'), dpi=300, bbox_inches='tight')
+fig, ax = plt.subplots()
+legend = ax.legend(handles, labels)
+bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+fig.savefig(os.path.join(flip_path, 'plots', 'plms_legend.pdf'), dpi=300, bbox_inches=bbox)
+
+# Barplot of ridge spearmans
+fig, ax = plt.subplots(figsize=(8, 6))
+plot_me3 = df[df['model'] == 'Ridge (one-hot)']
+plot_me3 = df[df['split'] != 'random']
+for i, row in plot_me3.iterrows():
+ plot_me3.loc[i, 'task'] = row['dataset'] + ' ' + row['split'].replace('_', '-')
+task_order = [
+ 'Amylase one-to-many',
+ 'Amylase close-to-far',
+ 'Amylase far-to-close',
+ 'Amylase by-mutation',
+ 'IRED two-to-many',
+ 'NucB two-to-many',
+ 'TrpB one-to-many',
+ 'TrpB two-to-many',
+ 'TrpB by-position',
+ 'Hydro three-to-many',
+ 'Hydro low-to-high',
+ 'Hydro to-P06241',
+ 'Hydro to-P0A9X9',
+ 'Hydro to-P01053',
+ 'Rhomax by-wt',
+ 'PDZ3 single-to-double'
+]
+_ = sns.barplot(data=plot_me3, x='task', y='Spearman', ax=ax, hue='split type', hue_order=split_order,
+ palette=split_hues,
+ order=task_order, legend=False)
+_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
+fig.savefig(os.path.join(flip_path, 'plots', 'all_ridge.pdf'), dpi=300, bbox_inches='tight')
+
+
+plot_me3 = df[df['model'].isin(['Dayhoff likelihood', 'ESM-650M likelihood', 'CARP-640M likelihood'])]
+for i, row in plot_me3.iterrows():
+ plot_me3.loc[i, 'task'] = row['dataset'] + ' ' + row['split'].replace('_', '-')
+plot_me3 = plot_me3.groupby('task').agg({'Spearman': 'mean', 'split type': lambda x: x.values[0]})
+plot_me3 = plot_me3.reset_index()
+fig, ax = plt.subplots(figsize=(8, 6))
+_ = sns.barplot(data=plot_me3, x='task', y='Spearman', ax=ax, hue='split type', hue_order=split_order,
+ palette=split_hues, legend=False, order=task_order)
+_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')
+fig.savefig(os.path.join(flip_path, 'plots', 'best_zs.pdf'), dpi=300, bbox_inches='tight')
+
+
+# Plot train zs Spearman vs test zs Spearman
+datasets = os.listdir(pruned_path)
+model_order = [
+ "Dayhoff",
+ "ESM2-650M",
+ "CARP-640M",
+ "Ridge (one-hot)",
+ "Ridge (one-hot + likelihoods)",
+ "CARP-640M naive supervised",
+ "CARP-640M supervised",
+ "ESMC-300M naive supervised",
+ "ESMC-300M supervised",
+]
+hue = {m: p for m, p in zip(model_order, pal)}
+
+zs_df = pd.DataFrame()
+zs_columns = {
+ 'esm2_650M_scores': "ESM2-650M",
+ 'carp_640m_zs': "CARP-640M",
+ 'dayhoff': 'Dayhoff',}
+idx = 0
+for dataset in datasets:
+ for split in os.listdir(os.path.join(pruned_path, dataset, 'splits')):
+ if split[-4:] != ".csv":
+ continue
+ results = pd.read_csv(os.path.join(pruned_path, dataset, 'splits', split))
+ if 'carp_640m_masked_zs' in results.columns:
+ results['carp_640m_zs'] = results['carp_640m_masked_zs']
+ results = results.drop(columns=['carp_640m_masked_zs'])
+ if dataset != "PDZ3":
+ results.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"}, inplace=True)
+ results['dayhoff'] = (results['dayhoff_3bgrhmc_fwd'] + results['dayhoff_3bgrhmc_bwd'] + results['dayhoff_3bur90_fwd'] + results['dayhoff_3bur90_bwd']) / 4
+ else:
+ results.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"}, inplace=True)
+ results.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"}, inplace=True)
+ d1 = (results['dayhoff_3bgrhmc_fwd_1'] + results['dayhoff_3bgrhmc_bwd_1'] + results['dayhoff_3bur90_fwd_1'] + results['dayhoff_3bur90_bwd_1']) / 4
+ d1 += (results['dayhoff_3bgrhmc_fwd_2'] + results['dayhoff_3bgrhmc_bwd_2'] + results['dayhoff_3bur90_fwd_2'] + results['dayhoff_3bur90_bwd_2']) / 4
+ d1 /= 2
+ results['dayhoff'] = d1
+ results['esm2_650M_scores'] = (results['esm2_650M_scores1'] + results['esm2_650M_scores2'].fillna(0)) / 2
+ results['carp_640m_zs'] = (results['carp_640m_zs_1'] + results['carp_640m_zs_2']) / 2
+ for m in zs_columns.keys():
+ if m in results.columns:
+ zs_df.loc[idx, 'dataset'] = dataset
+ zs_df.loc[idx, 'split type'] = split_dict[split[:-4]]
+ zs_df.loc[idx, 'split'] = split[:-4]
+ zs_df.loc[idx, 'model'] = zs_columns[m]
+ for s in ['train', 'test']:
+ d = results[results['set'] == s]
+ sp = spearmanr(d['target'], d[m]).correlation
+ zs_df.loc[idx, s + ' Spearman'] = sp
+ idx += 1
+plot_me = zs_df[zs_df['split'] != 'random']
+fig, ax = plt.subplots()
+_ = ax.plot([-0.2, 0.7], [-0.2, 0.7], color='gray')
+_ = sns.scatterplot(data=plot_me, x='train Spearman', y='test Spearman', hue='split type',
+ palette=split_hues, hue_order=split_order, ax=ax, alpha=0.7, s=150, legend=False)
+fig.savefig(os.path.join(flip_path, 'plots', 'zs_train_v_test.pdf'), dpi=300, bbox_inches='tight')
+print("Train/test zero shot spearmans")
+print(pearsonr(zs_df.dropna()['train Spearman'], zs_df.dropna()['test Spearman']))
+zs_df = zs_df[zs_df['split type'] != 'random']
+for model in model_order[:3]:
+ d = zs_df[zs_df['model'] == model]
+ print(model, pearsonr(d.dropna()['train Spearman'], d.dropna()['test Spearman']))
+
+for st in set(split_dict.values()):
+ if st != 'random':
+ d = zs_df[zs_df['split type'] == st]
+ print(st, pearsonr(d.dropna()['train Spearman'], d.dropna()['test Spearman']).correlation)
diff --git a/baselines/aggregate.py b/baselines/aggregate.py
new file mode 100644
index 0000000..9cf4afa
--- /dev/null
+++ b/baselines/aggregate.py
@@ -0,0 +1,237 @@
+import os
+from collections import Counter
+
+import pandas as pd
+import numpy as np
+from scipy.stats import spearmanr
+from sklearn.metrics import ndcg_score, roc_auc_score
+
+pd.set_option('display.max_columns', None)
+pd.set_option('display.width', 1000)
+pd.set_option('display.max_rows', None)
+
+flip_path = '/home/kevyan/results/flipv3/'
+# results_path = os.path.join(flip_path, 'combined_flip_results')
+# result_files = os.listdir(results_path)
+
+
+
+
+
+# Make clean full landscape zero shot score csvs
+pruned_path = "/home/kevyan/data/flip_data_pruned/"
+landscapes = os.listdir(pruned_path)
+os.makedirs(os.path.join(flip_path, "zs"), exist_ok=True)
+for landscape in landscapes:
+ if landscape == "rhomax":
+ df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "by_wt.csv"))
+ elif landscape == "TrpB":
+ df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "by_position.csv"))
+ else:
+ df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "random.csv"))
+ columns = [c for c in df.columns if c not in ("set", "validation", "dayhoff_min", "dayhoff_3bur90_min")]
+ df = df[columns]
+ if landscape != "PDZ3":
+ df.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"}, inplace=True)
+ df['dayhoff'] = (df['dayhoff_3bgrhmc_fwd'] + df['dayhoff_3bgrhmc_bwd'] + df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd']) / 4
+ else:
+ df.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"}, inplace=True)
+ df.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"}, inplace=True)
+ d = (df['dayhoff_3bgrhmc_fwd_1'] + df['dayhoff_3bgrhmc_bwd_1'] + df['dayhoff_3bur90_fwd_1'] + df['dayhoff_3bur90_bwd_1']) / 4
+ d += (df['dayhoff_3bgrhmc_fwd_2'] + df['dayhoff_3bgrhmc_bwd_2'] + df['dayhoff_3bur90_fwd_2'] + df['dayhoff_3bur90_bwd_2']) / 4
+ d /= 2
+ df['dayhoff'] = d
+ df['esm2_650M_scores'] = (df['esm2_650M_scores1'] + df['esm2_650M_scores2'].fillna(0)) / 2
+ df['carp_640m_zs'] = (df['carp_640m_zs_1'] + df['carp_640m_zs_2']) / 2
+ if 'carp_640m_masked_zs' in df.columns:
+ df['carp_640m_zs'] = df['carp_640m_masked_zs']
+ df = df.drop(columns=['carp_640m_masked_zs'])
+ rho1 = spearmanr(df['target'], df['esm2_650M_scores']).statistic
+ rho2 = spearmanr(df['target'], df['dayhoff']).statistic
+ rho3 = spearmanr(df['target'], df['carp_640m_zs']).statistic
+
+ print(landscape, rho1, rho2, rho3)
+ df.to_csv(os.path.join(flip_path, "zs", landscape + "_zs.csv"), index=False)
+
+# Make clean prediction csvs for each test set
+prediction_path = os.path.join(flip_path, "all_predictions")
+for landscape in landscapes:
+ os.makedirs(os.path.join(prediction_path, landscape), exist_ok=True)
+ split_csvs = os.listdir(os.path.join(pruned_path, landscape, "splits"))
+ split_csvs = [c for c in split_csvs if ".csv" in c]
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", split_csv), index_col=0)
+ df = df[df['set'] == 'test']
+ columns = [c for c in df.columns if c not in ("set", "validation", "dayhoff_min", "dayhoff_3bur90_min", "Unnamed: 0")]
+ df = df[columns]
+ if landscape != "PDZ3":
+ df.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"},
+ inplace=True)
+ df['dayhoff'] = (df['dayhoff_3bgrhmc_fwd'] + df['dayhoff_3bgrhmc_bwd'] + df['dayhoff_3bur90_fwd'] + df[
+ 'dayhoff_3bur90_bwd']) / 4
+ else:
+ df.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"},
+ inplace=True)
+ df.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"},
+ inplace=True)
+ d = (df['dayhoff_3bgrhmc_fwd_1'] + df['dayhoff_3bgrhmc_bwd_1'] + df['dayhoff_3bur90_fwd_1'] + df[
+ 'dayhoff_3bur90_bwd_1']) / 4
+ d += (df['dayhoff_3bgrhmc_fwd_2'] + df['dayhoff_3bgrhmc_bwd_2'] + df['dayhoff_3bur90_fwd_2'] + df[
+ 'dayhoff_3bur90_bwd_2']) / 4
+ d /= 2
+ df['dayhoff'] = d
+ df['esm2_650M_scores'] = (df['esm2_650M_scores1'] + df['esm2_650M_scores2'].fillna(0)) / 2
+ df['carp_640m_zs'] = (df['carp_640m_zs_1'] + df['carp_640m_zs_2']) / 2
+ if 'carp_640m_masked_zs' in df.columns:
+ df['carp_640m_zs'] = df['carp_640m_masked_zs']
+ df = df.drop(columns=['carp_640m_masked_zs'])
+ plm_path = os.path.join(flip_path, "plm_predictions", landscape + "_" + split_csv[:-4] + "_predictions.csv")
+ if os.path.isfile(plm_path):
+ df2 = pd.read_csv(plm_path)
+ if landscape == "PDZ3":
+ df2['sequence'] = df['sequence'].values
+ df3 = df.merge(df2, left_on='sequence', right_on='sequence')
+ df4 = pd.read_csv(os.path.join(flip_path, "ridge", landscape, split_csv))
+ df4.rename(columns={'prediction': 'Ridge', "target": "linear_target"}, inplace=True)
+ df3 = df3.merge(df4, left_on='sequence', right_on='sequence')
+ df4 = pd.read_csv(os.path.join(flip_path, "ridge_zs", landscape, split_csv))
+ df4 = df4[['sequence', 'prediction']]
+ df4.rename(columns={'prediction': 'zsRidge'}, inplace=True)
+ df3 = df3.merge(df4, left_on='sequence', right_on='sequence')
+ df3.to_csv(os.path.join(prediction_path, landscape, split_csv), index=False)
+
+# Make csv for metrics with all replicates
+nice_models = {
+ 'Ridge': ("Ridge", False, 0),
+ "zsRidge": ("zsRidge", False, 0),
+ "dayhoff": ("Dayhoff", False, 0),
+ "carp_640m_zs": ("CARP-640M zero shot", False, 0),
+ "esm2_650M_scores": ("ESM2-650M", False, 0),
+}
+for seed in range(5):
+ for model in ['carp', 'esmc']:
+ for p in ['pretrained', 'naive']:
+ key = '_'.join([model, p, str(seed)])
+ if model == "carp":
+ nice_models[key] = ("CARP-640M", p == "pretrained", seed)
+ else:
+ nice_models[key] = ("ESMC-300M", p == "pretrained", seed)
+
+all_results = pd.DataFrame(columns=['dataset', "split", "model", "pretrained", "Spearman", "MSE", "NDCG", "seed"])
+for landscape in landscapes:
+ pred_csvs = os.listdir(os.path.join(prediction_path, landscape))
+ for pred_csv in pred_csvs:
+ pred_df = pd.read_csv(os.path.join(prediction_path, landscape, pred_csv))
+ split = pred_csv[:-4]
+ for key in nice_models:
+ if key in pred_df.columns:
+ idx = len(all_results)
+ all_results.loc[idx, ['dataset', 'split']] = (landscape, split)
+ all_results.loc[idx, ['model', 'pretrained', 'seed']] = nice_models[key]
+ all_results.loc[idx, ["Spearman"]] = spearmanr(pred_df['scaled_target'], pred_df[key]).correlation
+ if "carp" in key or "esmc" in key or "Ridge" in key:
+ all_results.loc[idx, ["MSE"]] = ((pred_df['scaled_target'] - pred_df[key]) ** 2).mean()
+ pos_targets = (pred_df['target'] - pred_df['target'].min()).values[None, :]
+ all_results.loc[idx, ["NDCG"]] = ndcg_score(pos_targets, pred_df[key].values[None, :])
+all_results.to_csv(os.path.join(flip_path, "all_metrics.csv"), index=False)
+
+# Make csv with replicates aggregated
+grouped = all_results.groupby(['dataset', 'split', 'model', 'pretrained'])
+agged = grouped.agg(
+ spearman_mean=('Spearman', np.mean),
+ spearman_std=('Spearman', np.std),
+ mse_mean=('MSE', np.mean),
+ mse_std=('MSE', np.std),
+ ndcg_mean=('NDCG', np.mean),
+ ndcg_std=('NDCG', np.std),
+)
+
+
+compiled = agged.reset_index()
+print(compiled)
+compiled.to_csv(os.path.join(flip_path, 'all_metrics_aggregated.csv'), index=True)
+
+model_dict = {
+ "Ridge": "Ridge (one-hot)",
+ "zsRidge": 'Ridge (one-hot + likelihoods)',
+ "Dayhoff": "Dayhoff likelihood",
+ "ESM2-650M": "ESM2-650M likelihood",
+ "CARP-640M zero shot": "CARP-640M likelihood",
+ ("CARP-640M", True): "CARP-640M supervised",
+ ("CARP-640M", False): "CARP-640M naive supervised",
+ ("ESMC-300M", True): "ESMC-300M supervised",
+ ("ESMC-300M", False): "ESMC-300M naive supervised",
+}
+
+model_order = [
+ "Ridge (one-hot)",
+ 'Ridge (one-hot + likelihoods)',
+ "Dayhoff likelihood",
+ "ESM2-650M likelihood",
+ "CARP-640M likelihood",
+ "CARP-640M supervised",
+ "CARP-640M naive supervised",
+ "ESMC-300M supervised",
+ "ESMC-300M naive supervised",
+]
+
+for i, row in compiled.iterrows():
+ if row['model'] in model_dict:
+ compiled.loc[i, 'model'] = model_dict[row['model']]
+ else:
+ compiled.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])]
+
+all_tasks = []
+for i, row in compiled.iterrows():
+ all_tasks.append((row['dataset'], row['split']))
+
+all_tasks = set(all_tasks)
+for task in all_tasks:
+ if 'random' in task:
+ continue
+ print(task)
+ comp = compiled[(compiled['dataset'] == task[0]) & (compiled['split'] == task[1])]
+ for model in model_order:
+ c = comp[comp['model'] == model]
+ for i, row in c.iterrows():
+ if 'supervised' in model:
+ print(row['model'] + '& $%.3f \pm %.3f$ & $%.3f \pm %.3f$\\\\' %(row['spearman_mean'], row['spearman_std'], row['ndcg_mean'], row['ndcg_std']))
+
+ else:
+ print(row['model'] + '& $%.3f$ & $%.3f$\\\\' %(row['spearman_mean'], row['ndcg_mean']))
+by_task = compiled.pivot(index=['dataset', 'split'], columns=['model', 'pretrained'])
+by_task = by_task.reset_index()
+by_task.head()
+(by_task['spearman_mean']['CARP-640M'][True] > by_task['spearman_mean']['CARP-640M'][False]).sum()
+(by_task['spearman_mean']['ESMC-35M'][True] > by_task['spearman_mean']['ESMC-35M'][False]).sum()
+(by_task['spearman_mean']['Ridge'][False] > by_task['spearman_mean']['zsRidge'][False]).sum()
+(by_task['spearman_mean']['Dayhoff'][False] > by_task['spearman_mean']['zsRidge'][False]).sum()
+(by_task['spearman_mean']['ESM2-650M'][False] > by_task['spearman_mean']['zsRidge'][False]).sum()
+
+model_list = by_task.columns[2:11]
+best_spearmans = []
+for i, row in by_task.iterrows():
+ print(row['dataset'].values[0], row['split'].values[0], *model_list[row['spearman_mean'].argmax()][1:])
+ best_spearmans.append(model_list[row['spearman_mean'].argmax()][1:])
+Counter(best_spearmans)
+#
+# best_mses = []
+# for i, row in by_task.iterrows():
+# print(row['dataset'].values[0], row['split'].values[0], *model_list[row['mse_mean'].argmin()][1:])
+# best_mses.append(model_list[row['mse_mean'].argmin()][1:])
+# Counter(best_mses)
+#
+#
+# by_task[['dataset', 'split', 'spearman_mean']]
+# by_task['spearman_mean'].mean()
+# print("dataset,split,p>n,p>r")
+# for dataset in set(compiled.dataset):
+# for split in set(compiled[compiled['dataset'] == dataset].split):
+# r = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'Ridge')]['spearman_mean'].values[0]
+# p = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'CARP-640M') & (compiled['pretrained'])]['spearman_mean'].values[0]
+# n = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'CARP-640M') & (~compiled['pretrained'])]['spearman_mean']
+# if len(n) > 0:
+# n = n.values[0]
+# else:
+# n = -np.inf
+# print(dataset + "," + split + "," + str(p > n) + "," + str(p > r))
\ No newline at end of file
diff --git a/baselines/carp.py b/baselines/carp.py
new file mode 100644
index 0000000..e6301e5
--- /dev/null
+++ b/baselines/carp.py
@@ -0,0 +1,248 @@
+import argparse
+import os
+import json
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import LambdaLR
+from torch.optim import Adam
+import numpy as np
+from scipy.stats import spearmanr
+from sklearn.metrics import ndcg_score
+
+from sequence_models.collaters import Seq2PropertyCollater
+from sequence_models.constants import PAD
+from sequence_models.structure import Attention1d
+from sequence_models.utils import warmup
+from sequence_models.flip_utils import load_flip_data
+from sequence_models.pretrained import load_model_and_alphabet
+
+
+
+
+
+class Model(nn.Module):
+
+ def __init__(self, d_model, dropout=0.0):
+ super().__init__()
+ self.d_model = d_model
+ self.attention = Attention1d(d_model)
+ self.activation = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.hidden = nn.Linear(d_model, d_model)
+ self.linear = nn.Linear(d_model, 1)
+
+ def forward(self, e, input_mask=None):
+ attended = self.attention(e, input_mask=input_mask)
+ hidden = self.hidden(self.activation(attended))
+ return self.linear(self.dropout(self.activation(hidden)))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('task', type=str)
+ parser.add_argument('weights', type=str, default='pretrained')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--lr', default=1e-4, type=float)
+
+ args = parser.parse_args()
+ train(args)
+
+
+def train(args):
+ _ = torch.manual_seed(args.seed)
+ torch.cuda.set_device(0)
+ device = torch.device('cuda:0')
+ lr = args.lr
+
+ np.random.seed(args.seed)
+ carp, collater = load_model_and_alphabet('carp_640M')
+ if args.weights != 'pretrained':
+ for p in carp.modules():
+ try:
+ p.reset_parameters()
+ except AttributeError:
+ continue
+ embedder = carp.model.embedder.to(device)
+ d_model = carp.model.embedder.up_embedder.conv.out_channels
+ decoder = Model(d_model, dropout=0)
+ decoder = decoder.to(device)
+ optimizer = Adam(list(embedder.parameters()) + list(decoder.parameters()), lr=lr)
+ model = nn.ModuleDict({'embedder': embedder, 'decoder': decoder})
+ alphabet = collater.tokenizer.alphabet
+
+ ## Grab data
+ batch_size = 16
+ loss_func = nn.MSELoss()
+ if "AMY_BACSU" in args.task:
+ flip_dataset = '_'.join(args.task.split('_')[:2])
+ flip_split = '_'.join(args.task.split('_')[2:])
+ else:
+ flip_dataset = args.task.split('_')[0]
+ flip_split = '_'.join(args.task.split('_')[1:])
+ ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, flip_dataset, flip_split, max_len=2048, scale=True)
+ collate_fn = Seq2PropertyCollater(alphabet, return_mask=True)
+ num_workers = 4
+ dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collate_fn,
+ num_workers=num_workers, shuffle=True)
+ dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
+ dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
+ print('%d Train samples %d valid samples %d test samples' %(len(ds_train), len(ds_valid), len(ds_test)))
+ checkpoint_stem = 'carp_%s_%s_%d' %(args.task, args.weights, args.seed)
+ def step(model, batch, train=True, return_values=False):
+ src, tgt, input_mask = batch
+ src = src.to(device)
+ tgt = tgt.to(device)
+ input_mask = (src != alphabet.index(PAD)).float().unsqueeze(-1)
+ e = model['embedder'](src, input_mask=input_mask)
+ outputs = model['decoder'](e, input_mask=input_mask)
+
+ loss = loss_func(outputs, tgt)
+ locations = len(tgt)
+ mask = torch.ones(1) # dummy
+ if train:
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ if return_values:
+ return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu()
+ else:
+ return loss.item(), locations
+
+
+ def epoch(model, current_step=0):
+ model = model.train()
+ loader = dl_train
+ t = 'Training:'
+ losses = []
+ ns = []
+ n_seen = 0
+ if train:
+ n_total = len(ds_train)
+ else:
+ n_total = len(ds_valid)
+ for i, batch in enumerate(loader):
+ new_loss, new_n = step(model, batch, True)
+ losses.append(new_loss * new_n)
+ ns.append(new_n)
+ n_seen += len(batch[0])
+ total_n = sum(ns)
+ if total_n == 0:
+ rloss = 0
+ else:
+ rloss = sum(losses) / total_n
+ if train:
+ nsteps = current_step + i + 1
+ else:
+ nsteps = i
+ print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f'
+ % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss),
+ end='')
+ if not train:
+ return rloss
+ return i, rloss
+
+ def test_epoch(model, dl):
+ model = model.eval()
+ with torch.no_grad():
+ losses = []
+ ns = []
+ n_seen = 0
+ pred = []
+ tgt = []
+ masks = []
+ for i, batch in enumerate(dl):
+ new_loss, new_n, p, s, t, m = step(model, batch, False, return_values=True)
+ losses.append(new_loss * new_n)
+ pred.append(p)
+ tgt.append(t)
+ masks.append(m)
+ ns.append(new_n)
+ n_seen += len(batch[0])
+ total_n = sum(ns)
+
+ test_loss = sum(losses) / total_n
+ pred = torch.cat(pred)
+ tgt = torch.cat(tgt)
+ pred = pred.numpy()
+ tgt = tgt.numpy()
+ spearman = spearmanr(pred, tgt).correlation
+ if (tgt < 0).any():
+ pos_tgt = tgt - tgt.min()
+ else:
+ pos_tgt = tgt
+ ndcg = ndcg_score(pos_tgt.T, pred.T)
+ print('\tloss: %f' %test_loss, end='\t')
+ print('spearman: %f' %(spearman), end='\t')
+ print('ndcg: %f' %(ndcg), end='\t')
+ results = {
+ 'spearman': spearman,
+ 'loss': test_loss,
+ 'ndcg': ndcg
+ }
+ return results
+
+ epochs = 500
+ n_warmup = 1000
+ total_steps = 0
+ best_valid_metric = -np.inf
+ best_valid_loss = np.inf
+ patience = 10
+ scheduler = LambdaLR(optimizer, warmup(n_warmup))
+ waiting = 0
+ os.makedirs(args.out_fpath, exist_ok=True)
+ for e in range(epochs):
+ ts, train_loss = epoch(model, current_step=total_steps)
+ total_steps += ts
+ nsteps = total_steps
+ results = test_epoch(model, dl_valid)
+ vloss = results['loss']
+ vmetric = results['spearman']
+ waiting += 1
+ if vloss < best_valid_loss:
+ best_valid_loss = vloss
+ waiting = 0
+ torch.save({
+ 'step': nsteps,
+ 'epoch': e + 1,
+ 'model_state_dict': model.state_dict(),
+ 'val_spearman': vmetric,
+ 'val_ndcg': results['ndcg'],
+ 'val_loss': vloss,
+ 'train_loss': train_loss,
+ }, args.out_fpath + checkpoint_stem + '_best.pt')
+ if vmetric > best_valid_metric:
+ best_valid_metric = vmetric
+ waiting = 0
+ if vloss < train_loss:
+ waiting = 0
+ print("waiting: %d" % waiting)
+ if waiting == patience:
+ break
+ # TODO: checkpoint race condition
+ if args.out_fpath is not None:
+ sd = torch.load(args.out_fpath + checkpoint_stem + '_best.pt', weights_only=False)
+ model.load_state_dict(sd['model_state_dict'])
+ results = test_epoch(model, dl_test)
+ results['batch_size'] = batch_size
+ results['lr'] = lr
+ results['epoch'] = sd['epoch']
+ results['step'] = sd['step']
+ results['train_loss'] = sd['train_loss']
+ results['val_spearman'] = sd['val_spearman']
+ results['val_loss'] = sd['val_loss']
+ results['val_ndcg'] = sd['val_ndcg']
+ results['dataset'] = flip_dataset
+ results['split'] = flip_split
+ results['task'] = args.task
+ results['seed'] = args.seed
+ with open(args.out_fpath + checkpoint_stem + '.json', 'w') as f:
+ json.dump(results, f)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/baselines/carp_masked_zeroshot.py b/baselines/carp_masked_zeroshot.py
new file mode 100644
index 0000000..5fb259a
--- /dev/null
+++ b/baselines/carp_masked_zeroshot.py
@@ -0,0 +1,126 @@
+import argparse
+import os
+from typing import Tuple
+from tqdm import tqdm
+import json
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from scipy.stats import spearmanr
+
+from sequence_models.pretrained import load_model_and_alphabet
+from sequence_models.constants import MASK
+
+
+
+
+def train(args: argparse.Namespace) -> None:
+
+ # get the config, tokenizer, and model
+ torch.cuda.set_device(args.gpu)
+ DEVICE = torch.device('cuda:%d' % args.gpu)
+ output_dir = os.path.join(args.out_fpath, args.landscape, "splits")
+ carp, collator = load_model_and_alphabet('carp_640M')
+ # Move only model to GPU
+ model = carp.to(DEVICE)
+ model = model.eval()
+ a_to_t = collator.tokenizer.a_to_t
+ seq_to_result = {}
+ cache_file = os.path.join(output_dir, "carp640m_masked_scores.pt")
+ if os.path.exists(cache_file):
+ seq_to_result = torch.load(cache_file, weights_only=False)
+
+
+ # Get files
+ ## Grab data
+ batch_size = 1
+ landscape_path = os.path.join(args.data_fpath, args.landscape, "splits")
+ split_csvs = os.listdir(landscape_path)
+ split_csvs = [csv for csv in split_csvs if ".csv" in csv]
+ print(split_csvs)
+
+
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(landscape_path, split_csv))
+ likelihoods = np.empty(len(df))
+ # get the WT
+ # wt_seq = df[df['variant_info'].isna()]['sequence'].values[0]
+ for row in tqdm(df.itertuples(), total=len(df)):
+ sequence = row.sequence
+ variant_info = row.variant_info
+ mut_sequence = sequence[:]
+ input_sequence = sequence[:]
+ if isinstance(variant_info, str):
+ if "," in variant_info:
+ variant_info = variant_info.split(",")
+ else:
+ variant_info = [variant_info]
+ positions = tuple(int(v[1:-1]) - 1 for v in variant_info)
+ old = tuple(v[0] for v in variant_info)
+ new = tuple(v[-1] for v in variant_info)
+ else:
+ likelihoods[row.Index] = 0.0
+ continue
+ for i, pos in enumerate(positions):
+ mut_sequence = mut_sequence[:pos] + new[i] + mut_sequence[pos + 1:]
+ input_sequence = input_sequence[:pos] + MASK + input_sequence[pos + 1:]
+ assert mut_sequence == sequence
+ positions_key = ','.join(str(pos) for pos in positions)
+ if positions_key not in seq_to_result:
+ src = collator([[input_sequence]])[0].to(DEVICE)
+ with torch.no_grad():
+ logits = model(src, repr_layers=[], logits=True)['logits'] # 1, ell, 30
+ logits = logits[0, positions] # n_positions, 30
+ log_probs = F.log_softmax(logits, dim=-1).cpu().detach().numpy()
+ seq_to_result[positions_key] = log_probs
+ log_probs = seq_to_result[positions_key]
+ score = 0
+ for i, lp in enumerate(log_probs):
+ wt_lp = lp[a_to_t[old[i]]]
+ mut_lp = lp[a_to_t[new[i]]]
+ score += mut_lp - wt_lp
+ likelihoods[row.Index] = score
+
+
+
+
+ torch.save(seq_to_result, cache_file)
+ if args.landscape == "PDZ3":
+ df['carp_640m_zs_1'] = likelihoods[0]
+ df['carp_640m_zs_2'] = likelihoods[1]
+ else:
+ df['carp_640m_masked_zs'] = likelihoods
+ df.to_csv(os.path.join(output_dir, split_csv), index=False)
+
+
+
+ if args.landscape != "PDZ3":
+ print(split_csv,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_masked_zs']).statistic)
+ # else:
+ # print(split_csv,
+ # spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs_1'] + df[df['set'] == 'test']['carp_640m_zs_2']).statistic)
+
+
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('landscape', type=str)
+ parser.add_argument('--gpu', type=int, default=1)
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
diff --git a/baselines/carp_zeroshot.py b/baselines/carp_zeroshot.py
new file mode 100644
index 0000000..e8214b6
--- /dev/null
+++ b/baselines/carp_zeroshot.py
@@ -0,0 +1,108 @@
+import argparse
+import os
+from typing import Tuple
+from tqdm import tqdm
+import json
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from scipy.stats import spearmanr
+
+from sequence_models.collaters import Seq2PropertyCollater
+from sequence_models.constants import PAD, START, STOP
+from sequence_models.structure import Attention1d
+from sequence_models.utils import warmup
+from sequence_models.flip_utils import load_flip_data
+from sequence_models.pretrained import load_model_and_alphabet
+
+
+
+
+def train(args: argparse.Namespace) -> None:
+
+ # get the config, tokenizer, and model
+ torch.cuda.set_device(args.gpu)
+ DEVICE = torch.device('cuda:%d' % args.gpu)
+ output_dir = os.path.join(args.out_fpath, args.landscape, "splits")
+ carp, collator = load_model_and_alphabet('carp_640M')
+ # Move only model to GPU
+ model = carp.to(DEVICE)
+ model = model.eval()
+ seq_to_result = {}
+ cache_file = os.path.join(output_dir, "carp640m_scores.json")
+ if os.path.exists(cache_file):
+ with open(cache_file) as f:
+ seq_to_result = json.load(f)
+
+ # Get files
+ ## Grab data
+ batch_size = 1
+ landscape_path = os.path.join(args.data_fpath, args.landscape, "splits")
+ split_csvs = os.listdir(landscape_path)
+ split_csvs = [csv for csv in split_csvs if ".csv" in csv]
+ print(split_csvs)
+
+
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(landscape_path, split_csv))
+ likelihoods = [np.empty(len(df)), np.empty(len(df))]
+ for row in tqdm(df.itertuples(), total=len(df)):
+ sequence = row.sequence
+ if ":" in sequence:
+ sequences = sequence.split(":")
+ else:
+ sequences = [sequence]
+ for j, sequence in enumerate(sequences):
+ if sequence not in seq_to_result:
+ if len(sequence) == 0:
+ seq_to_result[sequence] = 0
+ continue
+ src = collator([[sequence]])[0].to(DEVICE)
+ with torch.no_grad():
+ logits = model(src, repr_layers=[], logits=True)['logits']
+ out = F.cross_entropy(logits[0], src.flatten())
+ seq_to_result[sequence] = -out.detach().cpu().item()
+ likelihoods[j][row.Index] = seq_to_result[sequence]
+ else:
+ likelihoods[j][row.Index] = seq_to_result[sequence]
+ with open(cache_file, "w") as f:
+ json.dump(seq_to_result, f)
+ if args.landscape == "PDZ3":
+ df['carp_640m_zs_1'] = likelihoods[0]
+ df['carp_640m_zs_2'] = likelihoods[1]
+ else:
+ df['carp_640m_zs'] = likelihoods[0]
+ df.to_csv(os.path.join(output_dir, split_csv), index=False)
+
+
+
+ if args.landscape != "PDZ3":
+ print(split_csv,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs']).statistic)
+ else:
+ print(split_csv,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs_1'] + df[df['set'] == 'test']['carp_640m_zs_2']).statistic)
+
+
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('landscape', type=str)
+ parser.add_argument('--gpu', type=int, default=1)
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
diff --git a/baselines/dayhoff_zeroshot.py b/baselines/dayhoff_zeroshot.py
new file mode 100644
index 0000000..c822bb9
--- /dev/null
+++ b/baselines/dayhoff_zeroshot.py
@@ -0,0 +1,148 @@
+import argparse
+import os
+from typing import Tuple
+from tqdm import tqdm
+import json
+
+import numpy as np
+import pandas as pd
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
+from scipy.stats import spearmanr
+from sklearn.metrics import ndcg_score
+from sequence_models.constants import START, STOP
+
+def is_amlt() -> bool:
+ return os.environ.get("AMLT_OUTPUT_DIR", None) is not None
+
+
+class SimpleCollator():
+
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+
+ def __call__(self, seq: "str") -> Tuple[torch.Tensor]:
+ fwd = START + seq + STOP
+ bwd = STOP + seq[::-1] + START
+ tokenized = self.tokenizer([fwd, bwd], return_tensors="pt", return_token_type_ids=False)
+ return (tokenized['input_ids'],)
+
+
+
+def train(args: argparse.Namespace) -> None:
+
+ # get the config, tokenizer, and model
+ torch.cuda.set_device(args.gpu)
+ DEVICE = torch.device('cuda:%d' % args.gpu)
+ output_dir = os.path.join(args.out_fpath, args.landscape, "splits")
+ os.makedirs(output_dir, exist_ok=True)
+ model_names = [
+ ['microsoft/Dayhoff-3b-UR90', "3b-UR90-seq_to_result.json", "dayhoff_3bur90"],
+ ['microsoft/Dayhoff-3b-GR-HM-c', "3b-gr-hm-c-seq_to_result.json", "dayhoff"]
+ ]
+ models = []
+ for m in model_names:
+ model = AutoModelForCausalLM.from_pretrained(m[0])
+ tokenizer = AutoTokenizer.from_pretrained(m[0], trust_remote_code=True)
+
+ # model = AutoModelForCausalLM.from_pretrained('microsoft/Dayhoff-3b-GR-HM-c')
+ # tokenizer = AutoTokenizer.from_pretrained('microsoft/Dayhoff-3b-GR-HM-c', trust_remote_code=True)
+ collator = SimpleCollator(tokenizer)
+
+ # Move only model to GPU
+ model = model.to(DEVICE)
+ model = model.to(torch.bfloat16)
+ model = model.eval()
+ seq_to_result = {}
+ cache_file = os.path.join(output_dir, m[1])
+ if os.path.exists(cache_file):
+ with open(cache_file) as f:
+ seq_to_result = json.load(f)
+ models.append([model, tokenizer, collator, cache_file, seq_to_result, m[2]])
+
+ # Get files
+ ## Grab data
+ batch_size = 1
+ landscape_path = os.path.join(args.data_fpath, args.landscape, "splits")
+ split_csvs = os.listdir(landscape_path)
+ split_csvs = [csv for csv in split_csvs if ".csv" in csv]
+ print(split_csvs)
+
+
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(landscape_path, split_csv))
+ for model in models:
+ model, tokenizer, collator, cache_file, seq_to_result, model_stem = model
+ if args.landscape == "PDZ3":
+ fwd_lls = [np.empty(len(df)), np.empty(len(df))]
+ bwd_lls = [np.empty(len(df)), np.empty(len(df))]
+ else:
+ fwd_lls = [np.empty(len(df))]
+ bwd_lls = [np.empty(len(df))]
+ for row in tqdm(df.itertuples(), total=len(df)):
+ sequence = row.sequence
+ if ":" in sequence:
+ sequences = sequence.split(":")
+ else:
+ sequences = [sequence]
+
+ for j, sequence in enumerate(sequences):
+ if sequence not in seq_to_result:
+ tokenized = collator(sequence)[0]
+ tokenized = tokenized.to(DEVICE)
+ with torch.no_grad():
+ out = model(input_ids=tokenized[:1], labels=tokenized[:1])
+ seq_to_result[sequence] = {"fwd": out.loss.detach().cpu().item()}
+ with torch.no_grad():
+ out = model(input_ids=tokenized[1:], labels=tokenized[1:])
+ seq_to_result[sequence]["bwd"] = out.loss.detach().cpu().item()
+
+ fwd_lls[j][row.Index] = seq_to_result[sequence]["fwd"]
+ bwd_lls[j][row.Index] = seq_to_result[sequence]["bwd"]
+ with open(cache_file, "w") as f:
+ json.dump(seq_to_result, f)
+ if args.landscape == "PDZ3":
+ df[model_stem + '_fwd_1'] = -fwd_lls[0]
+ df[model_stem + '_bwd_1'] = -bwd_lls[0]
+ df[model_stem + '_fwd_2'] = -fwd_lls[1]
+ df[model_stem + '_bwd_2'] = -bwd_lls[1]
+ else:
+ df[model_stem + '_fwd'] = -fwd_lls[0]
+ df[model_stem + '_bwd'] = -bwd_lls[0]
+ df[model_stem + '_min'] = -np.maximum(fwd_lls, bwd_lls)
+ df.to_csv(os.path.join(output_dir, split_csv), index=False)
+
+
+
+ if args.landscape != "PDZ3":
+ if "esm2_650M_scores" in df.columns:
+ print(split_csv,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_3bur90_min']).statistic,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_min']).statistic,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['esm2_650M_scores']).statistic)
+ else:
+ print(split_csv,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_3bur90_min']).statistic,
+ spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_min']).statistic)
+
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('landscape', type=str)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument("--no_fa2", action="store_true")
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
diff --git a/baselines/dayhoff_zeroshot.py.zip b/baselines/dayhoff_zeroshot.py.zip
new file mode 100644
index 0000000..e8ec5bf
Binary files /dev/null and b/baselines/dayhoff_zeroshot.py.zip differ
diff --git a/baselines/esm2_zeroshot.py b/baselines/esm2_zeroshot.py
new file mode 100644
index 0000000..4a3dd59
--- /dev/null
+++ b/baselines/esm2_zeroshot.py
@@ -0,0 +1,84 @@
+import argparse
+import os
+from tqdm import tqdm
+import json
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+
+import esm
+
+def is_amlt() -> bool:
+ return os.environ.get("AMLT_OUTPUT_DIR", None) is not None
+
+
+
+
+def train(args: argparse.Namespace) -> None:
+
+ # get the config, tokenizer, and model
+ torch.cuda.set_device(args.gpu)
+ DEVICE = torch.device('cuda:%d' % args.gpu)
+ output_dir = os.path.join(args.out_fpath, args.landscape, "splits")
+ os.makedirs(output_dir, exist_ok=True)
+
+ model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
+ batch_converter = alphabet.get_batch_converter()
+
+ model = model.to(DEVICE).eval()
+
+ seq_to_result = {}
+ cache_file = os.path.join(output_dir, "esm2_650m.json")
+ if os.path.exists(cache_file):
+ with open(cache_file) as f:
+ seq_to_result = json.load(f)
+
+ # Get files
+ ## Grab data
+ landscape_path = os.path.join(args.data_fpath, args.landscape, "splits")
+ split_csvs = os.listdir(landscape_path)
+ split_csvs = [csv for csv in split_csvs if ".csv" in csv]
+ print(split_csvs)
+
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(landscape_path, split_csv))
+ likelihoods = np.empty(len(df))
+ for row in tqdm(df.itertuples(), total=len(df)):
+ sequence = row.sequence
+ if sequence not in seq_to_result:
+ batch_labels, batch_strs, batch_tokens = batch_converter([("seq", sequence)])
+
+ with torch.no_grad():
+ logits = model(batch_tokens.to(DEVICE))["logits"][0, 1:-1]
+ out = F.cross_entropy(logits, batch_tokens[0, 1:-1].to(DEVICE))
+ seq_to_result[sequence] = -out.detach().cpu().item()
+ likelihoods[row.Index] = seq_to_result[sequence]
+ with open(cache_file, "w") as f:
+ json.dump(seq_to_result, f)
+ df['esm2_650M_scores'] = likelihoods
+ df.to_csv(os.path.join(output_dir, split_csv), index=False)
+
+
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('landscape', type=str)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument("--no_fa2", action="store_true")
+
+ args = parser.parse_args()
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
diff --git a/baselines/esmc.py b/baselines/esmc.py
new file mode 100644
index 0000000..7050998
--- /dev/null
+++ b/baselines/esmc.py
@@ -0,0 +1,262 @@
+import argparse
+import os
+import json
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import LambdaLR
+from torch.optim import Adam
+import torch.nn.functional as F
+import numpy as np
+from scipy.stats import spearmanr
+from sklearn.metrics import ndcg_score
+
+from esm.models.esmc import ESMC
+from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
+
+from sequence_models.structure import Attention1d
+from sequence_models.utils import warmup
+from sequence_models.flip_utils import load_flip_data
+
+
+
+
+
+class Model(nn.Module):
+
+ def __init__(self, d_model, dropout=0.0):
+ super().__init__()
+ self.d_model = d_model
+ self.attention = Attention1d(d_model)
+ self.activation = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.hidden = nn.Linear(d_model, d_model)
+ self.linear = nn.Linear(d_model, 1)
+
+ def forward(self, e, input_mask=None):
+ attended = self.attention(e, input_mask=input_mask)
+ hidden = self.hidden(self.activation(attended))
+ return self.linear(self.dropout(self.activation(hidden)))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+ parser.add_argument('task', type=str)
+ parser.add_argument('weights', type=str, default='pretrained')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--lr', default=1e-4, type=float)
+
+ args = parser.parse_args()
+ train(args)
+
+
+def train(args):
+ _ = torch.manual_seed(args.seed)
+ torch.cuda.set_device(0)
+ device = torch.device('cuda:0')
+ lr = args.lr
+
+ np.random.seed(args.seed)
+ esm = ESMC.from_pretrained("esmc_300m") # or "cpu"
+ if args.weights != 'pretrained':
+ for p in esm.modules():
+ try:
+ p.reset_parameters()
+ except AttributeError:
+ continue
+ esm = esm.to(device)
+ tokenizer = EsmSequenceTokenizer()
+
+ def collator(batch):
+ data = tuple(zip(*batch))
+ seqs, labels = data
+ t = [torch.tensor(tokenizer.encode(s)) for s in seqs]
+ max_len = max([len(tt) for tt in t])
+ t = [F.pad(tt, (0, max_len - len(tt)), value=1) for tt in t]
+ t = torch.stack(t)
+ y = torch.tensor(labels).unsqueeze(-1).float()
+ input_mask = t != 1
+ return t, y, input_mask
+
+
+ d_model = 960
+ decoder = Model(d_model, dropout=0).to(device)
+ model = nn.ModuleDict({'embed': esm.embed, 'transformer': esm.transformer, 'decoder': decoder})
+ optimizer = Adam(model.parameters(), lr=lr)
+
+
+ ## Grab data
+ batch_size = 16
+ loss_func = nn.MSELoss()
+ if "AMY_BACSU" in args.task:
+ flip_dataset = '_'.join(args.task.split('_')[:2])
+ flip_split = '_'.join(args.task.split('_')[2:])
+ else:
+ flip_dataset = args.task.split('_')[0]
+ flip_split = '_'.join(args.task.split('_')[1:])
+ ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, flip_dataset, flip_split, max_len=2048, scale=True)
+ num_workers = 4
+ dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collator,
+ num_workers=num_workers, shuffle=True)
+ dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collator, num_workers=num_workers)
+ dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collator, num_workers=num_workers)
+ print('%d Train samples %d valid samples %d test samples' %(len(ds_train), len(ds_valid), len(ds_test)))
+ checkpoint_stem = 'esmc_%s_%s_%d' %(args.task, args.weights, args.seed)
+ def step(model, batch, train=True, return_values=False):
+ src, tgt, input_mask = batch
+ src = src.to(device)
+ tgt = tgt.to(device)
+ input_mask = input_mask.to(device)
+ e = model['embed'](src)
+ e = model['transformer'](e)[0].float()
+ outputs = model['decoder'](e, input_mask=input_mask)
+
+ loss = loss_func(outputs, tgt)
+ locations = len(tgt)
+ mask = torch.ones(1) # dummy
+ if train:
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ if return_values:
+ return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu()
+ else:
+ return loss.item(), locations
+
+
+ def epoch(model, current_step=0):
+ model = model.train()
+ loader = dl_train
+ t = 'Training:'
+ losses = []
+ ns = []
+ n_seen = 0
+ if train:
+ n_total = len(ds_train)
+ else:
+ n_total = len(ds_valid)
+ for i, batch in enumerate(loader):
+ new_loss, new_n = step(model, batch, True)
+ losses.append(new_loss * new_n)
+ ns.append(new_n)
+ n_seen += len(batch[0])
+ total_n = sum(ns)
+ if total_n == 0:
+ rloss = 0
+ else:
+ rloss = sum(losses) / total_n
+ if train:
+ nsteps = current_step + i + 1
+ else:
+ nsteps = i
+ print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f'
+ % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss),
+ end='')
+ if not train:
+ return rloss
+ return i, rloss
+
+ def test_epoch(model, dl):
+ model = model.eval()
+ with torch.no_grad():
+ losses = []
+ ns = []
+ n_seen = 0
+ pred = []
+ tgt = []
+ masks = []
+ for i, batch in enumerate(dl):
+ new_loss, new_n, p, s, t, m = step(model, batch, False, return_values=True)
+ losses.append(new_loss * new_n)
+ pred.append(p)
+ tgt.append(t)
+ masks.append(m)
+ ns.append(new_n)
+ n_seen += len(batch[0])
+ total_n = sum(ns)
+
+ test_loss = sum(losses) / total_n
+ pred = torch.cat(pred)
+ tgt = torch.cat(tgt)
+ pred = pred.numpy()
+ tgt = tgt.numpy()
+ spearman = spearmanr(pred, tgt).correlation
+ if (tgt < 0).any():
+ pos_tgt = tgt - tgt.min()
+ else:
+ pos_tgt = tgt
+ ndcg = ndcg_score(pos_tgt.T, pred.T)
+ print('\tloss: %f' %test_loss, end='\t')
+ print('spearman: %f' %(spearman), end='\t')
+ print('ndcg: %f' %(ndcg), end='\t')
+ results = {
+ 'spearman': spearman,
+ 'loss': test_loss,
+ 'ndcg': ndcg
+ }
+ return results
+
+ epochs = 500
+ n_warmup = 1000
+ total_steps = 0
+ best_valid_metric = -np.inf
+ best_valid_loss = np.inf
+ patience = 10
+ scheduler = LambdaLR(optimizer, warmup(n_warmup))
+ waiting = 0
+ os.makedirs(args.out_fpath, exist_ok=True)
+ for e in range(epochs):
+ ts, train_loss = epoch(model, current_step=total_steps)
+ total_steps += ts
+ nsteps = total_steps
+ results = test_epoch(model, dl_valid)
+ vloss = results['loss']
+ vmetric = results['spearman']
+ waiting += 1
+ if vloss < best_valid_loss:
+ best_valid_loss = vloss
+ waiting = 0
+ torch.save({
+ 'step': nsteps,
+ 'epoch': e + 1,
+ 'model_state_dict': model.state_dict(),
+ 'val_spearman': vmetric,
+ 'val_ndcg': results['ndcg'],
+ 'val_loss': vloss,
+ 'train_loss': train_loss,
+ }, args.out_fpath + checkpoint_stem + '_best.pt')
+ if vmetric > best_valid_metric:
+ best_valid_metric = vmetric
+ waiting = 0
+ if vloss < train_loss:
+ waiting = 0
+ print("waiting: %d" % waiting)
+ if waiting == patience:
+ break
+ # TODO: checkpoint race condition
+ if args.out_fpath is not None:
+ sd = torch.load(args.out_fpath + checkpoint_stem + '_best.pt', weights_only=False)
+ model.load_state_dict(sd['model_state_dict'])
+ results = test_epoch(model, dl_test)
+ results['batch_size'] = batch_size
+ results['lr'] = lr
+ results['epoch'] = sd['epoch']
+ results['step'] = sd['step']
+ results['train_loss'] = sd['train_loss']
+ results['val_spearman'] = sd['val_spearman']
+ results['val_loss'] = sd['val_loss']
+ results['val_ndcg'] = sd['val_ndcg']
+ results['dataset'] = flip_dataset
+ results['split'] = flip_split
+ results['task'] = args.task
+ results['seed'] = args.seed
+ with open(args.out_fpath + checkpoint_stem + '.json', 'w') as f:
+ json.dump(results, f)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/baselines/get_predictions.py b/baselines/get_predictions.py
new file mode 100644
index 0000000..2a21735
--- /dev/null
+++ b/baselines/get_predictions.py
@@ -0,0 +1,190 @@
+import argparse
+import os
+from datetime import datetime
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+import pandas as pd
+import numpy as np
+
+from esm.models.esmc import ESMC
+from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
+
+from sequence_models.collaters import Seq2PropertyCollater
+from sequence_models.constants import PAD
+from sequence_models.structure import Attention1d
+from sequence_models.utils import warmup
+from sequence_models.flip_utils import load_flip_data
+from sequence_models.pretrained import load_model_and_alphabet
+
+
+
+
+
+class Model(nn.Module):
+
+ def __init__(self, d_model, dropout=0.0):
+ super().__init__()
+ self.d_model = d_model
+ self.attention = Attention1d(d_model)
+ self.activation = nn.GELU()
+ self.dropout = nn.Dropout(dropout)
+ self.hidden = nn.Linear(d_model, d_model)
+ self.linear = nn.Linear(d_model, 1)
+
+ def forward(self, e, input_mask=None):
+ attended = self.attention(e, input_mask=input_mask)
+ hidden = self.hidden(self.activation(attended))
+ return self.linear(self.dropout(self.activation(hidden)))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('data_fpath', type=str)
+ parser.add_argument('out_fpath', type=str)
+
+ args = parser.parse_args()
+ train(args)
+
+
+def train(args):
+ _ = torch.manual_seed(0)
+ torch.cuda.set_device(0)
+ device = torch.device('cuda:0')
+
+ np.random.seed(0)
+ carp, carp_collater = load_model_and_alphabet('carp_640M')
+ embedder = carp.model.embedder.to(device)
+ d_model = carp.model.embedder.up_embedder.conv.out_channels
+ decoder = Model(d_model, dropout=0)
+ decoder = decoder.to(device)
+ carp_model = nn.ModuleDict({'embedder': embedder, 'decoder': decoder})
+ carp_model = carp_model.eval()
+ carp_alphabet = carp_collater.tokenizer.alphabet
+ carp_collate_fn = Seq2PropertyCollater(carp_alphabet, return_mask=True)
+ esm = ESMC.from_pretrained("esmc_300m") # or "cpu"
+ esm = esm.to(device)
+ esm_tokenizer = EsmSequenceTokenizer()
+
+ def esm_collator(batch):
+ data = tuple(zip(*batch))
+ seqs, labels = data
+ t = [torch.tensor(esm_tokenizer.encode(s)) for s in seqs]
+ max_len = max([len(tt) for tt in t])
+ t = [F.pad(tt, (0, max_len - len(tt)), value=1) for tt in t]
+ t = torch.stack(t)
+ y = torch.tensor(labels).unsqueeze(-1).float()
+ input_mask = t != 1
+ return t, y, input_mask
+ d_model = 960
+ decoder = Model(d_model, dropout=0).to(device)
+ esm_model = nn.ModuleDict({'embed': esm.embed, 'transformer': esm.transformer, 'decoder': decoder})
+ esm_model = esm_model.eval()
+
+ split_dict = {
+ "AMY_BACSU": {
+ "random.csv": "random_split.csv",
+ "by_position.csv": "hard_split_.csv",
+ "close_to_far.csv":"med_split_is_close_to_as_0.csv",
+ "far_to_close.csv": "med_split_is_close_to_as_1.csv",
+ "one_to_many.csv": "one_to_many.csv",
+ },
+ "hydro": {
+ "low_to_high.csv": "hard_split.csv",
+ "random.csv": "random_split.csv",
+ "three_to_many.csv": "easy_split.csv",
+ "to_06241.csv": "med_P06241test_split.csv",
+ "to_P01053.csv": "med_P01053test_split.csv",
+ "to_P0A9X9.csv": "med_P0A9X9test_split.csv"
+ },
+ "ired": {
+ "mutation_order.csv": "ired_mutation_order_split.csv",
+ "random.csv": "ired_random_split.csv",
+ },
+ "NucB": {
+ "random.csv": "easy.csv",
+ "two_to_many.csv": "two_to_many.csv",
+ },
+ "PDZ3": {
+ "random.csv": "rand_split.csv",
+ "single_to_double.csv": "single_to_double.csv",
+ },
+ "RhoMax": {
+ "by_wt.csv": "by_wt.csv",
+ },
+ "trpb": {
+ "by_position.csv": "trpB_no_position_overlap_split.csv",
+ "one_to_many.csv": "trpB_one_vs_many_split.csv",
+ "two_to_many.csv": "trpB_two_vs_many_split.csv"
+ }
+ }
+ ## Grab data
+ checkpoint_paths = ["/mnt/amlt/flip_results_5-9-2025", "/mnt/amlt/flip_results/"]
+ os.makedirs(args.out_fpath, exist_ok=True)
+ landscapes = os.listdir(args.data_fpath)
+ for landscape in landscapes:
+ split_csvs = os.listdir(os.path.join(args.data_fpath, landscape, "splits"))
+ split_csvs = [csv for csv in split_csvs if ".csv" in csv]
+ num_workers = 4
+ for split_csv in split_csvs[::-1]:
+ out_file = os.path.join(args.out_fpath, "%s_%s_predictions.csv" % (landscape, split_csv[:-4]))
+ if os.path.isfile(out_file):
+ continue
+ print(landscape, split_csv, datetime.now())
+ ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, landscape, split_csv[:-4], max_len=2048,
+ scale=True)
+ results_df = pd.DataFrame()
+ results_df['sequence'] = [d[0] for d in ds_test]
+ results_df['scaled_target'] = [d[1] for d in ds_test]
+
+ carp_dl_test = DataLoader(ds_test, batch_size=32, collate_fn=carp_collate_fn, num_workers=num_workers)
+ esm_dl_test = DataLoader(ds_test, batch_size=32, collate_fn=esm_collator, num_workers=num_workers)
+ task = landscape + "_" + split_dict[landscape][split_csv][:-4]
+ for seed in range(5):
+ for weights in ['pretrained', 'naive']:
+ checkpoint_stem = 'carp_%s_%s_%d' % (task, weights, seed)
+ try:
+ sd = torch.load(os.path.join(checkpoint_paths[0], checkpoint_stem + "_best.pt"),
+ weights_only=False)
+ except FileNotFoundError:
+ sd = torch.load(os.path.join(checkpoint_paths[1], checkpoint_stem + "_best.pt"),
+ weights_only=False)
+ carp_model.load_state_dict(sd['model_state_dict'])
+ predictions = []
+ for i, batch in enumerate(carp_dl_test):
+ src, tgt, input_mask = batch
+ src = src.to(device)
+ with torch.no_grad():
+ input_mask = (src != carp_alphabet.index(PAD)).float().unsqueeze(-1)
+ e = carp_model['embedder'](src, input_mask=input_mask)
+ predictions.append(carp_model['decoder'](e, input_mask=input_mask).detach().cpu().numpy())
+ predictions = np.concatenate(predictions)
+ results_df["carp_%s_%d" % (weights, seed)] = predictions
+
+ checkpoint_stem = 'esmc_%s_%s_%d' % (task, weights, seed)
+ try:
+ sd = torch.load(os.path.join(checkpoint_paths[0], checkpoint_stem + "_best.pt"),
+ weights_only=False)
+ except FileNotFoundError:
+ sd = torch.load(os.path.join(checkpoint_paths[1], checkpoint_stem + "_best.pt"),
+ weights_only=False)
+ esm_model.load_state_dict(sd['model_state_dict'])
+ predictions = []
+ for i, batch in enumerate(esm_dl_test):
+ src, tgt, input_mask = batch
+ src = src.to(device)
+ input_mask = input_mask.to(device)
+ with torch.no_grad():
+ e = esm_model['embed'](src)
+ e = esm_model['transformer'](e)[0].float()
+ predictions.append(esm_model['decoder'](e, input_mask=input_mask).detach().cpu().numpy())
+ predictions = np.concatenate(predictions)
+ results_df["esmc_%s_%d" % (weights, seed)] = predictions
+ results_df.to_csv(out_file,
+ index=False)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/baselines/linear_models.py b/baselines/linear_models.py
new file mode 100644
index 0000000..a67db4b
--- /dev/null
+++ b/baselines/linear_models.py
@@ -0,0 +1,206 @@
+import argparse
+import json
+import os
+from tqdm import tqdm
+
+from sklearn.metrics import mean_squared_error
+from sklearn.linear_model import Ridge, RidgeClassifier
+from sklearn.preprocessing import StandardScaler
+from scipy.stats import spearmanr
+from sklearn.metrics import ndcg_score, roc_auc_score
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+import pandas as pd
+
+torch.manual_seed(0)
+
+from sequence_models.utils import Tokenizer
+from sequence_models.flip_utils import load_flip_data
+
+AAINDEX_ALPHABET = 'ARNDCQEGHILKMFPSTWYV'
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--seed', default=0, type=int)
+parser.add_argument('--solver', type=str, default='auto')
+parser.add_argument('--max_iter', type=int, default=1000000)
+parser.add_argument('--tol', type=float, default=1e-5)
+args = parser.parse_args()
+
+results_path = "/home/kevyan/results/flipv3/"
+
+model_dict = {
+ "esm2_650M_scores": "esm2_650M_zs",
+ # "dayhoff_fwd": "dayhoff_3b_gr_hm_c_fwd_zs",
+ # "dayhoff_bwd": "dayhoff_3b_gr_hm_c_bwd_zs",
+ # "dayhoff_min": "dayhoff_3b_gr_hm_c_min_zs",
+ # "dayhoff_max": "dayhoff_3b_gr_hm_c_max_zs",
+ "dayhoff_mean": "dayhoff_3b_gr_hm_c_mean_zs",
+ # "dayhoff_3bur90_fwd": "dayhoff_3b_ur90_fwd_zs",
+ # "dayhoff_3bur90_bwd": "dayhoff_3b_ur90_bwd_zs",
+ # "dayhoff_3bur90_min": "dayhoff_3b_ur90_zs",
+ # "dayhoff_3bur90_max": "dayhoff_3b_ur90_max_zs",
+ "dayhoff_3bur90_mean": "dayhoff_3b_ur90_mean_zs",
+ "dayhoff_both_mean": "dayhoff_both_mean"
+}
+tokenizer = Tokenizer(AAINDEX_ALPHABET) # tokenize
+
+# Randomize at different data sizes
+input_path = "/home/kevyan/results/flipv3/zs/"
+landscapes = os.listdir(input_path)
+np.random.seed(23)
+replicates = 50
+results = pd.DataFrame(columns=['dataset', 'model', 'fraction_train', 'n_train', 'Spearman', 'replicate'])
+n_min = 50
+with tqdm(total=len(landscapes) * replicates) as pbar:
+ for landscape in landscapes:
+ df = pd.read_csv(os.path.join(input_path, landscape))
+ landscape_name = landscape[:-7]
+ n = len(df)
+ X = df['sequence'].values
+ X = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in X]
+ maxlen = max([len(i) for i in X])
+ X = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X]
+ X_enc = [] # ohe
+ for i in X:
+ i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET))
+ i_onehot.zero_()
+ i_onehot.scatter_(1, i, 1)
+ X_enc.append(i_onehot)
+ X_enc = np.array([np.array(i.view(-1)) for i in X_enc]) # flatten
+ cols = ['esm2_650M_scores', 'carp_640m_zs', 'dayhoff']
+ X_zs = df[cols].values
+ new_X = np.hstack([X_enc, X_zs])
+ log10_min = np.log10(n_min)
+ log10_max = np.log10(n * 0.8)
+ n_trains = np.logspace(log10_min, log10_max, num=10)
+ for rep in range(replicates):
+ n_train = int(n_trains[rep % 10])
+ fraction = n_train / n
+ n_test = n - n_train
+ idx = np.arange(n)
+ np.random.shuffle(idx)
+ y_scale = df.iloc[idx[:n_train]]['target'].values[:, None]
+ y_test = df.iloc[idx[n_train:]]['target'].values[:, None]
+ X_train_enc = X_enc[idx[:n_train]]
+ X_test_enc = X_enc[idx[n_train:]]
+ scaler = StandardScaler()
+ scaler.fit(y_scale)
+ y_train = scaler.transform(y_scale)
+ y_test = scaler.transform(y_test)
+ lr = Ridge(solver='auto', tol=1e-5, max_iter=1000000, alpha=10)
+ # lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10)
+ lr.fit(X_train_enc, y_train)
+ preds = lr.predict(X_test_enc)
+ preds = preds.reshape(-1, 1)
+ rho = spearmanr(y_test, preds).correlation
+ results.loc[len(results)] = [landscape_name, 'Ridge (one-hot)', fraction, n_train, rho, rep]
+ new_X_train = new_X[idx[:n_train]]
+ new_X_test = new_X[idx[n_train:]]
+ lr = Ridge(solver='auto', tol=1e-5, max_iter=1000000, alpha=10)
+ lr.fit(new_X_train, y_train)
+ preds = lr.predict(new_X_test)
+ rho = spearmanr(y_test, preds).correlation
+ results.loc[len(results)] = [landscape_name, 'Ridge (one-hot + likelihoods)', fraction, n_train, rho, rep]
+ pbar.update(1)
+results.to_csv(results_path + "random_ridge.csv", index=False)
+
+input_path = "/home/kevyan/data/flip_data_pruned/"
+landscapes = os.listdir(input_path)
+
+for landscape in landscapes:
+ split_csvs = os.listdir(os.path.join(input_path, landscape, "splits"))
+ split_csvs = [c for c in split_csvs if "csv" in c]
+ for split_csv in split_csvs:
+ df = pd.read_csv(os.path.join(input_path, landscape, "splits", split_csv))
+# tokenize train data
+ X_train = df[(df["set"] == "train") & (~df['validation'])]['sequence'].values
+ y_scale = df[df['set'] == "train"]['target'].values[:, None]
+ y_train = df[(df["set"] == "train") & (~df['validation'])]['target'].values[:, None]
+ X_train = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in X_train]
+ seq_test = df[df["set"] == "test"]['sequence'].values
+ y_test = df[df["set"] == "test"]['target'].values[:, None]
+ X_test = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in seq_test]
+ # padding
+ maxlen_train = max([len(i) for i in X_train])
+ maxlen_test = max([len(i) for i in X_test])
+ maxlen = max([maxlen_train, maxlen_test])
+
+ X_train = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_train]
+ X_train_enc = [] # ohe
+ for i in X_train:
+ i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET))
+ i_onehot.zero_()
+ i_onehot.scatter_(1, i, 1)
+ X_train_enc.append(i_onehot)
+ X_train_enc = np.array([np.array(i.view(-1)) for i in X_train_enc]) # flatten
+
+ X_test = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_test]
+ X_test_enc = [] # ohe
+ for i in X_test:
+ i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET))
+ i_onehot.zero_()
+ i_onehot.scatter_(1, i, 1)
+ X_test_enc.append(i_onehot)
+ X_test_enc = np.array([np.array(i.view(-1)) for i in X_test_enc]) # flatten
+ scaler = StandardScaler()
+ scaler.fit(y_scale)
+ y_train = scaler.transform(y_train)
+ y_test = scaler.transform(y_test)
+ lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10)
+ lr.fit(X_train_enc, y_train)
+
+ print(landscape, split_csv[:-4], 'one-hot')
+ preds = lr.predict(X_test_enc)
+ preds = preds.reshape(-1, 1)
+
+ mse = mean_squared_error(y_test, preds)
+ print('TEST MSE: ', mse)
+ print('TEST RHO: ', spearmanr(y_test, preds).correlation)
+ y_test_pos = y_test - y_test.min()
+ print('TEST NDCG: ', ndcg_score(y_test_pos.T, preds.T))
+
+ results = pd.DataFrame()
+ results['sequence'] = seq_test
+ results['target'] = y_test
+ results['prediction'] = preds
+ os.makedirs(os.path.join(results_path, "ridge", landscape), exist_ok=True)
+ results.to_csv(os.path.join(results_path, "ridge", landscape, split_csv), index=False)
+
+
+ if landscape != "PDZ3":
+ df['dayhoff'] = df['dayhoff_bwd'] + df['dayhoff_fwd'] + df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd']
+ if 'carp_640m_masked_zs' in df.columns:
+ cols = ['dayhoff', "esm2_650M_scores", 'carp_640m_masked_zs']
+ else:
+ cols = ['dayhoff', "esm2_650M_scores", 'carp_640m_zs']
+ else:
+ df['dayhoff1'] = df['dayhoff_bwd_1'] + df['dayhoff_fwd_1'] + df['dayhoff_3bur90_fwd_1'] + df['dayhoff_3bur90_bwd_1']
+ df['dayhoff2'] = df['dayhoff_bwd_2'] + df['dayhoff_fwd_2'] + df['dayhoff_3bur90_fwd_2'] + df['dayhoff_3bur90_bwd_2']
+ cols = ["dayhoff1", "dayhoff2", "esm2_650M_scores1", "esm2_650M_scores1", 'carp_640m_zs_1', 'carp_640m_zs_2']
+ zs_x_train = df[(df["set"] == "train") & (~df['validation'])][cols].values
+ new_x_train = np.hstack([X_train_enc, zs_x_train])
+ zs_x_test = df[df["set"] == "test"][cols].values
+ new_x_test = np.hstack([X_test_enc, zs_x_test])
+ lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10)
+
+ lr.fit(new_x_train, y_train)
+
+ print(landscape, split_csv[:-4], 'one-hot + zs')
+ preds = lr.predict(new_x_test)
+ preds = preds.reshape(-1, 1)
+ mse = mean_squared_error(y_test, preds)
+ print('TEST MSE: ', mse)
+ print('TEST RHO: ', spearmanr(y_test, preds).correlation)
+ y_test_pos = y_test - y_test.min()
+ print('TEST NDCG: ', ndcg_score(y_test_pos.T, preds.T))
+ results = pd.DataFrame()
+ results['sequence'] = seq_test
+ results['target'] = y_test
+ results['prediction'] = preds
+ os.makedirs(os.path.join(results_path, "ridge_zs", landscape), exist_ok=True)
+ results.to_csv(os.path.join(results_path, "ridge_zs", landscape, split_csv), index=False)
+
+
diff --git a/collect_splits/amy_bacsu.py b/collect_splits/amy_bacsu.py
new file mode 100644
index 0000000..b1c9fad
--- /dev/null
+++ b/collect_splits/amy_bacsu.py
@@ -0,0 +1,40 @@
+import os
+
+import numpy as np
+import pandas as pd
+from scipy.stats import spearmanr
+
+pd.set_option('display.max_columns', 100)
+pd.set_option('display.width', 1000)
+
+n_muts = []
+df = pd.read_csv('/home/kevyan/data/flip_data_zs/AMY_BACSU/splits/easy_split.csv', index_col=0)
+for i, row in df.iterrows():
+ muts = row['variant_info']
+ if isinstance(muts, str):
+ n_muts.append(len(row['variant_info'].split(',')))
+ else:
+ n_muts.append(0)
+df['n_mutations'] = n_muts
+np.random.seed(0)
+df['validation'] = False
+for i, row in df.iterrows():
+ if row['n_mutations'] > 1:
+ df.loc[i, 'set'] = "test"
+ else:
+ df.loc[i, 'set'] = "train"
+ if np.random.random() < 0.15:
+ df.loc[i, "validation"] = True
+df.to_csv('/home/kevyan/data/flip_data_zs/AMY_BACSU/splits/one_to_many.csv', index=False)
+
+
+df[df['set'] == "test"].shape
+df['validation'].sum()
+df[(df['set'] == 'train') & (~df['validation'])].shape
+df[df['set'] == "test"]['target'].max()
+
+df[df['validation']]['target'].max()
+df[(df['set'] == 'train') & (~df['validation'])]['target'].max()
+df['target'].max()
+df[df['target'] > 0.19]
+spearmanr(df['esm2_650M_scores'], df['dayhoff_3bur90_fwd'])
\ No newline at end of file
diff --git a/collect_splits/nucb_fix_val.py b/collect_splits/nucb_fix_val.py
new file mode 100644
index 0000000..8aa47c9
--- /dev/null
+++ b/collect_splits/nucb_fix_val.py
@@ -0,0 +1,32 @@
+import pandas as pd
+import numpy as np
+
+
+pd.set_option('display.max_columns', 100)
+pd.set_option('display.width', 1000)
+df = pd.read_csv('/home/kevyan/data/flip_data_20250814/flip_data/datasets/NucB/splits/medium.csv', index_col=0)
+n_muts = []
+for i, row in df.iterrows():
+ muts = row['variant_info']
+ if isinstance(muts, str):
+ n_muts.append(len(row['variant_info'].split(',')))
+ else:
+ n_muts.append(0)
+df['n_mutations'] = n_muts
+np.random.seed(0)
+df['validation'] = False
+for i, row in df.iterrows():
+ if row['n_mutations'] > 2:
+ row['set'] = 'test'
+ else:
+ row['set'] = 'train'
+ if np.random.random() < 0.15:
+ df.loc[i, "validation"] = True
+
+df.to_csv('/home/kevyan/data/flip_data_pruned/NucB/splits/two_to_many.csv', index=False)
+
+df = pd.read_csv('/home/kevyan/data/flip_data_pruned/RhoMax/splits/by_wt.csv', index_col=0)
+spearmanr(df['target'], df['esm2_650M_scores'])
+spearmanr(df['target'], df['dayhoff_fwd'] + df['dayhoff_bwd'])
+spearmanr(df['target'], df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd'])
+spearmanr(df['target'], df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd'] + df['dayhoff_fwd'] + df['dayhoff_bwd'])
\ No newline at end of file
diff --git a/collect_splits/rhomax.py b/collect_splits/rhomax.py
new file mode 100644
index 0000000..aa8f48d
--- /dev/null
+++ b/collect_splits/rhomax.py
@@ -0,0 +1,62 @@
+import os
+from pathlib import Path
+import subprocess
+import argparse
+import os
+
+import pandas as pd
+import numpy as np
+
+parser = argparse.ArgumentParser()
+parser.add_argument("split_path", type=str, help="Directory to download raw data")
+parser.add_argument("out_path", type=str, help="Directory to save processed file")
+args = parser.parse_args()
+
+split_path = Path(args.split_path)
+# split_path = "/home/kevyan/src/FLIPv3/splits/rhomax"
+os.makedirs(split_path, exist_ok=True)
+subprocess.call(["wget", "-P", split_path, "https://github.com/dina-lab3D/OpsiGen/raw/refs/heads/colab/excel/data.xlsx"])
+
+df = pd.read_excel(os.path.join(split_path, "data.xlsx"))
+
+grouped = df.groupby("Wildtype")
+grouped = grouped.agg({"lmax": ['mean', 'count']})
+grouped.columns = grouped.columns.to_flat_index()
+grouped = grouped.sort_values(('lmax', 'count'))
+grouped = grouped.reset_index()
+val_wt = []
+n_val = 0
+test_wt = []
+n_test = 0
+num_val = 100
+num_test = 175
+current = 'val'
+
+for wildtype in grouped['Wildtype'].values:
+ if current == 'val' and n_val < num_val:
+ val_wt.append(wildtype)
+ n_val += grouped[grouped['Wildtype'] == wildtype][('lmax', 'count')].values[0]
+ if n_test < num_test:
+ current = 'test'
+ elif current == 'test' and n_test < num_test:
+ test_wt.append(wildtype)
+ n_test += grouped[grouped['Wildtype'] == wildtype][('lmax', 'count')].values[0]
+ if n_val < num_val:
+ current = 'val'
+
+df_val = df[df['Wildtype'].isin(val_wt)]
+df_test = df[df['Wildtype'].isin(test_wt)]
+df_train = df[~df['Wildtype'].isin(np.concatenate([val_wt, test_wt]))]
+print(len(df_val), "validation samples", "mean = ", df_val['lmax'].mean())
+print(len(df_test), "Test samples", "mean = ", df_test['lmax'].mean())
+print(len(df_train), "Train samples", "mean = ", df_train['lmax'].mean())
+
+df_out = pd.DataFrame()
+df_out['wildtype'] = df['Wildtype']
+df_out['sequence'] = df['Sequence']
+df_out['target'] = df['lmax']
+df_out['set'] = 'train'
+df_out.loc[df_out['wildtype'].isin(test_wt), 'set'] = 'test'
+df_out['validation'] = df_out['wildtype'].isin(val_wt)
+df_out = df_out[['sequence', 'target', 'set', 'validation']]
+df_out.to_csv(os.path.join(os.path.join(args.out_path, "by_wt.csv")), index=False)
\ No newline at end of file
diff --git a/flip/utils/ranking.py b/flip/utils/ranking.py
new file mode 100644
index 0000000..cdefe88
--- /dev/null
+++ b/flip/utils/ranking.py
@@ -0,0 +1,15 @@
+import torch
+from torch.nn import functional as F
+
+def make_square_and_get_triu(input):
+ input_by_input = input - input.transpose(0, 1)
+ idx = torch.triu_indices(len(input), len(input), 1, device=input.device)
+ return input_by_input[idx[0], idx[1]].view(-1, 1)
+
+
+def bradley_terry_loss(predictions, targets):
+ flat_targets = make_square_and_get_triu(targets)
+ flat_targets = flat_targets > 0
+ flat_targets = flat_targets.float()
+ flat_predictions = make_square_and_get_triu(predictions)
+ return F.binary_cross_entropy_with_logits(flat_predictions, flat_targets)
\ No newline at end of file