multiTemplateMDN.pm 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #!/user/bin/perl -w
  2. package multiTemplateMDN;
  3. use strict;
  4. use Math::Trig; ## for tanh
  5. #our @ISA = qw(Exporter);
  6. #our @EXPORT = qw(readMDNWeights getMixtureParametersFromMDN);
  7. sub activationFunction
  8. {
  9. my $value = shift;
  10. my $fct = shift;
  11. return 1.0/(1.0 + exp(-$value)) if ($fct =~ /L/);
  12. return tanh($value) if ($fct =~ /H/);
  13. return $value if ($fct =~ /N/);
  14. }
  15. ##################################################################################
  16. ## input: reference to vector and reference to matrix, which are to be multiplied
  17. ## dimensions must be correct
  18. ## output: matrix * vector => vector
  19. ##################################################################################
  20. sub vectorTimesMatrix {
  21. my ($vecPtr, $matPtr) = @_;
  22. my @vec = @$vecPtr;
  23. my @matrix = @$matPtr;
  24. ## test dimensions
  25. if (scalar(@matrix) == 0) {
  26. print "matrix is empty!\n";
  27. exit(1);
  28. }
  29. if (scalar(@vec != scalar(@{$matrix[0]}))) {
  30. print "Error: vector and matrix have non-corresponding dim!\n";
  31. exit(1);
  32. }
  33. my @erg;
  34. for (my $i=0; $i<scalar(@matrix); $i++) {
  35. my $scaProd = 0;
  36. for (my $j=0; $j<scalar(@{$matrix[$i]}); $j++) {
  37. $scaProd += $vec[$j] * $matrix[$i][$j];
  38. }
  39. push(@erg, $scaProd);
  40. }
  41. return @erg;
  42. }
  43. ########################################################
  44. ## weights has the following structure:
  45. ## w_11 w_21 w_31 bias
  46. ## w_12 w_22 w_32 bias
  47. ## ...
  48. ## w_15 2_25 w_35 bias
  49. ##
  50. ## one row represents the weights w_ij
  51. ## between node i and j (i is one layer before j)
  52. ########################################################
  53. ## requires two arguments:
  54. ## file for weights of layer1
  55. ## file for weights of layer2
  56. ##
  57. ########################################################
  58. sub new {
  59. my $class = shift;
  60. my $self = {
  61. numberOfInputs => undef,
  62. numberOfHiddenNodes => undef,
  63. numberOfOutputs => undef,
  64. dimTarget => 1,
  65. hiddenLayerActivationFunction => undef,
  66. outputLayerActivationFunction => undef,
  67. weightsLayer1 => [],
  68. weightsLayer2 => []
  69. };
  70. bless($self, $class);
  71. $self->readMDNWeights(shift, shift);
  72. return $self;
  73. }
  74. ## input: weight file name for layer 1 and layer 2
  75. ## output: fills weightsLayer1 and weightsLayer2 and some other
  76. ## variables
  77. ##
  78. sub readMDNWeights {
  79. my $self = shift;
  80. my $MDNLayer1WeightsFile = shift;
  81. my $MDNLayer2WeightsFile = shift;
  82. my @weightsLayer1;
  83. my @weightsLayer2;
  84. open (WH1, "< $MDNLayer1WeightsFile") or die "Cant open $MDNLayer1WeightsFile! $!\n";
  85. my $weightLineRead = 0;
  86. while(<WH1>) {
  87. ## ignore comments
  88. next if (/\s*#/);
  89. if(/hiddenLayerActivationFunction\s*=\s*(\S+)/) {
  90. $self->{hiddenLayerActivationFunction} = $1;
  91. }
  92. ## weights (at least 2)
  93. elsif (/^\s*(\S+\s+)+\S+/) {
  94. my @weights = split(/\s+/);
  95. $weightsLayer1[$weightLineRead] = \@weights;
  96. $weightLineRead++;
  97. }
  98. }
  99. close (WH1);
  100. my $numberOfInputs = scalar(@{$weightsLayer1[0]}) - 1;
  101. $self->{numberOfInputs} = $numberOfInputs;
  102. print "numberOfInputs=$numberOfInputs\n";
  103. my $numberOfHiddenNodes = $weightLineRead;
  104. $self->{numberOfHiddenNodes} = $numberOfHiddenNodes;
  105. ## reset
  106. $weightLineRead = 0;
  107. open(WH2, "< $MDNLayer2WeightsFile") or die "Cant open $MDNLayer2WeightsFile! $!\n";
  108. while(<WH2>) {
  109. next if (/\s*#/);
  110. if (/outputLayerActivationFunction\s*=\s*(\S+)/) {
  111. $self->{outputLayerActivationFunction} = $1;
  112. }
  113. elsif (/^\s*(\S+\s+)+\S+/) {
  114. my @weights = split(/\s+/);
  115. $weightsLayer2[$weightLineRead] = \@weights;
  116. $weightLineRead++;
  117. }
  118. }
  119. my $numberOfOutputs = $weightLineRead;
  120. $self->{numberOfOutputs} = $numberOfOutputs;
  121. close(WH2);
  122. $self->{weightsLayer1} = \@weightsLayer1;
  123. $self->{weightsLayer2} = \@weightsLayer2;
  124. }
  125. ##############################################################
  126. ## input: vector containing input for MDN
  127. ## output: MDN result, i.e. parameters of Gaussian components
  128. ## e.g. (pi1, pi2, mu1, mu2, sigma1, sigma2)
  129. ##############################################################
  130. sub getMixtureParametersFromMDN {
  131. my $self = shift;
  132. ## vector containing input for MDN, e.g. (distance, post, sim)
  133. my @input = @_;
  134. if (scalar(@input) != $self->{numberOfInputs}) {
  135. print "number of inputs doesnt match weights in first layer!\n";
  136. exit(1);
  137. }
  138. ## since the weight matrices have bias as last column,
  139. ## input must get a "1" as last entry
  140. push (@input, 1);
  141. my @inputLayer1 = &vectorTimesMatrix(\@input, $self->{weightsLayer1});
  142. my @outputLayer1;
  143. for (my $i=0; $i<@inputLayer1; $i++) {
  144. $outputLayer1[$i] = &activationFunction($inputLayer1[$i], $self->{hiddenLayerActivationFunction});
  145. }
  146. push (@outputLayer1, 1);
  147. my @outputLayer2 = &vectorTimesMatrix(\@outputLayer1, $self->{weightsLayer2});
  148. for (my $i=0; $i<@outputLayer2; $i++) {
  149. $outputLayer2[$i] = &activationFunction($outputLayer2[$i], $self->{outputLayerActivationFunction});
  150. }
  151. ## each component is decribed by 3 parameters:
  152. ## pi, mu, sigma (each one-dimensional)
  153. my $numberOfComponents = $self->{numberOfOutputs} / 3;
  154. ## from "raw output", calculate now the "real" parameters
  155. ## pi
  156. my $pi_normalizer = 0;
  157. for (my $i=0; $i<$numberOfComponents; $i++) {
  158. $outputLayer2[$i] = exp($outputLayer2[$i]);
  159. $pi_normalizer += $outputLayer2[$i];
  160. }
  161. for (my $i=0; $i<$numberOfComponents; $i++) {
  162. $outputLayer2[$i] /= $pi_normalizer;
  163. }
  164. ## mu stays unchanged, i.e. network emits mu directly
  165. ## sigma
  166. for (my $i=0; $i<$numberOfComponents; $i++) {
  167. $outputLayer2[$#outputLayer2 - $i] = sqrt( exp($outputLayer2[$#outputLayer2 - $i]) );
  168. }
  169. return @outputLayer2;
  170. }
  171. 1;
  172. ## &readWeights();
  173. ## my @input = (5, 0.9, 0.7);
  174. ## my @erg = &getMixtureParametersFromMDN(@input);
  175. ## print "@erg\n";