In [None]:
%matplotlib inline
%config InlineBackend.figure_format ='retina'

import torch
import socialforce

_ = torch.manual_seed(42)

# 1+1D


## Parametric

The potential $V(b, d_{\perp})$ is approximated by two 1D potentials:
\begin{align}
    V(b, d_{\perp}) &= \textrm{SF}(b) \cdot \max(0, 1 + a d_{\perp})
\end{align}
where $a$ is the `asymmetry` parameter of the `PedPedPotential2D` constructor.

In [None]:
V = socialforce.potentials.PedPedPotential2D(asymmetry=-1.0)
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential_2d(V, ax1)
    socialforce.show.potential_2d_grad(V, ax2)

## Scenarios

Here we use a combination of synthetic {ref}`Circle and ParallelOvertake scenarios <scenarios>`.

In [None]:
circle = socialforce.scenarios.Circle(ped_ped=V)
parallel = socialforce.scenarios.ParallelOvertake(ped_ped=V)
scenarios = circle.generate(5) + parallel.generate(5)
true_experience = socialforce.Trainer.scenes_to_experience(scenarios)

## MLP

Next we create a model for pedestrian-pedestrian interaction that is the 
product of two 1D potentials: one potential as a function of $b$ and another 
as a function of perpendicular distance. The potential is initialized to random
weights and biases.

\begin{align}
    V(b, d_{\perp}) &= \textrm{MLP}_b(b) \cdot \textrm{MLP}_{\perp}(d_{\perp}) \;\;\; .
\end{align}

In [None]:
V = socialforce.potentials.PedPedPotentialMLP1p1D()
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential_2d(V, ax1)
    socialforce.show.potential_2d_grad(V, ax2)

## Inference

Next, we use the standard SGD optimizer from PyTorch and train the ped-ped 
interaction model on the synthetic data created above.

In [None]:
simulator = socialforce.Simulator(ped_ped=V) 
opt = torch.optim.SGD(V.parameters(), lr=1.0)
socialforce.Trainer(simulator, opt).loop(20, true_experience, log_interval=5)

In [None]:
with socialforce.show.canvas(figsize=(12, 6), ncols=2) as (ax1, ax2):
    socialforce.show.potential_2d(V, ax1)
    socialforce.show.potential_2d_grad(V, ax2)