from pathlib import Path
import gdown
from socialnet import model_test, model_train
from socialnet.plot import (
plot_aggregation_subnetwork,
plot_interaction_scores,
plot_product,
plot_interaction_subnetwork,
)
# data from https://drive.google.com/drive/folders/1VH97_bNFz09Ke_kBL1oV2HbTS25BwnKC
gdown.download(
id="1y1ZhNr3eWbhYwA_ZPfIs9UjsinszKCwp", output="zebrafish_60_1.npy", resume=True
)
gdown.download(
id="1aJb2pgzJE8dhDkWVGKWdBgYgFRhpWkj7", output="zebrafish_60_2.npy", resume=True
)
gdown.download(
id="1LaNIqFGD5N9SUIq0yfIGEZUnrzF3TgLy", output="zebrafish_60_3.npy", resume=True
)
results_dict = model_train(
["zebrafish_60_1.npy", "zebrafish_60_2.npy", "zebrafish_60_3.npy"],
session_name="example",
)
expected_output_folder = Path.cwd() / "socialnet_session_example"
test_results = model_test(
trajectory_files=["zebrafish_60_1.npy", "zebrafish_60_2.npy", "zebrafish_60_3.npy"],
model_folder=expected_output_folder,
)
# plot slices of the neighbor x and y coordinates for different values of the focal velocity and the neighbor velocity
# fv = focal velocity
# nba = neighbour acceleration
# nbv = neighbour velocity
# nbx = neighbour x position
# nby = neighbour y position
fig_vars = ("fv", "nbv", "nbx", "nby") # row_var, col_var, x_var, y_var
plot_aggregation_subnetwork(expected_output_folder, fig_vars=fig_vars)
plot_interaction_subnetwork(expected_output_folder)
plot_interaction_scores(expected_output_folder, fig_vars=fig_vars)
plot_product(expected_output_folder, fig_vars=fig_vars)
SocialNet#
Source code
Check the source code at https://gitlab.com/polavieja_lab/socialnet.
SocialNet, presented in [1], is a deep learning model of collective behavior that learns the rules governing how animals influence each other. It is organized into two interpretable modules: a pair-interaction subnetwork that maps how one individual affects another, and an aggregation subnetwork that describes how each individual weighs and combines the influences of all its neighbors.
The model is trained on trajectory files produced by idtracker.ai or any other compatible source.
The SocialNet architecture (extracted from Figure 1 in [1]). (A) Variables used to predict future turns. Asocial variables, those only involving the focal, in red. Social variables, those involving both the focal and a neighbour, in blue. (B) Pair-interaction subnetwork of SocialNet, receiving asocial variables \(\alpha\) and social variables \(\sigma_i\) from a single neighbour \(i\), and outputting a single scalar value. All pair-interaction networks share the same weights. (C) Aggregation subnetwork of SocialNet. Same structure as B, but the input is a restricted symmetric subset of the variables and the output is passed through an exponential function to make it positive. (D) General SocialNet architecture, showing how the inputs of the pair-interaction and aggregation subnetworks are integrated to produce a single logit \(z\) for the focal fish turning right after 1 s.#
Left: Example interaction map showing how SocialNet predicts the influence of a neighbor on the focal animal’s probability of turning right after 1 second (obtained with
plot.plot_interaction_subnetwork()).Right: Example aggregation map showing how SocialNet weighs different neighbor variables when aggregating social information (obtained with
plot.plot_aggregation_subnetwork()).Install SocialNet#
Warning
SocialNet is not included in idtracker.ai, so it needs to be installed separately. It depends on Tensorflow which requires a Python version between 3.9 and 3.12. If you recently created your Conda environment for idtracker.ai you probably built it with Python 3.13. You can check the Python version of your Conda environment with the command
python --version. If you have Python 3.13 or higher, you will need to create a new Conda environment with Python 3.12 to install SocialNet:If the idtracker.ai environment has Python between 3.9-3.12 you can reuse it and install SocialNet directly in it.
Install SocialNet by first installing TensorFlow following the TensorFlow installation guide. And finally install SocialNet from our repository using pip:
Basic usage of SocialNet#
SocialNet has a simple API that allows users to train (
model_train()) and test (model_test()) the model using Python.model_trainTrain SocialNet model from a set of trajectory files.
model_testTest the model on the given trajectory files.
After training, the model can be analyzed by using the following plotting functions:
plot.plot_aggregation_subnetworkVisualizes the aggregation subnetwork of the model by plotting how weights vary as a function of the specified variables.
plot.plot_interaction_subnetworkPlots the probability (logit) of the focal animal turning right resulting from the pair-interaction subnetwork of the attention network, as a function of the orientation of the neighbour with respect to the focal (θi) and the speed of the neighbour (vi).
plot.plot_interaction_scoresPlots processed interaction scores (attraction, alignment, and repulsion) as a function of kinematic variables of the focal and neighbour animals.
plot.plot_productPlots the combination (product) of
plot.plot_interaction_scores()andplot.plot_aggregation_subnetwork(), visualizing how the interaction and attention subnetworks jointly affect the predicted interaction scores.Here is an example of how to use SocialNet to train and test a model with trajectory files. In this example, we download some sample trajectory files from Google Drive data repository, train a model, and then test it.
Alternative SocialNet CLI
SocialNet also has a command line interface (CLI). This provides a simple way to interact with the SocialNet API without needing to write Python code.
References