Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/Single_Sequence_Inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ For generating template information, you will need the UniRef90 and PDB70 databa
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.

```{note}
Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
```
Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. By default, `scripts/precompute_embeddings.py` truncates longer sequences to 1022 residues and emits a warning. You can disable truncation with `--no-truncate`, in which case the script raises an error for overlength sequences instead of silently continuing.
```
54 changes: 44 additions & 10 deletions scripts/precompute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

logging.basicConfig(level=logging.INFO)

ESM_MAX_RESIDUES = 1022

class SequenceDataset(object):
def __init__(self, labels, sequences) -> None:
self.labels = labels
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(self,
self.truncate = truncate
self.use_local_esm = use_local_esm
self.nogpu = nogpu
self.max_residues = ESM_MAX_RESIDUES

# Generate embeddings in bulk
if self.use_local_esm:
Expand Down Expand Up @@ -122,24 +125,46 @@ def run(
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
effective_lengths = []
truncated_labels = []
for label, seq in zip(labels, strs):
seq_len = len(seq)
if seq_len > self.max_residues:
if not self.truncate:
raise ValueError(
f"Sequence '{label}' has length {seq_len}, exceeding "
f"the ESM-1b limit of {self.max_residues} residues. "
"Re-run with truncation enabled."
)
truncated_labels.append((label, seq_len))
effective_lengths.append(min(seq_len, self.max_residues))

if truncated_labels:
for label, seq_len in truncated_labels:
logging.warning(
"Truncating sequence '%s' from %d to %d residues for ESM-1b.",
label,
seq_len,
self.max_residues,
)

max_effective_length = max(effective_lengths)
toks = toks[:, : max_effective_length + 2]
if torch.cuda.is_available() and not self.nogpu:
toks = toks.to(device="cuda", non_blocking=True)

if self.truncate:
toks = toks[:1022]


out = self.model(toks, repr_layers=repr_layers, return_contacts=False)

representations = {
33: out["representations"][33].to(device="cpu")
}

for i, label in enumerate(labels):
for i, (label, effective_length) in enumerate(zip(labels, effective_lengths)):
os.makedirs(os.path.join(output_dir, label), exist_ok=True)
result = {"label": label}

result["representations"] = {
33: representations[33][i, 1: len(strs[i]) + 1].clone()
33: representations[33][i, 1: effective_length + 1].clone()
}
torch.save(
result,
Expand Down Expand Up @@ -168,7 +193,7 @@ def main(args):
logging.info("Completed.")


if __name__ == "__main__":
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_dir", type=str,
Expand All @@ -183,8 +208,13 @@ def main(args):
help="maximum tokens in a batch"
)
parser.add_argument(
"--truncate", action="store_true", default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
"--truncate",
action=argparse.BooleanOptionalAction,
default=True,
help=(
f"Truncate sequences longer than {ESM_MAX_RESIDUES} residues "
"to satisfy the ESM-1b limit. Default: True"
),
)
parser.add_argument(
"--use_local_esm", type=str, default=None,
Expand All @@ -195,6 +225,10 @@ def main(args):
help="Do not use GPU"
)

args = parser.parse_args()
return parser


if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
main(args)
95 changes: 95 additions & 0 deletions tests/test_precompute_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import tempfile
import unittest
from unittest import mock

import torch

from scripts import precompute_embeddings


class _FakeModel:
def __init__(self):
self.last_toks_shape = None

def to(self, device="cpu"):
return self

def __call__(self, toks, repr_layers=None, return_contacts=False):
self.last_toks_shape = tuple(toks.shape)
batch, tok_len = toks.shape
reps = torch.arange(batch * tok_len * 3, dtype=torch.float32)
reps = reps.reshape(batch, tok_len, 3)
return {"representations": {33: reps}}


class _FakeAlphabet:
@staticmethod
def get_batch_converter():
def _convert(batch):
labels = [label for label, _ in batch]
sequences = [seq for _, seq in batch]
max_len = max(len(seq) for seq in sequences)
toks = torch.zeros((len(batch), max_len + 2), dtype=torch.int64)
return labels, sequences, toks

return _convert


class TestPrecomputeEmbeddings(unittest.TestCase):
def test_parser_defaults_and_no_truncate_flag(self):
parser = precompute_embeddings.create_parser()

args = parser.parse_args(["input", "output"])
self.assertTrue(args.truncate)

args = parser.parse_args(["input", "output", "--no-truncate"])
self.assertFalse(args.truncate)

@mock.patch("scripts.precompute_embeddings.torch.hub.load")
def test_long_sequences_are_truncated_on_token_axis(self, mock_load):
model = _FakeModel()
mock_load.return_value = (model, _FakeAlphabet())

generator = precompute_embeddings.EmbeddingGenerator(
truncate=True,
nogpu=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
fasta_path = os.path.join(tmpdir, "input.fasta")
output_dir = os.path.join(tmpdir, "embeddings")
long_seq = "A" * 1030
short_seq = "G" * 12
with open(fasta_path, "w") as fp:
fp.write(f">long\n{long_seq}\n>short\n{short_seq}\n")

generator.run(fasta_path, output_dir)

long_result = torch.load(os.path.join(output_dir, "long", "long.pt"))
short_result = torch.load(os.path.join(output_dir, "short", "short.pt"))

self.assertEqual(long_result["representations"][33].shape[0], 1022)
self.assertEqual(short_result["representations"][33].shape[0], len(short_seq))
self.assertEqual(model.last_toks_shape, (2, 1024))

@mock.patch("scripts.precompute_embeddings.torch.hub.load")
def test_long_sequences_fail_when_truncation_disabled(self, mock_load):
mock_load.return_value = (_FakeModel(), _FakeAlphabet())

generator = precompute_embeddings.EmbeddingGenerator(
truncate=False,
nogpu=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
fasta_path = os.path.join(tmpdir, "input.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">long\n{'A' * 1030}\n")

with self.assertRaisesRegex(ValueError, "exceeding the ESM-1b limit"):
generator.run(fasta_path, tmpdir)


if __name__ == "__main__":
unittest.main()