callbacks.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /******************************************************************************
  2. *
  3. * Copyright (C) 2020-2021 by
  4. * The Salk Institute for Biological Studies
  5. *
  6. * Use of this source code is governed by an MIT-style
  7. * license that can be found in the LICENSE file or at
  8. * https://opensource.org/licenses/MIT.
  9. *
  10. ******************************************************************************/
  11. #include "api/callbacks.h"
  12. #include "api/model.h"
  13. #include "api/geometry_object.h"
  14. #include "world.h"
  15. #include "molecule.h"
  16. using namespace std;
  17. namespace MCell {
  18. namespace API {
  19. Callbacks::Callbacks(Model* model_)
  20. : model(model_) {
  21. assert(model != nullptr);
  22. }
  23. void Callbacks::register_mol_wall_hit_callback(
  24. const mol_wall_hit_callback_function_t func,
  25. py::object context,
  26. const geometry_object_id_t geometry_object_id,
  27. const BNG::species_id_t species_id
  28. ) {
  29. assert(model != nullptr);
  30. auto it_geom_obj = mol_wall_hit_callbacks.find(geometry_object_id);
  31. if (it_geom_obj != mol_wall_hit_callbacks.end() && it_geom_obj->second.count(species_id) != 0) {
  32. string geom_name;
  33. if (geometry_object_id != GEOMETRY_OBJECT_ID_INVALID) {
  34. geom_name = model->get_world()->get_geometry_object(geometry_object_id).name;
  35. }
  36. else {
  37. geom_name = "any";
  38. }
  39. string species_name;
  40. if (species_id != BNG::SPECIES_ID_INVALID) {
  41. species_name = model->get_world()->get_all_species().get(species_id).name;
  42. }
  43. else {
  44. species_name = "any";
  45. }
  46. throw RuntimeError(S("Cannot register two callbacks for an identical pair or geometry object and species id, ") +
  47. " error while trying to register second callback for geometry object '" + geom_name + "' and species '" + species_name +"'.");
  48. }
  49. mol_wall_hit_callbacks[geometry_object_id][species_id] = MolWallHitCallbackInfo(func, context, geometry_object_id, species_id);
  50. // make sure that the species_id won't change in the future
  51. if (species_id != BNG::SPECIES_ID_INVALID) {
  52. model->get_world()->get_all_species().get(species_id).clear_flag(BNG::SPECIES_FLAG_IS_REMOVABLE);
  53. }
  54. }
  55. void Callbacks::do_mol_wall_hit_callbacks(std::shared_ptr<MolWallHitInfo> info) {
  56. // set geometry data
  57. info->geometry_object = model->get_geometry_object_with_id(info->geometry_object_id);
  58. assert(is_set(info->geometry_object));
  59. assert(info->partition_wall_index >= info->geometry_object->first_wall_index);
  60. info->wall_index = info->partition_wall_index - info->geometry_object->first_wall_index;
  61. // convert units
  62. assert(model->get_world() != nullptr);
  63. info->time = info->time * model->get_world()->config.time_unit;
  64. info->pos3d = mult_vec(info->pos3d, model->get_world()->config.length_unit);
  65. info->time_before_hit = info->time_before_hit * model->get_world()->config.time_unit;
  66. info->pos3d_before_hit = mult_vec(info->pos3d_before_hit, model->get_world()->config.length_unit);
  67. geometry_object_id_t geometry_object_id = info->geometry_object->geometry_object_id;
  68. // only one partition for now
  69. const MCell::Molecule& m = model->get_world()->get_partition(PARTITION_ID_INITIAL).get_m(info->molecule_id);
  70. // call callback for all matching registered callbacks
  71. auto it_specific_geom_obj = mol_wall_hit_callbacks.find(geometry_object_id);
  72. if (it_specific_geom_obj != mol_wall_hit_callbacks.end()) {
  73. do_mol_wall_hit_callback_for_specific_and_any_species(
  74. info, m.species_id, it_specific_geom_obj->second);
  75. }
  76. auto it_any_geom_obj = mol_wall_hit_callbacks.find(GEOMETRY_OBJECT_ID_INVALID);
  77. if (it_any_geom_obj != mol_wall_hit_callbacks.end()) {
  78. do_mol_wall_hit_callback_for_specific_and_any_species(
  79. info, m.species_id, it_any_geom_obj->second);
  80. }
  81. }
  82. void Callbacks::do_mol_wall_hit_callback_for_specific_and_any_species(
  83. std::shared_ptr<MolWallHitInfo> info,
  84. const BNG::species_id_t specific_species_id,
  85. const SpeciesMolWallHitCallbackInfoMap& species_map) {
  86. auto it_specific_species = species_map.find(specific_species_id);
  87. if (it_specific_species != species_map.end()) {
  88. do_individual_mol_wall_hit_callback(info, it_specific_species->second);
  89. }
  90. auto it_any_species = species_map.find(BNG::SPECIES_ID_INVALID);
  91. if (it_any_species != species_map.end()) {
  92. do_individual_mol_wall_hit_callback(info, it_any_species->second);
  93. }
  94. }
  95. void Callbacks::do_individual_mol_wall_hit_callback(
  96. std::shared_ptr<MolWallHitInfo> info,
  97. MolWallHitCallbackInfo callback_function_and_context) {
  98. // acquire GIL before calling Python code
  99. py::gil_scoped_acquire acquire;
  100. // call the actual callback
  101. callback_function_and_context.callback_function(info, callback_function_and_context.context);
  102. }
  103. void Callbacks::register_rxn_callback(
  104. const rxn_callback_function_t func,
  105. py::object context,
  106. const BNG::rxn_rule_id_t rxn_rule_id
  107. ) {
  108. assert(model != nullptr);
  109. assert(rxn_rule_id != BNG::RXN_RULE_ID_INVALID);
  110. if (rxn_callbacks.count(rxn_rule_id) != 0) {
  111. std::string name = model->get_world()->get_all_rxns().get(rxn_rule_id)->to_str();
  112. throw RuntimeError(S("Each reaction rule can have only a single callback, error while trying to register ") +
  113. "second callback for " + name + ".");
  114. }
  115. rxn_callbacks[rxn_rule_id] = RxnCallbackInfo(func, context, rxn_rule_id);
  116. }
  117. bool Callbacks::do_rxn_callback(std::shared_ptr<ReactionInfo> info) {
  118. // select the correct callback
  119. assert(rxn_callbacks.count(info->rxn_rule_id) != 0);
  120. const RxnCallbackInfo& specific_callback = rxn_callbacks[info->rxn_rule_id];
  121. // set reaction rule object
  122. info->reaction_rule = model->get_reaction_rule_with_fwd_id(info->rxn_rule_id);
  123. assert(is_set(info->reaction_rule));
  124. // convert units
  125. assert(model->get_world() != nullptr);
  126. info->time = info->time * model->get_world()->config.time_unit;
  127. info->pos3d = mult_vec(info->pos3d, model->get_world()->config.length_unit);
  128. const BNG::RxnRule* rxn = model->get_world()->get_all_rxns().get(info->rxn_rule_id);
  129. if (info->geometry_object_id != GEOMETRY_OBJECT_ID_INVALID) {
  130. info->geometry_object = model->get_geometry_object_with_id(info->geometry_object_id);
  131. assert(is_set(info->geometry_object));
  132. assert(info->partition_wall_index >= info->geometry_object->first_wall_index);
  133. info->wall_index = info->geometry_object->get_object_wall_index(info->partition_wall_index);
  134. info->pos2d = mult_vec(info->pos2d, model->get_world()->config.length_unit);
  135. }
  136. // acquire GIL before calling Python code
  137. py::gil_scoped_acquire acquire;
  138. // call the actual callback
  139. bool cancel_reaction = specific_callback.callback_function(info, specific_callback.context);
  140. return cancel_reaction;
  141. }
  142. } /* namespace API */
  143. } /* namespace MCell */