diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..3040da1 Binary files /dev/null and b/.DS_Store differ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/FLIPv3.iml b/.idea/FLIPv3.iml new file mode 100644 index 0000000..ecf479a --- /dev/null +++ b/.idea/FLIPv3.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..e460088 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,16 @@ + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..c419510 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,17 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..8acbfef --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..0ff281a --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/amlt/kky.yaml b/amlt/kky.yaml new file mode 100644 index 0000000..725c154 --- /dev/null +++ b/amlt/kky.yaml @@ -0,0 +1,88 @@ +description: Results on FLIPv3. + +target: + service: sing + name: msrresrchvc + workspace_name: bio0 + +environment: + image: amlt-sing/acpt-torch2.7.1-py3.10-cuda12.6-ubuntu22.04 + setup: + - cd sequence_models + - pip install -U --user -e . + - cd .. + - pip install biopython + - pip install httpx + - pip install esm +# - pip install scikit-learn +# - pip install scipy +storage: + amlt: + storage_account_name: kevyaneastus2 + container_name: amlt +code: + local_dir: src/ +jobs: + - name: "get_predictions-2" + identity: managed + sku: 1xG1-A100 + command: + - python FLIPv3/baselines/get_predictions.py /mnt/amlt/data/flip_data_pruned/ /mnt/amlt/flip_predictions/ + mpi: True + process_count_per_node: -1 + submit_args: + env: + WANDB_BASE_URL: "https://microsoft-research.wandb.io" + WANDB_API_KEY: "$WANDB_API_KEY" + WANDB_ENTITY: "bio" + NCCL_DEBUG: "INFO" + NCCL_SOCKET_IFNAME: "bond0" + NCCL_CUMEM_ENABLE: 0 + NCCL_DEBUG_SUBSYS: "ALL" + TORCH_DISTRIBUTED_DEBUG: "INFO" + AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai" + tags: [ Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0 ] +search: + job_template: + name: "{model_name}_{task}_{weights}_{seed}_{lr}" + sku: 1xG1-A100 + command: + - python FLIPv3/baselines/{model_name}.py /mnt/amlt/data/flip_data_5-9-2025/datasets/ /mnt/amlt/flip_results/ {task} {weights} --seed {seed} --lr {lr} + mpi: True + process_count_per_node: -1 + submit_args: + env: + WANDB_BASE_URL: "https://microsoft-research.wandb.io" + WANDB_API_KEY: "$WANDB_API_KEY" + WANDB_ENTITY: "bio" + NCCL_DEBUG: "INFO" + NCCL_SOCKET_IFNAME: "bond0" + NCCL_CUMEM_ENABLE: 0 + NCCL_DEBUG_SUBSYS: "ALL" + TORCH_DISTRIBUTED_DEBUG: "INFO" + _AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai" + tags: [Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0] + type: grid + max_trials: 10000 + params: + - name: lr + spec: discrete + values: [1e-5] + - name: model_name + spec: discrete + values: [carp] + - name: seed + spec: discrete + values: [1] + - name: weights + spec: discrete + values: [ naive ] + - name: task + spec: discrete + values: [ hydro_med_P06241test_split + ] +# RhoMax_by_wt, AMY_BACSU_easy_split, AMY_BACSU_hard_split_, AMY_BACSU_med_split_is_buried_0, AMY_BACSU_med_split_is_buried_1, AMY_BACSU_med_split_is_close_to_as_0, AMY_BACSU_med_split_is_close_to_as_1, +# AMY_BACSU_med_split_is_connected_0, AMY_BACSU_med_split_is_connected_1, AMY_BACSU_med_split_is_secondary_0, AMY_BACSU_med_split_is_secondary_1, AMY_BACSU_random_split, +# hydro_hard_split, hydro_med1_split, hydro_med2_split, hydro_random_split, +# ired_ired_excludeT241mut_mutation_order_split, ired_ired_low_high_split, ired_ired_mutation_order_split, ired_ired_random_split, +# PDZ3_low_vs_high, PDZ3_one_vs_rest, PDZ3_sampled, PDZ3_three_vs_rest, PDZ3_two_vs_rest, \ No newline at end of file diff --git a/amlt/ridge.yaml b/amlt/ridge.yaml new file mode 100644 index 0000000..c1a6962 --- /dev/null +++ b/amlt/ridge.yaml @@ -0,0 +1,50 @@ +description: Results on FLIPv3. + +target: + service: sing + name: msrresrchvc + workspace_name: bio0 + +environment: + image: amlt-sing/acpt-torch2.7.0-py3.10-cuda12.6-ubuntu22.04 + setup: + - cd sequence_models + - pip install -U --user -e . + - cd .. + - pip install biopython + - pip install esm + - pip install scikit-learn + - pip install scipy +storage: + amlt: + storage_account_name: kevyaneastus2 + container_name: amlt +code: + local_dir: src/ +search: + job_template: + name: "ridge_{task}" + sku: C3 + command: + - python FLIPv3/baselines/linear_models.py /mnt/amlt/data/flip_data/datasets/ /mnt/amlt/flip_results/ {task} + mpi: True + submit_args: + env: + WANDB_BASE_URL: "https://microsoft-research.wandb.io" + WANDB_API_KEY: "$WANDB_API_KEY" + WANDB_ENTITY: "bio" + NCCL_DEBUG: "INFO" + NCCL_SOCKET_IFNAME: "bond0" + NCCL_CUMEM_ENABLE: 0 + NCCL_DEBUG_SUBSYS: "ALL" + TORCH_DISTRIBUTED_DEBUG: "INFO" + _AZUREML_SINGULARITY_JOB_UAI: "/subscriptions/7be94754-107d-43b8-b840-202dff0e7cae/resourceGroups/bio0/providers/Microsoft.ManagedIdentity/userAssignedIdentities/bio0-uai" + tags: [Project_Name:Biomedical_ML,ProjectID:PRJ-0045-A32,Experiment:Bio-0] + type: random + max_trials: 25 + params: + - name: task + spec: discrete + values: [trpb_trpB_four_to_three_split, trpb_trpB_no_position_overlap_split, trpb_trpB_one_vs_many_split, trpb_trpB_three_to_four_split, trpb_trpB_two_vs_many_split, + NucB_easy, NucB_medium, NucB_hard + ] diff --git a/analysis/plot_baselines.py b/analysis/plot_baselines.py new file mode 100644 index 0000000..035fa64 --- /dev/null +++ b/analysis/plot_baselines.py @@ -0,0 +1,139 @@ +import os +from collections import Counter + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +import seaborn as sns +from scipy.stats import spearmanr, pearsonr +_ = sns.set(font_scale=1.7) +_ = sns.set_style('white') +pd.set_option('display.max_columns', None) +pd.set_option('display.width', 1000) +pd.set_option('display.max_rows', None) +flip_path = '/home/kevyan/results/flipv3/' +pruned_path = "/home/kevyan/data/flip_data_pruned/" + +datasets = os.listdir(os.path.join(flip_path, "all_predictions")) + + +df = pd.read_csv(os.path.join(flip_path, "random_ridge.csv")) + +df_zs = pd.DataFrame() +zs_model_dict = { + 'dayhoff': 'Dayhoff', + 'carp_640m_zs': 'CARP-640M', + # 'carp_640m_masked_zs': 'CARP-640M', + 'esm2_650M_scores': 'ESM2-650M' +} +idx = 0 +for dataset in datasets: + predictions = pd.read_csv(os.path.join(flip_path, 'zs', dataset + '_zs.csv')) + for model in zs_model_dict.keys(): + if model in predictions.columns: + df_zs.loc[idx, 'dataset'] = dataset + df_zs.loc[idx, 'model'] = zs_model_dict[model] + df_zs.loc[idx, 'Spearman'] = spearmanr(predictions['target'], predictions[model]).correlation + idx += 1 +best_zs = df_zs.groupby('dataset').agg({'Spearman': 'max'}).reset_index() + +# Get number of training examples in each split +dataset_path = '/home/kevyan/data/flip_data_pruned/' +df_sizes = pd.DataFrame(columns=['dataset', 'split', 'n_train', 'n_valid', 'n_test']) +for dataset in datasets: + split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits')) + split_csvs = [c for c in split_csvs if c[-4:] == '.csv'] + for split_csv in split_csvs: + df_data = pd.read_csv(os.path.join(dataset_path, dataset, 'splits', split_csv)) + n_test = len(df_data[df_data['set'] == 'test']) + n_train = len(df_data[(df_data['set'] == 'train') & (~df_data['validation'])]) + n_valid = len(df_data[df_data['validation']]) + df_sizes.loc[len(df_sizes)] = [dataset, split_csv[:-4], n_train, n_valid, n_test] +print(df_sizes) + +df_metrics = pd.read_csv(os.path.join(flip_path, 'all_metrics.csv')) +model_dict = { + "Ridge": "Ridge (one-hot)", + "zsRidge": 'Ridge (one-hot + likelihoods)', + "Dayhoff": "Dayhoff likelihood", + "ESM2-650M": "ESM2-650M likelihood", + "CARP-640M zero shot": "CARP-640M likelihood", + ("CARP-640M", True): "CARP-640M supervised", + ("CARP-640M", False): "CARP-640M naive supervised", + ("ESMC-300M", True): "ESMC-300M supervised", + ("ESMC-300M", False): "ESMC-300M naive supervised", +} + +for i, row in df_metrics.iterrows(): + if row['model'] in model_dict: + df_metrics.loc[i, 'model'] = model_dict[row['model']] + else: + df_metrics.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])] +df_metrics = df_metrics.fillna(0) +split_dict = { + 'close_to_far': 'position', + 'far_to_close': 'position', + 'by_mutation': 'mutation', + 'by_position': 'position', + 'by_wt': 'wild type', + 'random': 'random', + 'one_to_many': 'number', + 'to_P06241': 'wild type', + 'to_P01053': 'wild type', + 'to_P0A9X9': 'wild type', + 'low_to_high': 'fitness', + 'three_to_many': 'number', + 'single_to_double': 'number', + 'two_to_many': 'number', +} +pal = [ + '#76B900', + '#A77BB5', + '#4E79A7', + '#FF8A80', + '#F28E2B', + '#E15759' +] +split_hues = {"number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]} + + +for i, dataset in enumerate(datasets): + fig, ax = plt.subplots() + _ = sns.lineplot(data=df[df['dataset'] == dataset], x='n_train', y='Spearman', color='grey', style='model', + markers=True, ax=ax, ms=20, alpha=0.8) + _ = ax.axhline(y=best_zs[best_zs['dataset'] == dataset]['Spearman'].values[0], color='grey', linestyle='-', + label='best zero-shot likelihood score') + _ = ax.semilogx() + _ = ax.set_xlabel('Number of training examples') + _ = ax.set_ylim([-0.35, 1]) + split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits')) + split_csvs = [c for c in split_csvs if c[-4:] == '.csv'] + for split_csv in split_csvs: + if 'random' in split_csv: + continue + split_name = split_csv[:-4] + if dataset == 'Amylase' and split_name == 'by_position': + split_name = 'by_mutation' + color = split_hues[split_dict[split_name]] + x = df_sizes[(df_sizes['dataset'] == dataset) & (df_sizes['split'] == split_csv[:-4])]['n_train'] + y1 = df_metrics[(df_metrics['dataset'] == dataset) & (df_metrics['split'] == split_csv[:-4]) & (df_metrics['model'] == 'Ridge (one-hot)')]['Spearman'].values[0] + _ = ax.plot(x, y1, 'o', color=color, ms=20, alpha=0.7, mew=0) + y2 = df_metrics[(df_metrics['dataset'] == dataset) & (df_metrics['split'] == split_csv[:-4]) & (df_metrics['model'] == 'Ridge (one-hot + likelihoods)')]['Spearman'].values[0] + _ = ax.plot(x, y2, 'x', color=color, ms=20, alpha=1.0, mew=4) + _ = ax.set_title(dataset) + legend = ax.legend(title='Model') + legend.remove() + fig.savefig(os.path.join(flip_path, "plots", "random_ridge_%s.pdf" %dataset), dpi=300, bbox_inches='tight') +handles, labels = ax.get_legend_handles_labels() +fig, ax = plt.subplots() +legend = ax.legend(handles, labels, title='Model') +bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) +fig.savefig(os.path.join(flip_path, 'plots', 'random_ridge_models.pdf'), dpi=300, bbox_inches=bbox) +elements = [ + Patch(facecolor='gray', edgecolor=None, label='random'), +] +elements += [Patch(facecolor=split_hues[s], edgecolor=None, label=s, alpha=0.7) for s in split_hues] +legend = ax.legend(handles=elements, title='Split type') +bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) +fig.savefig(os.path.join(flip_path, 'plots', 'random_ridge_split_types.pdf'), dpi=300, bbox_inches=bbox) diff --git a/analysis/plot_zs.py b/analysis/plot_zs.py new file mode 100644 index 0000000..131d563 --- /dev/null +++ b/analysis/plot_zs.py @@ -0,0 +1,457 @@ +import os +from collections import Counter + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.stats import spearmanr, pearsonr +_ = sns.set(font_scale=1.7) +_ = sns.set_style('white') +pd.set_option('display.max_columns', None) +pd.set_option('display.width', 1000) +pd.set_option('display.max_rows', None) +flip_path = '/home/kevyan/results/flipv3/' +pruned_path = "/home/kevyan/data/flip_data_pruned/" + +datasets = os.listdir(os.path.join(flip_path, "all_predictions")) + +pal = sns.color_palette() +models = [ + ('dayhoff', 'Dayhoff likelihood', pal[0]), + ('esm2_650M_scores', "ESM2-650M likelihood", pal[1]), + ('carp_640m_zs', "CARP-640M likelihood", pal[2]), + ('Ridge', 'Ridge (one-hot)', pal[3]), + ('zsRidge', 'Ridge (one-hot + likelihoods)', pal[4]), + ('carp_naive', 'CARP-640M naive supervised', pal[5]), + ('carp_pretrained', 'CARP-640M supervised', pal[6]), + ('esmc_naive', 'ESMC-300M naive supervised', pal[7]), + ('esmc_pretrained', 'ESMC-300M supervised', pal[8]), +] +os.makedirs(os.path.join(flip_path, 'plots', 'predictions'), exist_ok=True) +# plot individual predictions +_ = sns.set(font_scale=1) +_ = sns.set_style('white') +for dataset in datasets: + splits = os.listdir(os.path.join(flip_path, "all_predictions", dataset)) + for split in splits: + split_name = split[:-4] + df = pd.read_csv(os.path.join(flip_path, 'all_predictions', dataset, split)) + df['carp_naive'] = (df['carp_naive_0'] + df['carp_naive_1'] + df['carp_naive_2'] + df['carp_naive_3'] + df['carp_naive_4']) / 5 + df['carp_pretrained'] = (df['carp_pretrained_0'] + df['carp_pretrained_1'] + df['carp_pretrained_2'] + df['carp_pretrained_3'] + df['carp_pretrained_4']) / 5 + df['esmc_naive'] = (df['esmc_naive_0'] + df['esmc_naive_1'] + df['esmc_naive_2'] + df['esmc_naive_3'] + df['esmc_naive_4']) / 5 + df['esmc_pretrained'] = (df['esmc_pretrained_0'] + df['esmc_pretrained_1'] + df['esmc_pretrained_2'] + df['esmc_pretrained_3'] + df['esmc_pretrained_4']) / 5 + for ugly, pretty, color in models: + # if dataset == "PDZ3": + # if ugly in ['dayhoff', 'esm2_650M_scores', 'carp_640m_zs']: + # continue + fig, ax = plt.subplots() + _ = sns.scatterplot(data=df, x='scaled_target', y=ugly, alpha=0.3, marker='o', ax=ax, color='gray') + _ = ax.set_ylabel(pretty) + _ = ax.set_xlabel('scaled target') + fig.savefig(os.path.join(flip_path, 'plots', 'predictions', + '_'.join([dataset, split_name, ugly + '.png'])), + dpi=100, bbox_inches='tight') + +# Plot landscape zero-shot scores +_ = sns.set(font_scale=1.7) +_ = sns.set_style('white') +pal = [ + '#76B900', + '#A77BB5', + '#4E79A7', + '#FF8A80', + '#F28E2B', + '#E15759' +] +model_dict = { + 'dayhoff': 'Dayhoff likelihood', + 'carp_640m_zs': 'CARP-640M likelihood', + 'esm2_650M_scores': 'ESM2-650M likelihood', +} +model_order = [ + 'Dayhoff', + 'ESM2-650M', + 'CARP-640M', +] +plot_me = pd.DataFrame() +idx = 0 +for dataset in datasets: + predictions = pd.read_csv(os.path.join(flip_path, 'zs', dataset + '_zs.csv')) + for model in model_dict.keys(): + if model in predictions.columns: + plot_me.loc[idx, 'dataset'] = dataset + plot_me.loc[idx, 'split'] = 'full' + plot_me.loc[idx, 'model'] = model_dict[model] + plot_me.loc[idx, 'Spearman'] = spearmanr(predictions['target'], predictions[model]).correlation + idx += 1 +dataset_order = ['Amylase', 'IRED', 'NucB', 'TrpB', 'Hydro', 'Rhomax', 'PDZ3'] +fig, ax = plt.subplots(figsize=(8, 6)) +palette = ['gray', 'gray', 'gray'] +_ = sns.barplot(x='dataset', y='Spearman', data=plot_me, ax=ax, hue='model', palette=palette, order=dataset_order) +hatches = ['/', '.', '*'] +handles, labels = ax.get_legend_handles_labels() +for i, bar in enumerate(ax.patches[:21]): + hatch = hatches[i // 7] + bar.set_hatch(hatch) +for i, handle in enumerate(handles): + handle.set_hatch(hatches[i]) +ax.legend(handles, labels) +ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right') +fig.savefig(os.path.join(flip_path, 'plots', 'dataset_zeroshot.pdf'), dpi=300, bbox_inches='tight') + + +# plot aggregate things +df = pd.read_csv(os.path.join(flip_path, 'all_metrics.csv')) +model_dict = { + "Ridge": "Ridge (one-hot)", + "zsRidge": 'Ridge (one-hot + likelihoods)', + "Dayhoff": "Dayhoff likelihood", + "ESM2-650M": "ESM2-650M likelihood", + "CARP-640M zero shot": "CARP-640M likelihood", + ("CARP-640M", True): "CARP-640M supervised", + ("CARP-640M", False): "CARP-640M naive supervised", + ("ESMC-300M", True): "ESMC-300M supervised", + ("ESMC-300M", False): "ESMC-300M naive supervised", +} +model_order = [ + "Dayhoff likelihood", + "ESM2-650M likelihood", + "CARP-640M likelihood", + "Ridge (one-hot)", + "Ridge (one-hot + likelihoods)", + "CARP-640M naive supervised", + "CARP-640M supervised", + "ESMC-300M naive supervised", + "ESMC-300M supervised", +] +# hue = {m: p for m, p in zip(model_order, pal)} + + +for i, row in df.iterrows(): + if row['model'] in model_dict: + df.loc[i, 'model'] = model_dict[row['model']] + else: + df.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])] + +# for dataset in set(df['dataset']): +# for split in set(df[df['dataset'] == dataset]['split']): +# pretty_split = split.replace('_', '-') +# data = df[(df['dataset'] == dataset) & (df['split'] == split)] +# fig, ax = plt.subplots() +# _ = sns.barplot(data=data, x='model', y='Spearman', hue='model', hue_order=model_order, palette=hue, +# errorbar=None, ax=ax, alpha=0.7, order=model_order) +# _ = sns.stripplot(data=data, x='model', y='Spearman', hue='model', hue_order=model_order, palette=hue, +# ax=ax, order=model_order) +# _ = ax.set_title(dataset + " " + pretty_split) +# _ = ax.set_ylim([-0.2, 1]) +# ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right') +# fig.savefig(os.path.join(flip_path, 'plots', 'spearman_' + dataset + '_' + split + '.pdf'), +# dpi=300, bbox_inches='tight') + + +split_dict = { + 'close_to_far': 'position', + 'far_to_close': 'position', + 'by_position': 'position', + 'by_mutation': 'mutation', + 'by_wt': 'wild type', + 'random': 'random', + 'one_to_many': 'number', + 'to_P06241': 'wild type', + 'to_P01053': 'wild type', + 'to_P0A9X9': 'wild type', + 'low_to_high': 'fitness', + 'three_to_many': 'number', + 'single_to_double': 'number', + 'two_to_many': 'number', +} +split_hues = {"full": 'gray', "number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]} +split_order = ["full", 'number', 'wild type', 'position', 'mutation', 'fitness'] +df['split type'] = df.apply(lambda row: split_dict[row['split']], axis=1) + +all_zs = df[df['model'].isin(['Dayhoff likelihood', 'ESM2-650M likelihood', 'CARP-640M likelihood'])] +all_zs = all_zs[['dataset', 'split', 'model', 'Spearman']] +all_zs = pd.concat([plot_me, all_zs], ignore_index=True) + +for idx, row in all_zs.iterrows(): + all_zs.loc[idx, 'task'] = row['dataset'] + ' ' + '-'.join(row['split'].split('_')) + all_zs.loc[idx, 'split type'] = split_dict[row['split']] if row['split'] in split_dict else row['split'] + +task_order = [ + 'Amylase full', + 'Amylase one-to-many', + 'Amylase close-to-far', + 'Amylase far-to-close', + 'Amylase by-mutation', + 'IRED full', + 'IRED two-to-many', + 'NucB full', + 'NucB two-to-many', + 'TrpB full', + 'TrpB one-to-many', + 'TrpB two-to-many', + 'TrpB by-position', + 'Hydro full', + 'Hydro three-to-many', + 'Hydro low-to-high', + 'Hydro to-P06241', + 'Hydro to-P0A9X9', + 'Hydro to-P01053', + 'Rhomax full', + 'Rhomax by-wt', + 'PDZ3 full', + 'PDZ3 single-to-double' +] +all_zs['task'] = pd.Categorical(all_zs['task'], categories=task_order, ordered=True) +_ = sns.set(font_scale=1.5) +_ = sns.set_style('white') +fig, ax = plt.subplots(figsize=(16, 6)) +# _ = ax.fill_betweenx([-0.5, 0.67], x1=[-1.1, -1.1], x2=[4.5, 4.5], color='gray', alpha=0.1) +_ = ax.fill_betweenx([-0.5, 0.67], x1=[4.5, 4.5], x2=[6.5, 6.5], color='gray', alpha=0.1, linewidth=0) +_ = ax.fill_betweenx([-0.5, 0.67], x1=[8.5, 8.5], x2=[12.5, 12.5], color='gray', alpha=0.1, linewidth=0) +_ = ax.fill_betweenx([-0.5, 0.67], x1=[18.5, 18.5], x2=[20.5, 20.5], color='gray', alpha=0.1, linewidth=0) + +_ = sns.scatterplot(data=all_zs, x='task', y='Spearman', hue='split type', hue_order=split_order, + ax=ax, style='model', palette=split_hues, s=150, markers=['X', 'o', 's'], alpha=0.7) +_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right') +_ = ax.legend(loc='upper left', bbox_to_anchor=(1.01, 1)) +fig.savefig(os.path.join(flip_path, 'plots', 'all_zeroshot.pdf'), dpi=300, bbox_inches='tight') + +_ = sns.set(font_scale=1.7) +_ = sns.set_style('white') +split_hues = {"number": pal[0], "wild type": pal[1], "position": pal[2], "mutation": pal[3], "fitness": pal[4]} +split_order = ['number', 'wild type', 'position', 'mutation', 'fitness'] +plot_me = pd.DataFrame() +idx = 0 +tasks = list(df[['dataset', 'split']].values) +tasks = set([(t[0], t[1]) for t in tasks]) +for dataset, split in tasks: + current = df[(df['dataset'] == dataset) & (df['split'] == split)] + for j, row in current.iterrows(): + plot_me.loc[idx, 'dataset'] = dataset + plot_me.loc[idx, 'split'] = split + plot_me.loc[idx, row['model']] = row['Spearman'] + plot_me.loc[idx, 'split type'] = split_dict[split] + idx += 1 +plot_me = plot_me.fillna(0) +plot_me = plot_me[plot_me['split type'] != 'random'] +models = list(plot_me.columns[2:]) +models.remove('split type') +models = np.array(models) +plot_me['best model'] = models[plot_me[models].values.argmax(axis=1)] +print(plot_me[['dataset', 'split', 'best model']].sort_values('dataset')) +# Plot Ridge vs Ridge + zs +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot)', y='Ridge (one-hot + likelihoods)', hue='split type', + alpha=0.7, ax=ax, color='gray', hue_order=split_order, s=150, legend=False, palette=split_hues) +_ = ax.set_ylabel('Ridge (one-hot+likelihoods)') +fig.savefig(os.path.join(flip_path, 'plots', 'ridge_comparison.pdf'), dpi=300, bbox_inches='tight') + +# Plot Ridge + zs vs best zs +plot_me['best zero-shot'] = plot_me.loc[:, ['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood']].max(axis=1) +plot_me['best zs method'] = np.array(['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood'])[plot_me.loc[:, ['Dayhoff likelihood', 'CARP-640M likelihood', 'ESM2-650M likelihood']].values.argmax(axis=1)] +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me, x='best zero-shot', y='Ridge (one-hot + likelihoods)', hue='split type', + alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150) +_ = ax.set_ylabel('Ridge (one-hot+likelihoods)') + +legend = ax.legend() +legend.remove() +fig.savefig(os.path.join(flip_path, 'plots', 'zs_vs_ridgezs.pdf'), dpi=300, bbox_inches='tight') + +# Plot Ridge vs best PLM +plot_me['best PLM'] = plot_me.loc[:, ['CARP-640M supervised', 'CARP-640M naive supervised', 'ESMC-300M supervised', 'ESMC-300M naive supervised']].max(axis=1) +archs = np.array(['CARP-640M', 'CARP-640M', 'ESMC-300M', 'ESMC-300M']) +plot_me['best architecture'] = archs[plot_me.loc[:, ['CARP-640M supervised', 'CARP-640M naive supervised', 'ESMC-300M supervised', 'ESMC-300M naive supervised']].values.argmax(axis=1)] +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot)', y='best PLM', hue='split type', style='best architecture', + alpha=0.7, + legend=False, style_order=['CARP-640M', 'ESMC-300M'], ax=ax, palette=split_hues, + hue_order=split_order, s=150) +fig.savefig(os.path.join(flip_path, 'plots', 'ridge_vs_plm.pdf'), dpi=300, bbox_inches='tight') + +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me, x='Ridge (one-hot + likelihoods)', y='best PLM', hue='split type', + style='best architecture', style_order=['CARP-640M', 'ESMC-300M'], legend=False, + alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150) +fig.savefig(os.path.join(flip_path, 'plots', 'ridgezs_vs_plm.pdf'), dpi=300, bbox_inches='tight') + +# Plot CARP +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me, x='CARP-640M likelihood', y='CARP-640M supervised', hue='split type', + alpha=0.7, ax=ax, palette=split_hues, hue_order=split_order, s=150) +legend = ax.legend() +legend.remove() +fig.savefig(os.path.join(flip_path, 'plots', 'carp.pdf'), dpi=300, bbox_inches='tight') +handles, labels = ax.get_legend_handles_labels() +fig, ax = plt.subplots() +legend = ax.legend(handles, labels, title='Split type') +bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) +fig.savefig(os.path.join(flip_path, 'plots', 'split_types.pdf'), dpi=300, bbox_inches=bbox) + +# Get dataset sizes +dataset_path = '/home/kevyan/data/flip_data_pruned/' +for dataset in set(plot_me['dataset']): + split_csvs = os.listdir(os.path.join(dataset_path, dataset, 'splits')) + split_csvs = [c for c in split_csvs if c[-4:] == '.csv'] + for split_csv in split_csvs: + df_data = pd.read_csv(os.path.join(dataset_path, dataset, 'splits', split_csv)) + n_test = len(df_data[df_data['set'] == 'test']) + n_train = len(df_data[(df_data['set'] == 'train') & (~df_data['validation'])]) + n_valid = len(df_data[df_data['validation']]) + index = plot_me[(plot_me['dataset'] == dataset) & (plot_me['split'] == split_csv[:-4])].index + if not index.empty: + plot_me.loc[index, 'n_train'] = n_train + + + +# Plot pretrained vs naive +plot_me2 = pd.DataFrame() +idx = 0 +for i, row in plot_me.iterrows(): + plot_me2.loc[idx, 'dataset'] = row['dataset'] + plot_me2.loc[idx, 'split'] = row['split'] + plot_me2.loc[idx, 'split type'] = split_dict[row['split']] + plot_me2.loc[idx, 'pretrained'] = row['CARP-640M supervised'] + plot_me2.loc[idx, 'naive'] = row['CARP-640M naive supervised'] + plot_me2.loc[idx, 'architecture'] = 'CARP-640M' + idx += 1 + plot_me2.loc[idx, 'dataset'] = row['dataset'] + plot_me2.loc[idx, 'split'] = row['split'] + plot_me2.loc[idx, 'split type'] = split_dict[row['split']] + plot_me2.loc[idx, 'pretrained'] = row['ESMC-300M supervised'] + plot_me2.loc[idx, 'naive'] = row['ESMC-300M naive supervised'] + plot_me2.loc[idx, 'architecture'] = 'ESMC-300M' + idx += 1 +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.8], [-0.2, 0.8], color='gray') +_ = sns.scatterplot(data=plot_me2, x='naive', y='pretrained', hue='split type', style='architecture', alpha=0.7, + palette=split_hues, ax=ax, style_order=['CARP-640M', 'ESMC-300M'], s=150, legend=True, + hue_order=split_order) +legend = ax.legend() +legend.remove() +handles, labels = ax.get_legend_handles_labels() +fig.savefig(os.path.join(flip_path, 'plots', 'plms.pdf'), dpi=300, bbox_inches='tight') +fig, ax = plt.subplots() +legend = ax.legend(handles, labels) +bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) +fig.savefig(os.path.join(flip_path, 'plots', 'plms_legend.pdf'), dpi=300, bbox_inches=bbox) + +# Barplot of ridge spearmans +fig, ax = plt.subplots(figsize=(8, 6)) +plot_me3 = df[df['model'] == 'Ridge (one-hot)'] +plot_me3 = df[df['split'] != 'random'] +for i, row in plot_me3.iterrows(): + plot_me3.loc[i, 'task'] = row['dataset'] + ' ' + row['split'].replace('_', '-') +task_order = [ + 'Amylase one-to-many', + 'Amylase close-to-far', + 'Amylase far-to-close', + 'Amylase by-mutation', + 'IRED two-to-many', + 'NucB two-to-many', + 'TrpB one-to-many', + 'TrpB two-to-many', + 'TrpB by-position', + 'Hydro three-to-many', + 'Hydro low-to-high', + 'Hydro to-P06241', + 'Hydro to-P0A9X9', + 'Hydro to-P01053', + 'Rhomax by-wt', + 'PDZ3 single-to-double' +] +_ = sns.barplot(data=plot_me3, x='task', y='Spearman', ax=ax, hue='split type', hue_order=split_order, + palette=split_hues, + order=task_order, legend=False) +_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right') +fig.savefig(os.path.join(flip_path, 'plots', 'all_ridge.pdf'), dpi=300, bbox_inches='tight') + + +plot_me3 = df[df['model'].isin(['Dayhoff likelihood', 'ESM-650M likelihood', 'CARP-640M likelihood'])] +for i, row in plot_me3.iterrows(): + plot_me3.loc[i, 'task'] = row['dataset'] + ' ' + row['split'].replace('_', '-') +plot_me3 = plot_me3.groupby('task').agg({'Spearman': 'mean', 'split type': lambda x: x.values[0]}) +plot_me3 = plot_me3.reset_index() +fig, ax = plt.subplots(figsize=(8, 6)) +_ = sns.barplot(data=plot_me3, x='task', y='Spearman', ax=ax, hue='split type', hue_order=split_order, + palette=split_hues, legend=False, order=task_order) +_ = ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right') +fig.savefig(os.path.join(flip_path, 'plots', 'best_zs.pdf'), dpi=300, bbox_inches='tight') + + +# Plot train zs Spearman vs test zs Spearman +datasets = os.listdir(pruned_path) +model_order = [ + "Dayhoff", + "ESM2-650M", + "CARP-640M", + "Ridge (one-hot)", + "Ridge (one-hot + likelihoods)", + "CARP-640M naive supervised", + "CARP-640M supervised", + "ESMC-300M naive supervised", + "ESMC-300M supervised", +] +hue = {m: p for m, p in zip(model_order, pal)} + +zs_df = pd.DataFrame() +zs_columns = { + 'esm2_650M_scores': "ESM2-650M", + 'carp_640m_zs': "CARP-640M", + 'dayhoff': 'Dayhoff',} +idx = 0 +for dataset in datasets: + for split in os.listdir(os.path.join(pruned_path, dataset, 'splits')): + if split[-4:] != ".csv": + continue + results = pd.read_csv(os.path.join(pruned_path, dataset, 'splits', split)) + if 'carp_640m_masked_zs' in results.columns: + results['carp_640m_zs'] = results['carp_640m_masked_zs'] + results = results.drop(columns=['carp_640m_masked_zs']) + if dataset != "PDZ3": + results.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"}, inplace=True) + results['dayhoff'] = (results['dayhoff_3bgrhmc_fwd'] + results['dayhoff_3bgrhmc_bwd'] + results['dayhoff_3bur90_fwd'] + results['dayhoff_3bur90_bwd']) / 4 + else: + results.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"}, inplace=True) + results.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"}, inplace=True) + d1 = (results['dayhoff_3bgrhmc_fwd_1'] + results['dayhoff_3bgrhmc_bwd_1'] + results['dayhoff_3bur90_fwd_1'] + results['dayhoff_3bur90_bwd_1']) / 4 + d1 += (results['dayhoff_3bgrhmc_fwd_2'] + results['dayhoff_3bgrhmc_bwd_2'] + results['dayhoff_3bur90_fwd_2'] + results['dayhoff_3bur90_bwd_2']) / 4 + d1 /= 2 + results['dayhoff'] = d1 + results['esm2_650M_scores'] = (results['esm2_650M_scores1'] + results['esm2_650M_scores2'].fillna(0)) / 2 + results['carp_640m_zs'] = (results['carp_640m_zs_1'] + results['carp_640m_zs_2']) / 2 + for m in zs_columns.keys(): + if m in results.columns: + zs_df.loc[idx, 'dataset'] = dataset + zs_df.loc[idx, 'split type'] = split_dict[split[:-4]] + zs_df.loc[idx, 'split'] = split[:-4] + zs_df.loc[idx, 'model'] = zs_columns[m] + for s in ['train', 'test']: + d = results[results['set'] == s] + sp = spearmanr(d['target'], d[m]).correlation + zs_df.loc[idx, s + ' Spearman'] = sp + idx += 1 +plot_me = zs_df[zs_df['split'] != 'random'] +fig, ax = plt.subplots() +_ = ax.plot([-0.2, 0.7], [-0.2, 0.7], color='gray') +_ = sns.scatterplot(data=plot_me, x='train Spearman', y='test Spearman', hue='split type', + palette=split_hues, hue_order=split_order, ax=ax, alpha=0.7, s=150, legend=False) +fig.savefig(os.path.join(flip_path, 'plots', 'zs_train_v_test.pdf'), dpi=300, bbox_inches='tight') +print("Train/test zero shot spearmans") +print(pearsonr(zs_df.dropna()['train Spearman'], zs_df.dropna()['test Spearman'])) +zs_df = zs_df[zs_df['split type'] != 'random'] +for model in model_order[:3]: + d = zs_df[zs_df['model'] == model] + print(model, pearsonr(d.dropna()['train Spearman'], d.dropna()['test Spearman'])) + +for st in set(split_dict.values()): + if st != 'random': + d = zs_df[zs_df['split type'] == st] + print(st, pearsonr(d.dropna()['train Spearman'], d.dropna()['test Spearman']).correlation) diff --git a/baselines/aggregate.py b/baselines/aggregate.py new file mode 100644 index 0000000..9cf4afa --- /dev/null +++ b/baselines/aggregate.py @@ -0,0 +1,237 @@ +import os +from collections import Counter + +import pandas as pd +import numpy as np +from scipy.stats import spearmanr +from sklearn.metrics import ndcg_score, roc_auc_score + +pd.set_option('display.max_columns', None) +pd.set_option('display.width', 1000) +pd.set_option('display.max_rows', None) + +flip_path = '/home/kevyan/results/flipv3/' +# results_path = os.path.join(flip_path, 'combined_flip_results') +# result_files = os.listdir(results_path) + + + + + +# Make clean full landscape zero shot score csvs +pruned_path = "/home/kevyan/data/flip_data_pruned/" +landscapes = os.listdir(pruned_path) +os.makedirs(os.path.join(flip_path, "zs"), exist_ok=True) +for landscape in landscapes: + if landscape == "rhomax": + df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "by_wt.csv")) + elif landscape == "TrpB": + df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "by_position.csv")) + else: + df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", "random.csv")) + columns = [c for c in df.columns if c not in ("set", "validation", "dayhoff_min", "dayhoff_3bur90_min")] + df = df[columns] + if landscape != "PDZ3": + df.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"}, inplace=True) + df['dayhoff'] = (df['dayhoff_3bgrhmc_fwd'] + df['dayhoff_3bgrhmc_bwd'] + df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd']) / 4 + else: + df.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"}, inplace=True) + df.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"}, inplace=True) + d = (df['dayhoff_3bgrhmc_fwd_1'] + df['dayhoff_3bgrhmc_bwd_1'] + df['dayhoff_3bur90_fwd_1'] + df['dayhoff_3bur90_bwd_1']) / 4 + d += (df['dayhoff_3bgrhmc_fwd_2'] + df['dayhoff_3bgrhmc_bwd_2'] + df['dayhoff_3bur90_fwd_2'] + df['dayhoff_3bur90_bwd_2']) / 4 + d /= 2 + df['dayhoff'] = d + df['esm2_650M_scores'] = (df['esm2_650M_scores1'] + df['esm2_650M_scores2'].fillna(0)) / 2 + df['carp_640m_zs'] = (df['carp_640m_zs_1'] + df['carp_640m_zs_2']) / 2 + if 'carp_640m_masked_zs' in df.columns: + df['carp_640m_zs'] = df['carp_640m_masked_zs'] + df = df.drop(columns=['carp_640m_masked_zs']) + rho1 = spearmanr(df['target'], df['esm2_650M_scores']).statistic + rho2 = spearmanr(df['target'], df['dayhoff']).statistic + rho3 = spearmanr(df['target'], df['carp_640m_zs']).statistic + + print(landscape, rho1, rho2, rho3) + df.to_csv(os.path.join(flip_path, "zs", landscape + "_zs.csv"), index=False) + +# Make clean prediction csvs for each test set +prediction_path = os.path.join(flip_path, "all_predictions") +for landscape in landscapes: + os.makedirs(os.path.join(prediction_path, landscape), exist_ok=True) + split_csvs = os.listdir(os.path.join(pruned_path, landscape, "splits")) + split_csvs = [c for c in split_csvs if ".csv" in c] + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(pruned_path, landscape, "splits", split_csv), index_col=0) + df = df[df['set'] == 'test'] + columns = [c for c in df.columns if c not in ("set", "validation", "dayhoff_min", "dayhoff_3bur90_min", "Unnamed: 0")] + df = df[columns] + if landscape != "PDZ3": + df.rename(columns={"dayhoff_fwd": "dayhoff_3bgrhmc_fwd", "dayhoff_bwd": "dayhoff_3bgrhmc_bwd"}, + inplace=True) + df['dayhoff'] = (df['dayhoff_3bgrhmc_fwd'] + df['dayhoff_3bgrhmc_bwd'] + df['dayhoff_3bur90_fwd'] + df[ + 'dayhoff_3bur90_bwd']) / 4 + else: + df.rename(columns={"dayhoff_fwd_1": "dayhoff_3bgrhmc_fwd_1", "dayhoff_bwd_1": "dayhoff_3bgrhmc_bwd_1"}, + inplace=True) + df.rename(columns={"dayhoff_fwd_2": "dayhoff_3bgrhmc_fwd_2", "dayhoff_bwd_2": "dayhoff_3bgrhmc_bwd_2"}, + inplace=True) + d = (df['dayhoff_3bgrhmc_fwd_1'] + df['dayhoff_3bgrhmc_bwd_1'] + df['dayhoff_3bur90_fwd_1'] + df[ + 'dayhoff_3bur90_bwd_1']) / 4 + d += (df['dayhoff_3bgrhmc_fwd_2'] + df['dayhoff_3bgrhmc_bwd_2'] + df['dayhoff_3bur90_fwd_2'] + df[ + 'dayhoff_3bur90_bwd_2']) / 4 + d /= 2 + df['dayhoff'] = d + df['esm2_650M_scores'] = (df['esm2_650M_scores1'] + df['esm2_650M_scores2'].fillna(0)) / 2 + df['carp_640m_zs'] = (df['carp_640m_zs_1'] + df['carp_640m_zs_2']) / 2 + if 'carp_640m_masked_zs' in df.columns: + df['carp_640m_zs'] = df['carp_640m_masked_zs'] + df = df.drop(columns=['carp_640m_masked_zs']) + plm_path = os.path.join(flip_path, "plm_predictions", landscape + "_" + split_csv[:-4] + "_predictions.csv") + if os.path.isfile(plm_path): + df2 = pd.read_csv(plm_path) + if landscape == "PDZ3": + df2['sequence'] = df['sequence'].values + df3 = df.merge(df2, left_on='sequence', right_on='sequence') + df4 = pd.read_csv(os.path.join(flip_path, "ridge", landscape, split_csv)) + df4.rename(columns={'prediction': 'Ridge', "target": "linear_target"}, inplace=True) + df3 = df3.merge(df4, left_on='sequence', right_on='sequence') + df4 = pd.read_csv(os.path.join(flip_path, "ridge_zs", landscape, split_csv)) + df4 = df4[['sequence', 'prediction']] + df4.rename(columns={'prediction': 'zsRidge'}, inplace=True) + df3 = df3.merge(df4, left_on='sequence', right_on='sequence') + df3.to_csv(os.path.join(prediction_path, landscape, split_csv), index=False) + +# Make csv for metrics with all replicates +nice_models = { + 'Ridge': ("Ridge", False, 0), + "zsRidge": ("zsRidge", False, 0), + "dayhoff": ("Dayhoff", False, 0), + "carp_640m_zs": ("CARP-640M zero shot", False, 0), + "esm2_650M_scores": ("ESM2-650M", False, 0), +} +for seed in range(5): + for model in ['carp', 'esmc']: + for p in ['pretrained', 'naive']: + key = '_'.join([model, p, str(seed)]) + if model == "carp": + nice_models[key] = ("CARP-640M", p == "pretrained", seed) + else: + nice_models[key] = ("ESMC-300M", p == "pretrained", seed) + +all_results = pd.DataFrame(columns=['dataset', "split", "model", "pretrained", "Spearman", "MSE", "NDCG", "seed"]) +for landscape in landscapes: + pred_csvs = os.listdir(os.path.join(prediction_path, landscape)) + for pred_csv in pred_csvs: + pred_df = pd.read_csv(os.path.join(prediction_path, landscape, pred_csv)) + split = pred_csv[:-4] + for key in nice_models: + if key in pred_df.columns: + idx = len(all_results) + all_results.loc[idx, ['dataset', 'split']] = (landscape, split) + all_results.loc[idx, ['model', 'pretrained', 'seed']] = nice_models[key] + all_results.loc[idx, ["Spearman"]] = spearmanr(pred_df['scaled_target'], pred_df[key]).correlation + if "carp" in key or "esmc" in key or "Ridge" in key: + all_results.loc[idx, ["MSE"]] = ((pred_df['scaled_target'] - pred_df[key]) ** 2).mean() + pos_targets = (pred_df['target'] - pred_df['target'].min()).values[None, :] + all_results.loc[idx, ["NDCG"]] = ndcg_score(pos_targets, pred_df[key].values[None, :]) +all_results.to_csv(os.path.join(flip_path, "all_metrics.csv"), index=False) + +# Make csv with replicates aggregated +grouped = all_results.groupby(['dataset', 'split', 'model', 'pretrained']) +agged = grouped.agg( + spearman_mean=('Spearman', np.mean), + spearman_std=('Spearman', np.std), + mse_mean=('MSE', np.mean), + mse_std=('MSE', np.std), + ndcg_mean=('NDCG', np.mean), + ndcg_std=('NDCG', np.std), +) + + +compiled = agged.reset_index() +print(compiled) +compiled.to_csv(os.path.join(flip_path, 'all_metrics_aggregated.csv'), index=True) + +model_dict = { + "Ridge": "Ridge (one-hot)", + "zsRidge": 'Ridge (one-hot + likelihoods)', + "Dayhoff": "Dayhoff likelihood", + "ESM2-650M": "ESM2-650M likelihood", + "CARP-640M zero shot": "CARP-640M likelihood", + ("CARP-640M", True): "CARP-640M supervised", + ("CARP-640M", False): "CARP-640M naive supervised", + ("ESMC-300M", True): "ESMC-300M supervised", + ("ESMC-300M", False): "ESMC-300M naive supervised", +} + +model_order = [ + "Ridge (one-hot)", + 'Ridge (one-hot + likelihoods)', + "Dayhoff likelihood", + "ESM2-650M likelihood", + "CARP-640M likelihood", + "CARP-640M supervised", + "CARP-640M naive supervised", + "ESMC-300M supervised", + "ESMC-300M naive supervised", +] + +for i, row in compiled.iterrows(): + if row['model'] in model_dict: + compiled.loc[i, 'model'] = model_dict[row['model']] + else: + compiled.loc[i, 'model'] = model_dict[(row['model'], row['pretrained'])] + +all_tasks = [] +for i, row in compiled.iterrows(): + all_tasks.append((row['dataset'], row['split'])) + +all_tasks = set(all_tasks) +for task in all_tasks: + if 'random' in task: + continue + print(task) + comp = compiled[(compiled['dataset'] == task[0]) & (compiled['split'] == task[1])] + for model in model_order: + c = comp[comp['model'] == model] + for i, row in c.iterrows(): + if 'supervised' in model: + print(row['model'] + '& $%.3f \pm %.3f$ & $%.3f \pm %.3f$\\\\' %(row['spearman_mean'], row['spearman_std'], row['ndcg_mean'], row['ndcg_std'])) + + else: + print(row['model'] + '& $%.3f$ & $%.3f$\\\\' %(row['spearman_mean'], row['ndcg_mean'])) +by_task = compiled.pivot(index=['dataset', 'split'], columns=['model', 'pretrained']) +by_task = by_task.reset_index() +by_task.head() +(by_task['spearman_mean']['CARP-640M'][True] > by_task['spearman_mean']['CARP-640M'][False]).sum() +(by_task['spearman_mean']['ESMC-35M'][True] > by_task['spearman_mean']['ESMC-35M'][False]).sum() +(by_task['spearman_mean']['Ridge'][False] > by_task['spearman_mean']['zsRidge'][False]).sum() +(by_task['spearman_mean']['Dayhoff'][False] > by_task['spearman_mean']['zsRidge'][False]).sum() +(by_task['spearman_mean']['ESM2-650M'][False] > by_task['spearman_mean']['zsRidge'][False]).sum() + +model_list = by_task.columns[2:11] +best_spearmans = [] +for i, row in by_task.iterrows(): + print(row['dataset'].values[0], row['split'].values[0], *model_list[row['spearman_mean'].argmax()][1:]) + best_spearmans.append(model_list[row['spearman_mean'].argmax()][1:]) +Counter(best_spearmans) +# +# best_mses = [] +# for i, row in by_task.iterrows(): +# print(row['dataset'].values[0], row['split'].values[0], *model_list[row['mse_mean'].argmin()][1:]) +# best_mses.append(model_list[row['mse_mean'].argmin()][1:]) +# Counter(best_mses) +# +# +# by_task[['dataset', 'split', 'spearman_mean']] +# by_task['spearman_mean'].mean() +# print("dataset,split,p>n,p>r") +# for dataset in set(compiled.dataset): +# for split in set(compiled[compiled['dataset'] == dataset].split): +# r = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'Ridge')]['spearman_mean'].values[0] +# p = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'CARP-640M') & (compiled['pretrained'])]['spearman_mean'].values[0] +# n = compiled[(compiled['dataset'] == dataset) & (compiled['split'] == split) & (compiled['model'] == 'CARP-640M') & (~compiled['pretrained'])]['spearman_mean'] +# if len(n) > 0: +# n = n.values[0] +# else: +# n = -np.inf +# print(dataset + "," + split + "," + str(p > n) + "," + str(p > r)) \ No newline at end of file diff --git a/baselines/carp.py b/baselines/carp.py new file mode 100644 index 0000000..e6301e5 --- /dev/null +++ b/baselines/carp.py @@ -0,0 +1,248 @@ +import argparse +import os +import json + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import LambdaLR +from torch.optim import Adam +import numpy as np +from scipy.stats import spearmanr +from sklearn.metrics import ndcg_score + +from sequence_models.collaters import Seq2PropertyCollater +from sequence_models.constants import PAD +from sequence_models.structure import Attention1d +from sequence_models.utils import warmup +from sequence_models.flip_utils import load_flip_data +from sequence_models.pretrained import load_model_and_alphabet + + + + + +class Model(nn.Module): + + def __init__(self, d_model, dropout=0.0): + super().__init__() + self.d_model = d_model + self.attention = Attention1d(d_model) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.hidden = nn.Linear(d_model, d_model) + self.linear = nn.Linear(d_model, 1) + + def forward(self, e, input_mask=None): + attended = self.attention(e, input_mask=input_mask) + hidden = self.hidden(self.activation(attended)) + return self.linear(self.dropout(self.activation(hidden))) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('task', type=str) + parser.add_argument('weights', type=str, default='pretrained') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--lr', default=1e-4, type=float) + + args = parser.parse_args() + train(args) + + +def train(args): + _ = torch.manual_seed(args.seed) + torch.cuda.set_device(0) + device = torch.device('cuda:0') + lr = args.lr + + np.random.seed(args.seed) + carp, collater = load_model_and_alphabet('carp_640M') + if args.weights != 'pretrained': + for p in carp.modules(): + try: + p.reset_parameters() + except AttributeError: + continue + embedder = carp.model.embedder.to(device) + d_model = carp.model.embedder.up_embedder.conv.out_channels + decoder = Model(d_model, dropout=0) + decoder = decoder.to(device) + optimizer = Adam(list(embedder.parameters()) + list(decoder.parameters()), lr=lr) + model = nn.ModuleDict({'embedder': embedder, 'decoder': decoder}) + alphabet = collater.tokenizer.alphabet + + ## Grab data + batch_size = 16 + loss_func = nn.MSELoss() + if "AMY_BACSU" in args.task: + flip_dataset = '_'.join(args.task.split('_')[:2]) + flip_split = '_'.join(args.task.split('_')[2:]) + else: + flip_dataset = args.task.split('_')[0] + flip_split = '_'.join(args.task.split('_')[1:]) + ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, flip_dataset, flip_split, max_len=2048, scale=True) + collate_fn = Seq2PropertyCollater(alphabet, return_mask=True) + num_workers = 4 + dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collate_fn, + num_workers=num_workers, shuffle=True) + dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers) + dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers) + print('%d Train samples %d valid samples %d test samples' %(len(ds_train), len(ds_valid), len(ds_test))) + checkpoint_stem = 'carp_%s_%s_%d' %(args.task, args.weights, args.seed) + def step(model, batch, train=True, return_values=False): + src, tgt, input_mask = batch + src = src.to(device) + tgt = tgt.to(device) + input_mask = (src != alphabet.index(PAD)).float().unsqueeze(-1) + e = model['embedder'](src, input_mask=input_mask) + outputs = model['decoder'](e, input_mask=input_mask) + + loss = loss_func(outputs, tgt) + locations = len(tgt) + mask = torch.ones(1) # dummy + if train: + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + if return_values: + return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu() + else: + return loss.item(), locations + + + def epoch(model, current_step=0): + model = model.train() + loader = dl_train + t = 'Training:' + losses = [] + ns = [] + n_seen = 0 + if train: + n_total = len(ds_train) + else: + n_total = len(ds_valid) + for i, batch in enumerate(loader): + new_loss, new_n = step(model, batch, True) + losses.append(new_loss * new_n) + ns.append(new_n) + n_seen += len(batch[0]) + total_n = sum(ns) + if total_n == 0: + rloss = 0 + else: + rloss = sum(losses) / total_n + if train: + nsteps = current_step + i + 1 + else: + nsteps = i + print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f' + % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss), + end='') + if not train: + return rloss + return i, rloss + + def test_epoch(model, dl): + model = model.eval() + with torch.no_grad(): + losses = [] + ns = [] + n_seen = 0 + pred = [] + tgt = [] + masks = [] + for i, batch in enumerate(dl): + new_loss, new_n, p, s, t, m = step(model, batch, False, return_values=True) + losses.append(new_loss * new_n) + pred.append(p) + tgt.append(t) + masks.append(m) + ns.append(new_n) + n_seen += len(batch[0]) + total_n = sum(ns) + + test_loss = sum(losses) / total_n + pred = torch.cat(pred) + tgt = torch.cat(tgt) + pred = pred.numpy() + tgt = tgt.numpy() + spearman = spearmanr(pred, tgt).correlation + if (tgt < 0).any(): + pos_tgt = tgt - tgt.min() + else: + pos_tgt = tgt + ndcg = ndcg_score(pos_tgt.T, pred.T) + print('\tloss: %f' %test_loss, end='\t') + print('spearman: %f' %(spearman), end='\t') + print('ndcg: %f' %(ndcg), end='\t') + results = { + 'spearman': spearman, + 'loss': test_loss, + 'ndcg': ndcg + } + return results + + epochs = 500 + n_warmup = 1000 + total_steps = 0 + best_valid_metric = -np.inf + best_valid_loss = np.inf + patience = 10 + scheduler = LambdaLR(optimizer, warmup(n_warmup)) + waiting = 0 + os.makedirs(args.out_fpath, exist_ok=True) + for e in range(epochs): + ts, train_loss = epoch(model, current_step=total_steps) + total_steps += ts + nsteps = total_steps + results = test_epoch(model, dl_valid) + vloss = results['loss'] + vmetric = results['spearman'] + waiting += 1 + if vloss < best_valid_loss: + best_valid_loss = vloss + waiting = 0 + torch.save({ + 'step': nsteps, + 'epoch': e + 1, + 'model_state_dict': model.state_dict(), + 'val_spearman': vmetric, + 'val_ndcg': results['ndcg'], + 'val_loss': vloss, + 'train_loss': train_loss, + }, args.out_fpath + checkpoint_stem + '_best.pt') + if vmetric > best_valid_metric: + best_valid_metric = vmetric + waiting = 0 + if vloss < train_loss: + waiting = 0 + print("waiting: %d" % waiting) + if waiting == patience: + break + # TODO: checkpoint race condition + if args.out_fpath is not None: + sd = torch.load(args.out_fpath + checkpoint_stem + '_best.pt', weights_only=False) + model.load_state_dict(sd['model_state_dict']) + results = test_epoch(model, dl_test) + results['batch_size'] = batch_size + results['lr'] = lr + results['epoch'] = sd['epoch'] + results['step'] = sd['step'] + results['train_loss'] = sd['train_loss'] + results['val_spearman'] = sd['val_spearman'] + results['val_loss'] = sd['val_loss'] + results['val_ndcg'] = sd['val_ndcg'] + results['dataset'] = flip_dataset + results['split'] = flip_split + results['task'] = args.task + results['seed'] = args.seed + with open(args.out_fpath + checkpoint_stem + '.json', 'w') as f: + json.dump(results, f) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/baselines/carp_masked_zeroshot.py b/baselines/carp_masked_zeroshot.py new file mode 100644 index 0000000..5fb259a --- /dev/null +++ b/baselines/carp_masked_zeroshot.py @@ -0,0 +1,126 @@ +import argparse +import os +from typing import Tuple +from tqdm import tqdm +import json + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from scipy.stats import spearmanr + +from sequence_models.pretrained import load_model_and_alphabet +from sequence_models.constants import MASK + + + + +def train(args: argparse.Namespace) -> None: + + # get the config, tokenizer, and model + torch.cuda.set_device(args.gpu) + DEVICE = torch.device('cuda:%d' % args.gpu) + output_dir = os.path.join(args.out_fpath, args.landscape, "splits") + carp, collator = load_model_and_alphabet('carp_640M') + # Move only model to GPU + model = carp.to(DEVICE) + model = model.eval() + a_to_t = collator.tokenizer.a_to_t + seq_to_result = {} + cache_file = os.path.join(output_dir, "carp640m_masked_scores.pt") + if os.path.exists(cache_file): + seq_to_result = torch.load(cache_file, weights_only=False) + + + # Get files + ## Grab data + batch_size = 1 + landscape_path = os.path.join(args.data_fpath, args.landscape, "splits") + split_csvs = os.listdir(landscape_path) + split_csvs = [csv for csv in split_csvs if ".csv" in csv] + print(split_csvs) + + + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(landscape_path, split_csv)) + likelihoods = np.empty(len(df)) + # get the WT + # wt_seq = df[df['variant_info'].isna()]['sequence'].values[0] + for row in tqdm(df.itertuples(), total=len(df)): + sequence = row.sequence + variant_info = row.variant_info + mut_sequence = sequence[:] + input_sequence = sequence[:] + if isinstance(variant_info, str): + if "," in variant_info: + variant_info = variant_info.split(",") + else: + variant_info = [variant_info] + positions = tuple(int(v[1:-1]) - 1 for v in variant_info) + old = tuple(v[0] for v in variant_info) + new = tuple(v[-1] for v in variant_info) + else: + likelihoods[row.Index] = 0.0 + continue + for i, pos in enumerate(positions): + mut_sequence = mut_sequence[:pos] + new[i] + mut_sequence[pos + 1:] + input_sequence = input_sequence[:pos] + MASK + input_sequence[pos + 1:] + assert mut_sequence == sequence + positions_key = ','.join(str(pos) for pos in positions) + if positions_key not in seq_to_result: + src = collator([[input_sequence]])[0].to(DEVICE) + with torch.no_grad(): + logits = model(src, repr_layers=[], logits=True)['logits'] # 1, ell, 30 + logits = logits[0, positions] # n_positions, 30 + log_probs = F.log_softmax(logits, dim=-1).cpu().detach().numpy() + seq_to_result[positions_key] = log_probs + log_probs = seq_to_result[positions_key] + score = 0 + for i, lp in enumerate(log_probs): + wt_lp = lp[a_to_t[old[i]]] + mut_lp = lp[a_to_t[new[i]]] + score += mut_lp - wt_lp + likelihoods[row.Index] = score + + + + + torch.save(seq_to_result, cache_file) + if args.landscape == "PDZ3": + df['carp_640m_zs_1'] = likelihoods[0] + df['carp_640m_zs_2'] = likelihoods[1] + else: + df['carp_640m_masked_zs'] = likelihoods + df.to_csv(os.path.join(output_dir, split_csv), index=False) + + + + if args.landscape != "PDZ3": + print(split_csv, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_masked_zs']).statistic) + # else: + # print(split_csv, + # spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs_1'] + df[df['set'] == 'test']['carp_640m_zs_2']).statistic) + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('landscape', type=str) + parser.add_argument('--gpu', type=int, default=1) + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main() + + + + + diff --git a/baselines/carp_zeroshot.py b/baselines/carp_zeroshot.py new file mode 100644 index 0000000..e8214b6 --- /dev/null +++ b/baselines/carp_zeroshot.py @@ -0,0 +1,108 @@ +import argparse +import os +from typing import Tuple +from tqdm import tqdm +import json + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from scipy.stats import spearmanr + +from sequence_models.collaters import Seq2PropertyCollater +from sequence_models.constants import PAD, START, STOP +from sequence_models.structure import Attention1d +from sequence_models.utils import warmup +from sequence_models.flip_utils import load_flip_data +from sequence_models.pretrained import load_model_and_alphabet + + + + +def train(args: argparse.Namespace) -> None: + + # get the config, tokenizer, and model + torch.cuda.set_device(args.gpu) + DEVICE = torch.device('cuda:%d' % args.gpu) + output_dir = os.path.join(args.out_fpath, args.landscape, "splits") + carp, collator = load_model_and_alphabet('carp_640M') + # Move only model to GPU + model = carp.to(DEVICE) + model = model.eval() + seq_to_result = {} + cache_file = os.path.join(output_dir, "carp640m_scores.json") + if os.path.exists(cache_file): + with open(cache_file) as f: + seq_to_result = json.load(f) + + # Get files + ## Grab data + batch_size = 1 + landscape_path = os.path.join(args.data_fpath, args.landscape, "splits") + split_csvs = os.listdir(landscape_path) + split_csvs = [csv for csv in split_csvs if ".csv" in csv] + print(split_csvs) + + + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(landscape_path, split_csv)) + likelihoods = [np.empty(len(df)), np.empty(len(df))] + for row in tqdm(df.itertuples(), total=len(df)): + sequence = row.sequence + if ":" in sequence: + sequences = sequence.split(":") + else: + sequences = [sequence] + for j, sequence in enumerate(sequences): + if sequence not in seq_to_result: + if len(sequence) == 0: + seq_to_result[sequence] = 0 + continue + src = collator([[sequence]])[0].to(DEVICE) + with torch.no_grad(): + logits = model(src, repr_layers=[], logits=True)['logits'] + out = F.cross_entropy(logits[0], src.flatten()) + seq_to_result[sequence] = -out.detach().cpu().item() + likelihoods[j][row.Index] = seq_to_result[sequence] + else: + likelihoods[j][row.Index] = seq_to_result[sequence] + with open(cache_file, "w") as f: + json.dump(seq_to_result, f) + if args.landscape == "PDZ3": + df['carp_640m_zs_1'] = likelihoods[0] + df['carp_640m_zs_2'] = likelihoods[1] + else: + df['carp_640m_zs'] = likelihoods[0] + df.to_csv(os.path.join(output_dir, split_csv), index=False) + + + + if args.landscape != "PDZ3": + print(split_csv, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs']).statistic) + else: + print(split_csv, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['carp_640m_zs_1'] + df[df['set'] == 'test']['carp_640m_zs_2']).statistic) + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('landscape', type=str) + parser.add_argument('--gpu', type=int, default=1) + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main() + + + + + diff --git a/baselines/dayhoff_zeroshot.py b/baselines/dayhoff_zeroshot.py new file mode 100644 index 0000000..c822bb9 --- /dev/null +++ b/baselines/dayhoff_zeroshot.py @@ -0,0 +1,148 @@ +import argparse +import os +from typing import Tuple +from tqdm import tqdm +import json + +import numpy as np +import pandas as pd +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed +from scipy.stats import spearmanr +from sklearn.metrics import ndcg_score +from sequence_models.constants import START, STOP + +def is_amlt() -> bool: + return os.environ.get("AMLT_OUTPUT_DIR", None) is not None + + +class SimpleCollator(): + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(self, seq: "str") -> Tuple[torch.Tensor]: + fwd = START + seq + STOP + bwd = STOP + seq[::-1] + START + tokenized = self.tokenizer([fwd, bwd], return_tensors="pt", return_token_type_ids=False) + return (tokenized['input_ids'],) + + + +def train(args: argparse.Namespace) -> None: + + # get the config, tokenizer, and model + torch.cuda.set_device(args.gpu) + DEVICE = torch.device('cuda:%d' % args.gpu) + output_dir = os.path.join(args.out_fpath, args.landscape, "splits") + os.makedirs(output_dir, exist_ok=True) + model_names = [ + ['microsoft/Dayhoff-3b-UR90', "3b-UR90-seq_to_result.json", "dayhoff_3bur90"], + ['microsoft/Dayhoff-3b-GR-HM-c', "3b-gr-hm-c-seq_to_result.json", "dayhoff"] + ] + models = [] + for m in model_names: + model = AutoModelForCausalLM.from_pretrained(m[0]) + tokenizer = AutoTokenizer.from_pretrained(m[0], trust_remote_code=True) + + # model = AutoModelForCausalLM.from_pretrained('microsoft/Dayhoff-3b-GR-HM-c') + # tokenizer = AutoTokenizer.from_pretrained('microsoft/Dayhoff-3b-GR-HM-c', trust_remote_code=True) + collator = SimpleCollator(tokenizer) + + # Move only model to GPU + model = model.to(DEVICE) + model = model.to(torch.bfloat16) + model = model.eval() + seq_to_result = {} + cache_file = os.path.join(output_dir, m[1]) + if os.path.exists(cache_file): + with open(cache_file) as f: + seq_to_result = json.load(f) + models.append([model, tokenizer, collator, cache_file, seq_to_result, m[2]]) + + # Get files + ## Grab data + batch_size = 1 + landscape_path = os.path.join(args.data_fpath, args.landscape, "splits") + split_csvs = os.listdir(landscape_path) + split_csvs = [csv for csv in split_csvs if ".csv" in csv] + print(split_csvs) + + + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(landscape_path, split_csv)) + for model in models: + model, tokenizer, collator, cache_file, seq_to_result, model_stem = model + if args.landscape == "PDZ3": + fwd_lls = [np.empty(len(df)), np.empty(len(df))] + bwd_lls = [np.empty(len(df)), np.empty(len(df))] + else: + fwd_lls = [np.empty(len(df))] + bwd_lls = [np.empty(len(df))] + for row in tqdm(df.itertuples(), total=len(df)): + sequence = row.sequence + if ":" in sequence: + sequences = sequence.split(":") + else: + sequences = [sequence] + + for j, sequence in enumerate(sequences): + if sequence not in seq_to_result: + tokenized = collator(sequence)[0] + tokenized = tokenized.to(DEVICE) + with torch.no_grad(): + out = model(input_ids=tokenized[:1], labels=tokenized[:1]) + seq_to_result[sequence] = {"fwd": out.loss.detach().cpu().item()} + with torch.no_grad(): + out = model(input_ids=tokenized[1:], labels=tokenized[1:]) + seq_to_result[sequence]["bwd"] = out.loss.detach().cpu().item() + + fwd_lls[j][row.Index] = seq_to_result[sequence]["fwd"] + bwd_lls[j][row.Index] = seq_to_result[sequence]["bwd"] + with open(cache_file, "w") as f: + json.dump(seq_to_result, f) + if args.landscape == "PDZ3": + df[model_stem + '_fwd_1'] = -fwd_lls[0] + df[model_stem + '_bwd_1'] = -bwd_lls[0] + df[model_stem + '_fwd_2'] = -fwd_lls[1] + df[model_stem + '_bwd_2'] = -bwd_lls[1] + else: + df[model_stem + '_fwd'] = -fwd_lls[0] + df[model_stem + '_bwd'] = -bwd_lls[0] + df[model_stem + '_min'] = -np.maximum(fwd_lls, bwd_lls) + df.to_csv(os.path.join(output_dir, split_csv), index=False) + + + + if args.landscape != "PDZ3": + if "esm2_650M_scores" in df.columns: + print(split_csv, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_3bur90_min']).statistic, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_min']).statistic, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['esm2_650M_scores']).statistic) + else: + print(split_csv, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_3bur90_min']).statistic, + spearmanr(df[df['set'] == 'test']['target'], df[df['set'] == 'test']['dayhoff_min']).statistic) + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('landscape', type=str) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument("--no_fa2", action="store_true") + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main() + + + + + diff --git a/baselines/dayhoff_zeroshot.py.zip b/baselines/dayhoff_zeroshot.py.zip new file mode 100644 index 0000000..e8ec5bf Binary files /dev/null and b/baselines/dayhoff_zeroshot.py.zip differ diff --git a/baselines/esm2_zeroshot.py b/baselines/esm2_zeroshot.py new file mode 100644 index 0000000..4a3dd59 --- /dev/null +++ b/baselines/esm2_zeroshot.py @@ -0,0 +1,84 @@ +import argparse +import os +from tqdm import tqdm +import json + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F + +import esm + +def is_amlt() -> bool: + return os.environ.get("AMLT_OUTPUT_DIR", None) is not None + + + + +def train(args: argparse.Namespace) -> None: + + # get the config, tokenizer, and model + torch.cuda.set_device(args.gpu) + DEVICE = torch.device('cuda:%d' % args.gpu) + output_dir = os.path.join(args.out_fpath, args.landscape, "splits") + os.makedirs(output_dir, exist_ok=True) + + model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() + batch_converter = alphabet.get_batch_converter() + + model = model.to(DEVICE).eval() + + seq_to_result = {} + cache_file = os.path.join(output_dir, "esm2_650m.json") + if os.path.exists(cache_file): + with open(cache_file) as f: + seq_to_result = json.load(f) + + # Get files + ## Grab data + landscape_path = os.path.join(args.data_fpath, args.landscape, "splits") + split_csvs = os.listdir(landscape_path) + split_csvs = [csv for csv in split_csvs if ".csv" in csv] + print(split_csvs) + + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(landscape_path, split_csv)) + likelihoods = np.empty(len(df)) + for row in tqdm(df.itertuples(), total=len(df)): + sequence = row.sequence + if sequence not in seq_to_result: + batch_labels, batch_strs, batch_tokens = batch_converter([("seq", sequence)]) + + with torch.no_grad(): + logits = model(batch_tokens.to(DEVICE))["logits"][0, 1:-1] + out = F.cross_entropy(logits, batch_tokens[0, 1:-1].to(DEVICE)) + seq_to_result[sequence] = -out.detach().cpu().item() + likelihoods[row.Index] = seq_to_result[sequence] + with open(cache_file, "w") as f: + json.dump(seq_to_result, f) + df['esm2_650M_scores'] = likelihoods + df.to_csv(os.path.join(output_dir, split_csv), index=False) + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('landscape', type=str) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument("--no_fa2", action="store_true") + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main() + + + + + diff --git a/baselines/esmc.py b/baselines/esmc.py new file mode 100644 index 0000000..7050998 --- /dev/null +++ b/baselines/esmc.py @@ -0,0 +1,262 @@ +import argparse +import os +import json + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import LambdaLR +from torch.optim import Adam +import torch.nn.functional as F +import numpy as np +from scipy.stats import spearmanr +from sklearn.metrics import ndcg_score + +from esm.models.esmc import ESMC +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer + +from sequence_models.structure import Attention1d +from sequence_models.utils import warmup +from sequence_models.flip_utils import load_flip_data + + + + + +class Model(nn.Module): + + def __init__(self, d_model, dropout=0.0): + super().__init__() + self.d_model = d_model + self.attention = Attention1d(d_model) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.hidden = nn.Linear(d_model, d_model) + self.linear = nn.Linear(d_model, 1) + + def forward(self, e, input_mask=None): + attended = self.attention(e, input_mask=input_mask) + hidden = self.hidden(self.activation(attended)) + return self.linear(self.dropout(self.activation(hidden))) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + parser.add_argument('task', type=str) + parser.add_argument('weights', type=str, default='pretrained') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--lr', default=1e-4, type=float) + + args = parser.parse_args() + train(args) + + +def train(args): + _ = torch.manual_seed(args.seed) + torch.cuda.set_device(0) + device = torch.device('cuda:0') + lr = args.lr + + np.random.seed(args.seed) + esm = ESMC.from_pretrained("esmc_300m") # or "cpu" + if args.weights != 'pretrained': + for p in esm.modules(): + try: + p.reset_parameters() + except AttributeError: + continue + esm = esm.to(device) + tokenizer = EsmSequenceTokenizer() + + def collator(batch): + data = tuple(zip(*batch)) + seqs, labels = data + t = [torch.tensor(tokenizer.encode(s)) for s in seqs] + max_len = max([len(tt) for tt in t]) + t = [F.pad(tt, (0, max_len - len(tt)), value=1) for tt in t] + t = torch.stack(t) + y = torch.tensor(labels).unsqueeze(-1).float() + input_mask = t != 1 + return t, y, input_mask + + + d_model = 960 + decoder = Model(d_model, dropout=0).to(device) + model = nn.ModuleDict({'embed': esm.embed, 'transformer': esm.transformer, 'decoder': decoder}) + optimizer = Adam(model.parameters(), lr=lr) + + + ## Grab data + batch_size = 16 + loss_func = nn.MSELoss() + if "AMY_BACSU" in args.task: + flip_dataset = '_'.join(args.task.split('_')[:2]) + flip_split = '_'.join(args.task.split('_')[2:]) + else: + flip_dataset = args.task.split('_')[0] + flip_split = '_'.join(args.task.split('_')[1:]) + ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, flip_dataset, flip_split, max_len=2048, scale=True) + num_workers = 4 + dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collator, + num_workers=num_workers, shuffle=True) + dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collator, num_workers=num_workers) + dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collator, num_workers=num_workers) + print('%d Train samples %d valid samples %d test samples' %(len(ds_train), len(ds_valid), len(ds_test))) + checkpoint_stem = 'esmc_%s_%s_%d' %(args.task, args.weights, args.seed) + def step(model, batch, train=True, return_values=False): + src, tgt, input_mask = batch + src = src.to(device) + tgt = tgt.to(device) + input_mask = input_mask.to(device) + e = model['embed'](src) + e = model['transformer'](e)[0].float() + outputs = model['decoder'](e, input_mask=input_mask) + + loss = loss_func(outputs, tgt) + locations = len(tgt) + mask = torch.ones(1) # dummy + if train: + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + if return_values: + return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu() + else: + return loss.item(), locations + + + def epoch(model, current_step=0): + model = model.train() + loader = dl_train + t = 'Training:' + losses = [] + ns = [] + n_seen = 0 + if train: + n_total = len(ds_train) + else: + n_total = len(ds_valid) + for i, batch in enumerate(loader): + new_loss, new_n = step(model, batch, True) + losses.append(new_loss * new_n) + ns.append(new_n) + n_seen += len(batch[0]) + total_n = sum(ns) + if total_n == 0: + rloss = 0 + else: + rloss = sum(losses) / total_n + if train: + nsteps = current_step + i + 1 + else: + nsteps = i + print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f' + % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss), + end='') + if not train: + return rloss + return i, rloss + + def test_epoch(model, dl): + model = model.eval() + with torch.no_grad(): + losses = [] + ns = [] + n_seen = 0 + pred = [] + tgt = [] + masks = [] + for i, batch in enumerate(dl): + new_loss, new_n, p, s, t, m = step(model, batch, False, return_values=True) + losses.append(new_loss * new_n) + pred.append(p) + tgt.append(t) + masks.append(m) + ns.append(new_n) + n_seen += len(batch[0]) + total_n = sum(ns) + + test_loss = sum(losses) / total_n + pred = torch.cat(pred) + tgt = torch.cat(tgt) + pred = pred.numpy() + tgt = tgt.numpy() + spearman = spearmanr(pred, tgt).correlation + if (tgt < 0).any(): + pos_tgt = tgt - tgt.min() + else: + pos_tgt = tgt + ndcg = ndcg_score(pos_tgt.T, pred.T) + print('\tloss: %f' %test_loss, end='\t') + print('spearman: %f' %(spearman), end='\t') + print('ndcg: %f' %(ndcg), end='\t') + results = { + 'spearman': spearman, + 'loss': test_loss, + 'ndcg': ndcg + } + return results + + epochs = 500 + n_warmup = 1000 + total_steps = 0 + best_valid_metric = -np.inf + best_valid_loss = np.inf + patience = 10 + scheduler = LambdaLR(optimizer, warmup(n_warmup)) + waiting = 0 + os.makedirs(args.out_fpath, exist_ok=True) + for e in range(epochs): + ts, train_loss = epoch(model, current_step=total_steps) + total_steps += ts + nsteps = total_steps + results = test_epoch(model, dl_valid) + vloss = results['loss'] + vmetric = results['spearman'] + waiting += 1 + if vloss < best_valid_loss: + best_valid_loss = vloss + waiting = 0 + torch.save({ + 'step': nsteps, + 'epoch': e + 1, + 'model_state_dict': model.state_dict(), + 'val_spearman': vmetric, + 'val_ndcg': results['ndcg'], + 'val_loss': vloss, + 'train_loss': train_loss, + }, args.out_fpath + checkpoint_stem + '_best.pt') + if vmetric > best_valid_metric: + best_valid_metric = vmetric + waiting = 0 + if vloss < train_loss: + waiting = 0 + print("waiting: %d" % waiting) + if waiting == patience: + break + # TODO: checkpoint race condition + if args.out_fpath is not None: + sd = torch.load(args.out_fpath + checkpoint_stem + '_best.pt', weights_only=False) + model.load_state_dict(sd['model_state_dict']) + results = test_epoch(model, dl_test) + results['batch_size'] = batch_size + results['lr'] = lr + results['epoch'] = sd['epoch'] + results['step'] = sd['step'] + results['train_loss'] = sd['train_loss'] + results['val_spearman'] = sd['val_spearman'] + results['val_loss'] = sd['val_loss'] + results['val_ndcg'] = sd['val_ndcg'] + results['dataset'] = flip_dataset + results['split'] = flip_split + results['task'] = args.task + results['seed'] = args.seed + with open(args.out_fpath + checkpoint_stem + '.json', 'w') as f: + json.dump(results, f) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/baselines/get_predictions.py b/baselines/get_predictions.py new file mode 100644 index 0000000..2a21735 --- /dev/null +++ b/baselines/get_predictions.py @@ -0,0 +1,190 @@ +import argparse +import os +from datetime import datetime + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import pandas as pd +import numpy as np + +from esm.models.esmc import ESMC +from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer + +from sequence_models.collaters import Seq2PropertyCollater +from sequence_models.constants import PAD +from sequence_models.structure import Attention1d +from sequence_models.utils import warmup +from sequence_models.flip_utils import load_flip_data +from sequence_models.pretrained import load_model_and_alphabet + + + + + +class Model(nn.Module): + + def __init__(self, d_model, dropout=0.0): + super().__init__() + self.d_model = d_model + self.attention = Attention1d(d_model) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.hidden = nn.Linear(d_model, d_model) + self.linear = nn.Linear(d_model, 1) + + def forward(self, e, input_mask=None): + attended = self.attention(e, input_mask=input_mask) + hidden = self.hidden(self.activation(attended)) + return self.linear(self.dropout(self.activation(hidden))) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_fpath', type=str) + parser.add_argument('out_fpath', type=str) + + args = parser.parse_args() + train(args) + + +def train(args): + _ = torch.manual_seed(0) + torch.cuda.set_device(0) + device = torch.device('cuda:0') + + np.random.seed(0) + carp, carp_collater = load_model_and_alphabet('carp_640M') + embedder = carp.model.embedder.to(device) + d_model = carp.model.embedder.up_embedder.conv.out_channels + decoder = Model(d_model, dropout=0) + decoder = decoder.to(device) + carp_model = nn.ModuleDict({'embedder': embedder, 'decoder': decoder}) + carp_model = carp_model.eval() + carp_alphabet = carp_collater.tokenizer.alphabet + carp_collate_fn = Seq2PropertyCollater(carp_alphabet, return_mask=True) + esm = ESMC.from_pretrained("esmc_300m") # or "cpu" + esm = esm.to(device) + esm_tokenizer = EsmSequenceTokenizer() + + def esm_collator(batch): + data = tuple(zip(*batch)) + seqs, labels = data + t = [torch.tensor(esm_tokenizer.encode(s)) for s in seqs] + max_len = max([len(tt) for tt in t]) + t = [F.pad(tt, (0, max_len - len(tt)), value=1) for tt in t] + t = torch.stack(t) + y = torch.tensor(labels).unsqueeze(-1).float() + input_mask = t != 1 + return t, y, input_mask + d_model = 960 + decoder = Model(d_model, dropout=0).to(device) + esm_model = nn.ModuleDict({'embed': esm.embed, 'transformer': esm.transformer, 'decoder': decoder}) + esm_model = esm_model.eval() + + split_dict = { + "AMY_BACSU": { + "random.csv": "random_split.csv", + "by_position.csv": "hard_split_.csv", + "close_to_far.csv":"med_split_is_close_to_as_0.csv", + "far_to_close.csv": "med_split_is_close_to_as_1.csv", + "one_to_many.csv": "one_to_many.csv", + }, + "hydro": { + "low_to_high.csv": "hard_split.csv", + "random.csv": "random_split.csv", + "three_to_many.csv": "easy_split.csv", + "to_06241.csv": "med_P06241test_split.csv", + "to_P01053.csv": "med_P01053test_split.csv", + "to_P0A9X9.csv": "med_P0A9X9test_split.csv" + }, + "ired": { + "mutation_order.csv": "ired_mutation_order_split.csv", + "random.csv": "ired_random_split.csv", + }, + "NucB": { + "random.csv": "easy.csv", + "two_to_many.csv": "two_to_many.csv", + }, + "PDZ3": { + "random.csv": "rand_split.csv", + "single_to_double.csv": "single_to_double.csv", + }, + "RhoMax": { + "by_wt.csv": "by_wt.csv", + }, + "trpb": { + "by_position.csv": "trpB_no_position_overlap_split.csv", + "one_to_many.csv": "trpB_one_vs_many_split.csv", + "two_to_many.csv": "trpB_two_vs_many_split.csv" + } + } + ## Grab data + checkpoint_paths = ["/mnt/amlt/flip_results_5-9-2025", "/mnt/amlt/flip_results/"] + os.makedirs(args.out_fpath, exist_ok=True) + landscapes = os.listdir(args.data_fpath) + for landscape in landscapes: + split_csvs = os.listdir(os.path.join(args.data_fpath, landscape, "splits")) + split_csvs = [csv for csv in split_csvs if ".csv" in csv] + num_workers = 4 + for split_csv in split_csvs[::-1]: + out_file = os.path.join(args.out_fpath, "%s_%s_predictions.csv" % (landscape, split_csv[:-4])) + if os.path.isfile(out_file): + continue + print(landscape, split_csv, datetime.now()) + ds_train, ds_valid, ds_test = load_flip_data(args.data_fpath, landscape, split_csv[:-4], max_len=2048, + scale=True) + results_df = pd.DataFrame() + results_df['sequence'] = [d[0] for d in ds_test] + results_df['scaled_target'] = [d[1] for d in ds_test] + + carp_dl_test = DataLoader(ds_test, batch_size=32, collate_fn=carp_collate_fn, num_workers=num_workers) + esm_dl_test = DataLoader(ds_test, batch_size=32, collate_fn=esm_collator, num_workers=num_workers) + task = landscape + "_" + split_dict[landscape][split_csv][:-4] + for seed in range(5): + for weights in ['pretrained', 'naive']: + checkpoint_stem = 'carp_%s_%s_%d' % (task, weights, seed) + try: + sd = torch.load(os.path.join(checkpoint_paths[0], checkpoint_stem + "_best.pt"), + weights_only=False) + except FileNotFoundError: + sd = torch.load(os.path.join(checkpoint_paths[1], checkpoint_stem + "_best.pt"), + weights_only=False) + carp_model.load_state_dict(sd['model_state_dict']) + predictions = [] + for i, batch in enumerate(carp_dl_test): + src, tgt, input_mask = batch + src = src.to(device) + with torch.no_grad(): + input_mask = (src != carp_alphabet.index(PAD)).float().unsqueeze(-1) + e = carp_model['embedder'](src, input_mask=input_mask) + predictions.append(carp_model['decoder'](e, input_mask=input_mask).detach().cpu().numpy()) + predictions = np.concatenate(predictions) + results_df["carp_%s_%d" % (weights, seed)] = predictions + + checkpoint_stem = 'esmc_%s_%s_%d' % (task, weights, seed) + try: + sd = torch.load(os.path.join(checkpoint_paths[0], checkpoint_stem + "_best.pt"), + weights_only=False) + except FileNotFoundError: + sd = torch.load(os.path.join(checkpoint_paths[1], checkpoint_stem + "_best.pt"), + weights_only=False) + esm_model.load_state_dict(sd['model_state_dict']) + predictions = [] + for i, batch in enumerate(esm_dl_test): + src, tgt, input_mask = batch + src = src.to(device) + input_mask = input_mask.to(device) + with torch.no_grad(): + e = esm_model['embed'](src) + e = esm_model['transformer'](e)[0].float() + predictions.append(esm_model['decoder'](e, input_mask=input_mask).detach().cpu().numpy()) + predictions = np.concatenate(predictions) + results_df["esmc_%s_%d" % (weights, seed)] = predictions + results_df.to_csv(out_file, + index=False) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/baselines/linear_models.py b/baselines/linear_models.py new file mode 100644 index 0000000..a67db4b --- /dev/null +++ b/baselines/linear_models.py @@ -0,0 +1,206 @@ +import argparse +import json +import os +from tqdm import tqdm + +from sklearn.metrics import mean_squared_error +from sklearn.linear_model import Ridge, RidgeClassifier +from sklearn.preprocessing import StandardScaler +from scipy.stats import spearmanr +from sklearn.metrics import ndcg_score, roc_auc_score + +import torch +import numpy as np +import torch.nn.functional as F +import pandas as pd + +torch.manual_seed(0) + +from sequence_models.utils import Tokenizer +from sequence_models.flip_utils import load_flip_data + +AAINDEX_ALPHABET = 'ARNDCQEGHILKMFPSTWYV' + + +parser = argparse.ArgumentParser() +parser.add_argument('--seed', default=0, type=int) +parser.add_argument('--solver', type=str, default='auto') +parser.add_argument('--max_iter', type=int, default=1000000) +parser.add_argument('--tol', type=float, default=1e-5) +args = parser.parse_args() + +results_path = "/home/kevyan/results/flipv3/" + +model_dict = { + "esm2_650M_scores": "esm2_650M_zs", + # "dayhoff_fwd": "dayhoff_3b_gr_hm_c_fwd_zs", + # "dayhoff_bwd": "dayhoff_3b_gr_hm_c_bwd_zs", + # "dayhoff_min": "dayhoff_3b_gr_hm_c_min_zs", + # "dayhoff_max": "dayhoff_3b_gr_hm_c_max_zs", + "dayhoff_mean": "dayhoff_3b_gr_hm_c_mean_zs", + # "dayhoff_3bur90_fwd": "dayhoff_3b_ur90_fwd_zs", + # "dayhoff_3bur90_bwd": "dayhoff_3b_ur90_bwd_zs", + # "dayhoff_3bur90_min": "dayhoff_3b_ur90_zs", + # "dayhoff_3bur90_max": "dayhoff_3b_ur90_max_zs", + "dayhoff_3bur90_mean": "dayhoff_3b_ur90_mean_zs", + "dayhoff_both_mean": "dayhoff_both_mean" +} +tokenizer = Tokenizer(AAINDEX_ALPHABET) # tokenize + +# Randomize at different data sizes +input_path = "/home/kevyan/results/flipv3/zs/" +landscapes = os.listdir(input_path) +np.random.seed(23) +replicates = 50 +results = pd.DataFrame(columns=['dataset', 'model', 'fraction_train', 'n_train', 'Spearman', 'replicate']) +n_min = 50 +with tqdm(total=len(landscapes) * replicates) as pbar: + for landscape in landscapes: + df = pd.read_csv(os.path.join(input_path, landscape)) + landscape_name = landscape[:-7] + n = len(df) + X = df['sequence'].values + X = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in X] + maxlen = max([len(i) for i in X]) + X = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X] + X_enc = [] # ohe + for i in X: + i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET)) + i_onehot.zero_() + i_onehot.scatter_(1, i, 1) + X_enc.append(i_onehot) + X_enc = np.array([np.array(i.view(-1)) for i in X_enc]) # flatten + cols = ['esm2_650M_scores', 'carp_640m_zs', 'dayhoff'] + X_zs = df[cols].values + new_X = np.hstack([X_enc, X_zs]) + log10_min = np.log10(n_min) + log10_max = np.log10(n * 0.8) + n_trains = np.logspace(log10_min, log10_max, num=10) + for rep in range(replicates): + n_train = int(n_trains[rep % 10]) + fraction = n_train / n + n_test = n - n_train + idx = np.arange(n) + np.random.shuffle(idx) + y_scale = df.iloc[idx[:n_train]]['target'].values[:, None] + y_test = df.iloc[idx[n_train:]]['target'].values[:, None] + X_train_enc = X_enc[idx[:n_train]] + X_test_enc = X_enc[idx[n_train:]] + scaler = StandardScaler() + scaler.fit(y_scale) + y_train = scaler.transform(y_scale) + y_test = scaler.transform(y_test) + lr = Ridge(solver='auto', tol=1e-5, max_iter=1000000, alpha=10) + # lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10) + lr.fit(X_train_enc, y_train) + preds = lr.predict(X_test_enc) + preds = preds.reshape(-1, 1) + rho = spearmanr(y_test, preds).correlation + results.loc[len(results)] = [landscape_name, 'Ridge (one-hot)', fraction, n_train, rho, rep] + new_X_train = new_X[idx[:n_train]] + new_X_test = new_X[idx[n_train:]] + lr = Ridge(solver='auto', tol=1e-5, max_iter=1000000, alpha=10) + lr.fit(new_X_train, y_train) + preds = lr.predict(new_X_test) + rho = spearmanr(y_test, preds).correlation + results.loc[len(results)] = [landscape_name, 'Ridge (one-hot + likelihoods)', fraction, n_train, rho, rep] + pbar.update(1) +results.to_csv(results_path + "random_ridge.csv", index=False) + +input_path = "/home/kevyan/data/flip_data_pruned/" +landscapes = os.listdir(input_path) + +for landscape in landscapes: + split_csvs = os.listdir(os.path.join(input_path, landscape, "splits")) + split_csvs = [c for c in split_csvs if "csv" in c] + for split_csv in split_csvs: + df = pd.read_csv(os.path.join(input_path, landscape, "splits", split_csv)) +# tokenize train data + X_train = df[(df["set"] == "train") & (~df['validation'])]['sequence'].values + y_scale = df[df['set'] == "train"]['target'].values[:, None] + y_train = df[(df["set"] == "train") & (~df['validation'])]['target'].values[:, None] + X_train = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in X_train] + seq_test = df[df["set"] == "test"]['sequence'].values + y_test = df[df["set"] == "test"]['target'].values[:, None] + X_test = [torch.tensor(tokenizer.tokenize(i.replace(":", ""))).view(-1, 1) for i in seq_test] + # padding + maxlen_train = max([len(i) for i in X_train]) + maxlen_test = max([len(i) for i in X_test]) + maxlen = max([maxlen_train, maxlen_test]) + + X_train = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_train] + X_train_enc = [] # ohe + for i in X_train: + i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET)) + i_onehot.zero_() + i_onehot.scatter_(1, i, 1) + X_train_enc.append(i_onehot) + X_train_enc = np.array([np.array(i.view(-1)) for i in X_train_enc]) # flatten + + X_test = [F.pad(i, (0, 0, 0, maxlen - i.shape[0]), "constant", 0.) for i in X_test] + X_test_enc = [] # ohe + for i in X_test: + i_onehot = torch.FloatTensor(maxlen, len(AAINDEX_ALPHABET)) + i_onehot.zero_() + i_onehot.scatter_(1, i, 1) + X_test_enc.append(i_onehot) + X_test_enc = np.array([np.array(i.view(-1)) for i in X_test_enc]) # flatten + scaler = StandardScaler() + scaler.fit(y_scale) + y_train = scaler.transform(y_train) + y_test = scaler.transform(y_test) + lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10) + lr.fit(X_train_enc, y_train) + + print(landscape, split_csv[:-4], 'one-hot') + preds = lr.predict(X_test_enc) + preds = preds.reshape(-1, 1) + + mse = mean_squared_error(y_test, preds) + print('TEST MSE: ', mse) + print('TEST RHO: ', spearmanr(y_test, preds).correlation) + y_test_pos = y_test - y_test.min() + print('TEST NDCG: ', ndcg_score(y_test_pos.T, preds.T)) + + results = pd.DataFrame() + results['sequence'] = seq_test + results['target'] = y_test + results['prediction'] = preds + os.makedirs(os.path.join(results_path, "ridge", landscape), exist_ok=True) + results.to_csv(os.path.join(results_path, "ridge", landscape, split_csv), index=False) + + + if landscape != "PDZ3": + df['dayhoff'] = df['dayhoff_bwd'] + df['dayhoff_fwd'] + df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd'] + if 'carp_640m_masked_zs' in df.columns: + cols = ['dayhoff', "esm2_650M_scores", 'carp_640m_masked_zs'] + else: + cols = ['dayhoff', "esm2_650M_scores", 'carp_640m_zs'] + else: + df['dayhoff1'] = df['dayhoff_bwd_1'] + df['dayhoff_fwd_1'] + df['dayhoff_3bur90_fwd_1'] + df['dayhoff_3bur90_bwd_1'] + df['dayhoff2'] = df['dayhoff_bwd_2'] + df['dayhoff_fwd_2'] + df['dayhoff_3bur90_fwd_2'] + df['dayhoff_3bur90_bwd_2'] + cols = ["dayhoff1", "dayhoff2", "esm2_650M_scores1", "esm2_650M_scores1", 'carp_640m_zs_1', 'carp_640m_zs_2'] + zs_x_train = df[(df["set"] == "train") & (~df['validation'])][cols].values + new_x_train = np.hstack([X_train_enc, zs_x_train]) + zs_x_test = df[df["set"] == "test"][cols].values + new_x_test = np.hstack([X_test_enc, zs_x_test]) + lr = Ridge(solver=args.solver, tol=args.tol, max_iter=args.max_iter, alpha=10) + + lr.fit(new_x_train, y_train) + + print(landscape, split_csv[:-4], 'one-hot + zs') + preds = lr.predict(new_x_test) + preds = preds.reshape(-1, 1) + mse = mean_squared_error(y_test, preds) + print('TEST MSE: ', mse) + print('TEST RHO: ', spearmanr(y_test, preds).correlation) + y_test_pos = y_test - y_test.min() + print('TEST NDCG: ', ndcg_score(y_test_pos.T, preds.T)) + results = pd.DataFrame() + results['sequence'] = seq_test + results['target'] = y_test + results['prediction'] = preds + os.makedirs(os.path.join(results_path, "ridge_zs", landscape), exist_ok=True) + results.to_csv(os.path.join(results_path, "ridge_zs", landscape, split_csv), index=False) + + diff --git a/collect_splits/amy_bacsu.py b/collect_splits/amy_bacsu.py new file mode 100644 index 0000000..b1c9fad --- /dev/null +++ b/collect_splits/amy_bacsu.py @@ -0,0 +1,40 @@ +import os + +import numpy as np +import pandas as pd +from scipy.stats import spearmanr + +pd.set_option('display.max_columns', 100) +pd.set_option('display.width', 1000) + +n_muts = [] +df = pd.read_csv('/home/kevyan/data/flip_data_zs/AMY_BACSU/splits/easy_split.csv', index_col=0) +for i, row in df.iterrows(): + muts = row['variant_info'] + if isinstance(muts, str): + n_muts.append(len(row['variant_info'].split(','))) + else: + n_muts.append(0) +df['n_mutations'] = n_muts +np.random.seed(0) +df['validation'] = False +for i, row in df.iterrows(): + if row['n_mutations'] > 1: + df.loc[i, 'set'] = "test" + else: + df.loc[i, 'set'] = "train" + if np.random.random() < 0.15: + df.loc[i, "validation"] = True +df.to_csv('/home/kevyan/data/flip_data_zs/AMY_BACSU/splits/one_to_many.csv', index=False) + + +df[df['set'] == "test"].shape +df['validation'].sum() +df[(df['set'] == 'train') & (~df['validation'])].shape +df[df['set'] == "test"]['target'].max() + +df[df['validation']]['target'].max() +df[(df['set'] == 'train') & (~df['validation'])]['target'].max() +df['target'].max() +df[df['target'] > 0.19] +spearmanr(df['esm2_650M_scores'], df['dayhoff_3bur90_fwd']) \ No newline at end of file diff --git a/collect_splits/nucb_fix_val.py b/collect_splits/nucb_fix_val.py new file mode 100644 index 0000000..8aa47c9 --- /dev/null +++ b/collect_splits/nucb_fix_val.py @@ -0,0 +1,32 @@ +import pandas as pd +import numpy as np + + +pd.set_option('display.max_columns', 100) +pd.set_option('display.width', 1000) +df = pd.read_csv('/home/kevyan/data/flip_data_20250814/flip_data/datasets/NucB/splits/medium.csv', index_col=0) +n_muts = [] +for i, row in df.iterrows(): + muts = row['variant_info'] + if isinstance(muts, str): + n_muts.append(len(row['variant_info'].split(','))) + else: + n_muts.append(0) +df['n_mutations'] = n_muts +np.random.seed(0) +df['validation'] = False +for i, row in df.iterrows(): + if row['n_mutations'] > 2: + row['set'] = 'test' + else: + row['set'] = 'train' + if np.random.random() < 0.15: + df.loc[i, "validation"] = True + +df.to_csv('/home/kevyan/data/flip_data_pruned/NucB/splits/two_to_many.csv', index=False) + +df = pd.read_csv('/home/kevyan/data/flip_data_pruned/RhoMax/splits/by_wt.csv', index_col=0) +spearmanr(df['target'], df['esm2_650M_scores']) +spearmanr(df['target'], df['dayhoff_fwd'] + df['dayhoff_bwd']) +spearmanr(df['target'], df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd']) +spearmanr(df['target'], df['dayhoff_3bur90_fwd'] + df['dayhoff_3bur90_bwd'] + df['dayhoff_fwd'] + df['dayhoff_bwd']) \ No newline at end of file diff --git a/collect_splits/rhomax.py b/collect_splits/rhomax.py new file mode 100644 index 0000000..aa8f48d --- /dev/null +++ b/collect_splits/rhomax.py @@ -0,0 +1,62 @@ +import os +from pathlib import Path +import subprocess +import argparse +import os + +import pandas as pd +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument("split_path", type=str, help="Directory to download raw data") +parser.add_argument("out_path", type=str, help="Directory to save processed file") +args = parser.parse_args() + +split_path = Path(args.split_path) +# split_path = "/home/kevyan/src/FLIPv3/splits/rhomax" +os.makedirs(split_path, exist_ok=True) +subprocess.call(["wget", "-P", split_path, "https://github.com/dina-lab3D/OpsiGen/raw/refs/heads/colab/excel/data.xlsx"]) + +df = pd.read_excel(os.path.join(split_path, "data.xlsx")) + +grouped = df.groupby("Wildtype") +grouped = grouped.agg({"lmax": ['mean', 'count']}) +grouped.columns = grouped.columns.to_flat_index() +grouped = grouped.sort_values(('lmax', 'count')) +grouped = grouped.reset_index() +val_wt = [] +n_val = 0 +test_wt = [] +n_test = 0 +num_val = 100 +num_test = 175 +current = 'val' + +for wildtype in grouped['Wildtype'].values: + if current == 'val' and n_val < num_val: + val_wt.append(wildtype) + n_val += grouped[grouped['Wildtype'] == wildtype][('lmax', 'count')].values[0] + if n_test < num_test: + current = 'test' + elif current == 'test' and n_test < num_test: + test_wt.append(wildtype) + n_test += grouped[grouped['Wildtype'] == wildtype][('lmax', 'count')].values[0] + if n_val < num_val: + current = 'val' + +df_val = df[df['Wildtype'].isin(val_wt)] +df_test = df[df['Wildtype'].isin(test_wt)] +df_train = df[~df['Wildtype'].isin(np.concatenate([val_wt, test_wt]))] +print(len(df_val), "validation samples", "mean = ", df_val['lmax'].mean()) +print(len(df_test), "Test samples", "mean = ", df_test['lmax'].mean()) +print(len(df_train), "Train samples", "mean = ", df_train['lmax'].mean()) + +df_out = pd.DataFrame() +df_out['wildtype'] = df['Wildtype'] +df_out['sequence'] = df['Sequence'] +df_out['target'] = df['lmax'] +df_out['set'] = 'train' +df_out.loc[df_out['wildtype'].isin(test_wt), 'set'] = 'test' +df_out['validation'] = df_out['wildtype'].isin(val_wt) +df_out = df_out[['sequence', 'target', 'set', 'validation']] +df_out.to_csv(os.path.join(os.path.join(args.out_path, "by_wt.csv")), index=False) \ No newline at end of file diff --git a/flip/utils/ranking.py b/flip/utils/ranking.py new file mode 100644 index 0000000..cdefe88 --- /dev/null +++ b/flip/utils/ranking.py @@ -0,0 +1,15 @@ +import torch +from torch.nn import functional as F + +def make_square_and_get_triu(input): + input_by_input = input - input.transpose(0, 1) + idx = torch.triu_indices(len(input), len(input), 1, device=input.device) + return input_by_input[idx[0], idx[1]].view(-1, 1) + + +def bradley_terry_loss(predictions, targets): + flat_targets = make_square_and_get_triu(targets) + flat_targets = flat_targets > 0 + flat_targets = flat_targets.float() + flat_predictions = make_square_and_get_triu(predictions) + return F.binary_cross_entropy_with_logits(flat_predictions, flat_targets) \ No newline at end of file