رفتن به نوشته‌ها

چگونه از tf.data در تنسورفلو استفاده کنیم؟

یکی از کارهایی که تقریبا تو همه تسک‌های ماشین لرنینگ و دیپ‌ لرنینگ انجام میشه خوراندن (feed) داده به مدل برای آموزش است. از اونجایی که با گسترش مدل‌های دیپ‌لرنیگ حجم دادگان هم افزایش پیدا کرده است، در نتیجه خوراندن بهینه (efficient) داده‌ها به مدل یکی از کارهایی مهمی است که انجام می‌شود. خوراندن باید به نحوی باشد که GPU تا حد امکان مشغول نگه داریم و نذاریم معطل کارهای مرتبط با IO باشد.

در تنسورفلو به طور کلی ۴ روش برای خوراندن داده به مدل داریم:

  • استفاده از tf.data
  • استفاده از placeholder و feed کردن به مدل در هر epoch
  • استفاده از QueueRunner که شما یک صف درست میکنید و به ابتدای گراف‌ تنسورفلو اضافه میکنید (روشی نسبتا پیچیده و دشوار)
  • استفاده از Preloaded data: تمام داده‌هاتون رو در قالب ثابت (constant) و متغیر (variable) نگه میدارید (این روش تنها برای دادگان کوچک عملی است)

تا چند ماه پیش من اغلب از روش دوم استفاده میکردم یعنی استفاده از placeholder و feed کردن داده‌ها در هر epoch و در اکثر آموزش‌ها و کدهایی که تو جاهای مختلف میدیدم از این روش استفاده میکردند که این روش طبق گفته وبسایت رسمی تنسورفلو کندترین روش برای خواندن داده‌هاست:

و کاری که انجام میدادم به صورت زیر بود:

ابتدا placeholderها رو تعریف میکردم (مثلا در مثال بالا برای داده‌های mnist) بعد یکی یکی batchها رو با استفاده از یک کلاس به نام DataHelper میخوندم و به مدل میدادم و با کمک اون مدل رو آموزش میدادم. این روش بسیار ساده است ولی مشکلی که داره یه تیکه کدتون (در مثال بالا dataHelper) در پایتون پیاده‌سازی شده است و مابقی کد در تنسورفلو. و همین موضوع باعث میشه که گلوگاه (bottleneck) دیتای ورودی برای مدل اتون ایجاد بشه و حداکثر استفاده رو نتونید از GPU ببرید.

اخیرا در یکی از کدهایی که داشتم میخوندم دیدم از tf.data استفاده شده است. که بنظرم یکی از جذاب‌ترین چیزهایی بوده که در تنسورفلو دیدم. اگر اشتباه نکنم این ویژگی در تنسورفلو ۱.۲ یا ۱.۴ به کد اصلی تنسورفلو اضافه شده است و در حال حاضر هم روی پکیج tf.contrib.data ویژگی‌های جالبتری در حال اضافه شدن به آن است و به زودی به کد اصلی تنسورفلو اضافه خواهد شد. علاوه بر بهینه بودن این روش، کار کردن با اون بسیار راحت است و دیگه درگیر placeholder و feed کردن نمیشید و از همه مهمتر خوندن فایل شما مثل یک گره (node) به گراف تنسورفلو شما اضافه خواهد شد و تنسورفلو خودش خواندن رو بهینه میکنه تا GPU شما معطل خواندن داده نشه.

خوب بیایید با یک مثال ساده شروع کنیم. فرض کنید یک فایل متنی با نام file.txt و محتوای زیر داریم و میخواییم اون رو بخونیم:

این یک متن تصادفی است.

یادگیری عمیق جذاب است.

یادگیری عمیق زیر مجموعه یادگیری ماشین محسوب میشود.

خوب بیاید فایل متنی بالا رو با استفاده از tf.data بخوانیم و یکی یکی آنها را چاپ کنیم:

نتیجه که میگیرید به صورت زیر خواهد بود:

چی شد؟ اکسپشن خوردیم! قرار بود کارمون رو راحت کنه که !!! واضح است که اررور میگه TextLineDataset قابل پیمایش نیست (فعلا با مود eager کار نداریم). بیایید یکم بررسی اش کنیم:

اگر نوع TextLineDataset رو چاپ کنیم مشاهده میکنیم که:

که اگر دقت کنید TextLineDataset جز تنسورفلو است. پس باید به گراف تنسورفلو اضافه بشه و عملیات خوندن از فایل رو در گراف تنسورفلو انجام بده. یعنی همونجوری که میومدیم عملیات (ضرب، جمع و تفریق) تعریف میکردیم و به گراف تنسورفلو اضافه‌اش میکردیم اینجا هم باید همچین کاری انجام بدهیم! یه گره به ابتدای گراف محاسباتی تنسورفلو اضافه کنیم که کارهای خوندن از فایل رو برای ما انجام بدهد. خوب بیاید فایل بالا رو پیمایش کنیم.

