Coverage for peakipy/fitting.py: 98%

624 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-15 20:42 -0400

1import re 

2from pathlib import Path 

3from dataclasses import dataclass, field 

4from typing import List, Tuple, Optional 

5 

6import numpy as np 

7from numpy import sqrt 

8import pandas as pd 

9from rich import print 

10from lmfit import Model, Parameters, Parameter 

11from lmfit.model import ModelResult 

12from pydantic import BaseModel 

13 

14from peakipy.lineshapes import ( 

15 Lineshape, 

16 pvoigt2d, 

17 pv_pv, 

18 pv_g, 

19 pv_l, 

20 voigt2d, 

21 gaussian_lorentzian, 

22 get_lineshape_function, 

23) 

24from peakipy.constants import log2 

25 

26 

27class FitDataModel(BaseModel): 

28 plane: int 

29 clustid: int 

30 assignment: str 

31 memcnt: int 

32 amp: float 

33 height: float 

34 center_x_ppm: float 

35 center_y_ppm: float 

36 fwhm_x_hz: float 

37 fwhm_y_hz: float 

38 lineshape: str 

39 x_radius: float 

40 y_radius: float 

41 center_x: float 

42 center_y: float 

43 sigma_x: float 

44 sigma_y: float 

45 

46 

47class FitDataModelPVGL(FitDataModel): 

48 fraction: float 

49 

50 

51class FitDataModelVoigt(FitDataModel): 

52 fraction: float 

53 gamma_x: float 

54 gamma_y: float 

55 

56 

57class FitDataModelPVPV(FitDataModel): 

58 fraction_x: float 

59 fraction_y: float 

60 

61 

62def validate_fit_data(dict): 

63 lineshape = dict.get("lineshape") 

64 if lineshape in ["PV", "G", "L"]: 

65 fit_data = FitDataModelPVGL(**dict) 

66 elif lineshape == "V": 

67 fit_data = FitDataModelVoigt(**dict) 

68 else: 

69 fit_data = FitDataModelPVPV(**dict) 

70 

71 return fit_data.model_dump() 

72 

73 

74def validate_fit_dataframe(df): 

75 validated_fit_data = [] 

76 for _, row in df.iterrows(): 

77 fit_data = validate_fit_data(row.to_dict()) 

78 validated_fit_data.append(fit_data) 

79 return pd.DataFrame(validated_fit_data) 

80 

81 

82def make_mask(data, c_x, c_y, r_x, r_y): 

83 """Create and elliptical mask 

84 

85 Generate an elliptical boolean mask with center c_x/c_y in points 

86 with radii r_x and r_y. Used to generate fit mask 

87 

88 :param data: 2D array 

89 :type data: np.array 

90 

91 :param c_x: x center 

92 :type c_x: float 

93 

94 :param c_y: y center 

95 :type c_y: float 

96 

97 :param r_x: radius in x 

98 :type r_x: float 

99 

100 :param r_y: radius in y 

101 :type r_y: float 

102 

103 :return: boolean mask of data.shape 

104 :rtype: numpy.array 

105 

106 """ 

107 a, b = c_y, c_x 

108 n_y, n_x = data.shape 

109 y, x = np.ogrid[-a : n_y - a, -b : n_x - b] 

110 mask = x**2.0 / r_x**2.0 + y**2.0 / r_y**2.0 <= 1.0 

111 return mask 

112 

113 

114def fix_params(params, to_fix): 

115 """Set parameters to fix 

116 

117 

118 :param params: lmfit parameters 

119 :type params: lmfit.Parameters 

120 

121 :param to_fix: list of parameter name to fix 

122 :type to_fix: list 

123 

124 :return: updated parameter object 

125 :rtype: lmfit.Parameters 

126 

127 """ 

128 for k in params: 

129 for p in to_fix: 

130 if p in k: 

131 params[k].vary = False 

132 

133 return params 

134 

135 

136def get_params(params, name): 

137 ps = [] 

138 ps_err = [] 

139 names = [] 

140 prefixes = [] 

141 for k in params: 

142 if name in k: 

143 ps.append(params[k].value) 

144 ps_err.append(params[k].stderr) 

145 names.append(k) 

146 prefixes.append(k.split(name)[0]) 

147 return ps, ps_err, names, prefixes 

148 

149 

150@dataclass 

151class PeakLimits: 

152 """Given a peak position and linewidth in points determine 

153 the limits based on the data 

154 

155 Arguments 

156 --------- 

157 peak: pd.DataFrame 

158 peak is a row from a pandas dataframe 

159 data: np.array 

160 2D numpy array 

161 """ 

162 

163 peak: pd.DataFrame 

164 data: np.array 

165 min_x: int = field(init=False) 

166 max_x: int = field(init=False) 

167 min_y: int = field(init=False) 

168 max_y: int = field(init=False) 

169 

170 def __post_init__(self): 

171 assert self.peak.Y_AXIS <= self.data.shape[0] 

172 assert self.peak.X_AXIS <= self.data.shape[1] 

173 self.max_y = int(np.ceil(self.peak.Y_AXIS + self.peak.YW)) + 1 

174 if self.max_y > self.data.shape[0]: 

175 self.max_y = self.data.shape[0] 

176 self.max_x = int(np.ceil(self.peak.X_AXIS + self.peak.XW)) + 1 

177 if self.max_x > self.data.shape[1]: 

178 self.max_x = self.data.shape[1] 

179 

180 self.min_y = int(self.peak.Y_AXIS - self.peak.YW) 

181 if self.min_y < 0: 

182 self.min_y = 0 

