Skip to content

feat(onnx): ViT zero-shot tasks #858

@QIN2DIM

Description

@QIN2DIM

Intro

See the example code for details.

The CLIP multimodal model enables zero-shot image classification. I've tested this on multiple datasets and the model is over 99.9% accurate, as long as an appropriate prompt is provided.

We just need to write positive_labels and negative_labels based on the cue words of the known challenge (image_binary_challenge). If a new prompt is encountered that has never been processed before, the program automatically performs the conversion and adjustment for the dichotomous task.

We tried to reproduce the process module using numpy, i.e., we did not need to rely on PyTorch to implement the process.

By default, we use the RN50.openai specification of the model for classification tasks. We encapsulate the activation of both the ONNX and VitTransformer Pipeline branches so that the program switches automatically when you have both torch and transformers installed in your runtime environment and a CUDA GPU available. Otherwise, it defaults to using ONNX and running on a CPU.

DEFAULT_CLIP_VISUAL_MODEL: str = "visual_CLIP_RN50.openai.onnx"
DEFAULT_CLIP_TEXTUAL_MODEL: str = "textual_CLIP_RN50.openai.onnx"
"""
Available Model
--- 1180+ MiB
DEFAULT_CLIP_VISUAL_MODEL: str = "visual_CLIP_ViT-B-32.openai.onnx"
DEFAULT_CLIP_TEXTUAL_MODEL: str = "textual_CLIP_ViT-B-32.openai.onnx"
--- 658.3 MiB
DEFAULT_CLIP_VISUAL_MODEL: str = "visual_CLIP_RN50.openai.onnx"
DEFAULT_CLIP_TEXTUAL_MODEL: str = "textual_CLIP_RN50.openai.onnx"
--- 3300+ MiB
DEFAULT_CLIP_VISUAL_MODEL: str = "visual_CLIP-ViT-L-14-DataComp.XL-s13B-b90K.onnx"
DEFAULT_CLIP_TEXTUAL_MODEL: str = "textual_CLIP-ViT-L-14-DataComp.XL-s13B-b90K.onnx"
"""

DEMO

datalake_post drawio

"""
1. **positive_labels** can contain only the slashed prompt, i.e., the meaning specified by the prompt

2. **negative_labels** usually have multiple categories,

please observe the other labels in the 9 images and fill in the label_name

3. **positive_labels** can fill in more than one, when there is ambiguity in the prompt.

   For example, if the prompt asks to select a `vehicle`, but `car` and `airplane` appear in the task.

   You can fill in this: `positive_labels = ["vehicle", "car", "airplane"]`

4. Sometimes the prompt doesn't change, but its corresponding image group is replaced.
   If you observe this, update your `datalake_post` to do so!

5. If a prompt never appears, i.e. you don't update it to datalake, the program automatically disassembles the prompt
and adds simple antonyms to the mapping network to ensure that the binary classification task proceeds properly.

   This process works sometimes, but the correctness rate is obviously no better than the way you fill it out manually
"""
from hcaptcha_challenger import split_prompt_message, label_cleaning, DataLake


def handle(x): return split_prompt_message(label_cleaning(x), "en")


datalake_post = {
    # --> off-road vehicle
    handle("Please click each image containing an off-road vehicle"): {
        "positive_labels": ["off-road vehicle"],
        "negative_labels": ["car", "bicycle"],
    },
    # --> pair of headphones
    handle("Please click each image containing a pair of headphones"): {
        "positive_labels": ["headphones"],
        "negative_labels": ["car", "elephant", "cat"]
    },
    # --> item of office equipment
    handle("Please click each image containing an item of office equipment"): {
        "positive_labels": ["office equipment", "chair"],
        "negative_labels": ["shoes", "guitar", "drum", "musical instruments"]
    }
}


def common():
    from hcaptcha_challenger import ModelHub

    # ... Some of the operations you are familiar

    modelhub = ModelHub.from_github_repo()
    modelhub.parse_objects()

    print(f"Before {modelhub.datalake.keys()=}")

    # Merge the data. And use this modelhub object later
    for prompt, serialized_binary in datalake_post.items():
        modelhub.datalake[prompt] = DataLake.from_serialized(serialized_binary)

    print(f"After {modelhub.datalake.keys()=}\n")

    for prompt, dl in modelhub.datalake.items():
        print(f"{prompt=}")
        print(f"{dl=}\n")

    # ... Some of the operations you are familiar


if __name__ == '__main__':
    common()

datalake:
furniture:
positive_labels:
- furniture
negative_labels:
- headphones
- guitar
- game tool
- keyboard
off-road vehicle:
positive_labels:
- off-road vehicle
negative_labels:
- car
- bicycle
pair of headphones:
positive_labels:
- pair of headphones
negative_labels:
- elephant
- car
- cat

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature新特性或新需求fixedBUG 已修复或问题已解决🦜 blog

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions