diff --git a/suite2p/detection/sparsedetect.py b/suite2p/detection/sparsedetect.py index 794ffc89..efe96d2f 100644 --- a/suite2p/detection/sparsedetect.py +++ b/suite2p/detection/sparsedetect.py @@ -452,7 +452,7 @@ def estimate_spatial_scale(I): ipk = np.abs(I0 - maximum_filter(I0, size=(11, 11))).flatten() < 1e-4 isort = np.argsort(I0.flatten()[ipk])[::-1] im, _ = mode(imap[ipk][isort[:50]], keepdims=True) - return im.item() + return int(im.item()) def find_best_scale(I, spatial_scale): diff --git a/suite2p/io/sbx.py b/suite2p/io/sbx.py index d25072c3..865d0ed4 100644 --- a/suite2p/io/sbx.py +++ b/suite2p/io/sbx.py @@ -1,140 +1,132 @@ """ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. """ -import os +import logging import numpy as np -import logging -logger = logging.getLogger(__name__) - -from .utils import init_settings, find_files_open_binaries +logger = logging.getLogger(__name__) try: from sbxreader import sbx_memmap + HAS_SBX = True except: HAS_SBX = False - -def sbx_to_binary(settings, ndeadcols=-1, ndeadrows=0): - """ finds scanbox files and writes them to binaries + +def sbx_to_binary(dbs, settings, reg_file, reg_file_chan2): + """finds scanbox files and writes them to binaries Parameters ---------- - settings : dictionary - "nplanes", "data_path", "save_path", "save_folder", "fast_disk", - "nchannels", "keep_movie_raw", "look_one_level_down" + dbs : list of dict + Database dictionaries for each plane. Must contain keys "file_list", + "nplanes", "nchannels", "batch_size", "functional_chan", "sbx_ndeadcols", + and "sbx_ndeadrows". Updated in-place with "Ly", "Lx", "nframes", + "nframes_per_folder", "meanImg", and "meanImg_chan2". + settings : dict + Suite2p settings dictionary, saved alongside each plane's database. + reg_file : list of file objects + Opened binary files for writing each plane's functional channel data. + reg_file_chan2 : list of file objects + Opened binary files for writing each plane's second channel data + (used only when nchannels > 1). Returns ------- - settings : dictionary of first plane - "Ly", "Lx", settings["reg_file"] or settings["raw_file"] is created binary - + dbs : list of dict + Updated database dictionaries with image dimensions, frame counts, and + mean images populated. """ if not HAS_SBX: raise ImportError("sbxreader is required for this file type, please 'pip install sbxreader'") - settings1 = init_settings(settings) - # the following should be taken from the metadata and not needed but the files are initialized before... - nplanes = settings1[0]["nplanes"] - nchannels = settings1[0]["nchannels"] - # open all binary files for writing - settings1, sbxlist, reg_file, reg_file_chan2 = find_files_open_binaries(settings1) + sbxlist = dbs[0]["file_list"] + nplanes = dbs[0]["nplanes"] + nchannels = dbs[0]["nchannels"] + batch_size = dbs[0].get("batch_size", 500) + functional_chan = dbs[0].get("functional_chan", 1) + ndeadcols = int(dbs[0].get("sbx_ndeadcols", -1)) + ndeadrows = int(dbs[0].get("sbx_ndeadrows", 0)) + iall = 0 - for j in range(settings1[0]["nplanes"]): - settings1[j]["nframes_per_folder"] = np.zeros(len(sbxlist), np.int32) - ik = 0 - if "sbx_ndeadcols" in settings1[0].keys(): - ndeadcols = int(settings1[0]["sbx_ndeadcols"]) - if "sbx_ndeadrows" in settings1[0].keys(): - ndeadrows = int(settings1[0]["sbx_ndeadrows"]) + for j in range(nplanes): + dbs[j]["nframes_per_folder"] = np.zeros(len(sbxlist), np.int32) if ndeadcols == -1 or ndeadrows == -1: - # compute dead rows and cols from the first file tmpsbx = sbx_memmap(sbxlist[0]) - # do not remove dead rows in non-multiplane mode - # This number should be different for each plane since the artifact is larger - # for larger ETL jumps. - if nplanes > 1 and ndeadrows == -1: - colprofile = np.array(np.mean(tmpsbx[0][0][0], axis=1)) - ndeadrows = np.argmax(np.diff(colprofile)) + 1 - else: - ndeadrows = 0 + # do not remove dead rows in non-multiplane mode; artifact is larger for bigger ETL jumps + if ndeadrows == -1: + if nplanes > 1: + colprofile = np.array(np.mean(tmpsbx[0][0][0], axis=1)) + ndeadrows = np.argmax(np.diff(colprofile)) + 1 + else: + ndeadrows = 0 # do not remove dead columns in unidirectional scanning mode - # do this only if ndeadcols is -1 - if tmpsbx.metadata["scanning_mode"] == "bidirectional" and ndeadcols == -1: - ndeadcols = tmpsbx.ndeadcols - else: - ndeadcols = 0 + if ndeadcols == -1: + if tmpsbx.metadata["scanning_mode"] == "bidirectional": + ndeadcols = tmpsbx.ndeadcols + else: + ndeadcols = 0 del tmpsbx logger.info("Removing {0} dead columns while loading sbx data.".format(ndeadcols)) logger.info("Removing {0} dead rows while loading sbx data.".format(ndeadrows)) - settings1[0]["sbx_ndeadcols"] = ndeadcols - settings1[0]["sbx_ndeadrows"] = ndeadrows + for db in dbs: + db["sbx_ndeadcols"] = ndeadcols + db["sbx_ndeadrows"] = ndeadrows for ifile, sbxfname in enumerate(sbxlist): f = sbx_memmap(sbxfname) - nplanes = f.shape[1] - nchannels = f.shape[2] + nplanes_f = f.shape[1] + nchannels_f = f.shape[2] nframes = f.shape[0] - iblocks = np.arange(0, nframes, settings1[0]["batch_size"]) + iblocks = np.arange(0, nframes, batch_size) if iblocks[-1] < nframes: iblocks = np.append(iblocks, nframes) - # data = nframes x nplanes x nchannels x pixels x pixels - if nchannels > 1: - nfunc = settings1[0]["functional_chan"] - 1 - else: - nfunc = 0 - # loop over all frames + # data shape: nframes x nplanes x nchannels x Ly x Lx + nfunc = functional_chan - 1 if nchannels_f > 1 else 0 for ichunk, onset in enumerate(iblocks[:-1]): offset = iblocks[ichunk + 1] im = np.array(f[onset:offset, :, :, ndeadrows:, ndeadcols:]) // 2 im = im.astype(np.int16) - im2mean = im.mean(axis=0).astype(np.float32) / len(iblocks) - for ichan in range(nchannels): - nframes = im.shape[0] + nframes_batch = im.shape[0] + if iall == 0: + for j in range(nplanes_f): + dbs[j]["meanImg"] = np.zeros((im.shape[3], im.shape[4]), np.float32) + if nchannels_f > 1: + dbs[j]["meanImg_chan2"] = np.zeros( + (im.shape[3], im.shape[4]), np.float32 + ) + dbs[j]["nframes"] = 0 + for ichan in range(nchannels_f): im2write = im[:, :, ichan, :, :] - for j in range(0, nplanes): - if iall == 0: - settings1[j]["meanImg"] = np.zeros((im.shape[3], im.shape[4]), - np.float32) - if nchannels > 1: - settings1[j]["meanImg_chan2"] = np.zeros( - (im.shape[3], im.shape[4]), np.float32) - settings1[j]["nframes"] = 0 + for j in range(nplanes_f): + plane_frames = im2write[:, j, :, :].astype(np.int16) if ichan == nfunc: - settings1[j]["meanImg"] += np.squeeze(im2mean[j, ichan, :, :]) - reg_file[j].write( - bytearray(im2write[:, j, :, :].astype("int16"))) + dbs[j]["meanImg"] += plane_frames.astype(np.float32).sum(axis=0) + reg_file[j].write(bytearray(plane_frames)) + dbs[j]["nframes"] += nframes_batch + dbs[j]["nframes_per_folder"][ifile] += nframes_batch else: - settings1[j]["meanImg_chan2"] += np.squeeze(im2mean[j, ichan, :, :]) - reg_file_chan2[j].write( - bytearray(im2write[:, j, :, :].astype("int16"))) - - settings1[j]["nframes"] += im2write.shape[0] - settings1[j]["nframes_per_folder"][ifile] += im2write.shape[0] - ik += nframes - iall += nframes - - # write settings files - do_registration = settings1[0]["do_registration"] - do_nonrigid = settings1[0]["nonrigid"] - for settings in settings1: - settings["Ly"] = im.shape[3] - settings["Lx"] = im.shape[4] + dbs[j]["meanImg_chan2"] += plane_frames.astype(np.float32).sum(axis=0) + reg_file_chan2[j].write(bytearray(plane_frames)) + iall += nframes_batch + + # update dbs with image dimensions and mean images + do_registration = settings["run"]["do_registration"] + for db in dbs: + db["Ly"] = im.shape[3] + db["Lx"] = im.shape[4] if not do_registration: - settings["yrange"] = np.array([0, settings["Ly"]]) - settings["xrange"] = np.array([0, settings["Lx"]]) - #settings["meanImg"] /= settings["nframes"] - #if nchannels>1: - # settings["meanImg_chan2"] /= settings["nframes"] - np.save(settings["settings_path"], settings) - # close all binary files and write settings files - for j in range(0, nplanes): - reg_file[j].close() + db["yrange"] = np.array([0, db["Ly"]]) + db["xrange"] = np.array([0, db["Lx"]]) + db["meanImg"] /= db["nframes"] if nchannels > 1: - reg_file_chan2[j].close() - return settings1[0] + db["meanImg_chan2"] /= db["nframes"] + np.save(db["db_path"], db) + np.save(db["settings_path"], settings) + + return dbs