Source code for pychron.pipeline.editors.flux_results_editor

# ===============================================================================
# Copyright 2015 Jake Ross
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
from operator import attrgetter

# ============= enthought library imports =======================
from numpy import average, array, diff, arctan, Inf
from scipy.stats import mode
from traits.api import HasTraits, Str, Int, Bool, Float, Property, List, Event, Button, on_trait_change
from traitsui.api import View, UItem, TableEditor, VGroup, HGroup, Item, spring, Tabbed, Readonly, EnumEditor
from traitsui.extras.checkbox_column import CheckboxColumn
from traitsui.table_column import ObjectColumn
from uncertainties import nominal_value, std_dev, ufloat

from pychron.core.helpers.formatting import calc_percent_error, floatfmt
from pychron.core.helpers.iterfuncs import groupby_key
from pychron.core.regression.mean_regressor import MeanRegressor, WeightedMeanRegressor
from pychron.envisage.icon_button_editor import icon_button_editor
from pychron.pipeline.editors.flux_visualization_editor import BaseFluxVisualizationEditor
from pychron.pipeline.plot.plotter.arar_figure import SelectionFigure
from pychron.processing.argon_calculations import calculate_flux
from pychron.pychron_constants import LEAST_SQUARES_1D, WEIGHTED_MEAN_1D
from pychron.pychron_constants import PLUSMINUS_ONE_SIGMA


def mean_j(ans, use_weights, error_kind, monitor_age, lambda_k):
    js = [calculate_flux(ai.uF, monitor_age, lambda_k=lambda_k) for ai in ans]

    fs = [nominal_value(fi) for fi in js]
    es = [std_dev(fi) for fi in js]

    if use_weights:
        klass = WeightedMeanRegressor
    else:
        klass = MeanRegressor

    reg = klass(ys=fs, yserr=es, error_calc_type=error_kind)
    reg.calculate()
    av = reg.predict()
    werr = reg.predict_error(1)

    j = ufloat(av, werr)

    return j, reg.mswd


def column(klass=ObjectColumn, editable=False, **kw):
    return klass(text_font='arial 10', editable=editable, **kw)


def sciformat(x):
    return '{:0.6E}'.format(x) if x else ''


def ff(x):
    return floatfmt(x, n=2) if x else ''


MONITOR_COLUMNS = [
    column(klass=CheckboxColumn, name='use', label='Use', editable=True, width=30),
    column(klass=CheckboxColumn, name='save', label='Save', editable=True, width=30),
    column(name='hole_id', label='Hole'),
    column(name='identifier', label='Identifier'),
    column(name='sample', label='Sample', width=115),

    column(name='n', label='N'),
    column(name='saved_j', label='Saved J',
           format_func=sciformat),
    column(name='saved_jerr', label=PLUSMINUS_ONE_SIGMA,
           format_func=sciformat),
    column(name='percent_saved_error',
           label='%',
           format_func=ff),
    column(name='mean_j', label='Mean J',
           format_func=sciformat),
    column(name='mean_jerr', label=PLUSMINUS_ONE_SIGMA,
           format_func=sciformat),
    column(name='percent_mean_error',
           label='%',
           format_func=ff),
    column(name='mean_j_mswd',
           label='MSWD',
           format_func=ff),
    column(name='j', label='Pred. J',
           format_func=sciformat,
           width=75),
    column(name='jerr',
           format_func=sciformat,
           label=PLUSMINUS_ONE_SIGMA,
           width=75),
    column(name='percent_pred_error',
           label='%',
           format_func=ff),
    column(name='dev', label='dev',
           format='%0.2f',
           width=70),
    column(name='position_jerr',
           format_func=sciformat,
           label='Position Err.'),
    column(name='percent_position_jerr',
           label='%',
           format_func=ff)]

MONITOR_EDITOR = TableEditor(columns=MONITOR_COLUMNS, sortable=False, reorderable=False)


class FluxPosition(HasTraits):
    hole_id = Int
    identifier = Str
    sample = Str
    x = Float
    y = Float
    z = Float
    theta = Float
    saved_j = Float
    saved_jerr = Float

    mean_j = Float
    mean_jerr = Float
    mean_j_mswd = Float

    model_kind = Str

    n = Int

    j = Float(enter_set=True, auto_set=False)
    jerr = Float(enter_set=True, auto_set=False)
    position_jerr = Float

    use = Bool(True)
    save = Bool(True)
    dev = Float

    percent_saved_error = Property
    percent_mean_error = Property
    percent_pred_error = Property
    percent_position_jerr = Property

    analyses = List
    error_kind = Str
    monitor_age = Float
    lambda_k = Float
    was_altered = Bool

    bracket_a = Int
    bracket_b = Int
    available_positions = List

    def set_mean_j(self, use_weights):

        ans = [a for a in self.analyses if not a.is_omitted()]
        if ans:
            j, mswd = mean_j(ans, use_weights, self.error_kind, self.monitor_age, self.lambda_k)
            self.mean_j = nominal_value(j)
            self.mean_jerr = std_dev(j)
            self.mean_j_mswd = mswd

        self.n = len(ans)

    def _get_percent_saved_error(self):
        return calc_percent_error(self.saved_j, self.saved_jerr)

    def _get_percent_mean_error(self):
        if self.mean_jerr and self.mean_jerr:
            return calc_percent_error(self.mean_j, self.mean_jerr)

    def _get_percent_pred_error(self):
        if self.j and self.jerr:
            return calc_percent_error(self.j, self.jerr)

    def _get_percent_position_jerr(self):
        if self.j and self.position_jerr:
            return calc_percent_error(self.j, self.position_jerr)


class FluxResultsEditor(BaseFluxVisualizationEditor, SelectionFigure):
    save_all_button = Event
    save_unknowns_button = Event
    recalculate_button = Button('Calculate')

    suppress_metadata_change = Bool(False)
    _analyses = List

    @property
    def analyses(self):
        return self._analyses[0] if self._analyses else []

    def set_items(self, analyses):
        pass
        # if self.geometry:
        #     self.set_positions(analyses)
        #     self.predict_values()

    def _recalculate_button_fired(self):
        self.predict_values()

    def set_positions(self, monitors, unk=None):
        self.debug('setting positions mons={}, unks={}'.format(len(monitors), len(unk) if unk else 0))
        opt = self.plotter_options
        monage = opt.monitor_age * 1e6
        lk = opt.lambda_k
        ek = opt.error_kind

        geom = self.geometry
        poss = []
        ans = []
        slope = True
        prev = None

        # calculate padding of the individuals analyses
        # by taking mean of the diffs between adjacent positions divided by 4
        if opt.model_kind in (LEAST_SQUARES_1D, WEIGHTED_MEAN_1D):
            idx = 0 if self.plotter_options.one_d_axis == 'X' else 1
            vs = array([p[idx] for p in geom])
            vs = abs(diff(vs))
            vs = vs[vs.astype(bool)].mean()
        else:
            vs = [p[1]/p[0] if p[0] else Inf for p in geom]
            vs = arctan(vs)

            vs = abs(diff(vs))
            vs = mode(vs[vs.astype(bool)], axis=None)[0][0]

        padding = vs / 4.

        for identifier, ais in groupby_key(monitors, 'identifier'):

            ais = list(ais)
            n = len(ais)

            ref = ais[0]
            j = ref.j
            ip = ref.irradiation_position
            sample = ref.sample

            x, y, r, idx = geom[ip - 1]

            p = FluxPosition(identifier=identifier,
                             irradiation=self.irradiation,
                             level=self.level,
                             sample=sample, hole_id=ip,
                             saved_j=nominal_value(j),
                             saved_jerr=std_dev(j),

                             error_kind=ek,
                             monitor_age=monage,
                             analyses=ais,
                             lambda_k=lk,
                             x=x, y=y,
                             n=n)

            p.set_mean_j(self.plotter_options.use_weighted_fit)
            poss.append(p)
            if prev:
                slope = prev < p.j
            prev = p.j
            vs = self._sort_individuals(p, monage, lk, slope, padding)
            if ans:
                ans = [list(ans[i]) + list(v) for i, v in enumerate(vs)]
                # ans = [ans[0].extend(aa), ans[0].extend(xx), ans[0].extend(yy), ans[0].extend(es)]
            else:
                ans = list(vs)

        self._analyses = ans
        self.monitor_positions = sorted(poss, key=attrgetter('hole_id'))

        if unk is not None:
            ps = [p.hole_id for p in self.monitor_positions]
            for ui in unk:
                ui.available_positions = ps
            self.unknown_positions = sorted(unk, key=attrgetter('hole_id'))

    def _update_graph_metadata(self, obj, name, old, new):
        if not self.suppress_metadata_change:
            self._filter_metadata_changes(obj, self.analyses, self._recalculate_means)

    def _recalculate_means(self, sel, ans=None):
        if ans is None:
            ans = self.analyses

        if sel:
            idx = {ans[si].identifier for si in sel}
        else:
            idx = [None]

        for identifier in idx:
            # self.debug('sel:{} idx:{}'.format(sel, idx))
            for p in self.monitor_positions:
                if p.identifier == identifier:
                    # self.debug('recalculate position {} {}, {}'.format(sel, p.hole_id, p.identifier))
                    p.set_mean_j(self.plotter_options.use_weighted_fit)
                    p.was_altered = True
                elif p.was_altered:
                    # self.debug('was altered recalculate position {} {}, {}'.format(sel, p.hole_id, p.identifier))
                    p.set_mean_j(self.plotter_options.use_weighted_fit)
                    p.was_altered = False

        self.predict_values(refresh=True)

    def traits_view(self):
        unk_cols = [column(klass=CheckboxColumn, name='save', label='Save', editable=True, width=30),
                    column(name='hole_id', label='Hole'),
                    column(name='identifier', label='Identifier'),
                    column(name='sample', label='Sample', width=115),
                    column(name='model_kind', label='Model'),
                    column(name='saved_j', label='Saved J',
                           format_func=sciformat),
                    column(name='saved_jerr', label=PLUSMINUS_ONE_SIGMA,
                           format_func=sciformat),
                    column(name='percent_saved_error',
                           label='%',
                           format_func=ff),
                    column(name='j', label='Pred. J',
                           format_func=sciformat,
                           width=75),
                    column(name='jerr',
                           format_func=sciformat,
                           label=PLUSMINUS_ONE_SIGMA,
                           width=75),
                    column(name='percent_pred_error',
                           label='%',
                           format_func=ff),
                    column(name='dev', label='dev',
                           format='%0.2f',
                           width=70)]

        unk_editor = TableEditor(columns=unk_cols, sortable=False,
                                 reorderable=False)

        pgrp = VGroup(HGroup(UItem('monitor_positions', editor=MONITOR_EDITOR),
                             show_border=True, label='Monitors'),
                      HGroup(UItem('unknown_positions', editor=unk_editor),
                             show_border=True, label='Unknowns'),
                      label='Tables')

        ggrp = UItem('graph', style='custom')
        tgrp = HGroup(UItem('recalculate_button'),
                      Item('min_j', format_str='%0.4e',
                           style='readonly',
                           label='Min. J'),
                      Item('max_j',
                           style='readonly',
                           format_str='%0.4e', label='Max. J'),
                      Item('percent_j_change',
                           style='readonly',
                           format_func=floatfmt,
                           label='Delta J(%)'),
                      Readonly('holder', label='Tray'),
                      # Item('j_gradient',
                      #      style='readonly',
                      #      format_func=floatfmt,
                      #      label='Gradient J(%/cm)'),
                      spring, icon_button_editor('save_unknowns_button', 'dialog-ok-5',
                                                 tooltip='Toggle "save" for unknown positions'),
                      icon_button_editor('save_all_button', 'dialog-ok-apply-5',
                                         tooltip='Toggle "save" for all positions'))

        v = View(VGroup(tgrp, Tabbed(ggrp, pgrp)))
        return v


class BracketingFluxResultsEditor(FluxResultsEditor):
    @on_trait_change('unknown_positions:[bracket_a, bracket_b]')
    def handle_bracket(self, obj, name, old, new):
        if obj.bracket_a and obj.bracket_b:
            a, b = [p for p in self.monitor_positions if p.hole_id in (obj.bracket_a, obj.bracket_b)]

            ws = array([1 / a.mean_jerr ** 2, 1 / b.mean_jerr ** 2])
            vs = array([a.mean_j, b.mean_j])
            if self.plotter_options.use_weighted_fit:

                j = average(vs, weights=ws)
                je = sum(ws)

            else:
                j, je = vs.mean(), vs.std()

            oj = obj.saved_j
            obj.j = j
            obj.jerr = je
            obj.dev = (oj - j) / j * 100

    def traits_view(self):
        unk_cols = [column(klass=CheckboxColumn, name='save', label='Save', editable=True, width=30),
                    column(name='hole_id', label='Hole'),
                    column(name='identifier', label='Identifier'),
                    column(name='sample', label='Sample', width=115),
                    column(name='bracket_a', editable=True, label='Bracket A', editor=EnumEditor(
                        name='available_positions')),
                    column(name='bracket_b', editable=True, label='Bracket B', editor=EnumEditor(
                        name='available_positions')),

                    column(name='saved_j', label='Saved J',
                           format_func=sciformat),
                    column(name='saved_jerr', label=PLUSMINUS_ONE_SIGMA,
                           format_func=sciformat),
                    column(name='percent_saved_error',
                           label='%',
                           format_func=ff),
                    column(name='j', label='Pred. J',
                           format_func=sciformat,
                           width=75),
                    column(name='jerr',
                           format_func=sciformat,
                           label=PLUSMINUS_ONE_SIGMA,
                           width=75),
                    column(name='percent_pred_error',
                           label='%',
                           format_func=ff),
                    column(name='dev', label='dev',
                           format='%0.2f',
                           width=70)]

        unk_editor = TableEditor(columns=unk_cols, sortable=False, reorderable=False)

        pgrp = VGroup(HGroup(UItem('monitor_positions', editor=MONITOR_EDITOR),
                             show_border=True, label='Monitors'),
                      HGroup(UItem('unknown_positions', editor=unk_editor),
                             show_border=True, label='Unknowns'),
                      label='Tables')

        tgrp = HGroup(UItem('recalculate_button'),
                      Item('min_j', format_str='%0.4e',
                           style='readonly',
                           label='Min. J'),
                      Item('max_j',
                           style='readonly',
                           format_str='%0.4e', label='Max. J'),
                      Item('percent_j_change',
                           style='readonly',
                           format_func=floatfmt,
                           label='Delta J(%)'),
                      Readonly('holder', label='Tray'),
                      # Item('j_gradient',
                      #      style='readonly',
                      #      format_func=floatfmt,
                      #      label='Gradient J(%/cm)'),
                      spring, icon_button_editor('save_unknowns_button', 'dialog-ok-5',
                                                 tooltip='Toggle "save" for unknown positions'),
                      icon_button_editor('save_all_button', 'dialog-ok-apply-5',
                                         tooltip='Toggle "save" for all positions'))

        ggrp = UItem('graph', style='custom')
        v = View(VGroup(tgrp, Tabbed(ggrp, pgrp)))
        return v

# ============= EOF =============================================