[Devel] [PATCH RHEL7 COMMIT] oom: fix NULL ptr deref on oom if memory cgroup is disabled

Konstantin Khorenko khorenko at virtuozzo.com
Fri Apr 29 09:28:22 PDT 2016


The commit is pushed to "branch-rh7-3.10.0-327.10.1.vz7.12.x-ovz" and will appear at https://src.openvz.org/scm/ovz/vzkernel.git
after rh7-3.10.0-327.10.1.vz7.12.15
------>
commit d5e7fba013768ed994b7af329f466b40a72ffe4e
Author: Vladimir Davydov <vdavydov at virtuozzo.com>
Date:   Fri Apr 29 20:28:22 2016 +0400

    oom: fix NULL ptr deref on oom if memory cgroup is disabled
    
    mem_cgroup_iter and try_get_mem_cgroup_from_mm return NULL in this case,
    handle this properly.
    
    https://jira.sw.ru/browse/PSBM-43328
    
    Signed-off-by: Vladimir Davydov <vdavydov at virtuozzo.com>
---
 include/linux/memcontrol.h |  5 +++--
 mm/memcontrol.c            |  4 +++-
 mm/oom_kill.c              | 20 +++++++++++---------
 3 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 27b3c56..1427692 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -31,6 +31,8 @@ struct mm_struct;
 struct kmem_cache;
 struct oom_context;
 
+extern struct oom_context global_oom_ctx;
+
 /* Stats that can be updated by kernel. */
 enum mem_cgroup_page_stat_item {
 	MEMCG_NR_FILE_MAPPED, /* # of pages charged as file rss */
@@ -392,8 +394,7 @@ mem_cgroup_update_lru_size(struct lruvec *lruvec, enum lru_list lru,
 static inline struct oom_context *
 mem_cgroup_oom_context(struct mem_cgroup *memcg)
 {
-	extern struct oom_context oom_ctx;
-	return &oom_ctx;
+	return &global_oom_ctx;
 }
 
 static inline unsigned long mem_cgroup_overdraft(struct mem_cgroup *memcg)
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 82204b3..7061864 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1699,6 +1699,8 @@ void mem_cgroup_note_oom_kill(struct mem_cgroup *root_memcg,
 
 struct oom_context *mem_cgroup_oom_context(struct mem_cgroup *memcg)
 {
+	if (mem_cgroup_disabled())
+		return &global_oom_ctx;
 	if (!memcg)
 		memcg = root_mem_cgroup;
 	return &memcg->oom_ctx;
@@ -1708,7 +1710,7 @@ unsigned long mem_cgroup_overdraft(struct mem_cgroup *memcg)
 {
 	unsigned long long guarantee, usage;
 
-	if (mem_cgroup_is_root(memcg))
+	if (mem_cgroup_disabled() || mem_cgroup_is_root(memcg))
 		return 0;
 
 	guarantee = ACCESS_ONCE(memcg->oom_guarantee);
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 2d0fcac..f9a8e62 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -51,12 +51,10 @@ static DEFINE_SPINLOCK(oom_context_lock);
 #define OOM_BASE_RAGE	-10
 #define OOM_MAX_RAGE	20
 
-#ifndef CONFIG_MEMCG
-struct oom_context oom_ctx = {
+struct oom_context global_oom_ctx = {
 	.rage		= OOM_BASE_RAGE,
-	.waitq		= __WAIT_QUEUE_HEAD_INITIALIZER(oom_ctx.waitq),
+	.waitq		= __WAIT_QUEUE_HEAD_INITIALIZER(global_oom_ctx.waitq),
 };
-#endif
 
 void init_oom_context(struct oom_context *ctx)
 {
@@ -187,7 +185,8 @@ static unsigned long mm_overdraft(struct mm_struct *mm)
 	memcg = try_get_mem_cgroup_from_mm(mm);
 	ctx = mem_cgroup_oom_context(memcg);
 	overdraft = ctx->overdraft;
-	mem_cgroup_put(memcg);
+	if (memcg)
+		mem_cgroup_put(memcg);
 
 	return overdraft;
 }
@@ -485,7 +484,8 @@ void mark_oom_victim(struct task_struct *tsk)
 		ctx->marked = true;
 	}
 	spin_unlock(&oom_context_lock);
-	mem_cgroup_put(memcg);
+	if (memcg)
+		mem_cgroup_put(memcg);
 }
 
 /**
@@ -596,7 +596,7 @@ bool oom_trylock(struct mem_cgroup *memcg)
 		 * information will be used in oom_badness.
 		 */
 		ctx->overdraft = mem_cgroup_overdraft(iter);
-		parent = parent_mem_cgroup(iter);
+		parent = iter ? parent_mem_cgroup(iter) : NULL;
 		if (parent && iter != memcg)
 			ctx->overdraft = max(ctx->overdraft,
 				mem_cgroup_oom_context(parent)->overdraft);
@@ -633,7 +633,8 @@ void oom_unlock(struct mem_cgroup *memcg)
 			 * on it for the victim to exit below.
 			 */
 			victim_memcg = iter;
-			mem_cgroup_get(iter);
+			if (iter)
+				mem_cgroup_get(iter);
 
 			mem_cgroup_iter_break(memcg, iter);
 			break;
@@ -683,7 +684,8 @@ void oom_unlock(struct mem_cgroup *memcg)
 	 */
 	ctx = mem_cgroup_oom_context(victim_memcg);
 	__wait_oom_context(ctx);
-	mem_cgroup_put(victim_memcg);
+	if (victim_memcg)
+		mem_cgroup_put(victim_memcg);
 }
 
 /*


More information about the Devel mailing list