برای اینکه بتونیم فایل رو پیمایش کنیم باید یک Iterator تعریف کنیم. در حال حاضر در تنسورفلو دو نوع Iterator در tf.data وجود دارد:

  • Dataset.make_one_shot_iterator
  • Dataset.make_initializable_iterator

اولی برای مواقعی است که تنها میخواهیم یک بار داده‌های ورودی را بخوانیم و دومی برای زمانی است که میخواهیم چندین بار روی آنها پیمایش کنیم که ما از مورد دوم اغلب استفاده میکنیم ولی مورد اول ساده‌تر است. اگر فایل بالا رو با استفاده از تابع اولی پیمایش کنیم خواهیم داشت:

خیلی راحته نه؟ دیگه نه میخواهیم placeholder تعریف کنیم نه feeddata! همچنین داریم از sess.run برای گرفتن دیتا استفاده میکنیم!!! یعنی اینکه دیتای ورودی یه جورایی بخشی از گراف شما شما شده است!

اگر print آخر رو کامنت کنید به اکسپشن خواهید خورد چون make_one_shot_iterator تنها یکبار مجاز به پیمایش داده‌ها هستیم. پس برای رفع این موضوع از make_initializable_iterator استفاده خواهیم کرد:

در تابع make_initializable_iterator هر موقع که پیمایش تمام شد باید از دوباره آن را initialize کنیم. خب همانطور که میبینید مشکل مون حل شد :).

خوب حالا بیایید در یک مثال واقعی ببینیم که چطور میشه ازش استفاده کرد مثلا فرض کنید میخواهیم یک دسته‌بندی کننده روی داده‌های Mnist درست کنیم و داده‌های ورودی رو با استفاده از tf.data به آن بدیم:

در ابتدا باید داده‌های Mnist رو بگیریم. برای اینکار من از دیتاست‌های keras استفاده میکنم که خیلی راحت میشه ازش استفاده کرد. همچین نوع داده‌ای mnist در کراس int هست که آنها را به float32 تبدیل میکنم. مورد آخر اینکه داده‌های کراس sparse هست و من برای راحتی آنها را به categorical تبدیل میکنم تا به به این ترتیب برچسب هر عکس به صورت one-hot تبدیل شود که هر کدام ۱ بود نشان میدهد آن برچسب مورد نظر است.

در مرحله بعد باید این داده‌ها رو به تنسورفلو و tf.data بدهیم. چون میخوام ارزیابی هم انجام بدم دوتا دیتاست درست کردم یکی برای تست و یکی برای آموزش:

همچنین یک سری پارامتر هم به دیتاست اضافه کردم که مثلا دادگان رو Shuffle کنه و اندازه batchها چقدر باشه (برای مشاهده سایر پارامترها مثل prefetch، repeat و… به اینجا سر بزنید)

اگر در این مرحله مقادیر X و Y رو چاپ کنیم، خواهیم داشت:

print(X,Y)
Tensor(“IteratorGetNext:0”, shape=(?, 784), dtype=float32)
Tensor(“IteratorGetNext:1”, shape=(?, 10), dtype=float32)

که همانطور که ملاحضه میکنید هر دو از جنس تنسور هستند.

در مرحله بعدی باید مدل تنسورفلو رو تعریف کنیم. در ساده‌ترین حالت یک مدل logistic regression تعریف میکنیم.

همانطور که دقت میکنید دیگه خبری از placeholder نیست و مستقیم داده‌های ورودی‌مون رو از iterator داریم میگیریم.

 

برای آموزش و تست مدل هم خواهیم داشت:

که در اینجا هم دیگه خبری از feeddata نیست :). همانطور که مشاهده میکنید برای اینکه چندبار پیمایش کنم هر دفعه که اکسپشن خوردم از دوباره initialize کردم و به طور کلی کدشو میتونیم اینجوری داشته باشیم:

 

اگر گراف بالا رو با استفاده از تنسوربورد نمایش بدید:

tensorboard –logdir=../graphs/logreg

 

گراف مدل به صورت زیر خواهد بود:

که همانطور که مشاهده میکنید Iteratorها هرکدام یک گره در گراف محاسباتی تنسورفلو هستند. چون در ابتدا ما دوتا Dataset مختلف (یکی برای آموزش و یکی برای آزمون) تعریف کردیم‌، دو گره در گراف شکل گرفته است.

خوب فکر کنم تا همینجا بس باشه و اگر خودتون دوس داشتید میتونید به سایت تنسورفلو سر بزنید و بیشتر راجع به tf.data و کاربردهاش مطالعه کنید. امیدوارم که توضیحاتم کافی بوده باشه 🙂 در قسمت بعد به سراغ این میرم که چطور میشه مدل‌های مختلف رو به صورت شی‌گرا (OOP) پیاده سازی کرد.

 

منابع بیشتر برای مطالعه:

https://www.tensorflow.org/guide/datasets

https://www.tensorflow.org/api_docs/python/tf/data/Iterator

https://cs230-stanford.github.io/tensorflow-input-data.html

 

منتشر شده در تنسورفلو

اولین باشید که نظر می دهید

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *