Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The rules for this file:
<!-- New added features -->

- Added support to duplicate widgets (PR #13)
- Added support for batching and parallelization (PR #15)

### Fixed

Expand Down
18 changes: 12 additions & 6 deletions mdadash/backend/analyses/com_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from collections import deque

import matplotlib.pyplot as plt
import numpy as np
from MDAnalysis.exceptions import NoDataError
from MDAnalysis.lib.distances import calc_bonds

from mdadash.backend.widgets.base import WidgetBase

Expand Down Expand Up @@ -150,11 +151,16 @@ def on_input_change(self, attribute, _old_value, new_value):
self.y_values = deque(maxlen=self.maxlen)
self._set_x_values()

def run(self):
"""run handler"""
com1 = self.ag1.center_of_mass()
com2 = self.ag2.center_of_mass()
self.y_values.append(np.linalg.norm(com1 - com2))
def run_per_frame(self):
"""per-frame run handler"""
try:
com1 = self.ag1.center_of_mass(unwrap=True)
com2 = self.ag2.center_of_mass(unwrap=True)
except NoDataError: # pragma: no cover
# unwrap can fail if there is no bonds info
com1 = self.ag1.center_of_mass()
com2 = self.ag2.center_of_mass()
self.y_values.append(calc_bonds(com1, com2, box=self.u.dimensions))
self.steps.append(self.u.trajectory.ts.data["step"])
self.times.append(self.u.trajectory.ts.data["time"])
plt.plot(self.x_values, self.y_values)
Expand Down
6 changes: 3 additions & 3 deletions mdadash/backend/analyses/energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def on_input_change(self, attribute, _old_value, new_value):
elif attribute == "x_type":
self._set_x_values()

def run(self):
"""run handler"""
def run_per_frame(self):
"""per-frame run handler"""
ts = getattr(self, "u").trajectory.ts
if self.data_key not in ts.data:
return # pragma no cover
return # pragma: no cover
self.steps.append(ts.data["step"])
self.times.append(ts.data["time"])
self.y_values.append(ts.data[self.data_key])
Expand Down
68 changes: 58 additions & 10 deletions mdadash/backend/analyses/rog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib.pyplot as plt
import numpy as np
from joblib import delayed

from mdadash.backend.widgets.base import WidgetBase

Expand All @@ -19,16 +20,26 @@ class ROG(WidgetBase):

name = "ROG"
description = "Radii of Gyration of a selection"
_analysis_mode = "per-frame"

_inputs = [
{
"attribute": "_analysis_mode",
"name": "Analysis mode",
"description": "The mode to run this analysis widget",
"attribute": "_run_frequency",
"name": "Run frequency",
"description": "The frequency with which the widget is run",
"type": "select",
"items": [
"per-frame",
"batch",
],
},
{
"attribute": "_run_mode",
"name": "Run mode",
"description": "The mode in which the widget is run",
"type": "select",
"items": [
"serial",
"parallel",
],
},
{
Expand Down Expand Up @@ -130,9 +141,8 @@ def on_input_change(self, attribute, _old_value, new_value):
self.y_values = deque(maxlen=self.maxlen)
self._set_x_values()

# pylint: disable=too-many-locals
def run(self):
"""run handler"""
def _compute_rog_per_frame(self):
"""Compute ROG values for current frame"""
masses = self.ag.masses
total_mass = np.sum(masses)
coordinates = self.ag.positions
Expand All @@ -149,10 +159,30 @@ def run(self):
rog_sq = np.sum(masses * sq_rs, axis=1) / total_mass
# square root
rog = np.sqrt(rog_sq)
return (
self.u.trajectory.ts.data["step"],
self.u.trajectory.ts.data["time"],
rog,
)

def _compute_rog_batch(self, batch_size):
"""Compute ROG values for current batch"""
values = []
for i in range(batch_size):
_ = self.u.trajectory[1 - batch_size + i]
values.append(self._compute_rog_per_frame())
return values

def _create_plot(self, values):
"""Append ROG values and create plot"""
if isinstance(values, tuple):
values = [values]
# update plot points
self.y_values.append(rog)
self.steps.append(self.u.trajectory.ts.data["step"])
self.times.append(self.u.trajectory.ts.data["time"])
for value in values:
(steps, times, rog) = value
self.steps.append(steps)
self.times.append(times)
self.y_values.append(rog)
# create plot
data = np.array(self.y_values)
labels = ["all", "x-axis", "y-axis", "z-axis"]
Expand All @@ -164,3 +194,21 @@ def run(self):
plt.title(self.custom_title if self.custom_title else self.title)
plt.grid(True)
plt.show()

def run_per_frame(self):
"""per-frame run handler"""
self._create_plot(self._compute_rog_per_frame())

def run_batch(self, batch_size):
"""batch run handler"""
self._create_plot(self._compute_rog_batch(batch_size))

def get_parallel_job(self, batch_size):
"""get parallel job handler"""
if self._run_frequency == "batch":
return delayed(self._compute_rog_batch)(batch_size)
return delayed(self._compute_rog_per_frame)()

def apply_parallel_results(self, values):
"""apply parallel results handler"""
self._create_plot(values)
Loading
Loading