from __future__ import print_function
import imageio
import matplotlib
import matplotlib.animation as animation
import matplotlib.cm as cm
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pylab as pl
from joblib import Parallel, delayed
from matplotlib.patches import Circle
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PySide6 import QtCore
from PySide6.QtCore import QRunnable, Signal, Slot
from scipy.ndimage import median_filter
from tqdm import tqdm
from piscat.InputOutput.cpu_configurations import CPUConfigurations
from piscat.Preproccessing.normalization import Normalization
class SignalConnection(QtCore.QObject):
updateProgress_ = QtCore.Signal(int)
finished = Signal()
[docs]class DisplaySubplot:
def __init__(
self, list_videos, numRows, numColumns, step=0, median_filter_flag=False, color="gray"
):
"""This class shows several videos (with the same number of frames) at
once.
Parameters
----------
list_videos: list of NDArray
List of videos.
numRows: int
It defines number of rows in sub-display.
numColumns: int
It defines number of columns in sub-display.
median_filter_flag: bool
In case it defines as True, a median filter is applied with size 3
to remove hot pixel effect.
color: str
It defines the colormap for visualization.
step: int
Stride between visualization frames.
"""
self.color = color
self.jump = step
self.pressed_key = {}
self.memory = 0
self.video = list_videos
self.median_filter_flag = median_filter_flag
self.numRows = numRows
self.numColumns = numColumns
self.numVideos = len(list_videos)
self.max_numberFrames = np.max([vid_.shape[0] for vid_ in list_videos])
self.fig, self.axes = plt.subplots(numRows, numColumns)
self.im = []
self.cb = []
self.tx = []
for vid_, ax in zip(self.video, self.axes):
self.div = make_axes_locatable(ax)
self.cax = self.div.append_axes("right", "5%", "5%")
self.cv0 = vid_[0, :, :]
im_ = ax.imshow(self.cv0, origin="lower", cmap=self.color)
self.im.append(im_)
self.cb.append(
self.fig.colorbar(im_, cax=self.cax, format=ticker.FuncFormatter(self.fmt))
)
self.tx.append(ax.set_title("Frame 0"))
self.ani = animation.FuncAnimation(
self.fig, self.animate, frames=np.arange(0, self.max_numberFrames), blit=False
)
plt.show()
def animate(self, i_):
print(i_)
self.fig.canvas.mpl_connect("key_press_event", self.press)
if self.pressed_key.get("key") == "q":
pl.close("all")
i_ = i_ + self.memory
if i_ < self.max_numberFrames:
for vid_, im, tx in zip(self.video, self.im, self.tx):
if self.median_filter_flag:
arr = median_filter(vid_[int(i_), :, :], 3)
else:
arr = median_filter(vid_[int(i_), :, :], 3)
vmax = np.max(arr)
vmin = np.min(arr)
im.set_data(arr)
im.set_clim(vmin, vmax)
tx.set_text("Frame {0}".format(i_))
self.memory = i_ + self.jump
else:
self.ani.event_source.stop()
def fmt(self, x, pos):
a, b = "{:.2e}".format(x).split("e")
b = int(b)
return r"${} \times 10^{{{}}}$".format(a, b)
def press(self, event):
print("press", event.key)
if event.key == "q":
# close the current figure
plt.close(event.canvas.figure)
self.pressed_key["key"] = event.key
[docs]class Display:
def __init__(self, video, step=0, color="gray", time_delay=0, median_filter_flag=False):
"""This class display the video.
Parameters
----------
video:(NDArray)
Input video.
median_filter_flag: bool
In case it defines as True, a median filter is applied with size 3
to remove hot pixel effect.
color: str
It defines the colormap for visualization.
time_delay: float
Delay between frames in milliseconds.
step: int
Stride between visualization frames.
"""
self.video = video
self.median_filter_flag = median_filter_flag
self.jump = step
self.time_delay = time_delay
self.memory = 0
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.pressed_key = {}
self.div = make_axes_locatable(self.ax)
self.cax = self.div.append_axes("right", "5%", "5%")
self.cv0 = self.video[0, :, :]
self.im = self.ax.imshow(
self.cv0, origin="lower", cmap=color
) # Here make an AxesImage rather than contour
self.cb = self.fig.colorbar(self.im, cax=self.cax, format=ticker.FuncFormatter(self.fmt))
self.tx = self.ax.set_title("Frame 0")
self.ani = animation.FuncAnimation(
self.fig,
self.animate,
frames=self.video.shape[0],
blit=False,
interval=self.time_delay,
repeat=False,
cache_frame_data=False,
)
plt.show()
def animate(self, i_):
self.fig.canvas.mpl_connect("key_press_event", self.press)
if self.pressed_key.get("key") == "q":
pl.close("all")
i_ = i_ + self.memory
if i_ < self.video.shape[0]:
arr = self.video[i_, :, :]
vmax = np.max(arr)
vmin = np.min(arr)
if self.median_filter_flag:
frame_v = median_filter(arr, 3)
else:
frame_v = arr
self.im.set_data(frame_v)
self.im.set_clim(vmin, vmax)
self.tx.set_text("Frame {0}".format(i_))
self.memory = i_ + self.jump
else:
self.ani.event_source.stop()
def fmt(self, x, pos):
a, b = "{:.2e}".format(x).split("e")
b = int(b)
return r"${} \times 10^{{{}}}$".format(a, b)
def press(self, event):
print("press", event.key)
if event.key == "q":
# close the current figure
plt.close(event.canvas.figure)
self.pressed_key["key"] = event.key
[docs]class DisplayPSFs_subplotLocalizationDisplay:
def __init__(
self,
list_videos,
list_df_PSFs,
list_titles,
numRows,
numColumns,
color="gray",
median_filter_flag=False,
imgSizex=5,
imgSizey=5,
time_delay=0.1,
):
"""This class shows several videos (with the same number of frames) at
once while highlight localize PSFs.
Parameters
----------
list_videos: list of NDArray
List of videos
list_df_PSFs: list panda data_frame
List Data Frames that contains the location of PSFs for each video.
numRows: int
It defines number of rows in sub-display
numColumns: int
It defines number of columns in sub-display
list_titles: list str
List of titles for each sub plot
median_filter_flag: bool
In case it defines as True, a median filter is applied with size 3
to remove hot pixel effect.
color: str
It defines the colormap for visualization.
imgSizex: int
Image length size.
imgSizey: int
Image width size.
time_delay: float
Delay between frames in milliseconds.
"""
self.list_video = list_videos
self.list_df_PSFs = list_df_PSFs
self.median_filter_flag = median_filter_flag
self.memory = 0
self.numRows = numRows
self.numColumns = numColumns
self.numVideos = len(list_videos)
if list_titles is None:
self.list_titles = [None for _ in range(self.numVideos)]
else:
self.list_titles = list_titles
self.fig = plt.figure(figsize=(imgSizex, imgSizey))
self.imgGrid_list = []
for i_ in range(1, self.numRows * self.numColumns + 1):
img_grid_ = self.fig.add_subplot(self.numRows, self.numColumns, i_)
img_grid_.axis("off")
self.imgGrid_list.append(img_grid_)
self.fig.tight_layout()
self.fig.subplots_adjust(wspace=0.18)
self.pressed_key = {}
self.div_0 = []
self.cax_0 = []
self.im_0 = []
self.cb_0 = []
self.tx_0 = []
for ax_, vid_, df_PSFs, tit_ in zip(
self.imgGrid_list, self.list_video, self.list_df_PSFs, self.list_titles
):
div_ = make_axes_locatable(ax_)
self.div_0.append(div_)
cax_ = div_.append_axes("right", "5%", "5%")
self.cax_0.append(cax_)
im_ = ax_.imshow(vid_[0, :, :], origin="lower", cmap=color)
self.im_0.append(im_)
self.cb_0.append(
self.fig.colorbar(im_, cax=cax_, format=ticker.FuncFormatter(self.fmt))
)
if tit_ is not None:
self.tx_0.append(ax_.set_title(tit_ + ", Frame 0"))
else:
self.tx_0.append(ax_.set_title("Frame 0"))
particle = df_PSFs.loc[df_PSFs["frame"] == 0]
particle_X = particle["x"].tolist()
particle_Y = particle["y"].tolist()
particle_sigma = particle["sigma"].tolist()
for j_ in range(len(particle_X)):
y = int(particle_Y[j_])
x = int(particle_X[j_])
sigma = particle_sigma[j_]
ax_.add_patch(
Circle(
(x, y),
radius=np.sqrt(2) * sigma,
edgecolor="r",
facecolor="none",
linewidth=2,
)
)
self.ani = animation.FuncAnimation(
self.fig,
self.animate,
frames=self.list_video[0].shape[0],
blit=False,
interval=time_delay,
repeat=True,
cache_frame_data=False,
)
plt.show()
def animate(self, i_):
self.fig.canvas.mpl_connect("key_press_event", self.press)
if self.pressed_key.get("key") == "q":
pl.close("all")
if i_ < self.list_video[0].shape[0]:
for idx_, (ax_, vid_, df_PSFs, tit_) in enumerate(
zip(self.imgGrid_list, self.list_video, self.list_df_PSFs, self.list_titles)
):
[p.remove() for p in reversed(ax_.patches)]
arr = vid_[i_, :, :]
vmax = np.max(arr)
vmin = np.min(arr)
if self.median_filter_flag:
frame_v = median_filter(arr, 3)
else:
frame_v = arr
self.im_0[idx_].set_data(frame_v)
self.im_0[idx_].set_clim(vmin, vmax)
if tit_ is not None:
self.tx_0[idx_].set_text(tit_ + ", Frame {}".format(i_))
else:
self.tx_0[idx_].set_text("Frame {}".format(i_))
particle = df_PSFs.loc[df_PSFs["frame"] == i_]
particle_X = particle["x"].tolist()
particle_Y = particle["y"].tolist()
particle_sigma = particle["sigma"].tolist()
for j_ in range(len(particle_X)):
y = int(particle_Y[j_])
x = int(particle_X[j_])
sigma = particle_sigma[j_]
ax_.add_patch(
Circle(
(x, y),
radius=np.sqrt(2) * sigma,
edgecolor="r",
facecolor="none",
linewidth=2,
)
)
else:
self.ani.event_source.stop()
def fmt(self, x, pos):
a, b = "{:.2e}".format(x).split("e")
b = int(b)
return r"${} \times 10^{{{}}}$".format(a, b)
def press(self, event):
print("press", event.key)
if event.key == "q":
# close the current figure
plt.close(event.canvas.figure)
self.pressed_key["key"] = event.key
[docs]class DisplayDataFramePSFsLocalization(QRunnable):
def __init__(self, video, df_PSFs, time_delay=0.1, GUI_progressbar=False, *args, **kwargs):
"""
This class displays video while highlighting PSFs.
Parameters
----------
video: NDArray
Input video.
df_PSFs: panda data_frame
Data Frames that contains the location of PSFs.
time_delay: float
Delay between frames in milliseconds.
GUI_progressbar: bool
This actives GUI progress bar
"""
self.cpu = CPUConfigurations()
super(DisplayDataFramePSFsLocalization, self).__init__()
self.video = video
self.time_delay = time_delay
self.df_PSFs = df_PSFs
self.pressed_key = {}
self.list_line = []
if "particle" in self.df_PSFs.keys():
self.list_particles_idx = self.df_PSFs.particle.unique()
else:
self.df_PSFs["particle"] = 0
self.list_particles_idx = self.df_PSFs.particle.unique()
self.GUI_progressbar = GUI_progressbar
self.args = args
self.kwargs = kwargs
colors_ = cm.autumn(np.linspace(0, 1, len(self.list_particles_idx)))
self.colors = colors_[0 : len(self.list_particles_idx), :]
self.obj_connection = SignalConnection()
[docs] @Slot()
def run(self):
self.gif_genrator(*self.args, **self.kwargs)
if self.GUI_progressbar is True:
self.obj_connection.finished.emit()
def show_psf(
self, jump=1, display_history=True, color_map="gray", save_flag=False, save_path=None
):
self.norm_vid = Normalization(self.video).normalized_image_specific()
img = None
pl.ion()
fig, ax = plt.subplots(1)
ax.set_aspect("equal")
if len(self.norm_vid.shape) == 3 and self.norm_vid.shape[0] > 0:
for frame_number in range(0, self.norm_vid.shape[0] - jump, jump):
if self.GUI_progressbar:
fig.canvas.mpl_connect("close_event", self.closeEvent)
else:
fig.canvas.mpl_connect("key_press_event", self.press)
if self.pressed_key.get("key") == "q":
pl.close("all")
break
im = self.norm_vid[frame_number, :, :]
pl.pause(self.time_delay)
[p.remove() for p in reversed(ax.patches)]
if img is None:
img = pl.imshow(im, cmap=color_map)
self.draw_circles(self.df_PSFs, ax, frame_number, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number, color=self.colors)
else:
img.set_data(im)
pl.draw()
pl.title("Press Q for exit\n Frame: " + str(frame_number))
self.draw_circles(self.df_PSFs, ax, frame_number, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number, color=self.colors)
if save_flag:
pl.savefig(save_path + "fig" + str(frame_number) + ".png")
else:
im = self.norm_vid
pl.pause(self.time_delay)
[p.remove() for p in reversed(ax.patches)]
if img is None:
img = pl.imshow(im)
self.draw_circles(self.df_PSFs, ax, frame_number=0, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number=0, color=self.colors)
else:
img.set_data(im)
pl.draw()
self.draw_circles(self.df_PSFs, ax, frame_number=0, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number=0, color=self.colors)
def draw_circles(self, list_psf, ax, frame_number, color="red"):
particle = list_psf.loc[list_psf["frame"] == frame_number]
particle_X = particle["x"].tolist()
particle_Y = particle["y"].tolist()
particle_sigma = particle["sigma"].tolist()
particle_labels = particle["particle"].tolist()
for j_, p_label in zip(range(len(particle_X)), particle_labels):
y = particle_Y[j_]
x = particle_X[j_]
sigma = particle_sigma[j_]
color_idx = list(self.list_particles_idx).index(p_label)
c = patches.Circle(
(x, y),
np.sqrt(2) * sigma,
color=self.colors[color_idx],
alpha=1,
fill=False,
linewidth=2,
)
ax.add_patch(c)
pl.draw()
def draw_trajectory(self, list_psf, ax, frame_number, color="red"):
particle = list_psf.loc[list_psf["frame"] == frame_number]
particle_labels = particle["particle"].tolist()
self.list_line = []
for label in particle_labels:
all_particle_ = list_psf.loc[list_psf["particle"] == label]
particle_f = all_particle_["frame"].tolist()
particle_X = all_particle_["x"].tolist()
particle_Y = all_particle_["y"].tolist()
particle_sigma = all_particle_["sigma"].tolist()
list_center = None
for f, x, y, _ in zip(particle_f, particle_X, particle_Y, particle_sigma):
if list_center is None:
center_position = [[x, y]]
list_center = center_position
list_center_array = np.asarray(list_center)
else:
center_position = [x, y]
list_center.append(center_position)
list_center_array = np.asarray(list_center)
if f <= frame_number:
color_idx = list(self.list_particles_idx).index(label)
(ln,) = ax.plot(
list_center_array[:, 0],
list_center_array[:, 1],
color=self.colors[color_idx],
linewidth=1,
)
self.list_line.append(ln)
pl.draw()
def draw_trajectory_1(self, list_psf, ax, frame_number, color="red"):
particle = list_psf.loc[list_psf["frame"] == frame_number]
particle_labels = particle["particle"].tolist()
list_line = []
for label in particle_labels:
all_particle_ = list_psf.loc[list_psf["particle"] == label]
particle_f = all_particle_["frame"].tolist()
particle_X = all_particle_["x"].tolist()
particle_Y = all_particle_["y"].tolist()
particle_sigma = all_particle_["sigma"].tolist()
list_center = None
flag_loop = True
(ln,) = ax.plot(
particle_X, particle_Y, color=self.colors[label], linewidth=1, alpha=0.7
)
list_line.append(ln)
pl.draw()
plt.pause(self.time_delay)
[ln_.remove() for ln_ in list_line]
def draw_trajectory_2(self, list_psf, ax, frame_number, color="red"):
particle = list_psf.loc[list_psf["frame"] == frame_number]
particle_labels = particle["particle"].tolist()
list_line = []
for label in particle_labels:
all_particle_ = list_psf.loc[list_psf["particle"] == label]
particle_f = all_particle_["frame"].tolist()
particle_X = all_particle_["x"].tolist()
particle_Y = all_particle_["y"].tolist()
particle_sigma = all_particle_["sigma"].tolist()
list_center = None
flag_loop = True
while flag_loop:
list_line = []
for f, x, y, _ in zip(particle_f, particle_X, particle_Y, particle_sigma):
if list_center is None:
center_position = [[x, y]]
list_center = center_position
list_center_array = np.asarray(list_center)
else:
center_position = [x, y]
list_center.append(center_position)
list_center_array = np.asarray(list_center)
if f <= frame_number:
(ln,) = ax.plot(
list_center_array[:, 0],
list_center_array[:, 1],
color=self.colors[label],
linewidth=1,
alpha=0.55,
)
list_line.append(ln)
pl.draw()
plt.pause(self.time_delay)
else:
flag_loop = False
pl.draw()
plt.pause(self.time_delay)
[ln_.remove() for ln_ in list_line]
def press(self, event):
print("press", event.key)
if event.key == "q":
plt.close(event.canvas.figure)
self.pressed_key["key"] = event.key
def closeEvent(self, event):
plt.close(event.canvas.figure)
self.pressed_key["key"] = "q"
def make_gif(self, frame_number, display_history=True, color_map="gray"):
matplotlib.use("Agg")
img = None
pl.ioff()
plt.ioff()
fig, ax = plt.subplots(1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_aspect("equal")
if len(self.norm_vid.shape) == 3 and self.norm_vid.shape[0] > 0:
im = self.norm_vid[frame_number, :, :]
pl.pause(self.time_delay)
[p.remove() for p in reversed(ax.patches)]
if img is None:
img = pl.imshow(im, cmap=color_map)
self.draw_circles(self.df_PSFs, ax, frame_number, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number, color=self.colors)
else:
img.set_data(im)
pl.draw()
pl.title(str(frame_number))
self.draw_circles(self.df_PSFs, ax, frame_number, color=self.colors)
if display_history:
[ln_.remove() for ln_ in self.list_line]
self.draw_trajectory(self.df_PSFs, ax, frame_number, color=self.colors)
fig.canvas.draw() # draw the canvas, cache the renderer
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
pl.close(fig)
return image
def gif_genrator(self, save_path, jump=5, fps=1.0):
self.norm_vid = Normalization(self.video).normalized_image_specific()
image_ = []
if self.cpu.parallel_active is True:
image_ = Parallel(
n_jobs=self.cpu.n_jobs, backend=self.cpu.backend, verbose=self.cpu.verbose
)(
delayed(self.make_gif)(frame_number)
for frame_number in range(0, self.norm_vid.shape[0] - jump, jump)
)
imageio.mimsave(save_path, image_, fps=fps)
else:
for frame_number in tqdm(range(0, self.norm_vid.shape[0] - jump, jump)):
image_.append(self.make_gif(frame_number))
if self.GUI_progressbar:
self.obj_connection.updateProgress_.emit(frame_number)
imageio.mimsave(save_path, image_, fps=fps)
def histogram_of_each_frames(
frame, bins="auto", range=None, normed=None, weights=None, density=None
):
return np.histogram(
frame, bins=bins, range=range, normed=normed, weights=weights, density=density
)
def histogram_1D_signal(signal, bins="auto"):
if bins == "auto":
plt.hist(signal, bins=bins) # arguments are passed to np.histogram
plt.title("Histogram with 'auto' bins")
plt.show()
else:
plt.hist(signal, bins=bins) # arguments are passed to np.histogram
plt.title("Histogram with" + str(bins) + "bins")
plt.show()