# -*- coding: utf-8 -*-
#
#
# TheVirtualBrain-Framework Package. This package holds all Data Management, and
# Web-UI helpful to run brain-simulations. To use it, you also need to download
# TheVirtualBrain-Scientific Package (for simulators). See content of the
# documentation-folder for more details. See also http://www.thevirtualbrain.org
#
# (c) 2012-2024, Baycrest Centre for Geriatric Care ("Baycrest") and others
#
# This program is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with this
# program. If not, see <http://www.gnu.org/licenses/>.
#
#
# CITATION:
# When using The Virtual Brain for scientific publications, please cite it as explained here:
# https://www.thevirtualbrain.org/tvb/zwei/neuroscience-publications
#
#
"""
Adapter that uses the traits module to generate interfaces to the Simulator.
Few supplementary steps are done here:
* from submitted Monitor/Model... names, build transient entities
* after UI parameters submit, compose transient Cortex entity to be passed to the Simulator.
.. moduleauthor:: Paula Popa <paula.popa@codemart.ro>
.. moduleauthor:: Lia Domide <lia.domide@codemart.ro>
.. moduleauthor:: Stuart A. Knock <Stuart@tvb.invalid>
"""
import json
from tvb.adapters.datatypes.db.connectivity import ConnectivityIndex
from tvb.adapters.datatypes.db.region_mapping import RegionMappingIndex, RegionVolumeMappingIndex
from tvb.adapters.datatypes.db.simulation_history import SimulationHistoryIndex
from tvb.adapters.datatypes.db.time_series import TimeSeriesIndex
from tvb.adapters.forms.coupling_forms import CouplingFunctionsEnum
from tvb.adapters.forms.model_forms import get_model_to_form_dict
from tvb.adapters.forms.monitor_forms import get_monitor_to_form_dict
from tvb.adapters.forms.simulator_fragments import *
from tvb.basic.neotraits.api import EnumAttr
from tvb.core.adapters.abcadapter import ABCAdapterForm, ABCAdapter
from tvb.core.adapters.exceptions import LaunchException, InvalidParameterException
from tvb.core.entities.file.simulator.simulation_history_h5 import SimulationHistory
from tvb.core.entities.file.simulator.view_model import SimulatorAdapterModel
from tvb.core.entities.storage import dao
from tvb.core.neocom import h5
from tvb.core.neotraits.forms import FloatField, SelectField
from tvb.simulator.simulator import Simulator
[docs]
class SimulatorAdapter(ABCAdapter):
"""
Interface between the Simulator and the Framework.
"""
_ui_name = "Simulation Core"
algorithm = None
branch_simulation_state_gid = None
def __init__(self):
super(SimulatorAdapter, self).__init__()
self.log.debug("%s: Initialized..." % str(self))
[docs]
def get_adapter_fragments(self, view_model):
# type (SimulatorAdapterModel) -> dict
forms = {None: [SimulatorStimulusFragment,
SimulatorModelFragment, SimulatorIntegratorFragment, SimulatorMonitorFragment,
SimulatorFinalFragment], "surface": [SimulatorSurfaceFragment, SimulatorRMFragment]}
current_model_class = type(view_model.model)
all_model_forms = get_model_to_form_dict()
forms["model"] = [all_model_forms.get(current_model_class)]
all_monitor_forms = get_monitor_to_form_dict()
selected_monitor_forms = []
for monitor in view_model.monitors:
current_monitor_class = type(monitor)
selected_monitor_forms.append(all_monitor_forms.get(current_monitor_class))
forms["monitors"] = selected_monitor_forms
# Not sure if where we should in fact include the entire tree, or it will become too tedious.
# For now I think it is ok if we rename this section "Summary" and filter what is shown
return forms
[docs]
def get_output(self):
"""
:returns: list of classes for possible results of the Simulator.
"""
return [TimeSeriesIndex, SimulationHistoryIndex]
[docs]
def get_required_memory_size(self, view_model):
# type: (SimulatorAdapterModel) -> int
"""
Return the required memory to run this algorithm.
"""
return self.algorithm.memory_requirement()
[docs]
def get_required_disk_size(self, view_model):
# type: (SimulatorAdapterModel) -> int
"""
Return the required disk size this algorithm estimates it will take. (in kB)
"""
return self.algorithm.storage_requirement() / 2 ** 10
[docs]
def get_execution_time_approximation(self, view_model):
# type: (SimulatorAdapterModel) -> int
"""
Method should approximate based on input arguments, the time it will take for the operation
to finish (in seconds).
"""
# This is just a brute approx so cluster nodes won't kill operation before
# it's finished. This should be done with a higher grade of sensitivity
# Magic number connecting simulation length to simulation computation time
# This number should as big as possible, as long as it is still realistic, to
magic_number = 6.57e-06 # seconds
approx_number_of_nodes = 500
approx_nvar = 15
approx_modes = 15
approx_integrator_dt = self.algorithm.integrator.dt
if approx_integrator_dt == 0.0:
approx_integrator_dt = 1.0
if self.algorithm.is_surface_simulation:
approx_number_of_nodes *= approx_number_of_nodes
estimation = (magic_number * approx_number_of_nodes * approx_nvar *
approx_modes * self.algorithm.simulation_length / approx_integrator_dt)
return max(int(estimation), 1)
def _try_find_mapping(self, mapping_class, connectivity_gid):
"""
Try to find a DataType instance of class "mapping_class", linked to the given Connectivity.
Entities in the current project will have priority.
:param mapping_class: DT class, with field "_connectivity" on it
:param connectivity_gid: GUID
:return: None or instance of "mapping_class"
"""
dts_list = dao.get_generic_entity(mapping_class, connectivity_gid, 'fk_connectivity_gid')
if len(dts_list) < 1:
return None
for dt in dts_list:
dt_operation = dao.get_operation_by_id(dt.fk_from_operation)
if dt_operation.fk_launched_in == self.current_project_id:
return dt
return dts_list[0]
def _try_load_region_mapping(self):
region_map = None
region_volume_map = None
region_map_index = self._try_find_mapping(RegionMappingIndex, self.algorithm.connectivity.gid.hex)
region_volume_map_index = self._try_find_mapping(RegionVolumeMappingIndex, self.algorithm.connectivity.gid.hex)
if region_map_index:
region_map = h5.load_from_index(region_map_index)
if region_volume_map_index:
region_volume_map = h5.load_from_index(region_volume_map_index)
return region_map, region_volume_map
[docs]
def launch(self, view_model):
# type: (SimulatorAdapterModel) -> [TimeSeriesIndex, SimulationHistoryIndex]
"""
Called from the GUI to launch a simulation.
*: string class name of chosen model, etc...
*_parameters: dictionary of parameters for chosen model, etc...
connectivity: tvb.datatypes.connectivity.Connectivity object.
surface: tvb.datatypes.surfaces.CorticalSurface: or None.
stimulus: tvb.datatypes.patters.* object
"""
result_h5 = dict()
result_indexes = dict()
start_time = self.algorithm.current_step * self.algorithm.integrator.dt
self.algorithm.configure(full_configure=False)
if self.branch_simulation_state_gid is not None:
history = self.load_traited_by_gid(self.branch_simulation_state_gid)
assert isinstance(history, SimulationHistory)
history.fill_into(self.algorithm)
region_map, region_volume_map = self._try_load_region_mapping()
for monitor in self.algorithm.monitors:
if monitor.period > view_model.simulation_length:
raise InvalidParameterException("Sampling period for monitors can not be bigger "
"than the simulation length!")
m_name = type(monitor).__name__
ts = monitor.create_time_series(self.algorithm.connectivity, self.algorithm.surface, region_map,
region_volume_map)
self.log.debug("Monitor created the TS")
ts.start_time = start_time
ts_index_class = h5.REGISTRY.get_index_for_datatype(type(ts))
ts_index = ts_index_class()
ts_index.fill_from_has_traits(ts)
ts_index.data_ndim = 4
ts_index.state = 'INTERMEDIATE'
if monitor.voi is not None:
state_variable_dimension_name = ts.labels_ordering[1]
selected_vois = [self.algorithm.model.variables_of_interest[idx] for idx in monitor.voi]
ts.labels_dimensions[state_variable_dimension_name] = selected_vois
ts_index.labels_dimensions = json.dumps(ts.labels_dimensions)
ts_h5_class = h5.REGISTRY.get_h5file_for_datatype(type(ts))
ts_h5_path = h5.path_by_dir(self._get_output_path(), ts_h5_class, ts.gid)
self.log.info("Generating Timeseries at: {}".format(ts_h5_path))
ts_h5 = ts_h5_class(ts_h5_path)
ts_h5.store(ts, scalars_only=True, store_references=False)
ts_h5.sample_rate.store(ts.sample_rate)
ts_h5.nr_dimensions.store(ts_index.data_ndim)
# Storing GA also here redundant, except for HPC
ts_h5.store_generic_attributes(self.generic_attributes)
ts_h5.store_references(ts)
result_indexes[m_name] = ts_index
result_h5[m_name] = ts_h5
# Run simulation
self.log.debug("Starting simulation...")
for result in self.algorithm(simulation_length=self.algorithm.simulation_length):
for j, monitor in enumerate(self.algorithm.monitors):
if result[j] is not None:
m_name = type(monitor).__name__
ts_h5 = result_h5[m_name]
ts_h5.write_time_slice([result[j][0]])
ts_h5.write_data_slice([result[j][1]])
self.log.debug("Completed simulation, starting to store simulation state ")
# Now store simulator history, at the simulation end
results = []
if not self._is_group_launch():
simulation_history = SimulationHistory()
simulation_history.populate_from(self.algorithm)
self.generic_attributes.visible = False
history_index = h5.store_complete_to_dir(simulation_history, self._get_output_path(),
self.generic_attributes)
self.generic_attributes.visible = True
history_index.fixed_generic_attributes = True
results.append(history_index)
self.log.debug("Simulation state persisted, returning results ")
for monitor in self.algorithm.monitors:
m_name = type(monitor).__name__
ts_shape = result_h5[m_name].read_data_shape()
result_indexes[m_name].fill_shape(ts_shape)
result_h5[m_name].close()
self.log.debug("%s: Adapter simulation finished!!" % str(self))
results.extend(result_indexes.values())
return results
def _get_output_path(self):
return self.get_storage_path()