"""
Module implementing the WASM/JS "standalone" device.
"""
import os
import platform
import re
import shutil
import tempfile
import time
from collections import Counter
import numpy as np
from brian2.units import second
from brian2.core.namespace import get_local_namespace
from brian2.core.preferences import prefs, BrianPreference
from brian2.synapses import Synapses
from brian2.utils.logger import get_logger
from brian2.utils.filetools import in_directory
from brian2.devices import all_devices
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice, CPPWriter
from brian2.utils.filetools import ensure_directory
logger = get_logger(__name__)
prefs.register_preferences(
'devices.wasm_standalone',
'Preferences for the WebAsm backend',
emsdk_directory=BrianPreference(
default="",
docs="""
Absolute path to the *emsdk* installation. Leave empty to use the
EMSDK/CONDA_EMSDK_DIR environment variables or an already-activated
emsdk in your shell.
""",
),
emsdk_version=BrianPreference(
default="latest",
docs="""
Version string passed to ``emsdk activate`` (e.g. ``"3.1.56"``).
Ignored when *emsdk_directory* is empty and the SDK is pre-activated.
""",
),
emcc_compile_args=BrianPreference(
default=["-w"],
docs="""
Extra flags appended to every *emcc* **compile** command.
Example: ``["-O3", "-sASSERTIONS"]``.
""",
),
emcc_link_args=BrianPreference(
default=[],
docs="""
Extra flags appended to the final *emcc* **link** command that produces
``wasm_module.js`` / ``.wasm``.
Example: ``["-sEXPORT_ES6", "-sEXPORTED_RUNTIME_METHODS=['cwrap']"]``.
""",
),
)
DEFAULT_HTML_CONTENT = {'title': 'Brian simulation',
'h1': '',
'h2': '',
'description': '',
'canvas_width': '95%',
'canvas_height': '500px'}
[docs]
class WASMStandaloneDevice(CPPStandaloneDevice):
"""
The `Device` used for WASM simulations.
"""
def __init__(self, *args, **kwds):
"""
Initialize the WASM standalone device.
This method prepares the device by setting up internal attributes
and delegating initialization to the parent ``CPPStandaloneDevice``.
Parameters
----------
*args : tuple
Positional arguments passed to the parent ``CPPStandaloneDevice``.
**kwds : dict
Keyword arguments passed to the parent ``CPPStandaloneDevice``.
Raises
------
None
Returns
-------
None
Initializes internal state; does not return a value.
"""
self.transfer_results = None
super(WASMStandaloneDevice, self).__init__(*args, **kwds)
[docs]
def transfer_only(self, variableviews):
"""
Mark variables for transfer from WASM to JavaScript.
This method specifies which simulation variables should be available
in JavaScript after the simulation completes.
Parameters
----------
variableviews : list
List of ``VariableView`` objects to be transferred.
Raises
------
AssertionError
If transfer variables are already set before calling this method.
Returns
-------
None
Stores the selected variables for later transfer; does not return a value.
"""
assert self.transfer_results is None
self.transfer_results = []
for variableview in variableviews:
self.transfer_results.append(variableview.variable)
[docs]
def activate(self, *args, **kwargs):
"""
Activate the WASM standalone device for simulation.
This method overrides template configuration and ensures WASM-specific
headers are included in the generated code.
Parameters
----------
*args : tuple
Positional arguments passed to the parent activate method.
**kwargs : dict
Keyword arguments passed to the parent activate method.
Raises
------
None
Returns
-------
None
Configures the device and modifies build templates; does not return a value.
"""
super(WASMStandaloneDevice, self).activate(*args, **kwargs)
# Overwrite the templater to prefer our templates
self.code_object_class().templater = self.code_object_class().templater.derive('brian2wasm')
if '<emscripten.h>' not in prefs.codegen.cpp.headers:
prefs.codegen.cpp.headers += ['<emscripten.h>']
[docs]
def generate_objects_source(
self,
writer,
arange_arrays,
synapses,
static_array_specs,
networks,
timed_arrays,
):
"""
Generate the main C++ source file for WASM compilation.
This method produces the core simulation code, including objects,
arrays, and transfer variables, and writes it to ``objects.*`` files.
Parameters
----------
writer : CodeWriter
Object for writing generated code.
arange_arrays : dict
Specifications for arange arrays.
synapses : set
Set of ``Synapses`` objects in the simulation.
static_array_specs : dict
Specifications for static arrays.
networks : set
Set of ``Network`` objects in the simulation.
timed_arrays : dict
Specifications for timed arrays.
Raises
------
IOError
If writing the generated code to files fails.
Returns
-------
None
Generates source files on disk; does not return a value.
"""
arr_tmp = self.code_object_class().templater.objects(
None,
None,
array_specs=self.arrays,
dynamic_array_specs=self.dynamic_arrays,
dynamic_array_2d_specs=self.dynamic_arrays_2d,
zero_arrays=self.zero_arrays,
arange_arrays=arange_arrays,
synapses=synapses,
clocks=self.clocks,
static_array_specs=static_array_specs,
networks=networks,
get_array_filename=self.get_array_filename,
get_array_name=self.get_array_name,
profiled_codeobjects=self.profiled_codeobjects,
code_objects=list(self.code_objects.values()),
timed_arrays=timed_arrays,
transfer_results=self.transfer_results,
)
writer.write("objects.*", arr_tmp)
[docs]
def generate_makefile(self, writer, compiler, compiler_flags, linker_flags, nb_threads, debug):
"""
Generate a platform-specific makefile for Emscripten compilation.
This method configures compiler and linker flags, resolves SDK paths,
and writes a makefile tailored for WASM builds.
Parameters
----------
writer : CodeWriter
Object for writing generated files.
compiler : str
Compiler name (typically ``emcc``).
compiler_flags : str
Compiler flags to apply.
linker_flags : str
Linker flags to apply.
nb_threads : int
Number of threads (unused for WASM).
debug : bool
Whether to include debug symbols.
Raises
------
RuntimeError
If Emscripten paths or build options are invalid.
Returns
-------
None
Writes the makefile to disk; does not return a value.
"""
preloads = ' '.join(f'--preload-file static_arrays/{static_array}'
for static_array in sorted(self.static_arrays.keys()))
rm_cmd = 'rm $(OBJS) $(PROGRAM) $(DEPS)'
if debug:
compiler_debug_flags = '-g -DDEBUG'
linker_debug_flags = '-g'
else:
compiler_debug_flags = ''
linker_debug_flags = ''
source_files = ' '.join(sorted(writer.source_files))
preamble_file = os.path.join(os.path.dirname(__file__), 'templates', 'pre.js')
prefs.devices.wasm_standalone.emsdk_directory = (
prefs.devices.wasm_standalone.emsdk_directory
or os.environ.get("EMSDK")
or os.environ.get("CONDA_EMSDK_DIR")
)
emsdk_path = prefs.devices.wasm_standalone.emsdk_directory
emsdk_version = prefs.devices.wasm_standalone.emsdk_version
if not emsdk_path:
# Check whether EMSDK is already activated
if not (os.environ.get("EMSDK", "")) or os.environ["EMSDK"] not in os.environ["PATH"]:
raise ValueError("Please provide the path to the emsdk directory in the preferences")
if os.name == 'nt':
makefile_tmp = self.code_object_class().templater.win_makefile(None, None,
source_files=source_files,
header_files=' '.join(sorted(writer.header_files)),
compiler_flags=compiler_flags,
compiler_debug_flags=compiler_debug_flags,
linker_debug_flags=linker_debug_flags,
linker_flags=linker_flags,
preloads=preloads,
preamble_file=preamble_file,
rm_cmd=rm_cmd,
emsdk_path=emsdk_path,
emsdk_version=emsdk_version)
else:
makefile_tmp = self.code_object_class().templater.makefile(None, None,
source_files=source_files,
header_files=' '.join(sorted(writer.header_files)),
compiler_flags=compiler_flags,
compiler_debug_flags=compiler_debug_flags,
linker_debug_flags=linker_debug_flags,
linker_flags=linker_flags,
preloads=preloads,
preamble_file=preamble_file,
rm_cmd=rm_cmd,
emsdk_path=emsdk_path,
emsdk_version=emsdk_version)
outputfile_name = 'win_makefile' if os.name == 'nt' else 'makefile'
writer.write(outputfile_name, makefile_tmp)
[docs]
def copy_source_files(self, writer, directory):
"""
Copy JavaScript runtime files to the build directory.
This method copies required JavaScript files (``worker.js``, ``brian.js``)
and optionally a custom ``index.html`` into the build folder.
Parameters
----------
writer : CodeWriter
Object containing source file information.
directory : str
Target directory for copied files.
Raises
------
IOError
If copying files fails.
Returns
-------
None
Populates the build directory with JavaScript runtime files.
"""
super(WASMStandaloneDevice, self).copy_source_files(writer, directory)
shutil.copy(os.path.join(os.path.dirname(__file__), 'templates', 'worker.js'), directory)
shutil.copy(os.path.join(os.path.dirname(__file__), 'templates', 'brian.js'), directory)
if self.build_options['html_file']:
shutil.copy(self.build_options['html_file'], os.path.join(directory, 'index.html'))
[docs]
def get_report_func(self, report):
"""
Generate C++ code for simulation progress reporting.
This method produces source code that reports simulation progress
to the console or forwards updates to JavaScript via ``EM_ASM``.
Parameters
----------
report : str or None
Type of progress reporting: None, 'text', 'stdout', 'stderr',
or custom C++ code.
Raises
------
ValueError
If the report type is unsupported.
Returns
-------
str
The generated C++ source code for progress reporting.
"""
# Code for a progress reporting function
standard_code = """
std::string _format_time(float time_in_s)
{
float divisors[] = {24*60*60, 60*60, 60, 1};
char letters[] = {'d', 'h', 'm', 's'};
float remaining = time_in_s;
std::string text = "";
int time_to_represent;
for (int i =0; i < sizeof(divisors)/sizeof(float); i++)
{
time_to_represent = int(remaining / divisors[i]);
remaining -= time_to_represent * divisors[i];
if (time_to_represent > 0 || text.length())
{
if(text.length() > 0)
{
text += " ";
}
text += (std::to_string(time_to_represent)+letters[i]);
}
}
//less than one second
if(text.length() == 0)
{
text = "< 1s";
}
return text;
}
void report_progress(const double elapsed, const double completed, const double start, const double duration)
{
// Send progress to javascript
EM_ASM({
(postMessage({ type: 'progress', elapsed: $0, completed: $1, start: $2, duration: $3}));
}, elapsed, completed, start, duration);
if (completed == 0.0)
{
%STREAMNAME% << "Starting simulation at t=" << start << " s for duration " << duration << " s";
} else
{
%STREAMNAME% << completed*duration << " s (" << (int)(completed*100.) << "%) simulated in " << _format_time(elapsed) << " (" << elapsed << "s)";
if (completed < 1.0)
{
const int remaining = (int)((1-completed)/completed*elapsed+0.5);
%STREAMNAME% << ", estimated " << _format_time(remaining) << " remaining.";
}
}
%STREAMNAME% << std::endl << std::flush;
}
"""
if report is None:
report_func = ''
elif report == 'text' or report == 'stdout':
report_func = standard_code.replace('%STREAMNAME%', 'std::cout')
elif report == 'stderr':
report_func = standard_code.replace('%STREAMNAME%', 'std::cerr')
elif isinstance(report, str):
report_func = """
void report_progress(const double elapsed, const double completed, const double start, const double duration)
{
%REPORT%
}
""".replace('%REPORT%', report)
else:
raise TypeError("report argument has to be either 'text', "
"'stdout', 'stderr', or the code for a report "
"function")
return report_func
[docs]
def network_run(self, net, duration, report=None, report_period=10*second,
namespace=None, profile=None, level=0, **kwds):
"""
Execute a Brian2 network simulation for the WASM backend.
This method organizes network objects, generates C++ execution code,
and triggers the build if ``build_on_run`` is enabled.
Parameters
----------
net : Network
The Brian2 network to simulate.
duration : Quantity
Duration of the simulation (must be non-negative).
report : str or None, optional
Progress reporting mode. Default is None.
report_period : Quantity, optional
Interval between progress reports. Default is 10*second.
namespace : dict, optional
Local namespace for variable resolution. Default is None.
profile : bool, optional
Whether to enable profiling. Default is None.
level : int, optional
Stack level for namespace detection. Default is 0.
**kwds : dict
Additional keyword arguments.
Raises
------
ValueError
If duration is negative.
NotImplementedError
If multiple incompatible report functions are used.
RuntimeError
If the network was already built and run.
Returns
-------
None
Prepares and builds the simulation; does not return a value.
"""
if duration < 0:
raise ValueError(
f"Function 'run' expected a non-negative duration but got '{duration}'"
)
self.networks.add(net)
if kwds:
logger.warn(('Unsupported keyword argument(s) provided for run: '
'%s') % ', '.join(kwds.keys()))
# We store this as an instance variable for later access by the
# `code_object` method
self.enable_profiling = profile
# Allow setting `profile` in the `set_device` call (used e.g. in brian2cuda
# SpeedTest configurations)
if profile is None:
self.enable_profiling = self.build_options.get('profile', False)
all_objects = net.sorted_objects
net._clocks = {obj.clock for obj in all_objects}
t_end = net.t+duration
for clock in net._clocks:
clock.set_interval(net.t, t_end)
# Get the local namespace
if namespace is None:
namespace = get_local_namespace(level=level+2)
net.before_run(namespace)
self.synapses |= {s for s in net.objects
if isinstance(s, Synapses)}
self.clocks.update(net._clocks)
net.t_ = float(t_end)
# TODO: remove this horrible hack
for clock in self.clocks:
if clock.name=='clock':
clock._name = '_clock'
# Extract all the CodeObjects
# Note that since we ran the Network object, these CodeObjects will be sorted into the right
# running order, assuming that there is only one clock
code_objects = []
for obj in all_objects:
if obj.active:
for codeobj in obj._code_objects:
code_objects.append((obj.clock, codeobj))
report_func = self.get_report_func(report)
if report_func != '':
if self.report_func != '' and report_func != self.report_func:
raise NotImplementedError("The C++ standalone device does not "
"support multiple report functions, "
"each run has to use the same (or "
"none).")
self.report_func = report_func
if report_func:
report_call = 'report_progress'
else:
report_call = 'NULL'
# Generate the updaters
run_lines = [f'{net.name}.clear();']
all_clocks = set()
for clock, codeobj in code_objects:
run_lines.append(f'{net.name}.add(&{clock.name}, _run_{codeobj.name});')
all_clocks.add(clock)
# Under some rare circumstances (e.g. a NeuronGroup only defining a
# subexpression that is used by other groups (via linking, or recorded
# by a StateMonitor) *and* not calculating anything itself *and* using a
# different clock than all other objects) a clock that is not used by
# any code object should nevertheless advance during the run. We include
# such clocks without a code function in the network.
for clock in net._clocks:
if clock not in all_clocks:
run_lines.append(f'{net.name}.add(&{clock.name}, NULL);')
run_lines.extend(self.code_lines['before_network_run'])
if not self.run_args_applied:
run_lines.append('set_from_command_line(args);')
self.run_args_applied = True
run_lines.append(f'{net.name}.run({float(duration)!r}, {report_call}, {float(report_period)!r});')
run_lines.extend(self.code_lines['after_network_run'])
self.main_queue.append(('run_network', (net, run_lines)))
net.after_run()
# Manually set the cache for the clocks, simulation scripts might
# want to access the time (which has been set in code and is therefore
# not accessible by the normal means until the code has been built and
# run)
for clock in net._clocks:
self.array_cache[clock.variables['timestep']] = np.array([clock._i_end])
self.array_cache[clock.variables['t']] = np.array([clock._i_end * clock.dt_])
if self.build_on_run:
if self.has_been_run:
raise RuntimeError("The network has already been built and run "
"before. Use set_device with "
"build_on_run=False and an explicit "
"device.build call to use multiple run "
"statements with this device.")
self.build(direct_call=False, **self.build_options)
[docs]
def run(self, directory, results_directory, with_output, run_args):
"""
Execute the compiled WASM simulation in a browser environment.
This method launches the simulation using ``emrun`` and provides
browser-based progress reporting and visualization.
Parameters
----------
directory : str
Build directory containing compiled files.
results_directory : str
Directory to store simulation results.
with_output : bool
Whether to forward stdout/stderr output.
run_args : list
Extra command-line arguments for the execution environment.
Raises
------
RuntimeError
If the server cannot be launched or required files are missing.
Returns
-------
None
Runs the simulation in a browser; does not return a value.
"""
html_file = self.build_options['html_file']
html_content = self.build_options['html_content']
if html_file is None:
import __main__
html_file = os.path.splitext(__main__.__file__)[0] + '.html'
if not os.path.exists(html_file):
if html_content is None:
html_content = dict(DEFAULT_HTML_CONTENT)
else:
for key in html_content:
if key not in DEFAULT_HTML_CONTENT:
raise KeyError(f"Key '{key} is not a valid key for html_content. Allowed keys: {', '.join(DEFAULT_HTML_CONTENT.keys())}")
for key in DEFAULT_HTML_CONTENT:
if key not in html_content:
html_content[key] = DEFAULT_HTML_CONTENT[key]
html_file = os.path.join(self.project_dir, 'index.html')
# Create HTML file from template in code directory
html_tmp = self.code_object_class().templater.html_template(None, None,
**html_content)
with open(html_file, 'w') as f:
f.write(html_tmp)
else: # HTML file exists, copy it to the project directory
shutil.copy(html_file, os.path.join(self.project_dir, 'index.html'))
with in_directory(directory):
if os.environ.get('BRIAN2WASM_NO_SERVER','0') == '1':
print("Skipping server startup (--no-server flag set)")
return
emsdk_path = prefs.devices.wasm_standalone.emsdk_directory
os.environ['EMSDK_QUIET'] = '1'
if platform.system() == "Windows":
cmd_line = f'cmd.exe /C "call {emsdk_path}\\emsdk_env.bat & emrun index.html"'
else:
run_cmd = ['source', f'{emsdk_path}/emsdk_env.sh', '&&', 'emrun', 'index.html']
cmd_line = f"/bin/bash -c '{' '.join(run_cmd + run_args)}'"
start_time = time.time()
os.system(cmd_line)
self.timers['run_binary'] = time.time() - start_time
[docs]
def build(self, html_file=None, html_content=None, **kwds):
"""
Build the project for the WASM backend.
This method orchestrates the full build pipeline from code generation
to Emscripten compilation and optional execution.
Parameters
----------
html_file : str, optional
Path to a custom HTML template file.
html_content : dict, optional
Dictionary of HTML template variables.
directory : str, optional
Target build directory. Defaults to "output".
results_directory : str, optional
Sub-folder for runtime results. Defaults to "results".
compile : bool, optional
Whether to compile sources with ``emcc``. Default is True.
run : bool, optional
Whether to run the generated bundle. Default is True.
debug : bool, optional
Whether to include debug flags. Default is False.
clean : bool, optional
Whether to clear old build artifacts. Default is False.
with_output : bool, optional
Whether to forward stdout/stderr. Default is True.
additional_source_files : list of str, optional
Extra ``.cpp`` files to include.
run_args : list of str, optional
Additional runtime arguments.
direct_call : bool, optional
True when called directly; False if triggered automatically.
**kwds : dict
Reserved for future options.
Raises
------
RuntimeError
If build state is invalid or already executed.
TypeError
If results_directory is absolute.
ValueError
If invalid options are passed (e.g., negative threads).
Returns
-------
None
Produces a build directory with compiled WASM/HTML output.
"""
self.build_options.update({'html_file': html_file,
'html_content': html_content})
direct_call = kwds.get('direct_call', True)
additional_source_files = kwds.get('additional_source_files', [])
run_args = kwds.get('run_args', [])
directory = kwds.get('directory') or tempfile.mkdtemp(prefix="brian_standalone_")
run = kwds.get('run', True)
debug = kwds.get('debug', False)
clean = kwds.get('clean', False)
with_output = kwds.get('with_output', True)
results_directory = kwds.get('results_directory', 'results')
compile = kwds.get('compile', True)
if self.build_on_run and direct_call:
raise RuntimeError(
"You used set_device with build_on_run=True "
"(the default option), which will automatically "
"build the simulation at the first encountered "
"run call - do not call device.build manually "
"in this case. If you want to call it manually, "
"e.g. because you have multiple run calls, use "
"set_device with build_on_run=False."
)
if self.has_been_run:
raise RuntimeError(
"The network has already been built and run "
"before. To build several simulations in "
'the same script, call "device.reinit()" '
'and "device.activate()". Note that you '
"will have to set build options (e.g. the "
"directory) and defaultclock.dt again."
)
self.project_dir = directory
ensure_directory(directory)
if os.path.isabs(results_directory):
raise TypeError(
"The 'results_directory' argument needs to be a relative path but was "
f"'{results_directory}'."
)
# Translate path to absolute path which ends with /
self.results_dir = os.path.join(
os.path.abspath(os.path.join(directory, results_directory)), ""
)
compiler = "emcc"
extra_compile_args = self.extra_compile_args + prefs["devices.wasm_standalone.emcc_compile_args"]
extra_link_args = self.extra_link_args + prefs["devices.wasm_standalone.emcc_link_args"]
define_macros = (
self.define_macros
+ prefs["codegen.cpp.define_macros"]
+ [m for c in self.code_objects.values() for m in c.compiler_kwds.get("define_macros", [])]
)
include_dirs = (
self.include_dirs
+ prefs["codegen.cpp.include_dirs"]
+ [d for c in self.code_objects.values() for d in c.compiler_kwds.get("include_dirs", [])]
)
library_dirs = (
self.library_dirs
+ prefs["codegen.cpp.library_dirs"]
+ [d for c in self.code_objects.values() for d in c.compiler_kwds.get("library_dirs", [])]
)
# This library is only relevant when targetting Windows
if "advapi32" in self.libraries:
self.libraries.remove("advapi32")
libraries = (
self.libraries
+ prefs["codegen.cpp.libraries"]
+ [l for c in self.code_objects.values() for l in c.compiler_kwds.get("libraries", [])]
)
macro_flags = []
for m in define_macros:
if isinstance(m, (list, tuple)):
name, val = m if len(m) == 2 else (m[0], None)
else:
name, val = m, None
macro_flags.append(f"-D{name}={val}" if val is not None else f"-D{name}")
compiler_flags = (
extra_compile_args
+ macro_flags
+ [f"-I{d}" for d in include_dirs]
)
linker_flags = (
extra_link_args
+ [f"-L{d}" for d in library_dirs]
+ [f"-l{l}" for l in libraries]
)
additional_source_files += [
f for c in self.code_objects.values() for f in c.compiler_kwds.get("sources", [])
]
for d in ("code_objects", "results", "static_arrays"):
ensure_directory(os.path.join(directory, d))
self.writer = CPPWriter(directory)
nb_threads = prefs.devices.cpp_standalone.openmp_threads
if nb_threads < 0:
raise ValueError("OpenMP threads cannot be negative.")
self.check_openmp_compatible(nb_threads)
self.write_static_arrays(directory)
names = [o.name for n in self.networks for o in n.sorted_objects]
dupes = [n for n, c in Counter(names).items() if c > 1]
if dupes:
raise ValueError("Duplicate object names: " + ", ".join(f"'{n}'" for n in dupes))
self.generate_objects_source(self.writer, self.arange_arrays, self.synapses,
self.static_array_specs, self.networks, self.timed_arrays)
self.generate_main_source(self.writer)
self.generate_codeobj_source(self.writer)
self.generate_network_source(self.writer, compiler)
self.generate_synapses_classes_source(self.writer)
self.generate_run_source(self.writer)
self.copy_source_files(self.writer, directory)
self.writer.source_files.update(additional_source_files)
self.generate_makefile(
self.writer,
compiler,
compiler_flags=" ".join(compiler_flags),
linker_flags=" ".join(linker_flags),
nb_threads=nb_threads,
debug=debug,
)
if compile:
# We switch the compiler name back to `mscv` on Windows, to make sure it uses `nmake`
self.compile_source(directory, 'msvc' if os.name == 'nt' else compiler, debug, clean)
if run:
self.run(directory, results_directory, with_output, run_args)
tm = self.timers
logger.debug("Time measurements: " + ", ".join(
f"{lbl}: {tm[g][k]:.2f}s" if isinstance(tm[g], dict) else f"{lbl}: {tm[g]:.2f}s"
for lbl, g, k in (
("'make clean'", "compile", "clean"),
("'make'", "compile", "make"),
("running 'main'", "run_binary", None),
)
if (k and tm[g][k] is not None) or (not k and tm[g] is not None)
))
wasm_standalone_device = WASMStandaloneDevice()
all_devices['wasm_standalone'] = wasm_standalone_device