"""
Plot fgmax output from GeoClaw run.

"""

from pylab import *
from matplotlib import image
from numpy import ma
import os
from clawpack.geoclaw import fgmax_tools, geoplot, dtopotools
import interptools

# Specific to WA coast project:
try:
    WAcoast = os.environ['WAcoast']
except:
    raise Exception("Need to set WAcoast environment variable")

#topofile = WAcoast+'/topo/la_push_wa.asc'  # finest level topo file from run
dtopofile = WAcoast + '/dtopo/CSZ_L1.tt3'  # source model to determine subsidence
dtopotype = 3  # format of dtopo file

# Google Earth image for plotting on top of...
# need to change for specific location on coast:
# This version will plot multiple fixed grid (FG) solutions
# Add or subtract the dictionary entries, below, for each FG, as appropriate
# Also, add or subtract "make_plots" calls at the bottom of this code, as needed
 
title_name = {}
FG_name = {}
fgmax_input_file = {}
FGfigno = {}
GEmap = {}
GEextent = {}

title_name[1] = "NeahMakah "
FG_name[1] = "NeahMakah"
fgmax_input_file[1] = "fgmax1.txt"
FGfigno[1] = 201
GEmap[1] = image.imread(WAcoast + '/maps/NeahMakahFGmax.png')
GEextent[1] = (-124.692,-124.576,48.2710,48.3965)  # NeahMakah

### PRIMARY FUNCTION DEFINITION TO CYCLE THROUGH EACH FIXED GRID (FG)

def make_plots(outdir='_output', plotdir='_plots', gridno=1):

    # Specify directories for the input and output files
    topdir = '.'                               # top level directory
    outdir = os.path.join(topdir, '_output')   # where to find output
    if not os.path.isdir(outdir):
        raise Exception("Missing directory: %s" % outdir)
    if not os.path.isdir(plotdir):
        os.mkdir(plotdir)
    plotdir = os.path.join(topdir,'_plots')                # where to put plots
    
    # Specify input file names for each FG
    fgmax_gridno = 'fgmax%s.txt' % gridno            # FG description file name
    fname = outdir + '/fort.FG%s.valuemax' % gridno  # FG data file name (GeoClaw output file)
    fgmax_input_file = os.path.join(topdir, fgmax_gridno)  # grid description
    
    print 
    print
    print '********* Start ', FG_name[gridno], ' plots ************'
    print
    print "gridno = ", gridno
    print "FG_name = ", FG_name[gridno]
    print "Reading output from ",outdir
    print "Using fgmax input from ",fgmax_gridno
    print "Reading %s ..." % fname
    print "fname = ", fname
    print "x1,x2,y1,y2 = ", GEextent[gridno]
    
    # Specify plotting parameters common to each FG    
    figsize = (9,7)     # size of figures, may want to adjust for different regions
    cbar_shrink = 1.0   # to shrink colorbar
        # Additional colors for topography contours.  See selection, e.g., at
        #  http://www.discoveryplayground.com/computer-programming-for-kids/rgb-colors/
    DeepPink  = '#ff1493'
    Gray      = '#bebebe' 
    LightGray = '#d3d3d3'
    linecolors_topo = ['k',LightGray]    # Coast = black, higher = another color
    
    # Specify the plots to be created
    plot_zeta = True
    plot_zeta_map = True  # set to false if no image available
    plot_zeta_times = True  # plot time of maximum flow depth?
    plot_arrival_times = True            # plot time of first arrival?
    plot_arrival_times_on_zeta = False   # plot contours of arrival time on depth?

    # Specify contour levels for the plots to be created
    clines_zeta = [0, 0.25, 0.5, 0.75, 1, 1.5, 2, 5, 10, 16]   # meter
    clines_t = (0,2,4,6,8,10,15,20,30,40,50,60)  # minutes
    clines_t_label = clines_t[::2]  # which ones to label 
    clines_t_colors = [.5,.5,.5]    # RGB of color for contour lines
    clines_topo = linspace(0,20,2)  # contours for topography
        # The speed, hs, hss, min_depth data might not be available:
        # To get them, set rundata.fgmax_data.num_fgmax_val = 2 or 5 in setrun.py
    plot_speed = True                   # plot max flow speed?
    plot_others = True                  # plot max momentum, mom flux, min depth
    clines_speed = [0,0.5,1,1.5,2,3,4,5,6,7,8,9,10,12,14]  # contours for speed m/s
    clines_hs = [0,50,100,120,140,150,160,170,180,200]     # contours for momentum
    clines_hss = [0,50,100,200,300,400,500,600,700,800,900,1000] # momentum flux
    clines_min_depth = [0,0.5,1,1.5,2,2.5,3]     # contours for min depth
    
    # Read in the data for this FGmax grid and specify the output file for the plots
    fg = fgmax_tools.FGmaxGrid()
    fg.read_input_data(input_file_name=fgmax_gridno)  # Read in the FG description file
    fg.read_output(fgno=gridno,outdir=outdir) # Read the GeoClaw results on the fgmax grid numbered *fgno*.
    
    # Read in topography and find the average bottom deformation
    dtopo = dtopotools.DTopography()
    dtopo.read(dtopofile, dtopotype)
    dzi = interptools.interp(fg.X,fg.Y,dtopo.X,dtopo.Y,dtopo.dZ[1,:,:])
    print "Average subsidence/uplift:  %s m"  % dzi.mean()
    
    # Define and adjust variables for plotting
    sea_level = 0.      # relative to MHW
    dry_tol = 1e-3      # mask colors where max depth h < dry_tol
    x,y = fg.X, fg.Y  # grid
    y_ave = y.mean()  # for scaling by latitude
    h = fg.h  # max depth
    B = fg.B  # topography
    # adjust final GeoClaw topo B by subsidence/uplift 
    # to recover original bathymetry before earthquake:
    B0 = B - dzi
    on_shore = (B0 > sea_level)   # relative to original bathymetry B0
    dry = (h < dry_tol)
    zeta = where(on_shore, h, h + B)   # h onshore, h+B offshore
    zeta = ma.masked_where(dry, zeta)  # mask in regions where it stays dry
    arrival_time = fg.arrival_time / 60. # convert to minutes
    
    ## SECONDARY FUNCTION DEFINITION TO PLOT ZETA_MAX AND ARRIVAL TIMES
    
    def make_zeta_plot(linecolors_topo,on_map=True):
        
        figure(201, figsize=figsize)
        clf()
        
        if on_map:
            # Plot on a Google Earth image (GE map)
            imshow(GEmap[gridno],extent=GEextent[gridno])
            
        # Plot zeta_max and topography contours
        print "max zeta = ", zeta.max()
        colors = geoplot.discrete_cmap_1(clines_zeta)
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contourf(x,y,zeta,clines_zeta,colors=colors,extend='max',lw=linewidths)
        cbar = colorbar(shrink=cbar_shrink, spacing='proportional')
        cbar.set_ticks(clines_zeta)
        # cbar.set_ticklabels([0,2,4,6,8,10,12,14,16], update_ticks=False)
        # cbar.set_ticklabels([0,2,4,6,8,10,12,14,16])
        cbar.set_label('meters', fontsize=15)        
        # Plot contours of original topo:
        contour(x,y,B0,clines_topo,colors=linecolors_topo,linestyles='-',linewidth=linewidths)
        
        # Plot zeta_max Arrival times
        if plot_arrival_times_on_zeta:
            cs = contour(x,y,arrival_time,clines_t,colors=clines_t_colors,linewidth=linewidths)
            clabel(cs,clines_t_label)
            title(FG_name[gridno] + " Zeta Maximum and arrival times",fontsize=15)
        ticklabel_format(format='plain',useOffset=False)
        xticks(rotation=20)
        gca().set_aspect(1./cos(y_ave*pi/180.))  # scale for latitude
        title(FG_name[gridno] + " Zeta Maximum",fontsize=14)
        if on_map:
            fname = plotdir + '/' + FG_name[gridno] + '_' +'zeta_map.png' 
        else:
            fname = plotdir + '/' + FG_name[gridno] + '_' +'zeta.png' 
        savefig(fname)
        print "Created ",fname
        
    # Use function make_zeta_plot.
    # Plot zeta_max with and/or without a GE image and with and/or without arrival times
    
    if plot_zeta:
        make_zeta_plot(linecolors_topo,False) # False means no GE map used
    if plot_zeta_map:

        make_zeta_plot(linecolors_topo,True)  # True means GE map is used
    if plot_zeta_times:
        # time maximum h recorded:
        tzeta = fg.h_time / 60.  # convert to minutes
        tzeta = ma.masked_where(dry, tzeta)
        # Plot time max h recorded:
        figure(102, figsize=figsize)
        clf()
        if clines_t is None:
            clines_t = linspace(tzeta.min(), tzeta.max(), 10)
        colors = geoplot.discrete_cmap_2(clines_t)
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] # Make the first (Coast) contour thick
        contourf(x,y,tzeta,clines_t,colors=colors,extend='max',lw=linewidths)
        cbar = colorbar(shrink=cbar_shrink, spacing='proportional')
        cbar.set_ticks(clines_t)
        cbar.set_label('minutes',fontsize=15)
        # Plot topographp contours
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contour(x,y,B0,clines_topo,colors=linecolors_topo,linestyles='-',linewidth=linewidths)
        ticklabel_format(format='plain',useOffset=False)
        xticks(rotation=20)
        gca().set_aspect(1./cos(y_ave*pi/180.))
        title(FG_name[gridno] + ' Time of max zeta', fontsize=14)
        fname = plotdir + '/' + FG_name[gridno] + '_' +'zetatimes.png' 
        savefig(fname)
        print "Created ",fname
    if plot_arrival_times:
        # Plot first arrival time:
        figure(103, figsize=figsize)
        clf()
        if clines_t is None:
            clines_t = linspace(atimes.min(), atimes.max(), 10)
        colors = geoplot.discrete_cmap_2(clines_t)
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contourf(x,y,arrival_time,clines_t,colors=colors,extend='max',lw=linewidths)
        cbar = colorbar(shrink=cbar_shrink, spacing='proportional')
        cbar.set_ticks(clines_t)
        cbar.set_label('minutes',fontsize=15)
        # Plot topographp contours
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contour(x,y,B0,clines_topo,colors=linecolors_topo,linestyles='-',linewidth=linewidths)
        ticklabel_format(format='plain',useOffset=False)
        xticks(rotation=20)
        gca().set_aspect(1./cos(y_ave*pi/180.))
        title(FG_name[gridno] + ' Arrival time', fontsize=14)
        fname = plotdir + '/' + FG_name[gridno] + '_' +'arrival_times.png' 
        savefig(fname)
        print "Created ",fname
    
    ## SECONDARY FUNCTION DEFINITION TO PLOT SPEED, MOMENTUM, MOMENTUM FLUX AND MINIMUM DEPTH
        # i.e., the variables:  speed, h, hs, hss, min_depth
    
    def plot_variable(name, v, clines, units='m', linecolors_topo='k', on_map=True):
        figure(figsize=figsize)
        clf()
        if on_map:
            # plot on map
            imshow(GEmap[gridno],extent=GEextent[gridno])
        print "max %s = %s" % (name,v.max())
        # Plot the variable 'v' that's passed to this def plot_variable
        colors = geoplot.discrete_cmap_1(clines)
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contourf(x,y,v,clines,colors=colors,extend='max',lw=linewidths)
        cbar = colorbar(shrink=cbar_shrink, spacing='proportional')
        cbar.set_ticks(clines)
        cbar.set_label(units, fontsize=15)
        # Plot topography contours:
        linewidths = [1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
        contour(x,y,B0,clines_topo,colors=linecolors_topo,linestyles='-',linewidth=linewidths)
        ticklabel_format(format='plain',useOffset=False)
        xticks(rotation=20)
        gca().set_aspect(1./cos(y_ave*pi/180.))
        if name == 'min_depth':
            title(FG_name[gridno] + "Minimum flow depth",fontsize=14)
        else:
            title(FG_name[gridno] + "Maximum %s" % name,fontsize=14)
        if on_map:
            fname = plotdir + '/' + FG_name[gridno] + '_' +'%s_map.png' % name
        else:
            fname = plotdir + '/' + FG_name[gridno] + '_' +'%s.png' % name
        savefig(fname)
        print "Created ",fname
        
    # Use function plot_variable.
    # Plot the variables:  speed, h, hs, hss, min_depth
    
    if plot_speed:
        speed = ma.masked_where(dry, fg.s)
        for on_map in [True,False]:
            plot_variable('speed',speed,clines_speed,'m/s',linecolors_topo,on_map)
    if plot_others:
        hs = ma.masked_where(dry, fg.hs)
        hss = ma.masked_where(dry, fg.hss)
        min_depth = ma.masked_where(on_shore, -fg.hmin)
        for on_map in [True,False]:
            plot_variable('hs',hs,clines_hs,'m**2 / s',linecolors_topo,on_map)
            plot_variable('hss',hss,clines_hss, 'm**3 / s**2',linecolors_topo,on_map)
            plot_variable('min_depth',min_depth,clines_min_depth,'m',linecolors_topo,on_map)

# Use Function make_plots:
# Make all the plot files for each FG 
# This is not yet automated.  Add or subract calls 
# to match the number of FG grids being processed

if __name__ == "__main__":
    make_plots(plotdir='_plots/',gridno=1)
    # make_plots(plotdir='_plots/',gridno=2)
    # make_plots(plotdir='_plots/',gridno=3)
    # make_plots(plotdir='_plots/',gridno=4)
