from mpmath import *
mp.dps = 60
import matplotlib.pyplot as plt
import numpy as np

# Read eigenvalues into a list of real numbers (mpf)
with open('eigenvalues') as f:
    eigenvalues_ = f.read().splitlines()

# How many eigenvalues to take? (max = len(eigenvalues_))
maxn_eigen = len(eigenvalues_)
# Uncomment below to round the numbers
eigenvalues = []
for i in range(0, maxn_eigen):
#	eigenvalues.append(floor(mp.mpf(eigenvalues_[i]) * 100000.0) / 100000.0)
    eigenvalues.append(mp.mpf(eigenvalues_[i]))

X = 49
logX = ln(mp.mpf(X))

# Now create a list with the proper values cos( t_j * log X)
values = []
for i in range(0, maxn_eigen):
    values.append(mp.mpc(cos(eigenvalues[i] * logX), sin(eigenvalues[i] * logX)))

# Finally create the sums for the plot
sums1 = []
sums2 = []
tempsum = mp.mpc(0,0)
minT = 1
maxT = 830
minIndex = maxIndex = -1

normalisation = 1.0

extra_wave_real = lambda t,y: pow(eigenvalues[t],1.0) * sin(eigenvalues[t] * y) / (3 * y)
extra_wave_imag = lambda t,y: -pow(eigenvalues[t],1.0) * (cos(eigenvalues[t] * y)) / (3 * y)

##### Comment out the 2 lines below in order to subtract the oscillatory main term

extra_wave_real = lambda t,y: 0
extra_wave_imag = lambda t,y: 0

for i in range(0, maxn_eigen):
    tempsum += values[i]
    if eigenvalues[i] >= minT and eigenvalues[i] <= maxT:
        if minIndex == -1:
            minIndex = i
        sums1.append(mp.mpf(2.0) * (tempsum.real) / pow(eigenvalues[i],normalisation) - extra_wave_real(i, logX) / pow(eigenvalues[i], normalisation))
        sums2.append(mp.mpf(2.0) * (tempsum.imag) / pow(eigenvalues[i],normalisation)- extra_wave_imag(i, logX) / pow(eigenvalues[i], normalisation))
    if (maxIndex == -1 and eigenvalues[i] > maxT) or i == maxn_eigen:
        maxIndex = i

if minIndex == -1 or maxIndex == -1:
    print('Error: index = 0')
print len(sums1)
print "minIndex = " + str(minIndex) + ", maxIndex = " + str(maxIndex)
xcoords = eigenvalues[minIndex : maxIndex]
plt.figure(1)
with workdps(1):
    plt.suptitle('$X='+mp.nstr(X)+'$',fontsize=30,y=0.92)
plt.plot(xcoords,sums1)
plt.ylim([-1.0,1.0])
plt.plot(xcoords,sums2)
plt.tight_layout()
plt.savefig('tplot.pdf', format='pdf', transparent=True)
plt.close()
#plt.show()
