Rohan Koodli
Rohan Koodli
Mentioned by @jjhong922 in #1608 - need a Jax equivalent of the PyTorch FCLayers to reduce redundancy in FlaxEncoder/Decoder
Error when using Flax LayerNorm in the Decoder in place of BatchNorm for JaxPEAKVI ## Code ```python class FlaxDecoder(nn.Module): n_input: int dropout_rate: float n_hidden: int def setup(self): self.dense1 = Dense(self.n_hidden)...
When using pytest fixtures to create simple SCVI models in test_model, the runtime of the suite increases from 4 -> 14 seconds, as reported in #1576
## CNN Use BEAR notation as additional structural feature ## SAP For comparisons, use BEAR instead of dot-bracket notation
- Needs to work for moves with multiple base changes in one move - Figure out a way to encode number of moves needed to complete the puzzle
- On 1 GPU, limited to 10 convolutional layers - When parallelized across multiple GPUs, can add more layers as more memory available
Currently `structure_and_energy_at_current_time` works only with 1 puzzle ID. Would reduce the number of pickles and the amount of time unpickling when training.