Skip to content

Fix/refactor dynamo ipex backend#93863

Closed
jansel wants to merge 9 commits intogh/jansel/37/basefrom
gh/jansel/37/head
Closed

Fix/refactor dynamo ipex backend#93863
jansel wants to merge 9 commits intogh/jansel/37/basefrom
gh/jansel/37/head

Conversation

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93863

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 33919d0:

BROKEN TRUNK - The following jobs failed but were present on the merge base 6650aac:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 1, 2023

In order to get IPEX working with PyTorch nighties I needed to patch a bunch of stuff. I just commented out or deleted code in two places when I hit build errors (mostly symint related), since I just wanted to test that the dynamo bindings worked.

diff --git a/CMakeLists.txt b/CMakeLists.txt
index a1ddfefb..f261d390 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -23,10 +23,6 @@ if(NOT EXISTS ${TORCH_INSTALL_PREFIX})
   message(FATAL_ERROR "Can NOT find torch install path at ${TORCH_INSTALL_PREFIX}!")
 endif()
 
-if(NOT ${Torch_COMP_VERION} VERSION_EQUAL "${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}")
-  message(FATAL_ERROR "Not compatible Torch version ${Torch_VERSION} at ${TORCH_INSTALL_PREFIX}!\nTorch ${Torch_COMP_VERION} is needed!")
-endif()
-
 include(${IPEX_ROOT_DIR}/cmake/Options.cmake)
 include(${IPEX_ROOT_DIR}/cmake/BuildFlags.cmake)
 
diff --git a/csrc/cpu/aten/TensorShape.cpp b/csrc/cpu/aten/TensorShape.cpp
index f293ba2c..bebfc816 100644
--- a/csrc/cpu/aten/TensorShape.cpp
+++ b/csrc/cpu/aten/TensorShape.cpp
@@ -63,6 +63,7 @@ void resize_out(
   // strides from the meta function and directly use the output tensor's
   // preexisting strides
   if (resized) {
+    /*
     if (!strides.empty()) {
       TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
       at::native::as_strided_(out, sizes, strides);
@@ -70,6 +71,7 @@ void resize_out(
       out.unsafeGetTensorImpl()->empty_tensor_restride(
           *options.memory_format_opt());
     }
+    */
   }
 }
 
@@ -303,4 +305,4 @@ IPEX_TORCH_LIBRARY_IMPL(aten, CPU, m) {
 }
 
 } // namespace cpu
-} // namespace torch_ipex
\ No newline at end of file
+} // namespace torch_ipex
diff --git a/csrc/cpu/aten/kernels/EmbeddingBagKrnl.cpp b/csrc/cpu/aten/kernels/EmbeddingBagKrnl.cpp
index 1794e205..bf20ab91 100644
--- a/csrc/cpu/aten/kernels/EmbeddingBagKrnl.cpp
+++ b/csrc/cpu/aten/kernels/EmbeddingBagKrnl.cpp
@@ -123,15 +123,7 @@ static inline at::Tensor _sparse_coo_tensor_unsafe(
   assert(options.has_layout() && options.layout() == c10::kSparse);
   int64_t sparse_dim = indices.size(0);
   int64_t dense_dim = values.dim() - 1;
-  return at::native::new_with_dims_and_tensor_sparse(
-      sparse_dim,
-      dense_dim,
-      size,
-      indices,
-      values,
-      values.scalar_type(),
-      c10::kSparse,
-      values.device());
+  return at::empty({1});  // HACK to make PT2 build work
 }
 
 template <typename T>
diff --git a/csrc/jit/initialization.cpp b/csrc/jit/initialization.cpp
index 6c09a75f..b6f91abe 100644
--- a/csrc/jit/initialization.cpp
+++ b/csrc/jit/initialization.cpp
@@ -47,17 +47,6 @@ void InitIPEX::check_pytorch_version() {
       IPEX_VERSION_MINOR = std::stoi(match[2]);
     }
   }
-  if (IPEX_VERSION_MAJOR != TORCH_VERSION_MAJOR ||
-      IPEX_VERSION_MINOR != TORCH_VERSION_MINOR) {
-    printf(
-        "ERROR! Intel® Extension for PyTorch* needs to work with PyTorch/libtorch %d.%d.*, but PyTorch/libtorch %d.%d.%d is found. Please switch to the matching version and run again.\n",
-        IPEX_VERSION_MAJOR,
-        IPEX_VERSION_MINOR,
-        TORCH_VERSION_MAJOR,
-        TORCH_VERSION_MINOR,
-        TORCH_VERSION_PATCH);
-    exit(127);
-  }
 }
 
 } // namespace torch_ipex
diff --git a/intel_extension_for_pytorch/__init__.py b/intel_extension_for_pytorch/__init__.py
index 3a2a0ba9..c3581efc 100644
--- a/intel_extension_for_pytorch/__init__.py
+++ b/intel_extension_for_pytorch/__init__.py
@@ -17,9 +17,6 @@ if matches and len(matches.groups()) == 1:
 matches = re.match('(\d+\.\d+).*', __version__)
 if matches and len(matches.groups()) == 1:
   ipex_version = matches.group(1)
-if torch_version == '' or ipex_version == '' or torch_version != ipex_version:
-  print('ERROR! Intel® Extension for PyTorch* needs to work with PyTorch {0}.*, but PyTorch {1} is found. Please switch to the matching version and run again.'.format(ipex_version, torch.__version__))
-  exit(127)
 
 from . import cpu
 from . import quantization
@@ -32,4 +29,4 @@ from .cpu._auto_kernel_selection import _enable_dnnl, _disable_dnnl, _using_dnnl
 
 # for xpu
 import intel_extension_for_pytorch.xpu
-from . import optim
\ No newline at end of file
+from . import optim

@jansel jansel added topic: not user facing topic category and removed release notes: onnx torch.onnx related changes that should show up in the release notes labels Feb 1, 2023
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@jansel jansel requested a review from desertfire February 2, 2023 05:48
@desertfire
Copy link
Copy Markdown
Contributor

In order to get IPEX working with PyTorch nighties I needed to patch a bunch of stuff. I just commented out or deleted code in two places when I hit build errors (mostly symint related), since I just wanted to test that the dynamo bindings worked.

Not sure I understand. Wouldn't you need to include those patches to pass some nightly tests?

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 2, 2023

Not sure I understand. Wouldn't you need to include those patches to pass some nightly tests?

Our CI runners don't have IPEX installed.

Those patches would need to go in the ipex repo, not the PyTorch one. I assume the intel folks plan to update ipex in the future to support the latest PyTorch.

Copy link
Copy Markdown
Contributor

@desertfire desertfire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp to unblock, but you might want to sync with @jgong5 before landing this.

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@jgong5
Copy link
Copy Markdown
Collaborator

jgong5 commented Feb 3, 2023

In order to get IPEX working with PyTorch nighties I needed to patch a bunch of stuff. I just commented out or deleted code in two places when I hit build errors (mostly symint related), since I just wanted to test that the dynamo bindings worked.

We will do a rebase of IPEX to make sure IPEX master can work with PyTorch nighties ASAP. cc @zhuhaozhe who is working on it.

return model


def ipex_fp32(gm: torch.fx.GraphModule, example_inputs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel : @jiayisunx had a PR #92067 to combine ipex_fp32 and ipex_bf16 into a unified "ipex" backend which simplifies the usage. Do you prefer to have this PR landed before @jiayisunx 's PR or incorporate her changes in this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can go ahead and land that one. I will rebase onto it and apply changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied that PR to #94028 (right above this in the stack). Feel free to land the other one which will make #94028 a no-op, which I can close.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#92067 has been landed, thanks!

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@jansel jansel added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 3, 2023
@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/jansel/37/head branch June 8, 2023 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants