Skip to content

Improving Hypernetwork initialization #2740

@danielalcalde

Description

@danielalcalde

This question/issue was first pointed out in

I have a question, why the hypernetwork is just two linear without activation?

Originally posted by @vexilligera in #2284 (comment)

I also asked myself the same today while reviewing the code and it makes no sense
I get the same/similar results just using one linear function with 4x less space needed for the .pt files.

Right now it is implemented like this

self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)

but if one wants to keep the two linear structure we might take advantage of it to reduce the optimization space by doing something like this

self.linear1 = torch.nn.Linear(dim, dim //n)
self.linear2 = torch.nn.Linear(dim //n, dim)

where n >= 2

this should also make it easier to combine hyper networks since we will be working in a subspace of the context.

and once we are at it, I would also suggest initializing with XAVIER initialization
std=0.01 / sqrt(dim)

Specifically, I suggest either remove the double linear and have:

p = 0.01
self.linear = torch.nn.Linear(dim, dim)
self.linear.weight.data.normal_(mean=0.0, std=p/math.sqrt(dim))

or for n >= 2

self.linear1 = torch.nn.Linear(dim, dim//n, bias=False)
self.linear2 = torch.nn.Linear(dim//n, dim)

std = math.sqrt(p)*np.sqrt(2/(dim+dim/n))
self.linear1.weight.data.normal_(mean=0.0, std=std)
self.linear2.weight.data.normal_(mean=0.0, std=std)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions