Skip to content

Commit

Permalink
Merge pull request #93 from spacetelescope/run_all_improvements
Browse files Browse the repository at this point in the history
Updated centroid to have separate plot arguments, utils.run_all can now run all vetters and make diagnostic plots.
  • Loading branch information
m-dallas authored May 6, 2024
2 parents fe1cf0a + 0ad2188 commit 0f31742
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 1,432 deletions.
172 changes: 170 additions & 2 deletions exovetter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

import sys
import warnings

from exovetter import lightkurve_utils
from exovetter import utils
from exovetter import const as exo_const
import astropy.units as u
from exovetter import vetters as vet
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import numpy as np
import os

__all__ = ['sine', 'estimate_scatter', 'mark_transit_cadences', 'median_detrend',
'plateau', 'set_median_flux_to_zero', 'set_median_flux_to_one', 'sigmaClip',
'get_mast_tce', 'WqedLSF', 'compute_phases', 'first_epoch']
'get_mast_tce', 'WqedLSF', 'compute_phases', 'first_epoch', 'run_all']

def sine(x, order, period=1):
"""Sine function for SWEET vetter."""
Expand Down Expand Up @@ -653,3 +660,164 @@ def first_epoch(epoch, period, lc):
first_epoch = epoch + N*period

return first_epoch

def run_all(tce, lc, tpf=None, vetters=None, remove_metrics=None, plot=False, plot_dir=None):
"""Run a set of vetters on a tce and lc, returning a dictionary of all metrics collected
Parameters
----------
tce : tce object
tce object is a dictionary that contains information about the tce
to vet, like period, epoch, duration, depth
lc : lightkurve object
lightkurve object with the time and flux to use for vetting.
tpf : obj
``lightkurve`` target pixel file object with pixels in column lc_name
vetters : list
list of vetter objects to run on, ie [vet.ModShift(), vet.OddEven(dur_frac=0.75)]
Defaults to all vetters
remove_metrics : list
metrics to not store, defaults to removing plotting values
plot : bool
Option to return a pdf of the vetting diagnostic plots, defaults to False
plot_dir : str
Path to store diagnostic pdfs in, defaults to current working directory
Returns
-------
results_dict : dictionary
Dictionary of the kept vetting metrics
"""

# Set initial parameters
if not vetters:
vetters = [vet.VizTransits(), vet.ModShift(), vet.Lpp(), vet.OddEven(), vet.TransitPhaseCoverage(), vet.Sweet(), vet.LeoTransitEvents(), vet.Centroid()]

if not remove_metrics:
remove_metrics = ['plot_data']

if not plot_dir:
plot_dir = os.getcwd()+'/'

if not tpf:
if any(vetter.__class__.__name__ == 'Centroid' for vetter in vetters):
raise Exception("TPF file required while running centroid")

# Run all listed vetters
results_list = []

if not plot:
for vetter in vetters:
if vetter.__class__.__name__ != 'Centroid':
vetter_results = vetter.run(tce, lc, plot=False) # dictionary returned from each vetter
results_list.append(vetter_results)
else:
vetter_results = vetter.run(tce, tpf, plot=False) # centroid uses tpf rather than lc
results_list.append(vetter_results)

else:
plot_name = lc.LABEL

diagnostic_plot = PdfPages(plot_dir+plot_name+'.pdf') # initialize a pdf to save each figure into
plot_figures = []

# Manually run viz transits with an extra mark cadences plot
cadences_plot = mark_cadences_plot(lc, tce)
diagnostic_plot.savefig(cadences_plot)

for vetter in vetters:
if vetter.__class__.__name__ not in ['VizTransits', 'Centroid', 'LeoTransitEvents']: # viz_transits and Centroid generate more than one figures so handle later
vetter_results = vetter.run(tce, lc, plot=True) # dictionary returned from each vetter
plot_figures.append(plt.gcf())
plt.close() # Make them not show up if running in a notebook
results_list.append(vetter_results)

if vetter.__class__.__name__ == 'Centroid': # centroid produces 2 plots, the second of which is the most useful so just collect that one
vetter_results = vet.Centroid(lc_name=vetter.lc_name, diff_plots=False, centroid_plots=True).run(tce, tpf)
plot_figures.append(plt.gcf())
plt.close()
results_list.append(vetter_results)

# run viz_transits plots
transit = vet.VizTransits(transit_plot=True, folded_plot=False).run(tce, lc)
transit_plot = plt.gcf()
transit_plot.suptitle(plot_name+' Transits')
transit_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(transit_plot)

folded = vet.VizTransits(transit_plot=False, folded_plot=True).run(tce, lc)
folded_plot = plt.gcf()
folded_plot.suptitle(plot_name+' Folded Transits')
folded_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(folded_plot)

# Save each diagnostic plot stored in plot_figures to diagnostic_plot pdf file
for plot in plot_figures:
diagnostic_plot.savefig(plot)

diagnostic_plot.close()

# Convert to a single dictionary output
results_dict = {k: v for d in results_list for k, v in d.items()} # Combine all dictionaries returned from vetters

# delete specified metrics
for key in remove_metrics:
if results_dict.get(key):
del results_dict[key]

return results_dict


def mark_cadences_plot(lc, tce):
"""return figure object of the lightcurve with epochs oeverplotted"""

fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(9,5))

# Get lightcurve data
time, flux, time_offset_str = lightkurve_utils.unpack_lk_version(lc, "flux") # noqa: E50
time_offset_q = getattr(exo_const, time_offset_str)
epoch = tce.get_epoch(time_offset_q).to_value(u.day)

# Get points of lightcurve found to be in transit
period = tce["period"].to_value(u.day)
dur = tce["duration"].to_value(u.day)
intransit = mark_transit_cadences(time, period, epoch, dur)

# Plot epoch
ax1.axvline(x=epoch, lw='0.6', color='r', label='epoch', alpha=0.5)

# Plot transit train
ax1.plot(time, flux, lw=0.72, alpha=0.9)
ax1.scatter(time, flux, color='k', s=3, label='cadences', alpha=0.5)

# TODO This only plots forward in time from the epoch, works fine assuming tce epoch is first in light curve but not robust
transit_epochs = epoch-dur/2
while transit_epochs <= time[-1]:
ax1.axvline(x=transit_epochs, lw='0.6', color='r', alpha=0.3, ls='--')
transit_epochs = transit_epochs+dur
ax1.axvline(x=transit_epochs, lw='0.6', color='r', alpha=0.3, ls='--')
transit_epochs = transit_epochs-dur + period

# Plot cadences in transit
ax1.scatter(time[intransit], flux[intransit], color='r', s=4, label='cadences in transit');

# Plotting params
ax1.set_ylabel('Flux')
ax1.set_xlabel('Time '+time_offset_str)
ax1.set_title(lc.label+' period='+'{0:.2f}'.format(period)+'d, dur='+'{0:.2f}'.format(dur)+'d')
ax1.legend();

cadences_plot = plt.gcf()
plt.close()

return cadences_plot


Loading

0 comments on commit 0f31742

Please sign in to comment.