#!/usr/bin/env /Users/csafta/miniconda3/bin/python
import sys
import numpy as np
from   dateutil import parser
import json
from prime_mcmc       import ammcmc, save_mcmc_chain
from prime_posterior  import logpost, logpost_negb, logpost_poisson
from prime_utils      import runningAvg, compute_error_weight
def get_opts(setupfile,verb=False,return_run_setup=False):
    run_setup=json.load(open(setupfile))
    if verb:
        print("=====================================================")
        print(run_setup)
        print("=====================================================")
    run_opts = dict()
    #-daily counts
    run_opts["count_data"] = run_setup["regioninfo"]["count_data"]
    run_opts["population_data"] = run_setup["regioninfo"]["population_data"]
    if "running_avg_obs" in run_setup["regioninfo"]:
        run_opts["running_avg_obs"] = run_setup["regioninfo"]["running_avg_obs"]
    run_opts["region_tag"] = run_setup["regioninfo"]["region_tag"]
    run_opts["day0"] = run_setup["regioninfo"]["day0"]
    #------------------------------------------------------------------
    #-incubation model
    assert "num_waves" in run_setup["modelopts"]
    run_opts["num_waves"] = run_setup["modelopts"]["num_waves"]
    run_opts["useconv"] = run_setup["modelopts"]["useconv"]
    run_opts["inc_median"] = run_setup["modelopts"]["incubation_median"]
    run_opts["inc_sigma"] = run_setup["modelopts"]["incubation_sigma"]
    if "incubation_model" in run_setup["modelopts"]:
        run_opts["inc_model"] = run_setup["modelopts"]["incubation_model"]
    else:
        run_opts["inc_model"] = "lognormal"
    if "incubation_type" in run_setup["modelopts"]:
        run_opts["inc_type"] = run_setup["modelopts"]["incubation_type"]
    else:
        run_opts["inc_type"] = "deterministic"
    
    #------------------------------------------------------------------
    #-mcmc model parameters
    run_opts["mcmc_log"] = run_setup["mcmcopts"]["logfile"]
    run_opts["mcmc_nsteps"] = run_setup["mcmcopts"]["nsteps"]
    run_opts["mcmc_nfinal"] = run_setup["mcmcopts"]["nfinal"]
    run_opts["mcmc_gamma"] = run_setup["mcmcopts"]["gamma"]
    run_opts["inicov"] = np.array(run_setup["mcmcopts"]["cvini"])
    run_opts["inistate"] = run_setup["mcmcopts"]["cini"]
    if len(run_opts["inicov"].shape) == 1:
        run_opts['inicov'] = np.diag(run_opts["inicov"])
    run_opts["spllo"] = np.array(run_setup["mcmcopts"]["spllo"])
    run_opts["splhi"] = np.array(run_setup["mcmcopts"]["splhi"])
    #------------------------------------------------------------------
    #-bayes framework
    run_opts["lpf_type"] = run_setup["bayesmod"]["lpf_type"]
    run_opts["error_model_type"] = run_setup["bayesmod"]["error_model_type"]
    run_opts["prior_types"] = run_setup["bayesmod"]["prior_types"]
    run_opts["prior_info"] = run_setup["bayesmod"]["prior_info"]
          
    #------------------------------------------------------------------
    run_opts["days_extra"] = run_setup["ppopts"]["days_extra"]
    if return_run_setup:
        return run_opts, run_setup
    else:
        return run_opts
[docs]
def get_counts(run_opts,return_raw_data=False):
    """
    Get counts from raw files
    """
    # extract data from raw data
    days_since_day0 = []
    daily_counts = []
    rawdata_all = []
    for ireg, region in enumerate(run_opts["count_data"]):
        rawdata = np.loadtxt(region,delimiter=",",dtype=str)
        rawdata_all.append(rawdata)
        ndays = rawdata.shape[0]
        days_since_day0.append(np.array([(parser.parse(rawdata[i,0])-parser.parse(run_opts["day0"])).days \
                                    
for i in range(ndays)]))
        daily_counts.append(np.array([float(rawdata[i,1]) for i in range(rawdata.shape[0])]))
        # scale daily counts
        daily_counts[-1] = daily_counts[-1]/(run_opts["population_data"][ireg] * 1.e6)
        # run averages
        if "running_avg_obs" in run_opts:
            daily_counts[-1] = runningAvg(daily_counts[-1], run_opts["running_avg_obs"])
            print("Taking {}-day running average of observations for {}".format(run_opts["running_avg_obs"],run_opts["region_tag"][ireg]))
    if return_raw_data:
        return days_since_day0, daily_counts, rawdata_all
    else:
        return days_since_day0, daily_counts 
[docs]
def main(setupfile):
    r"""
    Driver script to run MCMC for parameter inference for a multi-wave 
    epidemic model. Currently limited to up to three infection curves.
    To run this script:
    python <path-to-this-directory>/prime_run.py <name-of-json-input-file>
    Parameters
    ----------
    setupfile: string
        json format input file with information on observations data, filtering options,
        MCMC options, and postprocessing options. See "setup_template.json" for a detailed
        example
    """
    #----------- --------------------------------------------
    setupfile=sys.argv[1]
    run_opts = get_opts(setupfile)
    print(run_opts)
    print("=====================================================")
    
    # #-------------------------------------------------------
    # # definitions
    # fdata = run_setup["regioninfo"]["regionname"]+".dat"
    # day0  = run_setup["regioninfo"]["day0"]
    
    #-------------------------------------------------------
    # echo some settings
    print("Running inference with %d waves"%(run_opts["num_waves"]))   
    print("Error model %s"%(run_opts["error_model_type"]))
    assert run_opts["error_model_type"] in ["add","addMult"]
    #-------------------------------------------------------
    # get counts
    days_since_day0, daily_counts = get_counts(run_opts)
    
    #-------------------------------------------------------
    # mcmc
    opts = {"nsteps": run_opts["mcmc_nsteps"], "nfinal": run_opts["mcmc_nfinal"],"gamma": run_opts["mcmc_gamma"],
            "inicov": np.array(run_opts["inicov"]),"inistate": np.array(run_opts["inistate"]),
            "spllo": np.array(run_opts["spllo"]),"splhi": np.array(run_opts["splhi"]),
            "logfile": run_opts["mcmc_log"],"burnsc":5,
            "nburn":1000,"nadapt":100,"coveps":1.e-10,"ofreq":5000,"tmpchn":"tmpchn"
            }
    modelinfo={"num_waves":        run_opts["num_waves"],
               "error_model_type": run_opts["error_model_type"],
               "days_since_day0":  days_since_day0,
               "daily_counts":     daily_counts,
               "incubation_model": run_opts["inc_model"],
               "incubation_median":run_opts["inc_median"],
               "incubation_sigma": run_opts["inc_sigma"], 
               "incubation_type":  run_opts["inc_type"], 
               "inftype":          "gamma",
               "useconv":          run_opts["useconv"],
               "days_extra":       0,
               "prior_types":run_opts["prior_types"],"prior_info": run_opts["prior_info"]}
    
    # Convolution vs Quadrature:
    #   -The user can choose to use a fft convolution instead of 
    #    quadrature to perform the integration of Y(t)
    #   -default is set to zero if the user defines nothing
    #   -To set, add "useconv":1 to the mcmcopts in the *json file 
    if modelinfo["useconv"] == 1:
        print("Using FFT convolution instead of quadrature")
    
    # choose log-posterior function
    logpost_types={"gaussian":logpost,"negative_binomial":logpost_negb,"poisson":logpost_poisson}
    lpf = run_opts["lpf_type"]
    if lpf == "poisson":
        modelinfo["sumLogK"] = sum([sum([np.log(i) for i in range(1,int(k)+1)]) for k in daily_counts if k>0])
    
    # run MCMC
    sol=ammcmc(opts,logpost_types[lpf],modelinfo)
    save_mcmc_chain("".join(run_opts["region_tag"])+"_mcmc.h5",sol) 
if __name__ == '__main__':
    main(sys.argv[1])