SocialNet#

SocialNet, presented in [1], is a deep learning model of collective behavior that learns the rules governing how animals influence each other. It is organized into two interpretable modules: a pair-interaction subnetwork that maps how one individual affects another, and an aggregation subnetwork that describes how each individual weighs and combines the influences of all its neighbors.

The model is trained on trajectory files produced by idtracker.ai or any other compatible source.

SocialNet diagram

The SocialNet architecture (extracted from Figure 1 in [1]). (A) Variables used to predict future turns. Asocial variables, those only involving the focal, in red. Social variables, those involving both the focal and a neighbour, in blue. (B) Pair-interaction subnetwork of SocialNet, receiving asocial variables \(\alpha\) and social variables \(\sigma_i\) from a single neighbour \(i\), and outputting a single scalar value. All pair-interaction networks share the same weights. (C) Aggregation subnetwork of SocialNet. Same structure as B, but the input is a restricted symmetric subset of the variables and the output is passed through an exponential function to make it positive. (D) General SocialNet architecture, showing how the inputs of the pair-interaction and aggregation subnetworks are integrated to produce a single logit \(z\) for the focal fish turning right after 1 s.#

https://gitlab.com/polavieja_lab/socialnet/-/raw/master/examples/interaction_subnetwork.png https://gitlab.com/polavieja_lab/socialnet/-/raw/master/examples/aggregation_subnetwork_vars_fv_nbv_nbx_nby.png

Left: Example interaction map showing how SocialNet predicts the influence of a neighbor on the focal animal’s probability of turning right after 1 second (obtained with plot.plot_interaction_subnetwork()).

Right: Example aggregation map showing how SocialNet weighs different neighbor variables when aggregating social information (obtained with plot.plot_aggregation_subnetwork()).

Install SocialNet#

Warning

SocialNet is not included in idtracker.ai, so it needs to be installed separately. It depends on Tensorflow which requires a Python version between 3.9 and 3.12. If you recently created your Conda environment for idtracker.ai you probably built it with Python 3.13. You can check the Python version of your Conda environment with the command python --version. If you have Python 3.13 or higher, you will need to create a new Conda environment with Python 3.12 to install SocialNet:

conda create -n socialnet python=3.12
conda activate socialnet

If the idtracker.ai environment has Python between 3.9-3.12 you can reuse it and install SocialNet directly in it.

Install SocialNet by first installing TensorFlow following the TensorFlow installation guide. And finally install SocialNet from our repository using pip:

pip install git+https://gitlab.com/polavieja_lab/socialnet

Basic usage of SocialNet#

SocialNet has a simple API that allows users to train (model_train()) and test (model_test()) the model using Python.

model_train

Train SocialNet model from a set of trajectory files.

model_test

Test the model on the given trajectory files.

After training, the model can be analyzed by using the following plotting functions:

plot.plot_aggregation_subnetwork

Visualizes the aggregation subnetwork of the model by plotting how weights vary as a function of the specified variables.

plot.plot_interaction_subnetwork

Plots the probability (logit) of the focal animal turning right resulting from the pair-interaction subnetwork of the attention network, as a function of the orientation of the neighbour with respect to the focal (θi) and the speed of the neighbour (vi).

plot.plot_interaction_scores

Plots processed interaction scores (attraction, alignment, and repulsion) as a function of kinematic variables of the focal and neighbour animals.

plot.plot_product

Plots the combination (product) of plot.plot_interaction_scores() and plot.plot_aggregation_subnetwork(), visualizing how the interaction and attention subnetworks jointly affect the predicted interaction scores.

Here is an example of how to use SocialNet to train and test a model with trajectory files. In this example, we download some sample trajectory files from Google Drive data repository, train a model, and then test it.

Example of training, testing and plotting SocialNet Open in Colab #
from pathlib import Path
import gdown
from socialnet import model_test, model_train
from socialnet.plot import (
    plot_aggregation_subnetwork,
    plot_interaction_scores,
    plot_product,
    plot_interaction_subnetwork,
)

# data from https://drive.google.com/drive/folders/1VH97_bNFz09Ke_kBL1oV2HbTS25BwnKC

gdown.download(
    id="1y1ZhNr3eWbhYwA_ZPfIs9UjsinszKCwp", output="zebrafish_60_1.npy", resume=True
)
gdown.download(
    id="1aJb2pgzJE8dhDkWVGKWdBgYgFRhpWkj7", output="zebrafish_60_2.npy", resume=True
)
gdown.download(
    id="1LaNIqFGD5N9SUIq0yfIGEZUnrzF3TgLy", output="zebrafish_60_3.npy", resume=True
)

results_dict = model_train(
    ["zebrafish_60_1.npy", "zebrafish_60_2.npy", "zebrafish_60_3.npy"],
    session_name="example",
)

expected_output_folder = Path.cwd() / "socialnet_session_example"


test_results = model_test(
    trajectory_files=["zebrafish_60_1.npy", "zebrafish_60_2.npy", "zebrafish_60_3.npy"],
    model_folder=expected_output_folder,
)

# plot slices of the neighbor x and y coordinates for different values of the focal velocity and the neighbor velocity
# fv = focal velocity
# nba = neighbour acceleration
# nbv = neighbour velocity
# nbx = neighbour x position
# nby = neighbour y position
fig_vars = ("fv", "nbv", "nbx", "nby")  # row_var, col_var, x_var, y_var

plot_aggregation_subnetwork(expected_output_folder, fig_vars=fig_vars)
plot_interaction_subnetwork(expected_output_folder)
plot_interaction_scores(expected_output_folder, fig_vars=fig_vars)
plot_product(expected_output_folder, fig_vars=fig_vars)
Alternative SocialNet CLI

SocialNet also has a command line interface (CLI). This provides a simple way to interact with the SocialNet API without needing to write Python code.

Main CLI commands#
socialnet train --help
socialnet test --help
Plotting CLI commands#
socialnet plot --help
socialnet plot aggregation_subnetwork --help
socialnet plot interaction_subnetwork --help
socialnet plot interaction_scores --help
socialnet plot product --help

References