Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@
if getattr(site, key) == value:
yield site

def iter_sites_by_residue(self, residue_tag):
def iter_sites_by_residue(self, residue_tag, residue_number=None):
Comment thread Fixed
"""Iterate through this topology's sites which contain this specific residue name.

See Also
Expand All @@ -1457,12 +1457,29 @@
"""
if isinstance(residue_tag, str):
for site in self._sites:
if site.residue and getattr(site, "residue").name == residue_tag:
yield site
if residue_tag and residue_number is None:
if site.molecule and getattr(site, "residue").name == residue_tag:
yield site
elif residue_number and residue_tag is None:
if (
site.molecule
and getattr(site, "residue").number == residue_number
):
yield site
else:
if all(
[
site.molecule
and getattr(site, "residue").name == residue_tag,
site.molecule
and getattr(site, "residue").number == residue_number,
]
):
Comment thread Fixed
yield site
else:
return self.iter_sites("residue", residue_tag)

def iter_sites_by_molecule(self, molecule_tag):
def iter_sites_by_molecule(self, molecule_tag, molecule_number=None):
"""Iterate through this topology's sites which contain this specific molecule name.

See Also
Expand All @@ -1472,8 +1489,25 @@
"""
if isinstance(molecule_tag, str):
for site in self._sites:
if site.molecule and getattr(site, "molecule").name == molecule_tag:
yield site
if molecule_tag and molecule_number is None:
if site.molecule and getattr(site, "molecule").name == molecule_tag:
yield site
elif molecule_number and molecule_tag is None:
if (
site.molecule
and getattr(site, "molecule").number == molecule_number
):
yield site
else:
if all(
[
site.molecule
and getattr(site, "molecule").name == molecule_tag,
site.molecule
and getattr(site, "molecule").number == molecule_number,
]
):
yield site
else:
return self.iter_sites("molecule", molecule_tag)

Expand Down
11 changes: 11 additions & 0 deletions gmso/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,17 @@ def test_iter_sites_by_molecule(self, labeled_top):
for site in labeled_top.iter_sites_by_molecule(molecule_name):
assert site.molecule.name == molecule_name

def test_iter_sites_by_molecule_tag_and_number(self, labeled_top):
molecules = labeled_top.unique_site_labels("molecule", name_only=False)
for molecule in molecules:
for site in labeled_top.iter_sites_by_molecule(molecule):
assert site.residue == molecule

molecule_names = labeled_top.unique_site_labels("molecule", name_only=True)
for molecule_name in molecule_names:
for site in labeled_top.iter_sites_by_molecule(molecule_name):
assert site.molecule.name == molecule_name

@pytest.mark.parametrize(
"connections",
["bonds", "angles", "dihedrals", "impropers"],
Expand Down
39 changes: 39 additions & 0 deletions gmso/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mbuild as mb
import numpy as np
import pytest
import unyt as u
Expand All @@ -7,6 +8,10 @@
from gmso.utils.geometry import moment_of_inertia
from gmso.utils.io import run_from_ipython
from gmso.utils.misc import unyt_to_hashable
from gmso.utils.slicing import (
slice_topology_by_molecule,
slice_topology_by_residue,
)
from gmso.utils.sorting import sort_connection_members, sort_connection_strings


Expand Down Expand Up @@ -81,3 +86,37 @@ def test_moment_of_inertia():
masses=np.array([1.0 for i in xyz]),
)
assert np.array_equal(tensor, np.array([1, 1, 2]))


def test_slice_by_molecule():
benzene = mb.load("c1ccccc1", smiles=True)
benzene.name = "Benzene"
ethane = mb.load("CC", smiles=True)
ethane.name = "Ethane"

system = mb.fill_box(compound=[benzene, ethane], n_compounds=[2, 2], box=[2, 2, 2])
topology = system.to_gmso()
topology.identify_connections()

single_benzene_top = slice_topology_by_molecule(topology, "Benzene", 0)
assert single_benzene_top.n_sites == 12

all_benzene_top = slice_topology_by_molecule(topology, "Benzene")
assert all_benzene_top.n_sites == 24


def test_slice_by_residue():
benzene = mb.load("c1ccccc1", smiles=True)
benzene.name = "Benzene"
ethane = mb.load("CC", smiles=True)
ethane.name = "Ethane"

system = mb.fill_box(compound=[benzene, ethane], n_compounds=[2, 2], box=[2, 2, 2])
topology = system.to_gmso()
topology.identify_connections()

single_ethane_top = slice_topology_by_residue(topology, "Ethane", 0)
assert single_ethane_top.n_sites == 8

all_ethane_top = slice_topology_by_residue(topology, "Ethane")
assert all_ethane_top.n_sites == 16
81 changes: 81 additions & 0 deletions gmso/utils/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from gmso.core.topology import Topology


def slice_topology_by_molecule(topology, molecule_tag, molecule_number=None):
"""Create a Topology that contains a subset of molecules from another Topology.

Parameters
----------
topology : gmso.core.topology.Topology
The gmso Topology to perform the slice on
molecule_tag : str
The name of the gmso.abstract_site.Molecule object to include in the slice
molecule_number : int, default None
If given, only include a single molecule's sites
If None, then all sites in every molecule matching `molecule_tag` are included
in the sliced topology.

Returns
-------
gmso.core.topology.Topology
A new Topology instance containing only sites and connections from matching molecules.
"""
sites = [
s
for s in topology.iter_sites_by_molecule(
molecule_tag=molecule_tag, molecule_number=molecule_number
)
]
return slice_by_sites(topology=topology, sites=sites)


def slice_topology_by_residue(topology, residue_tag, residue_number=None):
"""Create a Topology that contains a subset of residues from another Topology.

Parameters
----------
topology : gmso.core.topology.Topology
The gmso Topology to perform the slice on.
residue_tag : str
The name of the gmso.abstract_site.Residue object to include in the slice.
residue_number : int, default None
If given, only include a single residue's sites
If None, then all sites in every residue matching `residue_tag` are included
in the sliced topology.

Returns
-------
gmso.core.topology.Topology
A new Topology instance containing only sites and connections from matching molecules.
"""
sites = [
s
for s in topology.iter_sites_by_residue(
residue_tag=residue_tag, residue_number=residue_number
)
]
return slice_by_sites(topology=topology, sites=sites)


def slice_by_sites(topology, sites):
"""Used by slice_topology_by_molecule() and slice_topology_by_residue()

Parameters
----------
sites : list of gmso.core.atom.Atom
List of sites to include in the sub-topology.
topology : gmso.core.topology.Topology
The topology being sliced.
"""
new_topology = Topology()

connections = set()
for site in sites:
new_topology.add_site(site)
for connection in topology.iter_connections_by_site(site):
connections.add(connection)

for connection in connections:
new_topology.add_connection(connection)

return new_topology
Loading