diff --git a/src/PyMca5/PyMcaCore/DataObject.py b/src/PyMca5/PyMcaCore/DataObject.py index 20494681b..3fe1337d2 100644 --- a/src/PyMca5/PyMcaCore/DataObject.py +++ b/src/PyMca5/PyMcaCore/DataObject.py @@ -59,6 +59,59 @@ def __init__(self): self.info = {} self.data = numpy.array([]) + def padIncompleteScan(self, imageShape, padPositioners=False, mcaIndex=-1): + """ + Pad an incomplete/unfinished scan with NaN. + Result array have channels as the last dimension. + + :param imageShape: desired shape after padding (without the channels) + :param bool padPositioners: also pad the per-point positioners with NaN. + """ + + nChannels = self.data.shape[mcaIndex] + imageShape = tuple(int(d) for d in imageShape) + finalShape = imageShape + (nChannels,) + nOld = numpy.prod(numpy.delete(self.data.shape, mcaIndex)) + numberOfSpectra = int(numpy.prod(imageShape)) + if numberOfSpectra <= nOld: + return + if self.data.dtype in (numpy.float16, numpy.float32): + dtype = self.data.dtype + else: + dtype = numpy.float64 + padded = numpy.full((numberOfSpectra, nChannels), numpy.nan, dtype=dtype) + reorderedData = numpy.moveaxis(self.data, mcaIndex, -1) + padded[:nOld] = reorderedData.reshape(nOld, nChannels) + self.data = padded.reshape(finalShape) + for i in range(len(self.data.shape)): + self.info["Dim_%d" % (i + 1)] = self.data.shape[i] + self.info["McaIndex"] = len(self.data.shape) - 1 + + # per-point metadata (e.g. McaLiveTime) must follow the data to stay + # aligned, so it is always padded with zeros + for key, value in self.info.items(): + if key == "positioners": + continue + if hasattr(value, "size"): + arr = numpy.asarray(value) + if arr.size == nOld: + paddedArr = numpy.zeros(numberOfSpectra, dtype=arr.dtype) + paddedArr[:nOld] = arr.ravel() + self.info[key] = paddedArr + + if padPositioners: + if "positioners" in self.info and hasattr(self.info["positioners"], "items"): + for motor_name, motor_values in self.info["positioners"].items(): + if hasattr(motor_values, "size") and motor_values.size == nOld: + if motor_values.dtype in (numpy.float16, numpy.float32): + pos_dtype = motor_values.dtype + else: + pos_dtype = numpy.float64 + paddedPos = numpy.full(numberOfSpectra, numpy.nan, dtype=pos_dtype) + paddedPos[:nOld] = numpy.asarray(motor_values).ravel() + self.info["positioners"][motor_name] = paddedPos + + # all the following methods are here for compatibility purposes # they are obsolete and bound to disappear. diff --git a/src/PyMca5/PyMcaGui/io/hdf5/QNexusWidget.py b/src/PyMca5/PyMcaGui/io/hdf5/QNexusWidget.py index aca0abf59..bfbe126d7 100644 --- a/src/PyMca5/PyMcaGui/io/hdf5/QNexusWidget.py +++ b/src/PyMca5/PyMcaGui/io/hdf5/QNexusWidget.py @@ -440,8 +440,9 @@ def getOutputFilename(self): def getWidgetConfiguration(self): cntSelection = self.cntTable.getCounterSelection() - if hasattr(self, "actions"): - ddict =self.actions.getConfiguration() + # self.actions only exists when built without buttons + if hasattr(self, "actions") and hasattr(self.actions, "getConfiguration"): + ddict = self.actions.getConfiguration() else: ddict = {} ddict['counters'] = cntSelection['cntlist'] diff --git a/src/PyMca5/PyMcaGui/plotting/MaskScatterViewWidget.py b/src/PyMca5/PyMcaGui/plotting/MaskScatterViewWidget.py new file mode 100644 index 000000000..85418207d --- /dev/null +++ b/src/PyMca5/PyMcaGui/plotting/MaskScatterViewWidget.py @@ -0,0 +1,288 @@ +#/*########################################################################## +# Copyright (C) 2004-2026 European Synchrotron Radiation Facility +# +# This file is part of the PyMca X-ray Fluorescence Toolkit developed at +# the ESRF. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#############################################################################*/ +""" +It is a wraper of a silx :class:`ScatterView` +""" + +import logging +import numpy + +from PyMca5.PyMcaGui import PyMcaQt as qt + +from silx.gui.plot.ScatterView import ScatterView +from silx.gui.plot import items as silx_items +from silx.gui.colors import Colormap + +_logger = logging.getLogger(__name__) + + +class AxesPositionersSelector(qt.QWidget): + sigSelectionChanged = qt.pyqtSignal(object, object) + + def __init__(self, parent=None): + qt.QWidget.__init__(self, parent) + hlayout = qt.QHBoxLayout() + self.setLayout(hlayout) + self._initializing = True + xlabel = qt.QLabel("X:", parent=parent) + self.xPositioner = qt.QComboBox(parent) + self.xPositioner.currentIndexChanged.connect(self._emitSelectionChanged) + + ylabel = qt.QLabel("Y:", parent=parent) + self.yPositioner = qt.QComboBox(parent) + self.yPositioner.currentIndexChanged.connect(self._emitSelectionChanged) + self._initializing = False + + hlayout.addWidget(xlabel) + hlayout.addWidget(self.xPositioner) + hlayout.addWidget(ylabel) + hlayout.addWidget(self.yPositioner) + + self._nPoints = None + """If set to an integer, only motors with this number of data points + can be added.""" + + self._initComboBoxes() + + def _initComboBoxes(self): + self.xPositioner.clear() + self.xPositioner.insertItem(0, "None") + self.yPositioner.clear() + self.yPositioner.insertItem(0, "None") + + def _emitSelectionChanged(self, idx): + if not self._initializing: + self.sigSelectionChanged.emit(*self.getSelectedPositioners()) + + def setNumPoints(self, n): + self._nPoints = n + + def unsetNumPoints(self): + self._nPoints = None + + def fillPositioners(self, positioners): + """ + + :param dict positioners: Dictionary of positioners + The key is the motor name, the value are the motor's position data + """ + currentX, currentY = self.getSelectedPositioners() + + self._initializing = True + self._initComboBoxes() + i = 0 + for motorName, motorValues in positioners.items(): + if not numpy.isscalar(motorValues) and self._nPoints is not None and self._nPoints != motorValues.size: + # checks consistency of number of data points (but accepts scalars) + continue + else: + i += 1 + self.xPositioner.insertItem(i, motorName) + self.yPositioner.insertItem(i, motorName) + + if currentX in positioners and currentY in positioners: + self.xPositioner.setCurrentIndex(self.xPositioner.findText(currentX)) + self.yPositioner.setCurrentIndex(self.yPositioner.findText(currentY)) + self._initializing = False + def getSelectedPositioners(self): + """ + + :return: 2-tuple of selected positioner names (or None) + """ + selected = [None, None] + if self.xPositioner.currentText() != "None": + selected[0] = self.xPositioner.currentText() + if self.yPositioner.currentText() != "None": + selected[1] = self.yPositioner.currentText() + return selected + + +class MaskScatterViewWidget(qt.QWidget): + """ + Plain QWidget (not a QMainWindow) + to work in a layout and in a plugin + """ + def __init__(self, parent=None, backend="mpl"): + qt.QWidget.__init__(self, parent) + self.setWindowTitle("Mask Scatter View") + layout = qt.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + + self._scatterView = ScatterView(parent=self, backend=backend) + self._scatterView.setColormap(Colormap("temperature")) + self._scatterView.getScatterItem().setSymbol("s") + + self._axesSelector = AxesPositionersSelector(parent=self) + self._axesSelector.sigSelectionChanged.connect(self._setAxesData) + + layout.addWidget(self._scatterView, 1) + layout.addWidget(self._axesSelector, 0) + + self._positioners = {} + self._xdata = None + self._ydata = None + self._stackImage = None + + def getMaskToolsWidget(self): + return self._scatterView.getMaskToolsWidget() + + def resetZoom(self): + return self._scatterView.resetZoom() + + def fillPositioners(self, positioners): + self._positioners = positioners + self._axesSelector.fillPositioners(positioners) + + def setNumPoints(self, n): + self._axesSelector.setNumPoints(n) + + def _setAxesData(self, xPositioner, yPositioner): + """ + + :param str xPositioner: motor name, or None + :param str yPositioner: motor name, or None + :return: + """ + if xPositioner not in [None, ""]: + assert xPositioner in self._positioners + self._xdata = self._positioners[xPositioner] + else: + self._xdata = None + if yPositioner not in [None, ""]: + assert yPositioner in self._positioners + self._ydata = self._positioners[yPositioner] + else: + self._ydata = None + if self._stackImage is not None: + self.setData() + if not self._scatterView.getMaskToolsWidget().isVisible(): + # synchronization inactive, force mask redrawing + mask = self._scatterView.getMaskToolsWidget().getSelectionMask() + if mask is not None: + self._scatterView.getMaskToolsWidget().setSelectionMask(mask) + + self._scatterView.resetZoom() + + def setData(self, stackImage=None): + first_time = self._stackImage is None + if first_time: + assert stackImage is not None + + if stackImage is None: + # use previous data + stackImage = self._stackImage + else: + # update stored data + self._stackImage = stackImage + nrows, ncols = stackImage.shape + + # flatten image + stackValues = stackImage.reshape((-1,)) + + # get regular grid coordinates as a 1D array + if self._xdata is None or self._ydata is None: + defaultX, defaultY = numpy.meshgrid(numpy.arange(ncols), + numpy.arange(nrows)) + defaultX = defaultX.reshape(*stackValues.shape) + defaultY = defaultY.reshape(*stackValues.shape) + + xdata = self._xdata if self._xdata is not None else defaultX + ydata = self._ydata if self._ydata is not None else defaultY + + if numpy.isscalar(xdata): + xdata = xdata * numpy.ones_like(stackValues) + _logger.debug("converting scalar to constant 1D array for x") + elif len(xdata.shape) > 1: + _logger.debug("flattening %s array", str(xdata.shape)) + xdata = xdata.reshape((-1,)) + + if numpy.isscalar(ydata): + ydata = ydata * numpy.ones_like(stackValues) + _logger.debug("converting scalar to constant 1D array for y") + elif len(ydata.shape) > 1: + _logger.debug("flattening %s array", str(ydata.shape)) + ydata = ydata.reshape((-1,)) + + self._scatterView.setData(xdata, ydata, stackValues, + copy=False) + if first_time: + self._scatterView.resetZoom() + + def _maskDockWidget(self): + widget = self._scatterView.getMaskToolsWidget() + while widget is not None and not isinstance(widget, qt.QDockWidget): + widget = widget.parentWidget() + return widget + + def setMaskToolsVisible(self, visible=True): + dock = self._maskDockWidget() + if dock is not None: + dock.setVisible(visible) + + def setSelectionReadOnly(self): + ''' + To disable the mask selection for original stack in ROI tool + ''' + dock = self._maskDockWidget() + if dock is not None: + dock.setVisible(False) + dock.toggleViewAction().setVisible(False) + + def setVisualizationMode(self, mode): + """ + Allow to set visualization mode like IRREGULAR_GRID + """ + self._scatterView.getScatterItem().setVisualization(mode) + + def addControlToolBar(self, title="Controls"): + """ + To be able to add extra menus to the toolbar + """ + toolbar = qt.QToolBar(title, self._scatterView) + self._scatterView.addToolBar(qt.Qt.TopToolBarArea, toolbar) + return toolbar + + def getScatterView(self): + return self._scatterView + + def setSelectedPositioners(self, xName, yName): + """ + + :param str xName: motor name to use as X axis (or None) + :param str yName: motor name to use as Y axis (or None) + """ + if xName is not None: + self._axesSelector.xPositioner.setCurrentText(xName) + if xName in self._positioners: + self._xdata = self._positioners[xName] + if yName is not None: + self._axesSelector.yPositioner.setCurrentText(yName) + if yName in self._positioners: + self._ydata = self._positioners[yName] + + + diff --git a/src/PyMca5/PyMcaGui/pymca/QHDF5StackWizard.py b/src/PyMca5/PyMcaGui/pymca/QHDF5StackWizard.py index 8c58d3e45..620bf14d5 100644 --- a/src/PyMca5/PyMcaGui/pymca/QHDF5StackWizard.py +++ b/src/PyMca5/PyMcaGui/pymca/QHDF5StackWizard.py @@ -36,6 +36,7 @@ from PyMca5.PyMcaCore import NexusDataSource from PyMca5 import PyMcaDirs import logging +import numpy _logger = logging.getLogger(__name__) @@ -190,6 +191,11 @@ def __init__(self, parent): self.stackIndexWidget = StackIndexWidget(self) self.mainLayout.addWidget(self.stackIndexWidget, 0) + self._scatterCheckBox = qt.QCheckBox( + "Scatter plot (X, Y coordinates are set per image point)", self) + self._scatterCheckBox.setChecked(False) + self.mainLayout.addWidget(self._scatterCheckBox, 0) + def setFileList(self, filelist): self.dataSource = NexusDataSource.NexusDataSource(filelist[0]) self.nexusWidget.setDataSource(self.dataSource) @@ -207,7 +213,7 @@ def setFileList(self, filelist): try: attr = attr.decode('utf-8') except Exception: - print("WARNING: Cannot decode NX_class attribute") + _logger.warning("Cannot decode NX_class attribute") attr = None else: attr = None @@ -226,7 +232,7 @@ def setFileList(self, filelist): try: attr = attr.decode('utf-8') except Exception: - print("WARNING: Cannot decode NX_class attribute") + _logger.warning("Cannot decode NX_class attribute") continue if attr in ['NXdata', b'NXdata']: nxDataList.append(key) @@ -247,7 +253,7 @@ def setFileList(self, filelist): try: signal_key = signal_key.decode('utf-8') except AttributeError: - print("WARNING: Cannot decode NX_class attribute") + _logger.warning("Cannot decode NX_class attribute") signal_dataset = nxData.get(signal_key) if signal_dataset is None: @@ -258,7 +264,7 @@ def setFileList(self, filelist): try: interpretation = interpretation.decode('utf-8') except AttributeError: - print("WARNING: Cannot decode interpretation") + _logger.warning("Cannot decode interpretation") axesList = list(nxData.attrs.get("axes", [])) if not axesList: @@ -270,7 +276,7 @@ def setFileList(self, filelist): try: axes = axes.decode('utf-8') except AttributeError: - print("WARNING: Cannot decode axes") + _logger.warning("Cannot decode axes") axes = axes.split(":") axesList = [ax for ax in axes if ax in nxData] signalList.append(signal_key) @@ -287,7 +293,7 @@ def setFileList(self, filelist): try: interpretation = interpretation.decode('utf-8') except Exception: - print("WARNING: Cannot decode interpretation") + _logger.warning("Cannot decode interpretation") if 'axes' in nxData[key].attrs.keys(): axes = nxData[key].attrs['axes'] @@ -295,7 +301,7 @@ def setFileList(self, filelist): try: axes = axes.decode('utf-8') except Exception: - print("WARNING: Cannot decode axes") + _logger.warning("Cannot decode axes") axes = axes.split(":") for axis in axes: if axis in nxData.keys(): @@ -335,16 +341,39 @@ def setFileList(self, filelist): self.nexusWidget.cntTable.setCounterSelection({'y': [0]}) def validatePage(self): + """ + Validate data while wizard is open + """ + selection = self._buildSelection() + if selection is None: + return False + if not self._validateScatterSelection(selection): + return False + + signalShapes, nPoints, axisSizes = self._collectValidationData(selection) + # required for case when no axes are selected + if signalShapes: + if not self._validateSignalShapes(signalShapes): + return False + if signalShapes and nPoints and axisSizes: + if not self._validateScanGeometry(selection, signalShapes, nPoints, axisSizes): + return False + + self.selection = selection + return True + + def _buildSelection(self): + """ + Build the selection dictionary from the counter table. + """ cntSelection = self.nexusWidget.cntTable.getCounterSelection() cntlist = cntSelection['cntlist'] if not len(cntlist): - text = "No dataset selection" - self.showMessage(text) - return False + self.showMessage("No dataset selection") + return None if not len(cntSelection['y']): - text = "No dataset selected as y" - self.showMessage(text) - return False + self.showMessage("No dataset selected as y") + return None selection = {} selection['x'] = [] selection['y'] = [] @@ -354,9 +383,132 @@ def validatePage(self): if len(cntSelection[key]): for idx in cntSelection[key]: selection[key].append(cntlist[idx]) - self.selection = selection + selection['scatter'] = self._scatterCheckBox.isChecked() + selection['allowPadding'] = False + return selection + + def _validateScatterSelection(self, selection): + if selection['scatter'] and len(selection['x']) < 2: + self.showMessage("Scatter mode requires two datasets selected as axes") + return False + else: + return True + + def _collectValidationData(self, selection): + """ + Read the signal and axes shapes from the selected datasets. + """ + try: + # choose selected entry, or first entry if none selected + h5file = self.dataSource._sourceObjectList[0] + entries = self.nexusWidget.getSelectedEntries() + entry = entries[0][0] if entries else list(h5file.keys())[0] + + signalShapes = [] + for yPath in selection['y']: + yShape = h5file[posixpath.join(entry, yPath.lstrip("/"))].shape + signalShapes.append(yShape) + + if selection['index'] == -1: + mcaAxis = len(signalShapes[0]) - 1 + else: + mcaAxis = selection['index'] + nPoints = int(numpy.prod(numpy.delete(signalShapes[0], mcaAxis))) + + axisSizes = [] + for xPath in selection['x']: + dataset = h5file[posixpath.join(entry, xPath.lstrip("/"))].shape + axisSizes.append(int(numpy.prod(dataset))) + + return signalShapes, nPoints, axisSizes + + except Exception: + _logger.warning("Fail to identify number of scan points and/or axes sizes") + return None, None, None + + def _validateSignalShapes(self, signalShapes): + # the selected signals are summed later so they must have the same shape + if not all(shape == signalShapes[0] for shape in signalShapes): + self.showMessage("Not all signal shapes are equal") + return False return True + def _validateScanGeometry(self, selection, signalShapes, nPoints, axisSizes): + scatter = selection['scatter'] + signalShape = signalShapes[0] + axisSize = axisSizes[0] + # scatter needs a single scan dimension (besides the channels) + if scatter and (len(signalShape) - 1) > 1: + self.showMessage( + "Scatter mode needs a flat per-point scan (one scan " + "dimension besides the channels).") + return False + + # the axes hold one value per scan point so they must have equal sizes + if scatter and not all(size == axisSize for size in axisSizes): + self.showMessage( + "Axes should have same number of positions " + "(as they hold one value per scan point)") + return False + + # check the number of points in the signal against the axes positions + if (nPoints > 1): + if scatter: + return self._validateScatterGeometry(selection, axisSize, nPoints) + elif len(axisSizes) == 2: + return self._validateGridGeometry(selection, axisSizes, nPoints) + return True + + def _validateScatterGeometry(self, selection, axisSize, nPoints): + if axisSize < nPoints: + self.showMessage("Fewer positions than points in scan is impossible") + return False + if axisSize > nPoints: + if not self._confirmPadding( + "There are %d motor positions but only %d scan points." + "The missing points can be padded with NaN and will be shown as empty." + % (axisSize, nPoints)): + return False + selection['allowPadding'] = True + return True + + def _validateGridGeometry(self, selection, axisSizes, nPoints): + nA, nB = axisSizes + # can cause a problem but only in unrealistic scenario + # when user want to pad symmetric scan which failed almost at the start + if (nA == nB) and (nA >= nPoints): + self.showMessage( + "Most probably the selected motor positions hold one value per scan point." + "The regular grid can not be defined." + "Enable 'Scatter plot' or select different axes.") + return False + elif (nA * nB) < nPoints: + self.showMessage( + "The selected axes define %d positions (%d x %d) but the " + "signal has %d points. Please select differently." % (nA * nB, nA, nB, nPoints)) + return False + elif (nA * nB) > nPoints: + if not self._confirmPadding( + "The %d x %d grid has %d positions but there are only %d scan points." + "The missing points can be padded with NaN and will be shown as empty." + % (nA, nB, nA * nB, nPoints)): + return False + # protecting from accidental padding + selection['allowPadding'] = True + return True + + def _confirmPadding(self, text): + """Confirm padding while wizard is open""" + msg = qt.QMessageBox(self) + msg.setIcon(qt.QMessageBox.Warning) + msg.setWindowTitle("Unfinished scan") + msg.setText(text + "\n\nContinue, or cancel to change the selection?") + contButton = msg.addButton("Continue", qt.QMessageBox.AcceptRole) + msg.addButton("Cancel", qt.QMessageBox.RejectRole) + msg.exec() + clicked = msg.clickedButton() + return clicked is contButton + def showMessage(self, text): msg = qt.QMessageBox(self) msg.setIcon(qt.QMessageBox.Information) diff --git a/src/PyMca5/PyMcaGui/pymca/QStackWidget.py b/src/PyMca5/PyMcaGui/pymca/QStackWidget.py index 6d79f275b..d24e9f1a1 100644 --- a/src/PyMca5/PyMcaGui/pymca/QStackWidget.py +++ b/src/PyMca5/PyMcaGui/pymca/QStackWidget.py @@ -63,6 +63,9 @@ from PyMca5.PyMcaGui.misc import CloseEventNotifyingWidget from PyMca5.PyMcaGui.plotting import MaskImageWidget convertToRowAndColumn = MaskImageWidget.convertToRowAndColumn +from PyMca5.PyMcaGui.plotting.MaskScatterViewWidget import MaskScatterViewWidget +from silx.gui.plot import items as silx_items +from silx.gui.plot.items.scatter import _guess_grid from PyMca5.PyMcaGui.pymca import RGBCorrelator from PyMca5.PyMcaGui.pymca import QStackWidget @@ -204,6 +207,13 @@ def _build(self, vertical=False): boxmainlayout.addWidget(self.roiWindow) self.mainLayout.addWidget(box) + # scatter mode is configured lazily in setStack when the loaded stack + # carries info["scatter"] == True (see _configureScatterMode). Two + # scatter views replace the two grid images (stack + ROI). + self._scatterStackView = None + self._scatterRoiView = None + self._scatterMode = False + #add some missing icons offset = 8 @@ -243,7 +253,8 @@ def setStack(self, *var, **kw): self.stackWidget.setImageData(None) self.roiWidget.setImageData(None) StackBase.StackBase.setStack(self, *var, **kw) - if (1 in self._stack.data.shape) and\ + scatter = self._stack.info.get("scatter", False) + if (not scatter) and (1 in self._stack.data.shape) and\ isinstance(self._stack.data, numpy.ndarray): oldshape = self._stack.data.shape dialog = ImageShapeDialog(self, shape=oldshape[0:2]) @@ -257,8 +268,9 @@ def setStack(self, *var, **kw): new_npixels = shape[0] * shape[1] old_npixels = oldshape[0] * oldshape[1] if pad_with_nan and new_npixels > old_npixels: - self._padIncompleteStack(old_npixels, new_npixels, - shape, oldshape[2]) + self._stack.padIncompleteScan( + (shape[0], shape[1]), + padPositioners=True) else: self._stack.data = self._stack.data.reshape( shape[0], shape[1], oldshape[2]) @@ -267,6 +279,7 @@ def setStack(self, *var, **kw): # make sure old ROI images are not used self._ROIImageDict["ROI"] = None StackBase.StackBase.setStack(self, self._stack, **kw) + self._configureScatterMode(scatter) if self._mcaMax is not None: self.addMcaMaxButton.show() else: @@ -296,41 +309,6 @@ def setStack(self, *var, **kw): def normalizeIconChecked(self): pass - def _padIncompleteStack(self, old_npixels, new_npixels, shape, nchannels): - """ - Pad an incomplete scan stack to the full grid size. - Data and positioner values are padded with NaN, - other per-pixel metadata is padded with zeros. - """ - # Pad spectrum - padded = numpy.full((new_npixels, nchannels), - numpy.nan, dtype=numpy.float64) - padded[:old_npixels, :] = self._stack.data.reshape( - old_npixels, nchannels) - self._stack.data = padded.reshape(shape[0], shape[1], nchannels) - - # Pad positioners - info = self._stack.info - if "positioners" in info and hasattr(info["positioners"], "items"): - for motor_name, motor_values in info["positioners"].items(): - if hasattr(motor_values, "size") and \ - motor_values.size == old_npixels: - padded_pos = numpy.full(new_npixels, - numpy.nan, - dtype=numpy.float64) - padded_pos[:old_npixels] = numpy.asarray( - motor_values).ravel() - info["positioners"][motor_name] = padded_pos - - # Pad other arrays (e.g. McaLiveTime) with zeros - for key, value in info.items(): - if hasattr(value, "size"): - arr = numpy.asarray(value) - if arr.size == old_npixels: - padded_arr = numpy.zeros(new_npixels, dtype=arr.dtype) - padded_arr[:old_npixels] = arr.ravel() - info[key] = padded_arr - def _roiSubtractBackgroundClicked(self): if not len(self._ROIImageList): return @@ -1087,7 +1065,12 @@ def calculateROIImages(self, index1, index2, imiddle=None, energy=None): def showROIImageList(self, imageList, image_names=None): xScale = self._stack.info.get("xScale", None) yScale = self._stack.info.get("yScale", None) - if self.roiBackgroundButton.isChecked(): + if self._scatterMode and (self._scatterRoiView is not None): + # required for the "Add Image" + self.roiWidget.graphWidget.graph.setGraphTitle(image_names[0]) + self._feedScatterView(self._scatterRoiView, imageList[0], full=False) + self._setRoiWidgetGridImage() + elif self.roiBackgroundButton.isChecked(): self.roiWidget.graphWidget.graph.setGraphTitle(image_names[0] + \ " Net") self.roiWidget.setImageData(imageList[0]-imageList[-1], @@ -1271,6 +1254,21 @@ def setSelectionMask(self, mask, instance_id=None): else: widget.setSelectionMask(mask, plot=True) + # mirror the mask onto the scatter views + if self._scatterMode and (self.getStackOriginalImage() is not None): + scatterMask = mask.reshape((-1,)) if mask is not None else None + # update original (left) stack + if self._scatterStackView is not None: + self._scatterStackView.getMaskToolsWidget().setSelectionMask(scatterMask) + # right view keep selection + if (self._scatterRoiView is not None) and (instance_id != id(self._scatterRoiView)): + roiMask = self._scatterRoiView.getMaskToolsWidget() + roiMask.sigMaskChanged.disconnect(self._scatterMaskChanged) + try: + roiMask.setSelectionMask(scatterMask) + finally: + roiMask.sigMaskChanged.connect(self._scatterMaskChanged) + if self.rgbWidget is not None: if hasattr(self.rgbWidget, "setSelectionMask"): self.rgbWidget.setSelectionMask(mask, instance_id=instance_id) @@ -1312,6 +1310,137 @@ def setSelectionMask(self, mask, instance_id=None): def getSelectionMask(self): return self._selectionMask + def _configureScatterMode(self, scatter): + """ + Replace the raw stack and the ROI image with MaskScatterViews + or restore the regular widgets. + + The MaskImageWidgets and MaskScatterViewWidgets + are kept alive but hidden in coressponding mode + """ + if scatter: + if self._scatterRoiView is None: + # left - orginal stack + try: + self._scatterStackView = MaskScatterViewWidget(parent=self.stackWindow, backend="mpl") + except: + self._scatterStackView = MaskScatterViewWidget(parent=self.stackWindow, backend="gl") + irregular = silx_items.Scatter.Visualization.IRREGULAR_GRID + self._scatterStackView.setVisualizationMode(irregular) + self._scatterStackView.setSelectionReadOnly() + # insert the scatter view in the origin + sidx = self.stackWindow.mainLayout.indexOf(self.stackWidget) + self.stackWindow.mainLayout.insertWidget(sidx, self._scatterStackView) + # right - ROI image + try: + self._scatterRoiView = MaskScatterViewWidget(parent=self.roiWindow, backend="mpl") + except: + self._scatterRoiView = MaskScatterViewWidget(parent=self.roiWindow, backend="gl") + self._scatterRoiView.setVisualizationMode(irregular) + # show mask tool on start + self._scatterRoiView.setMaskToolsVisible(True) + # insert the scatter view in the roi + ridx = self.roiWindow.mainLayout.indexOf(self.roiWidget) + self.roiWindow.mainLayout.insertWidget(ridx, self._scatterRoiView) + self._buildScatterControls() + self._scatterRoiView.getMaskToolsWidget().sigMaskChanged.connect(self._scatterMaskChanged) + self.stackWidget.hide() + self.roiWidget.graphWidget.hide() + self._scatterStackView.show() + self._scatterRoiView.show() + self._scatterMode = True + self._updateScatterData(full=True) + else: + if self._scatterStackView is not None: + self._scatterStackView.hide() + if self._scatterRoiView is not None: + self._scatterRoiView.hide() + self.stackWidget.show() + self.roiWidget.graphWidget.show() + self._scatterMode = False + + def _buildScatterControls(self): + """ + Recreate, control buttons on scatterStackView + """ + scatterToolBar = self._scatterStackView.addControlToolBar("Stack") + + self._scatterNormalizeAction = qt.QAction( + self.normalizeIcon, "Add spectra normalized to the number of selected pixels", scatterToolBar) + self._scatterNormalizeAction.setCheckable(True) + self._scatterNormalizeAction.setChecked(self.normalizeButton.isChecked()) + self._scatterNormalizeAction.toggled.connect(self.normalizeButton.setChecked) + self.normalizeButton.toggled.connect(self._scatterNormalizeAction.setChecked) + scatterToolBar.addAction(self._scatterNormalizeAction) + + if self.primary: + loadAction = qt.QAction(self.loadIcon, "Load another stack of same shape", scatterToolBar) + loadAction.triggered.connect(self.loadSecondaryStack) + scatterToolBar.addAction(loadAction) + + pluginAction = qt.QAction(self.pluginIcon, "Call/Load Stack Plugins", scatterToolBar) + pluginAction.triggered.connect(self._pluginClicked) + scatterToolBar.addAction(pluginAction) + + def _feedScatterView(self, view, imageData, full): + """ + Show ``imageData`` on the scatter view. + :param full: if True also populate the positioners + """ + if (view is None) or (imageData is None): + return + if full: + info = self._stack.info + view.setNumPoints(imageData.size) + view.fillPositioners(info.get("positioners", {})) + axes = info.get("scatterAxes") + if axes: + view.setSelectedPositioners(axes[0], axes[1]) + view.setData(imageData) + if full: + view.resetZoom() + + def _updateScatterData(self, full=False): + image = self.getStackOriginalImage() + self._feedScatterView(self._scatterStackView, image, full) + self._feedScatterView(self._scatterRoiView, image, full) + self._setRoiWidgetGridImage() + + def _setRoiWidgetGridImage(self): + gridded = self._scatterRegularGridImage() + if gridded is not None: + self.roiWidget.setImageData(gridded) + else: + _logger.warning("could define the regular-grid") + + def _scatterRegularGridImage(self): + item = self._scatterRoiView.getScatterView().getScatterItem() + guess = _guess_grid(item.getXData(copy=False), item.getYData(copy=False)) + if guess is None: + return None + order, (height, width) = guess + # nan padding (before padding was done only for the visualisation) + values = numpy.asarray(item.getValueData(copy=False), dtype=numpy.float64) + image = numpy.full(height * width, numpy.nan, dtype=numpy.float64) + image[:values.size] = values + if order == "row": + return image.reshape(height, width) + else: + return image.reshape(width, height).T + + def _scatterMaskChanged(self): + """ + Propagate a selection drawn on the ROI to the original stack. + """ + mtw = self._scatterRoiView.getMaskToolsWidget() + scattermask = mtw.getSelectionMask(copy=False) + original = self.getStackOriginalImage() + if (scattermask is not None) and (original is not None): + mask = scattermask.reshape(original.shape) + else: + mask = None + self.setSelectionMask(mask, instance_id=id(self._scatterRoiView)) + def _maskImageWidgetSlot(self, ddict): if ddict['event'] == "selectionMaskChanged": self.setSelectionMask(ddict['current'], instance_id=ddict['id']) @@ -1448,6 +1577,10 @@ def closeEvent(self, event): # Inform plugins for key in self.pluginInstanceDict.keys(): self.pluginInstanceDict[key].stackClosed() + if self._scatterStackView is not None: + self._scatterStackView.close() + if self._scatterRoiView is not None: + self._scatterRoiView.close() CloseEventNotifyingWidget.CloseEventNotifyingWidget.closeEvent(self, event) if (self._primaryStack is None) and __name__ == "__main__": app = qt.QApplication.instance() diff --git a/src/PyMca5/PyMcaIO/HDF5Stack1D.py b/src/PyMca5/PyMcaIO/HDF5Stack1D.py index a90df4d3a..ec08d8026 100644 --- a/src/PyMca5/PyMcaIO/HDF5Stack1D.py +++ b/src/PyMca5/PyMcaIO/HDF5Stack1D.py @@ -110,6 +110,7 @@ def loadFileList(self, filelist, selection, scanlist=None): xSelection = None else: xSelection = None + xDatasetList = [] # only one y is taken ySelection = selection['y'] if type(ySelection) == type([]): @@ -133,6 +134,10 @@ def loadFileList(self, filelist, selection, scanlist=None): else: mSelection = None + scatter = selection.get('scatter', False) + # padded only if authorized by user in wizard + allowPadding = selection.get('allowPadding', False) + USE_JUST_KEYS = False # deal with the pathological case where the scanlist corresponds # to a selected top level dataset @@ -359,11 +364,12 @@ def loadFileList(self, filelist, selection, scanlist=None): mDataset = numpy.asarray(tmpHdf[mpath], dtype=mdtype) self.monitor = [mDataset] if xSelectionList is not None: - if len(xpathList) == 1: - xpath = xpathList[0] - xDataset = tmpHdf[xpath][()] - xDatasetList = [xDataset] - self.x = [xDataset] + # all axes are loaded otherwise the scatter plot can fail + xDatasetList = [] + for xpath in xpathList: + xDatasetList.append(tmpHdf[xpath][()]) + if len(xDatasetList) == 1: + self.x = [xDatasetList[0]] if h5py.version.version < '2.0': #prevent automatic closing keeping a reference #to the open file @@ -788,6 +794,29 @@ def loadFileList(self, filelist, selection, scanlist=None): self.info['McaCalib'] = _calibration else: self.info['McaCalib'] = [ 0.0, 1.0, 0.0] + # "1D data is first dimension" with McaIndex == 0 case + # scatter / regular-grid imaging expect the channels to be last + # so normalize it once here to fit next steps + if (xSelectionList is not None) and (len(xDatasetList) == 2) \ + and (self.info.get("McaIndex") == 0) and (len(self.data.shape) == 2): + self.data = self.data.T + self.data = self.data.reshape((1,) + self.data.shape) + self.info["McaIndex"] = 2 + # For stack (1, n, nChannels) and axes (i, j) with i*j = n + # case i*j > n is also supported with padding protection + # Reshape it here to process it as a regular stack (nRows, nCols, nChannels) + if (not scatter) and (xSelectionList is not None) \ + and (len(xDatasetList) == 2) and (len(self.data.shape) == 3) \ + and (self.data.shape[0] == 1) and (self.info.get("McaIndex") == 2): + nRows = numpy.asarray(xDatasetList[0]).size + nCols = numpy.asarray(xDatasetList[1]).size + nChannels = self.data.shape[-1] + nFlat = self.data.shape[1] + if nRows * nCols == nFlat: + self.data = self.data.reshape(nRows, nCols, nChannels) + elif (nRows * nCols > nFlat) and allowPadding: + self.padIncompleteScan((nRows, nCols)) + shape = self.data.shape nSpectra = 1 for i in range(len(shape)): @@ -799,7 +828,7 @@ def loadFileList(self, filelist, selection, scanlist=None): # try to get scales scaleList = [] - if xSelectionList is not None: + if (xSelectionList is not None) and (not scatter): if len(xDatasetList) == 1: xDataset = xDatasetList[0] if xDataset.size == shape[self.info['McaIndex']]: @@ -823,7 +852,7 @@ def loadFileList(self, filelist, selection, scanlist=None): for i in range(len(self.data.shape)): dataset = xDatasetList[i].reshape(-1) datasize = self.data.shape[i] - if i == mcaIndex: + if i == self.info["McaIndex"]: self.x = [dataset] else: origin = dataset[0] @@ -842,8 +871,8 @@ def loadFileList(self, filelist, selection, scanlist=None): _logger.warning("Ignoring dimension selections %s" % xSelectionList) elif len(xDatasetList) == (len(self.data.shape) - 1): scaleList = [] - for i in range(len(self.data.shape)): - if i == mcaIndex: + for i in range(len(xDatasetList)): + if i == self.info["McaIndex"]: continue dataset = xDatasetList[i].reshape(-1) datasize = self.data.shape[i] @@ -873,8 +902,8 @@ def loadFileList(self, filelist, selection, scanlist=None): self.info["McaLiveTime"] = _time if positionersGroup: self.info["positioners"] = positioners - if (len(scaleList) == 0) and (nFiles == 1) and (nScans == 1) \ - and (len(self.data.shape) == 3): + if (len(scaleList) == 0) and (not scatter) and (nFiles == 1) \ + and (nScans == 1) and (len(self.data.shape) == 3): # try to figure out the scales from the data layout originalDir = posixpath.dirname(mcaObjectPaths["counts"]) targetDir = posixpath.dirname(mcaObjectPaths["target"]) @@ -901,7 +930,7 @@ def loadFileList(self, filelist, selection, scanlist=None): if len(dims) == len(self.data.shape): scaleList = [] for i in range(len(self.data.shape)): - if i == mcaIndex: + if i == self.info["McaIndex"]: continue dataset = dims[i] origin = dataset[0] @@ -921,6 +950,27 @@ def loadFileList(self, filelist, selection, scanlist=None): self.info["xScale"] = xScale self.info["yScale"] = yScale + if scatter and (xSelectionList is not None) and (len(xDatasetList) >= 2): + # This will never happen while using "show as 1D stack" + # Thus, all motors should be already selected + # The padding should happen anyway! + # At this point all selections should be already validated by the wizard + motors = [] + names = [] + # for scatter plot there should be only two motors + for i in (0, 1): + motors.append(numpy.asarray(xDatasetList[i]).reshape(-1)) + names.append(posixpath.basename(xSelectionList[i])) + nPoints = motors[0].size + if self.data.shape[0] == 1: + self.padIncompleteScan((1, nPoints)) + for name, motor in zip(names, motors): + # the selected axes are not treated as grid scales + positioners[name] = motor + self.info["positioners"] = positioners + self.info["scatter"] = True + self.info["scatterAxes"] = tuple(names) + def getDimensions(self, nFiles, nScans, shape, index=None): #somebody may want to overwrite this """ diff --git a/src/PyMca5/PyMcaPlugins/MaskScatterViewPlugin.py b/src/PyMca5/PyMcaPlugins/MaskScatterViewPlugin.py index 1146de35d..01eaf0ada 100644 --- a/src/PyMca5/PyMcaPlugins/MaskScatterViewPlugin.py +++ b/src/PyMca5/PyMcaPlugins/MaskScatterViewPlugin.py @@ -43,9 +43,7 @@ from PyMca5 import StackPluginBase -from silx.gui.plot.ScatterView import ScatterView -from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget -from silx.gui.colors import Colormap +from PyMca5.PyMcaGui.plotting.MaskScatterViewWidget import MaskScatterViewWidget _logger = logging.getLogger(__name__) # _logger.setLevel(logging.DEBUG) @@ -73,194 +71,6 @@ _logger.debug("GL availability: %s", isGLAvailable) -class AxesPositionersSelector(qt.QWidget): - sigSelectionChanged = qt.pyqtSignal(object, object) - - def __init__(self, parent=None): - qt.QWidget.__init__(self, parent) - hlayout = qt.QHBoxLayout() - self.setLayout(hlayout) - self._initializing = True - xlabel = qt.QLabel("X:", parent=parent) - self.xPositioner = qt.QComboBox(parent) - self.xPositioner.currentIndexChanged.connect(self._emitSelectionChanged) - - ylabel = qt.QLabel("Y:", parent=parent) - self.yPositioner = qt.QComboBox(parent) - self.yPositioner.currentIndexChanged.connect(self._emitSelectionChanged) - self._initializing = False - - hlayout.addWidget(xlabel) - hlayout.addWidget(self.xPositioner) - hlayout.addWidget(ylabel) - hlayout.addWidget(self.yPositioner) - - self._nPoints = None - """If set to an integer, only motors with this number of data points - can be added.""" - - self._initComboBoxes() - - def _initComboBoxes(self): - self.xPositioner.clear() - self.xPositioner.insertItem(0, "None") - self.yPositioner.clear() - self.yPositioner.insertItem(0, "None") - - def _emitSelectionChanged(self, idx): - if not self._initializing: - self.sigSelectionChanged.emit(*self.getSelectedPositioners()) - - def setNumPoints(self, n): - self._nPoints = n - - def unsetNumPoints(self): - self._nPoints = None - - def fillPositioners(self, positioners): - """ - - :param dict positioners: Dictionary of positioners - The key is the motor name, the value are the motor's position data - """ - currentX, currentY = self.getSelectedPositioners() - - self._initializing = True - self._initComboBoxes() - i = 0 - for motorName, motorValues in positioners.items(): - if not numpy.isscalar(motorValues) and self._nPoints is not None and self._nPoints != motorValues.size: - # checks consistency of number of data points (but accepts scalars) - continue - else: - i += 1 - self.xPositioner.insertItem(i, motorName) - self.yPositioner.insertItem(i, motorName) - - if currentX in positioners and currentY in positioners: - self.xPositioner.setCurrentIndex(self.xPositioner.findText(currentX)) - self.yPositioner.setCurrentIndex(self.yPositioner.findText(currentY)) - self._initializing = False - def getSelectedPositioners(self): - """ - - :return: 2-tuple of selected positioner names (or None) - """ - selected = [None, None] - if self.xPositioner.currentText() != "None": - selected[0] = self.xPositioner.currentText() - if self.yPositioner.currentText() != "None": - selected[1] = self.yPositioner.currentText() - return selected - - -class MaskScatterViewWidget(qt.QMainWindow): - def __init__(self, parent=None, backend="mpl"): - qt.QMainWindow.__init__(self, parent) - self._scatterView = ScatterView(parent=self, backend=backend) - self._scatterView.setColormap(Colormap("temperature")) - self._scatterView.getScatterItem().setSymbol("s") - - self._axesSelector = AxesPositionersSelector(parent=self._scatterView) - self._axesSelector.sigSelectionChanged.connect(self._setAxesData) - - self.setCentralWidget(self._scatterView) - _axesSelectorDock = BoxLayoutDockWidget() - _axesSelectorDock.setWindowTitle('Axes selection') - _axesSelectorDock.setWidget(self._axesSelector) - self.addDockWidget(qt.Qt.BottomDockWidgetArea, _axesSelectorDock) - - self._positioners = {} - self._xdata = None - self._ydata = None - self._stackImage = None - - def getMaskToolsWidget(self): - return self._scatterView.getMaskToolsWidget() - - def resetZoom(self): - return self._scatterView.resetZoom() - - def fillPositioners(self, positioners): - self._positioners = positioners - self._axesSelector.fillPositioners(positioners) - - def setNumPoints(self, n): - self._axesSelector.setNumPoints(n) - - def _setAxesData(self, xPositioner, yPositioner): - """ - - :param str xPositioner: motor name, or None - :param str yPositioner: motor name, or None - :return: - """ - if xPositioner not in [None, ""]: - assert xPositioner in self._positioners - self._xdata = self._positioners[xPositioner] - else: - self._xdata = None - if yPositioner not in [None, ""]: - assert yPositioner in self._positioners - self._ydata = self._positioners[yPositioner] - else: - self._ydata = None - if self._stackImage is not None: - self.setData() - if not self._scatterView.getMaskToolsWidget().isVisible(): - # synchronization inactive, force mask redrawing - mask = self._scatterView.getMaskToolsWidget().getSelectionMask() - if mask is not None: - self._scatterView.getMaskToolsWidget().setSelectionMask(mask) - - self._scatterView.resetZoom() - - def setData(self, stackImage=None): - first_time = self._stackImage is None - if first_time: - assert stackImage is not None - - if stackImage is None: - # use previous data - stackImage = self._stackImage - else: - # update stored data - self._stackImage = stackImage - nrows, ncols = stackImage.shape - - # flatten image - stackValues = stackImage.reshape((-1,)) - - # get regular grid coordinates as a 1D array - if self._xdata is None or self._ydata is None: - defaultX, defaultY = numpy.meshgrid(numpy.arange(ncols), - numpy.arange(nrows)) - defaultX = defaultX.reshape(*stackValues.shape) - defaultY = defaultY.reshape(*stackValues.shape) - - xdata = self._xdata if self._xdata is not None else defaultX - ydata = self._ydata if self._ydata is not None else defaultY - - if numpy.isscalar(xdata): - xdata = xdata * numpy.ones_like(stackValues) - _logger.debug("converting scalar to constant 1D array for x") - elif len(xdata.shape) > 1: - _logger.debug("flattening %s array", str(xdata.shape)) - xdata = xdata.reshape((-1,)) - - if numpy.isscalar(ydata): - ydata = ydata * numpy.ones_like(stackValues) - _logger.debug("converting scalar to constant 1D array for y") - elif len(ydata.shape) > 1: - _logger.debug("flattening %s array", str(ydata.shape)) - ydata = ydata.reshape((-1,)) - - self._scatterView.setData(xdata, ydata, stackValues, - copy=False) - if first_time: - self._scatterView.resetZoom() - - class MaskScatterViewPlugin(StackPluginBase.StackPluginBase): """ Widget to handle a stack as a scatter plot, by using