Coverage for peakipy/plotting.py: 96%

181 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-15 20:42 -0400

1from dataclasses import dataclass, field 

2from typing import List 

3 

4import pandas as pd 

5import numpy as np 

6import plotly.graph_objects as go 

7import matplotlib.pyplot as plt 

8from matplotlib import cm 

9from matplotlib.widgets import Button 

10from matplotlib.backends.backend_pdf import PdfPages 

11from rich import print 

12 

13from peakipy.io import Pseudo3D 

14from peakipy.utils import df_to_rich_table, bad_color_selection, bad_column_selection 

15 

16 

17@dataclass 

18class PlottingDataForPlane: 

19 pseudo3D: Pseudo3D 

20 plane_id: int 

21 plane_lineshape_parameters: pd.DataFrame 

22 X: np.array 

23 Y: np.array 

24 mask: np.array 

25 individual_masks: List[np.array] 

26 sim_data: np.array 

27 sim_data_singles: List[np.array] 

28 min_x: int 

29 max_x: int 

30 min_y: int 

31 max_y: int 

32 fit_color: str 

33 data_color: str 

34 rcount: int 

35 ccount: int 

36 

37 x_plot: np.array = field(init=False) 

38 y_plot: np.array = field(init=False) 

39 masked_data: np.array = field(init=False) 

40 masked_sim_data: np.array = field(init=False) 

41 residual: np.array = field(init=False) 

42 single_colors: List = field(init=False) 

43 

44 def __post_init__(self): 

45 self.plane_data = self.pseudo3D.data[self.plane_id] 

46 self.masked_data = self.plane_data.copy() 

47 self.masked_sim_data = self.sim_data.copy() 

48 self.masked_data[~self.mask] = np.nan 

49 self.masked_sim_data[~self.mask] = np.nan 

50 

51 self.x_plot = self.pseudo3D.uc_f2.ppm( 

52 self.X[self.min_y : self.max_y, self.min_x : self.max_x] 

53 ) 

54 self.y_plot = self.pseudo3D.uc_f1.ppm( 

55 self.Y[self.min_y : self.max_y, self.min_x : self.max_x] 

56 ) 

57 self.masked_data = self.masked_data[ 

58 self.min_y : self.max_y, self.min_x : self.max_x 

59 ] 

60 self.sim_plot = self.masked_sim_data[ 

61 self.min_y : self.max_y, self.min_x : self.max_x 

62 ] 

63 self.residual = self.masked_data - self.sim_plot 

64 

65 for single_mask, single in zip(self.individual_masks, self.sim_data_singles): 

66 single[~single_mask] = np.nan 

67 self.sim_data_singles = [ 

68 sim_data_single[self.min_y : self.max_y, self.min_x : self.max_x] 

69 for sim_data_single in self.sim_data_singles 

70 ] 

71 self.single_colors = [ 

72 cm.viridis(i) for i in np.linspace(0, 1, len(self.sim_data_singles)) 

73 ] 

74 

75 

76def plot_data_is_valid(plot_data: PlottingDataForPlane) -> bool: 

77 if len(plot_data.x_plot) < 1 or len(plot_data.y_plot) < 1: 

78 print( 

79 f"[red]Nothing to plot for cluster {int(plot_data.plane_lineshape_parameters.clustid)}[/red]" 

80 ) 

81 print(f"[red]x={plot_data.x_plot},y={plot_data.y_plot}[/red]") 

82 print( 

83 df_to_rich_table( 

84 plot_data.plane_lineshape_parameters, 

85 title="", 

86 columns=bad_column_selection, 

87 styles=bad_color_selection, 

88 ) 

89 ) 

90 plt.close() 

91 validated = False 

92 # print(Fore.RED + "Maybe your F1/F2 radii for fitting were too small...") 

93 elif plot_data.masked_data.shape[0] == 0 or plot_data.masked_data.shape[1] == 0: 

94 print(f"[red]Nothing to plot for cluster {int(plot_data.plane.clustid)}[/red]") 

95 print( 

96 df_to_rich_table( 

97 plot_data.plane_lineshape_parameters, 

98 title="Bad plane", 

99 columns=bad_column_selection, 

100 styles=bad_color_selection, 

101 ) 

102 ) 

103 spec_lim_f1 = " - ".join( 

104 ["%8.3f" % i for i in plot_data.pseudo3D.f1_ppm_limits] 

105 ) 

106 spec_lim_f2 = " - ".join( 

107 ["%8.3f" % i for i in plot_data.pseudo3D.f2_ppm_limits] 

108 ) 

109 print(f"Spectrum limits are {plot_data.pseudo3D.f2_label:4s}:{spec_lim_f2} ppm") 

110 print(f" {plot_data.pseudo3D.f1_label:4s}:{spec_lim_f1} ppm") 

111 plt.close() 

112 validated = False 

113 else: 

114 validated = True 

115 return validated 

116 

117 

118def create_matplotlib_figure( 

119 plot_data: PlottingDataForPlane, 

120 pdf: PdfPages, 

121 individual=False, 

122 label=False, 

123 ccpn_flag=False, 

124 show=True, 

125 test=False, 

126): 

127 fig = plt.figure(figsize=(10, 6)) 

128 ax = fig.add_subplot(projection="3d") 

129 if plot_data_is_valid(plot_data): 

130 cset = ax.contourf( 

131 plot_data.x_plot, 

132 plot_data.y_plot, 

133 plot_data.residual, 

134 zdir="z", 

135 offset=np.nanmin(plot_data.masked_data) * 1.1, 

136 alpha=0.5, 

137 cmap=cm.coolwarm, 

138 ) 

139 cbl = fig.colorbar(cset, ax=ax, shrink=0.5, format="%.2e") 

140 cbl.ax.set_title("Residual", pad=20) 

141 

142 if individual: 

143 #  for plotting single fit surfaces 

144 single_colors = [ 

145 cm.viridis(i) 

146 for i in np.linspace(0, 1, len(plot_data.sim_data_singles)) 

147 ] 

148 [ 

149 ax.plot_surface( 

150 plot_data.x_plot, 

151 plot_data.y_plot, 

152 z_single, 

153 color=c, 

154 alpha=0.5, 

155 ) 

156 for c, z_single in zip(single_colors, plot_data.sim_data_singles) 

157 ] 

158 ax.plot_wireframe( 

159 plot_data.x_plot, 

160 plot_data.y_plot, 

161 plot_data.sim_plot, 

162 # colors=[cm.coolwarm(i) for i in np.ravel(residual)], 

163 colors=plot_data.fit_color, 

164 linestyle="--", 

165 label="fit", 

166 rcount=plot_data.rcount, 

167 ccount=plot_data.ccount, 

168 ) 

169 ax.plot_wireframe( 

170 plot_data.x_plot, 

171 plot_data.y_plot, 

172 plot_data.masked_data, 

173 colors=plot_data.data_color, 

174 linestyle="-", 

175 label="data", 

176 rcount=plot_data.rcount, 

177 ccount=plot_data.ccount, 

178 ) 

179 ax.set_ylabel(plot_data.pseudo3D.f1_label) 

180 ax.set_xlabel(plot_data.pseudo3D.f2_label) 

181 

182 # axes will appear inverted 

183 ax.view_init(30, 120) 

184 

185 title = f"Plane={plot_data.plane_id},Cluster={plot_data.plane_lineshape_parameters.clustid.iloc[0]}" 

186 plt.title(title) 

187 print(f"[green]Plotting: {title}[/green]") 

188 out_str = "Volumes (Heights)\n===========\n" 

189 for _, row in plot_data.plane_lineshape_parameters.iterrows(): 

190 out_str += f"{row.assignment} = {row.amp:.3e} ({row.height:.3e})\n" 

191 if label: 

192 ax.text( 

193 row.center_x_ppm, 

194 row.center_y_ppm, 

195 row.height * 1.2, 

196 row.assignment, 

197 (1, 1, 1), 

198 ) 

199 

200 ax.text2D( 

201 -0.5, 

202 1.0, 

203 out_str, 

204 transform=ax.transAxes, 

205 fontsize=10, 

206 fontfamily="sans-serif", 

207 va="top", 

208 bbox=dict(boxstyle="round", ec="k", fc="k", alpha=0.5), 

209 ) 

210 

211 ax.legend() 

212 

213 if show: 

214 

215 def exit_program(event): 

216 exit() 

217 

218 def next_plot(event): 

219 plt.close() 

220 

221 axexit = plt.axes([0.81, 0.05, 0.1, 0.075]) 

222 bnexit = Button(axexit, "Exit") 

223 bnexit.on_clicked(exit_program) 

224 axnext = plt.axes([0.71, 0.05, 0.1, 0.075]) 

225 bnnext = Button(axnext, "Next") 

226 bnnext.on_clicked(next_plot) 

227 if test: 

228 return 

229 if ccpn_flag: 

230 plt.show(windowTitle="", size=(1000, 500)) 

231 else: 

232 plt.show() 

233 else: 

234 pdf.savefig() 

235 

236 plt.close() 

237 

238 

239def create_plotly_wireframe_lines(plot_data: PlottingDataForPlane): 

240 lines = [] 

241 show_legend = lambda x: x < 1 

242 showlegend = False 

243 # make simulated data wireframe 

244 line_marker = dict(color=plot_data.fit_color, width=4) 

245 counter = 0 

