#!/usr/bin/ruby
require 'complex'
# The array class is abused for representing polynoms (i.e. the Z-transforms
# of digital filters).
class Array

  # Create array with real values of each element.
  def real
    collect { |x| x.real }
  end

  # Create array with imaginary value of each element.
  def imag
    collect { |x| x.imag }
  end

  # Compute average of values.
  def average
    sum / size
  end

  # Compute sum of values.
  def sum
    inject( 0.0 ) { |s,v| s + v }
  end

  # Compute absolute value of each element.
  def abs
    collect { |x| x.abs }
  end

  # Multiply each element with a scalar value.
  def mult( scalar )
    collect { |x| x * scalar }
  end

  # Compute element-wise sum of two arrays.
  def plus( other )
    zip( other ).collect { |t| t[0]+t[1] }
  end

  # Compute element-wise difference of two arrays.
  def diff( other )
    zip( other ).collect { |t| t[0]-t[1] }
  end

  # Downsampling (s is offset, d is sampling rate).
  def downsample( s = 1, d = 2 )
    retval = []
    each_with_index { |x,i| retval << x if i % d == s }
    retval
  end

  # Upsampling (s is offset, d is sampling rate).
  def upsample( s = 1, d = 2 )
    retval = [ 0.0 ] * s
    each do |x|
      retval << x
      ( d - 1 ).times { retval << 0.0 }
    end
    s.times { retval.pop }
    retval
  end

  # Shift array left or right and fill with zeros.
  def roll( k )
    retval = dup
    if k < 0
      for i in k...0
        retval.shift
        retval.push( 0.0 )
      end
    else
      for i in 0...k
        retval.pop
        retval.unshift( 0.0 )
      end
    end
    retval
  end

  # Iterate over each slice of a given size.
  def each_slice( n )
    y = [ 0.0 ] * n
    each do |x|
      y.pop
      y.unshift( x )
      yield y
    end
    (n-1).times do
      y.pop
      y.unshift( 0.0 )
      yield y
    end
  end

  # Polynomial division with remainder.
  def pdiv( arr )
    if size >= arr.size
      factor = 1.0 * self[0] / arr[0]
      remainder = diff( ( arr + [0.0] * ( size - arr.size ) ).mult( factor ) )
      remainder.shift
      retval = remainder.pdiv( arr )
      retval[0] = ( [0.0] + retval[0] ).plus( [factor] +
                                              [0.0] * ( size - arr.size ) )
      retval
    else
      [ [], self ]
    end
  end

  # Polynomial division without remainder (using least squares).
  def pediv( arr )
    if arr.size > 1
      rsize = size + 1 - arr.size
      zeros = [0.0] * ( rsize - 1 )
      padded = ( zeros + arr.reverse + zeros )
      c = []
      for i in 0...size
        pos = rsize - 2 + arr.size - i
        c.push( padded[ pos...pos+rsize ] )
      end
      h = Matrix[ *c ]
      ( ( h.t * h ).inverse * h.t * Vector[ *self ] ).to_a
    else
      mult( 1.0 / arr[0] )
    end
  end

  # Convolve two arrays.
  def convolvex( arr )
    retval = []
    each_slice( arr.size ) do |x|
      retval << x.zip( arr ).inject( 0.0 ) { |sum,t| sum + t[0] * t[1] }
    end
    retval
  end

  # Correlate array k-times with itself.
  def **( k )
    if k == 0
      [1]
    elsif k ==1
      self
    else
      correlatex( self ** ( k - 1 ) )
    end
  end

  # Correlate two arrays.
  def correlatex( other )
    reverse.convolvex( other )
  end

  # Multiplicate two polynoms (same as convolution of arrays).
  def pmul( other )
    convolvex( other )
  end

  # Compute k-th derivative of polynom.
  def pderiv( k = 1 )
    if ( k == 0 )
      self
    elsif ( k == 1 )
      retval = []
      each_with_index { |x,i| retval.push( x * ( size - i - 1 ) ) }
      retval.pop
      retval
    else
      pderiv( 1 ).pderiv( k - 1 )
    end
  end

  # Evaluate polynom at a certain point.
  def peval( x )
    retval = 0
    each_with_index { |c,i| retval += c * x ** ( size - i - 1 ) }
    retval
  end

  # Compute inverse of filter.
  def inverse( offset = +1 )
    retval = []
    each_with_index { |x,i| retval.push( x * (-1)**i * offset ) }
    retval
  end

end

# Find a root of a polynom using Laguerre's algorithm.
def laguerre( poly, maxiter = 100 )
  x = 0.0
  c=0
  until c>maxiter
    xold = x
    # problem: poly.peval(x) -> 0
    g = poly.pderiv.peval( x ) / poly.peval( x )
    break if not g.real.to_f.finite? or not x.imag.to_f.finite?
    h = g * g - poly.pderiv( 2 ).peval( x ) / poly.peval( x )
    n = poly.size - 1
    d2 = ( ( n - 1 ) * ( n * h - g ** 2 ) )
    d = Math::sqrt( d2 )
    if ( g + d ).abs > ( g - d ).abs
      a = n / ( g + d )
    else
      a = n / ( g - d )
    end
    x -= a
    c += 1
  end
  xold
end

# Numerical factorisation using Laguerre's algorithm.
def pfact( poly, err = 1.0e-12 )
  if ( poly.size < 2 )
    []
  else
    # TODO: Find ggt of polynom and derivative to remove duplicate poles.
    zero_crossing = laguerre( poly )
    other_crossings = pfact( poly.pediv( [ 1.0, -zero_crossing ] ), err )
    [ zero_crossing ] + other_crossings
  end
