12 lines
239 B
Python
12 lines
239 B
Python
"""Helper modules to build our networks."""
|
|
import torch as th
|
|
|
|
|
|
class Flatten(th.nn.Module):
|
|
def __init__(self):
|
|
super(Flatten, self).__init__()
|
|
|
|
def forward(self, x):
|
|
bs = x.shape[0]
|
|
return x.view(bs, -1)
|