import time
import os
import numpy as np
from Additional_Funs.detection_fun import detection_veq02 as dtctv
from obspy import read as rdseed
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
matplotlib.use('TkAgg')

tic = time.time()

pwd_folder = os.getcwd()
data_folder = pwd_folder + '/Waveform_Data/'
File_Name = 'VG.MEPAS.00.HHZ.D.2024.325.msd'

output_csv = pwd_folder + '/Csv_Output_All_Detections/'

# Read Mseed File
stream_ori = rdseed((data_folder + File_Name), type='mseed')

# Get stats
str_channel = stream_ori.traces[0].stats.channel
str_network = stream_ori.traces[0].stats.network
str_loc = stream_ori.traces[0].stats.location
str_station = stream_ori.traces[0].stats.station
str_start = stream_ori.traces[0].stats.starttime
str_npts = stream_ori.traces[0].stats.npts
str_fs = stream_ori.traces[0].stats.sampling_rate
str_dur = str_npts/str_fs/3600

# Stream processing
stream_proc = stream_ori.copy()

# Extract Data (np array) from Stream
data_eq_int = stream_proc.traces[0].data
data_eq = data_eq_int.astype('f')

# Choose Period of Interest
hr_init = ''
hr_final = ''
while hr_init == '' or hr_final == '':
    hr_init = (input('Specify initial hour, may start from zero (integer value): '))
    hr_init = int(hr_init)
    hr_final = (input('Specify final hour, must be greater than initial hour (integer value): '))
    hr_final = int(hr_final)
    if hr_init == '' or hr_final == '':
        print("You haven't input either initial hour or final hour\nPlease do the input again....")
    else:
        hr_init = int(float(hr_init))
        hr_final = int(float(hr_final))
        if hr_init >= hr_final:
            print('Your initial hour is greater than or equal to final hour\nPlease do the input again...')
            hr_init =''
            hr_final = ''
        else:
            print('The selected initial hour is ' + str(hr_init) + ', while the selected final hour is '
            + str(hr_final))

hr_dur = hr_final - hr_init
cdat_pick = data_eq.copy()
cdat = cdat_pick[int(hr_init * 3600 * str_fs):int(hr_final * 3600 * str_fs)]

# Create Left - Right Buffer
t_buffer = 60                  # seconds
# Left Buffer
if hr_init == int(0):
    ldat = np.repeat(cdat[0], t_buffer * str_fs)
else:
    ldat = cdat_pick[int(hr_init * 3600 * str_fs - (t_buffer * str_fs)):int(hr_init * 3600 * str_fs)]
# Right Buffer
if cdat[-1] == cdat_pick[-1]:
    rdat = np.repeat(cdat[-1], t_buffer * str_fs)
else:
    rdat = cdat_pick[int(hr_final * 3600 * str_fs):int(hr_final * 3600 * str_fs + (t_buffer * str_fs))]

w_final = np.concatenate((ldat,cdat,rdat))

# STA/LTA Parameter
t_sta = 2                 # STA Window, seconds
t_lta = 20                # LTA Window, seconds
tstalta_smoothing = 10    # used to smooth STA/LTA values, seconds

# Threshold for trigger on and trigger off
# May be arbitrary, trial and errors, or derived statistics from quiet period

thresh_on = 0.2551       # mean + 3 std deviation (log STA/LTA)
thresh_off = -0.0056     # mean (log STA/LTA) MERAPI MEPAS 25 Dec 2020 UTC STA = 2s LTA = 20s

# Envelope Parameter

tenv_smoothing = 2                 # used to smooth EQ's envelope, seconds
limit_envelope = 1.1               # used to determine the duration

# Detect Earthquakes

# Input the Required Variables

input_det = [w_final, cdat, str_fs, str_station, t_buffer, t_sta, t_lta,
             tstalta_smoothing, thresh_on, thresh_off, tenv_smoothing,
             limit_envelope, str_start, hr_init, hr_final, output_csv]

time_array, data_fin, t_plot_ratio, idx_dt, stalta_ratio, smoothed_env = dtctv(input_det)

# Remove time_array that contain left buffer and right buffer
time_array_plot = time_array.copy()
data_fin_plot = data_fin.copy()

for j in range(len(time_array_plot)):
    if time_array_plot[j] < float(t_buffer) or time_array_plot[j] > (time_array_plot[-1] - float(t_buffer)):
        time_array_plot[j] = -99999.00

rmv_idx_buff = np.where(time_array_plot == -99999.00)
time_plot = np.delete(time_array_plot, rmv_idx_buff, axis=0) - float(t_buffer)
time_plotr = np.round(time_plot, 2)
data_fin_plot = np.delete(data_fin_plot, rmv_idx_buff, axis = 0)


sz_detect = idx_dt.shape
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.plot(time_plotr, data_fin_plot, linewidth=0.4, color='black')
plt.xlabel('Time (seconds)')
plt.ylabel('Amplitude (counts)')

for j in range(sz_detect[0]):
    prect_1 = patches.Rectangle((t_plot_ratio[idx_dt[j, 0]] - float(t_buffer), -5000.),
                                (t_plot_ratio[idx_dt[j, 1]] - t_plot_ratio[idx_dt[j, 0]]),
                                10000., alpha=0.2, ec="none", fc="red")
    ax.add_patch(prect_1)
    ax.text(t_plot_ratio[idx_dt[j, 0]] - float(t_buffer) + \
            (t_plot_ratio[idx_dt[j, 1]] - t_plot_ratio[idx_dt[j, 0]])/2,
            3000, str(j+1), color='blue', fontweight='bold')


plt.grid()
plt.show(block=False)

print('===================================================')


toc = time.time()

elapsed_time = toc - tic
print('elapsed time is:" {:.3f}'.format(elapsed_time) +' seconds')