end

# Create polynom with given zero-crossings (inverse of factorisation).
def pmult( zeros )
  retval = zeros.inject( [1.0] ) { |p,z| p.pmul([1.0,-z]) }
  retval.mult( 1.0 / retval[ retval.size.div( 2 ) ] )
end

# Create polynom with given zero-crossings and another zero-crossing at
# 1/x for each zero-crossing at x (symmetric polynom).
def realsympoly( zero_crossing )
  [ 1.0, -zero_crossing ].pmul( [ 1.0, -1.0 / zero_crossing ] )
end

# Create polynom with given zero-crossings and three other zero-crossings
# at 1/x, x.conj and 1/x.conj for each zero-crossing at x (symmetric,
# real-valued polynom).
def complexsympoly( zero_crossing )
  [ 1.0, -zero_crossing ].
    pmul( [ 1.0, -1.0 / zero_crossing ] ).
    pmul( [ 1.0, -zero_crossing.conj ] ).
    pmul( [ 1.0, -1.0 / zero_crossing.conj ] )
end

# Polynomial factorisation for real-valued, symmetric polynoms.
def sympfact( poly, err = 1.0e-12 )
  if ( poly.size < 2 )
    []
  else
    zero_crossing = laguerre( poly )
    reduced = poly.pdiv( complexsympoly( zero_crossing ) )
    if reduced[1].abs.sum > err
      reduced = poly.pediv( realsympoly( zero_crossing ) )
      zeros = [ zero_crossing, 1.0 / zero_crossing ]
    else
      reduced = poly.pediv( complexsympoly( zero_crossing ) )
      zeros = [ zero_crossing, 1.0 / zero_crossing,
                zero_crossing.conj, 1.0 / zero_crossing.conj ]
    end
    [ zeros ] + sympfact( reduced, err )
  end
end

class Integer
  # Compute factorial of integer.
  def factorial
    (1..self).inject( 1 ) { |f, n| f * n }
  end
  # Compute k out-of n.
  def out_of( n )
    n.factorial / ( ( n - self ).factorial * self.factorial )
  end
end

# Thiran filter (approximates delay "tau"). It is used to create approximate
# half-sample delay filters here.
def thiran( l, tau = 0.5 )
  (0..l-1).collect do |n|
    n.out_of( l - 1 ) * (-1)**n *
      (0..n-1).inject(1.0) { |p,k| p * ( tau - l + 1 + k ) / ( tau + 1 + k ) }
  end
end

# Complex wavelet design for biorthogonal wavelets by Selesnick.
# http://taco.poly.edu/selesi/ComplexWavelets/
def hwt_wavelet_design( k, ks, l )
  raise "k+ks+2*l-1 must be odd!" if not ( k + ks + 2 * l - 1 ) % 2 == 1
  d=thiran(l)
  s=([1,1]**(k+ks)).pmul(d).pmul(d.reverse)
  c=[]
  for i in 0...s.size
    c.push( s.roll( 2 * ( i - s.size.div(2) ) ) )
  end
  c=Matrix[*c]
  b=[]
  for i in 0...s.size
    if i - s.size.div(2) == 0
      b.push(1)
    else
      b.push(0)
    end
  end
  b=Vector[*b]
  r=c.inverse*b
  r=r.to_a
  z=sympfact(r[1...r.size-1])
  # puts z.flatten

  minsizediff = z.flatten.size
  splits = []
  for i in 0...( 1 << z.size )
    z1 = []
    z.each_with_index { |group,j| z1.push( group ) if i & ( 1 << j ) != 0 }
    z2 = []
    z.each_with_index { |group,j| z2.push( group ) if i & ( 1 << j ) == 0 }
    z1.flatten!
    z2.flatten!
    sizediff = ( z2.size - z1.size ).abs
    if minsizediff > sizediff
      minsizediff = sizediff
      splits = []
    end
    if minsizediff == sizediff
      splits.push( [ pmult( z1 ).real, pmult( z2 ).real ] )
    end
  end

  q=splits[0][0]
  qs=splits[0][1]
  qs=[0.0]+splits[0][1]+[0.0]

  q=q.mult(1.0/q.abs.sum)
  qs=qs.mult(1.0/qs.abs.sum)
  f=q.pmul([1,1]**k)
  fs=qs.pmul([1,1]**ks)

#  puts "#{f.size} #{fs.size}"
#  while f.size < fs.size
#    f = [0] + f + [0]
#  end
#  while fs.size < f.size
#    fs = [0] + fs + [0]
#  end
#  puts "#{f.size} #{fs.size}"


  h0=f.pmul(d)
  h0s=fs.pmul(d.reverse)

  h0 = h0.mult( Math::sqrt( 2 ) / h0.sum )
  h0s = h0s.mult( Math::sqrt( 2 ) / h0s.sum )

  g0 =h0.reverse
  g0s=h0s.reverse

  h1 = h0s.inverse(+1)
  h1s= h0 .inverse(+1) # -1
  g1 = g0s.inverse(+1)
  g1s= g0 .inverse(+1) # -1

  # h0s = h0s[1...h0s.size-1]
  # g0s = g0s[1...g0s.size-1]
  # h1  = h1 [1...h1.size-1]
  # g1  = g1 [1...g1.size-1]

  [ h0, h0s, h1, h1s, g0, g0s, g1, g1s ]
end