246 for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.sim_plot): 

247 showlegend = show_legend(counter) 

248 lines.append( 

249 go.Scatter3d( 

250 x=i, 

251 y=j, 

252 z=k, 

253 mode="lines", 

254 line=line_marker, 

255 name="fit", 

256 showlegend=showlegend, 

257 ) 

258 ) 

259 counter += 1 

260 for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.sim_plot.T): 

261 lines.append( 

262 go.Scatter3d( 

263 x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend 

264 ) 

265 ) 

266 # make experimental data wireframe 

267 line_marker = dict(color=plot_data.data_color, width=4) 

268 counter = 0 

269 for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.masked_data): 

270 showlegend = show_legend(counter) 

271 lines.append( 

272 go.Scatter3d( 

273 x=i, 

274 y=j, 

275 z=k, 

276 mode="lines", 

277 name="data", 

278 line=line_marker, 

279 showlegend=showlegend, 

280 ) 

281 ) 

282 counter += 1 

283 for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.masked_data.T): 

284 lines.append( 

285 go.Scatter3d( 

286 x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend 

287 ) 

288 ) 

289 

290 return lines 

291 

292 

293def construct_surface_legend_string(row): 

294 surface_legend = "" 

295 surface_legend += row.assignment 

296 return surface_legend 

297 

298 

299def create_plotly_surfaces(plot_data: PlottingDataForPlane): 

300 data = [] 

301 color_scale_values = np.linspace(0, 1, len(plot_data.single_colors)) 

302 color_scale = [ 

303 [val, f"rgb({', '.join('%d'%(i*255) for i in c[0:3])})"] 

304 for val, c in zip(color_scale_values, plot_data.single_colors) 

305 ] 

306 for val, individual_peak, row in zip( 

307 color_scale_values, 

308 plot_data.sim_data_singles, 

309 plot_data.plane_lineshape_parameters.itertuples(), 

310 ): 

311 name = construct_surface_legend_string(row) 

312 colors = np.zeros(shape=individual_peak.shape) + val 

313 data.append( 

314 go.Surface( 

315 z=individual_peak, 

316 x=plot_data.x_plot, 

317 y=plot_data.y_plot, 

318 opacity=0.5, 

319 surfacecolor=colors, 

320 colorscale=color_scale, 

321 showscale=False, 

322 cmin=0, 

323 cmax=1, 

324 name=name, 

325 ) 

326 ) 

327 return data 

328 

329 

330def create_residual_contours(plot_data: PlottingDataForPlane): 

331 contours = go.Contour( 

332 x=plot_data.x_plot[0], y=plot_data.y_plot.T[0], z=plot_data.residual 

333 ) 

334 return contours 

335 

336 

337def create_residual_figure(plot_data: PlottingDataForPlane): 

338 data = create_residual_contours(plot_data) 

339 fig = go.Figure(data=data) 

340 fig.update_layout( 

341 title="Fit residuals", 

342 xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", 

343 yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", 

344 xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), 

345 yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), 

346 

347 ) 

348 return fig 

349 

350 

351def create_plotly_figure(plot_data: PlottingDataForPlane): 

352 lines = create_plotly_wireframe_lines(plot_data) 

353 surfaces = create_plotly_surfaces(plot_data) 

354 fig = go.Figure(data=lines + surfaces) 

355 fig = update_axis_ranges(fig, plot_data) 

356 return fig 

357 

358 

359def update_axis_ranges(fig, plot_data: PlottingDataForPlane): 

360 fig.update_layout( 

361 scene=dict( 

362 xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]), 

363 yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]), 

364 xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm", 

365 yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm", 

366 annotations=make_annotations(plot_data), 

367 ) 

368 ) 

369 return fig 

370 

371 

372def make_annotations(plot_data: PlottingDataForPlane): 

373 annotations = [] 

374 for row in plot_data.plane_lineshape_parameters.itertuples(): 

375 annotations.append( 

376 dict( 

377 showarrow=True, 

378 x=row.center_x_ppm, 

379 y=row.center_y_ppm, 

380 z=row.height * 1.0, 

381 text=row.assignment, 

382 opacity=0.8, 

383 textangle=0, 

384 arrowsize=1, 

385 ) 

386 ) 

387 return annotations 

388 

389 

390def validate_sample_count(sample_count): 

391 if type(sample_count) == int: 

392 sample_count = sample_count 

393 else: 

394 raise TypeError("Sample count (ccount, rcount) should be an integer") 

395 return sample_count 

396 

397 

398def unpack_plotting_colors(colors): 

399 match colors: 

400 case (data_color, fit_color): 

401 data_color, fit_color = colors 

402 case _: 

403 data_color, fit_color = "green", "blue" 

404 return data_color, fit_color