diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index c9d9a8e7b45f717f30e124eff721e3b75130ff5c..5c97265c1c6e813dedd0784f6bde3f12472caf49 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -766,15 +766,13 @@ static inline int memcg_cache_id(struct mem_cgroup *memcg)
 	return memcg ? memcg->kmemcg_id : -1;
 }
 
-struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep);
+struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp);
 void __memcg_kmem_put_cache(struct kmem_cache *cachep);
 
-static inline bool __memcg_kmem_bypass(gfp_t gfp)
+static inline bool __memcg_kmem_bypass(void)
 {
 	if (!memcg_kmem_enabled())
 		return true;
-	if (!(gfp & __GFP_ACCOUNT))
-		return true;
 	if (in_interrupt() || (!current->mm) || (current->flags & PF_KTHREAD))
 		return true;
 	return false;
@@ -791,7 +789,9 @@ static inline bool __memcg_kmem_bypass(gfp_t gfp)
 static __always_inline int memcg_kmem_charge(struct page *page,
 					     gfp_t gfp, int order)
 {
-	if (__memcg_kmem_bypass(gfp))
+	if (__memcg_kmem_bypass())
+		return 0;
+	if (!(gfp & __GFP_ACCOUNT))
 		return 0;
 	return __memcg_kmem_charge(page, gfp, order);
 }
@@ -810,16 +810,15 @@ static __always_inline void memcg_kmem_uncharge(struct page *page, int order)
 /**
  * memcg_kmem_get_cache: selects the correct per-memcg cache for allocation
  * @cachep: the original global kmem cache
- * @gfp: allocation flags.
  *
  * All memory allocated from a per-memcg cache is charged to the owner memcg.
  */
 static __always_inline struct kmem_cache *
 memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp)
 {
-	if (__memcg_kmem_bypass(gfp))
+	if (__memcg_kmem_bypass())
 		return cachep;
-	return __memcg_kmem_get_cache(cachep);
+	return __memcg_kmem_get_cache(cachep, gfp);
 }
 
 static __always_inline void memcg_kmem_put_cache(struct kmem_cache *cachep)
diff --git a/include/linux/slab.h b/include/linux/slab.h
index 2037a861e3679910152a98ba98a667b395c6773c..3ffee74220126d9b6f52df7b600a6de33f29dbb1 100644
--- a/include/linux/slab.h
+++ b/include/linux/slab.h
@@ -86,6 +86,11 @@
 #else
 # define SLAB_FAILSLAB		0x00000000UL
 #endif
+#ifdef CONFIG_MEMCG_KMEM
+# define SLAB_ACCOUNT		0x04000000UL	/* Account to memcg */
+#else
+# define SLAB_ACCOUNT		0x00000000UL
+#endif
 
 /* The following flags affect the page allocator grouping pages by mobility */
 #define SLAB_RECLAIM_ACCOUNT	0x00020000UL		/* Objects are reclaimable */
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 14cb1db4c52b75ee16d1346e89e8bd6047d3db9f..4bd6c451339314502f13bc8d8c9fd93874c65af7 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2356,7 +2356,7 @@ static void memcg_schedule_kmem_cache_create(struct mem_cgroup *memcg,
  * Can't be called in interrupt context or from kernel threads.
  * This function needs to be called with rcu_read_lock() held.
  */
-struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep)
+struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp)
 {
 	struct mem_cgroup *memcg;
 	struct kmem_cache *memcg_cachep;
@@ -2364,6 +2364,12 @@ struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep)
 
 	VM_BUG_ON(!is_root_cache(cachep));
 
+	if (cachep->flags & SLAB_ACCOUNT)
+		gfp |= __GFP_ACCOUNT;
+
+	if (!(gfp & __GFP_ACCOUNT))
+		return cachep;
+
 	if (current->memcg_kmem_skip_account)
 		return cachep;
 
diff --git a/mm/slab.h b/mm/slab.h
index 7b608719799763cb075c35e3fd4c15e0fddc18b5..c63b8699cfa3d853c63de16b9162df09f3cc72db 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -128,10 +128,11 @@ static inline unsigned long kmem_cache_flags(unsigned long object_size,
 
 #if defined(CONFIG_SLAB)
 #define SLAB_CACHE_FLAGS (SLAB_MEM_SPREAD | SLAB_NOLEAKTRACE | \
-			  SLAB_RECLAIM_ACCOUNT | SLAB_TEMPORARY | SLAB_NOTRACK)
+			  SLAB_RECLAIM_ACCOUNT | SLAB_TEMPORARY | \
+			  SLAB_NOTRACK | SLAB_ACCOUNT)
 #elif defined(CONFIG_SLUB)
 #define SLAB_CACHE_FLAGS (SLAB_NOLEAKTRACE | SLAB_RECLAIM_ACCOUNT | \
-			  SLAB_TEMPORARY | SLAB_NOTRACK)
+			  SLAB_TEMPORARY | SLAB_NOTRACK | SLAB_ACCOUNT)
 #else
 #define SLAB_CACHE_FLAGS (0)
 #endif
diff --git a/mm/slab_common.c b/mm/slab_common.c
index 3c6a86b4ec25f8462c1584dcb5bcf01e4edbd4ff..e016178063e19e86c21a46ea4a39ced03b0fe001 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -37,7 +37,8 @@ struct kmem_cache *kmem_cache;
 		SLAB_TRACE | SLAB_DESTROY_BY_RCU | SLAB_NOLEAKTRACE | \
 		SLAB_FAILSLAB)
 
-#define SLAB_MERGE_SAME (SLAB_RECLAIM_ACCOUNT | SLAB_CACHE_DMA | SLAB_NOTRACK)
+#define SLAB_MERGE_SAME (SLAB_RECLAIM_ACCOUNT | SLAB_CACHE_DMA | \
+			 SLAB_NOTRACK | SLAB_ACCOUNT)
 
 /*
  * Merge control. If this is set then no merging of slab caches will occur.
diff --git a/mm/slub.c b/mm/slub.c
index 46997517406ede0e987388763c9fa1d07875c1db..2d0e610d195ae908586953126f0f125aebf94574 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -5362,6 +5362,8 @@ static char *create_unique_id(struct kmem_cache *s)
 		*p++ = 'F';
 	if (!(s->flags & SLAB_NOTRACK))
 		*p++ = 't';
+	if (s->flags & SLAB_ACCOUNT)
+		*p++ = 'A';
 	if (p != name + 1)
 		*p++ = '-';
 	p += sprintf(p, "%07d", s->size);