183 self.min_x = int(self.peak.X_AXIS - self.peak.XW) 

184 if self.min_x < 0: 

185 self.min_x = 0 

186 

187 

188def estimate_amplitude(peak, data): 

189 assert len(data.shape) == 2 

190 limits = PeakLimits(peak, data) 

191 amplitude_est = data[limits.min_y : limits.max_y, limits.min_x : limits.max_x].sum() 

192 return amplitude_est 

193 

194 

195def make_param_dict(peaks, data, lineshape: Lineshape = Lineshape.PV): 

196 """Make dict of parameter names using prefix""" 

197 

198 param_dict = {} 

199 

200 for _, peak in peaks.iterrows(): 

201 str_form = lambda x: "%s%s" % (to_prefix(peak.ASS), x) 

202 # using exact value of points (i.e decimal) 

203 param_dict[str_form("center_x")] = peak.X_AXISf 

204 param_dict[str_form("center_y")] = peak.Y_AXISf 

205 # estimate peak volume 

206 amplitude_est = estimate_amplitude(peak, data) 

207 param_dict[str_form("amplitude")] = amplitude_est 

208 # sigma linewidth esimate 

209 param_dict[str_form("sigma_x")] = peak.XW / 2.0 

210 param_dict[str_form("sigma_y")] = peak.YW / 2.0 

211 

212 match lineshape: 

213 case lineshape.V: 

214 #  Voigt G sigma from linewidth esimate 

215 param_dict[str_form("sigma_x")] = peak.XW / ( 

216 2.0 * sqrt(2.0 * log2) 

217 ) # 3.6013 

218 param_dict[str_form("sigma_y")] = peak.YW / ( 

219 2.0 * sqrt(2.0 * log2) 

220 ) # 3.6013 

221 #  Voigt L gamma from linewidth esimate 

222 param_dict[str_form("gamma_x")] = peak.XW / 2.0 

223 param_dict[str_form("gamma_y")] = peak.YW / 2.0 

224 # height 

225 # add height here 

226 

227 case lineshape.G: 

228 param_dict[str_form("fraction")] = 0.0 

229 case lineshape.L: 

230 param_dict[str_form("fraction")] = 1.0 

231 case lineshape.PV_PV: 

232 param_dict[str_form("fraction_x")] = 0.5 

233 param_dict[str_form("fraction_y")] = 0.5 

234 case _: 

235 param_dict[str_form("fraction")] = 0.5 

236 

237 return param_dict 

238 

239 

240def to_prefix(x): 

241 """ 

242 Peak assignments with characters that are not compatible lmfit model naming 

243 are converted to lmfit "safe" names. 

244 

245 :param x: Peak assignment to be used as prefix for lmfit model 

246 :type x: str 

247 

248 :returns: lmfit model prefix (_Peak_assignment_) 

249 :rtype: str 

250 

251 """ 

252 # must be string 

253 if type(x) != str: 

254 x = str(x) 

255 

256 prefix = "_" + x 

257 to_replace = [ 

258 [".", "_"], 

259 [" ", ""], 

260 ["{", "_"], 

261 ["}", "_"], 

262 ["[", "_"], 

263 ["]", "_"], 

264 ["-", ""], 

265 ["/", "or"], 

266 ["?", "maybe"], 

267 ["\\", ""], 

268 ["(", "_"], 

269 [")", "_"], 

270 ["@", "_at_"], 

271 ] 

272 for p in to_replace: 

273 prefix = prefix.replace(*p) 

274 

275 # Replace any remaining disallowed characters with underscore 

276 prefix = re.sub(r"[^a-z0-9_]", "_", prefix) 

277 return prefix + "_" 

278 

279 

280def make_models( 

281 model, 

282 peaks, 

283 data, 

284 lineshape: Lineshape = Lineshape.PV, 

285 xy_bounds=None, 

286): 

287 """Make composite models for multiple peaks 

288 

289 :param model: lineshape function 

290 :type model: function 

291 

292 :param peaks: instance of pandas.df.groupby("CLUSTID") 

293 :type peaks: pandas.df.groupby("CLUSTID") 

294 

295 :param data: NMR data 

296 :type data: numpy.array 

297 

298 :param lineshape: lineshape to use for fit (PV/G/L/PV_PV) 

299 :type lineshape: str 

300 

301 :param xy_bounds: bounds for peak centers (+/-x, +/-y) 

302 :type xy_bounds: tuple 

303 

304 :return mod: Composite lmfit model containing all peaks 

305 :rtype mod: lmfit.CompositeModel 

306 

307 :return p_guess: params for composite model with starting values 

308 :rtype p_guess: lmfit.Parameters 

309 

310 """ 

311 if len(peaks) == 1: 

312 # make model for first peak 

313 mod = Model(model, prefix="%s" % to_prefix(peaks.ASS.iloc[0])) 

314 # add parameters 

315 param_dict = make_param_dict( 

316 peaks, 

317 data, 

318 lineshape=lineshape, 

319 ) 

320 p_guess = mod.make_params(**param_dict) 

321 

322 elif len(peaks) > 1: 

323 # make model for first peak 

324 first_peak, *remaining_peaks = peaks.iterrows() 

325 mod = Model(model, prefix="%s" % to_prefix(first_peak[1].ASS)) 

326 for _, peak in remaining_peaks: 

327 mod += Model(model, prefix="%s" % to_prefix(peak.ASS)) 

328 

329 param_dict = make_param_dict( 

330 peaks, 

331 data, 

332 lineshape=lineshape, 

333 ) 

334 p_guess = mod.make_params(**param_dict) 

335 # add Peak params to p_guess 

336 

337 update_params(p_guess, param_dict, lineshape=lineshape, xy_bounds=xy_bounds) 

338 

339 return mod, p_guess 

340 

341 

342def update_params( 

343 params, param_dict, lineshape: Lineshape = Lineshape.PV, xy_bounds=None 

344): 

345 """Update lmfit parameters with values from Peak 

346 

347 :param params: lmfit parameters 

348 :type params: lmfit.Parameters object 

349 :param param_dict: parameters corresponding to each peak in fit 

350 :type param_dict: dict 

351 :param lineshape: Lineshape (PV, G, L, PV_PV etc.) 

352 :type lineshape: Lineshape 

353 :param xy_bounds: bounds on xy peak positions 

354 :type xy_bounds: tuple 

355 

356 :returns: None 

357 :rtype: None 

358 

359 ToDo 

360 -- deal with boundaries 

361 -- currently positions in points 

362 

363 """ 

364 for k, v in param_dict.items(): 

365 params[k].value = v 

366 # print("update", k, v) 

367 if "center" in k: 

368 if xy_bounds == None: 

369 # no bounds set 

370 pass 

371 else: 

372 if "center_x" in k: 

373 # set x bounds 

374 x_bound = xy_bounds[0] 

375 params[k].min = v - x_bound 

376 params[k].max = v + x_bound 

377 elif "center_y" in k: 

378 # set y bounds 

379 y_bound = xy_bounds[1] 

380 params[k].min = v - y_bound 

381 params[k].max = v + y_bound 

382 # pass 

383 # print( 

384 # "setting limit of %s, min = %.3e, max = %.3e" 

385 # % (k, params[k].min, params[k].max) 

386 # ) 

387 elif "sigma" in k: 

388 params[k].min = 0.0 

389 params[k].max = 1e4 

390 

391 elif "gamma" in k: 

392 params[k].min = 0.0 

393 params[k].max = 1e4 

394 # print( 

395 # "setting limit of %s, min = %.3e, max = %.3e" 

396 # % (k, params[k].min, params[k].max) 

397 # ) 

398 elif "fraction" in k: 

399 # fix weighting between 0 and 1 

400 params[k].min = 0.0 

401 params[k].max = 1.0 

402 

403 #  fix fraction of G or L 

404 match lineshape: 

405 case lineshape.G | lineshape.L: 

406 params[k].vary = False 

407 case lineshape.PV | lineshape.PV_PV: 

408 params[k].vary = True 

409 case _: 

410 pass 

411 

412 # return params 

413 

414 

415def make_mask_from_peak_cluster(group, data): 

416 mask = np.zeros(data.shape, dtype=bool) 

417 for _, peak in group.iterrows(): 

418 mask += make_mask( 

419 data, peak.X_AXISf, peak.Y_AXISf, peak.X_RADIUS, peak.Y_RADIUS 

420 ) 

421 return mask, peak 

422 

423 

424def select_reference_planes_using_indices(data, indices: List[int]): 

425 n_planes = data.shape[0] 

426 if indices == []: 

427 return data 

428 

429 max_index = max(indices) 

430 min_index = min(indices) 

431 

432 if max_index >= n_planes: 

433 raise IndexError( 

434 f"Your data has {n_planes}. You selected plane {max_index} (allowed indices between 0 and {n_planes-1})" 

435 ) 

436 elif min_index < (-1 * n_planes): 

437 raise IndexError( 

438 f"Your data has {n_planes}. You selected plane {min_index} (allowed indices between -{n_planes} and {n_planes-1})" 

439 ) 

440 else: 

441 data = data[indices] 

442 return data 

443 

444 

445def select_planes_above_threshold_from_masked_data(data, threshold=None): 

446 """This function returns planes with data above the threshold. 

447 

448 It currently uses absolute intensity values. 

449 Negative thresholds just result in return of the original data. 

450 

451 """ 

452 if threshold == None: 

453 selected_data = data 

454 else: 

455 selected_data = data[np.abs(data).max(axis=1) > threshold] 

456 

457 if selected_data.shape[0] == 0: 

458 selected_data = data 

459 

460 return selected_data 

461 

462 

463def validate_plane_selection(plane, pseudo3D): 

464 if (plane == []) or (plane == None): 

465 plane = list(range(pseudo3D.n_planes)) 

466 

467 elif max(plane) > (pseudo3D.n_planes - 1): 

468 raise ValueError( 

469 f"[red]There are {pseudo3D.n_planes} planes in your data you selected --plane {max(plane)}...[red]" 

470 f"plane numbering starts from 0." 

471 ) 

472 elif min(plane) < 0: 

473 raise ValueError( 

474 f"[red]Plane number can not be negative; you selected --plane {min(plane)}...[/red]" 

475 ) 

476 else: 

477 plane = sorted(plane) 

478 

479 return plane 

480 

481 

482def slice_peaks_from_data_using_mask(data, mask): 

483 peak_slices = np.array([d[mask] for d in data]) 

484 return peak_slices 

485 

486 

487def get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points): 

488 max_point, min_point = ( 

489 int(np.ceil(max(group_axis_points) + mask_radius_in_points + 1)), 

490 int(np.floor(min(group_axis_points) - mask_radius_in_points)), 

491 ) 

492 return max_point, min_point 

493 

494 

495def deal_with_peaks_on_edge_of_spectrum(data_shape, max_x, min_x, max_y, min_y): 

496 if min_y < 0: 

497 min_y = 0 

498 

499 if min_x < 0: 

500 min_x = 0 

501 

502 if max_y > data_shape[-2]: 

503 max_y = data_shape[-2] 

504 

505 if max_x > data_shape[-1]: 

506 max_x = data_shape[-1] 

507 return max_x, min_x, max_y, min_y 

508 

509 

510def make_meshgrid(data_shape): 

511 # must be a better way to make the meshgrid 

512 x = np.arange(data_shape[-1]) 

513 y = np.arange(data_shape[-2]) 

514 XY = np.meshgrid(x, y) 

515 return XY 

516 

517 

518def unpack_xy_bounds(xy_bounds, peakipy_data): 

519 match xy_bounds: 

520 case (0, 0): 

521 xy_bounds = None 

522 case (x, y): 

523 # convert ppm to points 

524 xy_bounds = list(xy_bounds) 

525 xy_bounds[0] = xy_bounds[0] * peakipy_data.pt_per_ppm_f2 

526 xy_bounds[1] = xy_bounds[1] * peakipy_data.pt_per_ppm_f1 

527 case _: 

528 raise TypeError( 

529 "xy_bounds should be a tuple (<x_bounds_ppm>, <y_bounds_ppm>)" 

530 ) 

531 return xy_bounds 

532 

533 

534def select_specified_planes(plane, peakipy_data): 

535 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) 

536 # only fit specified planes 

537 if plane: 

538 inds = [i for i in plane] 

539 data_inds = [ 

540 (i in inds) for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) 

541 ] 

542 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ 

543 data_inds 

544 ] 

545 peakipy_data.data = peakipy_data.data[data_inds] 

546 print( 

547 "[yellow]Using only planes {plane} data now has the following shape[/yellow]", 

548 peakipy_data.data.shape, 

549 ) 

550 if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: 

551 print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) 

552 exit() 

553 return plane_numbers, peakipy_data 

554 

555 

556def exclude_specified_planes(exclude_plane, peakipy_data): 

557 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) 

558 # do not fit these planes 

559 if exclude_plane: 

560 inds = [i for i in exclude_plane] 

561 data_inds = [ 

562 (i not in inds) 

563 for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) 

564 ] 

565 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ 

566 data_inds 

567 ] 

568 peakipy_data.data = peakipy_data.data[data_inds] 

569 print( 

570 f"[yellow]Excluding planes {exclude_plane} data now has the following shape[/yellow]", 

571 peakipy_data.data.shape, 

572 ) 

573 if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: 

574 print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) 

575 exit() 

576 return plane_numbers, peakipy_data 

577 

578 

579def get_fit_data_for_selected_peak_clusters(fits, clusters): 

580 match clusters: 

581 case None | []: 

582 pass 

583 case _: 

584 # only use these clusters 

585 fits = fits[fits.clustid.isin(clusters)] 

586 if len(fits) < 1: 

587 exit(f"Are you sure clusters {clusters} exist?") 

588 return fits 

589 

590 

591def make_masks_from_plane_data(empty_mask_array, plane_data): 

592 # make masks 

593 individual_masks = [] 

594 for cx, cy, rx, ry, name in zip( 

595 plane_data.center_x, 

596 plane_data.center_y, 

597 plane_data.x_radius, 

598 plane_data.y_radius, 

599 plane_data.assignment, 

600 ): 

601 tmp_mask = make_mask(empty_mask_array, cx, cy, rx, ry) 

602 empty_mask_array += tmp_mask 

603 individual_masks.append(tmp_mask) 

604 filled_mask_array = empty_mask_array 

605 return individual_masks, filled_mask_array 

606 

607 

608def simulate_pv_pv_lineshapes_from_fitted_peak_parameters( 

609 peak_parameters, XY, sim_data, sim_data_singles 

610): 

611 for amp, c_x, c_y, s_x, s_y, frac_x, frac_y, ls in zip( 

612 peak_parameters.amp, 

613 peak_parameters.center_x, 

614 peak_parameters.center_y, 

615 peak_parameters.sigma_x, 

616 peak_parameters.sigma_y, 

617 peak_parameters.fraction_x, 

618 peak_parameters.fraction_y, 

619 peak_parameters.lineshape, 

620 ): 

621 sim_data_i = pv_pv(XY, amp, c_x, c_y, s_x, s_y, frac_x, frac_y).reshape( 

622 sim_data.shape 

623 ) 

624 sim_data += sim_data_i 

625 sim_data_singles.append(sim_data_i) 

626 return sim_data, sim_data_singles 

627 

628 

629def simulate_lineshapes_from_fitted_peak_parameters( 

630 peak_parameters, XY, sim_data, sim_data_singles 

631): 

632 shape = sim_data.shape 

633 for amp, c_x, c_y, s_x, s_y, frac, lineshape in zip( 

634 peak_parameters.amp, 

635 peak_parameters.center_x, 

636 peak_parameters.center_y, 

637 peak_parameters.sigma_x, 

638 peak_parameters.sigma_y, 

639 peak_parameters.fraction, 

640 peak_parameters.lineshape, 

641 ): 

642 # print(amp) 

643 match lineshape: 

644 case "G" | "L" | "PV": 

645 sim_data_i = pvoigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

646 case "PV_L": 

647 sim_data_i = pv_l(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

648 

649 case "PV_G": 

650 sim_data_i = pv_g(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

651 

652 case "G_L": 

653 sim_data_i = gaussian_lorentzian( 

654 XY, amp, c_x, c_y, s_x, s_y, frac 

655 ).reshape(shape) 

656 

657 case "V": 

658 sim_data_i = voigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

659 sim_data += sim_data_i 

660 sim_data_singles.append(sim_data_i) 

661 return sim_data, sim_data_singles 

662 

663 

664@dataclass 

665class FitPeaksArgs: 

666 noise: float 

667 uc_dics: dict 

668 lineshape: Lineshape 

669 dims: List[int] = field(default_factory=lambda: [0, 1, 2]) 

670 colors: Tuple[str] = ("#5e3c99", "#e66101") 

671 max_cluster_size: Optional[int] = None 

672 to_fix: List[str] = field(default_factory=lambda: ["fraction", "sigma", "center"]) 

673 xy_bounds: Tuple[float, float] = ((0, 0),) 

674 vclist: Optional[Path] = (None,) 

675 plane: Optional[List[int]] = (None,) 

676 exclude_plane: Optional[List[int]] = (None,) 

677 reference_plane_indices: List[int] = ([],) 

678 initial_fit_threshold: Optional[float] = (None,) 

679 jack_knife_sample_errors: bool = False 

680 mp: bool = (True,) 

681 verbose: bool = (False,) 

682 vclist_data: Optional[np.array] = None 

683 

684 

685@dataclass 

686class Config: 

687 fit_method: str = "leastsq" 

688 

689 

690@dataclass 

691class FitPeaksInput: 

692 """input data for the fit_peaks function""" 

693 

694 args: FitPeaksArgs 

695 data: np.array 

696 config: Config 

697 plane_numbers: list 

698 

699 

700@dataclass 

701class FitPeakClusterInput: 

702 args: FitPeaksArgs 

703 data: np.array 

704 config: Config 

705 plane_numbers: list 

706 clustid: int 

707 group: pd.DataFrame 

708 last_peak: pd.DataFrame 

709 mask: np.array 

710 mod: Model 

711 p_guess: Parameters 

712 XY: np.array 

713 peak_slices: np.array 

714 XY_slices: np.array 

715 min_x: float 

716 max_x: float 

717 min_y: float 

718 max_y: float 

719 uc_dics: dict 

720 first_plane_data: np.array 

721 weights: np.array 

722 fit_method: str = "leastsq" 

723 verbose: bool = False 

724 masked_plane_data: np.array = field(init=False) 

725 

726 def __post_init__(self): 

727 self.masked_plane_data = np.array([d[self.mask] for d in self.data]) 

728 

729 

730@dataclass 

731class FitResult: 

732 out: ModelResult 

733 mask: np.array 

734 fit_str: str 

735 log: str 

736 group: pd.core.groupby.generic.DataFrameGroupBy 

737 uc_dics: dict 

738 min_x: float 

739 min_y: float 

740 max_x: float 

741 max_y: float 

742 X: np.array 

743 Y: np.array 

744 Z: np.array 

745 Z_sim: np.array 

746 peak_slices: np.array 

747 XY_slices: np.array 

748 weights: np.array 

749 mod: Model 

750 

751 def check_shifts(self): 

752 """Calculate difference between initial peak positions 

753 and check whether they moved too much from original 

754 position 

755 

756 """ 

757 pass 

758 

759 

760@dataclass 

761class FitPeaksResult: 

762 df: pd.DataFrame 

763 log: str 

764 

765 

766class FitPeaksResultDfRow(BaseModel): 

767 fit_prefix: str 

768 assignment: str 

769 amp: float 

770 amp_err: float 

771 center_x: float 

772 init_center_x: float 

773 center_y: float 

774 init_center_y: float 

775 sigma_x: float 

776 sigma_y: float 

777 clustid: int 

778 memcnt: int 

779 plane: int 

780 x_radius: float 

781 y_radius: float 

782 x_radius_ppm: float 

783 y_radius_ppm: float 

784 lineshape: str 

785 aic: float 

786 chisqr: float 

787 redchi: float 

788 residual_sum: float 

789 height: float 

790 height_err: float 

791 fwhm_x: float 

792 fwhm_y: float 

793 center_x_ppm: float 

794 center_y_ppm: float 

795 init_center_x_ppm: float 

796 init_center_y_ppm: float 

797 sigma_x_ppm: float 

798 sigma_y_ppm: float 

799 fwhm_x_ppm: float 

800 fwhm_y_ppm: float 

801 fwhm_x_hz: float 

802 fwhm_y_hz: float 

803 jack_knife_sample_index: Optional[int] 

804 

805 

806class FitPeaksResultRowGLPV(FitPeaksResultDfRow): 

807 fraction: float 

808 

809 

810class FitPeaksResultRowPVPV(FitPeaksResultDfRow): 

811 fraction_x: float # for PV_PV model 

812 fraction_y: float # for PV_PV model 

813 

814 

815class FitPeaksResultRowVoigt(FitPeaksResultDfRow): 

816 gamma_x_ppm: float # for voigt 

817 gamma_y_ppm: float # for voigt 

818 

819 

820def get_fit_peaks_result_validation_model(lineshape): 

821 """ 

822 Retrieve the appropriate validation model based on the lineshape used for fitting. 

823  

824 Parameters 

825 ---------- 

826 lineshape : Lineshape 

827 Enum or string indicating the type of lineshape model used for fitting. 

828  

829 Returns 

830 ------- 

831 type 

832 The validation model class corresponding to the specified lineshape. 

833 """ 

834 match lineshape: 

835 case lineshape.V: 

836 validation_model = FitPeaksResultRowVoigt 

837 case lineshape.PV_PV: 

838 validation_model = FitPeaksResultRowPVPV 

839 case _: 

840 validation_model = FitPeaksResultRowGLPV 

841 return validation_model 

842 

843 

844def filter_peak_clusters_by_max_cluster_size(grouped_peak_clusters, max_cluster_size): 

845 filtered_peak_clusters = grouped_peak_clusters.filter( 

846 lambda x: len(x) <= max_cluster_size 

847 ) 

848 return filtered_peak_clusters 

849 

850 

851def set_parameters_to_fix_during_fit(first_plane_fit_params, to_fix): 

852 # fix sigma center and fraction parameters 

853 # could add an option to select params to fix 

854 match to_fix: 

855 case None | () | []: 

856 float_str = "Floating all parameters" 

857 parameter_set = first_plane_fit_params 

858 case ["None"] | ["none"]: 

859 float_str = "Floating all parameters" 

860 parameter_set = first_plane_fit_params 

861 case _: 

862 float_str = f"Fixing parameters: {to_fix}" 

863 parameter_set = fix_params(first_plane_fit_params, to_fix) 

864 return parameter_set, float_str 

865 

866 

867def get_default_lineshape_param_names(lineshape: Lineshape): 

868 match lineshape: 

869 case Lineshape.PV | Lineshape.G | Lineshape.L: 

870 param_names = Model(pvoigt2d).param_names 

871 case Lineshape.V: 

872 param_names = Model(voigt2d).param_names 

873 case Lineshape.PV_PV: 

874 param_names = Model(pv_pv).param_names 

875 return param_names 

876 

877 

878def split_parameter_sets_by_peak( 

879 default_param_names: List, params: List[Tuple[str, Parameter]] 

880): 

881 """params is a list of tuples where the first element of each tuple is a 

882 prefixed parameter name and the second element is the corresponding 

883 Parameter object. This is created by calling .items() on a Parameters 

884 object 

885 """ 

886 number_of_fitted_parameters = len(params) 

887 number_of_default_params = len(default_param_names) 

888 number_of_fitted_peaks = int(number_of_fitted_parameters / number_of_default_params) 

889 split_param_items = [ 

890 params[i : (i + number_of_default_params)] 

891 for i in range(0, number_of_fitted_parameters, number_of_default_params) 

892 ] 

893 assert len(split_param_items) == number_of_fitted_peaks 

894 return split_param_items 

895 

896 

897def create_parameter_dict(prefix, parameters: List[Tuple[str, Parameter]]): 

898 parameter_dict = dict(prefix=prefix) 

899 parameter_dict.update({k.replace(prefix, ""): v.value for k, v in parameters}) 

900 parameter_dict.update( 

901 {f"{k.replace(prefix,'')}_stderr": v.stderr for k, v in parameters} 

902 ) 

903 return parameter_dict 

904 

905 

906def get_prefix_from_parameter_names( 

907 default_param_names: List, parameters: List[Tuple[str, Parameter]] 

908): 

909 prefixes = [ 

910 param_key_val[0].replace(default_param_name, "") 

911 for param_key_val, default_param_name in zip(parameters, default_param_names) 

912 ] 

913 assert len(set(prefixes)) == 1 

914 return prefixes[0] 

915 

916 

917def unpack_fitted_parameters_for_lineshape( 

918 lineshape: Lineshape, params: List[dict], plane_number: int 

919): 

920 default_param_names = get_default_lineshape_param_names(lineshape) 

921 split_parameter_names = split_parameter_sets_by_peak(default_param_names, params) 

922 prefixes = [ 

923 get_prefix_from_parameter_names(default_param_names, i) 

924 for i in split_parameter_names 

925 ] 

926 unpacked_params = [] 

927 for parameter_names, prefix in zip(split_parameter_names, prefixes): 

928 parameter_dict = create_parameter_dict(prefix, parameter_names) 

929 parameter_dict.update({"plane": plane_number}) 

930 unpacked_params.append(parameter_dict) 

931 return unpacked_params 

932 

933 

934def perform_initial_lineshape_fit_on_cluster_of_peaks( 

935 fit_peak_cluster_input: FitPeakClusterInput, 

936) -> FitResult: 

937 mod = fit_peak_cluster_input.mod 

938 peak_slices = fit_peak_cluster_input.peak_slices 

939 XY_slices = fit_peak_cluster_input.XY_slices 

940 p_guess = fit_peak_cluster_input.p_guess 

941 weights = fit_peak_cluster_input.weights 

942 fit_method = fit_peak_cluster_input.fit_method 

943 mask = fit_peak_cluster_input.mask 

944 XY = fit_peak_cluster_input.XY 

945 X, Y = XY 

946 first_plane_data = fit_peak_cluster_input.first_plane_data 

947 peak = fit_peak_cluster_input.last_peak 

948 group = fit_peak_cluster_input.group 

949 min_x = fit_peak_cluster_input.min_x 

950 min_y = fit_peak_cluster_input.min_y 

951 max_x = fit_peak_cluster_input.max_x 

952 max_y = fit_peak_cluster_input.max_y 

953 verbose = fit_peak_cluster_input.verbose 

954 uc_dics = fit_peak_cluster_input.uc_dics 

955 

956 out = mod.fit( 

957 peak_slices, XY=XY_slices, params=p_guess, weights=weights, method=fit_method 

958 ) 

959 

960 if verbose: 

961 print(out.fit_report()) 

962 

963 z_sim = mod.eval(XY=XY, params=out.params) 

964 z_sim[~mask] = np.nan 

965 z_plot = first_plane_data.copy() 

966 z_plot[~mask] = np.nan 

967 fit_str = "" 

968 log = "" 

969 

970 return FitResult( 

971 out=out, 

972 mask=mask, 

973 fit_str=fit_str, 

974 log=log, 

975 group=group, 

976 uc_dics=uc_dics, 

977 min_x=min_x, 

978 min_y=min_y, 

979 max_x=max_x, 

980 max_y=max_y, 

981 X=X, 

982 Y=Y, 

983 Z=z_plot, 

984 Z_sim=z_sim, 

985 peak_slices=peak_slices, 

986 XY_slices=XY_slices, 

987 weights=weights, 

988 mod=mod, 

989 ) 

990 

991 

992def refit_peak_cluster_with_constraints( 

993 fit_input: FitPeakClusterInput, fit_result: FitPeaksResult 

994): 

995 fit_results = [] 

996 for num, d in enumerate(fit_input.masked_plane_data): 

997 plane_number = fit_input.plane_numbers[num] 

998 fit_result.out.fit( 

999 data=d, 

1000 params=fit_result.out.params, 

1001 weights=fit_result.weights, 

1002 ) 

1003 fit_results.extend( 

1004 unpack_fitted_parameters_for_lineshape( 

1005 fit_input.args.lineshape, 

1006 list(fit_result.out.params.items()), 

1007 plane_number, 

1008 ) 

1009 ) 

1010 return fit_results 

1011 

1012 

1013def merge_unpacked_parameters_with_metadata(cluster_fit_df, group_of_peaks_df): 

1014 """ 

1015 Combine fitted peak parameters with their associated metadata. 

1016  

1017 Parameters 

1018 ---------- 

1019 cluster_fit_df : pd.DataFrame 

1020 DataFrame containing peak fitting results. 

1021 group_of_peaks_df : pd.DataFrame 

1022 DataFrame with metadata for corresponding peaks. 

1023  

1024 Returns 

1025 ------- 

1026 pd.DataFrame 

1027 Merged DataFrame with both fitting results and metadata. 

1028 """ 

1029 group_of_peaks_df["prefix"] = group_of_peaks_df.ASS.apply(to_prefix) 

1030 merged_cluster_fit_df = cluster_fit_df.merge( 

1031 group_of_peaks_df, on="prefix", suffixes=["", "_init"] 

1032 ) 

1033 return merged_cluster_fit_df 

1034 

1035 

1036def update_cluster_df_with_fit_statistics(cluster_df, fit_result: ModelResult): 

1037 cluster_df["chisqr"] = fit_result.chisqr 

1038 cluster_df["redchi"] = fit_result.redchi 

1039 cluster_df["residual_sum"] = np.sum(fit_result.residual) 

1040 cluster_df["aic"] = fit_result.aic 

1041 cluster_df["bic"] = fit_result.bic 

1042 cluster_df["nfev"] = fit_result.nfev 

1043 cluster_df["ndata"] = fit_result.ndata 

1044 return cluster_df 

1045 

1046 

1047def rename_columns_for_compatibility(df): 

1048 mapping = { 

1049 "amplitude": "amp", 

1050 "amplitude_stderr": "amp_err", 

1051 "X_AXIS": "init_center_x", 

1052 "Y_AXIS": "init_center_y", 

1053 "ASS": "assignment", 

1054 "MEMCNT": "memcnt", 

1055 "X_RADIUS": "x_radius", 

1056 "Y_RADIUS": "y_radius", 

1057 } 

1058 df = df.rename(columns=mapping) 

1059 return df 

1060 

1061 

1062def add_vclist_to_df(fit_input: FitPeaksInput, df: pd.DataFrame): 

1063 vclist_data = fit_input.args.vclist_data 

1064 df["vclist"] = df.plane.apply(lambda x: vclist_data[x]) 

1065 return df 

1066 

1067 

1068def prepare_group_of_peaks_for_fitting(clustid, group, fit_peaks_input: FitPeaksInput): 

1069 lineshape_function = get_lineshape_function(fit_peaks_input.args.lineshape) 

1070 

1071 first_plane_data = fit_peaks_input.data[0] 

1072 mask, peak = make_mask_from_peak_cluster(group, first_plane_data) 

1073 

1074 x_radius = group.X_RADIUS.max() 

1075 y_radius = group.Y_RADIUS.max() 

1076 

1077 max_x, min_x = get_limits_for_axis_in_points( 

1078 group_axis_points=group.X_AXISf, mask_radius_in_points=x_radius 

1079 ) 

1080 max_y, min_y = get_limits_for_axis_in_points( 

1081 group_axis_points=group.Y_AXISf, mask_radius_in_points=y_radius 

1082 ) 

1083 max_x, min_x, max_y, min_y = deal_with_peaks_on_edge_of_spectrum( 

1084 fit_peaks_input.data.shape, max_x, min_x, max_y, min_y 

1085 ) 

1086 selected_data = select_reference_planes_using_indices( 

1087 fit_peaks_input.data, fit_peaks_input.args.reference_plane_indices 

1088 ).sum(axis=0) 

1089 mod, p_guess = make_models( 

1090 lineshape_function, 

1091 group, 

1092 selected_data, 

1093 lineshape=fit_peaks_input.args.lineshape, 

1094 xy_bounds=fit_peaks_input.args.xy_bounds, 

1095 ) 

1096 peak_slices = slice_peaks_from_data_using_mask(fit_peaks_input.data, mask) 

1097 peak_slices = select_reference_planes_using_indices( 

1098 peak_slices, fit_peaks_input.args.reference_plane_indices 

1099 ) 

1100 peak_slices = select_planes_above_threshold_from_masked_data( 

1101 peak_slices, fit_peaks_input.args.initial_fit_threshold 

1102 ) 

1103 peak_slices = peak_slices.sum(axis=0) 

1104 

1105 XY = make_meshgrid(fit_peaks_input.data.shape) 

1106 X, Y = XY 

1107 

1108 XY_slices = np.array([X.copy()[mask], Y.copy()[mask]]) 

1109 weights = 1.0 / np.array([fit_peaks_input.args.noise] * len(np.ravel(peak_slices))) 

1110 return FitPeakClusterInput( 

1111 args=fit_peaks_input.args, 

1112 data=fit_peaks_input.data, 

1113 config=fit_peaks_input.config, 

1114 plane_numbers=fit_peaks_input.plane_numbers, 

1115 clustid=clustid, 

1116 group=group, 

1117 last_peak=peak, 

1118 mask=mask, 

1119 mod=mod, 

1120 p_guess=p_guess, 

1121 XY=XY, 

1122 peak_slices=peak_slices, 

1123 XY_slices=XY_slices, 

1124 weights=weights, 

1125 fit_method=Config.fit_method, 

1126 first_plane_data=first_plane_data, 

1127 uc_dics=fit_peaks_input.args.uc_dics, 

1128 min_x=min_x, 

1129 min_y=min_y, 

1130 max_x=max_x, 

1131 max_y=max_y, 

1132 verbose=fit_peaks_input.args.verbose, 

1133 ) 

1134 

1135 

1136def fit_cluster_of_peaks(data_for_fitting: FitPeakClusterInput) -> pd.DataFrame: 

1137 fit_result = perform_initial_lineshape_fit_on_cluster_of_peaks(data_for_fitting) 

1138 fit_result.out.params, float_str = set_parameters_to_fix_during_fit( 

1139 fit_result.out.params, data_for_fitting.args.to_fix 

1140 ) 

1141 fit_results = refit_peak_cluster_with_constraints(data_for_fitting, fit_result) 

1142 cluster_df = pd.DataFrame(fit_results) 

1143 cluster_df = update_cluster_df_with_fit_statistics(cluster_df, fit_result.out) 

1144 cluster_df["clustid"] = data_for_fitting.clustid 

1145 cluster_df = merge_unpacked_parameters_with_metadata( 

1146 cluster_df, data_for_fitting.group 

1147 ) 

1148 return cluster_df 

1149 

1150 

1151def fit_peak_clusters(peaks: pd.DataFrame, fit_input: FitPeaksInput) -> FitPeaksResult: 

1152 """Fit set of peak clusters to lineshape model 

1153 

1154 :param peaks: peaklist with generated by peakipy read or edit 

1155 :type peaks: pd.DataFrame 

1156 

1157 :param fit_input: Data structure containing input parameters (args, config and NMR data) 

1158 :type fit_input: FitPeaksInput 

1159 

1160 :returns: Data structure containing pd.DataFrame with the fitted results and a log 

1161 :rtype: FitPeaksResult 

1162 """ 

1163 peak_clusters = peaks.groupby("CLUSTID") 

1164 filtered_peaks = filter_peak_clusters_by_max_cluster_size( 

1165 peak_clusters, fit_input.args.max_cluster_size 

1166 ) 

1167 peak_clusters = filtered_peaks.groupby("CLUSTID") 

1168 out_str = "" 

1169 cluster_dfs = [] 

1170 for clustid, peak_cluster in peak_clusters: 

1171 data_for_fitting = prepare_group_of_peaks_for_fitting( 

1172 clustid, 

1173 peak_cluster, 

1174 fit_input, 

1175 ) 

1176 if fit_input.args.jack_knife_sample_errors: 

1177 cluster_df = jack_knife_sample_errors(data_for_fitting) 

1178 else: 

1179 cluster_df = fit_cluster_of_peaks(data_for_fitting) 

1180 cluster_dfs.append(cluster_df) 

1181 df = pd.concat(cluster_dfs, ignore_index=True) 

1182 

1183 df["lineshape"] = fit_input.args.lineshape.value 

1184 

1185 if fit_input.args.vclist: 

1186 df = add_vclist_to_df(fit_input, df) 

1187 df = rename_columns_for_compatibility(df) 

1188 return FitPeaksResult(df=df, log=out_str) 

1189 

1190 

1191def jack_knife_sample_errors(fit_input: FitPeakClusterInput) -> pd.DataFrame: 

1192 peak_slices = fit_input.peak_slices.copy() 

1193 XY_slices = fit_input.XY_slices.copy() 

1194 weights = fit_input.weights.copy() 

1195 masked_plane_data = fit_input.masked_plane_data.copy() 

1196 jk_results = [] 

1197 # first fit without jackknife 

1198 jk_result = fit_cluster_of_peaks(data_for_fitting=fit_input) 

1199 jk_result["jack_knife_sample_index"] = 0 

1200 jk_results.append(jk_result) 

1201 for i in np.arange(0, len(peak_slices), 10, dtype=int): 

1202 fit_input.peak_slices = np.delete(peak_slices, i, None) 

1203 XY_slices_0 = np.delete(XY_slices[0], i, None) 

1204 XY_slices_1 = np.delete(XY_slices[1], i, None) 

1205 fit_input.XY_slices = np.array([XY_slices_0, XY_slices_1]) 

1206 fit_input.weights = np.delete(weights, i, None) 

1207 fit_input.masked_plane_data = np.delete(masked_plane_data, i, axis=1) 

1208 jk_result = fit_cluster_of_peaks(data_for_fitting=fit_input) 

1209 jk_result["jack_knife_sample_index"] = i + 1 

1210 jk_results.append(jk_result) 

1211 return pd.concat(jk_results, ignore_index=True)