#!/bin/bash -l
#SBATCH --account=harmslab
#SBATCH --job-name=tfscreen
#SBATCH --output=hostname.out
#SBATCH --error=hostname.err
#SBATCH --partition=gpu
#SBATCH --time=00-02:00:00
#SBATCH --nodes=1
#SBATCH --cpus-per-task=8
#SBATCH --gpus=1
#SBATCH --ntasks-per-node=1

# Swap lines if on cluster (module load) or local (XLA_FLAGS to use CPU)
module load cuda/12.4.1
#export XLA_FLAGS="--xla_force_host_platform_device_count=8"

# Uncomment for debugging
#export JAX_TRACEBACK_FILTERING=off
#export JAX_ENABLE_X64=True

# Crash on any failure
set -e

# ---------------------------------------------------------------------------
# Template variables (set by tfs-setup-grid from the template: blocks)
#
# NOTE: tfs-configure-model already ran during grid setup and produced
# tfs_configure_config.yaml in this directory.  The model component choices
# (theta, condition_growth, etc.) live there; they do not appear here.
# ---------------------------------------------------------------------------

RUN_SEED={{ seed }}
PREDICT_GENOTYPES_FILE="{{ predict_genotypes_file }}"

# ---------------------------------------------------------------------------
# Run prefit calibration
# ---------------------------------------------------------------------------

echo ">>> Prefit calibration"
tfs-prefit-calibration tfs_configure_config.yaml \
    --seed ${RUN_SEED} \
    --convergence_tolerance 0.00001

# ---------------------------------------------------------------------------
# MAP fit
# ---------------------------------------------------------------------------

echo ">>> MAP fit"
tfs-fit-model \
    tfs_configure_config.yaml \
    --seed ${RUN_SEED} \
    --analysis_method map \
    --adam_step_size 1e-6 \
    --convergence_check_interval 100 \
    --convergence_window 50 \
    --checkpoint_interval 100 \
    --max_num_epochs 100000000 \
    --pre_map_num_epoch 100000 \
    --convergence_tolerance 0.0005 \
    --patience 5

# ---------------------------------------------------------------------------
# Sample posterior
# ---------------------------------------------------------------------------

echo ">>> Sample posterior"
tfs-sample-posterior \
    tfs_configure_config.yaml \
    tfs_fit_model_checkpoint.pkl \
    --sampling_batch_size 10 \
    --num_posterior_samples 1000 \
    --seed ${RUN_SEED}

# ---------------------------------------------------------------------------
# Summarize and predict
# ---------------------------------------------------------------------------

echo ">>> Extract parameters"
tfs-extract-params \
    tfs_configure_config.yaml \
    tfs_posterior.h5

echo ">>> Predict growth"
tfs-predict-growth \
    tfs_configure_config.yaml \
    tfs_posterior.h5

echo ">>> Predict theta"
tfs-predict-theta \
    tfs_configure_config.yaml \
    tfs_posterior.h5 \
    --genotypes_file ${PREDICT_GENOTYPES_FILE}

echo ">>> Categorize response"
tfs-cat-response \
    tfs_theta_pred.csv \
    --workers 8
