问题出现的背景:之前的工作:tensorflow利用for循环进行训练遇到的内存爆炸问题(OOM)

问题介绍

在threading.Thread创建的线程中使用了matplotlib.pyplot来做图,但出现:
UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail

之后查找Matplotlib官方文档,有这样的描述:

Working with threads:
Matplotlib is not thread-safe: in fact, there are known race conditions that affect certain artists. Hence, if you work with threads, it is your responsibility to set up the proper locks to serialize access to Matplotlib artists.
Note that (for the case where you are working with an interactive backend) most GUI backends require being run from the main thread as well.

并不是说无法在子线程中用plt绘图,只是会提示不安全。
经过实践,还是可以在Thread创建的子线程中使用plt来绘图的。(既然这样提示,也有可能出现失败的情况,只是自己没遇到)

如果不想有此Warning的提示,解决如下。

解决方法:

方法一:

将plt需要的变量设置为全局变量,这样等Thread创建的子线程结束之后,在主线程中再进行作图

本代码是要将模型训练得到的结果进行绘图
部分相关代码如下:

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import threading
import time

# distill_history是我们绘图需要用到的变量,所以需要先在模块层面(全局层面)声明此变量。即:类似于C,在所有的import之后声明变量即可
# 如果没有在全局层面声明global变量的话,会报错
distill_history = []
...
# main_loop()相关部分:
 def main_loop(alpha, T):
 	...
 	global distill_history
    # distill teacher to student
    distill_history = distiller.fit(train_images, train_labels, epochs=20, validation_data=(test_images, test_labels))   
    ...

# draw_distill()相关部分:
def draw_distill():
	# 评估distilled student模型
    plt.figure(figsize=(16, 8))
    # the accuracy plot
    plt.subplot(1, 2, 1)
    plt.plot(distill_history.history['accuracy'], label='accuracy')
    plt.plot(distill_history.history['val_accuracy'], label='val_accuracy')
	...                   
	                                                     
if __name__ == '__main__':
	...
	# 利用循环调参,观察不同超参数对应的蒸馏效果
	for alpha in (0.1, 0.2, 0.3):
	        for T in range(5, 21, 5):
	            print(time.strftime("%Hh-%Mm-%Ss: "))
	            t = threading.Thread(target=main_loop, args=(alpha, T))
	            t.start()
	            # join() : 阻塞当前线程,告诉主线程,当线程t执行完之后才能向后继续执行
	            t.join()
	            # 再使用draw_distill()函数来绘图
	            draw_distill()
				...	            
                       
方法二:

使用multiprocessing.Process类,即改用创建子进程的方式来处理(不过因为Process创建的是单独的进程,与主进程相互独立,无法直接使用主进程中的变量,必须将需要的变量作为参数传入args)

multiprocessing.Process与threading.Thread的区别及使用见:另一篇文章:python multiprocessing.Process与threading.Thread的区别以及多进程,多线程的一些使用方法

# 使用Process类的相关代码如下:
if __name__ == '__main__':
    # 加载数据集
    (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
    # Normalize pixel values
    train_images, test_images = train_images / 255.0, test_images / 255.0

    teacher = build_model('teacher', 32, 64, 64, 64)
    # load teacher model from SavedModel
    teacher = keras.models.load_model('teacher_model')

    for alpha in (0.1, 0.2, 0.3):
        for T in range(5, 21, 5):
            print(time.strftime('%H-%M-%S: '))
            """
            使用Process就要把需要的参数都传入args中,传参比较麻烦
            main_loop的中训练用到了teacher, train_image, train_labels等很多变量,都需要传入,不然会报错:xxx is not defined
            """
            p = Process(target=main_loop, args=(alpha, T, teacher, train_images, train_labels, test_images, test_labels))
            """
            相应的,main_loop函数的形参列表也要修改
            可以将plt绘图部分直接放到main_loop里面,因为使用Process不会出现标题所述的warning
            """
            p.start()
            # join() : 阻塞当前进程,告诉主进程,当进程p执行完之后才能向后继续执行进程
            p.join()
方法三:

继承Thread类,重写run()函数,以及添加get_result()函数

根据threading的官方文档,run()函数是有关线程活动的方法,我们可以在子类中重写这个方法

首先参考Thread的源代码:

class Thread:
    ...
    def __init__(self, group=None, target=None, name=None,
                 args=(), kwargs=None, *, daemon=None):
     	...
        if kwargs is None:
            kwargs = {}
        self._target = target
        self._args = args
        self._kwargs = kwargs
        ...
	    def run(self):
        """Method representing the thread's activity.

        You may override this method in a subclass. The standard run() method
        invokes the callable object passed to the object's constructor as the
        target argument, if any, with sequential and keyword arguments taken
        from the args and kwargs arguments, respectively.
        """
        try:
            if self._target:
                self._target(*self._args, **self._kwargs)
        finally:
            # Avoid a refcycle if the thread is running a function with
            # an argument that has a member that points to the thread.
            del self._target, self._args, self._kwargs
   		...

我们进行修改,添加self._result成员变量来记录传入的target函数的返回值,并将获取返回值的操作封装进get_result()函数中

修改后代码如下:

from threading import Thread
import traceback

class MyThread(Thread):
    def __init__(self, target=None, args=()):
        super(MyThread, self).__init__()
        self._target = target
        self._args = args
        # 先在__init__中声明 _result 变量
        self._result = None

    def run(self):
        try:
            if self._target:
            	# 此处做修改
                self._result = self._target(*self._args)
            else:
                print('target is None')
        finally:
            # Avoid a refcycle if the thread is running a function with
            # an argument that has a member that points to the thread.
            del self._target, self._args
	
	# 添加get_result()成员函数
    def get_result(self):
        try:
            if self._result:
                return self._result
        except Exception as e:
            traceback.print_exc()
            return None

进行测试:

def func_a(x, y):
    print('ok')
    return x * y

def func_b(num):
    print(f'num is : {num}')

if __name__ == '__main__':
    for i in range(5):
        t = MyThread(target=func_a, args=(i, i + 1))
        t.start()
        t.join()
        result = t.get_result()
        func_b(result)
注意:

使用方法三时,要记得修改我们传入的target函数,需要在target函数中把我们需要的变量return出来

参考

Matplotlib官方文档
threading的官方文档

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